diff --git a/Cargo.toml b/Cargo.toml index ed5d1f1..42be95a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,5 +11,5 @@ members = [ ] [patch.crates-io] -rspirv = { git = 'https://github.com/vosen/rspirv', rev = '4523d54d785faff59c1e928dd1f210c531a70258' } -spirv_headers = { git = 'https://github.com/vosen/rspirv', rev = '4523d54d785faff59c1e928dd1f210c531a70258' } \ No newline at end of file +rspirv = { git = 'https://github.com/vosen/rspirv', rev = '0f5761918624f4a95107c14abe64946c5c5f60ce' } +spirv_headers = { git = 'https://github.com/vosen/rspirv', rev = '0f5761918624f4a95107c14abe64946c5c5f60ce' } \ No newline at end of file diff --git a/ptx/Cargo.toml b/ptx/Cargo.toml index 42d60cb..96ab9d0 100644 --- a/ptx/Cargo.toml +++ b/ptx/Cargo.toml @@ -10,7 +10,7 @@ edition = "2018" lalrpop-util = "0.19" regex = "1" rspirv = "0.6" -spirv_headers = "1.4" +spirv_headers = "~1.4.2" quick-error = "1.2" bit-vec = "0.6" half ="1.6" diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 7edfa70..097e19c 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -463,14 +463,14 @@ pub enum CallOperand { pub enum IdOrVector { Reg(ID), - Vec(Vec) + Vec(Vec), } pub enum OperandOrVector { Reg(ID), RegOffset(ID, i32), Imm(u32), - Vec(Vec) + Vec(Vec), } impl From> for OperandOrVector { @@ -536,6 +536,8 @@ pub struct MovDetails { // two fields below are in use by member moves pub dst_width: u8, pub src_width: u8, + // This is in use by auto-generated movs + pub relaxed_src2_conv: bool, } impl MovDetails { @@ -544,7 +546,8 @@ impl MovDetails { typ, src_is_address: false, dst_width: 0, - src_width: 0 + src_width: 0, + relaxed_src2_conv: false, } } } @@ -560,7 +563,7 @@ pub struct MulIntDesc { pub control: MulIntControl, } -#[derive(Copy, Clone)] +#[derive(Copy, Clone, PartialEq, Eq)] pub enum MulIntControl { Low, High, diff --git a/ptx/src/test/spirv_run/ld_st_implicit.spvtxt b/ptx/src/test/spirv_run/ld_st_implicit.spvtxt index 249af90..d4d9499 100644 --- a/ptx/src/test/spirv_run/ld_st_implicit.spvtxt +++ b/ptx/src/test/spirv_run/ld_st_implicit.spvtxt @@ -2,8 +2,10 @@ OpCapability Linkage OpCapability Addresses OpCapability Kernel - OpCapability Int64 OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 OpCapability Float64 %23 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL @@ -33,17 +35,17 @@ %11 = OpCopyObject %ulong %12 OpStore %5 %11 %14 = OpLoad %ulong %4 - %17 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %14 - %18 = OpLoad %float %17 - %31 = OpBitcast %uint %18 + %18 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %14 + %17 = OpLoad %float %18 + %31 = OpBitcast %uint %17 %13 = OpUConvert %ulong %31 OpStore %6 %13 %15 = OpLoad %ulong %5 %16 = OpLoad %ulong %6 + %19 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %15 %32 = OpBitcast %ulong %16 %33 = OpUConvert %uint %32 - %19 = OpBitcast %float %33 - %20 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %15 - OpStore %20 %19 + %20 = OpBitcast %float %33 + OpStore %19 %20 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mul_wide.spvtxt b/ptx/src/test/spirv_run/mul_wide.spvtxt index 274612c..8ac0459 100644 --- a/ptx/src/test/spirv_run/mul_wide.spvtxt +++ b/ptx/src/test/spirv_run/mul_wide.spvtxt @@ -2,8 +2,10 @@ OpCapability Linkage OpCapability Addresses OpCapability Kernel - OpCapability Int64 OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 OpCapability Float64 %32 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL @@ -57,8 +59,8 @@ OpStore %8 %19 %22 = OpLoad %ulong %5 %23 = OpLoad %ulong %8 - %28 = OpCopyObject %ulong %23 - %29 = OpConvertUToPtr %_ptr_Generic_ulong %22 - OpStore %29 %28 + %28 = OpConvertUToPtr %_ptr_Generic_ulong %22 + %29 = OpCopyObject %ulong %23 + OpStore %28 %29 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/vector_extract.ptx b/ptx/src/test/spirv_run/vector_extract.ptx index 8624f8a..111f7c0 100644 --- a/ptx/src/test/spirv_run/vector_extract.ptx +++ b/ptx/src/test/spirv_run/vector_extract.ptx @@ -15,6 +15,9 @@ .reg .u16 temp4; .reg .v4.u16 foo; + ld.param.u64 in_addr, [input_p]; + ld.param.u64 out_addr, [output_p]; + ld.global.v4.u8 {temp1, temp2, temp3, temp4}, [in_addr]; mov.v4.u16 foo, {temp2, temp3, temp4, temp1}; mov.v4.u16 {temp3, temp4, temp1, temp2}, foo; diff --git a/ptx/src/test/spirv_run/vector_extract.spvtxt b/ptx/src/test/spirv_run/vector_extract.spvtxt index ff0ee97..45df3a8 100644 --- a/ptx/src/test/spirv_run/vector_extract.spvtxt +++ b/ptx/src/test/spirv_run/vector_extract.spvtxt @@ -2,96 +2,123 @@ OpCapability Linkage OpCapability Addresses OpCapability Kernel - OpCapability Int64 OpCapability Int8 - %60 = OpExtInstImport "OpenCL.std" + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %75 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %31 "vector" + OpEntryPoint Kernel %1 "vector_extract" %void = OpTypeVoid - %uint = OpTypeInt 32 0 - %v2uint = OpTypeVector %uint 2 - %64 = OpTypeFunction %v2uint %v2uint -%_ptr_Function_v2uint = OpTypePointer Function %v2uint -%_ptr_Function_uint = OpTypePointer Function %uint %ulong = OpTypeInt 64 0 - %68 = OpTypeFunction %void %ulong %ulong + %78 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Generic_v2uint = OpTypePointer Generic %v2uint - %1 = OpFunction %v2uint None %64 - %7 = OpFunctionParameter %v2uint - %30 = OpLabel - %3 = OpVariable %_ptr_Function_v2uint Function - %2 = OpVariable %_ptr_Function_v2uint Function - %4 = OpVariable %_ptr_Function_v2uint Function - %5 = OpVariable %_ptr_Function_uint Function - %6 = OpVariable %_ptr_Function_uint Function - OpStore %3 %7 - %9 = OpLoad %v2uint %3 - %27 = OpCompositeExtract %uint %9 0 - %8 = OpCopyObject %uint %27 - OpStore %5 %8 - %11 = OpLoad %v2uint %3 - %28 = OpCompositeExtract %uint %11 1 - %10 = OpCopyObject %uint %28 - OpStore %6 %10 - %13 = OpLoad %uint %5 - %14 = OpLoad %uint %6 - %12 = OpIAdd %uint %13 %14 - OpStore %6 %12 - %16 = OpLoad %v2uint %4 - %17 = OpLoad %uint %6 - %15 = OpCompositeInsert %v2uint %17 %16 0 - OpStore %4 %15 - %19 = OpLoad %v2uint %4 - %20 = OpLoad %uint %6 - %18 = OpCompositeInsert %v2uint %20 %19 1 - OpStore %4 %18 - %22 = OpLoad %v2uint %4 - %23 = OpLoad %v2uint %4 - %29 = OpCompositeExtract %uint %23 1 - %21 = OpCompositeInsert %v2uint %29 %22 0 - OpStore %4 %21 - %25 = OpLoad %v2uint %4 - %24 = OpCopyObject %v2uint %25 - OpStore %2 %24 - %26 = OpLoad %v2uint %2 - OpReturnValue %26 - OpFunctionEnd - %31 = OpFunction %void None %68 - %40 = OpFunctionParameter %ulong - %41 = OpFunctionParameter %ulong - %58 = OpLabel - %32 = OpVariable %_ptr_Function_ulong Function - %33 = OpVariable %_ptr_Function_ulong Function - %34 = OpVariable %_ptr_Function_ulong Function - %35 = OpVariable %_ptr_Function_ulong Function - %36 = OpVariable %_ptr_Function_v2uint Function - %37 = OpVariable %_ptr_Function_uint Function - %38 = OpVariable %_ptr_Function_uint Function - %39 = OpVariable %_ptr_Function_ulong Function - OpStore %32 %40 - OpStore %33 %41 - %43 = OpLoad %ulong %32 - %42 = OpCopyObject %ulong %43 - OpStore %34 %42 - %45 = OpLoad %ulong %33 - %44 = OpCopyObject %ulong %45 - OpStore %35 %44 - %47 = OpLoad %ulong %34 - %54 = OpConvertUToPtr %_ptr_Generic_v2uint %47 - %46 = OpLoad %v2uint %54 - OpStore %36 %46 - %49 = OpLoad %v2uint %36 - %48 = OpFunctionCall %v2uint %1 %49 - OpStore %36 %48 - %51 = OpLoad %v2uint %36 - %55 = OpBitcast %ulong %51 - %56 = OpCopyObject %ulong %55 - %50 = OpCopyObject %ulong %56 - OpStore %39 %50 - %52 = OpLoad %ulong %35 - %53 = OpLoad %v2uint %36 - %57 = OpConvertUToPtr %_ptr_Generic_v2uint %52 - OpStore %57 %53 + %ushort = OpTypeInt 16 0 +%_ptr_Function_ushort = OpTypePointer Function %ushort + %v4ushort = OpTypeVector %ushort 4 +%_ptr_Function_v4ushort = OpTypePointer Function %v4ushort + %uchar = OpTypeInt 8 0 + %v4uchar = OpTypeVector %uchar 4 +%_ptr_CrossWorkgroup_v4uchar = OpTypePointer CrossWorkgroup %v4uchar + %1 = OpFunction %void None %78 + %11 = OpFunctionParameter %ulong + %12 = OpFunctionParameter %ulong + %73 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_ushort Function + %7 = OpVariable %_ptr_Function_ushort Function + %8 = OpVariable %_ptr_Function_ushort Function + %9 = OpVariable %_ptr_Function_ushort Function + %10 = OpVariable %_ptr_Function_v4ushort Function + OpStore %2 %11 + OpStore %3 %12 + %14 = OpLoad %ulong %2 + %13 = OpCopyObject %ulong %14 + OpStore %4 %13 + %16 = OpLoad %ulong %3 + %15 = OpCopyObject %ulong %16 + OpStore %5 %15 + %21 = OpLoad %ulong %4 + %63 = OpConvertUToPtr %_ptr_CrossWorkgroup_v4uchar %21 + %45 = OpLoad %v4uchar %63 + %64 = OpCompositeExtract %uchar %45 0 + %87 = OpBitcast %uchar %64 + %17 = OpUConvert %ushort %87 + %65 = OpCompositeExtract %uchar %45 1 + %88 = OpBitcast %uchar %65 + %18 = OpUConvert %ushort %88 + %66 = OpCompositeExtract %uchar %45 2 + %89 = OpBitcast %uchar %66 + %19 = OpUConvert %ushort %89 + %67 = OpCompositeExtract %uchar %45 3 + %90 = OpBitcast %uchar %67 + %20 = OpUConvert %ushort %90 + OpStore %6 %17 + OpStore %7 %18 + OpStore %8 %19 + OpStore %9 %20 + %23 = OpLoad %ushort %7 + %24 = OpLoad %ushort %8 + %25 = OpLoad %ushort %9 + %26 = OpLoad %ushort %6 + %46 = OpUndef %v4ushort + %47 = OpCompositeInsert %v4ushort %23 %46 0 + %48 = OpCompositeInsert %v4ushort %24 %47 1 + %49 = OpCompositeInsert %v4ushort %25 %48 2 + %50 = OpCompositeInsert %v4ushort %26 %49 3 + %22 = OpCopyObject %v4ushort %50 + OpStore %10 %22 + %31 = OpLoad %v4ushort %10 + %51 = OpCopyObject %v4ushort %31 + %27 = OpCompositeExtract %ushort %51 0 + %28 = OpCompositeExtract %ushort %51 1 + %29 = OpCompositeExtract %ushort %51 2 + %30 = OpCompositeExtract %ushort %51 3 + OpStore %8 %27 + OpStore %9 %28 + OpStore %6 %29 + OpStore %7 %30 + %36 = OpLoad %ushort %8 + %37 = OpLoad %ushort %9 + %38 = OpLoad %ushort %6 + %39 = OpLoad %ushort %7 + %53 = OpUndef %v4ushort + %54 = OpCompositeInsert %v4ushort %36 %53 0 + %55 = OpCompositeInsert %v4ushort %37 %54 1 + %56 = OpCompositeInsert %v4ushort %38 %55 2 + %57 = OpCompositeInsert %v4ushort %39 %56 3 + %52 = OpCopyObject %v4ushort %57 + %32 = OpCompositeExtract %ushort %52 0 + %33 = OpCompositeExtract %ushort %52 1 + %34 = OpCompositeExtract %ushort %52 2 + %35 = OpCompositeExtract %ushort %52 3 + OpStore %9 %32 + OpStore %6 %33 + OpStore %7 %34 + OpStore %8 %35 + %40 = OpLoad %ulong %5 + %41 = OpLoad %ushort %6 + %42 = OpLoad %ushort %7 + %43 = OpLoad %ushort %8 + %44 = OpLoad %ushort %9 + %58 = OpUndef %v4uchar + %91 = OpBitcast %ushort %41 + %68 = OpUConvert %uchar %91 + %59 = OpCompositeInsert %v4uchar %68 %58 0 + %92 = OpBitcast %ushort %42 + %69 = OpUConvert %uchar %92 + %60 = OpCompositeInsert %v4uchar %69 %59 1 + %93 = OpBitcast %ushort %43 + %70 = OpUConvert %uchar %93 + %61 = OpCompositeInsert %v4uchar %70 %60 2 + %94 = OpBitcast %ushort %44 + %71 = OpUConvert %uchar %94 + %62 = OpCompositeInsert %v4uchar %71 %61 3 + %72 = OpConvertUToPtr %_ptr_CrossWorkgroup_v4uchar %40 + OpStore %72 %62 OpReturn OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 981da86..37cef00 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -843,6 +843,7 @@ fn insert_mem_ssa_statement_default<'a, F: VisitVariable>( (_, ArgumentSemantics::Address) => return Ok(desc.op), (t, ArgumentSemantics::RegisterPointer) | (t, ArgumentSemantics::Default) + | (t, ArgumentSemantics::DefaultRelaxed) | (t, ArgumentSemantics::PhysicalPointer) => t, }; let generated_id = id_def.new_id(id_type); @@ -933,17 +934,19 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { fn insert_composite_read( func: &mut Vec, id_def: &mut MutableNumericIdResolver<'a>, - (scalar_type, vec_len): (ast::ScalarType, u8), + typ: (ast::ScalarType, u8), scalar_dst: Option, + scalar_sema_override: Option, composite_src: (spirv::Word, u8), ) -> spirv::Word { - let new_id = - scalar_dst.unwrap_or_else(|| id_def.new_id(ast::Type::Vector(scalar_type, vec_len))); + let new_id = scalar_dst.unwrap_or_else(|| id_def.new_id(ast::Type::Scalar(typ.0))); func.push(Statement::Composite(CompositeRead { - typ: scalar_type, + typ: typ.0, dst: new_id, + dst_semantics_override: scalar_sema_override, src_composite: composite_src.0, src_index: composite_src.1 as u32, + src_len: typ.1 as u32, })); new_id } @@ -963,7 +966,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { ) -> Result { let (reg, offset) = desc.op; match desc.sema { - ArgumentSemantics::Default => { + ArgumentSemantics::Default | ArgumentSemantics::DefaultRelaxed => { let scalar_t = if let ast::Type::Scalar(scalar) = typ { scalar } else { @@ -1049,18 +1052,19 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { fn member_src( &mut self, desc: ArgumentDescriptor<(spirv::Word, u8)>, - (scalar_type, vec_len): (ast::ScalarType, u8), + typ: (ast::ScalarType, u8), ) -> Result { if desc.is_dst { return Err(TranslateError::Unreachable); } - let new_id = self.id_def.new_id(ast::Type::Vector(scalar_type, vec_len)); - self.func.push(Statement::Composite(CompositeRead { - typ: scalar_type, - dst: new_id, - src_composite: desc.op.0, - src_index: desc.op.1 as u32, - })); + let new_id = Self::insert_composite_read( + self.func, + self.id_def, + typ, + None, + Some(desc.sema), + desc.op, + ); Ok(new_id) } @@ -1077,10 +1081,11 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { let newer_id = self.id_def.new_id(typ); self.func.push(Statement::Instruction(ast::Instruction::Mov( ast::MovDetails { - typ: typ, + typ: ast::Type::Scalar(scalar_type), src_is_address: false, - dst_width: 0, - src_width: vec_len, + dst_width: vec_len, + src_width: 0, + relaxed_src2_conv: desc.sema == ArgumentSemantics::DefaultRelaxed, }, ast::Arg2Mov::Member(ast::Arg2MovMember::Dst( (newer_id, idx as u8), @@ -1099,6 +1104,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { self.id_def, (scalar_type, vec_len), Some(*id), + Some(desc.sema), (new_id, idx as u8), ); } @@ -1144,9 +1150,9 @@ impl<'a, 'b> ArgumentMapVisitor for FlattenAr fn src_member_operand( &mut self, desc: ArgumentDescriptor<(spirv::Word, u8)>, - (scalar_type, vec_len): (ast::ScalarType, u8), + typ: (ast::ScalarType, u8), ) -> Result { - self.member_src(desc, (scalar_type, vec_len)) + self.member_src(desc, typ) } fn id_or_vector( @@ -1195,123 +1201,41 @@ fn insert_implicit_conversions( let mut result = Vec::with_capacity(func.len()); for s in func.into_iter() { match s { - Statement::Call(call) => insert_implicit_bitcasts(&mut result, id_def, call)?, - Statement::Instruction(inst) => match inst { - ast::Instruction::Ld(ld, arg) => { - let pre_conv = get_implicit_conversions_ld_src( - id_def, - ld.typ, - ld.state_space, - arg.src, - false, - )?; - let post_conv = get_implicit_conversions_ld_dst( - id_def, - ld.typ, - arg.dst, - should_convert_relaxed_dst, - false, - )?; - insert_with_conversions( - &mut result, - id_def, - arg, - pre_conv.into_iter(), - iter::empty(), - post_conv.into_iter().collect(), - |arg| &mut arg.src, - |arg| &mut arg.dst, - |arg| ast::Instruction::Ld(ld, arg), - ) + Statement::Call(call) => insert_implicit_conversions_impl( + &mut result, + id_def, + call, + should_bitcast_wrapper, + None, + )?, + Statement::Instruction(inst) => { + let mut default_conversion_fn = should_bitcast_wrapper + as fn(_, _, _) -> Result, TranslateError>; + let mut state_space = None; + if let ast::Instruction::Ld(d, _) = &inst { + state_space = Some(d.state_space); } - ast::Instruction::St(st, arg) => { - let pre_conv = get_implicit_conversions_ld_dst( - id_def, - st.typ, - arg.src2, - should_convert_relaxed_src, - true, - )?; - let post_conv = get_implicit_conversions_ld_src( - id_def, - st.typ, - st.state_space.to_ld_ss(), - arg.src1, - true, - )?; - let (pre_conv_dest, post_conv) = if st.state_space == ast::StStateSpace::Param - || st.state_space == ast::StStateSpace::Local - { - (Vec::new(), post_conv) - } else { - (post_conv, Vec::new()) - }; - insert_with_conversions( - &mut result, - id_def, - arg, - pre_conv.into_iter(), - pre_conv_dest.into_iter(), - post_conv, - |arg| &mut arg.src2, - |arg| &mut arg.src1, - |arg| ast::Instruction::St(st, arg), - ) + if let ast::Instruction::St(d, _) = &inst { + state_space = Some(d.state_space.to_ld_ss()); } - ast::Instruction::Mov(d, ast::Arg2Mov::Normal(mut arg)) => { - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov-2 - // TODO: handle the case of mixed vector/scalar implicit conversions - let inst_typ_is_bit = match d.typ { - ast::Type::Scalar(t) => ast::ScalarType::from(t).kind() == ScalarKind::Bit, - ast::Type::Vector(_, _) => false, - ast::Type::Array(_, _) => false, - }; - let mut did_vector_implicit = false; - let mut post_conv = None; - if inst_typ_is_bit { - let src_type = id_def.get_typed(arg.src)?; - if let ast::Type::Vector(_, _) = src_type { - arg.src = insert_conversion_src( - &mut result, - id_def, - arg.src, - src_type, - d.typ.into(), - ConversionKind::Default, - ); - did_vector_implicit = true; - } - let dst_type = id_def.get_typed(arg.dst)?; - if let ast::Type::Vector(_, _) = dst_type { - post_conv = Some(get_conversion_dst( - id_def, - &mut arg.dst, - d.typ.into(), - dst_type, - ConversionKind::Default, - )); - did_vector_implicit = true; - } - } - if did_vector_implicit { - result.push(Statement::Instruction(ast::Instruction::Mov( - d, - ast::Arg2Mov::Normal(arg), - ))); - } else { - insert_implicit_bitcasts( - &mut result, - id_def, - ast::Instruction::Mov(d, ast::Arg2Mov::Normal(arg)), - )?; - } - if let Some(post_conv) = post_conv { - result.push(post_conv); - } + if let ast::Instruction::Mov(_, ast::Arg2Mov::Normal(_)) = &inst { + default_conversion_fn = should_bitcast_packed; } - inst @ _ => insert_implicit_bitcasts(&mut result, id_def, inst)?, - }, - Statement::Composite(c) => insert_implicit_bitcasts(&mut result, id_def, c)?, + insert_implicit_conversions_impl( + &mut result, + id_def, + inst, + default_conversion_fn, + state_space, + )?; + } + Statement::Composite(composite) => insert_implicit_conversions_impl( + &mut result, + id_def, + composite, + should_bitcast_wrapper, + None, + )?, s @ Statement::Conditional(_) | s @ Statement::Label(_) | s @ Statement::Constant(_) @@ -1326,6 +1250,77 @@ fn insert_implicit_conversions( Ok(result) } +fn insert_implicit_conversions_impl( + func: &mut Vec, + id_def: &mut MutableNumericIdResolver, + stmt: impl VisitVariableExpanded, + default_conversion_fn: fn( + ast::Type, + ast::Type, + Option, + ) -> Result, TranslateError>, + state_space: Option, +) -> Result<(), TranslateError> { + let mut post_conv = Vec::new(); + let statement = stmt.visit_variable_extended(&mut |desc, typ| { + let instr_type = match typ { + None => return Ok(desc.op), + Some(t) => t, + }; + let operand_type = id_def.get_typed(desc.op)?; + let mut conversion_fn = default_conversion_fn; + match desc.sema { + ArgumentSemantics::Default => {} + ArgumentSemantics::DefaultRelaxed => { + if desc.is_dst { + conversion_fn = should_convert_relaxed_dst_wrapper; + } else { + conversion_fn = should_convert_relaxed_src_wrapper; + } + } + ArgumentSemantics::PhysicalPointer => { + conversion_fn = bitcast_physical_pointer; + } + ArgumentSemantics::RegisterPointer => { + conversion_fn = force_bitcast; + } + ArgumentSemantics::Address => { + conversion_fn = force_bitcast_ptr_to_bit; + } + }; + match conversion_fn(operand_type, instr_type, state_space)? { + Some(conv_kind) => { + let conv_output = if desc.is_dst { + &mut post_conv + } else { + &mut *func + }; + let mut from = instr_type; + let mut to = operand_type; + let mut src = id_def.new_id(instr_type); + let mut dst = desc.op; + let result = Ok(src); + if !desc.is_dst { + mem::swap(&mut src, &mut dst); + mem::swap(&mut from, &mut to); + } + conv_output.push(Statement::Conversion(ImplicitConversion { + src, + dst, + from, + to, + kind: conv_kind, + })); + result + } + None => Ok(desc.op), + } + })?; + func.push(statement); + func.append(&mut post_conv); + Ok(()) +} + fn get_function_type( builder: &mut dr::Builder, map: &mut TypeWordMap, @@ -1505,7 +1500,11 @@ fn emit_function_body_ops( composite_src, scalar_src, )) => { - let result_type = map.get_or_add(builder, SpirvType::from(d.typ)); + let scalar_type = d.typ.get_scalar()?; + let result_type = map.get_or_add( + builder, + SpirvType::from(ast::Type::Vector(scalar_type, d.dst_width)), + ); let result_id = Some(dst.0); builder.composite_insert( result_type, @@ -1545,8 +1544,8 @@ fn emit_function_body_ops( // Obviously, old and buggy one is used for compiling L0 SPIRV // https://github.com/intel/intel-graphics-compiler/issues/148 let type_pred = map.get_or_add_scalar(builder, ast::ScalarType::Pred); - let const_true = builder.constant_true(type_pred); - let const_false = builder.constant_false(type_pred); + let const_true = builder.constant_true(type_pred, None); + let const_false = builder.constant_false(type_pred, None); builder.select(result_type, result_id, operand, const_false, const_true) } _ => builder.not(result_type, result_id, operand), @@ -2700,12 +2699,9 @@ where fn src_member_operand( &mut self, desc: ArgumentDescriptor, - (scalar_type, vec_len): (ast::ScalarType, u8), + (scalar_type, _): (ast::ScalarType, u8), ) -> Result { - self( - desc.new_op(desc.op), - Some(ast::Type::Vector(scalar_type.into(), vec_len)), - ) + self(desc.new_op(desc.op), Some(ast::Type::Scalar(scalar_type))) } } @@ -2793,6 +2789,8 @@ pub struct ArgumentDescriptor { pub enum ArgumentSemantics { // normal register access Default, + // normal register access with relaxed conversion rules (ld/st) + DefaultRelaxed, // st/ld global PhysicalPointer, // st/ld .param, .local @@ -2834,11 +2832,12 @@ impl ast::Instruction { } ast::Instruction::Mul(d, a) => { let inst_type = d.get_type(); - ast::Instruction::Mul(d, a.map_non_shift(visitor, inst_type)?) + let is_wide = d.is_wide(); + ast::Instruction::Mul(d, a.map_non_shift(visitor, inst_type, is_wide)?) } ast::Instruction::Add(d, a) => { let inst_type = d.get_type(); - ast::Instruction::Add(d, a.map_non_shift(visitor, inst_type)?) + ast::Instruction::Add(d, a.map_non_shift(visitor, inst_type, false)?) } ast::Instruction::Setp(d, a) => { let inst_type = d.typ; @@ -2889,7 +2888,8 @@ impl ast::Instruction { } ast::Instruction::Mad(d, a) => { let inst_type = d.get_type(); - ast::Instruction::Mad(d, a.map(visitor, inst_type)?) + let is_wide = d.is_wide(); + ast::Instruction::Mad(d, a.map(visitor, inst_type, is_wide)?) } }) } @@ -3004,6 +3004,27 @@ where } impl ast::Type { + fn widen(self) -> Result { + match self { + ast::Type::Scalar(scalar) => { + let kind = scalar.kind(); + let width = scalar.width(); + if (kind != ScalarKind::Signed + && kind != ScalarKind::Unsigned + && kind != ScalarKind::Bit) + || (width == 8) + { + return Err(TranslateError::MismatchedType); + } + Ok(ast::Type::Scalar(ast::ScalarType::from_parts( + width * 2, + kind, + ))) + } + _ => Err(TranslateError::Unreachable), + } + } + fn to_parts(self) -> TypeParts { match self { ast::Type::Scalar(scalar) => TypeParts { @@ -3102,8 +3123,10 @@ type Arg2St = ast::Arg2St; struct CompositeRead { pub typ: ast::ScalarType, pub dst: spirv::Word, + pub dst_semantics_override: Option, pub src_composite: spirv::Word, pub src_index: u32, + pub src_len: u32, } impl VisitVariableExpanded for CompositeRead { @@ -3116,12 +3139,15 @@ impl VisitVariableExpanded for CompositeRead { self, f: &mut F, ) -> Result { + let dst_sema = self + .dst_semantics_override + .unwrap_or(ArgumentSemantics::Default); Ok(Statement::Composite(CompositeRead { dst: f( ArgumentDescriptor { op: self.dst, is_dst: true, - sema: ArgumentSemantics::Default, + sema: dst_sema, }, Some(ast::Type::Scalar(self.typ)), )?, @@ -3131,7 +3157,7 @@ impl VisitVariableExpanded for CompositeRead { is_dst: false, sema: ArgumentSemantics::Default, }, - Some(ast::Type::Scalar(self.typ)), + Some(ast::Type::Vector(self.typ, self.src_len as u8)), )?, ..self })) @@ -3328,7 +3354,7 @@ impl ast::Arg2Ld { ArgumentDescriptor { op: self.dst, is_dst: true, - sema: ArgumentSemantics::Default, + sema: ArgumentSemantics::DefaultRelaxed, }, t.into(), )?; @@ -3380,7 +3406,7 @@ impl ast::Arg2St { ArgumentDescriptor { op: self.src2, is_dst: false, - sema: ArgumentSemantics::Default, + sema: ArgumentSemantics::DefaultRelaxed, }, t, )?; @@ -3429,9 +3455,9 @@ impl ast::Arg2MovNormal

{ op: self.src, is_dst: false, sema: if details.src_is_address { - ArgumentSemantics::RegisterPointer + ArgumentSemantics::Address } else { - ArgumentSemantics::PhysicalPointer + ArgumentSemantics::Default }, }, details.typ.into(), @@ -3476,13 +3502,14 @@ impl ast::Arg2MovMember { ) -> Result, TranslateError> { match self { ast::Arg2MovMember::Dst((dst, len), composite_src, scalar_src) => { + let scalar_type = details.typ.get_scalar()?; let dst = visitor.id( ArgumentDescriptor { op: dst, is_dst: true, sema: ArgumentSemantics::Default, }, - Some(details.typ.into()), + Some(ast::Type::Vector(scalar_type, details.dst_width)), )?; let src1 = visitor.id( ArgumentDescriptor { @@ -3490,7 +3517,7 @@ impl ast::Arg2MovMember { is_dst: false, sema: ArgumentSemantics::Default, }, - Some(details.typ.into()), + Some(ast::Type::Vector(scalar_type, details.dst_width)), )?; let src2 = visitor.id( ArgumentDescriptor { @@ -3498,6 +3525,8 @@ impl ast::Arg2MovMember { is_dst: false, sema: if details.src_is_address { ArgumentSemantics::Address + } else if details.relaxed_src2_conv { + ArgumentSemantics::DefaultRelaxed } else { ArgumentSemantics::Default }, @@ -3527,13 +3556,14 @@ impl ast::Arg2MovMember { Ok(ast::Arg2MovMember::Src(dst, src)) } ast::Arg2MovMember::Both((dst, len), composite_src, src) => { + let scalar_type = details.typ.get_scalar()?; let dst = visitor.id( ArgumentDescriptor { op: dst, is_dst: true, sema: ArgumentSemantics::Default, }, - Some(details.typ.into()), + Some(ast::Type::Vector(scalar_type, details.dst_width)), )?; let composite_src = visitor.id( ArgumentDescriptor { @@ -3541,16 +3571,19 @@ impl ast::Arg2MovMember { is_dst: false, sema: ArgumentSemantics::Default, }, - Some(details.typ.into()), + Some(ast::Type::Vector(scalar_type, details.dst_width)), )?; - let scalar_typ = details.typ.get_scalar()?; let src = visitor.src_member_operand( ArgumentDescriptor { op: src, is_dst: false, - sema: ArgumentSemantics::Default, + sema: if details.relaxed_src2_conv { + ArgumentSemantics::DefaultRelaxed + } else { + ArgumentSemantics::Default + }, }, - (scalar_typ.into(), details.src_width), + (scalar_type.into(), details.src_width), )?; Ok(ast::Arg2MovMember::Both((dst, len), composite_src, src)) } @@ -3570,7 +3603,8 @@ impl ast::Arg3 { fn map_non_shift>( self, visitor: &mut V, - t: ast::Type, + typ: ast::Type, + is_wide: bool, ) -> Result, TranslateError> { let dst = visitor.id( ArgumentDescriptor { @@ -3578,7 +3612,7 @@ impl ast::Arg3 { is_dst: true, sema: ArgumentSemantics::Default, }, - Some(t), + Some(if is_wide { typ.widen()? } else { typ }), )?; let src1 = visitor.operand( ArgumentDescriptor { @@ -3586,7 +3620,7 @@ impl ast::Arg3 { is_dst: false, sema: ArgumentSemantics::Default, }, - t, + typ, )?; let src2 = visitor.operand( ArgumentDescriptor { @@ -3594,7 +3628,7 @@ impl ast::Arg3 { is_dst: false, sema: ArgumentSemantics::Default, }, - t, + typ, )?; Ok(ast::Arg3 { dst, src1, src2 }) } @@ -3646,6 +3680,7 @@ impl ast::Arg4 { self, visitor: &mut V, t: ast::Type, + is_wide: bool, ) -> Result, TranslateError> { let dst = visitor.id( ArgumentDescriptor { @@ -3653,7 +3688,7 @@ impl ast::Arg4 { is_dst: true, sema: ArgumentSemantics::Default, }, - Some(t), + Some(if is_wide { t.widen()? } else { t }), )?; let src1 = visitor.operand( ArgumentDescriptor { @@ -4050,6 +4085,54 @@ impl ast::OperandOrVector { } } +impl ast::MulDetails { + fn is_wide(&self) -> bool { + match self { + ast::MulDetails::Int(desc) => desc.control == ast::MulIntControl::Wide, + ast::MulDetails::Float(_) => false, + } + } +} + +fn force_bitcast( + operand: ast::Type, + instr: ast::Type, + _: Option, +) -> Result, TranslateError> { + if instr != operand { + Ok(Some(ConversionKind::Default)) + } else { + Ok(None) + } +} + +fn bitcast_physical_pointer( + operand_type: ast::Type, + _: ast::Type, + ss: Option, +) -> Result, TranslateError> { + match operand_type { + ast::Type::Scalar(ast::ScalarType::B64) + | ast::Type::Scalar(ast::ScalarType::U64) + | ast::Type::Scalar(ast::ScalarType::S64) => { + if let Some(space) = ss { + Ok(Some(ConversionKind::BitToPtr(space))) + } else { + Err(TranslateError::Unreachable) + } + } + _ => Err(TranslateError::MismatchedType), + } +} + +fn force_bitcast_ptr_to_bit( + _: ast::Type, + _: ast::Type, + _: Option, +) -> Result, TranslateError> { + Ok(Some(ConversionKind::PtrToBit)) +} + fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool { match (instr, operand) { (ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => { @@ -4077,187 +4160,50 @@ fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool { } } -fn insert_with_conversions ast::Instruction>( - func: &mut Vec, - id_def: &mut MutableNumericIdResolver, - mut instr: T, - pre_conv_src: impl ExactSizeIterator, - pre_conv_dst: impl ExactSizeIterator, - mut post_conv: Vec, - mut src: impl FnMut(&mut T) -> &mut spirv::Word, - mut dst: impl FnMut(&mut T) -> &mut spirv::Word, - to_inst: ToInstruction, -) { - insert_with_conversions_pre_conv(func, id_def, &mut instr, pre_conv_src, &mut src); - insert_with_conversions_pre_conv(func, id_def, &mut instr, pre_conv_dst, &mut dst); - if post_conv.len() > 0 { - let new_id = id_def.new_id(post_conv[0].from); - post_conv[0].src = new_id; - post_conv.last_mut().unwrap().dst = *dst(&mut instr); - *dst(&mut instr) = new_id; - } - func.push(Statement::Instruction(to_inst(instr))); - for conv in post_conv { - func.push(Statement::Conversion(conv)); +fn should_bitcast_packed( + operand: ast::Type, + instr: ast::Type, + ss: Option, +) -> Result, TranslateError> { + if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) = + (operand, instr) + { + if scalar.kind() == ScalarKind::Bit + && scalar.width() == (vec_underlying_type.width() * vec_len) + { + return Ok(Some(ConversionKind::Default)); + } } + should_bitcast_wrapper(operand, instr, ss) } -fn insert_with_conversions_pre_conv( - func: &mut Vec, - id_def: &mut MutableNumericIdResolver, - mut instr: &mut T, - pre_conv: impl ExactSizeIterator, - src: &mut impl FnMut(&mut T) -> &mut spirv::Word, -) { - let pre_conv_len = pre_conv.len(); - for (i, mut conv) in pre_conv.enumerate() { - let original_src = src(&mut instr); - if i == 0 { - conv.src = *original_src; - } - if i == pre_conv_len - 1 { - let new_id = id_def.new_id(conv.to); - conv.dst = new_id; - *original_src = new_id; - } - func.push(Statement::Conversion(conv)); +fn should_bitcast_wrapper( + operand: ast::Type, + instr: ast::Type, + _: Option, +) -> Result, TranslateError> { + if instr == operand { + return Ok(None); } -} - -fn get_implicit_conversions_ld_dst< - ShouldConvert: FnOnce(ast::Type, ast::Type) -> Option, ->( - id_def: &mut MutableNumericIdResolver, - instr_type: ast::Type, - dst: spirv::Word, - should_convert: ShouldConvert, - in_reverse: bool, -) -> Result, TranslateError> { - let dst_type = id_def.get_typed(dst)?; - if let Some(conv) = should_convert(dst_type, instr_type) { - Ok(Some(ImplicitConversion { - src: u32::max_value(), - dst: u32::max_value(), - from: if !in_reverse { instr_type } else { dst_type }, - to: if !in_reverse { dst_type } else { instr_type }, - kind: conv, - })) + if should_bitcast(instr, operand) { + Ok(Some(ConversionKind::Default)) } else { - Ok(None) + Err(TranslateError::MismatchedType) } } -fn get_implicit_conversions_ld_src( - id_def: &mut MutableNumericIdResolver, - instr_type: ast::Type, - state_space: ast::LdStateSpace, - src: spirv::Word, - in_reverse_param_local: bool, -) -> Result, TranslateError> { - let src_type = id_def.get_typed(src)?; - match state_space { - ast::LdStateSpace::Param | ast::LdStateSpace::Local => { - if src_type != instr_type { - Ok(vec![ - ImplicitConversion { - src: u32::max_value(), - dst: u32::max_value(), - from: if !in_reverse_param_local { - src_type - } else { - instr_type - }, - to: if !in_reverse_param_local { - instr_type - } else { - src_type - }, - kind: ConversionKind::Default, - }; - 1 - ]) - } else { - Ok(Vec::new()) - } - } - ast::LdStateSpace::Generic | ast::LdStateSpace::Global => { - let new_src_type = ast::Type::Scalar(ast::ScalarType::from_parts( - mem::size_of::() as u8, - ScalarKind::Bit, - )); - let mut result = Vec::new(); - // HACK ALERT - // IGC currently segfaults if you bitcast integer -> ptr, that's why we emit an - // additional S64/U64 -> B64 conversion here, so the SPIR-V emission is easier - // TODO: error out if the src is not B64/U64/S64 - if let ast::Type::Scalar(scalar_src_type) = src_type { - if scalar_src_type.kind() == ScalarKind::Signed { - result.push(ImplicitConversion { - src: u32::max_value(), - dst: u32::max_value(), - from: src_type, - to: new_src_type, - kind: ConversionKind::Default, - }); - } - } - result.push(ImplicitConversion { - src: u32::max_value(), - dst: u32::max_value(), - from: src_type, - to: instr_type, - kind: ConversionKind::BitToPtr(state_space), - }); - if result.len() == 2 { - let new_id = id_def.new_id(new_src_type); - result[0].dst = new_id; - result[1].src = new_id; - result[1].from = new_src_type; - } - Ok(result) - } - _ => Err(TranslateError::Todo), - } -} - -#[must_use] -fn insert_conversion_src( - func: &mut Vec, - id_def: &mut MutableNumericIdResolver, - src: spirv::Word, +fn should_convert_relaxed_src_wrapper( src_type: ast::Type, instr_type: ast::Type, - conv: ConversionKind, -) -> spirv::Word { - let temp_src = id_def.new_id(instr_type); - func.push(Statement::Conversion(ImplicitConversion { - src: src, - dst: temp_src, - from: src_type, - to: instr_type, - kind: conv, - })); - temp_src -} - -#[must_use] -fn get_conversion_dst( - id_def: &mut MutableNumericIdResolver, - dst: &mut spirv::Word, - instr_type: ast::Type, - dst_type: ast::Type, - kind: ConversionKind, -) -> ExpandedStatement { - let original_dst = *dst; - let temp_dst = id_def.new_id(instr_type); - *dst = temp_dst; - Statement::Conversion(ImplicitConversion { - src: temp_dst, - dst: original_dst, - from: instr_type, - to: dst_type, - kind: kind, - }) + _: Option, +) -> Result, TranslateError> { + if src_type == instr_type { + return Ok(None); + } + match should_convert_relaxed_src(src_type, instr_type) { + conv @ Some(_) => Ok(conv), + None => Err(TranslateError::MismatchedType), + } } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands @@ -4302,6 +4248,20 @@ fn should_convert_relaxed_src( } } +fn should_convert_relaxed_dst_wrapper( + dst_type: ast::Type, + instr_type: ast::Type, + _: Option, +) -> Result, TranslateError> { + if dst_type == instr_type { + return Ok(None); + } + match should_convert_relaxed_dst(dst_type, instr_type) { + conv @ Some(_) => Ok(conv), + None => Err(TranslateError::MismatchedType), + } +} + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands fn should_convert_relaxed_dst( dst_type: ast::Type, @@ -4357,55 +4317,6 @@ fn should_convert_relaxed_dst( } } -fn insert_implicit_bitcasts( - func: &mut Vec, - id_def: &mut MutableNumericIdResolver, - stmt: impl VisitVariableExpanded, -) -> Result<(), TranslateError> { - let mut dst_coercion = None; - let instr = stmt.visit_variable_extended(&mut |mut desc, typ| { - let id_type_from_instr = match typ { - Some(t) => t, - None => return Ok(desc.op), - }; - let id_actual_type = id_def.get_typed(desc.op)?; - let conv_kind = if desc.sema == ArgumentSemantics::Address { - Some(ConversionKind::PtrToBit) - } else if should_bitcast(id_type_from_instr, id_def.get_typed(desc.op)?) { - Some(ConversionKind::Default) - } else { - None - }; - if let Some(conv_kind) = conv_kind { - if desc.is_dst { - dst_coercion = Some(get_conversion_dst( - id_def, - &mut desc.op, - id_type_from_instr, - id_actual_type, - conv_kind, - )); - Ok(desc.op) - } else { - Ok(insert_conversion_src( - func, - id_def, - desc.op, - id_actual_type, - id_type_from_instr, - conv_kind, - )) - } - } else { - Ok(desc.op) - } - })?; - func.push(instr); - if let Some(cond) = dst_coercion { - func.push(cond); - } - Ok(()) -} impl<'a> ast::MethodDecl<'a, &'a str> { fn name(&self) -> &'a str { match self {