From 5976e55b2b01b9805ab05d81b338cfd76ee11a1d Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 4 Dec 2020 00:15:42 +0100 Subject: [PATCH] Start fixing builtins --- ptx/src/translate.rs | 209 ++++++++++++++++++++++++++++++++++++------- 1 file changed, 178 insertions(+), 31 deletions(-) diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 20578eb..d371293 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1030,7 +1030,7 @@ fn emit_builtins( map: &mut TypeWordMap, id_defs: &GlobalStringIdResolver, ) { - for (reg, id) in id_defs.special_registers.iter() { + for (reg, id) in id_defs.special_registers.builtins() { let result_type = map.get_or_add( builder, SpirvType::Pointer( @@ -1038,9 +1038,9 @@ fn emit_builtins( spirv::StorageClass::Input, ), ); - builder.variable(result_type, Some(*id), spirv::StorageClass::Input, None); + builder.variable(result_type, Some(id), spirv::StorageClass::Input, None); builder.decorate( - *id, + id, spirv::Decoration::BuiltIn, &[dr::Operand::BuiltIn(reg.get_builtin())], ); @@ -1086,11 +1086,7 @@ fn emit_function_header<'a>( .iter() .filter_map(|(k, t)| t.as_ref().map(|_| *k)) .collect::>(); - let mut interface = defined_globals - .special_registers - .iter() - .map(|(_, id)| *id) - .collect::>(); + let mut interface = defined_globals.special_registers.interface(); for ast::Variable { name, .. } in synthetic_globals { interface.push(*name); } @@ -1320,6 +1316,7 @@ fn to_ssa<'input, 'b>( convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?; let typed_statements = convert_to_stateful_memory_access(&mut spirv_decl, typed_statements, &mut numeric_id_defs)?; + let typed_statements = fix_builtins(typed_statements, &mut numeric_id_defs)?; let ssa_statements = insert_mem_ssa_statements( typed_statements, &mut numeric_id_defs, @@ -1343,6 +1340,75 @@ fn to_ssa<'input, 'b>( }) } +fn fix_builtins( + typed_statements: Vec, + numeric_id_defs: &mut NumericIdResolver, +) -> Result, TranslateError> { + let mut result = Vec::with_capacity(typed_statements.len()); + let mut visitor = FixBuiltinsVisitor {}; + for s in typed_statements { + match s { + Statement::Instruction(inst) => { + result.push(Statement::Instruction(inst.map(&mut visitor)?)) + } + s => result.push(s), + } + } + Ok(result) +} + +struct FixBuiltinsVisitor {} + +impl ArgumentMapVisitor for FixBuiltinsVisitor { + fn id( + &mut self, + desc: ArgumentDescriptor, + typ: Option<&ast::Type>, + ) -> Result { + todo!() + } + + fn operand( + &mut self, + desc: ArgumentDescriptor>, + typ: &ast::Type, + ) -> Result, TranslateError> { + todo!() + } + + fn id_or_vector( + &mut self, + desc: ArgumentDescriptor>, + typ: &ast::Type, + ) -> Result, TranslateError> { + todo!() + } + + fn operand_or_vector( + &mut self, + desc: ArgumentDescriptor>, + typ: &ast::Type, + ) -> Result, TranslateError> { + todo!() + } + + fn src_call_operand( + &mut self, + desc: ArgumentDescriptor>, + typ: &ast::Type, + ) -> Result, TranslateError> { + todo!() + } + + fn src_member_operand( + &mut self, + desc: ArgumentDescriptor<(spirv::Word, u8)>, + typ: (ast::ScalarType, u8), + ) -> Result<(spirv::Word, u8), TranslateError> { + todo!() + } +} + fn extract_globals<'input, 'b>( sorted_statements: Vec, ptx_impl_imports: &mut HashMap, @@ -4599,9 +4665,13 @@ fn is_64_bit_integer(id_defs: &NumericIdResolver, id: spirv::Word) -> bool { #[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)] enum PtxSpecialRegister { Tid, + Tid64, Ntid, + Ntid64, Ctaid, + Ctaid64, Nctaid, + Nctaid64, } impl PtxSpecialRegister { @@ -4618,27 +4688,116 @@ impl PtxSpecialRegister { fn get_type(self) -> ast::Type { match self { PtxSpecialRegister::Tid => ast::Type::Vector(ast::ScalarType::U32, 4), + PtxSpecialRegister::Tid64 => ast::Type::Vector(ast::ScalarType::U64, 3), PtxSpecialRegister::Ntid => ast::Type::Vector(ast::ScalarType::U32, 4), + PtxSpecialRegister::Ntid64 => ast::Type::Vector(ast::ScalarType::U64, 3), PtxSpecialRegister::Ctaid => ast::Type::Vector(ast::ScalarType::U32, 4), + PtxSpecialRegister::Ctaid64 => ast::Type::Vector(ast::ScalarType::U64, 3), PtxSpecialRegister::Nctaid => ast::Type::Vector(ast::ScalarType::U32, 4), + PtxSpecialRegister::Nctaid64 => ast::Type::Vector(ast::ScalarType::U64, 3), } } fn get_builtin(self) -> spirv::BuiltIn { match self { - PtxSpecialRegister::Tid => spirv::BuiltIn::LocalInvocationId, - PtxSpecialRegister::Ntid => spirv::BuiltIn::WorkgroupSize, - PtxSpecialRegister::Ctaid => spirv::BuiltIn::WorkgroupId, - PtxSpecialRegister::Nctaid => spirv::BuiltIn::NumWorkgroups, + PtxSpecialRegister::Tid | PtxSpecialRegister::Tid64 => { + spirv::BuiltIn::LocalInvocationId + } + PtxSpecialRegister::Ntid | PtxSpecialRegister::Ntid64 => spirv::BuiltIn::WorkgroupSize, + PtxSpecialRegister::Ctaid | PtxSpecialRegister::Ctaid64 => spirv::BuiltIn::WorkgroupId, + PtxSpecialRegister::Nctaid | PtxSpecialRegister::Nctaid64 => { + spirv::BuiltIn::NumWorkgroups + } } } + + fn push_conversion( + self, + id_def: &mut NumericIdResolver, + func: &mut Vec, + mut composite_read: CompositeRead, + ) { + todo!() + /* + match self { + PtxSpecialRegister::Tid + | PtxSpecialRegister::Ntid + | PtxSpecialRegister::Ctaid + | PtxSpecialRegister::Nctaid => { + if composite_read.src_index == 3 { + func.push(Statement::Constant(ConstantDefinition { + dst: composite_read.dst, + typ: ast::ScalarType::U32, + value: ast::ImmediateValue::U64(0), + })); + } else { + let dst = composite_read.dst; + let temp_dst = + id_def.new_non_variable(Some(ast::Type::Scalar(ast::ScalarType::U64))); + composite_read.dst = temp_dst; + func.push(Statement::Composite(composite_read)); + func.push(Statement::Conversion(ImplicitConversion { + src: temp_dst, + dst: dst, + from: ast::Type::Scalar(ast::ScalarType::U64), + to: ast::Type::Scalar(ast::ScalarType::U32), + kind: ConversionKind::Default, + src_sema: ArgumentSemantics::Default, + dst_sema: ArgumentSemantics::Default, + })); + } + } + } + */ + } +} + +struct SpecialRegistersMap { + reg_to_id: HashMap, + id_to_reg: HashMap, +} + +impl SpecialRegistersMap { + fn new() -> Self { + SpecialRegistersMap { + reg_to_id: HashMap::new(), + id_to_reg: HashMap::new(), + } + } + + fn builtins<'a>(&'a self) -> impl Iterator + 'a { + self.reg_to_id.iter().map(|(reg, id)| (*reg, *id)) + } + + fn interface(&self) -> Vec { + self.id_to_reg.iter().map(|(id, _)| *id).collect::>() + } + + fn get(&self, id: spirv::Word) -> Option { + self.id_to_reg.get(&id).copied() + } + + fn get_or_add(&mut self, current_id: &mut spirv::Word, reg: PtxSpecialRegister) -> spirv::Word { + match self.reg_to_id.entry(reg) { + hash_map::Entry::Occupied(e) => *e.get(), + hash_map::Entry::Vacant(e) => { + let numeric_id = *current_id; + *current_id += 1; + e.insert(numeric_id); + self.id_to_reg.insert(numeric_id, reg); + numeric_id + } + } + } + + fn replace(&mut self) {} } struct GlobalStringIdResolver<'input> { current_id: spirv::Word, variables: HashMap, spirv::Word>, variables_type_check: HashMap>, - special_registers: HashMap, + special_registers: SpecialRegistersMap, fns: HashMap, } @@ -4653,7 +4812,7 @@ impl<'a> GlobalStringIdResolver<'a> { current_id: start_id, variables: HashMap::new(), variables_type_check: HashMap::new(), - special_registers: HashMap::new(), + special_registers: SpecialRegistersMap::new(), fns: HashMap::new(), } } @@ -4768,7 +4927,7 @@ struct FnStringIdResolver<'input, 'b> { current_id: &'b mut spirv::Word, global_variables: &'b HashMap, spirv::Word>, global_type_check: &'b HashMap>, - special_registers: &'b mut HashMap, + special_registers: &'b mut SpecialRegistersMap, variables: Vec, spirv::Word>>, type_check: HashMap>, } @@ -4779,11 +4938,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { current_id: self.current_id, global_type_check: self.global_type_check, type_check: self.type_check, - special_registers: self - .special_registers - .iter() - .map(|(reg, id)| (*id, *reg)) - .collect(), + special_registers: self.special_registers, } } @@ -4807,15 +4962,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { None => { let sreg = PtxSpecialRegister::try_parse(id).ok_or(TranslateError::UnknownSymbol)?; - match self.special_registers.entry(sreg) { - hash_map::Entry::Occupied(e) => Ok(*e.get()), - hash_map::Entry::Vacant(e) => { - let numeric_id = *self.current_id; - *self.current_id += 1; - e.insert(numeric_id); - Ok(numeric_id) - } - } + Ok(self.special_registers.get_or_add(self.current_id, sreg)) } } } @@ -4858,7 +5005,7 @@ struct NumericIdResolver<'b> { current_id: &'b mut spirv::Word, global_type_check: &'b HashMap>, type_check: HashMap>, - special_registers: HashMap, + special_registers: &'b mut SpecialRegistersMap, } impl<'b> NumericIdResolver<'b> { @@ -4870,7 +5017,7 @@ impl<'b> NumericIdResolver<'b> { match self.type_check.get(&id) { Some(Some(x)) => Ok(x.clone()), Some(None) => Err(TranslateError::UntypedSymbol), - None => match self.special_registers.get(&id) { + None => match self.special_registers.get(id) { Some(x) => Ok((x.get_type(), true)), None => match self.global_type_check.get(&id) { Some(Some(result)) => Ok(result.clone()),