mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-19 16:04:44 +00:00
Add more passes
This commit is contained in:
parent
c84d257bb7
commit
7bd4179d1d
13 changed files with 1208 additions and 105 deletions
|
@ -18,6 +18,8 @@ bit-vec = "0.6"
|
|||
half ="1.6"
|
||||
bitflags = "1.2"
|
||||
rustc-hash = "2.0.0"
|
||||
strum = "0.26"
|
||||
strum_macros = "0.26"
|
||||
|
||||
[dependencies.lalrpop-util]
|
||||
version = "0.19.12"
|
||||
|
|
141
ptx/src/pass/deparamize_functions.rs
Normal file
141
ptx/src/pass/deparamize_functions.rs
Normal file
|
@ -0,0 +1,141 @@
|
|||
use std::collections::BTreeMap;
|
||||
|
||||
use super::*;
|
||||
|
||||
pub(super) fn run<'a, 'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
directives
|
||||
.into_iter()
|
||||
.map(|directive| run_directive(resolver, directive))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
}
|
||||
|
||||
fn run_directive<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2,
|
||||
directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
Ok(match directive {
|
||||
var @ Directive2::Variable(..) => var,
|
||||
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
|
||||
})
|
||||
}
|
||||
|
||||
fn run_method<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2,
|
||||
mut method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
if method.func_decl.name.is_kernel() {
|
||||
return Ok(method);
|
||||
}
|
||||
let is_declaration = method.body.is_none();
|
||||
let mut body = Vec::new();
|
||||
let mut remap_returns = Vec::new();
|
||||
for arg in method.func_decl.return_arguments.iter_mut() {
|
||||
match arg.state_space {
|
||||
ptx_parser::StateSpace::Param => {
|
||||
arg.state_space = ptx_parser::StateSpace::Reg;
|
||||
let old_name = arg.name;
|
||||
arg.name = resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
|
||||
if is_declaration {
|
||||
continue;
|
||||
}
|
||||
remap_returns.push((old_name, arg.name, arg.v_type.clone()));
|
||||
body.push(Statement::Variable(ast::Variable {
|
||||
align: None,
|
||||
name: old_name,
|
||||
v_type: arg.v_type.clone(),
|
||||
state_space: ptx_parser::StateSpace::Param,
|
||||
array_init: Vec::new(),
|
||||
}));
|
||||
}
|
||||
ptx_parser::StateSpace::Reg => {}
|
||||
_ => return Err(error_unreachable()),
|
||||
}
|
||||
}
|
||||
for arg in method.func_decl.input_arguments.iter_mut() {
|
||||
match arg.state_space {
|
||||
ptx_parser::StateSpace::Param => {
|
||||
arg.state_space = ptx_parser::StateSpace::Reg;
|
||||
let old_name = arg.name;
|
||||
arg.name = resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
|
||||
if is_declaration {
|
||||
continue;
|
||||
}
|
||||
body.push(Statement::Variable(ast::Variable {
|
||||
align: None,
|
||||
name: old_name,
|
||||
v_type: arg.v_type.clone(),
|
||||
state_space: ptx_parser::StateSpace::Param,
|
||||
array_init: Vec::new(),
|
||||
}));
|
||||
body.push(Statement::Instruction(ast::Instruction::St {
|
||||
data: ast::StData {
|
||||
qualifier: ast::LdStQualifier::Weak,
|
||||
state_space: ast::StateSpace::Param,
|
||||
caching: ast::StCacheOperator::Writethrough,
|
||||
typ: arg.v_type.clone(),
|
||||
},
|
||||
arguments: ast::StArgs {
|
||||
src1: old_name,
|
||||
src2: arg.name,
|
||||
},
|
||||
}));
|
||||
}
|
||||
ptx_parser::StateSpace::Reg => {}
|
||||
_ => return Err(error_unreachable()),
|
||||
}
|
||||
}
|
||||
if remap_returns.is_empty() {
|
||||
return Ok(method);
|
||||
}
|
||||
let body = method
|
||||
.body
|
||||
.map(|statements| {
|
||||
for statement in statements {
|
||||
run_statement(&remap_returns, &mut body, statement)?;
|
||||
}
|
||||
Ok::<_, TranslateError>(body)
|
||||
})
|
||||
.transpose()?;
|
||||
Ok(Function2 {
|
||||
func_decl: method.func_decl,
|
||||
globals: method.globals,
|
||||
body,
|
||||
import_as: method.import_as,
|
||||
tuning: method.tuning,
|
||||
linkage: method.linkage,
|
||||
})
|
||||
}
|
||||
|
||||
fn run_statement<'input>(
|
||||
remap_returns: &Vec<(SpirvWord, SpirvWord, ast::Type)>,
|
||||
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
statement: Statement<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
match statement {
|
||||
Statement::Instruction(ast::Instruction::Ret { .. }) => {
|
||||
for (old_name, new_name, type_) in remap_returns.iter().cloned() {
|
||||
result.push(Statement::Instruction(ast::Instruction::Ld {
|
||||
data: ast::LdDetails {
|
||||
qualifier: ast::LdStQualifier::Weak,
|
||||
state_space: ast::StateSpace::Reg,
|
||||
caching: ast::LdCacheOperator::Cached,
|
||||
typ: type_,
|
||||
non_coherent: false,
|
||||
},
|
||||
arguments: ast::LdArgs {
|
||||
dst: new_name,
|
||||
src: old_name,
|
||||
},
|
||||
}));
|
||||
}
|
||||
result.push(statement);
|
||||
}
|
||||
statement => {
|
||||
result.push(statement);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
|
@ -308,6 +308,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
|
|||
Statement::PtrAccess(_) => todo!(),
|
||||
Statement::RepackVector(_) => todo!(),
|
||||
Statement::FunctionPointer(_) => todo!(),
|
||||
Statement::VectorAccess(_) => todo!(),
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -1561,6 +1561,7 @@ fn emit_function_body_ops<'input>(
|
|||
builder.copy_object(vector_type.0, Some(repack.packed.0), temp_vec)?;
|
||||
}
|
||||
}
|
||||
Statement::VectorAccess(vector_access) => todo!(),
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
|
|
289
ptx/src/pass/expand_operands.rs
Normal file
289
ptx/src/pass/expand_operands.rs
Normal file
|
@ -0,0 +1,289 @@
|
|||
use super::*;
|
||||
|
||||
pub(super) fn run<'a, 'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directives: Vec<UnconditionalDirective<'input>>,
|
||||
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
directives
|
||||
.into_iter()
|
||||
.map(|directive| run_directive(resolver, directive))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
}
|
||||
|
||||
fn run_directive<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directive: Directive2<
|
||||
'input,
|
||||
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
||||
ast::ParsedOperand<SpirvWord>,
|
||||
>,
|
||||
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
Ok(match directive {
|
||||
Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
|
||||
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
|
||||
})
|
||||
}
|
||||
|
||||
fn run_method<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
method: Function2<
|
||||
'input,
|
||||
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
||||
ast::ParsedOperand<SpirvWord>,
|
||||
>,
|
||||
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
let body = method
|
||||
.body
|
||||
.map(|statements| {
|
||||
let mut result = Vec::with_capacity(statements.len());
|
||||
for statement in statements {
|
||||
run_statement(resolver, &mut result, statement)?;
|
||||
}
|
||||
Ok::<_, TranslateError>(result)
|
||||
})
|
||||
.transpose()?;
|
||||
Ok(Function2 {
|
||||
func_decl: method.func_decl,
|
||||
globals: method.globals,
|
||||
body,
|
||||
import_as: method.import_as,
|
||||
tuning: method.tuning,
|
||||
linkage: method.linkage,
|
||||
})
|
||||
}
|
||||
|
||||
fn run_statement<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
statement: UnconditionalStatement,
|
||||
) -> Result<(), TranslateError> {
|
||||
let mut visitor = FlattenArguments::new(resolver, result);
|
||||
let new_statement = statement.visit_map(&mut visitor)?;
|
||||
visitor.result.push(new_statement);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct FlattenArguments<'a, 'input> {
|
||||
result: &'a mut Vec<ExpandedStatement>,
|
||||
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
||||
post_stmts: Vec<ExpandedStatement>,
|
||||
}
|
||||
|
||||
impl<'a, 'input> FlattenArguments<'a, 'input> {
|
||||
fn new(
|
||||
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
||||
result: &'a mut Vec<ExpandedStatement>,
|
||||
) -> Self {
|
||||
FlattenArguments {
|
||||
result,
|
||||
resolver,
|
||||
post_stmts: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn reg(&mut self, name: SpirvWord) -> Result<SpirvWord, TranslateError> {
|
||||
Ok(name)
|
||||
}
|
||||
|
||||
fn reg_offset(
|
||||
&mut self,
|
||||
reg: SpirvWord,
|
||||
offset: i32,
|
||||
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
_is_dst: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
let (type_, state_space) = if let Some((type_, state_space)) = type_space {
|
||||
(type_, state_space)
|
||||
} else {
|
||||
return Err(TranslateError::UntypedSymbol);
|
||||
};
|
||||
if state_space == ast::StateSpace::Reg {
|
||||
let (reg_type, reg_space) = self.resolver.get_typed(reg)?;
|
||||
if *reg_space != ast::StateSpace::Reg {
|
||||
return Err(error_mismatched_type());
|
||||
}
|
||||
let reg_scalar_type = match reg_type {
|
||||
ast::Type::Scalar(underlying_type) => *underlying_type,
|
||||
_ => return Err(error_mismatched_type()),
|
||||
};
|
||||
let reg_type = reg_type.clone();
|
||||
let id_constant_stmt = self
|
||||
.resolver
|
||||
.register_unnamed(Some((reg_type.clone(), ast::StateSpace::Reg)));
|
||||
self.result.push(Statement::Constant(ConstantDefinition {
|
||||
dst: id_constant_stmt,
|
||||
typ: reg_scalar_type,
|
||||
value: ast::ImmediateValue::S64(offset as i64),
|
||||
}));
|
||||
let arith_details = match reg_scalar_type.kind() {
|
||||
ast::ScalarKind::Signed => ast::ArithDetails::Integer(ast::ArithInteger {
|
||||
type_: reg_scalar_type,
|
||||
saturate: false,
|
||||
}),
|
||||
ast::ScalarKind::Unsigned | ast::ScalarKind::Bit => {
|
||||
ast::ArithDetails::Integer(ast::ArithInteger {
|
||||
type_: reg_scalar_type,
|
||||
saturate: false,
|
||||
})
|
||||
}
|
||||
_ => return Err(error_unreachable()),
|
||||
};
|
||||
let id_add_result = self
|
||||
.resolver
|
||||
.register_unnamed(Some((reg_type, state_space)));
|
||||
self.result
|
||||
.push(Statement::Instruction(ast::Instruction::Add {
|
||||
data: arith_details,
|
||||
arguments: ast::AddArgs {
|
||||
dst: id_add_result,
|
||||
src1: reg,
|
||||
src2: id_constant_stmt,
|
||||
},
|
||||
}));
|
||||
Ok(id_add_result)
|
||||
} else {
|
||||
let id_constant_stmt = self.resolver.register_unnamed(Some((
|
||||
ast::Type::Scalar(ast::ScalarType::S64),
|
||||
ast::StateSpace::Reg,
|
||||
)));
|
||||
self.result.push(Statement::Constant(ConstantDefinition {
|
||||
dst: id_constant_stmt,
|
||||
typ: ast::ScalarType::S64,
|
||||
value: ast::ImmediateValue::S64(offset as i64),
|
||||
}));
|
||||
let dst = self
|
||||
.resolver
|
||||
.register_unnamed(Some((type_.clone(), state_space)));
|
||||
self.result.push(Statement::PtrAccess(PtrAccess {
|
||||
underlying_type: type_.clone(),
|
||||
state_space: state_space,
|
||||
dst,
|
||||
ptr_src: reg,
|
||||
offset_src: id_constant_stmt,
|
||||
}));
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
fn immediate(
|
||||
&mut self,
|
||||
value: ast::ImmediateValue,
|
||||
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
let (scalar_t, state_space) =
|
||||
if let Some((ast::Type::Scalar(scalar), state_space)) = type_space {
|
||||
(*scalar, state_space)
|
||||
} else {
|
||||
return Err(TranslateError::UntypedSymbol);
|
||||
};
|
||||
let id = self
|
||||
.resolver
|
||||
.register_unnamed(Some((ast::Type::Scalar(scalar_t), state_space)));
|
||||
self.result.push(Statement::Constant(ConstantDefinition {
|
||||
dst: id,
|
||||
typ: scalar_t,
|
||||
value,
|
||||
}));
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
fn vec_member(
|
||||
&mut self,
|
||||
vector_src: SpirvWord,
|
||||
member: u8,
|
||||
_type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
is_dst: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
if is_dst {
|
||||
return Err(error_mismatched_type());
|
||||
}
|
||||
let (vector_width, scalar_type, space) = match self.resolver.get_typed(vector_src)? {
|
||||
(ast::Type::Vector(vector_width, scalar_t), space) => {
|
||||
(*vector_width, *scalar_t, *space)
|
||||
}
|
||||
_ => return Err(error_mismatched_type()),
|
||||
};
|
||||
let temporary = self
|
||||
.resolver
|
||||
.register_unnamed(Some((scalar_type.into(), space)));
|
||||
self.result.push(Statement::VectorAccess(VectorAccess {
|
||||
scalar_type,
|
||||
vector_width,
|
||||
dst: temporary,
|
||||
src: vector_src,
|
||||
member: member,
|
||||
}));
|
||||
Ok(temporary)
|
||||
}
|
||||
|
||||
fn vec_pack(
|
||||
&mut self,
|
||||
vecs: Vec<SpirvWord>,
|
||||
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
is_dst: bool,
|
||||
relaxed_type_check: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
let (scalar_t, state_space) = match type_space {
|
||||
Some((ast::Type::Vector(_, scalar_t), space)) => (*scalar_t, space),
|
||||
_ => return Err(error_mismatched_type()),
|
||||
};
|
||||
let temp_vec = self
|
||||
.resolver
|
||||
.register_unnamed(Some((scalar_t.into(), state_space)));
|
||||
let statement = Statement::RepackVector(RepackVectorDetails {
|
||||
is_extract: is_dst,
|
||||
typ: scalar_t,
|
||||
packed: temp_vec,
|
||||
unpacked: vecs,
|
||||
relaxed_type_check,
|
||||
});
|
||||
if is_dst {
|
||||
self.post_stmts.push(statement);
|
||||
} else {
|
||||
self.result.push(statement);
|
||||
}
|
||||
Ok(temp_vec)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'b> ast::VisitorMap<ast::ParsedOperand<SpirvWord>, SpirvWord, TranslateError>
|
||||
for FlattenArguments<'a, 'b>
|
||||
{
|
||||
fn visit(
|
||||
&mut self,
|
||||
args: ast::ParsedOperand<SpirvWord>,
|
||||
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
is_dst: bool,
|
||||
relaxed_type_check: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
match args {
|
||||
ast::ParsedOperand::Reg(r) => self.reg(r),
|
||||
ast::ParsedOperand::Imm(x) => self.immediate(x, type_space),
|
||||
ast::ParsedOperand::RegOffset(reg, offset) => {
|
||||
self.reg_offset(reg, offset, type_space, is_dst)
|
||||
}
|
||||
ast::ParsedOperand::VecMember(vec, member) => {
|
||||
self.vec_member(vec, member, type_space, is_dst)
|
||||
}
|
||||
ast::ParsedOperand::VecPack(vecs) => {
|
||||
self.vec_pack(vecs, type_space, is_dst, relaxed_type_check)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_ident(
|
||||
&mut self,
|
||||
name: <TypedOperand as ast::Operand>::Ident,
|
||||
_type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
_is_dst: bool,
|
||||
_relaxed_type_check: bool,
|
||||
) -> Result<<SpirvWord as ast::Operand>::Ident, TranslateError> {
|
||||
self.reg(name)
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for FlattenArguments<'_, '_> {
|
||||
fn drop(&mut self) {
|
||||
self.result.extend(self.post_stmts.drain(..));
|
||||
}
|
||||
}
|
209
ptx/src/pass/fix_special_registers2.rs
Normal file
209
ptx/src/pass/fix_special_registers2.rs
Normal file
|
@ -0,0 +1,209 @@
|
|||
use super::*;
|
||||
|
||||
pub(super) fn run<'a, 'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
special_registers: &'a SpecialRegistersMap2,
|
||||
directives: Vec<UnconditionalDirective<'input>>,
|
||||
) -> Result<Vec<UnconditionalDirective<'input>>, TranslateError> {
|
||||
let declarations = SpecialRegistersMap2::generate_declarations(resolver);
|
||||
let mut result = Vec::with_capacity(declarations.len() + directives.len());
|
||||
let mut sreg_to_function =
|
||||
FxHashMap::with_capacity_and_hasher(declarations.len(), Default::default());
|
||||
for (sreg, declaration) in declarations {
|
||||
let name = if let ast::MethodName::Func(name) = declaration.name {
|
||||
name
|
||||
} else {
|
||||
return Err(error_unreachable());
|
||||
};
|
||||
result.push(UnconditionalDirective::Method(UnconditionalFunction {
|
||||
func_decl: declaration,
|
||||
globals: Vec::new(),
|
||||
body: None,
|
||||
import_as: None,
|
||||
tuning: Vec::new(),
|
||||
linkage: ast::LinkingDirective::EXTERN,
|
||||
}));
|
||||
sreg_to_function.insert(sreg, name);
|
||||
}
|
||||
let mut visitor = SpecialRegisterResolver {
|
||||
resolver,
|
||||
special_registers,
|
||||
sreg_to_function,
|
||||
result: Vec::new(),
|
||||
};
|
||||
directives
|
||||
.into_iter()
|
||||
.map(|directive| run_directive(&mut visitor, directive))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
}
|
||||
|
||||
fn run_directive<'a, 'input>(
|
||||
visitor: &mut SpecialRegisterResolver<'a, 'input>,
|
||||
directive: UnconditionalDirective<'input>,
|
||||
) -> Result<UnconditionalDirective<'input>, TranslateError> {
|
||||
Ok(match directive {
|
||||
var @ Directive2::Variable(..) => var,
|
||||
Directive2::Method(method) => Directive2::Method(run_method(visitor, method)?),
|
||||
})
|
||||
}
|
||||
|
||||
fn run_method<'a, 'input>(
|
||||
visitor: &mut SpecialRegisterResolver<'a, 'input>,
|
||||
method: UnconditionalFunction<'input>,
|
||||
) -> Result<UnconditionalFunction<'input>, TranslateError> {
|
||||
let body = method
|
||||
.body
|
||||
.map(|statements| {
|
||||
let mut result = Vec::with_capacity(statements.len());
|
||||
for statement in statements {
|
||||
run_statement(visitor, &mut result, statement)?;
|
||||
}
|
||||
Ok::<_, TranslateError>(result)
|
||||
})
|
||||
.transpose()?;
|
||||
Ok(Function2 {
|
||||
func_decl: method.func_decl,
|
||||
globals: method.globals,
|
||||
body,
|
||||
import_as: method.import_as,
|
||||
tuning: method.tuning,
|
||||
linkage: method.linkage,
|
||||
})
|
||||
}
|
||||
|
||||
fn run_statement<'a, 'input>(
|
||||
visitor: &mut SpecialRegisterResolver<'a, 'input>,
|
||||
result: &mut Vec<UnconditionalStatement>,
|
||||
statement: UnconditionalStatement,
|
||||
) -> Result<(), TranslateError> {
|
||||
let converted_statement = statement.visit_map(visitor)?;
|
||||
result.extend(visitor.result.drain(..));
|
||||
result.push(converted_statement);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct SpecialRegisterResolver<'a, 'input> {
|
||||
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
||||
special_registers: &'a SpecialRegistersMap2,
|
||||
sreg_to_function: FxHashMap<PtxSpecialRegister, SpirvWord>,
|
||||
result: Vec<UnconditionalStatement>,
|
||||
}
|
||||
|
||||
impl<'a, 'b, 'input>
|
||||
ast::VisitorMap<ast::ParsedOperand<SpirvWord>, ast::ParsedOperand<SpirvWord>, TranslateError>
|
||||
for SpecialRegisterResolver<'a, 'input>
|
||||
{
|
||||
fn visit(
|
||||
&mut self,
|
||||
operand: ast::ParsedOperand<SpirvWord>,
|
||||
_type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||
is_dst: bool,
|
||||
_relaxed_type_check: bool,
|
||||
) -> Result<ast::ParsedOperand<SpirvWord>, TranslateError> {
|
||||
map_operand(operand, &mut |ident, vector_index| {
|
||||
self.replace_sreg(ident, vector_index, is_dst)
|
||||
})
|
||||
}
|
||||
|
||||
fn visit_ident(
|
||||
&mut self,
|
||||
args: SpirvWord,
|
||||
_type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||
is_dst: bool,
|
||||
_relaxed_type_check: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
self.replace_sreg(args, None, is_dst)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'input> {
|
||||
fn replace_sreg(
|
||||
&mut self,
|
||||
name: SpirvWord,
|
||||
vector_index: Option<u8>,
|
||||
is_dst: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
if let Some(sreg) = self.special_registers.get(name) {
|
||||
if is_dst {
|
||||
return Err(error_mismatched_type());
|
||||
}
|
||||
let input_arguments = match (vector_index, sreg.get_function_input_type()) {
|
||||
(Some(idx), Some(inp_type)) => {
|
||||
if inp_type != ast::ScalarType::U8 {
|
||||
return Err(TranslateError::Unreachable);
|
||||
}
|
||||
let constant = self.resolver.register_unnamed(Some((
|
||||
ast::Type::Scalar(inp_type),
|
||||
ast::StateSpace::Reg,
|
||||
)));
|
||||
self.result.push(Statement::Constant(ConstantDefinition {
|
||||
dst: constant,
|
||||
typ: inp_type,
|
||||
value: ast::ImmediateValue::U64(idx as u64),
|
||||
}));
|
||||
vec![(constant, ast::Type::Scalar(inp_type), ast::StateSpace::Reg)]
|
||||
}
|
||||
(None, None) => Vec::new(),
|
||||
_ => return Err(error_mismatched_type()),
|
||||
};
|
||||
let return_type = sreg.get_function_return_type();
|
||||
let fn_result = self
|
||||
.resolver
|
||||
.register_unnamed(Some((ast::Type::Scalar(return_type), ast::StateSpace::Reg)));
|
||||
let return_arguments = vec![(
|
||||
fn_result,
|
||||
ast::Type::Scalar(return_type),
|
||||
ast::StateSpace::Reg,
|
||||
)];
|
||||
let data = ast::CallDetails {
|
||||
uniform: false,
|
||||
return_arguments: return_arguments
|
||||
.iter()
|
||||
.map(|(_, typ, space)| (typ.clone(), *space))
|
||||
.collect(),
|
||||
input_arguments: input_arguments
|
||||
.iter()
|
||||
.map(|(_, typ, space)| (typ.clone(), *space))
|
||||
.collect(),
|
||||
};
|
||||
let arguments = ast::CallArgs::<ast::ParsedOperand<SpirvWord>> {
|
||||
return_arguments: return_arguments.iter().map(|(name, _, _)| *name).collect(),
|
||||
func: self.sreg_to_function[&sreg],
|
||||
input_arguments: input_arguments
|
||||
.iter()
|
||||
.map(|(name, _, _)| ast::ParsedOperand::Reg(*name))
|
||||
.collect(),
|
||||
};
|
||||
self.result
|
||||
.push(Statement::Instruction(ast::Instruction::Call {
|
||||
data,
|
||||
arguments,
|
||||
}));
|
||||
Ok(fn_result)
|
||||
} else {
|
||||
Ok(name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn map_operand<T, U, Err>(
|
||||
this: ast::ParsedOperand<T>,
|
||||
fn_: &mut impl FnMut(T, Option<u8>) -> Result<U, Err>,
|
||||
) -> Result<ast::ParsedOperand<U>, Err> {
|
||||
Ok(match this {
|
||||
ast::ParsedOperand::Reg(ident) => ast::ParsedOperand::Reg(fn_(ident, None)?),
|
||||
ast::ParsedOperand::RegOffset(ident, offset) => {
|
||||
ast::ParsedOperand::RegOffset(fn_(ident, None)?, offset)
|
||||
}
|
||||
ast::ParsedOperand::Imm(imm) => ast::ParsedOperand::Imm(imm),
|
||||
ast::ParsedOperand::VecMember(ident, member) => {
|
||||
ast::ParsedOperand::Reg(fn_(ident, Some(member))?)
|
||||
}
|
||||
ast::ParsedOperand::VecPack(idents) => ast::ParsedOperand::VecPack(
|
||||
idents
|
||||
.into_iter()
|
||||
.map(|ident| fn_(ident, None))
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
),
|
||||
})
|
||||
}
|
273
ptx/src/pass/insert_explicit_load_store.rs
Normal file
273
ptx/src/pass/insert_explicit_load_store.rs
Normal file
|
@ -0,0 +1,273 @@
|
|||
use super::*;
|
||||
use ptx_parser::VisitorMap;
|
||||
use rustc_hash::FxHashSet;
|
||||
|
||||
// This pass:
|
||||
// * Turns all .local, .param and .reg in-body variables into .local variables
|
||||
// (if _not_ an input method argument)
|
||||
// * Inserts explicit `ld`/`st` for newly converted .reg variables
|
||||
// * Fixup state space of all existing `ld`/`st` instructions into newly
|
||||
// converted variables
|
||||
// * Turns `.entry` input arguments into param::entry and all related `.param`
|
||||
// loads into `param::entry` loads
|
||||
// * All `.func` input arguments are turned into `.reg` arguments by another
|
||||
// pass, so we do nothing there
|
||||
pub(super) fn run<'a, 'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
directives
|
||||
.into_iter()
|
||||
.map(|directive| run_directive(resolver, directive))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
}
|
||||
|
||||
fn run_directive<'a, 'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
Ok(match directive {
|
||||
var @ Directive2::Variable(..) => var,
|
||||
Directive2::Method(method) => {
|
||||
let visitor = InsertMemSSAVisitor::new(resolver);
|
||||
Directive2::Method(run_method(visitor, method)?)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn run_method<'a, 'input>(
|
||||
mut visitor: InsertMemSSAVisitor<'a, 'input>,
|
||||
method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
let mut func_decl = method.func_decl;
|
||||
for arg in func_decl.return_arguments.iter_mut() {
|
||||
visitor.visit_variable(arg);
|
||||
}
|
||||
let is_kernel = func_decl.name.is_kernel();
|
||||
// let mut prelude = Vec::with_capacity(method.body.as_ref().map(Vec::len).unwrap_or(0));
|
||||
if is_kernel {
|
||||
for arg in func_decl.input_arguments.iter_mut() {
|
||||
let old_name = arg.name;
|
||||
let old_space = arg.state_space;
|
||||
let new_space = ast::StateSpace::ParamEntry;
|
||||
let new_name = visitor
|
||||
.resolver
|
||||
.register_unnamed(Some((arg.v_type.clone(), new_space)));
|
||||
visitor.input_argument(old_name, new_name, old_space);
|
||||
arg.name = new_name;
|
||||
arg.state_space = new_space;
|
||||
}
|
||||
};
|
||||
let body = method
|
||||
.body
|
||||
.map(move |statements| {
|
||||
let mut result = Vec::with_capacity(statements.len());
|
||||
for statement in statements {
|
||||
run_statement(&mut visitor, &mut result, statement)?;
|
||||
}
|
||||
Ok::<_, TranslateError>(result)
|
||||
})
|
||||
.transpose()?;
|
||||
Ok(Function2 {
|
||||
func_decl: func_decl,
|
||||
globals: method.globals,
|
||||
body,
|
||||
import_as: method.import_as,
|
||||
tuning: method.tuning,
|
||||
linkage: method.linkage,
|
||||
})
|
||||
}
|
||||
|
||||
fn run_statement<'a, 'input>(
|
||||
visitor: &mut InsertMemSSAVisitor<'a, 'input>,
|
||||
result: &mut Vec<ExpandedStatement>,
|
||||
statement: ExpandedStatement,
|
||||
) -> Result<(), TranslateError> {
|
||||
match statement {
|
||||
Statement::Variable(mut var) => {
|
||||
visitor.visit_variable(&mut var);
|
||||
result.push(Statement::Variable(var));
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::Ld { data, arguments }) => {
|
||||
let instruction = visitor.visit_ld(data, arguments)?;
|
||||
let instruction = ast::visit_map(instruction, visitor)?;
|
||||
result.push(Statement::Instruction(instruction));
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::St {
|
||||
data,
|
||||
mut arguments,
|
||||
}) => {
|
||||
let instruction = visitor.visit_st(data, arguments)?;
|
||||
let instruction = ast::visit_map(instruction, visitor)?;
|
||||
result.push(Statement::Instruction(instruction));
|
||||
}
|
||||
s => result.push(s.visit_map(visitor)?),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct InsertMemSSAVisitor<'a, 'input> {
|
||||
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
||||
variables: FxHashMap<SpirvWord, RemapAction>,
|
||||
}
|
||||
|
||||
impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
|
||||
fn new(resolver: &'a mut GlobalStringIdentResolver2<'input>) -> Self {
|
||||
Self {
|
||||
resolver,
|
||||
variables: FxHashMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn input_argument(
|
||||
&mut self,
|
||||
old_name: SpirvWord,
|
||||
new_name: SpirvWord,
|
||||
old_space: ast::StateSpace,
|
||||
) -> Result<(), TranslateError> {
|
||||
if old_space != ast::StateSpace::Param {
|
||||
return Err(error_unreachable());
|
||||
}
|
||||
self.variables.insert(
|
||||
old_name,
|
||||
RemapAction::LDStSpaceChange {
|
||||
name: new_name,
|
||||
old_space,
|
||||
new_space: ast::StateSpace::ParamEntry,
|
||||
},
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn variable(
|
||||
&mut self,
|
||||
old_name: SpirvWord,
|
||||
new_name: SpirvWord,
|
||||
old_space: ast::StateSpace,
|
||||
) -> Result<(), TranslateError> {
|
||||
Ok(match old_space {
|
||||
ast::StateSpace::Reg => {
|
||||
self.variables
|
||||
.insert(old_name, RemapAction::PreLdPostSt(new_name));
|
||||
}
|
||||
ast::StateSpace::Param => {
|
||||
self.variables.insert(
|
||||
old_name,
|
||||
RemapAction::LDStSpaceChange {
|
||||
old_space,
|
||||
new_space: ast::StateSpace::Local,
|
||||
name: new_name,
|
||||
},
|
||||
);
|
||||
}
|
||||
// Good as-is
|
||||
ast::StateSpace::Local => {}
|
||||
// Will be pulled into global scope later
|
||||
ast::StateSpace::Generic
|
||||
| ast::StateSpace::SharedCluster
|
||||
| ast::StateSpace::Global
|
||||
| ast::StateSpace::Const
|
||||
| ast::StateSpace::SharedCta
|
||||
| ast::StateSpace::Shared => {}
|
||||
ast::StateSpace::ParamEntry | ast::StateSpace::ParamFunc => {
|
||||
return Err(error_unreachable())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn visit_st(
|
||||
&self,
|
||||
mut data: ast::StData,
|
||||
mut arguments: ast::StArgs<SpirvWord>,
|
||||
) -> Result<ast::Instruction<SpirvWord>, TranslateError> {
|
||||
if let Some(remap) = self.variables.get(&arguments.src1) {
|
||||
match remap {
|
||||
RemapAction::PreLdPostSt(_) => return Err(error_mismatched_type()),
|
||||
RemapAction::LDStSpaceChange {
|
||||
old_space,
|
||||
new_space,
|
||||
name,
|
||||
} => {
|
||||
if data.state_space != *old_space {
|
||||
return Err(error_mismatched_type());
|
||||
}
|
||||
data.state_space = *new_space;
|
||||
arguments.src1 = *name;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(ast::Instruction::St { data, arguments })
|
||||
}
|
||||
|
||||
fn visit_ld(
|
||||
&self,
|
||||
mut data: ast::LdDetails,
|
||||
mut arguments: ast::LdArgs<SpirvWord>,
|
||||
) -> Result<ast::Instruction<SpirvWord>, TranslateError> {
|
||||
if let Some(remap) = self.variables.get(&arguments.src) {
|
||||
match remap {
|
||||
RemapAction::PreLdPostSt(_) => return Err(error_mismatched_type()),
|
||||
RemapAction::LDStSpaceChange {
|
||||
old_space,
|
||||
new_space,
|
||||
name,
|
||||
} => {
|
||||
if data.state_space != *old_space {
|
||||
return Err(error_mismatched_type());
|
||||
}
|
||||
data.state_space = *new_space;
|
||||
arguments.src = *name;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(ast::Instruction::Ld { data, arguments })
|
||||
}
|
||||
|
||||
fn visit_variable(&mut self, var: &mut ast::Variable<SpirvWord>) {
|
||||
if var.state_space != ast::StateSpace::Local {
|
||||
let old_name = var.name;
|
||||
let old_space = var.state_space;
|
||||
let new_space = ast::StateSpace::Local;
|
||||
let new_name = self
|
||||
.resolver
|
||||
.register_unnamed(Some((var.v_type.clone(), new_space)));
|
||||
self.variable(old_name, new_name, old_space);
|
||||
var.name = new_name;
|
||||
var.state_space = new_space;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'input> ast::VisitorMap<SpirvWord, SpirvWord, TranslateError>
|
||||
for InsertMemSSAVisitor<'a, 'input>
|
||||
{
|
||||
fn visit(
|
||||
&mut self,
|
||||
args: SpirvWord,
|
||||
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
is_dst: bool,
|
||||
relaxed_type_check: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn visit_ident(
|
||||
&mut self,
|
||||
args: SpirvWord,
|
||||
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
is_dst: bool,
|
||||
relaxed_type_check: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
self.visit(args, type_space, is_dst, relaxed_type_check)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
enum RemapAction {
|
||||
PreLdPostSt(SpirvWord),
|
||||
LDStSpaceChange {
|
||||
old_space: ast::StateSpace,
|
||||
new_space: ast::StateSpace,
|
||||
name: SpirvWord,
|
||||
},
|
||||
}
|
|
@ -45,6 +45,13 @@ pub(super) fn run(
|
|||
Statement::RepackVector(repack),
|
||||
)?;
|
||||
}
|
||||
Statement::VectorAccess(vector_access) => {
|
||||
insert_implicit_conversions_impl(
|
||||
&mut result,
|
||||
id_def,
|
||||
Statement::VectorAccess(vector_access),
|
||||
)?;
|
||||
}
|
||||
s @ Statement::Conditional(_)
|
||||
| s @ Statement::Conversion(_)
|
||||
| s @ Statement::Label(_)
|
||||
|
|
|
@ -13,15 +13,21 @@ use std::{
|
|||
mem,
|
||||
rc::Rc,
|
||||
};
|
||||
use strum::IntoEnumIterator;
|
||||
use strum_macros::EnumIter;
|
||||
|
||||
mod convert_dynamic_shared_memory_usage;
|
||||
mod convert_to_stateful_memory_access;
|
||||
mod convert_to_typed;
|
||||
mod deparamize_functions;
|
||||
pub(crate) mod emit_llvm;
|
||||
mod emit_spirv;
|
||||
mod expand_arguments;
|
||||
mod expand_operands;
|
||||
mod extract_globals;
|
||||
mod fix_special_registers;
|
||||
mod fix_special_registers2;
|
||||
mod insert_explicit_load_store;
|
||||
mod insert_implicit_conversions;
|
||||
mod insert_mem_ssa_statements;
|
||||
mod normalize_identifiers;
|
||||
|
@ -68,6 +74,20 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result<Module, Transl
|
|||
})
|
||||
}
|
||||
|
||||
pub fn to_llvm_module2<'input>(ast: ast::Module<'input>) -> Result<Module, TranslateError> {
|
||||
let mut flat_resolver = GlobalStringIdentResolver2::<'input>::new(SpirvWord(1));
|
||||
let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver);
|
||||
let sreg_map = SpecialRegistersMap2::new(&mut scoped_resolver)?;
|
||||
let directives = normalize_identifiers2::run(&mut scoped_resolver, ast.directives)?;
|
||||
let directives = normalize_predicates2::run(&mut flat_resolver, directives)?;
|
||||
let directives = resolve_function_pointers::run(directives)?;
|
||||
let directives = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?;
|
||||
let directives = expand_operands::run(&mut flat_resolver, directives)?;
|
||||
let directives = deparamize_functions::run(&mut flat_resolver, directives)?;
|
||||
let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?;
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn translate_directive<'input, 'a>(
|
||||
id_defs: &'a mut GlobalStringIdResolver<'input>,
|
||||
ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
|
||||
|
@ -323,7 +343,7 @@ pub struct KernelInfo {
|
|||
pub uses_shared_mem: bool,
|
||||
}
|
||||
|
||||
#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)]
|
||||
#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone, EnumIter)]
|
||||
enum PtxSpecialRegister {
|
||||
Tid,
|
||||
Ntid,
|
||||
|
@ -346,6 +366,17 @@ impl PtxSpecialRegister {
|
|||
}
|
||||
}
|
||||
|
||||
fn as_str(self) -> &'static str {
|
||||
match self {
|
||||
Self::Tid => "%tid",
|
||||
Self::Ntid => "%ntid",
|
||||
Self::Ctaid => "%ctaid",
|
||||
Self::Nctaid => "%nctaid",
|
||||
Self::Clock => "%clock",
|
||||
Self::LanemaskLt => "%lanemask_lt",
|
||||
}
|
||||
}
|
||||
|
||||
fn get_type(self) -> ast::Type {
|
||||
match self {
|
||||
PtxSpecialRegister::Tid
|
||||
|
@ -726,6 +757,7 @@ enum Statement<I, P: ast::Operand> {
|
|||
PtrAccess(PtrAccess<P>),
|
||||
RepackVector(RepackVectorDetails),
|
||||
FunctionPointer(FunctionPointerDetails),
|
||||
VectorAccess(VectorAccess),
|
||||
}
|
||||
|
||||
impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
|
||||
|
@ -894,6 +926,36 @@ impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
|
|||
offset_src,
|
||||
})
|
||||
}
|
||||
Statement::VectorAccess(VectorAccess {
|
||||
scalar_type,
|
||||
vector_width,
|
||||
dst,
|
||||
src: vector_src,
|
||||
member,
|
||||
}) => {
|
||||
let dst: SpirvWord = visitor.visit_ident(
|
||||
dst,
|
||||
Some((&scalar_type.into(), ast::StateSpace::Reg)),
|
||||
true,
|
||||
false,
|
||||
)?;
|
||||
let src = visitor.visit_ident(
|
||||
vector_src,
|
||||
Some((
|
||||
&ast::Type::Vector(vector_width, scalar_type),
|
||||
ast::StateSpace::Reg,
|
||||
)),
|
||||
false,
|
||||
false,
|
||||
)?;
|
||||
Statement::VectorAccess(VectorAccess {
|
||||
scalar_type,
|
||||
vector_width,
|
||||
dst,
|
||||
src,
|
||||
member,
|
||||
})
|
||||
}
|
||||
Statement::RepackVector(RepackVectorDetails {
|
||||
is_extract,
|
||||
typ,
|
||||
|
@ -1448,6 +1510,7 @@ fn compute_denorm_information<'input>(
|
|||
Statement::Label(_) => {}
|
||||
Statement::Variable(_) => {}
|
||||
Statement::PtrAccess { .. } => {}
|
||||
Statement::VectorAccess { .. } => {}
|
||||
Statement::RepackVector(_) => {}
|
||||
Statement::FunctionPointer(_) => {}
|
||||
}
|
||||
|
@ -1668,7 +1731,7 @@ pub(crate) enum Directive2<'input, Instruction, Operand: ast::Operand> {
|
|||
}
|
||||
|
||||
pub(crate) struct Function2<'input, Instruction, Operand: ast::Operand> {
|
||||
pub func_decl: Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
|
||||
pub func_decl: ast::MethodDeclaration<'input, SpirvWord>,
|
||||
pub globals: Vec<ast::Variable<SpirvWord>>,
|
||||
pub body: Option<Vec<Statement<Instruction, Operand>>>,
|
||||
import_as: Option<String>,
|
||||
|
@ -1712,10 +1775,31 @@ struct GlobalStringIdentResolver2<'input> {
|
|||
}
|
||||
|
||||
impl<'input> GlobalStringIdentResolver2<'input> {
|
||||
fn register_intermediate(
|
||||
fn new(spirv_word: SpirvWord) -> Self {
|
||||
Self {
|
||||
current_id: spirv_word,
|
||||
ident_map: FxHashMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn register_named(
|
||||
&mut self,
|
||||
name: Cow<'input, str>,
|
||||
type_space: Option<(ast::Type, ast::StateSpace)>,
|
||||
) -> SpirvWord {
|
||||
let new_id = self.current_id;
|
||||
self.ident_map.insert(
|
||||
new_id,
|
||||
IdentEntry {
|
||||
name: Some(name),
|
||||
type_space,
|
||||
},
|
||||
);
|
||||
self.current_id.0 += 1;
|
||||
new_id
|
||||
}
|
||||
|
||||
fn register_unnamed(&mut self, type_space: Option<(ast::Type, ast::StateSpace)>) -> SpirvWord {
|
||||
let new_id = self.current_id;
|
||||
self.ident_map.insert(
|
||||
new_id,
|
||||
|
@ -1727,9 +1811,191 @@ impl<'input> GlobalStringIdentResolver2<'input> {
|
|||
self.current_id.0 += 1;
|
||||
new_id
|
||||
}
|
||||
|
||||
fn get_typed(&self, id: SpirvWord) -> Result<&(ast::Type, ast::StateSpace), TranslateError> {
|
||||
match self.ident_map.get(&id) {
|
||||
Some(IdentEntry {
|
||||
type_space: Some(type_space),
|
||||
..
|
||||
}) => Ok(type_space),
|
||||
_ => Err(error_unknown_symbol()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct IdentEntry<'input> {
|
||||
name: Option<Cow<'input, str>>,
|
||||
type_space: Option<(ast::Type, ast::StateSpace)>,
|
||||
}
|
||||
|
||||
struct ScopedResolver<'input, 'b> {
|
||||
flat_resolver: &'b mut GlobalStringIdentResolver2<'input>,
|
||||
scopes: Vec<ScopeMarker<'input>>,
|
||||
}
|
||||
|
||||
impl<'input, 'b> ScopedResolver<'input, 'b> {
|
||||
fn new(flat_resolver: &'b mut GlobalStringIdentResolver2<'input>) -> Self {
|
||||
Self {
|
||||
flat_resolver,
|
||||
scopes: vec![ScopeMarker::new()],
|
||||
}
|
||||
}
|
||||
|
||||
fn start_scope(&mut self) {
|
||||
self.scopes.push(ScopeMarker::new());
|
||||
}
|
||||
|
||||
fn end_scope(&mut self) {
|
||||
let scope = self.scopes.pop().unwrap();
|
||||
scope.flush(self.flat_resolver);
|
||||
}
|
||||
|
||||
fn add(
|
||||
&mut self,
|
||||
name: Cow<'input, str>,
|
||||
type_space: Option<(ast::Type, ast::StateSpace)>,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
let result = self.flat_resolver.current_id;
|
||||
self.flat_resolver.current_id.0 += 1;
|
||||
let current_scope = self.scopes.last_mut().unwrap();
|
||||
if current_scope
|
||||
.name_to_ident
|
||||
.insert(name.clone(), result)
|
||||
.is_some()
|
||||
{
|
||||
return Err(error_unknown_symbol());
|
||||
}
|
||||
current_scope.ident_map.insert(
|
||||
result,
|
||||
IdentEntry {
|
||||
name: Some(name),
|
||||
type_space,
|
||||
},
|
||||
);
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn get(&mut self, name: &str) -> Result<SpirvWord, TranslateError> {
|
||||
self.scopes
|
||||
.iter()
|
||||
.rev()
|
||||
.find_map(|resolver| resolver.name_to_ident.get(name).copied())
|
||||
.ok_or_else(|| error_unreachable())
|
||||
}
|
||||
|
||||
fn get_in_current_scope(&self, label: &'input str) -> Result<SpirvWord, TranslateError> {
|
||||
let current_scope = self.scopes.last().unwrap();
|
||||
current_scope
|
||||
.name_to_ident
|
||||
.get(label)
|
||||
.copied()
|
||||
.ok_or_else(|| error_unreachable())
|
||||
}
|
||||
}
|
||||
|
||||
struct ScopeMarker<'input> {
|
||||
ident_map: FxHashMap<SpirvWord, IdentEntry<'input>>,
|
||||
name_to_ident: FxHashMap<Cow<'input, str>, SpirvWord>,
|
||||
}
|
||||
|
||||
impl<'input> ScopeMarker<'input> {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
ident_map: FxHashMap::default(),
|
||||
name_to_ident: FxHashMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(self, resolver: &mut GlobalStringIdentResolver2<'input>) {
|
||||
resolver.ident_map.extend(self.ident_map);
|
||||
}
|
||||
}
|
||||
|
||||
struct SpecialRegistersMap2 {
|
||||
reg_to_id: FxHashMap<PtxSpecialRegister, SpirvWord>,
|
||||
id_to_reg: FxHashMap<SpirvWord, PtxSpecialRegister>,
|
||||
}
|
||||
|
||||
impl SpecialRegistersMap2 {
|
||||
fn new(resolver: &mut ScopedResolver) -> Result<Self, TranslateError> {
|
||||
let mut result = SpecialRegistersMap2 {
|
||||
reg_to_id: FxHashMap::default(),
|
||||
id_to_reg: FxHashMap::default(),
|
||||
};
|
||||
for sreg in PtxSpecialRegister::iter() {
|
||||
let text = sreg.as_str();
|
||||
let id = resolver.add(
|
||||
Cow::Borrowed(text),
|
||||
Some((sreg.get_type(), ast::StateSpace::Reg)),
|
||||
)?;
|
||||
result.reg_to_id.insert(sreg, id);
|
||||
result.id_to_reg.insert(id, sreg);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn get(&self, id: SpirvWord) -> Option<PtxSpecialRegister> {
|
||||
self.id_to_reg.get(&id).copied()
|
||||
}
|
||||
|
||||
fn get_or_add(&mut self, current_id: &mut SpirvWord, reg: PtxSpecialRegister) -> SpirvWord {
|
||||
match self.reg_to_id.entry(reg) {
|
||||
hash_map::Entry::Occupied(e) => *e.get(),
|
||||
hash_map::Entry::Vacant(e) => {
|
||||
let numeric_id = SpirvWord(current_id.0);
|
||||
current_id.0 += 1;
|
||||
e.insert(numeric_id);
|
||||
self.id_to_reg.insert(numeric_id, reg);
|
||||
numeric_id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_declarations<'a, 'input>(
|
||||
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
||||
) -> impl ExactSizeIterator<
|
||||
Item = (
|
||||
PtxSpecialRegister,
|
||||
ast::MethodDeclaration<'input, SpirvWord>,
|
||||
),
|
||||
> + 'a {
|
||||
PtxSpecialRegister::iter().map(|sreg| {
|
||||
let external_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat();
|
||||
let name =
|
||||
ast::MethodName::Func(resolver.register_named(Cow::Owned(external_fn_name), None));
|
||||
let return_type = sreg.get_function_return_type();
|
||||
let input_type = sreg.get_function_return_type();
|
||||
(
|
||||
sreg,
|
||||
ast::MethodDeclaration {
|
||||
return_arguments: vec![ast::Variable {
|
||||
align: None,
|
||||
v_type: return_type.into(),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
name: resolver
|
||||
.register_unnamed(Some((return_type.into(), ast::StateSpace::Reg))),
|
||||
array_init: Vec::new(),
|
||||
}],
|
||||
name: name,
|
||||
input_arguments: vec![ast::Variable {
|
||||
align: None,
|
||||
v_type: input_type.into(),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
name: resolver
|
||||
.register_unnamed(Some((input_type.into(), ast::StateSpace::Reg))),
|
||||
array_init: Vec::new(),
|
||||
}],
|
||||
shared_mem: None,
|
||||
},
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct VectorAccess {
|
||||
scalar_type: ast::ScalarType,
|
||||
vector_width: u8,
|
||||
dst: SpirvWord,
|
||||
src: SpirvWord,
|
||||
member: u8,
|
||||
}
|
||||
|
|
|
@ -2,21 +2,21 @@ use super::*;
|
|||
use ptx_parser as ast;
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
pub(crate) fn run<'input>(
|
||||
fn_defs: &mut GlobalStringIdentResolver2<'input>,
|
||||
pub(crate) fn run<'input, 'b>(
|
||||
resolver: &mut ScopedResolver<'input, 'b>,
|
||||
directives: Vec<ast::Directive<'input, ast::ParsedOperand<&'input str>>>,
|
||||
) -> Result<Vec<NormalizedDirective2<'input>>, TranslateError> {
|
||||
let mut resolver = NameResolver::new(fn_defs);
|
||||
resolver.start_scope();
|
||||
let result = directives
|
||||
.into_iter()
|
||||
.map(|directive| run_directive(&mut resolver, directive))
|
||||
.map(|directive| run_directive(resolver, directive))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
resolver.end_scope();
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn run_directive<'input, 'b>(
|
||||
resolver: &mut NameResolver<'input, 'b>,
|
||||
resolver: &mut ScopedResolver<'input, 'b>,
|
||||
directive: ast::Directive<'input, ast::ParsedOperand<&'input str>>,
|
||||
) -> Result<NormalizedDirective2<'input>, TranslateError> {
|
||||
Ok(match directive {
|
||||
|
@ -30,7 +30,7 @@ fn run_directive<'input, 'b>(
|
|||
}
|
||||
|
||||
fn run_method<'input, 'b>(
|
||||
resolver: &mut NameResolver<'input, 'b>,
|
||||
resolver: &mut ScopedResolver<'input, 'b>,
|
||||
linkage: ast::LinkingDirective,
|
||||
method: ast::Function<'input, &'input str, ast::Statement<ast::ParsedOperand<&'input str>>>,
|
||||
) -> Result<NormalizedFunction2<'input>, TranslateError> {
|
||||
|
@ -41,11 +41,7 @@ fn run_method<'input, 'b>(
|
|||
}
|
||||
};
|
||||
resolver.start_scope();
|
||||
let func_decl = Rc::new(RefCell::new(run_function_decl(
|
||||
resolver,
|
||||
method.func_directive,
|
||||
name,
|
||||
)?));
|
||||
let func_decl = run_function_decl(resolver, method.func_directive, name)?;
|
||||
let body = method
|
||||
.body
|
||||
.map(|statements| {
|
||||
|
@ -66,7 +62,7 @@ fn run_method<'input, 'b>(
|
|||
}
|
||||
|
||||
fn run_function_decl<'input, 'b>(
|
||||
resolver: &mut NameResolver<'input, 'b>,
|
||||
resolver: &mut ScopedResolver<'input, 'b>,
|
||||
func_directive: ast::MethodDeclaration<'input, &'input str>,
|
||||
name: ast::MethodName<'input, SpirvWord>,
|
||||
) -> Result<ast::MethodDeclaration<'input, SpirvWord>, TranslateError> {
|
||||
|
@ -90,7 +86,7 @@ fn run_function_decl<'input, 'b>(
|
|||
}
|
||||
|
||||
fn run_variable<'input, 'b>(
|
||||
resolver: &mut NameResolver<'input, 'b>,
|
||||
resolver: &mut ScopedResolver<'input, 'b>,
|
||||
variable: ast::Variable<&'input str>,
|
||||
) -> Result<ast::Variable<SpirvWord>, TranslateError> {
|
||||
Ok(ast::Variable {
|
||||
|
@ -106,7 +102,7 @@ fn run_variable<'input, 'b>(
|
|||
}
|
||||
|
||||
fn run_statements<'input, 'b>(
|
||||
resolver: &mut NameResolver<'input, 'b>,
|
||||
resolver: &mut ScopedResolver<'input, 'b>,
|
||||
result: &mut Vec<NormalizedStatement>,
|
||||
statements: Vec<ast::Statement<ast::ParsedOperand<&'input str>>>,
|
||||
) -> Result<(), TranslateError> {
|
||||
|
@ -148,7 +144,7 @@ fn run_statements<'input, 'b>(
|
|||
}
|
||||
|
||||
fn run_instruction<'input, 'b>(
|
||||
resolver: &mut NameResolver<'input, 'b>,
|
||||
resolver: &mut ScopedResolver<'input, 'b>,
|
||||
instruction: ast::Instruction<ast::ParsedOperand<&'input str>>,
|
||||
) -> Result<ast::Instruction<ast::ParsedOperand<SpirvWord>>, TranslateError> {
|
||||
ast::visit_map(instruction, &mut |name: &'input str,
|
||||
|
@ -163,7 +159,7 @@ fn run_instruction<'input, 'b>(
|
|||
}
|
||||
|
||||
fn run_multivariable<'input, 'b>(
|
||||
resolver: &mut NameResolver<'input, 'b>,
|
||||
resolver: &mut ScopedResolver<'input, 'b>,
|
||||
result: &mut Vec<NormalizedStatement>,
|
||||
variable: ast::MultiVariable<&'input str>,
|
||||
) -> Result<(), TranslateError> {
|
||||
|
@ -201,86 +197,3 @@ fn run_multivariable<'input, 'b>(
|
|||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct NameResolver<'input, 'b> {
|
||||
flat_resolver: &'b mut GlobalStringIdentResolver2<'input>,
|
||||
scopes: Vec<ScopeStringIdentResolver<'input>>,
|
||||
}
|
||||
|
||||
impl<'input, 'b> NameResolver<'input, 'b> {
|
||||
fn new(flat_resolver: &'b mut GlobalStringIdentResolver2<'input>) -> Self {
|
||||
Self {
|
||||
flat_resolver,
|
||||
scopes: vec![ScopeStringIdentResolver::new()],
|
||||
}
|
||||
}
|
||||
|
||||
fn start_scope(&mut self) {
|
||||
self.scopes.push(ScopeStringIdentResolver::new());
|
||||
}
|
||||
|
||||
fn end_scope(&mut self) {
|
||||
let scope = self.scopes.pop().unwrap();
|
||||
scope.flush(self.flat_resolver);
|
||||
}
|
||||
|
||||
fn add(
|
||||
&mut self,
|
||||
name: Cow<'input, str>,
|
||||
type_space: Option<(ast::Type, ast::StateSpace)>,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
let result = self.flat_resolver.current_id;
|
||||
self.flat_resolver.current_id.0 += 1;
|
||||
let current_scope = self.scopes.last_mut().unwrap();
|
||||
if current_scope
|
||||
.name_to_ident
|
||||
.insert(name.clone(), result)
|
||||
.is_some()
|
||||
{
|
||||
return Err(error_unknown_symbol());
|
||||
}
|
||||
current_scope.ident_map.insert(
|
||||
result,
|
||||
IdentEntry {
|
||||
name: Some(name),
|
||||
type_space,
|
||||
},
|
||||
);
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn get(&mut self, name: &str) -> Result<SpirvWord, TranslateError> {
|
||||
self.scopes
|
||||
.iter()
|
||||
.rev()
|
||||
.find_map(|resolver| resolver.name_to_ident.get(name).copied())
|
||||
.ok_or_else(|| error_unreachable())
|
||||
}
|
||||
|
||||
fn get_in_current_scope(&self, label: &'input str) -> Result<SpirvWord, TranslateError> {
|
||||
let current_scope = self.scopes.last().unwrap();
|
||||
current_scope
|
||||
.name_to_ident
|
||||
.get(label)
|
||||
.copied()
|
||||
.ok_or_else(|| error_unreachable())
|
||||
}
|
||||
}
|
||||
|
||||
struct ScopeStringIdentResolver<'input> {
|
||||
ident_map: FxHashMap<SpirvWord, IdentEntry<'input>>,
|
||||
name_to_ident: FxHashMap<Cow<'input, str>, SpirvWord>,
|
||||
}
|
||||
|
||||
impl<'input> ScopeStringIdentResolver<'input> {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
ident_map: FxHashMap::default(),
|
||||
name_to_ident: FxHashMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(self, resolver: &mut GlobalStringIdentResolver2<'input>) {
|
||||
resolver.ident_map.extend(self.ident_map);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,6 +26,7 @@ pub(super) fn run(
|
|||
| Statement::Constant(..)
|
||||
| Statement::Label(..)
|
||||
| Statement::PtrAccess { .. }
|
||||
| Statement::VectorAccess { .. }
|
||||
| Statement::RepackVector(..)
|
||||
| Statement::FunctionPointer(..) => {}
|
||||
}
|
||||
|
|
|
@ -55,8 +55,8 @@ fn run_statement<'input>(
|
|||
Statement::Variable(var) => result.push(Statement::Variable(var)),
|
||||
Statement::Instruction((predicate, instruction)) => {
|
||||
if let Some(pred) = predicate {
|
||||
let if_true = resolver.register_intermediate(None);
|
||||
let if_false = resolver.register_intermediate(None);
|
||||
let if_true = resolver.register_unnamed(None);
|
||||
let if_false = resolver.register_unnamed(None);
|
||||
let folded_bra = match &instruction {
|
||||
ast::Instruction::Bra { arguments, .. } => Some(arguments.src),
|
||||
_ => None,
|
||||
|
|
|
@ -20,7 +20,7 @@ fn run_directive<'input>(
|
|||
var @ Directive2::Variable(..) => var,
|
||||
Directive2::Method(method) => {
|
||||
{
|
||||
let func_decl = method.func_decl.borrow();
|
||||
let func_decl = &method.func_decl;
|
||||
match func_decl.name {
|
||||
ptx_parser::MethodName::Kernel(_) => {}
|
||||
ptx_parser::MethodName::Func(name) => {
|
||||
|
|
Loading…
Add table
Reference in a new issue