diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index e49e489..3ad61e5 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -1,6 +1,6 @@ use half::f16; use lalrpop_util::{lexer::Token, ParseError}; -use std::{convert::From, mem, num::ParseFloatError, rc::Rc, str::FromStr}; +use std::{convert::From, mem, num::ParseFloatError, str::FromStr}; use std::{marker::PhantomData, num::ParseIntError}; #[derive(Debug, thiserror::Error)] diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index ea6451e..c6b7f01 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -74,12 +74,6 @@ impl SpirvType { } } -impl ast::Type { - fn param_pointer_to(self, space: ast::StateSpace) -> Result { - Ok(self) - } -} - impl From for SpirvType { fn from(t: ast::ScalarType) -> Self { SpirvType::Base(t.into()) @@ -636,7 +630,7 @@ fn get_kernels_call_map<'input>( for statement in statements { match statement { Statement::Call(call) => { - multi_hash_map_append(&mut directly_called_by, call_key, call.func); + multi_hash_map_append(&mut directly_called_by, call_key, call.name); } _ => {} } @@ -872,8 +866,8 @@ fn replace_uses_of_shared_memory<'a>( to_type: ast::Type::Pointer((*scalar_type).into()), to_space: ast::StateSpace::Shared, kind: ConversionKind::PtrToPtr { spirv_ptr: true }, - src_sema: ArgumentSemantics::Default, - dst_sema: ArgumentSemantics::Default, + src_ + dst_ })); replacement_id } else { @@ -1172,48 +1166,19 @@ fn translate_directive<'input>( d: ast::Directive<'input, ast::ParsedArgParams<'input>>, ) -> Result>, TranslateError> { Ok(match d { - ast::Directive::Variable(v) => Some(Directive::Variable(translate_variable(id_defs, v)?)), + ast::Directive::Variable(var) => Some(Directive::Variable(ast::Variable { + align: var.align, + v_type: var.v_type.clone(), + state_space: var.state_space, + name: id_defs.get_or_add_def_typed(var.name, var.v_type, var.state_space, true), + array_init: var.array_init, + })), ast::Directive::Method(f) => { translate_function(id_defs, ptx_impl_imports, f)?.map(Directive::Method) } }) } -fn translate_variable<'a>( - id_defs: &mut GlobalStringIdResolver<'a>, - var: ast::Variable<&'a str>, -) -> Result, TranslateError> { - let (space, var_type) = (var.state_space, var.v_type.clone()); - let mut is_variable = false; - let var_type = match space { - ast::StateSpace::Reg => { - is_variable = true; - var_type - } - ast::StateSpace::Const => var_type.param_pointer_to(ast::StateSpace::Const)?, - ast::StateSpace::Global => var_type.param_pointer_to(ast::StateSpace::Global)?, - ast::StateSpace::Local => var_type.param_pointer_to(ast::StateSpace::Local)?, - ast::StateSpace::Shared => { - // If it's a pointer it will be translated to a method parameter later - if let ast::Type::Pointer(..) = var_type { - is_variable = true; - var_type - } else { - var_type.param_pointer_to(ast::StateSpace::Shared)? - } - } - ast::StateSpace::Param => var_type.param_pointer_to(ast::StateSpace::Param)?, - ast::StateSpace::Generic | ast::StateSpace::Sreg => return Err(error_unreachable()), - }; - Ok(ast::Variable { - align: var.align, - v_type: var.v_type, - state_space: var.state_space, - name: id_defs.get_or_add_def_typed(var.name, var_type, var.state_space, is_variable), - array_init: var.array_init, - }) -} - fn translate_function<'a>( id_defs: &mut GlobalStringIdResolver<'a>, ptx_impl_imports: &mut HashMap>, @@ -1247,29 +1212,6 @@ fn translate_function<'a>( } } -fn expand_kernel_params<'a, 'b>( - fn_resolver: &mut FnStringIdResolver<'a, 'b>, - args: impl Iterator>, -) -> Result>, TranslateError> { - args.map(|a| { - Ok(ast::Variable { - name: fn_resolver.add_def( - a.name, - Some(( - ast::Type::from(a.v_type.clone()).param_pointer_to(ast::StateSpace::Param)?, - a.state_space, - )), - false, - ), - v_type: a.v_type.clone(), - state_space: a.state_space, - align: a.align, - array_init: Vec::new(), - }) - }) - .collect::>() -} - fn rename_fn_params<'a, 'b>( fn_resolver: &mut FnStringIdResolver<'a, 'b>, args: &'b [ast::Variable<&'a str>], @@ -1293,6 +1235,7 @@ fn to_ssa<'input, 'b>( f_body: Option>>>, tuning: Vec, ) -> Result, TranslateError> { + deparamize_function_decl(&func_decl)?; let f_body = match f_body { Some(vec) => vec, None => { @@ -1337,6 +1280,30 @@ fn to_ssa<'input, 'b>( }) } +fn deparamize_function_decl( + func_decl_rc: &Rc>>, +) -> Result<(), TranslateError> { + let mut func_decl = func_decl_rc.borrow_mut(); + match func_decl.name { + ast::MethodName::Func(..) => { + for decl in func_decl.input_arguments.iter_mut() { + if decl.state_space == ast::StateSpace::Param { + decl.state_space = ast::StateSpace::Reg; + let baseline_type = match decl.v_type { + ast::Type::Scalar(t) => t, + ast::Type::Vector(t, _) => t, // TODO: write a test for this + ast::Type::Array(t, _) => t, // TODO: write a test for this + ast::Type::Pointer(_, _) => return Err(error_unreachable()), + }; + decl.v_type = ast::Type::Pointer(baseline_type, ast::StateSpace::Param); + } + } + } + ast::MethodName::Kernel(..) => {} + }; + Ok(()) +} + fn fix_special_registers( typed_statements: Vec, numeric_id_defs: &mut NumericIdResolver, @@ -1394,8 +1361,6 @@ fn fix_special_registers( to_type: ast::Type::Scalar(ast::ScalarType::U32), to_space: ast::StateSpace::Sreg, kind: ConversionKind::Default, - src_sema: ArgumentSemantics::Default, - dst_sema: ArgumentSemantics::Default, })); } } @@ -1566,45 +1531,21 @@ fn convert_to_typed_statements( ast::Instruction::Call(call) => { // TODO: error out if lengths don't match let fn_def = (**fn_defs.get_fn_decl(call.func)?).borrow(); - let out_args = to_resolved_fn_args(call.ret_params, &*fn_def.return_arguments); - let in_args = to_resolved_fn_args(call.param_list, &*fn_def.input_arguments); - let (out_params, out_non_params): (Vec<_>, Vec<_>) = out_args - .into_iter() - .partition(|(_, _, space)| *space == ast::StateSpace::Param); - let normalized_input_args = out_params - .into_iter() - .map(|(id, typ, space)| (ast::Operand::Reg(id), typ, space)) - .chain(in_args.into_iter()) - .collect(); + let return_arguments = + to_resolved_fn_args(call.ret_params, &*fn_def.return_arguments); + let input_arguments = + to_resolved_fn_args(call.param_list, &*fn_def.input_arguments); let resolved_call = ResolvedCall { uniform: call.uniform, - ret_params: out_non_params, - func: call.func, - param_list: normalized_input_args, + return_arguments, + name: call.func, + input_arguments, }; let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); let reresolved_call = resolved_call.visit(&mut visitor)?; visitor.func.push(reresolved_call); visitor.func.extend(visitor.post_stmts); } - ast::Instruction::Mov(mut d, ast::Arg2Mov { dst, src }) => { - if let Some(src_id) = src.underlying() { - let (typ, _, _) = id_defs.get_typed(*src_id)?; - let take_address = match typ { - ast::Type::Scalar(..) => false, - ast::Type::Vector(..) => false, - ast::Type::Array(..) => true, - ast::Type::Pointer(..) => true, - }; - d.src_is_address = take_address; - } - let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); - let instruction = Statement::Instruction( - ast::Instruction::Mov(d, ast::Arg2Mov { dst, src }).map(&mut visitor)?, - ); - visitor.func.push(instruction); - visitor.func.extend(visitor.post_stmts); - } inst => { let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); let instruction = Statement::Instruction(inst.map(&mut visitor)?); @@ -1639,7 +1580,6 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> { fn convert_vector( &mut self, is_dst: bool, - vector_sema: ArgumentSemantics, typ: &ast::Type, state_space: ast::StateSpace, idx: Vec, @@ -1657,7 +1597,6 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> { typ: scalar_t, packed: temp_vec, unpacked: idx, - vector_sema, }); if is_dst { self.post_stmts = Some(statement); @@ -1690,13 +1629,9 @@ impl<'a, 'b> ArgumentMapVisitor ast::Operand::RegOffset(reg, offset) => TypedOperand::RegOffset(reg, offset), ast::Operand::Imm(x) => TypedOperand::Imm(x), ast::Operand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx), - ast::Operand::VecPack(vec) => TypedOperand::Reg(self.convert_vector( - desc.is_dst, - desc.sema, - typ, - state_space, - vec, - )?), + ast::Operand::VecPack(vec) => { + TypedOperand::Reg(self.convert_vector(desc.is_dst, typ, state_space, vec)?) + } }) } } @@ -1770,9 +1705,9 @@ fn to_ptx_impl_atomic_call( }; Statement::Call(ResolvedCall { uniform: false, - func: fn_id, - ret_params: vec![(arg.dst, ast::Type::Scalar(scalar_typ), ast::StateSpace::Reg)], - param_list: vec![ + name: fn_id, + return_arguments: vec![(arg.dst, ast::Type::Scalar(scalar_typ), ast::StateSpace::Reg)], + input_arguments: vec![ ( arg.src1, ast::Type::Pointer(typ, ptr_space), @@ -1859,9 +1794,9 @@ fn to_ptx_impl_bfe_call( }; Statement::Call(ResolvedCall { uniform: false, - func: fn_id, - ret_params: vec![(arg.dst, ast::Type::Scalar(typ.into()), ast::StateSpace::Reg)], - param_list: vec![ + name: fn_id, + return_arguments: vec![(arg.dst, ast::Type::Scalar(typ.into()), ast::StateSpace::Reg)], + input_arguments: vec![ ( arg.src1, ast::Type::Scalar(typ.into()), @@ -1958,9 +1893,9 @@ fn to_ptx_impl_bfi_call( }; Statement::Call(ResolvedCall { uniform: false, - func: fn_id, - ret_params: vec![(arg.dst, ast::Type::Scalar(typ.into()), ast::StateSpace::Reg)], - param_list: vec![ + name: fn_id, + return_arguments: vec![(arg.dst, ast::Type::Scalar(typ.into()), ast::StateSpace::Reg)], + input_arguments: vec![ ( arg.src1, ast::Type::Scalar(typ.into()), @@ -2217,14 +2152,14 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { fn symbol( &mut self, desc: ArgumentDescriptor<(spirv::Word, Option)>, - expected_type: Option<(&ast::Type, ast::StateSpace)>, + expected: Option<(&ast::Type, ast::StateSpace)>, ) -> Result { let symbol = desc.op.0; - if expected_type.is_none() { + if expected.is_none() { return Ok(symbol); }; - let (mut var_type, _, is_variable) = self.id_def.get_typed(symbol)?; - if !is_variable { + let (mut var_type, var_space, is_variable) = self.id_def.get_typed(symbol)?; + if !var_space.is_compatible(ast::StateSpace::Reg) || !is_variable { return Ok(symbol); }; let member_index = match desc.op.1 { @@ -2579,28 +2514,10 @@ fn insert_implicit_conversions( for s in func.into_iter() { match s { Statement::Call(call) => { - insert_implicit_conversions_impl(&mut result, id_def, call, should_bitcast_wrapper)? + insert_implicit_conversions_impl(&mut result, id_def, call)?; } Statement::Instruction(inst) => { - let mut default_conversion_fn = - should_bitcast_wrapper as for<'a> fn(&'a _, _, &'a _, _) -> _; - let mut state_space = None; - if let ast::Instruction::Ld(d, _) = &inst { - state_space = Some(d.state_space); - } - if let ast::Instruction::St(d, _) = &inst { - state_space = Some(d.state_space); - } - if let ast::Instruction::Atom(d, _) = &inst { - state_space = Some(d.space); - } - if let ast::Instruction::AtomCas(d, _) = &inst { - state_space = Some(d.space); - } - if let ast::Instruction::Mov(..) = &inst { - default_conversion_fn = should_bitcast_packed; - } - insert_implicit_conversions_impl(&mut result, id_def, inst, default_conversion_fn)?; + insert_implicit_conversions_impl(&mut result, id_def, inst)?; } Statement::PtrAccess(PtrAccess { underlying_type, @@ -2613,7 +2530,7 @@ fn insert_implicit_conversions( desc: ArgumentDescriptor { op: ptr_src, is_dst: false, - sema: ArgumentSemantics::PhysicalPointer, + non_default_implicit_conversion: None, }, typ: &ast::Type::Pointer(underlying_type, state_space), state_space: new_todo!(), @@ -2627,19 +2544,11 @@ fn insert_implicit_conversions( }) }, }; - insert_implicit_conversions_impl( - &mut result, - id_def, - visit_desc, - bitcast_physical_pointer, - )?; + insert_implicit_conversions_impl(&mut result, id_def, visit_desc)?; + } + Statement::RepackVector(repack) => { + insert_implicit_conversions_impl(&mut result, id_def, repack)?; } - Statement::RepackVector(repack) => insert_implicit_conversions_impl( - &mut result, - id_def, - repack, - should_bitcast_wrapper, - )?, s @ Statement::Conditional(_) | s @ Statement::Conversion(_) | s @ Statement::Label(_) @@ -2657,12 +2566,6 @@ fn insert_implicit_conversions_impl( func: &mut Vec, id_def: &mut MutableNumericIdResolver, stmt: impl Visitable, - default_conversion_fn: for<'a> fn( - &'a ast::Type, - ast::StateSpace, - &'a ast::Type, - ast::StateSpace, - ) -> Result, TranslateError>, ) -> Result<(), TranslateError> { let mut post_conv = Vec::new(); let statement = @@ -2673,27 +2576,13 @@ fn insert_implicit_conversions_impl( Some(t) => t, }; let (operand_type, operand_space) = id_def.get_typed(desc.op)?; - let mut conversion_fn = default_conversion_fn; - match desc.sema { - ArgumentSemantics::Default => {} - ArgumentSemantics::DefaultRelaxed => { - if desc.is_dst { - conversion_fn = should_convert_relaxed_dst_wrapper; - } else { - conversion_fn = should_convert_relaxed_src_wrapper; - } - } - ArgumentSemantics::PhysicalPointer => { - conversion_fn = bitcast_physical_pointer; - } - ArgumentSemantics::RegisterPointer => { - conversion_fn = bitcast_register_pointer; - } - ArgumentSemantics::Address => { - conversion_fn = force_bitcast_ptr_to_bit; - } - }; - match conversion_fn(&operand_type, operand_space, instr_type, instruction_space)? { + let conversion_fn = desc + .non_default_implicit_conversion + .unwrap_or(default_implicit_conversion); + match conversion_fn( + (operand_space, &operand_type), + (instruction_space, instr_type), + )? { Some(conv_kind) => { let conv_output = if desc.is_dst { &mut post_conv @@ -2721,8 +2610,6 @@ fn insert_implicit_conversions_impl( to_type, to_space, kind: conv_kind, - src_sema: ArgumentSemantics::Default, - dst_sema: ArgumentSemantics::Default, })); result } @@ -2774,7 +2661,7 @@ fn emit_function_body_ops( match s { Statement::Label(_) => (), Statement::Call(call) => { - let (result_type, result_id) = match &*call.ret_params { + let (result_type, result_id) = match &*call.return_arguments { [(id, typ, space)] => ( map.get_or_add(builder, SpirvType::new(typ.clone())), Some(*id), @@ -2783,11 +2670,11 @@ fn emit_function_body_ops( _ => todo!(), }; let arg_list = call - .param_list + .input_arguments .iter() .map(|(id, _, _)| *id) .collect::>(); - builder.function_call(result_type, result_id, call.func, arg_list)?; + builder.function_call(result_type, result_id, call.name, arg_list)?; } Statement::Variable(var) => { emit_variable(builder, map, var)?; @@ -3863,8 +3750,6 @@ fn emit_cvt( )), to_space: ast::StateSpace::Reg, kind: ConversionKind::Default, - src_sema: ArgumentSemantics::Default, - dst_sema: ArgumentSemantics::Default, }; emit_implicit_conversion(builder, map, &cv)?; new_dst @@ -4218,19 +4103,19 @@ fn emit_implicit_conversion( ) -> Result<(), TranslateError> { let from_parts = cv.from_type.to_parts(); let to_parts = cv.to_type.to_parts(); - match (from_parts.kind, to_parts.kind, cv.kind) { - (_, _, ConversionKind::PtrToBit(typ)) => { + match (from_parts.kind, to_parts.kind, &cv.kind) { + (_, _, &ConversionKind::PtrToBit(typ)) => { let dst_type = map.get_or_add_scalar(builder, typ.into()); builder.convert_ptr_to_u(dst_type, Some(cv.dst), cv.src)?; } - (_, _, ConversionKind::BitToPtr) => { + (_, _, &ConversionKind::BitToPtr) => { let dst_type = map.get_or_add( builder, SpirvType::pointer_to(cv.to_type.clone(), cv.to_space.to_spirv()), ); builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?; } - (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::Default) => { + (TypeKind::Scalar, TypeKind::Scalar, &ConversionKind::Default) => { if from_parts.width == to_parts.width { let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); if from_parts.scalar_kind != ast::ScalarKind::Float @@ -4282,35 +4167,29 @@ fn emit_implicit_conversion( to_type: cv.to_type.clone(), to_space: new_todo!(), kind: ConversionKind::Default, - src_sema: cv.src_sema, - dst_sema: cv.dst_sema, }, )?; } } } - (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::SignExtend) => { + (TypeKind::Scalar, TypeKind::Scalar, &ConversionKind::SignExtend) => { let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); builder.s_convert(result_type, Some(cv.dst), cv.src)?; } - (TypeKind::Vector, TypeKind::Scalar, ConversionKind::Default) - | (TypeKind::Scalar, TypeKind::Array, ConversionKind::Default) - | (TypeKind::Array, TypeKind::Scalar, ConversionKind::Default) => { + (TypeKind::Vector, TypeKind::Scalar, &ConversionKind::Default) + | (TypeKind::Scalar, TypeKind::Array, &ConversionKind::Default) + | (TypeKind::Array, TypeKind::Scalar, &ConversionKind::Default) => { let into_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); builder.bitcast(into_type, Some(cv.dst), cv.src)?; } - (_, _, ConversionKind::PtrToPtr { spirv_ptr }) => { - let result_type = if spirv_ptr { - map.get_or_add( - builder, - SpirvType::Pointer( - Box::new(SpirvType::new(cv.to_type.clone())), - spirv::StorageClass::Function, - ), - ) - } else { - map.get_or_add(builder, SpirvType::new(cv.to_type.clone())) - }; + (_, _, &ConversionKind::PtrToPtr) => { + let result_type = map.get_or_add( + builder, + SpirvType::Pointer( + Box::new(SpirvType::new(cv.to_type.clone())), + cv.to_space.to_spirv(), + ), + ); builder.bitcast(result_type, Some(cv.dst), cv.src)?; } _ => unreachable!(), @@ -4417,38 +4296,12 @@ fn expand_map_variables<'a, 'b>( i.map_variable(&mut |id| id_defs.get_id(id))?, ))), ast::Statement::Variable(var) => { - let mut var_type = ast::Type::from(var.var.v_type.clone()); - let mut is_variable = false; - var_type = match var.var.state_space { - ast::StateSpace::Reg => { - is_variable = true; - var_type - } - ast::StateSpace::Shared => { - // If it's a pointer it will be translated to a method parameter later - if let ast::Type::Pointer(..) = var_type { - is_variable = true; - var_type - } else { - var_type.param_pointer_to(ast::StateSpace::Shared)? - } - } - ast::StateSpace::Global => var_type.param_pointer_to(ast::StateSpace::Global)?, - ast::StateSpace::Param => var_type.param_pointer_to(ast::StateSpace::Param)?, - ast::StateSpace::Local => var_type.param_pointer_to(ast::StateSpace::Local)?, - ast::StateSpace::Const => new_todo!(), - ast::StateSpace::Generic => new_todo!(), - ast::StateSpace::Sreg => new_todo!(), - }; + let var_type = var.var.v_type.clone(); match var.count { Some(count) => { - for new_id in id_defs.add_defs( - var.var.name, - count, - var_type, - var.var.state_space, - is_variable, - ) { + for new_id in + id_defs.add_defs(var.var.name, count, var_type, var.var.state_space, true) + { result.push(Statement::Variable(ast::Variable { align: var.var.align, v_type: var.var.v_type.clone(), @@ -4459,11 +4312,8 @@ fn expand_map_variables<'a, 'b>( } } None => { - let new_id = id_defs.add_def( - var.var.name, - Some((var_type, var.var.state_space)), - is_variable, - ); + let new_id = + id_defs.add_def(var.var.name, Some((var_type, var.var.state_space)), true); result.push(Statement::Variable(ast::Variable { align: var.var.align, v_type: var.var.v_type.clone(), @@ -4839,7 +4689,7 @@ fn convert_to_stateful_memory_access_postprocess( ast::StateSpace::Global, ), kind: ConversionKind::BitToPtr(ast::StateSpace::Global), - src_sema: ArgumentSemantics::Default, + src_ dst_sema: arg_desc.sema, })); converting_id @@ -4854,7 +4704,7 @@ fn convert_to_stateful_memory_access_postprocess( to_type: old_type, kind: ConversionKind::PtrToBit(ast::ScalarType::U64), src_sema: arg_desc.sema, - dst_sema: ArgumentSemantics::Default, + dst_ })); converting_id } @@ -4881,7 +4731,7 @@ fn convert_to_stateful_memory_access_postprocess( to_type: old_type_clone, kind: ConversionKind::PtrToPtr { spirv_ptr: false }, src_sema: arg_desc.sema, - dst_sema: ArgumentSemantics::Default, + dst_ })); converting_id } @@ -4889,7 +4739,6 @@ fn convert_to_stateful_memory_access_postprocess( }, }) } -*/ fn is_add_ptr_direct(remapped_ids: &HashMap, arg: &ast::Arg3) -> bool { match arg.dst { @@ -4917,6 +4766,7 @@ fn is_64_bit_integer(id_defs: &NumericIdResolver, id: spirv::Word) -> bool { _ => false, } } +*/ #[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)] enum PtxSpecialRegister { @@ -5368,7 +5218,7 @@ impl ExpandedStatement { Statement::StoreVar(details) } Statement::Call(mut call) => { - for (id, _, space) in call.ret_params.iter_mut() { + for (id, _, space) in call.return_arguments.iter_mut() { let is_dst = match space { ast::StateSpace::Reg => true, ast::StateSpace::Param => false, @@ -5377,8 +5227,8 @@ impl ExpandedStatement { }; *id = f(*id, is_dst); } - call.func = f(call.func, false); - for (id, _, _) in call.param_list.iter_mut() { + call.name = f(call.name, false); + for (id, _, _) in call.input_arguments.iter_mut() { *id = f(*id, false); } Statement::Call(call) @@ -5461,7 +5311,6 @@ struct RepackVectorDetails { typ: ast::ScalarType, packed: spirv::Word, unpacked: Vec, - vector_sema: ArgumentSemantics, } impl RepackVectorDetails { @@ -5477,7 +5326,8 @@ impl RepackVectorDetails { ArgumentDescriptor { op: self.packed, is_dst: !self.is_extract, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, Some(( &ast::Type::Vector(self.typ, self.unpacked.len() as u8), @@ -5486,7 +5336,6 @@ impl RepackVectorDetails { )?; let scalar_type = self.typ; let is_extract = self.is_extract; - let vector_sema = self.vector_sema; let vector = self .unpacked .into_iter() @@ -5495,7 +5344,7 @@ impl RepackVectorDetails { ArgumentDescriptor { op: id, is_dst: is_extract, - sema: vector_sema, + non_default_implicit_conversion: None, }, Some((&ast::Type::Scalar(scalar_type), ast::StateSpace::Reg)), ) @@ -5506,7 +5355,6 @@ impl RepackVectorDetails { typ: self.typ, packed: scalar, unpacked: vector, - vector_sema, }) } } @@ -5524,18 +5372,18 @@ impl, U: ArgParamsEx> Visitab struct ResolvedCall { pub uniform: bool, - pub ret_params: Vec<(P::Id, ast::Type, ast::StateSpace)>, - pub func: P::Id, - pub param_list: Vec<(P::Operand, ast::Type, ast::StateSpace)>, + pub return_arguments: Vec<(P::Id, ast::Type, ast::StateSpace)>, + pub name: P::Id, + pub input_arguments: Vec<(P::Operand, ast::Type, ast::StateSpace)>, } impl ResolvedCall { fn cast>(self) -> ResolvedCall { ResolvedCall { uniform: self.uniform, - ret_params: self.ret_params, - func: self.func, - param_list: self.param_list, + return_arguments: self.return_arguments, + name: self.name, + input_arguments: self.input_arguments, } } } @@ -5546,14 +5394,14 @@ impl> ResolvedCall { visitor: &mut V, ) -> Result, TranslateError> { let ret_params = self - .ret_params + .return_arguments .into_iter() .map::, _>(|(id, typ, space)| { let new_id = visitor.id( ArgumentDescriptor { op: id, is_dst: space != ast::StateSpace::Param, - sema: space.semantics(), + non_default_implicit_conversion: None, }, Some((&typ, space)), )?; @@ -5562,21 +5410,22 @@ impl> ResolvedCall { .collect::, _>>()?; let func = visitor.id( ArgumentDescriptor { - op: self.func, + op: self.name, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, None, )?; let param_list = self - .param_list + .input_arguments .into_iter() .map::, _>(|(id, typ, space)| { let new_id = visitor.operand( ArgumentDescriptor { op: id, is_dst: false, - sema: space.semantics(), + non_default_implicit_conversion: None, }, &typ, space, @@ -5586,9 +5435,9 @@ impl> ResolvedCall { .collect::, _>>()?; Ok(ResolvedCall { uniform: self.uniform, - ret_params, - func, - param_list, + return_arguments: ret_params, + name: func, + input_arguments: param_list, }) } } @@ -5623,7 +5472,7 @@ impl> PtrAccess

