From 00b8d8d87f408231a2ae8528dee5f30b0fb90525 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 4 Dec 2020 19:50:08 +0100 Subject: [PATCH] Start refactoring vector-handling code --- ptx/src/ast.rs | 122 ++--- ptx/src/ptx.lalrpop | 12 +- ptx/src/translate.rs | 1120 ++++++++++++++---------------------------- 3 files changed, 435 insertions(+), 819 deletions(-) diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 367f060..8fcf82a 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -557,7 +557,7 @@ pub enum Instruction { Mul(MulDetails, Arg3

), Add(ArithDetails, Arg3

), Setp(SetpData, Arg4Setp

), - SetpBool(SetpBoolData, Arg5

), + SetpBool(SetpBoolData, Arg5Setp

), Not(BooleanType, Arg2

), Bra(BraData, Arg1

), Cvt(CvtDetails, Arg2

), @@ -614,16 +614,15 @@ pub struct CallInst { pub uniform: bool, pub ret_params: Vec, pub func: P::Id, - pub param_list: Vec, + pub param_list: Vec, } pub trait ArgParams { type Id; - type Operand; - type IdOrVector; - type OperandOrVector; - type CallOperand; - type SrcMemberOperand; + type DstOperand; + type SrcOperand; + type DstOperandVec; + type SrcOperandVec; } pub struct ParsedArgParams<'a> { @@ -632,11 +631,10 @@ pub struct ParsedArgParams<'a> { impl<'a> ArgParams for ParsedArgParams<'a> { type Id = &'a str; - type Operand = Operand<&'a str>; - type CallOperand = CallOperand<&'a str>; - type IdOrVector = IdOrVector<&'a str>; - type OperandOrVector = OperandOrVector<&'a str>; - type SrcMemberOperand = (&'a str, u8); + type DstOperand = DstOperand<&'a str>; + type SrcOperand = SrcOperand<&'a str>; + type DstOperandVec = DstOperandVec<&'a str>; + type SrcOperandVec = SrcOperandVec<&'a str>; } pub struct Arg1 { @@ -644,67 +642,54 @@ pub struct Arg1 { } pub struct Arg1Bar { - pub src: P::Operand, + pub src: P::SrcOperand, } pub struct Arg2 { - pub dst: P::Id, - pub src: P::Operand, + pub dst: P::DstOperand, + pub src: P::SrcOperand, } pub struct Arg2Ld { - pub dst: P::IdOrVector, - pub src: P::Operand, + pub dst: P::DstOperandVec, + pub src: P::SrcOperand, } pub struct Arg2St { - pub src1: P::Operand, - pub src2: P::OperandOrVector, + pub src1: P::SrcOperand, + pub src2: P::SrcOperandVec, } -pub enum Arg2Mov { - Normal(Arg2MovNormal

), - Member(Arg2MovMember

), -} - -pub struct Arg2MovNormal { - pub dst: P::IdOrVector, - pub src: P::OperandOrVector, -} - -// We duplicate dst here because during further compilation -// composite dst and composite src will receive different ids -pub enum Arg2MovMember { - Dst((P::Id, u8), P::Id, P::Id), - Src(P::Id, P::SrcMemberOperand), - Both((P::Id, u8), P::Id, P::SrcMemberOperand), +pub struct Arg2Mov { + pub dst: P::DstOperandVec, + pub src: P::SrcOperandVec, } pub struct Arg3 { - pub dst: P::Id, - pub src1: P::Operand, - pub src2: P::Operand, + pub dst: P::DstOperand, + pub src1: P::SrcOperand, + pub src2: P::SrcOperand, } pub struct Arg4 { - pub dst: P::Id, - pub src1: P::Operand, - pub src2: P::Operand, - pub src3: P::Operand, + pub dst: P::DstOperand, + pub src1: P::SrcOperand, + pub src2: P::SrcOperand, + pub src3: P::SrcOperand, } pub struct Arg4Setp { pub dst1: P::Id, pub dst2: Option, - pub src1: P::Operand, - pub src2: P::Operand, + pub src1: P::SrcOperand, + pub src2: P::SrcOperand, } -pub struct Arg5 { +pub struct Arg5Setp { pub dst1: P::Id, pub dst2: Option, - pub src1: P::Operand, - pub src2: P::Operand, - pub src3: P::Operand, + pub src1: P::SrcOperand, + pub src2: P::SrcOperand, + pub src3: P::SrcOperand, } #[derive(Copy, Clone)] @@ -716,38 +701,29 @@ pub enum ImmediateValue { } #[derive(Copy, Clone)] -pub enum Operand { +pub enum DstOperand { Reg(ID), - RegOffset(ID, i32), - Imm(ImmediateValue), + VecMember(ID, u8), +} + +#[derive(Clone)] +pub enum DstOperandVec { + Normal(DstOperand), + Vector(Vec), } #[derive(Copy, Clone)] -pub enum CallOperand { - Reg(ID), +pub enum SrcOperand { + Reg(Id), + RegOffset(Id, i32), Imm(ImmediateValue), + VecIndex(Id, u8), } -pub enum IdOrVector { - Reg(ID), - Vec(Vec), -} - -pub enum OperandOrVector { - Reg(ID), - RegOffset(ID, i32), - Imm(ImmediateValue), - Vec(Vec), -} - -impl From> for OperandOrVector { - fn from(this: Operand) -> Self { - match this { - Operand::Reg(r) => OperandOrVector::Reg(r), - Operand::RegOffset(r, imm) => OperandOrVector::RegOffset(r, imm), - Operand::Imm(imm) => OperandOrVector::Imm(imm), - } - } +#[derive(Clone)] +pub enum SrcOperandVec { + Normal(SrcOperand), + Vector(Vec), } pub enum VectorPrefix { diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index d2c235a..ff48cd7 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -921,7 +921,7 @@ InstAdd: ast::Instruction> = { // TODO: support f16 setp InstSetp: ast::Instruction> = { "setp" => ast::Instruction::Setp(d, a), - "setp" => ast::Instruction::SetpBool(d, a), + "setp" => ast::Instruction::SetpBool(d, a), }; SetpMode: ast::SetpData = { @@ -1775,9 +1775,9 @@ Operand: ast::Operand<&'input str> = { => ast::Operand::Imm(x) }; -CallOperand: ast::CallOperand<&'input str> = { - => ast::CallOperand::Reg(r), - => ast::CallOperand::Imm(x) +CallOperand: ast::SrcOperand<&'input str> = { + => ast::SrcOperand::Reg(r), + => ast::SrcOperand::Imm(x) }; // TODO: start parsing whole constants sub-language: @@ -1875,8 +1875,8 @@ Arg4Setp: ast::Arg4Setp> = { }; // TODO: pass src3 negation somewhere -Arg5: ast::Arg5> = { - "," "," "," "!"? => ast::Arg5{<>} +Arg5Setp: ast::Arg5Setp> = { + "," "," "," "!"? => ast::Arg5Setp{<>} }; ArgCall: (Vec<&'input str>, &'input str, Vec>) = { diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 15211ab..f0b9161 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -975,6 +975,8 @@ fn compute_denorm_information<'input>( Statement::Label(_) => {} Statement::Variable(_) => {} Statement::PtrAccess { .. } => {} + Statement::PackVector(_) => {} + Statement::UnpackVector(_) => {} } } denorm_methods.insert(method_key, flush_counter); @@ -1307,7 +1309,7 @@ fn to_ssa<'input, 'b>( let mut numeric_id_defs = id_defs.finish(); let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?; let typed_statements = - convert_to_typed_statements(unadorned_statements, &fn_defs, &numeric_id_defs)?; + convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?; let typed_statements = convert_to_stateful_memory_access(&mut spirv_decl, typed_statements, &mut numeric_id_defs)?; let ssa_statements = insert_mem_ssa_statements( @@ -1431,7 +1433,7 @@ fn normalize_variable_decls(directives: &mut Vec) { fn convert_to_typed_statements( func: Vec, fn_defs: &GlobalFnDeclResolver, - id_defs: &NumericIdResolver, + id_defs: &mut NumericIdResolver, ) -> Result, TranslateError> { let mut result = Vec::::with_capacity(func.len()); for s in func { @@ -1447,7 +1449,7 @@ fn convert_to_typed_statements( .partition(|(_, arg_type)| arg_type.is_param()); let normalized_input_args = out_params .into_iter() - .map(|(id, typ)| (ast::CallOperand::Reg(id), typ)) + .map(|(id, typ)| (ast::SrcOperand::Reg(id), typ)) .chain(in_args.into_iter()) .collect(); let resolved_call = ResolvedCall { @@ -1458,192 +1460,38 @@ fn convert_to_typed_statements( }; result.push(Statement::Call(resolved_call)); } - ast::Instruction::Ld(d, arg) => { - result.push(Statement::Instruction(ast::Instruction::Ld(d, arg.cast()))); - } - ast::Instruction::St(d, arg) => { - result.push(Statement::Instruction(ast::Instruction::St(d, arg.cast()))); - } - ast::Instruction::Mov(mut d, args) => match args { - ast::Arg2Mov::Normal(arg) => { - if let Some(src_id) = arg.src.single_underlying() { - let (typ, _) = id_defs.get_typed(*src_id)?; - let take_address = match typ { - ast::Type::Scalar(_) => false, - ast::Type::Vector(_, _) => false, - ast::Type::Array(_, _) => true, - ast::Type::Pointer(_, _) => true, - }; - d.src_is_address = take_address; - } - result.push(Statement::Instruction(ast::Instruction::Mov( - d, - ast::Arg2Mov::Normal(arg.cast()), - ))); - } - ast::Arg2Mov::Member(args) => { - if let Some(dst_typ) = args.vector_dst() { - match id_defs.get_typed(*dst_typ)? { - (ast::Type::Vector(_, len), _) => { - d.dst_width = len; - } - _ => return Err(TranslateError::MismatchedType), - } + ast::Instruction::Mov( + mut d, + ast::Arg2Mov { + dst, + src: ast::SrcOperandVec::Normal(src), + }, + ) => { + if let Some(src_id) = src.underlying() { + let (typ, _) = id_defs.get_typed(*src_id)?; + let take_address = match typ { + ast::Type::Scalar(_) => false, + ast::Type::Vector(_, _) => false, + ast::Type::Array(_, _) => true, + ast::Type::Pointer(_, _) => true, }; - if let Some((src_typ, _)) = args.vector_src() { - match id_defs.get_typed(*src_typ)? { - (ast::Type::Vector(_, len), _) => { - d.src_width = len; - } - _ => return Err(TranslateError::MismatchedType), - } - }; - result.push(Statement::Instruction(ast::Instruction::Mov( - d, - ast::Arg2Mov::Member(args.cast()), - ))); + d.src_is_address = take_address; } - }, - ast::Instruction::Mul(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Mul(d, a.cast()))) + let mut visitor = VectorPackingVisitor::new(&mut result, id_defs); + result.push(Statement::Instruction( + ast::Instruction::Mov( + d, + ast::Arg2Mov { + dst, + src: ast::SrcOperandVec::Normal(src), + }, + ) + .map(&mut visitor)?, + )); } - ast::Instruction::Add(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Add(d, a.cast()))) - } - ast::Instruction::Setp(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Setp(d, a.cast()))) - } - ast::Instruction::SetpBool(d, a) => result.push(Statement::Instruction( - ast::Instruction::SetpBool(d, a.cast()), - )), - ast::Instruction::Not(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Not(d, a.cast()))) - } - ast::Instruction::Bra(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Bra(d, a.cast()))) - } - ast::Instruction::Cvt(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Cvt(d, a.cast()))) - } - ast::Instruction::Cvta(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Cvta(d, a.cast()))) - } - ast::Instruction::Shl(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Shl(d, a.cast()))) - } - ast::Instruction::Ret(d) => { - result.push(Statement::Instruction(ast::Instruction::Ret(d))) - } - ast::Instruction::Abs(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Abs(d, a.cast()))) - } - ast::Instruction::Mad(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Mad(d, a.cast()))) - } - ast::Instruction::Shr(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Shr(d, a.cast()))) - } - ast::Instruction::Or(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Or(d, a.cast()))) - } - ast::Instruction::Sub(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Sub(d, a.cast()))) - } - ast::Instruction::Min(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Min(d, a.cast()))) - } - ast::Instruction::Max(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Max(d, a.cast()))) - } - ast::Instruction::Rcp(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Rcp(d, a.cast()))) - } - ast::Instruction::And(d, a) => { - result.push(Statement::Instruction(ast::Instruction::And(d, a.cast()))) - } - ast::Instruction::Selp(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Selp(d, a.cast()))) - } - ast::Instruction::Bar(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Bar(d, a.cast()))) - } - ast::Instruction::Atom(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Atom(d, a.cast()))) - } - ast::Instruction::AtomCas(d, a) => result.push(Statement::Instruction( - ast::Instruction::AtomCas(d, a.cast()), - )), - ast::Instruction::Div(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Div(d, a.cast()))) - } - ast::Instruction::Sqrt(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Sqrt(d, a.cast()))) - } - ast::Instruction::Rsqrt(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Rsqrt(d, a.cast()))) - } - ast::Instruction::Neg(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Neg(d, a.cast()))) - } - ast::Instruction::Sin { flush_to_zero, arg } => { - result.push(Statement::Instruction(ast::Instruction::Sin { - flush_to_zero, - arg: arg.cast(), - })) - } - ast::Instruction::Cos { flush_to_zero, arg } => { - result.push(Statement::Instruction(ast::Instruction::Cos { - flush_to_zero, - arg: arg.cast(), - })) - } - ast::Instruction::Lg2 { flush_to_zero, arg } => { - result.push(Statement::Instruction(ast::Instruction::Lg2 { - flush_to_zero, - arg: arg.cast(), - })) - } - ast::Instruction::Ex2 { flush_to_zero, arg } => { - result.push(Statement::Instruction(ast::Instruction::Ex2 { - flush_to_zero, - arg: arg.cast(), - })) - } - ast::Instruction::Clz { typ, arg } => { - result.push(Statement::Instruction(ast::Instruction::Clz { - typ, - arg: arg.cast(), - })) - } - ast::Instruction::Brev { typ, arg } => { - result.push(Statement::Instruction(ast::Instruction::Brev { - typ, - arg: arg.cast(), - })) - } - ast::Instruction::Popc { typ, arg } => { - result.push(Statement::Instruction(ast::Instruction::Popc { - typ, - arg: arg.cast(), - })) - } - ast::Instruction::Xor { typ, arg } => { - result.push(Statement::Instruction(ast::Instruction::Xor { - typ, - arg: arg.cast(), - })) - } - ast::Instruction::Bfe { typ, arg } => { - result.push(Statement::Instruction(ast::Instruction::Bfe { - typ, - arg: arg.cast(), - })) - } - ast::Instruction::Rem { typ, arg } => { - result.push(Statement::Instruction(ast::Instruction::Rem { - typ, - arg: arg.cast(), - })) + inst => { + let mut visitor = VectorPackingVisitor::new(&mut result, id_defs); + result.push(Statement::Instruction(inst.map(&mut visitor)?)); } }, Statement::Label(i) => result.push(Statement::Label(i)), @@ -1655,6 +1503,72 @@ fn convert_to_typed_statements( Ok(result) } +struct VectorPackingVisitor<'a, 'b> { + func: &'b mut Vec, + id_def: &'b mut NumericIdResolver<'a>, + post_stmts: Vec, +} + +impl<'a, 'b> VectorPackingVisitor<'a, 'b> { + fn new(func: &'b mut Vec, id_def: &'b mut NumericIdResolver<'a>) -> Self { + VectorPackingVisitor { + func, + id_def, + post_stmts: Vec::new(), + } + } +} + +impl<'a, 'b> ArgumentMapVisitor + for VectorPackingVisitor<'a, 'b> +{ + fn id( + &mut self, + desc: ArgumentDescriptor, + typ: Option<&ast::Type>, + ) -> Result { + Ok(desc.op) + } + + fn dst_operand( + &mut self, + desc: ArgumentDescriptor>, + typ: &ast::Type, + ) -> Result, TranslateError> { + Ok(desc.op) + } + + fn src_operand( + &mut self, + desc: ArgumentDescriptor>, + typ: &ast::Type, + ) -> Result, TranslateError> { + Ok(desc.op) + } + + fn dst_operand_vec( + &mut self, + desc: ArgumentDescriptor>, + typ: &ast::Type, + ) -> Result, TranslateError> { + match desc.op { + ast::DstOperandVec::Normal(op) => self.dst_operand(desc.new_op(op), typ), + ast::DstOperandVec::Vector(vec) => todo!(), + } + } + + fn src_operand_vec( + &mut self, + desc: ArgumentDescriptor>, + typ: &ast::Type, + ) -> Result, TranslateError> { + match desc.op { + ast::SrcOperandVec::Normal(op) => self.src_operand(desc.new_op(op), typ), + ast::SrcOperandVec::Vector(_) => todo!(), + } + } +} + //TODO: share common code between this and to_ptx_impl_bfe_call fn to_ptx_impl_atomic_call( id_defs: &mut NumericIdResolver, @@ -1872,17 +1786,19 @@ fn normalize_labels( labels_in_use.insert(cond.if_true); labels_in_use.insert(cond.if_false); } - Statement::Composite(_) - | Statement::Call(_) - | Statement::Variable(_) - | Statement::LoadVar(_, _) - | Statement::StoreVar(_, _) - | Statement::RetValue(_, _) - | Statement::Conversion(_) - | Statement::Constant(_) - | Statement::Label(_) - | Statement::Undef(_, _) - | Statement::PtrAccess { .. } => {} + Statement::Composite(..) + | Statement::Call(..) + | Statement::Variable(..) + | Statement::LoadVar(..) + | Statement::StoreVar(..) + | Statement::RetValue(..) + | Statement::Conversion(..) + | Statement::Constant(..) + | Statement::Label(..) + | Statement::Undef(..) + | Statement::PtrAccess { .. } + | Statement::PackVector(..) + | Statement::UnpackVector(..) => {} } } iter::once(Statement::Label(id_def.new_non_variable(None))) @@ -2202,6 +2118,8 @@ fn expand_arguments<'a, 'b>( Statement::Composite(_) | Statement::Constant(_) | Statement::Undef(_, _) => { return Err(TranslateError::Unreachable) } + Statement::PackVector(_) => todo!(), + Statement::UnpackVector(_) => todo!(), } } Ok(result) @@ -2398,6 +2316,8 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { self.func.push(Statement::Undef(typ.clone(), new_id)); for (idx, id) in desc.op.iter().enumerate() { let newer_id = self.id_def.new_non_variable(typ.clone()); + todo!(); + /* self.func.push(Statement::Instruction(ast::Instruction::Mov( ast::MovDetails { typ: ast::Type::Scalar(scalar_type), @@ -2412,6 +2332,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { *id, )), ))); + */ new_id = newer_id; } Ok(new_id) @@ -2441,61 +2362,36 @@ impl<'a, 'b> ArgumentMapVisitor for FlattenAr self.reg(desc, t) } - fn operand( + fn dst_operand( &mut self, - desc: ArgumentDescriptor>, + desc: ArgumentDescriptor>, typ: &ast::Type, ) -> Result { - match desc.op { - ast::Operand::Reg(r) => self.reg(desc.new_op(r), Some(typ)), - ast::Operand::Imm(x) => self.immediate(desc.new_op(x), typ), - ast::Operand::RegOffset(reg, offset) => { - self.reg_offset(desc.new_op((reg, offset)), typ) - } - } + todo!() } - fn src_call_operand( + fn src_operand( &mut self, - desc: ArgumentDescriptor>, + desc: ArgumentDescriptor>, typ: &ast::Type, ) -> Result { - match desc.op { - ast::CallOperand::Reg(reg) => self.reg(desc.new_op(reg), Some(typ)), - ast::CallOperand::Imm(x) => self.immediate(desc.new_op(x), typ), - } + todo!() } - fn src_member_operand( + fn dst_operand_vec( &mut self, - desc: ArgumentDescriptor<(spirv::Word, u8)>, - typ: (ast::ScalarType, u8), - ) -> Result { - self.member_src(desc, typ) - } - - fn id_or_vector( - &mut self, - desc: ArgumentDescriptor>, + desc: ArgumentDescriptor>, typ: &ast::Type, ) -> Result { - match desc.op { - ast::IdOrVector::Reg(r) => self.reg(desc.new_op(r), Some(typ)), - ast::IdOrVector::Vec(ref v) => self.vector(desc.new_op(v), typ), - } + todo!() } - fn operand_or_vector( + fn src_operand_vec( &mut self, - desc: ArgumentDescriptor>, + desc: ArgumentDescriptor>, typ: &ast::Type, ) -> Result { - match desc.op { - ast::OperandOrVector::Reg(r) => self.reg(desc.new_op(r), Some(typ)), - ast::OperandOrVector::RegOffset(r, imm) => self.reg_offset(desc.new_op((r, imm)), typ), - ast::OperandOrVector::Imm(imm) => self.immediate(desc.new_op(imm), typ), - ast::OperandOrVector::Vec(ref v) => self.vector(desc.new_op(v), typ), - } + todo!() } } @@ -2543,7 +2439,7 @@ fn insert_implicit_conversions( if let ast::Instruction::AtomCas(d, _) = &inst { state_space = Some(d.space.to_ld_ss()); } - if let ast::Instruction::Mov(_, ast::Arg2Mov::Normal(_)) = &inst { + if let ast::Instruction::Mov(..) = &inst { default_conversion_fn = should_bitcast_packed; } insert_implicit_conversions_impl( @@ -2861,38 +2757,11 @@ fn emit_function_body_ops( } // SPIR-V does not support ret as guaranteed-converged ast::Instruction::Ret(_) => builder.ret()?, - ast::Instruction::Mov(d, arg) => match arg { - ast::Arg2Mov::Normal(ast::Arg2MovNormal { dst, src }) - | ast::Arg2Mov::Member(ast::Arg2MovMember::Src(dst, src)) => { - let result_type = map - .get_or_add(builder, SpirvType::from(ast::Type::from(d.typ.clone()))); - builder.copy_object(result_type, Some(*dst), *src)?; - } - ast::Arg2Mov::Member(ast::Arg2MovMember::Dst( - dst, - composite_src, - scalar_src, - )) - | ast::Arg2Mov::Member(ast::Arg2MovMember::Both( - dst, - composite_src, - scalar_src, - )) => { - let scalar_type = d.typ.get_scalar()?; - let result_type = map.get_or_add( - builder, - SpirvType::from(ast::Type::Vector(scalar_type, d.dst_width)), - ); - let result_id = Some(dst.0); - builder.composite_insert( - result_type, - result_id, - *scalar_src, - *composite_src, - [dst.1 as u32], - )?; - } - }, + ast::Instruction::Mov(d, arg) => { + let result_type = + map.get_or_add(builder, SpirvType::from(ast::Type::from(d.typ.clone()))); + builder.copy_object(result_type, Some(arg.dst), arg.src)?; + } ast::Instruction::Mul(mul, arg) => match mul { ast::MulDetails::Signed(ref ctr) => { emit_mul_sint(builder, map, opencl, ctr, arg)? @@ -3254,6 +3123,8 @@ fn emit_function_body_ops( )?; builder.bitcast(result_type, Some(*dst), temp)?; } + Statement::PackVector(_) => todo!(), + Statement::UnpackVector(_) => todo!(), } } Ok(()) @@ -4290,9 +4161,9 @@ fn convert_to_stateful_memory_access<'a>( }, arg, )) => { - if let Some(src) = arg.src.underlying() { - if is_64_bit_integer(id_defs, *src) && is_64_bit_integer(id_defs, arg.dst) { - stateful_markers.push((arg.dst, *src)); + if let (ast::DstOperand::Reg(dst), Some(src)) = (arg.dst, arg.src.underlying()) { + if is_64_bit_integer(id_defs, *src) && is_64_bit_integer(id_defs, dst) { + stateful_markers.push((dst, *src)); } } } @@ -4320,7 +4191,7 @@ fn convert_to_stateful_memory_access<'a>( }, arg, )) => { - if let (ast::IdOrVector::Reg(dst), Some(src)) = (&arg.dst, arg.src.underlying()) { + if let (ast::DstOperand::Reg(dst), Some(src)) = (&arg.dst, arg.src.underlying()) { if func_args_64bit.contains(src) { multi_hash_map_append(&mut stateful_init_reg, *dst, *src); } @@ -4369,13 +4240,17 @@ fn convert_to_stateful_memory_access<'a>( }), arg, )) => { - if let Some(src1) = arg.src1.underlying() { + if let (ast::DstOperand::Reg(dst), Some(src1)) = + (arg.dst, arg.src1.underlying()) + { if regs_ptr_current.contains(src1) && !regs_ptr_seen.contains(src1) { - regs_ptr_new.insert(arg.dst); + regs_ptr_new.insert(dst); } - } else if let Some(src2) = arg.src2.underlying() { + } else if let (ast::DstOperand::Reg(dst), Some(src2)) = + (arg.dst, arg.src2.underlying()) + { if regs_ptr_current.contains(src2) && !regs_ptr_seen.contains(src2) { - regs_ptr_new.insert(arg.dst); + regs_ptr_new.insert(dst); } } } @@ -4435,10 +4310,11 @@ fn convert_to_stateful_memory_access<'a>( } _ => return Err(TranslateError::Unreachable), }; + let dst = arg.dst.unwrap_reg()?; result.push(Statement::PtrAccess(PtrAccess { underlying_type: ast::PointerType::Scalar(ast::ScalarType::U8), state_space: ast::LdStateSpace::Global, - dst: *remapped_ids.get(&arg.dst).unwrap(), + dst: *remapped_ids.get(&dst).unwrap(), ptr_src: *ptr, offset_src: offset, })) @@ -4472,15 +4348,16 @@ fn convert_to_stateful_memory_access<'a>( }, ast::Arg2 { src: offset, - dst: offset_neg, + dst: ast::DstOperand::Reg(offset_neg), }, ))); + let dst = arg.dst.unwrap_reg()?; result.push(Statement::PtrAccess(PtrAccess { underlying_type: ast::PointerType::Scalar(ast::ScalarType::U8), state_space: ast::LdStateSpace::Global, - dst: *remapped_ids.get(&arg.dst).unwrap(), + dst: *remapped_ids.get(&dst).unwrap(), ptr_src: *ptr, - offset_src: ast::Operand::Reg(offset_neg), + offset_src: ast::SrcOperand::Reg(offset_neg), })) } Statement::Instruction(inst) => { @@ -4617,13 +4494,18 @@ fn convert_to_stateful_memory_access_postprocess( } fn is_add_ptr_direct(remapped_ids: &HashMap, arg: &ast::Arg3) -> bool { - if !remapped_ids.contains_key(&arg.dst) { - return false; - } - match arg.src1.underlying() { - Some(src1) if remapped_ids.contains_key(src1) => true, - Some(src2) if remapped_ids.contains_key(src2) => true, - _ => false, + match arg.dst { + ast::DstOperand::VecMember(..) => return false, + ast::DstOperand::Reg(dst) => { + if !remapped_ids.contains_key(&dst) { + return false; + } + match arg.src1.underlying() { + Some(src1) if remapped_ids.contains_key(src1) => true, + Some(src2) if remapped_ids.contains_key(src2) => true, + _ => false, + } + } } } @@ -4970,6 +4852,8 @@ enum Statement { RetValue(ast::RetData, spirv::Word), Undef(ast::Type, spirv::Word), PtrAccess(PtrAccess

), + PackVector(PackVector), + UnpackVector(UnpackVector), } impl ExpandedStatement { @@ -5056,19 +4940,24 @@ impl ExpandedStatement { offset_src: constant_src, }) } + Statement::PackVector(_) => todo!(), + Statement::UnpackVector(_) => todo!(), } } } +struct PackVector {} +struct UnpackVector {} + struct ResolvedCall { pub uniform: bool, - pub ret_params: Vec<(spirv::Word, ast::FnArgumentType)>, - pub func: spirv::Word, - pub param_list: Vec<(P::CallOperand, ast::FnArgumentType)>, + pub ret_params: Vec<(P::Id, ast::FnArgumentType)>, + pub func: P::Id, + pub param_list: Vec<(P::SrcOperand, ast::FnArgumentType)>, } impl ResolvedCall { - fn cast>(self) -> ResolvedCall { + fn cast>(self) -> ResolvedCall { ResolvedCall { uniform: self.uniform, ret_params: self.ret_params, @@ -5110,7 +4999,7 @@ impl> ResolvedCall { .param_list .into_iter() .map::, _>(|(id, typ)| { - let new_id = visitor.src_call_operand( + let new_id = visitor.src_operand( ArgumentDescriptor { op: id, is_dst: false, @@ -5190,7 +5079,7 @@ impl> PtrAccess

{ }, Some(&ptr_type), )?; - let new_constant_src = visitor.operand( + let new_constant_src = visitor.src_operand( ArgumentDescriptor { op: self.offset_src, is_dst: false, @@ -5243,11 +5132,10 @@ enum NormalizedArgParams {} impl ast::ArgParams for NormalizedArgParams { type Id = spirv::Word; - type Operand = ast::Operand; - type CallOperand = ast::CallOperand; - type IdOrVector = ast::IdOrVector; - type OperandOrVector = ast::OperandOrVector; - type SrcMemberOperand = (spirv::Word, u8); + type DstOperand = ast::DstOperand; + type SrcOperand = ast::SrcOperand; + type DstOperandVec = ast::DstOperandVec; + type SrcOperandVec = ast::SrcOperandVec; } impl ArgParamsEx for NormalizedArgParams { @@ -5273,11 +5161,10 @@ enum TypedArgParams {} impl ast::ArgParams for TypedArgParams { type Id = spirv::Word; - type Operand = ast::Operand; - type CallOperand = ast::CallOperand; - type IdOrVector = ast::IdOrVector; - type OperandOrVector = ast::OperandOrVector; - type SrcMemberOperand = (spirv::Word, u8); + type DstOperand = ast::DstOperand; + type SrcOperand = ast::SrcOperand; + type DstOperandVec = ast::DstOperand; + type SrcOperandVec = ast::SrcOperand; } impl ArgParamsEx for TypedArgParams { @@ -5296,11 +5183,10 @@ type ExpandedStatement = Statement, Expanded impl ast::ArgParams for ExpandedArgParams { type Id = spirv::Word; - type Operand = spirv::Word; - type CallOperand = spirv::Word; - type IdOrVector = spirv::Word; - type OperandOrVector = spirv::Word; - type SrcMemberOperand = spirv::Word; + type DstOperand = spirv::Word; + type SrcOperand = spirv::Word; + type DstOperandVec = spirv::Word; + type SrcOperandVec = spirv::Word; } impl ArgParamsEx for ExpandedArgParams { @@ -5354,31 +5240,26 @@ pub trait ArgumentMapVisitor { desc: ArgumentDescriptor, typ: Option<&ast::Type>, ) -> Result; - fn operand( + fn dst_operand( &mut self, - desc: ArgumentDescriptor, + desc: ArgumentDescriptor, typ: &ast::Type, - ) -> Result; - fn id_or_vector( + ) -> Result; + fn src_operand( &mut self, - desc: ArgumentDescriptor, + desc: ArgumentDescriptor, typ: &ast::Type, - ) -> Result; - fn operand_or_vector( + ) -> Result; + fn dst_operand_vec( &mut self, - desc: ArgumentDescriptor, + desc: ArgumentDescriptor, typ: &ast::Type, - ) -> Result; - fn src_call_operand( + ) -> Result; + fn src_operand_vec( &mut self, - desc: ArgumentDescriptor, + desc: ArgumentDescriptor, typ: &ast::Type, - ) -> Result; - fn src_member_operand( - &mut self, - desc: ArgumentDescriptor, - typ: (ast::ScalarType, u8), - ) -> Result; + ) -> Result; } impl ArgumentMapVisitor for T @@ -5396,15 +5277,7 @@ where self(desc, t) } - fn operand( - &mut self, - desc: ArgumentDescriptor, - t: &ast::Type, - ) -> Result { - self(desc, Some(t)) - } - - fn id_or_vector( + fn dst_operand( &mut self, desc: ArgumentDescriptor, typ: &ast::Type, @@ -5412,7 +5285,7 @@ where self(desc, Some(typ)) } - fn operand_or_vector( + fn src_operand( &mut self, desc: ArgumentDescriptor, typ: &ast::Type, @@ -5420,20 +5293,20 @@ where self(desc, Some(typ)) } - fn src_call_operand( + fn dst_operand_vec( &mut self, desc: ArgumentDescriptor, - t: &ast::Type, + typ: &ast::Type, ) -> Result { - self(desc, Some(t)) + self(desc, Some(typ)) } - fn src_member_operand( + fn src_operand_vec( &mut self, desc: ArgumentDescriptor, - (scalar_type, _): (ast::ScalarType, u8), + typ: &ast::Type, ) -> Result { - self(desc.new_op(desc.op), Some(&ast::Type::Scalar(scalar_type))) + self(desc, Some(typ)) } } @@ -5449,65 +5322,62 @@ where self(desc.op) } - fn operand( + fn dst_operand( &mut self, - desc: ArgumentDescriptor>, - _: &ast::Type, - ) -> Result, TranslateError> { - match desc.op { - ast::Operand::Reg(id) => Ok(ast::Operand::Reg(self(id)?)), - ast::Operand::RegOffset(id, imm) => Ok(ast::Operand::RegOffset(self(id)?, imm)), - ast::Operand::Imm(imm) => Ok(ast::Operand::Imm(imm)), - } + desc: ArgumentDescriptor>, + typ: &ast::Type, + ) -> Result, TranslateError> { + Ok(match desc.op { + ast::DstOperand::Reg(id) => ast::DstOperand::Reg(self(id)?), + ast::DstOperand::VecMember(id, member) => ast::DstOperand::VecMember(self(id)?, member), + }) } - fn id_or_vector( + fn src_operand( &mut self, - desc: ArgumentDescriptor>, - _: &ast::Type, - ) -> Result, TranslateError> { - match desc.op { - ast::IdOrVector::Reg(id) => Ok(ast::IdOrVector::Reg(self(id)?)), - ast::IdOrVector::Vec(ids) => Ok(ast::IdOrVector::Vec( - ids.into_iter().map(self).collect::>()?, - )), - } + desc: ArgumentDescriptor>, + typ: &ast::Type, + ) -> Result, TranslateError> { + Ok(match desc.op { + ast::SrcOperand::Reg(id) => ast::SrcOperand::Reg(self(id)?), + ast::SrcOperand::RegOffset(id, imm) => ast::SrcOperand::RegOffset(self(id)?, imm), + ast::SrcOperand::Imm(imm) => ast::SrcOperand::Imm(imm), + ast::SrcOperand::VecIndex(id, member) => ast::SrcOperand::VecIndex(self(id)?, member), + }) } - fn operand_or_vector( + fn dst_operand_vec( &mut self, - desc: ArgumentDescriptor>, - _: &ast::Type, - ) -> Result, TranslateError> { - match desc.op { - ast::OperandOrVector::Reg(id) => Ok(ast::OperandOrVector::Reg(self(id)?)), - ast::OperandOrVector::RegOffset(id, imm) => { - Ok(ast::OperandOrVector::RegOffset(self(id)?, imm)) + desc: ArgumentDescriptor>, + typ: &ast::Type, + ) -> Result, TranslateError> { + Ok(match desc.op { + ast::DstOperandVec::Normal(inner_op) => { + ast::DstOperandVec::Normal(self.dst_operand(desc.new_op(inner_op), typ)?) } - ast::OperandOrVector::Imm(imm) => Ok(ast::OperandOrVector::Imm(imm)), - ast::OperandOrVector::Vec(ids) => Ok(ast::OperandOrVector::Vec( - ids.into_iter().map(self).collect::>()?, - )), - } + ast::DstOperandVec::Vector(ids) => ast::DstOperandVec::Vector( + ids.into_iter() + .map(|id| self.id(desc.new_op(id), Some(typ))) + .collect::, _>>()?, + ), + }) } - fn src_call_operand( + fn src_operand_vec( &mut self, - desc: ArgumentDescriptor>, - _: &ast::Type, - ) -> Result, TranslateError> { - match desc.op { - ast::CallOperand::Reg(id) => Ok(ast::CallOperand::Reg(self(id)?)), - ast::CallOperand::Imm(imm) => Ok(ast::CallOperand::Imm(imm)), - } - } - - fn src_member_operand( - &mut self, - desc: ArgumentDescriptor<(&str, u8)>, - _: (ast::ScalarType, u8), - ) -> Result<(spirv::Word, u8), TranslateError> { - Ok((self(desc.op.0)?, desc.op.1)) + desc: ArgumentDescriptor>, + typ: &ast::Type, + ) -> Result, TranslateError> { + Ok(match desc.op { + ast::SrcOperandVec::Normal(inner_op) => { + ast::SrcOperandVec::Normal(self.src_operand(desc.new_op(inner_op), typ)?) + } + ast::SrcOperandVec::Vector(ids) => ast::SrcOperandVec::Vector( + ids.into_iter() + .map(|id| self.id(desc.new_op(id), Some(typ))) + .collect::, _>>()?, + ), + }) } } @@ -5522,7 +5392,7 @@ pub struct PtrAccess { state_space: ast::LdStateSpace, dst: spirv::Word, ptr_src: spirv::Word, - offset_src: P::Operand, + offset_src: P::SrcOperand, } #[derive(Copy, Clone, PartialEq, Eq, Debug)] @@ -5846,81 +5716,56 @@ where self(desc, t) } - fn operand( + fn dst_operand( &mut self, - desc: ArgumentDescriptor>, - t: &ast::Type, - ) -> Result, TranslateError> { - match desc.op { - ast::Operand::Reg(id) => Ok(ast::Operand::Reg(self(desc.new_op(id), Some(t))?)), - ast::Operand::Imm(imm) => Ok(ast::Operand::Imm(imm)), - ast::Operand::RegOffset(id, imm) => Ok(ast::Operand::RegOffset( - self(desc.new_op(id), Some(t))?, - imm, - )), - } - } - - fn src_call_operand( - &mut self, - desc: ArgumentDescriptor>, - t: &ast::Type, - ) -> Result, TranslateError> { - match desc.op { - ast::CallOperand::Reg(id) => Ok(ast::CallOperand::Reg(self(desc.new_op(id), Some(t))?)), - ast::CallOperand::Imm(imm) => Ok(ast::CallOperand::Imm(imm)), - } - } - - fn id_or_vector( - &mut self, - desc: ArgumentDescriptor>, + desc: ArgumentDescriptor>, typ: &ast::Type, - ) -> Result, TranslateError> { - match desc.op { - ast::IdOrVector::Reg(id) => Ok(ast::IdOrVector::Reg(self(desc.new_op(id), Some(typ))?)), - ast::IdOrVector::Vec(ref ids) => Ok(ast::IdOrVector::Vec( - ids.iter() - .map(|id| self(desc.new_op(*id), Some(typ))) - .collect::>()?, - )), - } + ) -> Result, TranslateError> { + Ok(match desc.op { + ast::DstOperand::Reg(id) => ast::DstOperand::Reg(self(desc.new_op(id), Some(typ))?), + ast::DstOperand::VecMember(_, _) => todo!(), + }) } - fn operand_or_vector( + fn src_operand( &mut self, - desc: ArgumentDescriptor>, + desc: ArgumentDescriptor>, typ: &ast::Type, - ) -> Result, TranslateError> { - match desc.op { - ast::OperandOrVector::Reg(id) => { - Ok(ast::OperandOrVector::Reg(self(desc.new_op(id), Some(typ))?)) + ) -> Result, TranslateError> { + Ok(match desc.op { + ast::SrcOperand::Reg(id) => ast::SrcOperand::Reg(self(desc.new_op(id), Some(typ))?), + ast::SrcOperand::Imm(imm) => ast::SrcOperand::Imm(imm), + ast::SrcOperand::RegOffset(id, imm) => { + ast::SrcOperand::RegOffset(self(desc.new_op(id), Some(typ))?, imm) } - ast::OperandOrVector::RegOffset(id, imm) => Ok(ast::OperandOrVector::RegOffset( - self(desc.new_op(id), Some(typ))?, - imm, - )), - ast::OperandOrVector::Imm(imm) => Ok(ast::OperandOrVector::Imm(imm)), - ast::OperandOrVector::Vec(ref ids) => Ok(ast::OperandOrVector::Vec( - ids.iter() - .map(|id| self(desc.new_op(*id), Some(typ))) - .collect::>()?, - )), - } + ast::SrcOperand::VecIndex(_, _) => todo!(), + }) } - fn src_member_operand( + fn dst_operand_vec( &mut self, - desc: ArgumentDescriptor<(spirv::Word, u8)>, - (scalar_type, vector_len): (ast::ScalarType, u8), - ) -> Result<(spirv::Word, u8), TranslateError> { - Ok(( - self( - desc.new_op(desc.op.0), - Some(&ast::Type::Vector(scalar_type.into(), vector_len)), - )?, - desc.op.1, - )) + desc: ArgumentDescriptor>, + typ: &ast::Type, + ) -> Result, TranslateError> { + Ok(match desc.op { + ast::DstOperand::Reg(id) => ast::DstOperand::Reg(self(desc.new_op(id), Some(typ))?), + ast::DstOperand::VecMember(_, _) => todo!(), + }) + } + + fn src_operand_vec( + &mut self, + desc: ArgumentDescriptor>, + typ: &ast::Type, + ) -> Result, TranslateError> { + Ok(match desc.op { + ast::SrcOperand::Reg(id) => ast::SrcOperand::Reg(self(desc.new_op(id), Some(typ))?), + ast::SrcOperand::Imm(imm) => ast::SrcOperand::Imm(imm), + ast::SrcOperand::RegOffset(id, imm) => { + ast::SrcOperand::RegOffset(self(desc.new_op(id), Some(typ))?, imm) + } + ast::SrcOperand::VecIndex(_, _) => todo!(), + }) } } @@ -6330,10 +6175,6 @@ impl From for ast::Type { } impl ast::Arg1 { - fn cast>(self) -> ast::Arg1 { - ast::Arg1 { src: self.src } - } - fn map>( self, visitor: &mut V, @@ -6352,15 +6193,11 @@ impl ast::Arg1 { } impl ast::Arg1Bar { - fn cast>(self) -> ast::Arg1Bar { - ast::Arg1Bar { src: self.src } - } - fn map>( self, visitor: &mut V, ) -> Result, TranslateError> { - let new_src = visitor.operand( + let new_src = visitor.src_operand( ArgumentDescriptor { op: self.src, is_dst: false, @@ -6373,27 +6210,20 @@ impl ast::Arg1Bar { } impl ast::Arg2 { - fn cast>(self) -> ast::Arg2 { - ast::Arg2 { - src: self.src, - dst: self.dst, - } - } - fn map>( self, visitor: &mut V, t: &ast::Type, ) -> Result, TranslateError> { - let new_dst = visitor.id( + let new_dst = visitor.dst_operand( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, - Some(t), + t, )?; - let new_src = visitor.operand( + let new_src = visitor.src_operand( ArgumentDescriptor { op: self.src, is_dst: false, @@ -6413,15 +6243,15 @@ impl ast::Arg2 { dst_t: &ast::Type, src_t: &ast::Type, ) -> Result, TranslateError> { - let dst = visitor.id( + let dst = visitor.dst_operand( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, - Some(dst_t), + dst_t, )?; - let src = visitor.operand( + let src = visitor.src_operand( ArgumentDescriptor { op: self.src, is_dst: false, @@ -6434,21 +6264,12 @@ impl ast::Arg2 { } impl ast::Arg2Ld { - fn cast>( - self, - ) -> ast::Arg2Ld { - ast::Arg2Ld { - dst: self.dst, - src: self.src, - } - } - fn map>( self, visitor: &mut V, details: &ast::LdDetails, ) -> Result, TranslateError> { - let dst = visitor.id_or_vector( + let dst = visitor.dst_operand_vec( ArgumentDescriptor { op: self.dst, is_dst: true, @@ -6458,7 +6279,7 @@ impl ast::Arg2Ld { )?; let is_logical_ptr = details.state_space == ast::LdStateSpace::Param || details.state_space == ast::LdStateSpace::Local; - let src = visitor.operand( + let src = visitor.src_operand( ArgumentDescriptor { op: self.src, is_dst: false, @@ -6478,15 +6299,6 @@ impl ast::Arg2Ld { } impl ast::Arg2St { - fn cast>( - self, - ) -> ast::Arg2St { - ast::Arg2St { - src1: self.src1, - src2: self.src2, - } - } - fn map>( self, visitor: &mut V, @@ -6494,7 +6306,7 @@ impl ast::Arg2St { ) -> Result, TranslateError> { let is_logical_ptr = details.state_space == ast::StStateSpace::Param || details.state_space == ast::StStateSpace::Local; - let src1 = visitor.operand( + let src1 = visitor.src_operand( ArgumentDescriptor { op: self.src1, is_dst: false, @@ -6509,7 +6321,7 @@ impl ast::Arg2St { details.state_space.to_ld_ss(), ), )?; - let src2 = visitor.operand_or_vector( + let src2 = visitor.src_operand_vec( ArgumentDescriptor { op: self.src2, is_dst: false, @@ -6527,29 +6339,7 @@ impl ast::Arg2Mov { visitor: &mut V, details: &ast::MovDetails, ) -> Result, TranslateError> { - Ok(match self { - ast::Arg2Mov::Normal(arg) => ast::Arg2Mov::Normal(arg.map(visitor, details)?), - ast::Arg2Mov::Member(arg) => ast::Arg2Mov::Member(arg.map(visitor, details)?), - }) - } -} - -impl ast::Arg2MovNormal

{ - fn cast>( - self, - ) -> ast::Arg2MovNormal { - ast::Arg2MovNormal { - dst: self.dst, - src: self.src, - } - } - - fn map>( - self, - visitor: &mut V, - details: &ast::MovDetails, - ) -> Result, TranslateError> { - let dst = visitor.id_or_vector( + let dst = visitor.dst_operand_vec( ArgumentDescriptor { op: self.dst, is_dst: true, @@ -6557,7 +6347,7 @@ impl ast::Arg2MovNormal

{ }, &details.typ.clone().into(), )?; - let src = visitor.operand_or_vector( + let src = visitor.src_operand_vec( ArgumentDescriptor { op: self.src, is_dst: false, @@ -6569,144 +6359,11 @@ impl ast::Arg2MovNormal

{ }, &details.typ.clone().into(), )?; - Ok(ast::Arg2MovNormal { dst, src }) - } -} - -impl ast::Arg2MovMember { - fn cast>( - self, - ) -> ast::Arg2MovMember { - match self { - ast::Arg2MovMember::Dst(dst, src1, src2) => ast::Arg2MovMember::Dst(dst, src1, src2), - ast::Arg2MovMember::Src(dst, src) => ast::Arg2MovMember::Src(dst, src), - ast::Arg2MovMember::Both(dst, src1, src2) => ast::Arg2MovMember::Both(dst, src1, src2), - } - } - - fn vector_dst(&self) -> Option<&T::Id> { - match self { - ast::Arg2MovMember::Src(_, _) => None, - ast::Arg2MovMember::Dst((d, _), _, _) | ast::Arg2MovMember::Both((d, _), _, _) => { - Some(d) - } - } - } - - fn vector_src(&self) -> Option<&T::SrcMemberOperand> { - match self { - ast::Arg2MovMember::Src(_, d) | ast::Arg2MovMember::Both(_, _, d) => Some(d), - ast::Arg2MovMember::Dst(_, _, _) => None, - } - } -} - -impl ast::Arg2MovMember { - fn map>( - self, - visitor: &mut V, - details: &ast::MovDetails, - ) -> Result, TranslateError> { - match self { - ast::Arg2MovMember::Dst((dst, len), composite_src, scalar_src) => { - let scalar_type = details.typ.get_scalar()?; - let dst = visitor.id( - ArgumentDescriptor { - op: dst, - is_dst: true, - sema: ArgumentSemantics::Default, - }, - Some(&ast::Type::Vector(scalar_type, details.dst_width)), - )?; - let src1 = visitor.id( - ArgumentDescriptor { - op: composite_src, - is_dst: false, - sema: ArgumentSemantics::Default, - }, - Some(&ast::Type::Vector(scalar_type, details.dst_width)), - )?; - let src2 = visitor.id( - ArgumentDescriptor { - op: scalar_src, - is_dst: false, - sema: if details.src_is_address { - ArgumentSemantics::Address - } else if details.relaxed_src2_conv { - ArgumentSemantics::DefaultRelaxed - } else { - ArgumentSemantics::Default - }, - }, - Some(&details.typ.clone().into()), - )?; - Ok(ast::Arg2MovMember::Dst((dst, len), src1, src2)) - } - ast::Arg2MovMember::Src(dst, src) => { - let dst = visitor.id( - ArgumentDescriptor { - op: dst, - is_dst: true, - sema: ArgumentSemantics::Default, - }, - Some(&details.typ.clone().into()), - )?; - let scalar_typ = details.typ.get_scalar()?; - let src = visitor.src_member_operand( - ArgumentDescriptor { - op: src, - is_dst: false, - sema: ArgumentSemantics::Default, - }, - (scalar_typ.into(), details.src_width), - )?; - Ok(ast::Arg2MovMember::Src(dst, src)) - } - ast::Arg2MovMember::Both((dst, len), composite_src, src) => { - let scalar_type = details.typ.get_scalar()?; - let dst = visitor.id( - ArgumentDescriptor { - op: dst, - is_dst: true, - sema: ArgumentSemantics::Default, - }, - Some(&ast::Type::Vector(scalar_type, details.dst_width)), - )?; - let composite_src = visitor.id( - ArgumentDescriptor { - op: composite_src, - is_dst: false, - sema: ArgumentSemantics::Default, - }, - Some(&ast::Type::Vector(scalar_type, details.dst_width)), - )?; - let src = visitor.src_member_operand( - ArgumentDescriptor { - op: src, - is_dst: false, - sema: if details.relaxed_src2_conv { - ArgumentSemantics::DefaultRelaxed - } else { - ArgumentSemantics::Default - }, - }, - (scalar_type.into(), details.src_width), - )?; - Ok(ast::Arg2MovMember::Both((dst, len), composite_src, src)) - } - } + Ok(ast::Arg2Mov { dst, src }) } } impl ast::Arg3 { - fn cast>(self) -> ast::Arg3 { - ast::Arg3 { - dst: self.dst, - src1: self.src1, - src2: self.src2, - } - } - fn map_non_shift>( self, visitor: &mut V, @@ -6718,15 +6375,15 @@ impl ast::Arg3 { } else { None }; - let dst = visitor.id( + let dst = visitor.dst_operand( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, - Some(wide_type.as_ref().unwrap_or(typ)), + wide_type.as_ref().unwrap_or(typ), )?; - let src1 = visitor.operand( + let src1 = visitor.src_operand( ArgumentDescriptor { op: self.src1, is_dst: false, @@ -6734,7 +6391,7 @@ impl ast::Arg3 { }, typ, )?; - let src2 = visitor.operand( + let src2 = visitor.src_operand( ArgumentDescriptor { op: self.src2, is_dst: false, @@ -6750,15 +6407,15 @@ impl ast::Arg3 { visitor: &mut V, t: &ast::Type, ) -> Result, TranslateError> { - let dst = visitor.id( + let dst = visitor.dst_operand( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, - Some(t), + t, )?; - let src1 = visitor.operand( + let src1 = visitor.src_operand( ArgumentDescriptor { op: self.src1, is_dst: false, @@ -6766,7 +6423,7 @@ impl ast::Arg3 { }, t, )?; - let src2 = visitor.operand( + let src2 = visitor.src_operand( ArgumentDescriptor { op: self.src2, is_dst: false, @@ -6784,15 +6441,15 @@ impl ast::Arg3 { state_space: ast::AtomSpace, ) -> Result, TranslateError> { let scalar_type = ast::ScalarType::from(t); - let dst = visitor.id( + let dst = visitor.dst_operand( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, - Some(&ast::Type::Scalar(scalar_type)), + &ast::Type::Scalar(scalar_type), )?; - let src1 = visitor.operand( + let src1 = visitor.src_operand( ArgumentDescriptor { op: self.src1, is_dst: false, @@ -6803,7 +6460,7 @@ impl ast::Arg3 { state_space.to_ld_ss(), ), )?; - let src2 = visitor.operand( + let src2 = visitor.src_operand( ArgumentDescriptor { op: self.src2, is_dst: false, @@ -6816,15 +6473,6 @@ impl ast::Arg3 { } impl ast::Arg4 { - fn cast>(self) -> ast::Arg4 { - ast::Arg4 { - dst: self.dst, - src1: self.src1, - src2: self.src2, - src3: self.src3, - } - } - fn map>( self, visitor: &mut V, @@ -6836,15 +6484,15 @@ impl ast::Arg4 { } else { None }; - let dst = visitor.id( + let dst = visitor.dst_operand( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, - Some(wide_type.as_ref().unwrap_or(t)), + wide_type.as_ref().unwrap_or(t), )?; - let src1 = visitor.operand( + let src1 = visitor.src_operand( ArgumentDescriptor { op: self.src1, is_dst: false, @@ -6852,7 +6500,7 @@ impl ast::Arg4 { }, t, )?; - let src2 = visitor.operand( + let src2 = visitor.src_operand( ArgumentDescriptor { op: self.src2, is_dst: false, @@ -6860,7 +6508,7 @@ impl ast::Arg4 { }, t, )?; - let src3 = visitor.operand( + let src3 = visitor.src_operand( ArgumentDescriptor { op: self.src3, is_dst: false, @@ -6881,15 +6529,15 @@ impl ast::Arg4 { visitor: &mut V, t: ast::SelpType, ) -> Result, TranslateError> { - let dst = visitor.id( + let dst = visitor.dst_operand( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, - Some(&ast::Type::Scalar(t.into())), + &ast::Type::Scalar(t.into()), )?; - let src1 = visitor.operand( + let src1 = visitor.src_operand( ArgumentDescriptor { op: self.src1, is_dst: false, @@ -6897,7 +6545,7 @@ impl ast::Arg4 { }, &ast::Type::Scalar(t.into()), )?; - let src2 = visitor.operand( + let src2 = visitor.src_operand( ArgumentDescriptor { op: self.src2, is_dst: false, @@ -6905,7 +6553,7 @@ impl ast::Arg4 { }, &ast::Type::Scalar(t.into()), )?; - let src3 = visitor.operand( + let src3 = visitor.src_operand( ArgumentDescriptor { op: self.src3, is_dst: false, @@ -6928,15 +6576,15 @@ impl ast::Arg4 { state_space: ast::AtomSpace, ) -> Result, TranslateError> { let scalar_type = ast::ScalarType::from(t); - let dst = visitor.id( + let dst = visitor.dst_operand( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, - Some(&ast::Type::Scalar(scalar_type)), + &ast::Type::Scalar(scalar_type), )?; - let src1 = visitor.operand( + let src1 = visitor.src_operand( ArgumentDescriptor { op: self.src1, is_dst: false, @@ -6947,7 +6595,7 @@ impl ast::Arg4 { state_space.to_ld_ss(), ), )?; - let src2 = visitor.operand( + let src2 = visitor.src_operand( ArgumentDescriptor { op: self.src2, is_dst: false, @@ -6955,7 +6603,7 @@ impl ast::Arg4 { }, &ast::Type::Scalar(scalar_type), )?; - let src3 = visitor.operand( + let src3 = visitor.src_operand( ArgumentDescriptor { op: self.src3, is_dst: false, @@ -6976,15 +6624,15 @@ impl ast::Arg4 { visitor: &mut V, typ: &ast::Type, ) -> Result, TranslateError> { - let dst = visitor.id( + let dst = visitor.dst_operand( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, - Some(typ), + typ, )?; - let src1 = visitor.operand( + let src1 = visitor.src_operand( ArgumentDescriptor { op: self.src1, is_dst: false, @@ -6993,7 +6641,7 @@ impl ast::Arg4 { typ, )?; let u32_type = ast::Type::Scalar(ast::ScalarType::U32); - let src2 = visitor.operand( + let src2 = visitor.src_operand( ArgumentDescriptor { op: self.src2, is_dst: false, @@ -7001,7 +6649,7 @@ impl ast::Arg4 { }, &u32_type, )?; - let src3 = visitor.operand( + let src3 = visitor.src_operand( ArgumentDescriptor { op: self.src3, is_dst: false, @@ -7019,15 +6667,6 @@ impl ast::Arg4 { } impl ast::Arg4Setp { - fn cast>(self) -> ast::Arg4Setp { - ast::Arg4Setp { - dst1: self.dst1, - dst2: self.dst2, - src1: self.src1, - src2: self.src2, - } - } - fn map>( self, visitor: &mut V, @@ -7054,7 +6693,7 @@ impl ast::Arg4Setp { ) }) .transpose()?; - let src1 = visitor.operand( + let src1 = visitor.src_operand( ArgumentDescriptor { op: self.src1, is_dst: false, @@ -7062,7 +6701,7 @@ impl ast::Arg4Setp { }, t, )?; - let src2 = visitor.operand( + let src2 = visitor.src_operand( ArgumentDescriptor { op: self.src2, is_dst: false, @@ -7079,22 +6718,12 @@ impl ast::Arg4Setp { } } -impl ast::Arg5 { - fn cast>(self) -> ast::Arg5 { - ast::Arg5 { - dst1: self.dst1, - dst2: self.dst2, - src1: self.src1, - src2: self.src2, - src3: self.src3, - } - } - +impl ast::Arg5Setp { fn map>( self, visitor: &mut V, t: &ast::Type, - ) -> Result, TranslateError> { + ) -> Result, TranslateError> { let dst1 = visitor.id( ArgumentDescriptor { op: self.dst1, @@ -7116,7 +6745,7 @@ impl ast::Arg5 { ) }) .transpose()?; - let src1 = visitor.operand( + let src1 = visitor.src_operand( ArgumentDescriptor { op: self.src1, is_dst: false, @@ -7124,7 +6753,7 @@ impl ast::Arg5 { }, t, )?; - let src2 = visitor.operand( + let src2 = visitor.src_operand( ArgumentDescriptor { op: self.src2, is_dst: false, @@ -7132,7 +6761,7 @@ impl ast::Arg5 { }, t, )?; - let src3 = visitor.operand( + let src3 = visitor.src_operand( ArgumentDescriptor { op: self.src3, is_dst: false, @@ -7140,7 +6769,7 @@ impl ast::Arg5 { }, &ast::Type::Scalar(ast::ScalarType::Pred), )?; - Ok(ast::Arg5 { + Ok(ast::Arg5Setp { dst1, dst2, src1, @@ -7166,14 +6795,33 @@ impl ast::Type { } } -impl ast::CallOperand { +impl ast::SrcOperand { fn map_variable Result>( self, f: &mut F, - ) -> Result, TranslateError> { + ) -> Result, TranslateError> { + Ok(match self { + ast::SrcOperand::Reg(reg) => ast::SrcOperand::Reg(f(reg)?), + ast::SrcOperand::RegOffset(reg, offset) => ast::SrcOperand::RegOffset(f(reg)?, offset), + ast::SrcOperand::Imm(x) => ast::SrcOperand::Imm(x), + ast::SrcOperand::VecIndex(reg, idx) => ast::SrcOperand::VecIndex(f(reg)?, idx), + }) + } +} + +impl ast::DstOperand { + fn to_src_operand(self) -> ast::SrcOperand { match self { - ast::CallOperand::Reg(id) => Ok(ast::CallOperand::Reg(f(id)?)), - ast::CallOperand::Imm(x) => Ok(ast::CallOperand::Imm(x)), + ast::DstOperand::Reg(reg) => ast::SrcOperand::Reg(reg), + ast::DstOperand::VecMember(reg, idx) => ast::SrcOperand::VecIndex(reg, idx), + } + } +} +impl ast::DstOperand { + fn unwrap_reg(&self) -> Result { + match self { + ast::DstOperand::Reg(reg) => Ok(*reg), + ast::DstOperand::VecMember(..) => Err(TranslateError::Unreachable), } } } @@ -7389,20 +7037,12 @@ impl From for ast::VariableType { } } -impl ast::Operand { +impl ast::SrcOperand { fn underlying(&self) -> Option<&T> { match self { - ast::Operand::Reg(r) | ast::Operand::RegOffset(r, _) => Some(r), - ast::Operand::Imm(_) => None, - } - } -} - -impl ast::OperandOrVector { - fn single_underlying(&self) -> Option<&T> { - match self { - ast::OperandOrVector::Reg(r) | ast::OperandOrVector::RegOffset(r, _) => Some(r), - ast::OperandOrVector::Imm(_) | ast::OperandOrVector::Vec(_) => None, + ast::SrcOperand::Reg(r) | ast::SrcOperand::RegOffset(r, _) => Some(r), + ast::SrcOperand::Imm(_) => None, + ast::SrcOperand::VecIndex(reg, _) => Some(reg), } } }