diff --git a/level_zero/src/ze.rs b/level_zero/src/ze.rs index 4f120de..eb3887b 100644 --- a/level_zero/src/ze.rs +++ b/level_zero/src/ze.rs @@ -743,7 +743,7 @@ impl<'a> Kernel<'a> { check!(sys::zeKernelSetArgumentValue( self.0, index, - mem::size_of::(), + mem::size_of::<*const ()>(), &ptr as *const _ as *const _, )); Ok(()) diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index ec49925..a2c6d66 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -268,6 +268,7 @@ pub struct Arg5 { pub src3: P::Operand, } +#[derive(Copy, Clone)] pub enum Operand { Reg(ID), RegOffset(ID, i32), @@ -353,6 +354,7 @@ pub struct MulFloatDesc { pub saturate: bool, } +#[derive(PartialEq, Eq, Copy, Clone)] pub enum RoundingMode { NearestEven, Zero, diff --git a/ptx/src/test/spirv_run/cvt_sat_s_u.ptx b/ptx/src/test/spirv_run/cvt_sat_s_u.ptx new file mode 100644 index 0000000..ef0a10f --- /dev/null +++ b/ptx/src/test/spirv_run/cvt_sat_s_u.ptx @@ -0,0 +1,24 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry cvt_sat_s_u( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .s32 temp; + .reg .u32 temp2; + .reg .s32 temp3; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.s32 temp, [in_addr]; + cvt.sat.u32.s32 temp2, temp; + cvt.s32.u32 temp3, temp2; + st.s32 [out_addr], temp3; + ret; +} diff --git a/ptx/src/test/spirv_run/cvt_sat_s_u.spvtxt b/ptx/src/test/spirv_run/cvt_sat_s_u.spvtxt new file mode 100644 index 0000000..afd2864 --- /dev/null +++ b/ptx/src/test/spirv_run/cvt_sat_s_u.spvtxt @@ -0,0 +1,43 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + %1 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %5 "cvt_sat_s_u" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %4 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %uint = OpTypeInt 32 0 +%_ptr_Function_uint = OpTypePointer Function %uint +%_ptr_Generic_uint = OpTypePointer Generic %uint + %5 = OpFunction %void None %4 + %6 = OpFunctionParameter %ulong + %7 = OpFunctionParameter %ulong + %23 = OpLabel + %8 = OpVariable %_ptr_Function_ulong Function + %9 = OpVariable %_ptr_Function_ulong Function + %10 = OpVariable %_ptr_Function_uint Function + %11 = OpVariable %_ptr_Function_uint Function + %12 = OpVariable %_ptr_Function_uint Function + OpStore %8 %6 + OpStore %9 %7 + %14 = OpLoad %ulong %8 + %21 = OpConvertUToPtr %_ptr_Generic_uint %14 + %13 = OpLoad %uint %21 + OpStore %10 %13 + %16 = OpLoad %uint %10 + %15 = OpSatConvertSToU %uint %16 + OpStore %11 %15 + %18 = OpLoad %uint %11 + %17 = OpBitcast %uint %18 + OpStore %12 %17 + %19 = OpLoad %ulong %9 + %20 = OpLoad %uint %12 + %22 = OpConvertUToPtr %_ptr_Generic_uint %19 + OpStore %22 %20 + OpReturn + OpFunctionEnd \ No newline at end of file diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 14a48be..e1e5c32 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -1,7 +1,7 @@ use crate::ptx; use crate::translate; use rspirv::{ - binary::Assemble, + binary::{Assemble, Disassemble}, dr::{Block, Function, Instruction, Loader, Operand}, }; use spirv_headers::Word; @@ -48,6 +48,7 @@ test_ptx!(setp, [10u64, 11u64], [1u64, 0u64]); test_ptx!(bra, [10u64], [11u64]); test_ptx!(not, [0u64], [u64::max_value()]); test_ptx!(shl, [11u64], [44u64]); +test_ptx!(cvt_sat_s_u, [0i32], [0i32]); struct DisplayError { err: T, diff --git a/ptx/src/test/spirv_run/not.spvtxt b/ptx/src/test/spirv_run/not.spvtxt index 518e995..84482d9 100644 --- a/ptx/src/test/spirv_run/not.spvtxt +++ b/ptx/src/test/spirv_run/not.spvtxt @@ -12,7 +12,6 @@ %4 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_Generic_ulong = OpTypePointer Generic %ulong - %ulong_0 = OpTypeInt 64 0 %5 = OpFunction %void None %4 %6 = OpFunctionParameter %ulong %7 = OpFunctionParameter %ulong @@ -27,8 +26,8 @@ %18 = OpConvertUToPtr %_ptr_Generic_ulong %13 %12 = OpLoad %ulong %18 OpStore %10 %12 - %15 = OpLoad %ulong_0 %10 - %14 = OpNot %ulong_0 %15 + %15 = OpLoad %ulong %10 + %14 = OpNot %ulong %15 OpStore %11 %14 %16 = OpLoad %ulong %9 %17 = OpLoad %ulong %11 @@ -36,4 +35,3 @@ OpStore %19 %17 OpReturn OpFunctionEnd - \ No newline at end of file diff --git a/ptx/src/test/spirv_run/setp.spvtxt b/ptx/src/test/spirv_run/setp.spvtxt index 22e7b54..064cd97 100644 --- a/ptx/src/test/spirv_run/setp.spvtxt +++ b/ptx/src/test/spirv_run/setp.spvtxt @@ -20,7 +20,7 @@ %5 = OpFunction %void None %4 %6 = OpFunctionParameter %ulong %7 = OpFunctionParameter %ulong - %38 = OpLabel + %39 = OpLabel %8 = OpVariable %_ptr_Function_ulong Function %9 = OpVariable %_ptr_Function_ulong Function %10 = OpVariable %_ptr_Function_ulong Function @@ -34,9 +34,10 @@ %18 = OpLoad %ulong %35 OpStore %10 %18 %21 = OpLoad %ulong %8 - %32 = OpIAdd %ulong %21 %ulong_8 - %36 = OpConvertUToPtr %_ptr_Generic_ulong %32 - %20 = OpLoad %ulong %36 + %36 = OpCopyObject %ulong %21 + %32 = OpIAdd %ulong %36 %ulong_8 + %37 = OpConvertUToPtr %_ptr_Generic_ulong %32 + %20 = OpLoad %ulong %37 OpStore %11 %20 %23 = OpLoad %ulong %10 %24 = OpLoad %ulong %11 @@ -58,8 +59,7 @@ %17 = OpLabel %29 = OpLoad %ulong %9 %30 = OpLoad %ulong %12 - %37 = OpConvertUToPtr %_ptr_Generic_ulong %29 - OpStore %37 %30 + %38 = OpConvertUToPtr %_ptr_Generic_ulong %29 + OpStore %38 %30 OpReturn OpFunctionEnd - \ No newline at end of file diff --git a/ptx/src/test/spirv_run/shl.spvtxt b/ptx/src/test/spirv_run/shl.spvtxt index 131bd9e..3e57fc3 100644 --- a/ptx/src/test/spirv_run/shl.spvtxt +++ b/ptx/src/test/spirv_run/shl.spvtxt @@ -12,7 +12,6 @@ %4 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_Generic_ulong = OpTypePointer Generic %ulong - %ulong_0 = OpTypeInt 64 0 %uint = OpTypeInt 32 0 %uint_2 = OpConstant %uint 2 %5 = OpFunction %void None %4 @@ -29,8 +28,8 @@ %19 = OpConvertUToPtr %_ptr_Generic_ulong %13 %12 = OpLoad %ulong %19 OpStore %10 %12 - %15 = OpLoad %ulong_0 %10 - %14 = OpShiftLeftLogical %ulong_0 %15 %uint_2 + %15 = OpLoad %ulong %10 + %14 = OpShiftLeftLogical %ulong %15 %uint_2 OpStore %11 %14 %16 = OpLoad %ulong %9 %17 = OpLoad %ulong %11 @@ -38,4 +37,3 @@ OpStore %20 %17 OpReturn OpFunctionEnd - \ No newline at end of file diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 9e51046..511ef72 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -7,25 +7,85 @@ use rspirv::binary::Assemble; #[derive(PartialEq, Eq, Hash, Clone, Copy)] enum SpirvType { - Base(ast::ScalarType), - Extended(ast::ExtendedScalarType), - Pointer(ast::Type, spirv::StorageClass), + Base(SpirvScalarKey), + Pointer(SpirvScalarKey, spirv::StorageClass), +} + +impl SpirvType { + fn new_pointer(t: ast::Type, sc: spirv::StorageClass) -> Self { + let key = match t { + ast::Type::Scalar(typ) => SpirvScalarKey::from(typ), + ast::Type::ExtendedScalar(typ) => SpirvScalarKey::from(typ), + }; + SpirvType::Pointer(key, sc) + } } impl From for SpirvType { fn from(t: ast::Type) -> Self { match t { - ast::Type::Scalar(t) => SpirvType::Base(t), - ast::Type::ExtendedScalar(t) => SpirvType::Extended(t), + ast::Type::Scalar(t) => SpirvType::Base(t.into()), + ast::Type::ExtendedScalar(t) => SpirvType::Base(t.into()), } } } +impl From for SpirvType { + fn from(t: ast::ScalarType) -> Self { + SpirvType::Base(t.into()) + } +} + struct TypeWordMap { void: spirv::Word, complex: HashMap, } +// SPIR-V integer type definitions are signless, more below: +// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_unsignedsigned_a_unsigned_versus_signed_integers +// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_validation_rules_for_kernel_a_href_capability_capabilities_a +#[derive(PartialEq, Eq, Hash, Clone, Copy)] +enum SpirvScalarKey { + B8, + B16, + B32, + B64, + F16, + F32, + F64, + Pred, + F16x2, +} + +impl From for SpirvScalarKey { + fn from(t: ast::ScalarType) -> Self { + match t { + ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => SpirvScalarKey::B8, + ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => { + SpirvScalarKey::B16 + } + ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => { + SpirvScalarKey::B32 + } + ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => { + SpirvScalarKey::B64 + } + ast::ScalarType::F16 => SpirvScalarKey::F16, + ast::ScalarType::F32 => SpirvScalarKey::F32, + ast::ScalarType::F64 => SpirvScalarKey::F64, + } + } +} + +impl From for SpirvScalarKey { + fn from(t: ast::ExtendedScalarType) -> Self { + match t { + ast::ExtendedScalarType::Pred => SpirvScalarKey::Pred, + ast::ExtendedScalarType::F16x2 => SpirvScalarKey::F16x2, + } + } +} + impl TypeWordMap { fn new(b: &mut dr::Builder) -> TypeWordMap { let void = b.type_void(); @@ -40,21 +100,24 @@ impl TypeWordMap { } fn get_or_add_scalar(&mut self, b: &mut dr::Builder, t: ast::ScalarType) -> spirv::Word { + let key: SpirvScalarKey = t.into(); + self.get_or_add_spirv_scalar(b, key) + } + + fn get_or_add_spirv_scalar(&mut self, b: &mut dr::Builder, key: SpirvScalarKey) -> spirv::Word { *self .complex - .entry(SpirvType::Base(t)) - .or_insert_with(|| match t { - ast::ScalarType::B8 | ast::ScalarType::U8 => b.type_int(8, 0), - ast::ScalarType::B16 | ast::ScalarType::U16 => b.type_int(16, 0), - ast::ScalarType::B32 | ast::ScalarType::U32 => b.type_int(32, 0), - ast::ScalarType::B64 | ast::ScalarType::U64 => b.type_int(64, 0), - ast::ScalarType::S8 => b.type_int(8, 1), - ast::ScalarType::S16 => b.type_int(16, 1), - ast::ScalarType::S32 => b.type_int(32, 1), - ast::ScalarType::S64 => b.type_int(64, 1), - ast::ScalarType::F16 => b.type_float(16), - ast::ScalarType::F32 => b.type_float(32), - ast::ScalarType::F64 => b.type_float(64), + .entry(SpirvType::Base(key)) + .or_insert_with(|| match key { + SpirvScalarKey::B8 => b.type_int(8, 0), + SpirvScalarKey::B16 => b.type_int(16, 0), + SpirvScalarKey::B32 => b.type_int(32, 0), + SpirvScalarKey::B64 => b.type_int(64, 0), + SpirvScalarKey::F16 => b.type_float(16), + SpirvScalarKey::F32 => b.type_float(32), + SpirvScalarKey::F64 => b.type_float(64), + SpirvScalarKey::Pred => b.type_bool(), + SpirvScalarKey::F16x2 => todo!(), }) } @@ -63,24 +126,15 @@ impl TypeWordMap { b: &mut dr::Builder, t: ast::ExtendedScalarType, ) -> spirv::Word { - *self - .complex - .entry(SpirvType::Extended(t)) - .or_insert_with(|| match t { - ast::ExtendedScalarType::Pred => b.type_bool(), - ast::ExtendedScalarType::F16x2 => todo!(), - }) + let key: SpirvScalarKey = t.into(); + self.get_or_add_spirv_scalar(b, key) } fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> spirv::Word { match t { - SpirvType::Base(scalar) => self.get_or_add_scalar(b, scalar), - SpirvType::Extended(ext) => self.get_or_add_extended(b, ext), - SpirvType::Pointer(typ, storage) => { - let base = match typ { - ast::Type::Scalar(scalar) => self.get_or_add_scalar(b, scalar), - ast::Type::ExtendedScalar(ext) => self.get_or_add_extended(b, ext), - }; + SpirvType::Base(key) => self.get_or_add_spirv_scalar(b, key), + SpirvType::Pointer(typ, mut storage) => { + let base = self.get_or_add_spirv_scalar(b, typ); *self .complex .entry(t) @@ -102,7 +156,7 @@ impl TypeWordMap { pub fn to_spirv_module(ast: ast::Module) -> Result { let mut builder = dr::Builder::new(); // https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module - builder.set_version(1, 0); + builder.set_version(1, 3); emit_capabilities(&mut builder); emit_extensions(&mut builder); let opencl_id = emit_opencl_import(&mut builder); @@ -277,24 +331,25 @@ fn insert_mem_ssa_statements( } inst => { let mut post_statements = Vec::new(); - let inst = inst.visit_variable(&mut |id, is_dst, id_type| { - let id_type = match id_type { - Some(t) => t, - None => return id, + let inst = inst.visit_variable(&mut |desc| { + let id_type = match (desc.typ, desc.is_pointer) { + (Some(t), false) => t, + (Some(_), true) => ast::Type::Scalar(ast::ScalarType::B64), + (None, _) => return desc.op, }; let generated_id = id_def.new_id(Some(id_type)); - if !is_dst { + if !desc.is_dst { result.push(Statement::LoadVar( Arg2 { dst: generated_id, - src: id, + src: desc.op, }, id_type, )); } else { post_statements.push(Statement::StoreVar( Arg2St { - src1: id, + src1: desc.op, src2: generated_id, }, id_type, @@ -365,15 +420,15 @@ impl<'a> FlattenArguments<'a> { } impl<'a> ArgumentMapVisitor for FlattenArguments<'a> { - fn dst_variable(&mut self, x: spirv::Word, _: Option) -> spirv::Word { - x + fn dst_variable(&mut self, desc: ArgumentDescriptor) -> spirv::Word { + desc.op } - fn src_operand(&mut self, op: ast::Operand, t: Option) -> spirv::Word { - match op { + fn src_operand(&mut self, desc: ArgumentDescriptor>) -> spirv::Word { + match desc.op { ast::Operand::Reg(r) => r, ast::Operand::Imm(x) => { - if let Some(typ) = t { + if let Some(typ) = desc.typ { let scalar_t = if let ast::Type::Scalar(scalar) = typ { scalar } else { @@ -391,7 +446,7 @@ impl<'a> ArgumentMapVisitor for FlattenA } } ast::Operand::RegOffset(reg, offset) => { - if let Some(typ) = t { + if let Some(typ) = desc.typ { let scalar_t = if let ast::Type::Scalar(scalar) = typ { scalar } else { @@ -403,7 +458,7 @@ impl<'a> ArgumentMapVisitor for FlattenA typ: scalar_t, value: offset as i128, })); - let result_id = self.id_def.new_id(t); + let result_id = self.id_def.new_id(desc.typ); let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!()); self.func.push(Statement::Instruction( ast::Instruction::::Add( @@ -428,11 +483,10 @@ impl<'a> ArgumentMapVisitor for FlattenA fn src_mov_operand( &mut self, - op: ast::MovOperand, - t: Option, + desc: ArgumentDescriptor>, ) -> spirv::Word { - match op { - ast::MovOperand::Op(opr) => self.src_operand(opr, t), + match &desc.op { + ast::MovOperand::Op(opr) => self.src_operand(desc.new_op(*opr)), ast::MovOperand::Vec(_, _) => todo!(), } } @@ -517,7 +571,7 @@ fn get_function_type( map: &mut TypeWordMap, args: &[ast::Argument], ) -> spirv::Word { - map.get_or_add_fn(builder, args.iter().map(|arg| SpirvType::Base(arg.a_type))) + map.get_or_add_fn(builder, args.iter().map(|arg| SpirvType::from(arg.a_type))) } fn emit_function_args( @@ -565,7 +619,7 @@ fn emit_function_body_ops( Statement::Variable(id, typ, ss) => { let type_id = map.get_or_add( builder, - SpirvType::Pointer(*typ, spirv::StorageClass::Function), + SpirvType::new_pointer(*typ, spirv::StorageClass::Function), ); if *ss != ast::StateSpace::Reg { todo!() @@ -672,7 +726,10 @@ fn emit_function_body_ops( let result_type = map.get_or_add(builder, SpirvType::from(t.to_type())); builder.shift_left_logical(result_type, Some(a.dst), a.src1, a.src2)?; } - _ => todo!(), + ast::Instruction::Cvt(dets, arg) => { + emit_cvt(builder, map, opencl, dets, arg)?; + } + ast::Instruction::SetpBool(_, _) => todo!(), }, Statement::LoadVar(arg, typ) => { let type_id = map.get_or_add(builder, SpirvType::from(*typ)); @@ -686,6 +743,133 @@ fn emit_function_body_ops( Ok(()) } +fn emit_cvt( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + dets: &ast::CvtDetails, + arg: &ast::Arg2, +) -> Result<(), dr::Error> { + match dets { + ast::CvtDetails::FloatFromFloat(desc) => { + if desc.dst == desc.src { + return Ok(()); + } + if desc.saturate || desc.flush_to_zero { + todo!() + } + let dest_t: ast::Type = desc.dst.into(); + let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); + builder.f_convert(result_type, Some(arg.dst), arg.src)?; + emit_rounding_decoration(builder, arg.dst, desc.rounding); + } + ast::CvtDetails::FloatFromInt(desc) => { + if desc.saturate || desc.flush_to_zero { + todo!() + } + let dest_t: ast::Type = desc.dst.into(); + let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); + if desc.src.is_signed() { + builder.convert_s_to_f(result_type, Some(arg.dst), arg.src)?; + } else { + builder.convert_u_to_f(result_type, Some(arg.dst), arg.src)?; + } + emit_rounding_decoration(builder, arg.dst, desc.rounding); + } + ast::CvtDetails::IntFromFloat(desc) => { + if desc.flush_to_zero { + todo!() + } + let dest_t: ast::ScalarType = desc.dst.into(); + let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); + if desc.dst.is_signed() { + builder.convert_f_to_s(result_type, Some(arg.dst), arg.src)?; + } else { + builder.convert_f_to_u(result_type, Some(arg.dst), arg.src)?; + } + emit_rounding_decoration(builder, arg.dst, desc.rounding); + emit_saturating_decoration(builder, arg.dst, desc.saturate); + } + ast::CvtDetails::IntFromInt(desc) => { + if desc.dst == desc.src { + return Ok(()); + } + let dest_t: ast::ScalarType = desc.dst.into(); + let src_t: ast::ScalarType = desc.src.into(); + // first do shortening/widening + let src = if desc.dst.width() != desc.src.width() { + let new_dst = if dest_t.kind() == src_t.kind() { + arg.dst + } else { + builder.id() + }; + let cv = ImplicitConversion { + src: arg.src, + dst: new_dst, + from: ast::Type::Scalar(src_t), + to: ast::Type::Scalar(ast::ScalarType::from_parts( + dest_t.width(), + src_t.kind(), + )), + kind: ConversionKind::Default, + }; + emit_implicit_conversion(builder, map, &cv)?; + new_dst + } else { + arg.src + }; + if dest_t.kind() == src_t.kind() { + return Ok(()); + } + // now do actual conversion + let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); + if desc.saturate { + if desc.dst.is_signed() { + builder.sat_convert_u_to_s(result_type, Some(arg.dst), src)?; + } else { + builder.sat_convert_s_to_u(result_type, Some(arg.dst), src)?; + } + } else { + builder.bitcast(result_type, Some(arg.dst), src)?; + } + } + _ => todo!(), + } + Ok(()) +} + +fn emit_saturating_decoration(builder: &mut dr::Builder, dst: u32, saturate: bool) { + if saturate { + builder.decorate(dst, spirv::Decoration::SaturatedConversion, []); + } +} + +fn emit_rounding_decoration( + builder: &mut dr::Builder, + dst: spirv::Word, + rounding: Option, +) { + if let Some(rounding) = rounding { + builder.decorate( + dst, + spirv::Decoration::FPRoundingMode, + [rounding.to_spirv()], + ); + } +} + +impl ast::RoundingMode { + fn to_spirv(self) -> rspirv::dr::Operand { + let mode = match self { + ast::RoundingMode::NearestEven => spirv::FPRoundingMode::RTE, + ast::RoundingMode::Zero => spirv::FPRoundingMode::RTZ, + ast::RoundingMode::PositiveInf => spirv::FPRoundingMode::RTP, + ast::RoundingMode::NegativeInf => spirv::FPRoundingMode::RTN, + }; + rspirv::dr::Operand::FPRoundingMode(mode) + } +} + fn emit_setp( builder: &mut dr::Builder, map: &mut TypeWordMap, @@ -695,7 +879,7 @@ fn emit_setp( if setp.flush_to_zero { todo!() } - let result_type = map.get_or_add(builder, SpirvType::Extended(ast::ExtendedScalarType::Pred)); + let result_type = map.get_or_add(builder, SpirvType::Base(SpirvScalarKey::Pred)); let result_id = Some(arg.dst1); let operand_1 = arg.src1; let operand_2 = arg.src2; @@ -768,7 +952,7 @@ fn emit_mul_int( desc: &ast::MulIntDesc, arg: &ast::Arg3, ) -> Result<(), dr::Error> { - let inst_type = map.get_or_add(builder, SpirvType::Base(desc.typ.into())); + let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ))); match desc.control { ast::MulIntControl::Low => { builder.i_mul(inst_type, Some(arg.dst), arg.src1, arg.src2)?; @@ -798,7 +982,7 @@ fn emit_add_int( ctr: &ast::AddIntDesc, arg: &ast::Arg3, ) -> Result<(), dr::Error> { - let inst_type = map.get_or_add(builder, SpirvType::Base(ctr.typ.into())); + let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(ctr.typ))); builder.i_add(inst_type, Some(arg.dst), arg.src1, arg.src2)?; Ok(()) } @@ -817,7 +1001,7 @@ fn emit_implicit_conversion( let dst_type = map.get_or_add( builder, SpirvType::Pointer( - ast::Type::Scalar(to_type), + SpirvScalarKey::from(to_type), spirv_headers::StorageClass::Generic, ), ); @@ -826,14 +1010,12 @@ fn emit_implicit_conversion( ConversionKind::Default => { if from_type.width() == to_type.width() { let dst_type = map.get_or_add_scalar(builder, to_type); - if from_type.kind() == ScalarKind::Unsigned && to_type.kind() == ScalarKind::Byte - || from_type.kind() == ScalarKind::Byte - && to_type.kind() == ScalarKind::Unsigned - { + if from_type.kind() != ScalarKind::Float && to_type.kind() != ScalarKind::Float { // It is noop, but another instruction expects result of this conversion builder.copy_object(dst_type, Some(cv.dst), cv.src)?; + } else { + builder.bitcast(dst_type, Some(cv.dst), cv.src)?; } - builder.bitcast(dst_type, Some(cv.dst), cv.src)?; } else { let as_unsigned_type = map.get_or_add_scalar( builder, @@ -1057,23 +1239,23 @@ impl ast::ArgParams for ExpandedArgParams { } trait ArgumentMapVisitor { - fn dst_variable(&mut self, v: T::ID, typ: Option) -> U::ID; - fn src_operand(&mut self, o: T::Operand, typ: Option) -> U::Operand; - fn src_mov_operand(&mut self, o: T::MovOperand, typ: Option) -> U::MovOperand; + fn dst_variable(&mut self, desc: ArgumentDescriptor) -> U::ID; + fn src_operand(&mut self, desc: ArgumentDescriptor) -> U::Operand; + fn src_mov_operand(&mut self, desc: ArgumentDescriptor) -> U::MovOperand; } impl ArgumentMapVisitor for T where - T: FnMut(spirv::Word, bool, Option) -> spirv::Word, + T: FnMut(ArgumentDescriptor) -> spirv::Word, { - fn dst_variable(&mut self, x: spirv::Word, t: Option) -> spirv::Word { - self(x, t.is_some(), t) + fn dst_variable(&mut self, desc: ArgumentDescriptor) -> spirv::Word { + self(desc) } - fn src_operand(&mut self, x: spirv::Word, t: Option) -> spirv::Word { - self(x, false, t) + fn src_operand(&mut self, desc: ArgumentDescriptor) -> spirv::Word { + self(desc) } - fn src_mov_operand(&mut self, x: spirv::Word, t: Option) -> spirv::Word { - self(x, false, t) + fn src_mov_operand(&mut self, desc: ArgumentDescriptor) -> spirv::Word { + self(desc) } } @@ -1081,16 +1263,15 @@ impl<'a, T> ArgumentMapVisitor, NormalizedArgParams> fo where T: FnMut(&str) -> spirv::Word, { - fn dst_variable(&mut self, x: &str, _: Option) -> spirv::Word { - self(x) + fn dst_variable(&mut self, desc: ArgumentDescriptor<&str>) -> spirv::Word { + self(desc.op) } fn src_operand( &mut self, - x: ast::Operand<&str>, - _: Option, + desc: ArgumentDescriptor>, ) -> ast::Operand { - match x { + match desc.op { ast::Operand::Reg(id) => ast::Operand::Reg(self(id)), ast::Operand::Imm(imm) => ast::Operand::Imm(imm), ast::Operand::RegOffset(id, imm) => ast::Operand::RegOffset(self(id), imm), @@ -1099,16 +1280,33 @@ where fn src_mov_operand( &mut self, - x: ast::MovOperand<&str>, - t: Option, + desc: ArgumentDescriptor>, ) -> ast::MovOperand { - match x { - ast::MovOperand::Op(op) => ast::MovOperand::Op(self.src_operand(op, t)), + match desc.op { + ast::MovOperand::Op(op) => ast::MovOperand::Op(self.src_operand(desc.new_op(op))), ast::MovOperand::Vec(x1, x2) => ast::MovOperand::Vec(x1, x2), } } } +struct ArgumentDescriptor { + op: T, + is_dst: bool, + typ: Option, + is_pointer: bool, +} + +impl ArgumentDescriptor { + fn new_op(&self, u: U) -> ArgumentDescriptor { + ArgumentDescriptor { + op: u, + is_dst: self.is_dst, + typ: self.typ, + is_pointer: self.is_pointer, + } + } +} + impl ast::Instruction { fn map>( self, @@ -1117,7 +1315,7 @@ impl ast::Instruction { match self { ast::Instruction::Ld(d, a) => { let inst_type = d.typ; - ast::Instruction::Ld(d, a.map(visitor, Some(ast::Type::Scalar(inst_type)))) + ast::Instruction::Ld(d, a.map_ld(visitor, Some(ast::Type::Scalar(inst_type)))) } ast::Instruction::Mov(d, a) => { let inst_type = d.typ; @@ -1142,7 +1340,22 @@ impl ast::Instruction { ast::Instruction::Not(t, a) => { ast::Instruction::Not(t, a.map(visitor, Some(t.to_type()))) } - ast::Instruction::Cvt(_, _) => todo!(), + ast::Instruction::Cvt(d, a) => { + let (dst_t, src_t) = match &d { + ast::CvtDetails::FloatFromFloat(desc) => (desc.dst.into(), desc.src.into()), + ast::CvtDetails::FloatFromInt(desc) => { + (desc.dst.into(), ast::Type::Scalar(desc.src.into())) + } + ast::CvtDetails::IntFromFloat(desc) => { + (ast::Type::Scalar(desc.dst.into()), desc.src.into()) + } + ast::CvtDetails::IntFromInt(desc) => ( + ast::Type::Scalar(desc.dst.into()), + ast::Type::Scalar(desc.src.into()), + ), + }; + ast::Instruction::Cvt(d, a.map_cvt(visitor, dst_t, src_t)) + } ast::Instruction::Shl(t, a) => { ast::Instruction::Shl(t, a.map_shift(visitor, Some(t.to_type()))) } @@ -1157,7 +1370,7 @@ impl ast::Instruction { } impl ast::Instruction { - fn visit_variable) -> spirv::Word>( + fn visit_variable) -> spirv::Word>( self, f: &mut F, ) -> ast::Instruction { @@ -1167,34 +1380,34 @@ impl ast::Instruction { impl ArgumentMapVisitor for T where - T: FnMut(spirv::Word, bool, Option) -> spirv::Word, + T: FnMut(ArgumentDescriptor) -> spirv::Word, { - fn dst_variable(&mut self, x: spirv::Word, t: Option) -> spirv::Word { - self(x, t.is_some(), t) + fn dst_variable(&mut self, desc: ArgumentDescriptor) -> spirv::Word { + self(desc) } fn src_operand( &mut self, - x: ast::Operand, - t: Option, + desc: ArgumentDescriptor>, ) -> ast::Operand { - match x { - ast::Operand::Reg(id) => ast::Operand::Reg(self(id, false, t)), + match desc.op { + ast::Operand::Reg(id) => ast::Operand::Reg(self(desc.new_op(id))), ast::Operand::Imm(imm) => ast::Operand::Imm(imm), - ast::Operand::RegOffset(id, imm) => ast::Operand::RegOffset(self(id, false, t), imm), + ast::Operand::RegOffset(id, imm) => ast::Operand::RegOffset(self(desc.new_op(id)), imm), } } fn src_mov_operand( &mut self, - x: ast::MovOperand, - t: Option, + desc: ArgumentDescriptor>, ) -> ast::MovOperand { - match x { + match desc.op { ast::MovOperand::Op(op) => ast::MovOperand::Op(ArgumentMapVisitor::< NormalizedArgParams, NormalizedArgParams, - >::src_operand(self, op, t)), + >::src_operand( + self, desc.new_op(op) + )), ast::MovOperand::Vec(x1, x2) => ast::MovOperand::Vec(x1, x2), } } @@ -1202,8 +1415,8 @@ where fn reduced_visitor<'a>( f: &'a mut impl FnMut(spirv::Word) -> spirv::Word, -) -> impl FnMut(spirv::Word, bool, Option) -> spirv::Word + 'a { - move |id, _, _| f(id) +) -> impl FnMut(ArgumentDescriptor) -> spirv::Word + 'a { + move |desc| f(desc.op) } impl ast::Instruction { @@ -1212,7 +1425,7 @@ impl ast::Instruction { self.map(&mut visitor) } - fn visit_variable_extended) -> spirv::Word>( + fn visit_variable_extended) -> spirv::Word>( self, f: &mut F, ) -> Self { @@ -1326,7 +1539,12 @@ impl ast::Arg1 { t: Option, ) -> ast::Arg1 { ast::Arg1 { - src: visitor.dst_variable(self.src, t), + src: visitor.dst_variable(ArgumentDescriptor { + op: self.src, + typ: t, + is_dst: false, + is_pointer: false, + }), } } } @@ -1338,8 +1556,61 @@ impl ast::Arg2 { t: Option, ) -> ast::Arg2 { ast::Arg2 { - dst: visitor.dst_variable(self.dst, t), - src: visitor.src_operand(self.src, t), + dst: visitor.dst_variable(ArgumentDescriptor { + op: self.dst, + typ: t, + is_dst: true, + is_pointer: false, + }), + src: visitor.src_operand(ArgumentDescriptor { + op: self.src, + typ: t, + is_dst: false, + is_pointer: false, + }), + } + } + + fn map_ld>( + self, + visitor: &mut V, + t: Option, + ) -> ast::Arg2 { + ast::Arg2 { + dst: visitor.dst_variable(ArgumentDescriptor { + op: self.dst, + typ: t, + is_dst: true, + is_pointer: false, + }), + src: visitor.src_operand(ArgumentDescriptor { + op: self.src, + typ: t, + is_dst: false, + is_pointer: true, + }), + } + } + + fn map_cvt>( + self, + visitor: &mut V, + dst_t: ast::Type, + src_t: ast::Type, + ) -> ast::Arg2 { + ast::Arg2 { + dst: visitor.dst_variable(ArgumentDescriptor { + op: self.dst, + typ: Some(dst_t), + is_dst: true, + is_pointer: false, + }), + src: visitor.src_operand(ArgumentDescriptor { + op: self.src, + typ: Some(src_t), + is_dst: false, + is_pointer: false, + }), } } } @@ -1351,8 +1622,18 @@ impl ast::Arg2St { t: Option, ) -> ast::Arg2St { ast::Arg2St { - src1: visitor.src_operand(self.src1, t), - src2: visitor.src_operand(self.src2, t), + src1: visitor.src_operand(ArgumentDescriptor { + op: self.src1, + typ: t, + is_dst: false, + is_pointer: true, + }), + src2: visitor.src_operand(ArgumentDescriptor { + op: self.src2, + typ: t, + is_dst: false, + is_pointer: false, + }), } } } @@ -1364,8 +1645,18 @@ impl ast::Arg2Mov { t: Option, ) -> ast::Arg2Mov { ast::Arg2Mov { - dst: visitor.dst_variable(self.dst, t), - src: visitor.src_mov_operand(self.src, t), + dst: visitor.dst_variable(ArgumentDescriptor { + op: self.dst, + typ: t, + is_dst: true, + is_pointer: false, + }), + src: visitor.src_mov_operand(ArgumentDescriptor { + op: self.src, + typ: t, + is_dst: false, + is_pointer: false, + }), } } } @@ -1377,9 +1668,24 @@ impl ast::Arg3 { t: Option, ) -> ast::Arg3 { ast::Arg3 { - dst: visitor.dst_variable(self.dst, t), - src1: visitor.src_operand(self.src1, t), - src2: visitor.src_operand(self.src2, t), + dst: visitor.dst_variable(ArgumentDescriptor { + op: self.dst, + typ: t, + is_dst: true, + is_pointer: false, + }), + src1: visitor.src_operand(ArgumentDescriptor { + op: self.src1, + typ: t, + is_dst: false, + is_pointer: false, + }), + src2: visitor.src_operand(ArgumentDescriptor { + op: self.src2, + typ: t, + is_dst: false, + is_pointer: false, + }), } } @@ -1389,9 +1695,24 @@ impl ast::Arg3 { t: Option, ) -> ast::Arg3 { ast::Arg3 { - dst: visitor.dst_variable(self.dst, t), - src1: visitor.src_operand(self.src1, t), - src2: visitor.src_operand(self.src2, Some(ast::Type::Scalar(ast::ScalarType::U32))), + dst: visitor.dst_variable(ArgumentDescriptor { + op: self.dst, + typ: t, + is_dst: true, + is_pointer: false, + }), + src1: visitor.src_operand(ArgumentDescriptor { + op: self.src1, + typ: t, + is_dst: false, + is_pointer: false, + }), + src2: visitor.src_operand(ArgumentDescriptor { + op: self.src2, + typ: Some(ast::Type::Scalar(ast::ScalarType::U32)), + is_dst: false, + is_pointer: false, + }), } } } @@ -1403,18 +1724,32 @@ impl ast::Arg4 { t: Option, ) -> ast::Arg4 { ast::Arg4 { - dst1: visitor.dst_variable( - self.dst1, - Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), - ), - dst2: self.dst2.map(|dst2| { - visitor.dst_variable( - dst2, - Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), - ) + dst1: visitor.dst_variable(ArgumentDescriptor { + op: self.dst1, + typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), + is_dst: true, + is_pointer: false, + }), + dst2: self.dst2.map(|dst2| { + visitor.dst_variable(ArgumentDescriptor { + op: dst2, + typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), + is_dst: true, + is_pointer: false, + }) + }), + src1: visitor.src_operand(ArgumentDescriptor { + op: self.src1, + typ: t, + is_dst: false, + is_pointer: false, + }), + src2: visitor.src_operand(ArgumentDescriptor { + op: self.src2, + typ: t, + is_dst: false, + is_pointer: false, }), - src1: visitor.src_operand(self.src1, t), - src2: visitor.src_operand(self.src2, t), } } } @@ -1426,22 +1761,38 @@ impl ast::Arg5 { t: Option, ) -> ast::Arg5 { ast::Arg5 { - dst1: visitor.dst_variable( - self.dst1, - Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), - ), - dst2: self.dst2.map(|dst2| { - visitor.dst_variable( - dst2, - Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), - ) + dst1: visitor.dst_variable(ArgumentDescriptor { + op: self.dst1, + typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), + is_dst: true, + is_pointer: false, + }), + dst2: self.dst2.map(|dst2| { + visitor.dst_variable(ArgumentDescriptor { + op: dst2, + typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), + is_dst: true, + is_pointer: false, + }) + }), + src1: visitor.src_operand(ArgumentDescriptor { + op: self.src1, + typ: t, + is_dst: false, + is_pointer: false, + }), + src2: visitor.src_operand(ArgumentDescriptor { + op: self.src2, + typ: t, + is_dst: false, + is_pointer: false, + }), + src3: visitor.src_operand(ArgumentDescriptor { + op: self.src3, + typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), + is_dst: false, + is_pointer: false, }), - src1: visitor.src_operand(self.src1, t), - src2: visitor.src_operand(self.src2, t), - src3: visitor.src_operand( - self.src3, - Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), - ), } } } @@ -1851,34 +2202,34 @@ fn insert_implicit_bitcasts( instr: ast::Instruction, ) { let mut dst_coercion = None; - let instr = instr.visit_variable_extended(&mut |mut id, is_dst, id_type| { - let id_type_from_instr = match id_type { + let instr = instr.visit_variable_extended(&mut |mut desc| { + let id_type_from_instr = match desc.typ { Some(t) => t, - None => return id, + None => return desc.op, }; - let id_actual_type = id_def.get_type(id); - if should_bitcast(id_type_from_instr, id_def.get_type(id)) { - if is_dst { + let id_actual_type = id_def.get_type(desc.op); + if should_bitcast(id_type_from_instr, id_def.get_type(desc.op)) { + if desc.is_dst { dst_coercion = Some(get_conversion_dst( id_def, - &mut id, + &mut desc.op, id_type_from_instr, id_actual_type, ConversionKind::Default, )); - id + desc.op } else { insert_conversion_src( func, id_def, - id, + desc.op, id_actual_type, id_type_from_instr, ConversionKind::Default, ) } } else { - id + desc.op } }); func.push(Statement::Instruction(instr));