From d7bf1acf84faa8f6cb1d5edb6c4d9eb0f05a5ae0 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Thu, 5 Nov 2020 22:10:06 +0100 Subject: [PATCH] Implement instructions clz, brev, popc --- ptx/src/ast.rs | 3 ++ ptx/src/ptx.lalrpop | 34 ++++++++++++--- ptx/src/test/spirv_run/brev.ptx | 21 ++++++++++ ptx/src/test/spirv_run/brev.spvtxt | 47 +++++++++++++++++++++ ptx/src/test/spirv_run/clz.ptx | 21 ++++++++++ ptx/src/test/spirv_run/clz.spvtxt | 47 +++++++++++++++++++++ ptx/src/test/spirv_run/mod.rs | 9 +++- ptx/src/test/spirv_run/popc.ptx | 21 ++++++++++ ptx/src/test/spirv_run/popc.spvtxt | 47 +++++++++++++++++++++ ptx/src/translate.rs | 66 +++++++++++++++++++++++++++++- 10 files changed, 308 insertions(+), 8 deletions(-) create mode 100644 ptx/src/test/spirv_run/brev.ptx create mode 100644 ptx/src/test/spirv_run/brev.spvtxt create mode 100644 ptx/src/test/spirv_run/clz.ptx create mode 100644 ptx/src/test/spirv_run/clz.spvtxt create mode 100644 ptx/src/test/spirv_run/popc.ptx create mode 100644 ptx/src/test/spirv_run/popc.spvtxt diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 653060b..b6ac3db 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -587,6 +587,9 @@ pub enum Instruction { Cos { flush_to_zero: bool, arg: Arg2

}, Lg2 { flush_to_zero: bool, arg: Arg2

}, Ex2 { flush_to_zero: bool, arg: Arg2

}, + Clz { typ: BitType, arg: Arg2

}, + Brev { typ: BitType, arg: Arg2

}, + Popc { typ: BitType, arg: Arg2

}, } #[derive(Copy, Clone)] diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 31c2356..cd1c642 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -143,7 +143,9 @@ match { "bar", "barrier", "bra", + "brev", "call", + "clz", "cos", "cvt", "cvta", @@ -162,6 +164,7 @@ match { "neg", "not", "or", + "popc", "rcp", "ret", "rsqrt", @@ -190,7 +193,9 @@ ExtendedID : &'input str = { "bar", "barrier", "bra", + "brev", "call", + "clz", "cos", "cvt", "cvta", @@ -209,6 +214,7 @@ ExtendedID : &'input str = { "neg", "not", "or", + "popc", "rcp", "ret", "rsqrt", @@ -699,6 +705,9 @@ Instruction: ast::Instruction> = { InstCos, InstLg2, InstEx2, + InstClz, + InstBrev, + InstPopc, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld @@ -1395,7 +1404,7 @@ InstBar: ast::Instruction> = { // * Operation .dec requires .u32 type for instuction // Otherwise as documented InstAtom: ast::Instruction> = { - "atom" => { + "atom" => { let details = ast::AtomDetails { semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), scope: scope.unwrap_or(ast::MemScope::Gpu), @@ -1459,7 +1468,7 @@ InstAtom: ast::Instruction> = { } InstAtomCas: ast::Instruction> = { - "atom" ".cas" => { + "atom" ".cas" => { let details = ast::AtomCasDetails { semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), scope: scope.unwrap_or(ast::MemScope::Gpu), @@ -1501,7 +1510,7 @@ AtomSIntOp: ast::AtomSIntOp = { ".max" => ast::AtomSIntOp::Max, } -AtomBitType: ast::BitType = { +BitType: ast::BitType = { ".b32" => ast::BitType::B32, ".b64" => ast::BitType::B64, } @@ -1640,6 +1649,21 @@ InstEx2: ast::Instruction> = { }, } +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-clz +InstClz: ast::Instruction> = { + "clz" => ast::Instruction::Clz{ <> } +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-brev +InstBrev: ast::Instruction> = { + "brev" => ast::Instruction::Brev{ <> } +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-popc +InstPopc: ast::Instruction> = { + "popc" => ast::Instruction::Popc{ <> } +} + NegTypeFtz: ast::ScalarType = { ".f16" => ast::ScalarType::F16, ".f16x2" => ast::ScalarType::F16x2, @@ -1858,7 +1882,7 @@ Section = { }; SectionDwarfLines: () = { - BitType Comma, + AnyBitType Comma, ".b32" SectionLabel, ".b64" SectionLabel, ".b32" SectionLabel "+" U32Num, @@ -1870,7 +1894,7 @@ SectionLabel = { DotID }; -BitType = { +AnyBitType = { ".b8", ".b16", ".b32", ".b64" }; diff --git a/ptx/src/test/spirv_run/brev.ptx b/ptx/src/test/spirv_run/brev.ptx new file mode 100644 index 0000000..1d9dd75 --- /dev/null +++ b/ptx/src/test/spirv_run/brev.ptx @@ -0,0 +1,21 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry brev( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .b32 temp; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.b32 temp, [in_addr]; + brev.b32 temp, temp; + st.b32 [out_addr], temp; + ret; +} diff --git a/ptx/src/test/spirv_run/brev.spvtxt b/ptx/src/test/spirv_run/brev.spvtxt new file mode 100644 index 0000000..df5df53 --- /dev/null +++ b/ptx/src/test/spirv_run/brev.spvtxt @@ -0,0 +1,47 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %21 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "brev" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %24 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %uint = OpTypeInt 32 0 +%_ptr_Function_uint = OpTypePointer Function %uint +%_ptr_Generic_uint = OpTypePointer Generic %uint + %1 = OpFunction %void None %24 + %7 = OpFunctionParameter %ulong + %8 = OpFunctionParameter %ulong + %19 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_uint Function + OpStore %2 %7 + OpStore %3 %8 + %9 = OpLoad %ulong %2 + OpStore %4 %9 + %10 = OpLoad %ulong %3 + OpStore %5 %10 + %12 = OpLoad %ulong %4 + %17 = OpConvertUToPtr %_ptr_Generic_uint %12 + %11 = OpLoad %uint %17 + OpStore %6 %11 + %14 = OpLoad %uint %6 + %13 = OpBitReverse %uint %14 + OpStore %6 %13 + %15 = OpLoad %ulong %5 + %16 = OpLoad %uint %6 + %18 = OpConvertUToPtr %_ptr_Generic_uint %15 + OpStore %18 %16 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/clz.ptx b/ptx/src/test/spirv_run/clz.ptx new file mode 100644 index 0000000..b475b90 --- /dev/null +++ b/ptx/src/test/spirv_run/clz.ptx @@ -0,0 +1,21 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry clz( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .b32 temp; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.b32 temp, [in_addr]; + clz.b32 temp, temp; + st.b32 [out_addr], temp; + ret; +} diff --git a/ptx/src/test/spirv_run/clz.spvtxt b/ptx/src/test/spirv_run/clz.spvtxt new file mode 100644 index 0000000..5d1ebc8 --- /dev/null +++ b/ptx/src/test/spirv_run/clz.spvtxt @@ -0,0 +1,47 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %21 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "clz" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %24 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %uint = OpTypeInt 32 0 +%_ptr_Function_uint = OpTypePointer Function %uint +%_ptr_Generic_uint = OpTypePointer Generic %uint + %1 = OpFunction %void None %24 + %7 = OpFunctionParameter %ulong + %8 = OpFunctionParameter %ulong + %19 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_uint Function + OpStore %2 %7 + OpStore %3 %8 + %9 = OpLoad %ulong %2 + OpStore %4 %9 + %10 = OpLoad %ulong %3 + OpStore %5 %10 + %12 = OpLoad %ulong %4 + %17 = OpConvertUToPtr %_ptr_Generic_uint %12 + %11 = OpLoad %uint %17 + OpStore %6 %11 + %14 = OpLoad %uint %6 + %13 = OpExtInst %uint %21 clz %14 + OpStore %6 %13 + %15 = OpLoad %ulong %5 + %16 = OpLoad %uint %6 + %18 = OpConvertUToPtr %_ptr_Generic_uint %15 + OpStore %18 %16 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 163caac..a7ef75b 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -104,11 +104,18 @@ test_ptx!(div_approx, [1f32, 2f32], [0.5f32]); test_ptx!(sqrt, [0.25f32], [0.5f32]); test_ptx!(rsqrt, [0.25f64], [2f64]); test_ptx!(neg, [181i32], [-181i32]); -test_ptx!(sin, [std::f32::consts::PI/2f32], [1f32]); +test_ptx!(sin, [std::f32::consts::PI / 2f32], [1f32]); test_ptx!(cos, [std::f32::consts::PI], [-1f32]); test_ptx!(lg2, [512f32], [9f32]); test_ptx!(ex2, [10f32], [1024f32]); test_ptx!(cvt_rni, [9.5f32, 10.5f32], [10f32, 10f32]); +test_ptx!(clz, [0b00000101_00101101_00010011_10101011u32], [5u32]); +test_ptx!(popc, [0b10111100_10010010_01001001_10001010u32], [14u32]); +test_ptx!( + brev, + [0b11000111_01011100_10101110_11111011u32], + [0b11011111_01110101_00111010_11100011u32] +); struct DisplayError { err: T, diff --git a/ptx/src/test/spirv_run/popc.ptx b/ptx/src/test/spirv_run/popc.ptx new file mode 100644 index 0000000..7106422 --- /dev/null +++ b/ptx/src/test/spirv_run/popc.ptx @@ -0,0 +1,21 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry popc( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .b32 temp; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.b32 temp, [in_addr]; + popc.b32 temp, temp; + st.b32 [out_addr], temp; + ret; +} diff --git a/ptx/src/test/spirv_run/popc.spvtxt b/ptx/src/test/spirv_run/popc.spvtxt new file mode 100644 index 0000000..bb4968f --- /dev/null +++ b/ptx/src/test/spirv_run/popc.spvtxt @@ -0,0 +1,47 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %21 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "popc" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %24 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %uint = OpTypeInt 32 0 +%_ptr_Function_uint = OpTypePointer Function %uint +%_ptr_Generic_uint = OpTypePointer Generic %uint + %1 = OpFunction %void None %24 + %7 = OpFunctionParameter %ulong + %8 = OpFunctionParameter %ulong + %19 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_uint Function + OpStore %2 %7 + OpStore %3 %8 + %9 = OpLoad %ulong %2 + OpStore %4 %9 + %10 = OpLoad %ulong %3 + OpStore %5 %10 + %12 = OpLoad %ulong %4 + %17 = OpConvertUToPtr %_ptr_Generic_uint %12 + %11 = OpLoad %uint %17 + OpStore %6 %11 + %14 = OpLoad %uint %6 + %13 = OpBitCount %uint %14 + OpStore %6 %13 + %15 = OpLoad %ulong %5 + %16 = OpLoad %uint %6 + %18 = OpConvertUToPtr %_ptr_Generic_uint %15 + OpStore %18 %16 + OpReturn + OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 9519951..23a63be 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1573,6 +1573,24 @@ fn convert_to_typed_statements( arg: arg.cast(), })) } + ast::Instruction::Clz { typ, arg } => { + result.push(Statement::Instruction(ast::Instruction::Clz { + typ, + arg: arg.cast(), + })) + } + ast::Instruction::Brev { typ, arg } => { + result.push(Statement::Instruction(ast::Instruction::Brev { + typ, + arg: arg.cast(), + })) + } + ast::Instruction::Popc { typ, arg } => { + result.push(Statement::Instruction(ast::Instruction::Popc { + typ, + arg: arg.cast(), + })) + } }, Statement::Label(i) => result.push(Statement::Label(i)), Statement::Variable(v) => result.push(Statement::Variable(v)), @@ -2997,6 +3015,24 @@ fn emit_function_body_ops( [arg.src], )?; } + ast::Instruction::Clz { typ, arg } => { + let result_type = map.get_or_add_scalar(builder, (*typ).into()); + builder.ext_inst( + result_type, + Some(arg.dst), + opencl, + spirv::CLOp::clz as u32, + [arg.src], + )?; + } + ast::Instruction::Brev { typ, arg } => { + let result_type = map.get_or_add_scalar(builder, (*typ).into()); + builder.bit_reverse(result_type, Some(arg.dst), arg.src)?; + } + ast::Instruction::Popc { typ, arg } => { + let result_type = map.get_or_add_scalar(builder, (*typ).into()); + builder.bit_count(result_type, Some(arg.dst), arg.src)?; + } }, Statement::LoadVar(arg, typ) => { let type_id = map.get_or_add(builder, SpirvType::from(typ.clone())); @@ -4881,7 +4917,7 @@ impl ast::Instruction { ast::Type::Scalar(desc.src.into()), ), }; - ast::Instruction::Cvt(d, a.map_cvt(visitor, &dst_t, &src_t)?) + ast::Instruction::Cvt(d, a.map_different_types(visitor, &dst_t, &src_t)?) } ast::Instruction::Shl(t, a) => { ast::Instruction::Shl(t, a.map_shift(visitor, &t.to_type())?) @@ -4980,6 +5016,29 @@ impl ast::Instruction { arg: arg.map(visitor, &typ)?, } } + ast::Instruction::Clz { typ, arg } => { + let dst_type = ast::Type::Scalar(ast::ScalarType::B32); + let src_type = ast::Type::Scalar(typ.into()); + ast::Instruction::Clz { + typ, + arg: arg.map_different_types(visitor, &dst_type, &src_type)?, + } + } + ast::Instruction::Brev { typ, arg } => { + let full_type = ast::Type::Scalar(typ.into()); + ast::Instruction::Brev { + typ, + arg: arg.map(visitor, &full_type)?, + } + } + ast::Instruction::Popc { typ, arg } => { + let dst_type = ast::Type::Scalar(ast::ScalarType::B32); + let src_type = ast::Type::Scalar(typ.into()); + ast::Instruction::Popc { + typ, + arg: arg.map_different_types(visitor, &dst_type, &src_type)?, + } + } }) } } @@ -5289,6 +5348,9 @@ impl ast::Instruction { ast::Instruction::Cvt(ast::CvtDetails::FloatFromInt(_), _) => None, ast::Instruction::Div(ast::DivDetails::Unsigned(_), _) => None, ast::Instruction::Div(ast::DivDetails::Signed(_), _) => None, + ast::Instruction::Clz { .. } => None, + ast::Instruction::Brev { .. } => None, + ast::Instruction::Popc { .. } => None, ast::Instruction::Sub(ast::ArithDetails::Float(float_control), _) | ast::Instruction::Add(ast::ArithDetails::Float(float_control), _) | ast::Instruction::Mul(ast::MulDetails::Float(float_control), _) @@ -5567,7 +5629,7 @@ impl ast::Arg2 { }) } - fn map_cvt>( + fn map_different_types>( self, visitor: &mut V, dst_t: &ast::Type,