diff --git a/ptx/src/pass/fix_special_registers.rs b/ptx/src/pass/fix_special_registers.rs new file mode 100644 index 0000000..d94786e --- /dev/null +++ b/ptx/src/pass/fix_special_registers.rs @@ -0,0 +1,183 @@ +use super::*; +use std::collections::HashMap; + +fn run<'a, 'b, 'input>( + ptx_impl_imports: &'a mut HashMap>, + typed_statements: Vec, + numeric_id_defs: &'a mut NumericIdResolver<'b>, +) -> Result, TranslateError> { + let result = Vec::with_capacity(typed_statements.len()); + let mut sreg_sresolver = SpecialRegisterResolver { + ptx_impl_imports, + numeric_id_defs, + result, + }; + for statement in typed_statements { + let statement = statement.visit_map(&mut sreg_sresolver)?; + sreg_sresolver.result.push(statement); + } + Ok(sreg_sresolver.result) +} + +struct SpecialRegisterResolver<'a, 'b, 'input> { + ptx_impl_imports: &'a mut HashMap>, + numeric_id_defs: &'a mut NumericIdResolver<'b>, + result: Vec, +} + +impl<'a, 'b, 'input> ast::VisitorMap + for SpecialRegisterResolver<'a, 'b, 'input> +{ + fn visit( + &mut self, + operand: TypedOperand, + _type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, + is_dst: bool, + _relaxed_type_check: bool, + ) -> Result { + operand.map(|name, vector_index| self.replace_sreg(name, is_dst, vector_index)) + } + + fn visit_ident( + &mut self, + args: SpirvWord, + _type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, + is_dst: bool, + _relaxed_type_check: bool, + ) -> Result { + self.replace_sreg(args, is_dst, None) + } +} + +impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'b, 'input> { + fn replace_sreg( + &mut self, + name: SpirvWord, + is_dst: bool, + vector_index: Option, + ) -> Result { + if let Some(sreg) = self.numeric_id_defs.special_registers.get(name) { + if is_dst { + return Err(TranslateError::MismatchedType); + } + let input_arguments = match (vector_index, sreg.get_function_input_type()) { + (Some(idx), Some(inp_type)) => { + if inp_type != ast::ScalarType::U8 { + return Err(TranslateError::Unreachable); + } + let constant = self.numeric_id_defs.register_intermediate(Some(( + ast::Type::Scalar(inp_type), + ast::StateSpace::Reg, + ))); + self.result.push(Statement::Constant(ConstantDefinition { + dst: constant, + typ: inp_type, + value: ast::ImmediateValue::U64(idx as u64), + })); + vec![( + TypedOperand::Reg(constant), + ast::Type::Scalar(inp_type), + ast::StateSpace::Reg, + )] + } + (None, None) => Vec::new(), + _ => return Err(TranslateError::MismatchedType), + }; + let ocl_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat(); + let return_type = sreg.get_function_return_type(); + let fn_result = self.numeric_id_defs.register_intermediate(Some(( + ast::Type::Scalar(return_type), + ast::StateSpace::Reg, + ))); + let return_arguments = vec![( + fn_result, + ast::Type::Scalar(return_type), + ast::StateSpace::Reg, + )]; + let fn_call = register_external_fn_call( + self.numeric_id_defs, + self.ptx_impl_imports, + ocl_fn_name.to_string(), + return_arguments.iter().map(|(_, typ, space)| (typ, *space)), + input_arguments.iter().map(|(_, typ, space)| (typ, *space)), + )?; + let data = ast::CallDetails { + uniform: false, + return_arguments: return_arguments + .iter() + .map(|(_, typ, space)| (typ.clone(), *space)) + .collect(), + input_arguments: input_arguments + .iter() + .map(|(_, typ, space)| (typ.clone(), *space)) + .collect(), + }; + let arguments = ast::CallArgs { + return_arguments: return_arguments.iter().map(|(name, _, _)| *name).collect(), + func: fn_call, + input_arguments: input_arguments.iter().map(|(name, _, _)| *name).collect(), + }; + self.result + .push(Statement::Instruction(ast::Instruction::Call { + data, + arguments, + })); + Ok(fn_result) + } else { + Ok(name) + } + } +} + +fn register_external_fn_call<'a>( + id_defs: &mut NumericIdResolver, + ptx_impl_imports: &mut HashMap, + name: String, + return_arguments: impl Iterator, + input_arguments: impl Iterator, +) -> Result { + match ptx_impl_imports.entry(name) { + hash_map::Entry::Vacant(entry) => { + let fn_id = id_defs.register_intermediate(None); + let return_arguments = fn_arguments_to_variables(id_defs, return_arguments); + let input_arguments = fn_arguments_to_variables(id_defs, input_arguments); + let func_decl = ast::MethodDeclaration:: { + return_arguments, + name: ast::MethodName::Func(fn_id), + input_arguments, + shared_mem: None, + }; + let func = Function { + func_decl: Rc::new(RefCell::new(func_decl)), + globals: Vec::new(), + body: None, + import_as: Some(entry.key().clone()), + tuning: Vec::new(), + linkage: ast::LinkingDirective::EXTERN, + }; + entry.insert(Directive::Method(func)); + Ok(fn_id) + } + hash_map::Entry::Occupied(entry) => match entry.get() { + Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name { + ast::MethodName::Func(fn_id) => Ok(fn_id), + ast::MethodName::Kernel(_) => Err(error_unreachable()), + }, + _ => Err(error_unreachable()), + }, + } +} + +fn fn_arguments_to_variables<'a>( + id_defs: &mut NumericIdResolver, + args: impl Iterator, +) -> Vec> { + args.map(|(typ, space)| ast::Variable { + align: None, + v_type: typ.clone(), + state_space: space, + name: id_defs.register_intermediate(None), + array_init: Vec::new(), + }) + .collect::>() +} \ No newline at end of file diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 3968d3d..b3bfa72 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -9,6 +9,7 @@ use std::{ }; mod convert_to_typed; +mod fix_special_registers; mod normalize_identifiers; mod normalize_predicates; @@ -735,6 +736,235 @@ enum Statement { FunctionPointer(FunctionPointerDetails), } +impl> Statement, T> { + fn visit_map, Err>( + self, + visitor: &mut impl ast::VisitorMap, + ) -> std::result::Result, T>, Err> { + Ok(match self { + Statement::Instruction(i) => { + return ast::visit_map(i, visitor).map(Statement::Instruction) + } + Statement::Label(label) => { + Statement::Label(visitor.visit_ident(label, None, false, false)?) + } + Statement::Variable(var) => { + let name = visitor.visit_ident( + var.name, + Some((&var.v_type, var.state_space)), + true, + false, + )?; + Statement::Variable(ast::Variable { + align: var.align, + v_type: var.v_type, + state_space: var.state_space, + name, + array_init: var.array_init, + }) + } + Statement::Conditional(conditional) => { + let predicate = visitor.visit_ident(conditional.predicate, None, false, false)?; + let if_true = visitor.visit_ident(conditional.if_true, None, false, false)?; + let if_false = visitor.visit_ident(conditional.if_false, None, false, false)?; + Statement::Conditional(BrachCondition { + predicate, + if_true, + if_false, + }) + } + Statement::LoadVar(LoadVarDetails { + arg, + typ, + member_index, + }) => { + let dst = visitor.visit_ident( + arg.dst, + Some((&typ, ast::StateSpace::Reg)), + true, + false, + )?; + let src = visitor.visit_ident( + arg.src, + Some((&typ, ast::StateSpace::Local)), + false, + false, + )?; + Statement::LoadVar(LoadVarDetails { + arg: ast::LdArgs { dst, src }, + typ, + member_index, + }) + } + Statement::StoreVar(StoreVarDetails { + arg, + typ, + member_index, + }) => { + let src1 = visitor.visit_ident( + arg.src1, + Some((&typ, ast::StateSpace::Local)), + false, + false, + )?; + let src2 = visitor.visit_ident( + arg.src2, + Some((&typ, ast::StateSpace::Reg)), + false, + false, + )?; + Statement::StoreVar(StoreVarDetails { + arg: ast::StArgs { src1, src2 }, + typ, + member_index, + }) + } + Statement::Conversion(ImplicitConversion { + src, + dst, + from_type, + to_type, + from_space, + to_space, + kind, + }) => { + let dst = visitor.visit_ident( + dst, + Some((&to_type, ast::StateSpace::Reg)), + true, + false, + )?; + let src = visitor.visit_ident( + src, + Some((&from_type, ast::StateSpace::Reg)), + false, + false, + )?; + Statement::Conversion(ImplicitConversion { + src, + dst, + from_type, + to_type, + from_space, + to_space, + kind, + }) + } + Statement::Constant(ConstantDefinition { dst, typ, value }) => { + let dst = visitor.visit_ident( + dst, + Some((&typ.into(), ast::StateSpace::Reg)), + true, + false, + )?; + Statement::Constant(ConstantDefinition { dst, typ, value }) + } + Statement::RetValue(data, value) => { + // TODO: + // We should report type here + let value = visitor.visit_ident(value, None, false, false)?; + Statement::RetValue(data, value) + } + Statement::PtrAccess(PtrAccess { + underlying_type, + state_space, + dst, + ptr_src, + offset_src, + }) => { + let dst = + visitor.visit_ident(dst, Some((&underlying_type, state_space)), true, false)?; + let ptr_src = visitor.visit_ident( + ptr_src, + Some((&underlying_type, state_space)), + false, + false, + )?; + Statement::PtrAccess(PtrAccess { + underlying_type, + state_space, + dst, + ptr_src, + offset_src, + }) + } + Statement::RepackVector(RepackVectorDetails { + is_extract, + typ, + packed, + unpacked, + relaxed_type_check, + }) => { + let (packed, unpacked) = if is_extract { + let unpacked = unpacked + .into_iter() + .map(|ident| { + visitor.visit_ident( + ident, + Some((&typ.into(), ast::StateSpace::Reg)), + true, + relaxed_type_check, + ) + }) + .collect::, _>>()?; + let packed = visitor.visit_ident( + packed, + Some(( + &ast::Type::Vector(typ, unpacked.len() as u8), + ast::StateSpace::Reg, + )), + false, + false, + )?; + (packed, unpacked) + } else { + let packed = visitor.visit_ident( + packed, + Some(( + &ast::Type::Vector(typ, unpacked.len() as u8), + ast::StateSpace::Reg, + )), + true, + false, + )?; + let unpacked = unpacked + .into_iter() + .map(|ident| { + visitor.visit_ident( + ident, + Some((&typ.into(), ast::StateSpace::Reg)), + false, + relaxed_type_check, + ) + }) + .collect::, _>>()?; + (packed, unpacked) + }; + Statement::RepackVector(RepackVectorDetails { + is_extract, + typ, + packed, + unpacked, + relaxed_type_check, + }) + } + Statement::FunctionPointer(FunctionPointerDetails { dst, src }) => { + let dst = visitor.visit_ident( + dst, + Some(( + &ast::Type::Scalar(ast::ScalarType::U64), + ast::StateSpace::Reg, + )), + true, + false, + )?; + let src = visitor.visit_ident(src, None, false, false)?; + Statement::FunctionPointer(FunctionPointerDetails { dst, src }) + } + }) + } +} + struct BrachCondition { predicate: SpirvWord, if_true: SpirvWord, @@ -743,7 +973,6 @@ struct BrachCondition { struct LoadVarDetails { arg: ast::LdArgs, typ: ast::Type, - state_space: ast::StateSpace, // (index, vector_width) // HACK ALERT // For some reason IGC explodes when you try to load from builtin vectors @@ -798,7 +1027,7 @@ struct RepackVectorDetails { typ: ast::ScalarType, packed: SpirvWord, unpacked: Vec, - relaxed_type_check: bool + relaxed_type_check: bool, } struct FunctionPointerDetails { @@ -876,6 +1105,20 @@ enum TypedOperand { VecMember(SpirvWord, u8), } +impl TypedOperand { + fn map( + self, + fn_: impl FnOnce(SpirvWord, Option) -> Result, + ) -> Result { + Ok(match self { + TypedOperand::Reg(reg) => TypedOperand::Reg(fn_(reg, None)?), + TypedOperand::RegOffset(reg, off) => TypedOperand::RegOffset(fn_(reg, None)?, off), + TypedOperand::Imm(imm) => TypedOperand::Imm(imm), + TypedOperand::VecMember(reg, idx) => TypedOperand::VecMember(fn_(reg, Some(idx))?, idx), + }) + } +} + impl ast::Operand for TypedOperand { type Ident = SpirvWord;