diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 8fcf82a..413d3f5 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -717,7 +717,7 @@ pub enum SrcOperand { Reg(Id), RegOffset(Id, i32), Imm(ImmediateValue), - VecIndex(Id, u8), + VecMember(Id, u8), } #[derive(Clone)] diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 91abdde..6d9f93d 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -1856,7 +1856,7 @@ SrcOperand: ast::SrcOperand<&'input str> = { => ast::SrcOperand::Imm(x), => { let (reg, idx) = mem_op; - ast::SrcOperand::VecIndex(reg, idx) + ast::SrcOperand::VecMember(reg, idx) } } diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 7644cb6..c31e0a2 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -27,6 +27,16 @@ quick_error! { } } +#[cfg(debug_assertions)] +fn error_unreachable() -> TranslateError { + unreachable!() +} + +#[cfg(not(debug_assertions))] +fn error_unreachable() -> TranslateError { + TranslateError::Unreachable +} + #[derive(PartialEq, Eq, Hash, Clone)] enum SpirvType { Base(SpirvScalarKey), @@ -82,7 +92,7 @@ impl ast::Type { ast::Type::Pointer(ast::PointerType::Scalar(t), space) => { ast::Type::Pointer(ast::PointerType::Pointer(t, space), space) } - ast::Type::Pointer(_, _) => return Err(TranslateError::Unreachable), + ast::Type::Pointer(_, _) => return Err(error_unreachable()), }) } } @@ -364,7 +374,7 @@ impl TypeWordMap { b.constant_composite(result_type, None, &components) } ast::Type::Array(typ, dims) => match dims.as_slice() { - [] => return Err(TranslateError::Unreachable), + [] => return Err(error_unreachable()), [dim] => { let result_type = self .get_or_add(b, SpirvType::Array(SpirvScalarKey::from(*typ), vec![*dim])); @@ -791,13 +801,14 @@ fn convert_dynamic_shared_memory_usage<'input>( ast::PointerStateSpace::Shared, )), }); - let shared_var_st = ExpandedStatement::StoreVar( - ast::Arg2St { + let shared_var_st = ExpandedStatement::StoreVar(StoreVarDetails { + arg: ast::Arg2St { src1: shared_var_id, src2: shared_id_param, }, - ast::Type::Scalar(ast::ScalarType::B8), - ); + typ: ast::Type::Scalar(ast::ScalarType::B8), + member_index: None, + }); let mut new_statements = vec![shared_var, shared_var_st]; replace_uses_of_shared_memory( &mut new_statements, @@ -963,10 +974,9 @@ fn compute_denorm_information<'input>( denorm_count_map_update(&mut flush_counter, width, flush); } } - Statement::LoadVar(_, _) => {} - Statement::StoreVar(_, _) => {} + Statement::LoadVar(..) => {} + Statement::StoreVar(..) => {} Statement::Call(_) => {} - Statement::Composite(_) => {} Statement::Conditional(_) => {} Statement::Conversion(_) => {} Statement::Constant(_) => {} @@ -1500,7 +1510,7 @@ fn convert_to_typed_statements( Statement::Label(i) => result.push(Statement::Label(i)), Statement::Variable(v) => result.push(Statement::Variable(v)), Statement::Conditional(c) => result.push(Statement::Conditional(c)), - _ => return Err(TranslateError::Unreachable), + _ => return Err(error_unreachable()), } } Ok(result) @@ -1534,7 +1544,7 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> { _ => return Err(TranslateError::MismatchedType), }; let temp_vec = self.id_def.new_non_variable(Some(typ.clone())); - let statement = Statement::RepackVector(RepackVector { + let statement = Statement::RepackVector(RepackVectorDetails { is_extract: is_dst, typ: scalar_t, packed: temp_vec, @@ -1556,7 +1566,7 @@ impl<'a, 'b> ArgumentMapVisitor fn id( &mut self, desc: ArgumentDescriptor, - typ: Option<&ast::Type>, + _: Option<&ast::Type>, ) -> Result { Ok(desc.op) } @@ -1564,7 +1574,7 @@ impl<'a, 'b> ArgumentMapVisitor fn dst_operand( &mut self, desc: ArgumentDescriptor>, - typ: &ast::Type, + _: &ast::Type, ) -> Result, TranslateError> { Ok(desc.op) } @@ -1572,7 +1582,7 @@ impl<'a, 'b> ArgumentMapVisitor fn src_operand( &mut self, desc: ArgumentDescriptor>, - typ: &ast::Type, + _: &ast::Type, ) -> Result, TranslateError> { Ok(desc.op) } @@ -1827,8 +1837,7 @@ fn normalize_labels( labels_in_use.insert(cond.if_true); labels_in_use.insert(cond.if_false); } - Statement::Composite(..) - | Statement::Call(..) + Statement::Call(..) | Statement::Variable(..) | Statement::LoadVar(..) | Statement::StoreVar(..) @@ -1885,7 +1894,7 @@ fn normalize_predicates( } Statement::Variable(var) => result.push(Statement::Variable(var)), // Blocks are flattened when resolving ids - _ => return Err(TranslateError::Unreachable), + _ => return Err(error_unreachable()), } } Ok(result) @@ -1912,7 +1921,7 @@ fn insert_mem_ssa_statements<'a, 'b>( array_init: arg.array_init.clone(), })); } - None => return Err(TranslateError::Unreachable), + None => return Err(error_unreachable()), } } for spirv_arg in fn_decl.input.iter_mut() { @@ -1926,13 +1935,14 @@ fn insert_mem_ssa_statements<'a, 'b>( name: spirv_arg.name, array_init: spirv_arg.array_init.clone(), })); - result.push(Statement::StoreVar( - ast::Arg2St { + result.push(Statement::StoreVar(StoreVarDetails { + arg: ast::Arg2St { src1: spirv_arg.name, src2: new_id, }, typ, - )); + member_index: None, + })); spirv_arg.name = new_id; } None => {} @@ -1949,13 +1959,14 @@ fn insert_mem_ssa_statements<'a, 'b>( if let &[out_param] = &fn_decl.output.as_slice() { let (typ, _) = id_def.get_typed(out_param.name)?; let new_id = id_def.new_non_variable(Some(typ.clone())); - result.push(Statement::LoadVar( - ast::Arg2 { + result.push(Statement::LoadVar(LoadVarDetails { + arg: ast::Arg2 { dst: new_id, src: out_param.name, }, - typ.clone(), - )); + typ: typ.clone(), + member_index: None, + })); result.push(Statement::RetValue(d, new_id)); } else { result.push(Statement::Instruction(ast::Instruction::Ret(d))) @@ -1966,13 +1977,14 @@ fn insert_mem_ssa_statements<'a, 'b>( Statement::Conditional(mut bra) => { let generated_id = id_def.new_non_variable(Some(ast::Type::Scalar(ast::ScalarType::Pred))); - result.push(Statement::LoadVar( - Arg2 { + result.push(Statement::LoadVar(LoadVarDetails { + arg: Arg2 { dst: generated_id, src: bra.predicate, }, - ast::Type::Scalar(ast::ScalarType::Pred), - )); + typ: ast::Type::Scalar(ast::ScalarType::Pred), + member_index: None, + })); bra.predicate = generated_id; result.push(Statement::Conditional(bra)); } @@ -1986,7 +1998,7 @@ fn insert_mem_ssa_statements<'a, 'b>( insert_mem_ssa_statement_default(id_def, &mut result, repack)? } s @ Statement::Variable(_) | s @ Statement::Label(_) => result.push(s), - _ => return Err(TranslateError::Unreachable), + _ => return Err(error_unreachable()), } } Ok(result) @@ -2018,71 +2030,56 @@ fn type_to_variable_type( scalar_type .clone() .try_into() - .map_err(|_| TranslateError::Unreachable)?, - (*space) - .try_into() - .map_err(|_| TranslateError::Unreachable)?, + .map_err(|_| error_unreachable())?, + (*space).try_into().map_err(|_| error_unreachable())?, ))) } ast::Type::Pointer(_, ast::LdStateSpace::Shared) => None, - _ => return Err(TranslateError::Unreachable), + _ => return Err(error_unreachable()), }) } -trait VisitVariable: Sized { - fn visit_variable< - 'a, - F: FnMut( - ArgumentDescriptor, - Option<&ast::Type>, - ) -> Result, - >( +trait Visitable: Sized { + fn visit( self, - f: &mut F, - ) -> Result; -} -trait VisitVariableExpanded { - fn visit_variable_extended< - F: FnMut( - ArgumentDescriptor, - Option<&ast::Type>, - ) -> Result, - >( - self, - f: &mut F, - ) -> Result; + visitor: &mut impl ArgumentMapVisitor, + ) -> Result, To>, TranslateError>; } -struct VisitArgumentDescriptor<'a, Ctor: FnOnce(spirv::Word) -> ExpandedStatement> { +struct VisitArgumentDescriptor< + 'a, + Ctor: FnOnce(spirv::Word) -> Statement, U>, + U: ArgParamsEx, +> { desc: ArgumentDescriptor, typ: &'a ast::Type, stmt_ctor: Ctor, } -impl<'a, Ctor: FnOnce(spirv::Word) -> ExpandedStatement> VisitVariableExpanded - for VisitArgumentDescriptor<'a, Ctor> +impl< + 'a, + Ctor: FnOnce(spirv::Word) -> Statement, U>, + T: ArgParamsEx, + U: ArgParamsEx, + > Visitable for VisitArgumentDescriptor<'a, Ctor, U> { - fn visit_variable_extended< - F: FnMut( - ArgumentDescriptor, - Option<&ast::Type>, - ) -> Result, - >( + fn visit( self, - f: &mut F, - ) -> Result { - f(self.desc, Some(self.typ)).map(self.stmt_ctor) + visitor: &mut impl ArgumentMapVisitor, + ) -> Result, U>, TranslateError> { + Ok((self.stmt_ctor)(visitor.id(self.desc, Some(self.typ))?)) } } -fn insert_mem_ssa_statement_default<'a, F: VisitVariable>( +fn insert_mem_ssa_statement_default<'a, S: Visitable>( id_def: &mut NumericIdResolver, result: &mut Vec, - stmt: F, + stmt: S, ) -> Result<(), TranslateError> { let mut post_statements = Vec::new(); - let new_statement = stmt.visit_variable( - &mut |desc: ArgumentDescriptor, expected_type| { + let new_statement = + stmt.visit(&mut |desc: ArgumentDescriptor, + expected_type: Option<&ast::Type>| { if expected_type.is_none() { return Ok(desc.op); }; @@ -2092,25 +2089,26 @@ fn insert_mem_ssa_statement_default<'a, F: VisitVariable>( } let generated_id = id_def.new_non_variable(Some(var_type.clone())); if !desc.is_dst { - result.push(Statement::LoadVar( - Arg2 { + result.push(Statement::LoadVar(LoadVarDetails { + arg: Arg2 { dst: generated_id, src: desc.op, }, - var_type, - )); + typ: var_type, + member_index: None, + })); } else { - post_statements.push(Statement::StoreVar( - Arg2St { + post_statements.push(Statement::StoreVar(StoreVarDetails { + arg: Arg2St { src1: desc.op, src2: generated_id, }, - var_type, - )); + typ: var_type, + member_index: None, + })); } Ok(generated_id) - }, - )?; + })?; result.push(new_statement); result.append(&mut post_statements); Ok(()) @@ -2160,13 +2158,11 @@ fn expand_arguments<'a, 'b>( } Statement::Label(id) => result.push(Statement::Label(id)), Statement::Conditional(bra) => result.push(Statement::Conditional(bra)), - Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)), - Statement::StoreVar(arg, typ) => result.push(Statement::StoreVar(arg, typ)), + Statement::LoadVar(details) => result.push(Statement::LoadVar(details)), + Statement::StoreVar(details) => result.push(Statement::StoreVar(details)), Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)), Statement::Conversion(conv) => result.push(Statement::Conversion(conv)), - Statement::Composite(_) | Statement::Constant(_) | Statement::Undef(_, _) => { - return Err(TranslateError::Unreachable) - } + Statement::Constant(_) | Statement::Undef(_, _) => return Err(error_unreachable()), } } Ok(result) @@ -2190,27 +2186,6 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { } } - fn insert_composite_read( - func: &mut Vec, - id_def: &mut MutableNumericIdResolver<'a>, - typ: (ast::ScalarType, u8), - scalar_dst: Option, - scalar_sema_override: Option, - composite_src: (spirv::Word, u8), - ) -> spirv::Word { - let new_id = - scalar_dst.unwrap_or_else(|| id_def.new_non_variable(ast::Type::Scalar(typ.0))); - func.push(Statement::Composite(CompositeRead { - typ: typ.0, - dst: new_id, - dst_semantics_override: scalar_sema_override, - src_composite: composite_src.0, - src_index: composite_src.1 as u32, - src_len: typ.1 as u32, - })); - new_id - } - fn reg( &mut self, desc: ArgumentDescriptor, @@ -2350,7 +2325,7 @@ impl<'a, 'b> ArgumentMapVisitor for FlattenAr ) -> Result { match desc.op { ast::DstOperand::Reg(reg) => self.reg(desc.new_op(reg), Some(typ)), - ast::DstOperand::VecMember(..) => Err(TranslateError::Unreachable), + ast::DstOperand::VecMember(..) => Err(error_unreachable()), } } @@ -2365,7 +2340,7 @@ impl<'a, 'b> ArgumentMapVisitor for FlattenAr ast::SrcOperand::RegOffset(reg, offset) => { self.reg_offset(desc.new_op((reg, offset)), typ) } - ast::SrcOperand::VecIndex(..) => Err(TranslateError::Unreachable), + ast::SrcOperand::VecMember(..) => Err(error_unreachable()), } } @@ -2376,7 +2351,7 @@ impl<'a, 'b> ArgumentMapVisitor for FlattenAr ) -> Result { match desc.op { ast::DstOperand::Reg(reg) => self.reg(desc.new_op(reg), Some(typ)), - ast::DstOperand::VecMember(..) => Err(TranslateError::Unreachable), + ast::DstOperand::VecMember(..) => Err(error_unreachable()), } } @@ -2391,7 +2366,7 @@ impl<'a, 'b> ArgumentMapVisitor for FlattenAr ast::SrcOperand::RegOffset(reg, offset) => { self.reg_offset(desc.new_op((reg, offset)), typ) } - ast::SrcOperand::VecIndex(..) => Err(TranslateError::Unreachable), + ast::SrcOperand::VecMember(..) => Err(error_unreachable()), } } } @@ -2451,13 +2426,6 @@ fn insert_implicit_conversions( state_space, )?; } - Statement::Composite(composite) => insert_implicit_conversions_impl( - &mut result, - id_def, - composite, - should_bitcast_wrapper, - None, - )?, Statement::PtrAccess(PtrAccess { underlying_type, state_space, @@ -2502,8 +2470,8 @@ fn insert_implicit_conversions( | s @ Statement::Label(_) | s @ Statement::Constant(_) | s @ Statement::Variable(_) - | s @ Statement::LoadVar(_, _) - | s @ Statement::StoreVar(_, _) + | s @ Statement::LoadVar(..) + | s @ Statement::StoreVar(..) | s @ Statement::Undef(_, _) | s @ Statement::RetValue(_, _) => result.push(s), } @@ -2514,7 +2482,7 @@ fn insert_implicit_conversions( fn insert_implicit_conversions_impl( func: &mut Vec, id_def: &mut MutableNumericIdResolver, - stmt: impl VisitVariableExpanded, + stmt: impl Visitable, default_conversion_fn: for<'a> fn( &'a ast::Type, &'a ast::Type, @@ -2523,62 +2491,64 @@ fn insert_implicit_conversions_impl( state_space: Option, ) -> Result<(), TranslateError> { let mut post_conv = Vec::new(); - let statement = stmt.visit_variable_extended(&mut |desc, typ| { - let instr_type = match typ { - None => return Ok(desc.op), - Some(t) => t, - }; - let operand_type = id_def.get_typed(desc.op)?; - let mut conversion_fn = default_conversion_fn; - match desc.sema { - ArgumentSemantics::Default => {} - ArgumentSemantics::DefaultRelaxed => { - if desc.is_dst { - conversion_fn = should_convert_relaxed_dst_wrapper; - } else { - conversion_fn = should_convert_relaxed_src_wrapper; + let statement = stmt.visit( + &mut |desc: ArgumentDescriptor, typ: Option<&ast::Type>| { + let instr_type = match typ { + None => return Ok(desc.op), + Some(t) => t, + }; + let operand_type = id_def.get_typed(desc.op)?; + let mut conversion_fn = default_conversion_fn; + match desc.sema { + ArgumentSemantics::Default => {} + ArgumentSemantics::DefaultRelaxed => { + if desc.is_dst { + conversion_fn = should_convert_relaxed_dst_wrapper; + } else { + conversion_fn = should_convert_relaxed_src_wrapper; + } } - } - ArgumentSemantics::PhysicalPointer => { - conversion_fn = bitcast_physical_pointer; - } - ArgumentSemantics::RegisterPointer => { - conversion_fn = bitcast_register_pointer; - } - ArgumentSemantics::Address => { - conversion_fn = force_bitcast_ptr_to_bit; - } - }; - match conversion_fn(&operand_type, instr_type, state_space)? { - Some(conv_kind) => { - let conv_output = if desc.is_dst { - &mut post_conv - } else { - &mut *func - }; - let mut from = instr_type.clone(); - let mut to = operand_type; - let mut src = id_def.new_non_variable(instr_type.clone()); - let mut dst = desc.op; - let result = Ok(src); - if !desc.is_dst { - mem::swap(&mut src, &mut dst); - mem::swap(&mut from, &mut to); + ArgumentSemantics::PhysicalPointer => { + conversion_fn = bitcast_physical_pointer; } - conv_output.push(Statement::Conversion(ImplicitConversion { - src, - dst, - from, - to, - kind: conv_kind, - src_sema: ArgumentSemantics::Default, - dst_sema: ArgumentSemantics::Default, - })); - result + ArgumentSemantics::RegisterPointer => { + conversion_fn = bitcast_register_pointer; + } + ArgumentSemantics::Address => { + conversion_fn = force_bitcast_ptr_to_bit; + } + }; + match conversion_fn(&operand_type, instr_type, state_space)? { + Some(conv_kind) => { + let conv_output = if desc.is_dst { + &mut post_conv + } else { + &mut *func + }; + let mut from = instr_type.clone(); + let mut to = operand_type; + let mut src = id_def.new_non_variable(instr_type.clone()); + let mut dst = desc.op; + let result = Ok(src); + if !desc.is_dst { + mem::swap(&mut src, &mut dst); + mem::swap(&mut from, &mut to); + } + conv_output.push(Statement::Conversion(ImplicitConversion { + src, + dst, + from, + to, + kind: conv_kind, + src_sema: ArgumentSemantics::Default, + dst_sema: ArgumentSemantics::Default, + })); + result + } + None => Ok(desc.op), } - None => Ok(desc.op), - } - })?; + }, + )?; func.push(statement); func.append(&mut post_conv); Ok(()) @@ -3079,26 +3049,22 @@ fn emit_function_body_ops( builder_fn(builder, result_type, Some(arg.dst), arg.src1, arg.src2)?; } }, - Statement::LoadVar(arg, typ) => { + Statement::LoadVar(LoadVarDetails { + arg, + typ, + member_index, + }) => { let type_id = map.get_or_add(builder, SpirvType::from(typ.clone())); builder.load(type_id, Some(arg.dst), arg.src, None, [])?; } - Statement::StoreVar(arg, _) => { + Statement::StoreVar(StoreVarDetails { + arg, member_index, .. + }) => { builder.store(arg.src1, arg.src2, None, [])?; } Statement::RetValue(_, id) => { builder.ret_value(*id)?; } - Statement::Composite(c) => { - let result_type = map.get_or_add_scalar(builder, c.typ.into()); - let result_id = Some(c.dst); - builder.composite_extract( - result_type, - result_id, - c.src_composite, - [c.src_index], - )?; - } Statement::Undef(t, id) => { let result_type = map.get_or_add(builder, SpirvType::from(t.clone())); builder.undef(result_type, Some(*id)); @@ -3180,7 +3146,7 @@ fn insert_shift_hack( 2 => map.get_or_add_scalar(builder, ast::ScalarType::B16), 8 => map.get_or_add_scalar(builder, ast::ScalarType::B64), 4 => return Ok(offset_var), - _ => return Err(TranslateError::Unreachable), + _ => return Err(error_unreachable()), }; Ok(builder.u_convert(result_type, None, offset_var)?) } @@ -3260,7 +3226,7 @@ fn emit_atom( let spirv_op = match op { ast::AtomUIntOp::Add => dr::Builder::atomic_i_add, ast::AtomUIntOp::Inc | ast::AtomUIntOp::Dec => { - return Err(TranslateError::Unreachable); + return Err(error_unreachable()); } ast::AtomUIntOp::Min => dr::Builder::atomic_u_min, ast::AtomUIntOp::Max => dr::Builder::atomic_u_max, @@ -4346,7 +4312,7 @@ fn convert_to_stateful_memory_access<'a>( Some(src2) if remapped_ids.contains_key(src2) => { (remapped_ids.get(src2).unwrap(), arg.src1) } - _ => return Err(TranslateError::Unreachable), + _ => return Err(error_unreachable()), }; let dst = arg.dst.unwrap_reg()?; result.push(Statement::PtrAccess(PtrAccess { @@ -4375,7 +4341,7 @@ fn convert_to_stateful_memory_access<'a>( Some(src2) if remapped_ids.contains_key(src2) => { (remapped_ids.get(src2).unwrap(), arg.src1) } - _ => return Err(TranslateError::Unreachable), + _ => return Err(error_unreachable()), }; let offset_neg = id_defs.new_non_variable(Some(ast::Type::Scalar(ast::ScalarType::S64))); @@ -4400,8 +4366,9 @@ fn convert_to_stateful_memory_access<'a>( } Statement::Instruction(inst) => { let mut post_statements = Vec::new(); - let new_statement = inst.visit_variable( - &mut |arg_desc: ArgumentDescriptor, expected_type| { + let new_statement = inst.visit( + &mut |arg_desc: ArgumentDescriptor, + expected_type: Option<&ast::Type>| { convert_to_stateful_memory_access_postprocess( id_defs, &remapped_ids, @@ -4418,8 +4385,9 @@ fn convert_to_stateful_memory_access<'a>( } Statement::Call(call) => { let mut post_statements = Vec::new(); - let new_statement = call.visit_variable( - &mut |arg_desc: ArgumentDescriptor, expected_type| { + let new_statement = call.visit( + &mut |arg_desc: ArgumentDescriptor, + expected_type: Option<&ast::Type>| { convert_to_stateful_memory_access_postprocess( id_defs, &remapped_ids, @@ -4436,8 +4404,9 @@ fn convert_to_stateful_memory_access<'a>( } Statement::RepackVector(pack) => { let mut post_statements = Vec::new(); - let new_statement = pack.visit_variable( - &mut |arg_desc: ArgumentDescriptor, expected_type| { + let new_statement = pack.visit( + &mut |arg_desc: ArgumentDescriptor, + expected_type: Option<&ast::Type>| { convert_to_stateful_memory_access_postprocess( id_defs, &remapped_ids, @@ -4452,7 +4421,7 @@ fn convert_to_stateful_memory_access<'a>( result.push(new_statement); result.extend(post_statements); } - _ => return Err(TranslateError::Unreachable), + _ => return Err(error_unreachable()), } } for arg in func_args.input.iter_mut() { @@ -4517,7 +4486,7 @@ fn convert_to_stateful_memory_access_postprocess( None => match func_args_ptr.get(&arg_desc.op) { Some(new_id) => { if arg_desc.is_dst { - return Err(TranslateError::Unreachable); + return Err(error_unreachable()); } // We skip conversion here to trigger PtrAcces in a later pass let old_type = match expected_type { @@ -4896,15 +4865,14 @@ enum Statement { // SPIR-V compatible replacement for PTX predicates Conditional(BrachCondition), Call(ResolvedCall

), - LoadVar(ast::Arg2, ast::Type), - StoreVar(ast::Arg2St, ast::Type), - Composite(CompositeRead), + LoadVar(LoadVarDetails), + StoreVar(StoreVarDetails), Conversion(ImplicitConversion), Constant(ConstantDefinition), RetValue(ast::RetData, spirv::Word), Undef(ast::Type, spirv::Word), PtrAccess(PtrAccess

), - RepackVector(RepackVector), + RepackVector(RepackVectorDetails), } impl ExpandedStatement { @@ -4916,19 +4884,19 @@ impl ExpandedStatement { Statement::Variable(var) } Statement::Instruction(inst) => inst - .visit_variable_extended(&mut |arg: ArgumentDescriptor<_>, _| { + .visit(&mut |arg: ArgumentDescriptor<_>, _: Option<&ast::Type>| { Ok(f(arg.op, arg.is_dst)) }) .unwrap(), - Statement::LoadVar(mut arg, typ) => { - arg.dst = f(arg.dst, true); - arg.src = f(arg.src, false); - Statement::LoadVar(arg, typ) + Statement::LoadVar(mut details) => { + details.arg.dst = f(details.arg.dst, true); + details.arg.src = f(details.arg.src, false); + Statement::LoadVar(details) } - Statement::StoreVar(mut arg, typ) => { - arg.src1 = f(arg.src1, false); - arg.src2 = f(arg.src2, false); - Statement::StoreVar(arg, typ) + Statement::StoreVar(mut details) => { + details.arg.src1 = f(details.arg.src1, false); + details.arg.src2 = f(details.arg.src2, false); + Statement::StoreVar(details) } Statement::Call(mut call) => { for (id, typ) in call.ret_params.iter_mut() { @@ -4945,11 +4913,6 @@ impl ExpandedStatement { } Statement::Call(call) } - Statement::Composite(mut composite) => { - composite.dst = f(composite.dst, true); - composite.src_composite = f(composite.src_composite, false); - Statement::Composite(composite) - } Statement::Conditional(mut conditional) => { conditional.predicate = f(conditional.predicate, false); conditional.if_true = f(conditional.if_true, false); @@ -4996,7 +4959,19 @@ impl ExpandedStatement { } } -struct RepackVector { +struct LoadVarDetails { + arg: ast::Arg2, + typ: ast::Type, + member_index: Option, +} + +struct StoreVarDetails { + arg: ast::Arg2St, + typ: ast::Type, + member_index: Option, +} + +struct RepackVectorDetails { is_extract: bool, typ: ast::ScalarType, packed: spirv::Word, @@ -5004,7 +4979,7 @@ struct RepackVector { vector_sema: ArgumentSemantics, } -impl RepackVector { +impl RepackVectorDetails { fn map< From: ArgParamsEx, To: ArgParamsEx, @@ -5012,7 +4987,7 @@ impl RepackVector { >( self, visitor: &mut V, - ) -> Result { + ) -> Result { let scalar = visitor.id( ArgumentDescriptor { op: self.packed, @@ -5038,7 +5013,7 @@ impl RepackVector { ) }) .collect::>()?; - Ok(RepackVector { + Ok(RepackVectorDetails { is_extract, typ: self.typ, packed: scalar, @@ -5048,45 +5023,17 @@ impl RepackVector { } } -impl VisitVariable for RepackVector { - fn visit_variable< - 'a, - F: FnMut( - ArgumentDescriptor, - Option<&ast::Type>, - ) -> Result, - >( +impl, U: ArgParamsEx> Visitable + for RepackVectorDetails +{ + fn visit( self, - f: &mut F, - ) -> Result { - Ok(TypedStatement::RepackVector( - self.map::(f)?, - )) + visitor: &mut impl ArgumentMapVisitor, + ) -> Result, U>, TranslateError> { + Ok(Statement::RepackVector(self.map::<_, _, _>(visitor)?)) } } -impl VisitVariableExpanded for RepackVector { - fn visit_variable_extended< - F: FnMut( - ArgumentDescriptor, - Option<&ast::Type>, - ) -> Result, - >( - self, - f: &mut F, - ) -> Result { - Ok(ExpandedStatement::RepackVector( - self.map::(f)?, - )) - } -} - -struct UnpackVector { - typ: ast::ScalarType, - dst: Vec, - src: spirv::Word, -} - struct ResolvedCall { pub uniform: bool, pub ret_params: Vec<(P::Id, ast::FnArgumentType)>, @@ -5157,32 +5104,14 @@ impl> ResolvedCall { } } -impl VisitVariable for ResolvedCall { - fn visit_variable< - 'a, - F: FnMut( - ArgumentDescriptor, - Option<&ast::Type>, - ) -> Result, - >( +impl, U: ArgParamsEx> Visitable + for ResolvedCall +{ + fn visit( self, - f: &mut F, - ) -> Result { - Ok(Statement::Call(self.map(f)?)) - } -} - -impl VisitVariableExpanded for ResolvedCall { - fn visit_variable_extended< - F: FnMut( - ArgumentDescriptor, - Option<&ast::Type>, - ) -> Result, - >( - self, - f: &mut F, - ) -> Result { - Ok(Statement::Call(self.map(f)?)) + visitor: &mut impl ArgumentMapVisitor, + ) -> Result, U>, TranslateError> { + Ok(Statement::Call(self.map(visitor)?)) } } @@ -5235,18 +5164,14 @@ impl> PtrAccess

{ } } -impl VisitVariable for PtrAccess { - fn visit_variable< - 'a, - F: FnMut( - ArgumentDescriptor, - Option<&ast::Type>, - ) -> Result, - >( +impl, U: ArgParamsEx> Visitable + for PtrAccess +{ + fn visit( self, - f: &mut F, - ) -> Result { - Ok(Statement::PtrAccess(self.map(f)?)) + visitor: &mut impl ArgumentMapVisitor, + ) -> Result, U>, TranslateError> { + Ok(Statement::PtrAccess(self.map(visitor)?)) } } @@ -5480,7 +5405,7 @@ where ast::SrcOperand::Reg(id) => ast::SrcOperand::Reg(self(id)?), ast::SrcOperand::RegOffset(id, imm) => ast::SrcOperand::RegOffset(self(id)?, imm), ast::SrcOperand::Imm(imm) => ast::SrcOperand::Imm(imm), - ast::SrcOperand::VecIndex(id, member) => ast::SrcOperand::VecIndex(self(id)?, member), + ast::SrcOperand::VecMember(id, member) => ast::SrcOperand::VecMember(self(id)?, member), }) } @@ -5567,7 +5492,7 @@ impl ast::Instruction { ast::Instruction::Abs(d, arg.map(visitor, &ast::Type::Scalar(d.typ))?) } // Call instruction is converted to a call statement early on - ast::Instruction::Call(_) => return Err(TranslateError::Unreachable), + ast::Instruction::Call(_) => return Err(error_unreachable()), ast::Instruction::Ld(d, a) => { let new_args = a.map(visitor, &d)?; ast::Instruction::Ld(d, new_args) @@ -5760,18 +5685,12 @@ impl ast::Instruction { } } -impl VisitVariable for ast::Instruction { - fn visit_variable< - 'a, - F: FnMut( - ArgumentDescriptor, - Option<&ast::Type>, - ) -> Result, - >( +impl Visitable for ast::Instruction { + fn visit( self, - f: &mut F, - ) -> Result { - Ok(Statement::Instruction(self.map(f)?)) + visitor: &mut impl ArgumentMapVisitor, + ) -> Result, U>, TranslateError> { + Ok(Statement::Instruction(self.map(visitor)?)) } } @@ -5810,32 +5729,14 @@ impl ImplicitConversion { } } -impl VisitVariable for ImplicitConversion { - fn visit_variable< - 'a, - F: FnMut( - ArgumentDescriptor, - Option<&ast::Type>, - ) -> Result, - >( +impl, To: ArgParamsEx> Visitable + for ImplicitConversion +{ + fn visit( self, - f: &mut F, - ) -> Result { - self.map(f) - } -} - -impl VisitVariableExpanded for ImplicitConversion { - fn visit_variable_extended< - F: FnMut( - ArgumentDescriptor, - Option<&ast::Type>, - ) -> Result, - >( - self, - f: &mut F, - ) -> Result { - self.map(f) + visitor: &mut impl ArgumentMapVisitor, + ) -> Result, To>, TranslateError> { + Ok(self.map(visitor)?) } } @@ -5861,7 +5762,14 @@ where ) -> Result, TranslateError> { Ok(match desc.op { ast::DstOperand::Reg(id) => ast::DstOperand::Reg(self(desc.new_op(id), Some(typ))?), - ast::DstOperand::VecMember(_, _) => todo!(), + ast::DstOperand::VecMember(reg, index) => { + let scalar_type = match typ { + ast::Type::Scalar(scalar_t) => *scalar_t, + _ => return Err(error_unreachable()), + }; + let vec_type = ast::Type::Vector(scalar_type, index + 1); + ast::DstOperand::VecMember(self(desc.new_op(reg), Some(&vec_type))?, index) + } }) } @@ -5876,7 +5784,14 @@ where ast::SrcOperand::RegOffset(id, imm) => { ast::SrcOperand::RegOffset(self(desc.new_op(id), Some(typ))?, imm) } - ast::SrcOperand::VecIndex(_, _) => todo!(), + ast::SrcOperand::VecMember(reg, index) => { + let scalar_type = match typ { + ast::Type::Scalar(scalar_t) => *scalar_t, + _ => return Err(error_unreachable()), + }; + let vec_type = ast::Type::Vector(scalar_type, index + 1); + ast::SrcOperand::VecMember(self(desc.new_op(reg), Some(&vec_type))?, index) + } }) } @@ -5887,7 +5802,14 @@ where ) -> Result, TranslateError> { Ok(match desc.op { ast::DstOperand::Reg(id) => ast::DstOperand::Reg(self(desc.new_op(id), Some(typ))?), - ast::DstOperand::VecMember(_, _) => todo!(), + ast::DstOperand::VecMember(reg, index) => { + let scalar_type = match typ { + ast::Type::Scalar(scalar_t) => *scalar_t, + _ => return Err(error_unreachable()), + }; + let vec_type = ast::Type::Vector(scalar_type, index + 1); + ast::DstOperand::VecMember(self(desc.new_op(reg), Some(&vec_type))?, index) + } }) } @@ -5902,7 +5824,14 @@ where ast::SrcOperand::RegOffset(id, imm) => { ast::SrcOperand::RegOffset(self(desc.new_op(id), Some(typ))?, imm) } - ast::SrcOperand::VecIndex(_, _) => todo!(), + ast::SrcOperand::VecMember(reg, index) => { + let scalar_type = match typ { + ast::Type::Scalar(scalar_t) => *scalar_t, + _ => return Err(error_unreachable()), + }; + let vec_type = ast::Type::Vector(scalar_type, index + 1); + ast::SrcOperand::VecMember(self(desc.new_op(reg), Some(&vec_type))?, index) + } }) } } @@ -5925,7 +5854,7 @@ impl ast::Type { kind, ))) } - _ => Err(TranslateError::Unreachable), + _ => Err(error_unreachable()), } } @@ -6165,67 +6094,9 @@ impl ast::Instruction { } } -impl VisitVariableExpanded for ast::Instruction { - fn visit_variable_extended< - F: FnMut( - ArgumentDescriptor, - Option<&ast::Type>, - ) -> Result, - >( - self, - f: &mut F, - ) -> Result { - Ok(Statement::Instruction(self.map(f)?)) - } -} - type Arg2 = ast::Arg2; type Arg2St = ast::Arg2St; -struct CompositeRead { - pub typ: ast::ScalarType, - pub dst: spirv::Word, - pub dst_semantics_override: Option, - pub src_composite: spirv::Word, - pub src_index: u32, - pub src_len: u32, -} - -impl VisitVariableExpanded for CompositeRead { - fn visit_variable_extended< - F: FnMut( - ArgumentDescriptor, - Option<&ast::Type>, - ) -> Result, - >( - self, - f: &mut F, - ) -> Result { - let dst_sema = self - .dst_semantics_override - .unwrap_or(ArgumentSemantics::Default); - Ok(Statement::Composite(CompositeRead { - dst: f( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - sema: dst_sema, - }, - Some(&ast::Type::Scalar(self.typ)), - )?, - src_composite: f( - ArgumentDescriptor { - op: self.src_composite, - is_dst: false, - sema: ArgumentSemantics::Default, - }, - Some(&ast::Type::Vector(self.typ, self.src_len as u8)), - )?, - ..self - })) - } -} - struct ConstantDefinition { pub dst: spirv::Word, pub typ: ast::ScalarType, @@ -6926,7 +6797,7 @@ impl ast::SrcOperand { ast::SrcOperand::Reg(reg) => ast::SrcOperand::Reg(f(reg)?), ast::SrcOperand::RegOffset(reg, offset) => ast::SrcOperand::RegOffset(f(reg)?, offset), ast::SrcOperand::Imm(x) => ast::SrcOperand::Imm(x), - ast::SrcOperand::VecIndex(reg, idx) => ast::SrcOperand::VecIndex(f(reg)?, idx), + ast::SrcOperand::VecMember(reg, idx) => ast::SrcOperand::VecMember(f(reg)?, idx), }) } } @@ -6935,7 +6806,7 @@ impl ast::DstOperand { fn to_src_operand(self) -> ast::SrcOperand { match self { ast::DstOperand::Reg(reg) => ast::SrcOperand::Reg(reg), - ast::DstOperand::VecMember(reg, idx) => ast::SrcOperand::VecIndex(reg, idx), + ast::DstOperand::VecMember(reg, idx) => ast::SrcOperand::VecMember(reg, idx), } } } @@ -6943,7 +6814,7 @@ impl ast::DstOperand { fn unwrap_reg(&self) -> Result { match self { ast::DstOperand::Reg(reg) => Ok(*reg), - ast::DstOperand::VecMember(..) => Err(TranslateError::Unreachable), + ast::DstOperand::VecMember(..) => Err(error_unreachable()), } } } @@ -7164,7 +7035,7 @@ impl ast::SrcOperand { match self { ast::SrcOperand::Reg(r) | ast::SrcOperand::RegOffset(r, _) => Some(r), ast::SrcOperand::Imm(_) => None, - ast::SrcOperand::VecIndex(reg, _) => Some(reg), + ast::SrcOperand::VecMember(reg, _) => Some(reg), } } } @@ -7262,7 +7133,7 @@ fn bitcast_physical_pointer( if let Some(space) = ss { Ok(Some(ConversionKind::BitToPtr(space))) } else { - Err(TranslateError::Unreachable) + Err(error_unreachable()) } } ast::Type::Scalar(ast::ScalarType::B32)