From ca0d8ec666e499ec1a71132757acba407c3ba53b Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Thu, 16 Sep 2021 01:25:09 +0200 Subject: [PATCH] Add missing vray instructions --- ptx/src/ast.rs | 3 + ptx/src/ptx.lalrpop | 43 +++++++- ptx/src/test/spirv_run/activemask.ptx | 18 ++++ ptx/src/test/spirv_run/activemask.spvtxt | 45 ++++++++ ptx/src/test/spirv_run/membar.ptx | 21 ++++ ptx/src/test/spirv_run/membar.spvtxt | 49 +++++++++ ptx/src/test/spirv_run/mod.rs | 3 + ptx/src/test/spirv_run/prmt.ptx | 23 ++++ ptx/src/test/spirv_run/prmt.spvtxt | 67 ++++++++++++ ptx/src/translate.rs | 131 ++++++++++++++++++++++- 10 files changed, 399 insertions(+), 4 deletions(-) create mode 100644 ptx/src/test/spirv_run/activemask.ptx create mode 100644 ptx/src/test/spirv_run/activemask.spvtxt create mode 100644 ptx/src/test/spirv_run/membar.ptx create mode 100644 ptx/src/test/spirv_run/membar.spvtxt create mode 100644 ptx/src/test/spirv_run/prmt.ptx create mode 100644 ptx/src/test/spirv_run/prmt.spvtxt diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 36e7191..a8309b0 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -287,6 +287,9 @@ pub enum Instruction { Bfe { typ: ScalarType, arg: Arg4

}, Bfi { typ: ScalarType, arg: Arg5

}, Rem { typ: ScalarType, arg: Arg3

}, + Prmt { control: u16, arg: Arg3

}, + Activemask { arg: Arg1

}, + Membar { level: MemScope }, } #[derive(Copy, Clone)] diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 0bc7655..fa3cfec 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -70,6 +70,7 @@ match { ".func", ".ge", ".geu", + ".gl", ".global", ".gpu", ".gt", @@ -142,6 +143,7 @@ match { } else { // IF YOU ARE ADDING A NEW TOKEN HERE ALSO ADD IT BELOW TO ExtendedID "abs", + "activemask", "add", "and", "atom", @@ -165,6 +167,7 @@ match { "mad", "map_f64_to_f32", "max", + "membar", "min", "mov", "mul", @@ -172,6 +175,7 @@ match { "not", "or", "popc", + "prmt", "rcp", "rem", "ret", @@ -196,6 +200,7 @@ match { ExtendedID : &'input str = { "abs", + "activemask", "add", "and", "atom", @@ -219,6 +224,7 @@ ExtendedID : &'input str = { "mad", "map_f64_to_f32", "max", + "membar", "min", "mov", "mul", @@ -226,6 +232,7 @@ ExtendedID : &'input str = { "not", "or", "popc", + "prmt", "rcp", "rem", "ret", @@ -292,6 +299,16 @@ U8Num: u8 = { } } +U16Num: u16 = { + =>? { + let (text, radix, _) = x; + match u16::from_str_radix(text, radix) { + Ok(x) => Ok(x), + Err(err) => Err(ParseError::User { error: ast::PtxError::from(err) }) + } + } +} + U32Num: u32 = { =>? { let (text, radix, _) = x; @@ -761,6 +778,9 @@ Instruction: ast::Instruction> = { InstRem, InstBfe, InstBfi, + InstPrmt, + InstActivemask, + InstMembar, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld @@ -821,6 +841,12 @@ MemScope: ast::MemScope = { ".sys" => ast::MemScope::Sys }; +MembarLevel: ast::MemScope = { + ".cta" => ast::MemScope::Cta, + ".gl" => ast::MemScope::Gpu, + ".sys" => ast::MemScope::Sys +}; + LdNonGlobalStateSpace: ast::StateSpace = { ".const" => ast::StateSpace::Const, ".local" => ast::StateSpace::Local, @@ -1445,8 +1471,9 @@ SelpType: ast::ScalarType = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar InstBar: ast::Instruction> = { + "bar" ".sync" => ast::Instruction::Bar(ast::BarDetails::SyncAligned, a), + "barrier" ".sync" => ast::Instruction::Bar(ast::BarDetails::SyncAligned, a), "barrier" ".sync" ".aligned" => ast::Instruction::Bar(ast::BarDetails::SyncAligned, a), - "bar" ".sync" => ast::Instruction::Bar(ast::BarDetails::SyncAligned, a) } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-atom @@ -1731,11 +1758,25 @@ InstBfi: ast::Instruction> = { "bfi" => ast::Instruction::Bfi{ <> } } +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt +InstPrmt: ast::Instruction> = { + "prmt" ".b32" "," => ast::Instruction::Prmt{ <> } +} + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-rem InstRem: ast::Instruction> = { "rem" => ast::Instruction::Rem{ <> } } +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-activemask +InstActivemask: ast::Instruction> = { + "activemask" ".b32" => ast::Instruction::Activemask{ <> } +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-membar +InstMembar: ast::Instruction> = { + "membar" => ast::Instruction::Membar{ <> } +} NegTypeFtz: ast::ScalarType = { ".f16" => ast::ScalarType::F16, diff --git a/ptx/src/test/spirv_run/activemask.ptx b/ptx/src/test/spirv_run/activemask.ptx new file mode 100644 index 0000000..c352bb2 --- /dev/null +++ b/ptx/src/test/spirv_run/activemask.ptx @@ -0,0 +1,18 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry activemask( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 out_addr; + .reg .b32 temp; + + ld.param.u64 out_addr, [output]; + + activemask.b32 temp; + st.u32 [out_addr], temp; + ret; +} diff --git a/ptx/src/test/spirv_run/activemask.spvtxt b/ptx/src/test/spirv_run/activemask.spvtxt new file mode 100644 index 0000000..c4ad55d --- /dev/null +++ b/ptx/src/test/spirv_run/activemask.spvtxt @@ -0,0 +1,45 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %16 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "activemask" + OpExecutionMode %1 ContractionOff + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %19 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %uint = OpTypeInt 32 0 +%_ptr_Function_uint = OpTypePointer Function %uint + %v4uint = OpTypeVector %uint 4 + %bool = OpTypeBool + %true = OpConstantTrue %bool +%_ptr_Generic_uint = OpTypePointer Generic %uint + %1 = OpFunction %void None %19 + %6 = OpFunctionParameter %ulong + %7 = OpFunctionParameter %ulong + %14 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_uint Function + OpStore %2 %6 + OpStore %3 %7 + %8 = OpLoad %ulong %3 Aligned 8 + OpStore %4 %8 + %26 = OpSubgroupBallotKHR %v4uint %true + %9 = OpCompositeExtract %uint %26 0 + OpStore %5 %9 + %10 = OpLoad %ulong %4 + %11 = OpLoad %uint %5 + %12 = OpConvertUToPtr %_ptr_Generic_uint %10 + %13 = OpCopyObject %uint %11 + OpStore %12 %13 Aligned 4 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/membar.ptx b/ptx/src/test/spirv_run/membar.ptx new file mode 100644 index 0000000..01aa9f2 --- /dev/null +++ b/ptx/src/test/spirv_run/membar.ptx @@ -0,0 +1,21 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry membar( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .s32 temp; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.u32 temp, [in_addr]; + membar.sys; + st.s32 [out_addr], temp; + ret; +} diff --git a/ptx/src/test/spirv_run/membar.spvtxt b/ptx/src/test/spirv_run/membar.spvtxt new file mode 100644 index 0000000..d808cf3 --- /dev/null +++ b/ptx/src/test/spirv_run/membar.spvtxt @@ -0,0 +1,49 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %20 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "membar" + OpExecutionMode %1 ContractionOff + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %23 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %uint = OpTypeInt 32 0 +%_ptr_Function_uint = OpTypePointer Function %uint +%_ptr_Generic_uint = OpTypePointer Generic %uint + %uint_0 = OpConstant %uint 0 + %uint_784 = OpConstant %uint 784 + %1 = OpFunction %void None %23 + %7 = OpFunctionParameter %ulong + %8 = OpFunctionParameter %ulong + %18 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_uint Function + OpStore %2 %7 + OpStore %3 %8 + %9 = OpLoad %ulong %2 Aligned 8 + OpStore %4 %9 + %10 = OpLoad %ulong %3 Aligned 8 + OpStore %5 %10 + %12 = OpLoad %ulong %4 + %16 = OpConvertUToPtr %_ptr_Generic_uint %12 + %15 = OpLoad %uint %16 Aligned 4 + %11 = OpCopyObject %uint %15 + OpStore %6 %11 + OpMemoryBarrier %uint_0 %uint_784 + %13 = OpLoad %ulong %5 + %14 = OpLoad %uint %6 + %17 = OpConvertUToPtr %_ptr_Generic_uint %13 + OpStore %17 %14 Aligned 4 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 97cfbb5..f6b556e 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -206,6 +206,9 @@ test_ptx!(stateful_neg_offset, [1237518u64], [1237518u64]); test_ptx!(const, [0u16], [10u16, 20, 30, 40]); test_ptx!(cvt_s16_s8, [0x139231C2u32], [0xFFFFFFC2u32]); test_ptx!(cvt_f64_f32, [0.125f32], [0.125f64]); +test_ptx!(prmt, [0x70c507d6u32, 0x6fbd4b5cu32], [0x6fbdd65cu32]); +test_ptx!(activemask, [0u32], [1u32]); +test_ptx!(membar, [152731u32], [152731u32]); struct DisplayError { err: T, diff --git a/ptx/src/test/spirv_run/prmt.ptx b/ptx/src/test/spirv_run/prmt.ptx new file mode 100644 index 0000000..ba339e8 --- /dev/null +++ b/ptx/src/test/spirv_run/prmt.ptx @@ -0,0 +1,23 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry prmt( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u32 temp1; + .reg .u32 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.u32 temp1, [in_addr]; + ld.u32 temp2, [in_addr+4]; + prmt.b32 temp2, temp1, temp2, 30212; + st.u32 [out_addr], temp2; + ret; +} diff --git a/ptx/src/test/spirv_run/prmt.spvtxt b/ptx/src/test/spirv_run/prmt.spvtxt new file mode 100644 index 0000000..060f534 --- /dev/null +++ b/ptx/src/test/spirv_run/prmt.spvtxt @@ -0,0 +1,67 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %31 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "prmt" + OpExecutionMode %1 ContractionOff + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %34 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %uint = OpTypeInt 32 0 +%_ptr_Function_uint = OpTypePointer Function %uint +%_ptr_Generic_uint = OpTypePointer Generic %uint + %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar + %v4uchar = OpTypeVector %uchar 4 + %1 = OpFunction %void None %34 + %8 = OpFunctionParameter %ulong + %9 = OpFunctionParameter %ulong + %29 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_uint Function + %7 = OpVariable %_ptr_Function_uint Function + OpStore %2 %8 + OpStore %3 %9 + %10 = OpLoad %ulong %2 Aligned 8 + OpStore %4 %10 + %11 = OpLoad %ulong %3 Aligned 8 + OpStore %5 %11 + %13 = OpLoad %ulong %4 + %23 = OpConvertUToPtr %_ptr_Generic_uint %13 + %12 = OpLoad %uint %23 Aligned 4 + OpStore %6 %12 + %15 = OpLoad %ulong %4 + %24 = OpConvertUToPtr %_ptr_Generic_uint %15 + %41 = OpBitcast %_ptr_Generic_uchar %24 + %42 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %41 %ulong_4 + %22 = OpBitcast %_ptr_Generic_uint %42 + %14 = OpLoad %uint %22 Aligned 4 + OpStore %7 %14 + %17 = OpLoad %uint %6 + %18 = OpLoad %uint %7 + %26 = OpCopyObject %uint %17 + %27 = OpCopyObject %uint %18 + %44 = OpBitcast %v4uchar %26 + %45 = OpBitcast %v4uchar %27 + %46 = OpVectorShuffle %v4uchar %44 %45 4 0 6 7 + %25 = OpBitcast %uint %46 + %16 = OpCopyObject %uint %25 + OpStore %7 %16 + %19 = OpLoad %ulong %5 + %20 = OpLoad %uint %7 + %28 = OpConvertUToPtr %_ptr_Generic_uint %19 + OpStore %28 %20 Aligned 4 + OpReturn + OpFunctionEnd \ No newline at end of file diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index a41179d..e015062 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -2992,6 +2992,76 @@ fn emit_function_body_ops<'input>( let result_type = map.get_or_add_scalar(builder, (*typ).into()); builder_fn(builder, result_type, Some(arg.dst), arg.src1, arg.src2)?; } + ast::Instruction::Prmt { control, arg } => { + let control = *control as u32; + let components = [ + (control >> 0) & 0b1111, + (control >> 4) & 0b1111, + (control >> 8) & 0b1111, + (control >> 12) & 0b1111, + ]; + if components.iter().any(|&c| c > 7) { + return Err(TranslateError::Todo); + } + let vec4_b8_type = + map.get_or_add(builder, SpirvType::Vector(SpirvScalarKey::B8, 4)); + let b32_type = map.get_or_add_scalar(builder, ast::ScalarType::B32); + let src1_vector = builder.bitcast(vec4_b8_type, None, arg.src1)?; + let src2_vector = builder.bitcast(vec4_b8_type, None, arg.src2)?; + let dst_vector = builder.vector_shuffle( + vec4_b8_type, + None, + src1_vector, + src2_vector, + components, + )?; + builder.bitcast(b32_type, Some(arg.dst), dst_vector)?; + } + ast::Instruction::Activemask { arg } => { + let b32_type = map.get_or_add_scalar(builder, ast::ScalarType::B32); + let vec4_b32_type = + map.get_or_add(builder, SpirvType::Vector(SpirvScalarKey::B32, 4)); + let pred_true = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::Pred), + &[1], + )?; + let dst_vector = builder.subgroup_ballot_khr(vec4_b32_type, None, pred_true)?; + builder.composite_extract(b32_type, Some(arg.src), dst_vector, [0])?; + } + ast::Instruction::Membar { level } => { + let (scope, semantics) = match level { + ast::MemScope::Cta => ( + spirv::Scope::Workgroup, + spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY + | spirv::MemorySemantics::WORKGROUP_MEMORY + | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT, + ), + ast::MemScope::Gpu => ( + spirv::Scope::Device, + spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY + | spirv::MemorySemantics::WORKGROUP_MEMORY + | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT, + ), + ast::MemScope::Sys => ( + spirv::Scope::CrossDevice, + spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY + | spirv::MemorySemantics::WORKGROUP_MEMORY + | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT, + ), + }; + let spirv_scope = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(scope as u32), + )?; + let spirv_semantics = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(semantics), + )?; + builder.memory_barrier(spirv_scope, spirv_semantics)?; + } }, Statement::LoadVar(details) => { emit_load_var(builder, map, details)?; @@ -4172,7 +4242,6 @@ fn normalize_identifiers<'input, 'b>( match s { ast::Statement::Label(id) => { id_defs.add_def(*id, None, false); - eprintln!("{}", id); } _ => (), } @@ -5800,7 +5869,7 @@ impl ast::Instruction { let new_args = a.map(visitor, &d)?; ast::Instruction::St(d, new_args) } - ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, None)?), + ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, false, None)?), ast::Instruction::Ret(d) => ast::Instruction::Ret(d), ast::Instruction::Cvta(d, a) => { let inst_type = ast::Type::Scalar(ast::ScalarType::B64); @@ -5942,6 +6011,21 @@ impl ast::Instruction { arg: arg.map_non_shift(visitor, &full_type, false)?, } } + ast::Instruction::Prmt { control, arg } => ast::Instruction::Prmt { + control, + arg: arg.map_prmt(visitor)?, + }, + ast::Instruction::Activemask { arg } => ast::Instruction::Activemask { + arg: arg.map( + visitor, + true, + Some(( + &ast::Type::Scalar(ast::ScalarType::B32), + ast::StateSpace::Reg, + )), + )?, + }, + ast::Instruction::Membar { level } => ast::Instruction::Membar { level }, }) } } @@ -6202,6 +6286,9 @@ impl ast::Instruction { ast::Instruction::Bfe { .. } => None, ast::Instruction::Bfi { .. } => None, ast::Instruction::Rem { .. } => None, + ast::Instruction::Prmt { .. } => None, + ast::Instruction::Activemask { .. } => None, + ast::Instruction::Membar { .. } => None, ast::Instruction::Sub(ast::ArithDetails::Float(float_control), _) | ast::Instruction::Add(ast::ArithDetails::Float(float_control), _) | ast::Instruction::Mul(ast::MulDetails::Float(float_control), _) @@ -6339,12 +6426,13 @@ impl ast::Arg1 { fn map>( self, visitor: &mut V, + is_dst: bool, t: Option<(&ast::Type, ast::StateSpace)>, ) -> Result, TranslateError> { let new_src = visitor.id( ArgumentDescriptor { op: self.src, - is_dst: false, + is_dst, is_memory_access: false, non_default_implicit_conversion: None, }, @@ -6685,6 +6773,43 @@ impl ast::Arg3 { )?; Ok(ast::Arg3 { dst, src1, src2 }) } + + fn map_prmt>( + self, + visitor: &mut V, + ) -> Result, TranslateError> { + let dst = visitor.operand( + ArgumentDescriptor { + op: self.dst, + is_dst: true, + is_memory_access: false, + non_default_implicit_conversion: None, + }, + &ast::Type::Scalar(ast::ScalarType::B32), + ast::StateSpace::Reg, + )?; + let src1 = visitor.operand( + ArgumentDescriptor { + op: self.src1, + is_dst: false, + is_memory_access: false, + non_default_implicit_conversion: None, + }, + &ast::Type::Scalar(ast::ScalarType::B32), + ast::StateSpace::Reg, + )?; + let src2 = visitor.operand( + ArgumentDescriptor { + op: self.src2, + is_dst: false, + is_memory_access: false, + non_default_implicit_conversion: None, + }, + &ast::Type::Scalar(ast::ScalarType::B32), + ast::StateSpace::Reg, + )?; + Ok(ast::Arg3 { dst, src1, src2 }) + } } impl ast::Arg4 {