diff --git a/ptx/src/pass/convert_to_typed.rs b/ptx/src/pass/convert_to_typed.rs new file mode 100644 index 0000000..3dfef55 --- /dev/null +++ b/ptx/src/pass/convert_to_typed.rs @@ -0,0 +1,140 @@ +use super::*; +use ptx_parser as ast; + +pub(crate) fn run( + func: Vec, + fn_defs: &GlobalFnDeclResolver, + id_defs: &mut NumericIdResolver, +) -> Result, TranslateError> { + let mut result = Vec::::with_capacity(func.len()); + for s in func { + match s { + Statement::Instruction(inst) => match inst { + ast::Instruction::Mov { + data, + arguments: + ast::MovArgs { + dst: ast::ParsedOperand::Reg(dst_reg), + src: ast::ParsedOperand::Reg(src_reg), + }, + } if fn_defs.fns.contains_key(&src_reg) => { + if data.typ != ast::Type::Scalar(ast::ScalarType::U64) { + return Err(TranslateError::MismatchedType); + } + result.push(TypedStatement::FunctionPointer(FunctionPointerDetails { + dst: dst_reg, + src: src_reg, + })); + } + ast::Instruction::Call(call) => { + let resolver = fn_defs.get_fn_sig_resolver(call.func)?; + let resolved_call = resolver.resolve_in_spirv_repr(call)?; + let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); + let reresolved_call = resolved_call.visit(&mut visitor)?; + visitor.func.push(reresolved_call); + visitor.func.extend(visitor.post_stmts); + } + inst => { + let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); + let instruction = Statement::Instruction(inst.map(&mut visitor)?); + visitor.func.push(instruction); + visitor.func.extend(visitor.post_stmts); + } + }, + Statement::Label(i) => result.push(Statement::Label(i)), + Statement::Variable(v) => result.push(Statement::Variable(v)), + Statement::Conditional(c) => result.push(Statement::Conditional(c)), + _ => return Err(error_unreachable()), + } + } + Ok(result) +} + +struct VectorRepackVisitor<'a, 'b> { + func: &'b mut Vec, + id_def: &'b mut NumericIdResolver<'a>, + post_stmts: Option, +} + +impl<'a, 'b> VectorRepackVisitor<'a, 'b> { + fn new(func: &'b mut Vec, id_def: &'b mut NumericIdResolver<'a>) -> Self { + VectorRepackVisitor { + func, + id_def, + post_stmts: None, + } + } + + fn convert_vector( + &mut self, + is_dst: bool, + non_default_implicit_conversion: Option< + fn( + (ast::StateSpace, &ast::Type), + (ast::StateSpace, &ast::Type), + ) -> Result, TranslateError>, + >, + typ: &ast::Type, + state_space: ast::StateSpace, + idx: Vec, + ) -> Result { + // mov.u32 foobar, {a,b}; + let scalar_t = match typ { + ast::Type::Vector(scalar_t, _) => *scalar_t, + _ => return Err(TranslateError::MismatchedType), + }; + let temp_vec = self + .id_def + .register_intermediate(Some((typ.clone(), state_space))); + let statement = Statement::RepackVector(RepackVectorDetails { + is_extract: is_dst, + typ: scalar_t, + packed: temp_vec, + unpacked: idx, + non_default_implicit_conversion, + }); + if is_dst { + self.post_stmts = Some(statement); + } else { + self.func.push(statement); + } + Ok(temp_vec) + } +} + +impl<'a, 'b> ast::VisitorMap, TypedOperand, TranslateError> + for VectorRepackVisitor<'a, 'b> +{ + fn visit_ident( + &mut self, + ident: SpirvWord, + _: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, + _: bool, + ) -> Result { + Ok(ident) + } + + fn visit( + &mut self, + op: ast::ParsedOperand, + type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, + is_dst: bool, + ) -> Result { + Ok(match op { + ast::ParsedOperand::Reg(reg) => TypedOperand::Reg(reg), + ast::ParsedOperand::RegOffset(reg, offset) => TypedOperand::RegOffset(reg, offset), + ast::ParsedOperand::Imm(x) => TypedOperand::Imm(x), + ast::ParsedOperand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx), + ast::ParsedOperand::VecPack(vec) => { + let (type_, space) = type_space.ok_or(TranslateError::MismatchedType)?; + TypedOperand::Reg(self.convert_vector( + is_dst, + desc.non_default_implicit_conversion, + type_, + space, + vec, + )?) + } + }) + } +} diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 934a472..bedf46a 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -8,7 +8,9 @@ use std::{ rc::Rc, }; -pub(crate) mod normalize; +mod convert_to_typed; +mod normalize_identifiers; +mod normalize_predicates; static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv"); static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc"); @@ -161,13 +163,13 @@ fn to_ssa<'input, 'b>( }) } }; - let normalized_ids = normalize::run(&mut id_defs, &fn_defs, f_body)?; + let normalized_ids = normalize_identifiers::run(&mut id_defs, &fn_defs, f_body)?; + let mut numeric_id_defs = id_defs.finish(); + let unadorned_statements = normalize_predicates::run(normalized_ids, &mut numeric_id_defs)?; + let typed_statements = + convert_to_typed::run(unadorned_statements, &fn_defs, &mut numeric_id_defs)?; todo!() /* - let mut numeric_id_defs = id_defs.finish(); - let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?; - let typed_statements = - convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?; let typed_statements = fix_special_registers2(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?; let (func_decl, typed_statements) = @@ -856,4 +858,33 @@ pub(crate) struct Function<'input> { linkage: ast::LinkingDirective, } -type ExpandedStatement = Statement, SpirvWord>; \ No newline at end of file +type ExpandedStatement = Statement, SpirvWord>; + +type NormalizedStatement = Statement< + ( + Option>, + ast::Instruction>, + ), + ast::ParsedOperand, +>; + +type UnconditionalStatement = + Statement>, ast::ParsedOperand>; + +type TypedStatement = Statement, TypedOperand>; + +#[derive(Copy, Clone)] +enum TypedOperand { + Reg(SpirvWord), + RegOffset(SpirvWord, i32), + Imm(ast::ImmediateValue), + VecMember(SpirvWord, u8), +} + +impl ast::Operand for TypedOperand { + type Ident = SpirvWord; + + fn from_ident(ident: Self::Ident) -> Self { + TypedOperand::Reg(ident) + } +} diff --git a/ptx/src/pass/normalize.rs b/ptx/src/pass/normalize_identifiers.rs similarity index 91% rename from ptx/src/pass/normalize.rs rename to ptx/src/pass/normalize_identifiers.rs index 68ac26e..6588d63 100644 --- a/ptx/src/pass/normalize.rs +++ b/ptx/src/pass/normalize_identifiers.rs @@ -1,14 +1,6 @@ use super::*; use ptx_parser as ast; -type NormalizedStatement = Statement< - ( - Option>, - ast::Instruction>, - ), - ast::ParsedOperand, ->; - pub(crate) fn run<'input, 'b>( id_defs: &mut FnStringIdResolver<'input, 'b>, fn_defs: &GlobalFnDeclResolver<'input, 'b>, diff --git a/ptx/src/pass/normalize_predicates.rs b/ptx/src/pass/normalize_predicates.rs new file mode 100644 index 0000000..c971cfa --- /dev/null +++ b/ptx/src/pass/normalize_predicates.rs @@ -0,0 +1,44 @@ +use super::*; +use ptx_parser as ast; + +pub(crate) fn run( + func: Vec, + id_def: &mut NumericIdResolver, +) -> Result, TranslateError> { + let mut result = Vec::with_capacity(func.len()); + for s in func { + match s { + Statement::Label(id) => result.push(Statement::Label(id)), + Statement::Instruction((pred, inst)) => { + if let Some(pred) = pred { + let if_true = id_def.register_intermediate(None); + let if_false = id_def.register_intermediate(None); + let folded_bra = match &inst { + ast::Instruction::Bra { arguments, .. } => Some(arguments.src), + _ => None, + }; + let mut branch = BrachCondition { + predicate: pred.label, + if_true: folded_bra.unwrap_or(if_true), + if_false, + }; + if pred.not { + std::mem::swap(&mut branch.if_true, &mut branch.if_false); + } + result.push(Statement::Conditional(branch)); + if folded_bra.is_none() { + result.push(Statement::Label(if_true)); + result.push(Statement::Instruction(inst)); + } + result.push(Statement::Label(if_false)); + } else { + result.push(Statement::Instruction(inst)); + } + } + Statement::Variable(var) => result.push(Statement::Variable(var)), + // Blocks are flattened when resolving ids + _ => return Err(error_unreachable()), + } + } + Ok(result) +}