{ ArgumentDescriptor { op: self.dst, is_dst: true, - sema, + non_default_implicit_conversion: None, }, Some((&ptr_type, self.state_space)), )?; @@ -5631,7 +5480,7 @@ impl> PtrAccess

{ ArgumentDescriptor { op: self.ptr_src, is_dst: false, - sema, + non_default_implicit_conversion: None, }, Some((&ptr_type, self.state_space)), )?; @@ -5639,7 +5488,8 @@ impl> PtrAccess

{ ArgumentDescriptor { op: self.offset_src, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, &ast::Type::Scalar(ast::ScalarType::S64), self.state_space, @@ -5816,7 +5666,12 @@ where pub struct ArgumentDescriptor { op: Op, is_dst: bool, - sema: ArgumentSemantics, + non_default_implicit_conversion: Option< + fn( + (ast::StateSpace, &ast::Type), + (ast::StateSpace, &ast::Type), + ) -> Result, TranslateError>, + >, } pub struct PtrAccess { @@ -5846,7 +5701,7 @@ impl ArgumentDescriptor { ArgumentDescriptor { op: u, is_dst: self.is_dst, - sema: self.sema, + non_default_implicit_conversion: None, } } } @@ -6085,7 +5940,7 @@ impl ImplicitConversion { ArgumentDescriptor { op: self.dst, is_dst: true, - sema: self.dst_sema, + non_default_implicit_conversion: None, }, Some((&self.to_type, self.to_space)), )?; @@ -6093,7 +5948,7 @@ impl ImplicitConversion { ArgumentDescriptor { op: self.src, is_dst: false, - sema: self.src_sema, + non_default_implicit_conversion: None, }, Some((&self.from_type, self.from_space)), )?; @@ -6396,18 +6251,16 @@ struct ImplicitConversion { from_space: ast::StateSpace, to_space: ast::StateSpace, kind: ConversionKind, - src_sema: ArgumentSemantics, - dst_sema: ArgumentSemantics, } -#[derive(PartialEq, Copy, Clone)] +#[derive(PartialEq, Clone)] enum ConversionKind { Default, // zero-extend/chop/bitcast depending on types SignExtend, BitToPtr, PtrToBit(ast::ScalarType), - PtrToPtr { spirv_ptr: bool }, + PtrToPtr, } impl ast::PredAt { @@ -6461,7 +6314,8 @@ impl ast::Arg1 { ArgumentDescriptor { op: self.src, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, t, )?; @@ -6478,7 +6332,8 @@ impl ast::Arg1Bar { ArgumentDescriptor { op: self.src, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, &ast::Type::Scalar(ast::ScalarType::U32), ast::StateSpace::Reg, @@ -6497,7 +6352,8 @@ impl ast::Arg2 { ArgumentDescriptor { op: self.dst, is_dst: true, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, t, ast::StateSpace::Reg, @@ -6506,7 +6362,8 @@ impl ast::Arg2 { ArgumentDescriptor { op: self.src, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, t, ast::StateSpace::Reg, @@ -6527,7 +6384,8 @@ impl ast::Arg2 { ArgumentDescriptor { op: self.dst, is_dst: true, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, dst_t, ast::StateSpace::Reg, @@ -6536,7 +6394,8 @@ impl ast::Arg2 { ArgumentDescriptor { op: self.src, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, src_t, ast::StateSpace::Reg, @@ -6555,7 +6414,7 @@ impl ast::Arg2Ld { ArgumentDescriptor { op: self.dst, is_dst: true, - sema: ArgumentSemantics::DefaultRelaxed, + non_default_implicit_conversion: None, }, &ast::Type::from(details.typ.clone()), ast::StateSpace::Reg, @@ -6566,11 +6425,7 @@ impl ast::Arg2Ld { ArgumentDescriptor { op: self.src, is_dst: false, - sema: if is_logical_ptr { - ArgumentSemantics::RegisterPointer - } else { - ArgumentSemantics::PhysicalPointer - }, + non_default_implicit_conversion: None, }, &details.typ, details.state_space, @@ -6591,11 +6446,7 @@ impl ast::Arg2St { ArgumentDescriptor { op: self.src1, is_dst: false, - sema: if is_logical_ptr { - ArgumentSemantics::RegisterPointer - } else { - ArgumentSemantics::PhysicalPointer - }, + non_default_implicit_conversion: None, }, &details.typ, details.state_space, @@ -6604,7 +6455,7 @@ impl ast::Arg2St { ArgumentDescriptor { op: self.src2, is_dst: false, - sema: ArgumentSemantics::DefaultRelaxed, + non_default_implicit_conversion: None, }, &details.typ.clone().into(), ast::StateSpace::Reg, @@ -6623,7 +6474,8 @@ impl ast::Arg2Mov { ArgumentDescriptor { op: self.dst, is_dst: true, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, &details.typ.clone().into(), ast::StateSpace::Reg, @@ -6632,11 +6484,7 @@ impl ast::Arg2Mov { ArgumentDescriptor { op: self.src, is_dst: false, - sema: if details.src_is_address { - ArgumentSemantics::Address - } else { - ArgumentSemantics::Default - }, + non_default_implicit_conversion: None, }, &details.typ.clone().into(), ast::StateSpace::Reg, @@ -6661,7 +6509,8 @@ impl ast::Arg3 { ArgumentDescriptor { op: self.dst, is_dst: true, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, wide_type.as_ref().unwrap_or(typ), ast::StateSpace::Reg, @@ -6670,7 +6519,8 @@ impl ast::Arg3 { ArgumentDescriptor { op: self.src1, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, typ, ast::StateSpace::Reg, @@ -6679,7 +6529,8 @@ impl ast::Arg3 { ArgumentDescriptor { op: self.src2, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, typ, ast::StateSpace::Reg, @@ -6696,7 +6547,8 @@ impl ast::Arg3 { ArgumentDescriptor { op: self.dst, is_dst: true, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, t, ast::StateSpace::Reg, @@ -6705,7 +6557,8 @@ impl ast::Arg3 { ArgumentDescriptor { op: self.src1, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, t, ast::StateSpace::Reg, @@ -6714,7 +6567,8 @@ impl ast::Arg3 { ArgumentDescriptor { op: self.src2, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, &ast::Type::Scalar(ast::ScalarType::U32), ast::StateSpace::Reg, @@ -6733,7 +6587,8 @@ impl ast::Arg3 { ArgumentDescriptor { op: self.dst, is_dst: true, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, &ast::Type::Scalar(scalar_type), ast::StateSpace::Reg, @@ -6742,7 +6597,7 @@ impl ast::Arg3 { ArgumentDescriptor { op: self.src1, is_dst: false, - sema: ArgumentSemantics::PhysicalPointer, + non_default_implicit_conversion: None, }, &ast::Type::Scalar(scalar_type), state_space, @@ -6751,7 +6606,8 @@ impl ast::Arg3 { ArgumentDescriptor { op: self.src2, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, &ast::Type::Scalar(scalar_type), ast::StateSpace::Reg, @@ -6776,7 +6632,8 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.dst, is_dst: true, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, wide_type.as_ref().unwrap_or(t), ast::StateSpace::Reg, @@ -6785,7 +6642,8 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.src1, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, t, ast::StateSpace::Reg, @@ -6794,7 +6652,8 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.src2, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, t, ast::StateSpace::Reg, @@ -6803,7 +6662,7 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.src3, is_dst: false, - sema: ArgumentSemantics::Default, + non_default_implicit_conversion: None, }, t, ast::StateSpace::Reg, @@ -6825,7 +6684,7 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.dst, is_dst: true, - sema: ArgumentSemantics::Default, + non_default_implicit_conversion: None, }, &ast::Type::Scalar(t.into()), ast::StateSpace::Reg, @@ -6834,7 +6693,7 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.src1, is_dst: false, - sema: ArgumentSemantics::Default, + non_default_implicit_conversion: None, }, &ast::Type::Scalar(t.into()), ast::StateSpace::Reg, @@ -6843,7 +6702,7 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.src2, is_dst: false, - sema: ArgumentSemantics::Default, + non_default_implicit_conversion: None, }, &ast::Type::Scalar(t.into()), ast::StateSpace::Reg, @@ -6852,7 +6711,7 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.src3, is_dst: false, - sema: ArgumentSemantics::Default, + non_default_implicit_conversion: None, }, &ast::Type::Scalar(ast::ScalarType::Pred), ast::StateSpace::Reg, @@ -6876,7 +6735,8 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.dst, is_dst: true, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, &ast::Type::Scalar(scalar_type), ast::StateSpace::Reg, @@ -6885,7 +6745,7 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.src1, is_dst: false, - sema: ArgumentSemantics::PhysicalPointer, + non_default_implicit_conversion: None, }, &ast::Type::Scalar(scalar_type), ast::StateSpace::Reg, @@ -6894,7 +6754,8 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.src2, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, &ast::Type::Scalar(scalar_type), ast::StateSpace::Reg, @@ -6903,7 +6764,8 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.src3, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, &ast::Type::Scalar(scalar_type), ast::StateSpace::Reg, @@ -6925,7 +6787,8 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.dst, is_dst: true, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, typ, ast::StateSpace::Reg, @@ -6934,7 +6797,8 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.src1, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, typ, ast::StateSpace::Reg, @@ -6944,7 +6808,8 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.src2, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, &u32_type, ast::StateSpace::Reg, @@ -6953,7 +6818,8 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.src3, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, &u32_type, ast::StateSpace::Reg, @@ -6977,7 +6843,8 @@ impl ast::Arg4Setp { ArgumentDescriptor { op: self.dst1, is_dst: true, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, Some(( &ast::Type::Scalar(ast::ScalarType::Pred), @@ -6991,7 +6858,8 @@ impl ast::Arg4Setp { ArgumentDescriptor { op: dst2, is_dst: true, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, Some(( &ast::Type::Scalar(ast::ScalarType::Pred), @@ -7004,7 +6872,8 @@ impl ast::Arg4Setp { ArgumentDescriptor { op: self.src1, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, t, ast::StateSpace::Reg, @@ -7013,7 +6882,8 @@ impl ast::Arg4Setp { ArgumentDescriptor { op: self.src2, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, t, ast::StateSpace::Reg, @@ -7037,7 +6907,8 @@ impl ast::Arg5 { ArgumentDescriptor { op: self.dst, is_dst: true, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, base_type, ast::StateSpace::Reg, @@ -7046,7 +6917,8 @@ impl ast::Arg5 { ArgumentDescriptor { op: self.src1, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, base_type, ast::StateSpace::Reg, @@ -7055,7 +6927,8 @@ impl ast::Arg5 { ArgumentDescriptor { op: self.src2, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, base_type, ast::StateSpace::Reg, @@ -7064,7 +6937,8 @@ impl ast::Arg5 { ArgumentDescriptor { op: self.src3, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, &ast::Type::Scalar(ast::ScalarType::U32), ast::StateSpace::Reg, @@ -7073,7 +6947,8 @@ impl ast::Arg5 { ArgumentDescriptor { op: self.src4, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, &ast::Type::Scalar(ast::ScalarType::U32), ast::StateSpace::Reg, @@ -7098,7 +6973,8 @@ impl ast::Arg5Setp { ArgumentDescriptor { op: self.dst1, is_dst: true, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, Some(( &ast::Type::Scalar(ast::ScalarType::Pred), @@ -7112,7 +6988,8 @@ impl ast::Arg5Setp { ArgumentDescriptor { op: dst2, is_dst: true, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, Some(( &ast::Type::Scalar(ast::ScalarType::Pred), @@ -7125,7 +7002,8 @@ impl ast::Arg5Setp { ArgumentDescriptor { op: self.src1, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, t, ast::StateSpace::Reg, @@ -7134,7 +7012,8 @@ impl ast::Arg5Setp { ArgumentDescriptor { op: self.src2, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, t, ast::StateSpace::Reg, @@ -7143,7 +7022,8 @@ impl ast::Arg5Setp { ArgumentDescriptor { op: self.src3, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, &ast::Type::Scalar(ast::ScalarType::Pred), ast::StateSpace::Reg, @@ -7287,6 +7167,12 @@ impl ast::StateSpace { ast::StateSpace::Sreg => spirv::StorageClass::Input, } } + + fn is_compatible(self, other: ast::StateSpace) -> bool { + self == other + || self == ast::StateSpace::Reg && other == ast::StateSpace::Sreg + || self == ast::StateSpace::Sreg && other == ast::StateSpace::Reg + } } impl ast::Operand { @@ -7342,54 +7228,89 @@ impl ast::StateSpace { } } -fn bitcast_register_pointer( - operand_type: &ast::Type, - operand_space: ast::StateSpace, - instr_type: &ast::Type, - instruction_space: ast::StateSpace, +fn default_implicit_conversion( + (operand_space, operand_type): (ast::StateSpace, &ast::Type), + (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), ) -> Result, TranslateError> { - bitcast_physical_pointer(operand_type, operand_space, instr_type, instruction_space) + if !instruction_space.is_compatible(operand_space) { + default_implicit_conversion_space( + (operand_space, operand_type), + (instruction_space, instruction_type), + ) + } else if instruction_type != operand_type { + default_implicit_conversion_type(instruction_space, operand_type, instruction_type) + } else { + Ok(None) + } } -fn bitcast_physical_pointer( - operand_type: &ast::Type, - operand_space: ast::StateSpace, - instruction_type: &ast::Type, - instruction_space: ast::StateSpace, +// Space is different +fn default_implicit_conversion_space( + (operand_space, operand_type): (ast::StateSpace, &ast::Type), + (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), ) -> Result, TranslateError> { - if operand_space == instruction_space { - if operand_type != instruction_type { - Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false })) - } else { - Ok(None) - } - } else { - match operand_space { - ast::StateSpace::Reg | ast::StateSpace::Sreg => match instruction_space { - ast::StateSpace::Generic - | ast::StateSpace::Global - | ast::StateSpace::Shared - | ast::StateSpace::Local => Ok(Some(ConversionKind::BitToPtr)), + if operand_space.is_compatible(ast::StateSpace::Reg) { + match operand_type { + ast::Type::Pointer(operand_ptr_type, operand_ptr_space) + if *operand_ptr_space == instruction_space => + { + if instruction_type != &ast::Type::Scalar(*operand_ptr_type) { + Ok(Some(ConversionKind::PtrToPtr)) + } else { + Ok(None) + } + } + // TODO: 32 bit + ast::Type::Scalar(ast::ScalarType::B64) + | ast::Type::Scalar(ast::ScalarType::U64) + | ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space { + ast::StateSpace::Global + | ast::StateSpace::Generic + | ast::StateSpace::Const + | ast::StateSpace::Local + | ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)), + _ => Err(TranslateError::MismatchedType), + }, + ast::Type::Scalar(ast::ScalarType::B32) + | ast::Type::Scalar(ast::ScalarType::U32) + | ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space { + ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => { + Ok(Some(ConversionKind::BitToPtr)) + } _ => Err(TranslateError::MismatchedType), }, _ => Err(TranslateError::MismatchedType), } + } else if instruction_space.is_compatible(ast::StateSpace::Reg) { + if let ast::Type::Pointer(instr_ptr_type, instr_ptr_space) = instruction_type { + if operand_space != *instr_ptr_space { + Ok(Some(ConversionKind::PtrToPtr)) + } else { + Ok(None) + } + } else { + Err(TranslateError::MismatchedType) + } + } else { + Err(TranslateError::MismatchedType) } } -fn force_bitcast_ptr_to_bit( - _: &ast::Type, - _: ast::StateSpace, - instr_type: &ast::Type, - _: ast::StateSpace, +// Space is same, but type is different +fn default_implicit_conversion_type( + space: ast::StateSpace, + operand_type: &ast::Type, + instruction_type: &ast::Type, ) -> Result, TranslateError> { - // TODO: verify this on f32, u16 and the like - if let ast::Type::Scalar(scalar_t) = instr_type { - if let Ok(int_type) = (*scalar_t).try_into() { - return Ok(Some(ConversionKind::PtrToBit(int_type))); + if space.is_compatible(ast::StateSpace::Reg) { + if should_bitcast(instruction_type, operand_type) { + Ok(Some(ConversionKind::Default)) + } else { + Err(TranslateError::MismatchedType) } + } else { + Ok(Some(ConversionKind::PtrToPtr)) } - Err(TranslateError::MismatchedType) } fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool { @@ -7421,22 +7342,26 @@ fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool { } } -fn should_bitcast_packed( - operand: &ast::Type, - operand_space: ast::StateSpace, - instruction: &ast::Type, - instruction_space: ast::StateSpace, +fn implicit_conversion_mov( + (operand_space, operand_type): (ast::StateSpace, &ast::Type), + (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), ) -> Result, TranslateError> { - if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) = - (operand, instruction) - { - if scalar.kind() == ast::ScalarKind::Bit - && scalar.size_of() == (vec_underlying_type.size_of() * vec_len) + // instruction_space is always reg + if operand_space.is_compatible(ast::StateSpace::Reg) { + if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) = + (operand_type, instruction_type) { - return Ok(Some(ConversionKind::Default)); + if scalar.kind() == ast::ScalarKind::Bit + && scalar.size_of() == (vec_underlying_type.size_of() * vec_len) + { + return Ok(Some(ConversionKind::Default)); + } } } - should_bitcast_wrapper(operand, operand_space, instruction, instruction_space) + default_implicit_conversion( + (operand_space, operand_type), + (instruction_space, instruction_type), + ) } fn should_bitcast_wrapper(