diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index d3460d7..7644cb6 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -975,8 +975,7 @@ fn compute_denorm_information<'input>( Statement::Label(_) => {} Statement::Variable(_) => {} Statement::PtrAccess { .. } => {} - Statement::PackVector(_) => {} - Statement::UnpackVector(_) => {} + Statement::RepackVector(_) => {} } } denorm_methods.insert(method_key, flush_counter); @@ -1477,7 +1476,7 @@ fn convert_to_typed_statements( }; d.src_is_address = take_address; } - let mut visitor = VectorPackingVisitor::new(&mut result, id_defs); + let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); let instruction = Statement::Instruction( ast::Instruction::Mov( d, @@ -1488,12 +1487,14 @@ fn convert_to_typed_statements( ) .map(&mut visitor)?, ); - result.push(instruction); + visitor.func.push(instruction); + visitor.func.extend(visitor.post_stmts); } inst => { - let mut visitor = VectorPackingVisitor::new(&mut result, id_defs); + let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); let instruction = Statement::Instruction(inst.map(&mut visitor)?); - result.push(instruction); + visitor.func.push(instruction); + visitor.func.extend(visitor.post_stmts); } }, Statement::Label(i) => result.push(Statement::Label(i)), @@ -1505,24 +1506,52 @@ fn convert_to_typed_statements( Ok(result) } -struct VectorPackingVisitor<'a, 'b> { +struct VectorRepackVisitor<'a, 'b> { func: &'b mut Vec, id_def: &'b mut NumericIdResolver<'a>, - post_stmts: Vec, + post_stmts: Option, } -impl<'a, 'b> VectorPackingVisitor<'a, 'b> { +impl<'a, 'b> VectorRepackVisitor<'a, 'b> { fn new(func: &'b mut Vec, id_def: &'b mut NumericIdResolver<'a>) -> Self { - VectorPackingVisitor { + VectorRepackVisitor { func, id_def, - post_stmts: Vec::new(), + post_stmts: None, } } + + fn convert_vector( + &mut self, + is_dst: bool, + vector_sema: ArgumentSemantics, + typ: &ast::Type, + idx: Vec, + ) -> Result { + // mov.u32 foobar, {a,b}; + let scalar_t = match typ { + ast::Type::Vector(scalar_t, _) => *scalar_t, + _ => return Err(TranslateError::MismatchedType), + }; + let temp_vec = self.id_def.new_non_variable(Some(typ.clone())); + let statement = Statement::RepackVector(RepackVector { + is_extract: is_dst, + typ: scalar_t, + packed: temp_vec, + unpacked: idx, + vector_sema, + }); + if is_dst { + self.post_stmts = Some(statement); + } else { + self.func.push(statement); + } + Ok(temp_vec) + } } impl<'a, 'b> ArgumentMapVisitor - for VectorPackingVisitor<'a, 'b> + for VectorRepackVisitor<'a, 'b> { fn id( &mut self, @@ -1555,7 +1584,12 @@ impl<'a, 'b> ArgumentMapVisitor ) -> Result, TranslateError> { match desc.op { ast::DstOperandVec::Normal(op) => self.dst_operand(desc.new_op(op), typ), - ast::DstOperandVec::Vector(vec) => todo!(), + ast::DstOperandVec::Vector(vec) => Ok(ast::DstOperand::Reg(self.convert_vector( + desc.is_dst, + desc.sema, + typ, + vec, + )?)), } } @@ -1566,7 +1600,12 @@ impl<'a, 'b> ArgumentMapVisitor ) -> Result, TranslateError> { match desc.op { ast::SrcOperandVec::Normal(op) => self.src_operand(desc.new_op(op), typ), - ast::SrcOperandVec::Vector(_) => todo!(), + ast::SrcOperandVec::Vector(vec) => Ok(ast::SrcOperand::Reg(self.convert_vector( + desc.is_dst, + desc.sema, + typ, + vec, + )?)), } } } @@ -1799,8 +1838,7 @@ fn normalize_labels( | Statement::Label(..) | Statement::Undef(..) | Statement::PtrAccess { .. } - | Statement::PackVector(..) - | Statement::UnpackVector(..) => {} + | Statement::RepackVector(..) => {} } } iter::once(Statement::Label(id_def.new_non_variable(None))) @@ -1944,6 +1982,9 @@ fn insert_mem_ssa_statements<'a, 'b>( Statement::PtrAccess(ptr_access) => { insert_mem_ssa_statement_default(id_def, &mut result, ptr_access)? } + Statement::RepackVector(repack) => { + insert_mem_ssa_statement_default(id_def, &mut result, repack)? + } s @ Statement::Variable(_) | s @ Statement::Label(_) => result.push(s), _ => return Err(TranslateError::Unreachable), } @@ -2111,6 +2152,12 @@ fn expand_arguments<'a, 'b>( result.push(Statement::PtrAccess(new_inst)); result.extend(post_stmts); } + Statement::RepackVector(repack) => { + let mut visitor = FlattenArguments::new(&mut result, id_def); + let (new_inst, post_stmts) = (repack.map(&mut visitor)?, visitor.post_stmts); + result.push(Statement::RepackVector(new_inst)); + result.extend(post_stmts); + } Statement::Label(id) => result.push(Statement::Label(id)), Statement::Conditional(bra) => result.push(Statement::Conditional(bra)), Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)), @@ -2120,8 +2167,6 @@ fn expand_arguments<'a, 'b>( Statement::Composite(_) | Statement::Constant(_) | Statement::Undef(_, _) => { return Err(TranslateError::Unreachable) } - Statement::PackVector(_) => todo!(), - Statement::UnpackVector(_) => todo!(), } } Ok(result) @@ -2445,6 +2490,13 @@ fn insert_implicit_conversions( Some(state_space), )?; } + Statement::RepackVector(repack) => insert_implicit_conversions_impl( + &mut result, + id_def, + repack, + should_bitcast_wrapper, + None, + )?, s @ Statement::Conditional(_) | s @ Statement::Conversion(_) | s @ Statement::Label(_) @@ -2454,8 +2506,6 @@ fn insert_implicit_conversions( | s @ Statement::StoreVar(_, _) | s @ Statement::Undef(_, _) | s @ Statement::RetValue(_, _) => result.push(s), - Statement::PackVector(_) => todo!(), - Statement::UnpackVector(_) => todo!(), } } Ok(result) @@ -3081,8 +3131,38 @@ fn emit_function_body_ops( )?; builder.bitcast(result_type, Some(*dst), temp)?; } - Statement::PackVector(_) => todo!(), - Statement::UnpackVector(_) => todo!(), + Statement::RepackVector(repack) => { + if repack.is_extract { + let scalar_type = map.get_or_add_scalar(builder, repack.typ); + for (index, dst_id) in repack.unpacked.iter().enumerate() { + builder.composite_extract( + scalar_type, + Some(*dst_id), + repack.packed, + &[index as u32], + )?; + } + } else { + let vector_type = map.get_or_add( + builder, + SpirvType::Vector( + SpirvScalarKey::from(repack.typ), + repack.unpacked.len() as u8, + ), + ); + let mut temp_vec = builder.undef(vector_type, None); + for (index, src_id) in repack.unpacked.iter().enumerate() { + temp_vec = builder.composite_insert( + vector_type, + None, + *src_id, + temp_vec, + &[index as u32], + )?; + } + builder.copy_object(vector_type, Some(repack.packed), temp_vec)?; + } + } } } Ok(()) @@ -4334,9 +4414,7 @@ fn convert_to_stateful_memory_access<'a>( }, )?; result.push(new_statement); - for s in post_statements { - result.push(s); - } + result.extend(post_statements); } Statement::Call(call) => { let mut post_statements = Vec::new(); @@ -4354,9 +4432,25 @@ fn convert_to_stateful_memory_access<'a>( }, )?; result.push(new_statement); - for s in post_statements { - result.push(s); - } + result.extend(post_statements); + } + Statement::RepackVector(pack) => { + let mut post_statements = Vec::new(); + let new_statement = pack.visit_variable( + &mut |arg_desc: ArgumentDescriptor, expected_type| { + convert_to_stateful_memory_access_postprocess( + id_defs, + &remapped_ids, + &func_args_ptr, + &mut result, + &mut post_statements, + arg_desc, + expected_type, + ) + }, + )?; + result.push(new_statement); + result.extend(post_statements); } _ => return Err(TranslateError::Unreachable), } @@ -4810,8 +4904,7 @@ enum Statement { RetValue(ast::RetData, spirv::Word), Undef(ast::Type, spirv::Word), PtrAccess(PtrAccess

), - PackVector(PackVector), - UnpackVector(UnpackVector), + RepackVector(RepackVector), } impl ExpandedStatement { @@ -4898,14 +4991,101 @@ impl ExpandedStatement { offset_src: constant_src, }) } - Statement::PackVector(_) => todo!(), - Statement::UnpackVector(_) => todo!(), + Statement::RepackVector(_) => todo!(), } } } -struct PackVector {} -struct UnpackVector {} +struct RepackVector { + is_extract: bool, + typ: ast::ScalarType, + packed: spirv::Word, + unpacked: Vec, + vector_sema: ArgumentSemantics, +} + +impl RepackVector { + fn map< + From: ArgParamsEx, + To: ArgParamsEx, + V: ArgumentMapVisitor, + >( + self, + visitor: &mut V, + ) -> Result { + let scalar = visitor.id( + ArgumentDescriptor { + op: self.packed, + is_dst: !self.is_extract, + sema: ArgumentSemantics::Default, + }, + Some(&ast::Type::Vector(self.typ, self.unpacked.len() as u8)), + )?; + let scalar_type = self.typ; + let is_extract = self.is_extract; + let vector_sema = self.vector_sema; + let vector = self + .unpacked + .into_iter() + .map(|id| { + visitor.id( + ArgumentDescriptor { + op: id, + is_dst: is_extract, + sema: vector_sema, + }, + Some(&ast::Type::Scalar(scalar_type)), + ) + }) + .collect::>()?; + Ok(RepackVector { + is_extract, + typ: self.typ, + packed: scalar, + unpacked: vector, + vector_sema, + }) + } +} + +impl VisitVariable for RepackVector { + fn visit_variable< + 'a, + F: FnMut( + ArgumentDescriptor, + Option<&ast::Type>, + ) -> Result, + >( + self, + f: &mut F, + ) -> Result { + Ok(TypedStatement::RepackVector( + self.map::(f)?, + )) + } +} + +impl VisitVariableExpanded for RepackVector { + fn visit_variable_extended< + F: FnMut( + ArgumentDescriptor, + Option<&ast::Type>, + ) -> Result, + >( + self, + f: &mut F, + ) -> Result { + Ok(ExpandedStatement::RepackVector( + self.map::(f)?, + )) + } +} + +struct UnpackVector { + typ: ast::ScalarType, + dst: Vec, + src: spirv::Word, +} struct ResolvedCall { pub uniform: bool, @@ -6737,22 +6917,6 @@ impl ast::Arg5Setp { } } -impl ast::Type { - fn get_vector(&self) -> Result<(ast::ScalarType, u8), TranslateError> { - match self { - ast::Type::Vector(t, len) => Ok((*t, *len)), - _ => Err(TranslateError::MismatchedType), - } - } - - fn get_scalar(&self) -> Result { - match self { - ast::Type::Scalar(t) => Ok(*t), - _ => Err(TranslateError::MismatchedType), - } - } -} - impl ast::SrcOperand { fn map_variable Result>( self,