diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 413d3f5..aba6bda 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -614,15 +614,12 @@ pub struct CallInst { pub uniform: bool, pub ret_params: Vec, pub func: P::Id, - pub param_list: Vec, + pub param_list: Vec, } pub trait ArgParams { type Id; - type DstOperand; - type SrcOperand; - type DstOperandVec; - type SrcOperandVec; + type Operand; } pub struct ParsedArgParams<'a> { @@ -631,10 +628,7 @@ pub struct ParsedArgParams<'a> { impl<'a> ArgParams for ParsedArgParams<'a> { type Id = &'a str; - type DstOperand = DstOperand<&'a str>; - type SrcOperand = SrcOperand<&'a str>; - type DstOperandVec = DstOperandVec<&'a str>; - type SrcOperandVec = SrcOperandVec<&'a str>; + type Operand = Operand<&'a str>; } pub struct Arg1 { @@ -642,54 +636,54 @@ pub struct Arg1 { } pub struct Arg1Bar { - pub src: P::SrcOperand, + pub src: P::Operand, } pub struct Arg2 { - pub dst: P::DstOperand, - pub src: P::SrcOperand, + pub dst: P::Operand, + pub src: P::Operand, } pub struct Arg2Ld { - pub dst: P::DstOperandVec, - pub src: P::SrcOperand, + pub dst: P::Operand, + pub src: P::Operand, } pub struct Arg2St { - pub src1: P::SrcOperand, - pub src2: P::SrcOperandVec, + pub src1: P::Operand, + pub src2: P::Operand, } pub struct Arg2Mov { - pub dst: P::DstOperandVec, - pub src: P::SrcOperandVec, + pub dst: P::Operand, + pub src: P::Operand, } pub struct Arg3 { - pub dst: P::DstOperand, - pub src1: P::SrcOperand, - pub src2: P::SrcOperand, + pub dst: P::Operand, + pub src1: P::Operand, + pub src2: P::Operand, } pub struct Arg4 { - pub dst: P::DstOperand, - pub src1: P::SrcOperand, - pub src2: P::SrcOperand, - pub src3: P::SrcOperand, + pub dst: P::Operand, + pub src1: P::Operand, + pub src2: P::Operand, + pub src3: P::Operand, } pub struct Arg4Setp { pub dst1: P::Id, pub dst2: Option, - pub src1: P::SrcOperand, - pub src2: P::SrcOperand, + pub src1: P::Operand, + pub src2: P::Operand, } pub struct Arg5Setp { pub dst1: P::Id, pub dst2: Option, - pub src1: P::SrcOperand, - pub src2: P::SrcOperand, - pub src3: P::SrcOperand, + pub src1: P::Operand, + pub src2: P::Operand, + pub src3: P::Operand, } #[derive(Copy, Clone)] @@ -700,30 +694,13 @@ pub enum ImmediateValue { F64(f64), } -#[derive(Copy, Clone)] -pub enum DstOperand { - Reg(ID), - VecMember(ID, u8), -} - #[derive(Clone)] -pub enum DstOperandVec { - Normal(DstOperand), - Vector(Vec), -} - -#[derive(Copy, Clone)] -pub enum SrcOperand { +pub enum Operand { Reg(Id), RegOffset(Id, i32), Imm(ImmediateValue), VecMember(Id, u8), -} - -#[derive(Clone)] -pub enum SrcOperandVec { - Normal(SrcOperand), - Vector(Vec), + VecPack(Vec), } pub enum VectorPrefix { diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 6d9f93d..fd2a3f1 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -1177,7 +1177,7 @@ InstSt: ast::Instruction> = { }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#using-addresses-arrays-and-vectors -MemoryOperand: ast::SrcOperand<&'input str> = { +MemoryOperand: ast::Operand<&'input str> = { "[" "]" => o } @@ -1734,15 +1734,15 @@ ArithFloatMustRound: ast::ArithFloat = { }, } -Operand: ast::SrcOperand<&'input str> = { - => ast::SrcOperand::Reg(r), - "+" => ast::SrcOperand::RegOffset(r, offset), - => ast::SrcOperand::Imm(x) +Operand: ast::Operand<&'input str> = { + => ast::Operand::Reg(r), + "+" => ast::Operand::RegOffset(r, offset), + => ast::Operand::Imm(x) }; -CallOperand: ast::SrcOperand<&'input str> = { - => ast::SrcOperand::Reg(r), - => ast::SrcOperand::Imm(x) +CallOperand: ast::Operand<&'input str> = { + => ast::Operand::Reg(r), + => ast::Operand::Imm(x) }; // TODO: start parsing whole constants sub-language: @@ -1838,44 +1838,44 @@ Arg5Setp: ast::Arg5Setp> = { "," "," "," "!"? => ast::Arg5Setp{<>} }; -ArgCall: (Vec<&'input str>, &'input str, Vec>) = { +ArgCall: (Vec<&'input str>, &'input str, Vec>) = { "(" > ")" "," "," "(" > ")" => { (ret_params, func, param_list) }, "," "(" > ")" => (Vec::new(), func, param_list), - => (Vec::new(), func, Vec::>::new()), + => (Vec::new(), func, Vec::>::new()), }; OptionalDst: &'input str = { "|" => dst2 } -SrcOperand: ast::SrcOperand<&'input str> = { - => ast::SrcOperand::Reg(r), - "+" => ast::SrcOperand::RegOffset(r, offset), - => ast::SrcOperand::Imm(x), +SrcOperand: ast::Operand<&'input str> = { + => ast::Operand::Reg(r), + "+" => ast::Operand::RegOffset(r, offset), + => ast::Operand::Imm(x), => { let (reg, idx) = mem_op; - ast::SrcOperand::VecMember(reg, idx) + ast::Operand::VecMember(reg, idx) } } -SrcOperandVec: ast::SrcOperandVec<&'input str> = { - => ast::SrcOperandVec::Normal(normal), - => ast::SrcOperandVec::Vector(vec), +SrcOperandVec: ast::Operand<&'input str> = { + => normal, + => ast::Operand::VecPack(vec), } -DstOperand: ast::DstOperand<&'input str> = { - => ast::DstOperand::Reg(r), +DstOperand: ast::Operand<&'input str> = { + => ast::Operand::Reg(r), => { let (reg, idx) = mem_op; - ast::DstOperand::VecMember(reg, idx) + ast::Operand::VecMember(reg, idx) } } -DstOperandVec: ast::DstOperandVec<&'input str> = { - => ast::DstOperandVec::Normal(normal), - => ast::DstOperandVec::Vector(vec), +DstOperandVec: ast::Operand<&'input str> = { + => normal, + => ast::Operand::VecPack(vec), } VectorPrefix: u8 = { diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index ca64e60..20578eb 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1457,7 +1457,7 @@ fn convert_to_typed_statements( .partition(|(_, arg_type)| arg_type.is_param()); let normalized_input_args = out_params .into_iter() - .map(|(id, typ)| (ast::SrcOperand::Reg(id), typ)) + .map(|(id, typ)| (ast::Operand::Reg(id), typ)) .chain(in_args.into_iter()) .collect(); let resolved_call = ResolvedCall { @@ -1466,15 +1466,12 @@ fn convert_to_typed_statements( func: call.func, param_list: normalized_input_args, }; - result.push(Statement::Call(resolved_call)); + let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); + let reresolved_call = resolved_call.visit(&mut visitor)?; + visitor.func.push(reresolved_call); + visitor.func.extend(visitor.post_stmts); } - ast::Instruction::Mov( - mut d, - ast::Arg2Mov { - dst, - src: ast::SrcOperandVec::Normal(src), - }, - ) => { + ast::Instruction::Mov(mut d, ast::Arg2Mov { dst, src }) => { if let Some(src_id) = src.underlying() { let (typ, _) = id_defs.get_typed(*src_id)?; let take_address = match typ { @@ -1487,14 +1484,7 @@ fn convert_to_typed_statements( } let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); let instruction = Statement::Instruction( - ast::Instruction::Mov( - d, - ast::Arg2Mov { - dst, - src: ast::SrcOperandVec::Normal(src), - }, - ) - .map(&mut visitor)?, + ast::Instruction::Mov(d, ast::Arg2Mov { dst, src }).map(&mut visitor)?, ); visitor.func.push(instruction); visitor.func.extend(visitor.post_stmts); @@ -1570,52 +1560,20 @@ impl<'a, 'b> ArgumentMapVisitor Ok(desc.op) } - fn dst_operand( + fn operand( &mut self, - desc: ArgumentDescriptor>, - _: &ast::Type, - ) -> Result, TranslateError> { - Ok(desc.op) - } - - fn src_operand( - &mut self, - desc: ArgumentDescriptor>, - _: &ast::Type, - ) -> Result, TranslateError> { - Ok(desc.op) - } - - fn dst_operand_vec( - &mut self, - desc: ArgumentDescriptor>, + desc: ArgumentDescriptor>, typ: &ast::Type, - ) -> Result, TranslateError> { - match desc.op { - ast::DstOperandVec::Normal(op) => self.dst_operand(desc.new_op(op), typ), - ast::DstOperandVec::Vector(vec) => Ok(ast::DstOperand::Reg(self.convert_vector( - desc.is_dst, - desc.sema, - typ, - vec, - )?)), - } - } - - fn src_operand_vec( - &mut self, - desc: ArgumentDescriptor>, - typ: &ast::Type, - ) -> Result, TranslateError> { - match desc.op { - ast::SrcOperandVec::Normal(op) => self.src_operand(desc.new_op(op), typ), - ast::SrcOperandVec::Vector(vec) => Ok(ast::SrcOperand::Reg(self.convert_vector( - desc.is_dst, - desc.sema, - typ, - vec, - )?)), - } + ) -> Result { + Ok(match desc.op { + ast::Operand::Reg(reg) => TypedOperand::Reg(reg), + ast::Operand::RegOffset(reg, offset) => TypedOperand::RegOffset(reg, offset), + ast::Operand::Imm(x) => TypedOperand::Imm(x), + ast::Operand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx), + ast::Operand::VecPack(vec) => { + TypedOperand::Reg(self.convert_vector(desc.is_dst, desc.sema, typ, vec)?) + } + }) } } @@ -2145,72 +2103,21 @@ impl<'a, 'input> ArgumentMapVisitor self.symbol(desc.new_op((desc.op, None)), typ) } - fn dst_operand( + fn operand( &mut self, - desc: ArgumentDescriptor>, + desc: ArgumentDescriptor, typ: &ast::Type, - ) -> Result, TranslateError> { + ) -> Result { Ok(match desc.op { - ast::DstOperand::Reg(reg) => { - ast::DstOperand::Reg(self.symbol(desc.new_op((reg, None)), Some(typ))?) + TypedOperand::Reg(reg) => { + TypedOperand::Reg(self.symbol(desc.new_op((reg, None)), Some(typ))?) } - ast::DstOperand::VecMember(symbol, index) => { - ast::DstOperand::Reg(self.symbol(desc.new_op((symbol, Some(index))), Some(typ))?) + TypedOperand::RegOffset(reg, offset) => { + TypedOperand::RegOffset(self.symbol(desc.new_op((reg, None)), Some(typ))?, offset) } - }) - } - - fn src_operand( - &mut self, - desc: ArgumentDescriptor>, - typ: &ast::Type, - ) -> Result, TranslateError> { - Ok(match desc.op { - ast::SrcOperand::Reg(reg) => { - ast::SrcOperand::Reg(self.symbol(desc.new_op((reg, None)), Some(typ))?) - } - ast::SrcOperand::RegOffset(reg, offset) => ast::SrcOperand::RegOffset( - self.symbol(desc.new_op((reg, None)), Some(typ))?, - offset, - ), - op @ ast::SrcOperand::Imm(..) => op, - ast::SrcOperand::VecMember(symbol, index) => { - ast::SrcOperand::Reg(self.symbol(desc.new_op((symbol, Some(index))), Some(typ))?) - } - }) - } - - fn dst_operand_vec( - &mut self, - desc: ArgumentDescriptor>, - typ: &ast::Type, - ) -> Result, TranslateError> { - Ok(match desc.op { - ast::DstOperand::Reg(reg) => { - ast::DstOperand::Reg(self.symbol(desc.new_op((reg, None)), Some(typ))?) - } - ast::DstOperand::VecMember(symbol, index) => { - ast::DstOperand::Reg(self.symbol(desc.new_op((symbol, Some(index))), Some(typ))?) - } - }) - } - - fn src_operand_vec( - &mut self, - desc: ArgumentDescriptor>, - typ: &ast::Type, - ) -> Result, TranslateError> { - Ok(match desc.op { - ast::SrcOperand::Reg(reg) => { - ast::SrcOperand::Reg(self.symbol(desc.new_op((reg, None)), Some(typ))?) - } - ast::SrcOperand::RegOffset(reg, offset) => ast::SrcOperand::RegOffset( - self.symbol(desc.new_op((reg, None)), Some(typ))?, - offset, - ), - op @ ast::SrcOperand::Imm(..) => op, - ast::SrcOperand::VecMember(symbol, index) => { - ast::SrcOperand::Reg(self.symbol(desc.new_op((symbol, Some(index))), Some(typ))?) + op @ TypedOperand::Imm(..) => op, + TypedOperand::VecMember(symbol, index) => { + TypedOperand::Reg(self.symbol(desc.new_op((symbol, Some(index))), Some(typ))?) } }) } @@ -2436,55 +2343,18 @@ impl<'a, 'b> ArgumentMapVisitor for FlattenAr self.reg(desc, t) } - fn dst_operand( + fn operand( &mut self, - desc: ArgumentDescriptor>, + desc: ArgumentDescriptor, typ: &ast::Type, ) -> Result { match desc.op { - ast::DstOperand::Reg(reg) => self.reg(desc.new_op(reg), Some(typ)), - ast::DstOperand::VecMember(..) => Err(error_unreachable()), - } - } - - fn src_operand( - &mut self, - desc: ArgumentDescriptor>, - typ: &ast::Type, - ) -> Result { - match desc.op { - ast::SrcOperand::Reg(r) => self.reg(desc.new_op(r), Some(typ)), - ast::SrcOperand::Imm(x) => self.immediate(desc.new_op(x), typ), - ast::SrcOperand::RegOffset(reg, offset) => { + TypedOperand::Reg(r) => self.reg(desc.new_op(r), Some(typ)), + TypedOperand::Imm(x) => self.immediate(desc.new_op(x), typ), + TypedOperand::RegOffset(reg, offset) => { self.reg_offset(desc.new_op((reg, offset)), typ) } - ast::SrcOperand::VecMember(..) => Err(error_unreachable()), - } - } - - fn dst_operand_vec( - &mut self, - desc: ArgumentDescriptor>, - typ: &ast::Type, - ) -> Result { - match desc.op { - ast::DstOperand::Reg(reg) => self.reg(desc.new_op(reg), Some(typ)), - ast::DstOperand::VecMember(..) => Err(error_unreachable()), - } - } - - fn src_operand_vec( - &mut self, - desc: ArgumentDescriptor>, - typ: &ast::Type, - ) -> Result { - match desc.op { - ast::SrcOperand::Reg(r) => self.reg(desc.new_op(r), Some(typ)), - ast::SrcOperand::Imm(x) => self.immediate(desc.new_op(x), typ), - ast::SrcOperand::RegOffset(reg, offset) => { - self.reg_offset(desc.new_op((reg, offset)), typ) - } - ast::SrcOperand::VecMember(..) => Err(error_unreachable()), + TypedOperand::VecMember(..) => Err(error_unreachable()), } } } @@ -4346,7 +4216,9 @@ fn convert_to_stateful_memory_access<'a>( }, arg, )) => { - if let (ast::DstOperand::Reg(dst), Some(src)) = (arg.dst, arg.src.underlying()) { + if let (TypedOperand::Reg(dst), Some(src)) = + (arg.dst, arg.src.upcast().underlying()) + { if is_64_bit_integer(id_defs, *src) && is_64_bit_integer(id_defs, dst) { stateful_markers.push((dst, *src)); } @@ -4376,7 +4248,9 @@ fn convert_to_stateful_memory_access<'a>( }, arg, )) => { - if let (ast::DstOperand::Reg(dst), Some(src)) = (&arg.dst, arg.src.underlying()) { + if let (TypedOperand::Reg(dst), Some(src)) = + (&arg.dst, arg.src.upcast().underlying()) + { if func_args_64bit.contains(src) { multi_hash_map_append(&mut stateful_init_reg, *dst, *src); } @@ -4425,14 +4299,14 @@ fn convert_to_stateful_memory_access<'a>( }), arg, )) => { - if let (ast::DstOperand::Reg(dst), Some(src1)) = - (arg.dst, arg.src1.underlying()) + if let (TypedOperand::Reg(dst), Some(src1)) = + (arg.dst, arg.src1.upcast().underlying()) { if regs_ptr_current.contains(src1) && !regs_ptr_seen.contains(src1) { regs_ptr_new.insert(dst); } - } else if let (ast::DstOperand::Reg(dst), Some(src2)) = - (arg.dst, arg.src2.underlying()) + } else if let (TypedOperand::Reg(dst), Some(src2)) = + (arg.dst, arg.src2.upcast().underlying()) { if regs_ptr_current.contains(src2) && !regs_ptr_seen.contains(src2) { regs_ptr_new.insert(dst); @@ -4486,7 +4360,7 @@ fn convert_to_stateful_memory_access<'a>( }), arg, )) if is_add_ptr_direct(&remapped_ids, &arg) => { - let (ptr, offset) = match arg.src1.underlying() { + let (ptr, offset) = match arg.src1.upcast().underlying() { Some(src1) if remapped_ids.contains_key(src1) => { (remapped_ids.get(src1).unwrap(), arg.src2) } @@ -4495,7 +4369,7 @@ fn convert_to_stateful_memory_access<'a>( } _ => return Err(error_unreachable()), }; - let dst = arg.dst.unwrap_reg()?; + let dst = arg.dst.upcast().unwrap_reg()?; result.push(Statement::PtrAccess(PtrAccess { underlying_type: ast::PointerType::Scalar(ast::ScalarType::U8), state_space: ast::LdStateSpace::Global, @@ -4515,7 +4389,7 @@ fn convert_to_stateful_memory_access<'a>( }), arg, )) if is_add_ptr_direct(&remapped_ids, &arg) => { - let (ptr, offset) = match arg.src1.underlying() { + let (ptr, offset) = match arg.src1.upcast().underlying() { Some(src1) if remapped_ids.contains_key(src1) => { (remapped_ids.get(src1).unwrap(), arg.src2) } @@ -4533,16 +4407,16 @@ fn convert_to_stateful_memory_access<'a>( }, ast::Arg2 { src: offset, - dst: ast::DstOperand::Reg(offset_neg), + dst: TypedOperand::Reg(offset_neg), }, ))); - let dst = arg.dst.unwrap_reg()?; + let dst = arg.dst.upcast().unwrap_reg()?; result.push(Statement::PtrAccess(PtrAccess { underlying_type: ast::PointerType::Scalar(ast::ScalarType::U8), state_space: ast::LdStateSpace::Global, dst: *remapped_ids.get(&dst).unwrap(), ptr_src: *ptr, - offset_src: ast::SrcOperand::Reg(offset_neg), + offset_src: TypedOperand::Reg(offset_neg), })) } Statement::Instruction(inst) => { @@ -4697,12 +4571,14 @@ fn convert_to_stateful_memory_access_postprocess( fn is_add_ptr_direct(remapped_ids: &HashMap, arg: &ast::Arg3) -> bool { match arg.dst { - ast::DstOperand::VecMember(..) => return false, - ast::DstOperand::Reg(dst) => { + TypedOperand::Imm(..) | TypedOperand::RegOffset(..) | TypedOperand::VecMember(..) => { + return false + } + TypedOperand::Reg(dst) => { if !remapped_ids.contains_key(&dst) { return false; } - match arg.src1.underlying() { + match arg.src1.upcast().underlying() { Some(src1) if remapped_ids.contains_key(src1) => true, Some(src2) if remapped_ids.contains_key(src2) => true, _ => false, @@ -5219,11 +5095,11 @@ struct ResolvedCall { pub uniform: bool, pub ret_params: Vec<(P::Id, ast::FnArgumentType)>, pub func: P::Id, - pub param_list: Vec<(P::SrcOperand, ast::FnArgumentType)>, + pub param_list: Vec<(P::Operand, ast::FnArgumentType)>, } impl ResolvedCall { - fn cast>(self) -> ResolvedCall { + fn cast>(self) -> ResolvedCall { ResolvedCall { uniform: self.uniform, ret_params: self.ret_params, @@ -5265,7 +5141,7 @@ impl> ResolvedCall { .param_list .into_iter() .map::, _>(|(id, typ)| { - let new_id = visitor.src_operand( + let new_id = visitor.operand( ArgumentDescriptor { op: id, is_dst: false, @@ -5327,7 +5203,7 @@ impl> PtrAccess

{ }, Some(&ptr_type), )?; - let new_constant_src = visitor.src_operand( + let new_constant_src = visitor.operand( ArgumentDescriptor { op: self.offset_src, is_dst: false, @@ -5376,10 +5252,7 @@ enum NormalizedArgParams {} impl ast::ArgParams for NormalizedArgParams { type Id = spirv::Word; - type DstOperand = ast::DstOperand; - type SrcOperand = ast::SrcOperand; - type DstOperandVec = ast::DstOperandVec; - type SrcOperandVec = ast::SrcOperandVec; + type Operand = ast::Operand; } impl ArgParamsEx for NormalizedArgParams { @@ -5405,10 +5278,7 @@ enum TypedArgParams {} impl ast::ArgParams for TypedArgParams { type Id = spirv::Word; - type DstOperand = ast::DstOperand; - type SrcOperand = ast::SrcOperand; - type DstOperandVec = ast::DstOperand; - type SrcOperandVec = ast::SrcOperand; + type Operand = TypedOperand; } impl ArgParamsEx for TypedArgParams { @@ -5420,6 +5290,25 @@ impl ArgParamsEx for TypedArgParams { } } +#[derive(Copy, Clone)] +enum TypedOperand { + Reg(spirv::Word), + RegOffset(spirv::Word, i32), + Imm(ast::ImmediateValue), + VecMember(spirv::Word, u8), +} + +impl TypedOperand { + fn upcast(self) -> ast::Operand { + match self { + TypedOperand::Reg(reg) => ast::Operand::Reg(reg), + TypedOperand::RegOffset(reg, idx) => ast::Operand::RegOffset(reg, idx), + TypedOperand::Imm(x) => ast::Operand::Imm(x), + TypedOperand::VecMember(vec, idx) => ast::Operand::VecMember(vec, idx), + } + } +} + type TypedStatement = Statement, TypedArgParams>; enum ExpandedArgParams {} @@ -5427,10 +5316,7 @@ type ExpandedStatement = Statement, Expanded impl ast::ArgParams for ExpandedArgParams { type Id = spirv::Word; - type DstOperand = spirv::Word; - type SrcOperand = spirv::Word; - type DstOperandVec = spirv::Word; - type SrcOperandVec = spirv::Word; + type Operand = spirv::Word; } impl ArgParamsEx for ExpandedArgParams { @@ -5461,26 +5347,11 @@ pub trait ArgumentMapVisitor { desc: ArgumentDescriptor, typ: Option<&ast::Type>, ) -> Result; - fn dst_operand( + fn operand( &mut self, - desc: ArgumentDescriptor, + desc: ArgumentDescriptor, typ: &ast::Type, - ) -> Result; - fn src_operand( - &mut self, - desc: ArgumentDescriptor, - typ: &ast::Type, - ) -> Result; - fn dst_operand_vec( - &mut self, - desc: ArgumentDescriptor, - typ: &ast::Type, - ) -> Result; - fn src_operand_vec( - &mut self, - desc: ArgumentDescriptor, - typ: &ast::Type, - ) -> Result; + ) -> Result; } impl ArgumentMapVisitor for T @@ -5498,31 +5369,7 @@ where self(desc, t) } - fn dst_operand( - &mut self, - desc: ArgumentDescriptor, - typ: &ast::Type, - ) -> Result { - self(desc, Some(typ)) - } - - fn src_operand( - &mut self, - desc: ArgumentDescriptor, - typ: &ast::Type, - ) -> Result { - self(desc, Some(typ)) - } - - fn dst_operand_vec( - &mut self, - desc: ArgumentDescriptor, - typ: &ast::Type, - ) -> Result { - self(desc, Some(typ)) - } - - fn src_operand_vec( + fn operand( &mut self, desc: ArgumentDescriptor, typ: &ast::Type, @@ -5543,57 +5390,17 @@ where self(desc.op) } - fn dst_operand( + fn operand( &mut self, - desc: ArgumentDescriptor>, - _: &ast::Type, - ) -> Result, TranslateError> { - Ok(match desc.op { - ast::DstOperand::Reg(id) => ast::DstOperand::Reg(self(id)?), - ast::DstOperand::VecMember(id, member) => ast::DstOperand::VecMember(self(id)?, member), - }) - } - - fn src_operand( - &mut self, - desc: ArgumentDescriptor>, - _: &ast::Type, - ) -> Result, TranslateError> { - Ok(match desc.op { - ast::SrcOperand::Reg(id) => ast::SrcOperand::Reg(self(id)?), - ast::SrcOperand::RegOffset(id, imm) => ast::SrcOperand::RegOffset(self(id)?, imm), - ast::SrcOperand::Imm(imm) => ast::SrcOperand::Imm(imm), - ast::SrcOperand::VecMember(id, member) => ast::SrcOperand::VecMember(self(id)?, member), - }) - } - - fn dst_operand_vec( - &mut self, - desc: ArgumentDescriptor>, + desc: ArgumentDescriptor>, typ: &ast::Type, - ) -> Result, TranslateError> { - Ok(match &desc.op { - ast::DstOperandVec::Normal(inner_op) => { - ast::DstOperandVec::Normal(self.dst_operand(desc.new_op(inner_op.clone()), typ)?) - } - ast::DstOperandVec::Vector(ids) => ast::DstOperandVec::Vector( - ids.into_iter() - .map(|id| self.id(desc.new_op(id), Some(typ))) - .collect::, _>>()?, - ), - }) - } - - fn src_operand_vec( - &mut self, - desc: ArgumentDescriptor>, - typ: &ast::Type, - ) -> Result, TranslateError> { - Ok(match &desc.op { - ast::SrcOperandVec::Normal(inner_op) => { - ast::SrcOperandVec::Normal(self.src_operand(desc.new_op(inner_op.clone()), typ)?) - } - ast::SrcOperandVec::Vector(ids) => ast::SrcOperandVec::Vector( + ) -> Result, TranslateError> { + Ok(match desc.op { + ast::Operand::Reg(id) => ast::Operand::Reg(self(id)?), + ast::Operand::RegOffset(id, imm) => ast::Operand::RegOffset(self(id)?, imm), + ast::Operand::Imm(imm) => ast::Operand::Imm(imm), + ast::Operand::VecMember(id, member) => ast::Operand::VecMember(self(id)?, member), + ast::Operand::VecPack(ref ids) => ast::Operand::VecPack( ids.into_iter() .map(|id| self.id(desc.new_op(id), Some(typ))) .collect::, _>>()?, @@ -5613,7 +5420,7 @@ pub struct PtrAccess { state_space: ast::LdStateSpace, dst: spirv::Word, ptr_src: spirv::Word, - offset_src: P::SrcOperand, + offset_src: P::Operand, } #[derive(Copy, Clone, PartialEq, Eq, Debug)] @@ -5913,82 +5720,24 @@ where self(desc, t) } - fn dst_operand( + fn operand( &mut self, - desc: ArgumentDescriptor>, + desc: ArgumentDescriptor, typ: &ast::Type, - ) -> Result, TranslateError> { + ) -> Result { Ok(match desc.op { - ast::DstOperand::Reg(id) => ast::DstOperand::Reg(self(desc.new_op(id), Some(typ))?), - ast::DstOperand::VecMember(reg, index) => { + TypedOperand::Reg(id) => TypedOperand::Reg(self(desc.new_op(id), Some(typ))?), + TypedOperand::Imm(imm) => TypedOperand::Imm(imm), + TypedOperand::RegOffset(id, imm) => { + TypedOperand::RegOffset(self(desc.new_op(id), Some(typ))?, imm) + } + TypedOperand::VecMember(reg, index) => { let scalar_type = match typ { ast::Type::Scalar(scalar_t) => *scalar_t, _ => return Err(error_unreachable()), }; let vec_type = ast::Type::Vector(scalar_type, index + 1); - ast::DstOperand::VecMember(self(desc.new_op(reg), Some(&vec_type))?, index) - } - }) - } - - fn src_operand( - &mut self, - desc: ArgumentDescriptor>, - typ: &ast::Type, - ) -> Result, TranslateError> { - Ok(match desc.op { - ast::SrcOperand::Reg(id) => ast::SrcOperand::Reg(self(desc.new_op(id), Some(typ))?), - ast::SrcOperand::Imm(imm) => ast::SrcOperand::Imm(imm), - ast::SrcOperand::RegOffset(id, imm) => { - ast::SrcOperand::RegOffset(self(desc.new_op(id), Some(typ))?, imm) - } - ast::SrcOperand::VecMember(reg, index) => { - let scalar_type = match typ { - ast::Type::Scalar(scalar_t) => *scalar_t, - _ => return Err(error_unreachable()), - }; - let vec_type = ast::Type::Vector(scalar_type, index + 1); - ast::SrcOperand::VecMember(self(desc.new_op(reg), Some(&vec_type))?, index) - } - }) - } - - fn dst_operand_vec( - &mut self, - desc: ArgumentDescriptor>, - typ: &ast::Type, - ) -> Result, TranslateError> { - Ok(match desc.op { - ast::DstOperand::Reg(id) => ast::DstOperand::Reg(self(desc.new_op(id), Some(typ))?), - ast::DstOperand::VecMember(reg, index) => { - let scalar_type = match typ { - ast::Type::Scalar(scalar_t) => *scalar_t, - _ => return Err(error_unreachable()), - }; - let vec_type = ast::Type::Vector(scalar_type, index + 1); - ast::DstOperand::VecMember(self(desc.new_op(reg), Some(&vec_type))?, index) - } - }) - } - - fn src_operand_vec( - &mut self, - desc: ArgumentDescriptor>, - typ: &ast::Type, - ) -> Result, TranslateError> { - Ok(match desc.op { - ast::SrcOperand::Reg(id) => ast::SrcOperand::Reg(self(desc.new_op(id), Some(typ))?), - ast::SrcOperand::Imm(imm) => ast::SrcOperand::Imm(imm), - ast::SrcOperand::RegOffset(id, imm) => { - ast::SrcOperand::RegOffset(self(desc.new_op(id), Some(typ))?, imm) - } - ast::SrcOperand::VecMember(reg, index) => { - let scalar_type = match typ { - ast::Type::Scalar(scalar_t) => *scalar_t, - _ => return Err(error_unreachable()), - }; - let vec_type = ast::Type::Vector(scalar_type, index + 1); - ast::SrcOperand::VecMember(self(desc.new_op(reg), Some(&vec_type))?, index) + TypedOperand::VecMember(self(desc.new_op(reg), Some(&vec_type))?, index) } }) } @@ -6364,7 +6113,7 @@ impl ast::Arg1Bar { self, visitor: &mut V, ) -> Result, TranslateError> { - let new_src = visitor.src_operand( + let new_src = visitor.operand( ArgumentDescriptor { op: self.src, is_dst: false, @@ -6382,7 +6131,7 @@ impl ast::Arg2 { visitor: &mut V, t: &ast::Type, ) -> Result, TranslateError> { - let new_dst = visitor.dst_operand( + let new_dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, @@ -6390,7 +6139,7 @@ impl ast::Arg2 { }, t, )?; - let new_src = visitor.src_operand( + let new_src = visitor.operand( ArgumentDescriptor { op: self.src, is_dst: false, @@ -6410,7 +6159,7 @@ impl ast::Arg2 { dst_t: &ast::Type, src_t: &ast::Type, ) -> Result, TranslateError> { - let dst = visitor.dst_operand( + let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, @@ -6418,7 +6167,7 @@ impl ast::Arg2 { }, dst_t, )?; - let src = visitor.src_operand( + let src = visitor.operand( ArgumentDescriptor { op: self.src, is_dst: false, @@ -6436,7 +6185,7 @@ impl ast::Arg2Ld { visitor: &mut V, details: &ast::LdDetails, ) -> Result, TranslateError> { - let dst = visitor.dst_operand_vec( + let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, @@ -6446,7 +6195,7 @@ impl ast::Arg2Ld { )?; let is_logical_ptr = details.state_space == ast::LdStateSpace::Param || details.state_space == ast::LdStateSpace::Local; - let src = visitor.src_operand( + let src = visitor.operand( ArgumentDescriptor { op: self.src, is_dst: false, @@ -6473,7 +6222,7 @@ impl ast::Arg2St { ) -> Result, TranslateError> { let is_logical_ptr = details.state_space == ast::StStateSpace::Param || details.state_space == ast::StStateSpace::Local; - let src1 = visitor.src_operand( + let src1 = visitor.operand( ArgumentDescriptor { op: self.src1, is_dst: false, @@ -6488,7 +6237,7 @@ impl ast::Arg2St { details.state_space.to_ld_ss(), ), )?; - let src2 = visitor.src_operand_vec( + let src2 = visitor.operand( ArgumentDescriptor { op: self.src2, is_dst: false, @@ -6506,7 +6255,7 @@ impl ast::Arg2Mov { visitor: &mut V, details: &ast::MovDetails, ) -> Result, TranslateError> { - let dst = visitor.dst_operand_vec( + let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, @@ -6514,7 +6263,7 @@ impl ast::Arg2Mov { }, &details.typ.clone().into(), )?; - let src = visitor.src_operand_vec( + let src = visitor.operand( ArgumentDescriptor { op: self.src, is_dst: false, @@ -6542,7 +6291,7 @@ impl ast::Arg3 { } else { None }; - let dst = visitor.dst_operand( + let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, @@ -6550,7 +6299,7 @@ impl ast::Arg3 { }, wide_type.as_ref().unwrap_or(typ), )?; - let src1 = visitor.src_operand( + let src1 = visitor.operand( ArgumentDescriptor { op: self.src1, is_dst: false, @@ -6558,7 +6307,7 @@ impl ast::Arg3 { }, typ, )?; - let src2 = visitor.src_operand( + let src2 = visitor.operand( ArgumentDescriptor { op: self.src2, is_dst: false, @@ -6574,7 +6323,7 @@ impl ast::Arg3 { visitor: &mut V, t: &ast::Type, ) -> Result, TranslateError> { - let dst = visitor.dst_operand( + let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, @@ -6582,7 +6331,7 @@ impl ast::Arg3 { }, t, )?; - let src1 = visitor.src_operand( + let src1 = visitor.operand( ArgumentDescriptor { op: self.src1, is_dst: false, @@ -6590,7 +6339,7 @@ impl ast::Arg3 { }, t, )?; - let src2 = visitor.src_operand( + let src2 = visitor.operand( ArgumentDescriptor { op: self.src2, is_dst: false, @@ -6608,7 +6357,7 @@ impl ast::Arg3 { state_space: ast::AtomSpace, ) -> Result, TranslateError> { let scalar_type = ast::ScalarType::from(t); - let dst = visitor.dst_operand( + let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, @@ -6616,7 +6365,7 @@ impl ast::Arg3 { }, &ast::Type::Scalar(scalar_type), )?; - let src1 = visitor.src_operand( + let src1 = visitor.operand( ArgumentDescriptor { op: self.src1, is_dst: false, @@ -6627,7 +6376,7 @@ impl ast::Arg3 { state_space.to_ld_ss(), ), )?; - let src2 = visitor.src_operand( + let src2 = visitor.operand( ArgumentDescriptor { op: self.src2, is_dst: false, @@ -6651,7 +6400,7 @@ impl ast::Arg4 { } else { None }; - let dst = visitor.dst_operand( + let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, @@ -6659,7 +6408,7 @@ impl ast::Arg4 { }, wide_type.as_ref().unwrap_or(t), )?; - let src1 = visitor.src_operand( + let src1 = visitor.operand( ArgumentDescriptor { op: self.src1, is_dst: false, @@ -6667,7 +6416,7 @@ impl ast::Arg4 { }, t, )?; - let src2 = visitor.src_operand( + let src2 = visitor.operand( ArgumentDescriptor { op: self.src2, is_dst: false, @@ -6675,7 +6424,7 @@ impl ast::Arg4 { }, t, )?; - let src3 = visitor.src_operand( + let src3 = visitor.operand( ArgumentDescriptor { op: self.src3, is_dst: false, @@ -6696,7 +6445,7 @@ impl ast::Arg4 { visitor: &mut V, t: ast::SelpType, ) -> Result, TranslateError> { - let dst = visitor.dst_operand( + let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, @@ -6704,7 +6453,7 @@ impl ast::Arg4 { }, &ast::Type::Scalar(t.into()), )?; - let src1 = visitor.src_operand( + let src1 = visitor.operand( ArgumentDescriptor { op: self.src1, is_dst: false, @@ -6712,7 +6461,7 @@ impl ast::Arg4 { }, &ast::Type::Scalar(t.into()), )?; - let src2 = visitor.src_operand( + let src2 = visitor.operand( ArgumentDescriptor { op: self.src2, is_dst: false, @@ -6720,7 +6469,7 @@ impl ast::Arg4 { }, &ast::Type::Scalar(t.into()), )?; - let src3 = visitor.src_operand( + let src3 = visitor.operand( ArgumentDescriptor { op: self.src3, is_dst: false, @@ -6743,7 +6492,7 @@ impl ast::Arg4 { state_space: ast::AtomSpace, ) -> Result, TranslateError> { let scalar_type = ast::ScalarType::from(t); - let dst = visitor.dst_operand( + let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, @@ -6751,7 +6500,7 @@ impl ast::Arg4 { }, &ast::Type::Scalar(scalar_type), )?; - let src1 = visitor.src_operand( + let src1 = visitor.operand( ArgumentDescriptor { op: self.src1, is_dst: false, @@ -6762,7 +6511,7 @@ impl ast::Arg4 { state_space.to_ld_ss(), ), )?; - let src2 = visitor.src_operand( + let src2 = visitor.operand( ArgumentDescriptor { op: self.src2, is_dst: false, @@ -6770,7 +6519,7 @@ impl ast::Arg4 { }, &ast::Type::Scalar(scalar_type), )?; - let src3 = visitor.src_operand( + let src3 = visitor.operand( ArgumentDescriptor { op: self.src3, is_dst: false, @@ -6791,7 +6540,7 @@ impl ast::Arg4 { visitor: &mut V, typ: &ast::Type, ) -> Result, TranslateError> { - let dst = visitor.dst_operand( + let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, @@ -6799,7 +6548,7 @@ impl ast::Arg4 { }, typ, )?; - let src1 = visitor.src_operand( + let src1 = visitor.operand( ArgumentDescriptor { op: self.src1, is_dst: false, @@ -6808,7 +6557,7 @@ impl ast::Arg4 { typ, )?; let u32_type = ast::Type::Scalar(ast::ScalarType::U32); - let src2 = visitor.src_operand( + let src2 = visitor.operand( ArgumentDescriptor { op: self.src2, is_dst: false, @@ -6816,7 +6565,7 @@ impl ast::Arg4 { }, &u32_type, )?; - let src3 = visitor.src_operand( + let src3 = visitor.operand( ArgumentDescriptor { op: self.src3, is_dst: false, @@ -6860,7 +6609,7 @@ impl ast::Arg4Setp { ) }) .transpose()?; - let src1 = visitor.src_operand( + let src1 = visitor.operand( ArgumentDescriptor { op: self.src1, is_dst: false, @@ -6868,7 +6617,7 @@ impl ast::Arg4Setp { }, t, )?; - let src2 = visitor.src_operand( + let src2 = visitor.operand( ArgumentDescriptor { op: self.src2, is_dst: false, @@ -6912,7 +6661,7 @@ impl ast::Arg5Setp { ) }) .transpose()?; - let src1 = visitor.src_operand( + let src1 = visitor.operand( ArgumentDescriptor { op: self.src1, is_dst: false, @@ -6920,7 +6669,7 @@ impl ast::Arg5Setp { }, t, )?; - let src2 = visitor.src_operand( + let src2 = visitor.operand( ArgumentDescriptor { op: self.src2, is_dst: false, @@ -6928,7 +6677,7 @@ impl ast::Arg5Setp { }, t, )?; - let src3 = visitor.src_operand( + let src3 = visitor.operand( ArgumentDescriptor { op: self.src3, is_dst: false, @@ -6946,34 +6695,28 @@ impl ast::Arg5Setp { } } -impl ast::SrcOperand { +impl ast::Operand { fn map_variable Result>( self, f: &mut F, - ) -> Result, TranslateError> { + ) -> Result, TranslateError> { Ok(match self { - ast::SrcOperand::Reg(reg) => ast::SrcOperand::Reg(f(reg)?), - ast::SrcOperand::RegOffset(reg, offset) => ast::SrcOperand::RegOffset(f(reg)?, offset), - ast::SrcOperand::Imm(x) => ast::SrcOperand::Imm(x), - ast::SrcOperand::VecMember(reg, idx) => ast::SrcOperand::VecMember(f(reg)?, idx), + ast::Operand::Reg(reg) => ast::Operand::Reg(f(reg)?), + ast::Operand::RegOffset(reg, offset) => ast::Operand::RegOffset(f(reg)?, offset), + ast::Operand::Imm(x) => ast::Operand::Imm(x), + ast::Operand::VecMember(reg, idx) => ast::Operand::VecMember(f(reg)?, idx), + ast::Operand::VecPack(vec) => { + ast::Operand::VecPack(vec.into_iter().map(f).collect::>()?) + } }) } } -impl ast::DstOperand { - fn to_src_operand(self) -> ast::SrcOperand { - match self { - ast::DstOperand::Reg(reg) => ast::SrcOperand::Reg(reg), - ast::DstOperand::VecMember(reg, idx) => ast::SrcOperand::VecMember(reg, idx), - } - } -} - -impl ast::DstOperand { +impl ast::Operand { fn unwrap_reg(&self) -> Result { match self { - ast::DstOperand::Reg(reg) => Ok(*reg), - ast::DstOperand::VecMember(..) => Err(error_unreachable()), + ast::Operand::Reg(reg) => Ok(*reg), + _ => Err(error_unreachable()), } } } @@ -7189,12 +6932,13 @@ impl From for ast::VariableType { } } -impl ast::SrcOperand { +impl ast::Operand { fn underlying(&self) -> Option<&T> { match self { - ast::SrcOperand::Reg(r) | ast::SrcOperand::RegOffset(r, _) => Some(r), - ast::SrcOperand::Imm(_) => None, - ast::SrcOperand::VecMember(reg, _) => Some(reg), + ast::Operand::Reg(r) | ast::Operand::RegOffset(r, _) => Some(r), + ast::Operand::Imm(_) => None, + ast::Operand::VecMember(reg, _) => Some(reg), + ast::Operand::VecPack(..) => None, } } }