From 8ee46c8fe121868fbd452c69e08bdfa705c60443 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sat, 1 Aug 2020 00:51:18 +0200 Subject: [PATCH] Implement negation --- ptx/src/ast.rs | 10 ++++++-- ptx/src/ptx.lalrpop | 9 ++++--- ptx/src/test/spirv_run/mod.rs | 1 + ptx/src/test/spirv_run/not.ptx | 22 +++++++++++++++++ ptx/src/test/spirv_run/not.spvtxt | 39 +++++++++++++++++++++++++++++++ ptx/src/translate.rs | 28 ++++++++++++++++++---- 6 files changed, 100 insertions(+), 9 deletions(-) create mode 100644 ptx/src/test/spirv_run/not.ptx create mode 100644 ptx/src/test/spirv_run/not.spvtxt diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index bbc5815..158ec8d 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -176,7 +176,7 @@ pub enum Instruction { Add(AddDetails, Arg3

), Setp(SetpData, Arg4

), SetpBool(SetpBoolData, Arg5

), - Not(NotData, Arg2

), + Not(NotType, Arg2

), Bra(BraData, Arg1

), Cvt(CvtData, Arg2

), Shl(ShlData, Arg3

), @@ -386,7 +386,13 @@ pub struct SetpBoolData { pub bool_op: SetpBoolPostOp, } -pub struct NotData {} +#[derive(PartialEq, Eq, Copy, Clone)] +pub enum NotType { + Pred, + B16, + B32, + B64, +} pub struct BraData { pub uniform: bool, diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index af26765..d525fbe 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -557,11 +557,14 @@ SetpType: ast::ScalarType = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-not InstNot: ast::Instruction> = { - "not" NotType => ast::Instruction::Not(ast::NotData{}, a) + "not" => ast::Instruction::Not(t, a) }; -NotType = { - ".pred", ".b16", ".b32", ".b64" +NotType: ast::NotType = { + ".pred" => ast::NotType::Pred, + ".b16" => ast::NotType::B16, + ".b32" => ast::NotType::B32, + ".b64" => ast::NotType::B64, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-at diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index c90e487..b4414d9 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -46,6 +46,7 @@ test_ptx!(mul_hi, [u64::max_value()], [1u64]); test_ptx!(add, [1u64], [2u64]); test_ptx!(setp, [10u64, 11u64], [1u64, 0u64]); test_ptx!(bra, [10u64], [11u64]); +test_ptx!(not, [0u64], [u64::max_value()]); struct DisplayError { err: T, diff --git a/ptx/src/test/spirv_run/not.ptx b/ptx/src/test/spirv_run/not.ptx new file mode 100644 index 0000000..6182134 --- /dev/null +++ b/ptx/src/test/spirv_run/not.ptx @@ -0,0 +1,22 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry not( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u64 temp; + .reg .u64 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.u64 temp, [in_addr]; + not.b64 temp2, temp; + st.u64 [out_addr], temp2; + ret; +} diff --git a/ptx/src/test/spirv_run/not.spvtxt b/ptx/src/test/spirv_run/not.spvtxt new file mode 100644 index 0000000..518e995 --- /dev/null +++ b/ptx/src/test/spirv_run/not.spvtxt @@ -0,0 +1,39 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + %1 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %5 "not" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %4 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %ulong_0 = OpTypeInt 64 0 + %5 = OpFunction %void None %4 + %6 = OpFunctionParameter %ulong + %7 = OpFunctionParameter %ulong + %20 = OpLabel + %8 = OpVariable %_ptr_Function_ulong Function + %9 = OpVariable %_ptr_Function_ulong Function + %10 = OpVariable %_ptr_Function_ulong Function + %11 = OpVariable %_ptr_Function_ulong Function + OpStore %8 %6 + OpStore %9 %7 + %13 = OpLoad %ulong %8 + %18 = OpConvertUToPtr %_ptr_Generic_ulong %13 + %12 = OpLoad %ulong %18 + OpStore %10 %12 + %15 = OpLoad %ulong_0 %10 + %14 = OpNot %ulong_0 %15 + OpStore %11 %14 + %16 = OpLoad %ulong %9 + %17 = OpLoad %ulong %11 + %19 = OpConvertUToPtr %_ptr_Generic_ulong %16 + OpStore %19 %17 + OpReturn + OpFunctionEnd + \ No newline at end of file diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index c40e554..a6e627f 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -659,6 +659,15 @@ fn emit_function_body_ops( } emit_setp(builder, map, setp, arg)?; } + ast::Instruction::Not(t, a) => { + let result_type = map.get_or_add(builder, SpirvType::from(t.to_type())); + let result_id = Some(a.dst); + let operand = a.src; + match t { + ast::NotType::Pred => builder.logical_not(result_type, result_id, operand), + _ => builder.not(result_type, result_id, operand), + }?; + } _ => todo!(), }, Statement::LoadVar(arg, typ) => { @@ -887,9 +896,7 @@ fn expand_map_variables<'a>( s: ast::Statement>, ) { match s { - ast::Statement::Label(name) => { - result.push(ast::Statement::Label(id_defs.get_id(name))) - } + ast::Statement::Label(name) => result.push(ast::Statement::Label(id_defs.get_id(name))), ast::Statement::Instruction(p, i) => result.push(ast::Statement::Instruction( p.map(|p| p.map_variable(&mut |id| id_defs.get_id(id))), i.map_variable(&mut |id| id_defs.get_id(id)), @@ -1128,7 +1135,9 @@ impl ast::Instruction { let inst_type = d.typ; ast::Instruction::SetpBool(d, a.map(visitor, Some(ast::Type::Scalar(inst_type)))) } - ast::Instruction::Not(_, _) => todo!(), + ast::Instruction::Not(t, a) => { + ast::Instruction::Not(t, a.map(visitor, Some(t.to_type()))) + } ast::Instruction::Cvt(_, _) => todo!(), ast::Instruction::Shl(_, _) => todo!(), ast::Instruction::St(d, a) => { @@ -1513,6 +1522,17 @@ impl ast::ScalarType { } } +impl ast::NotType { + fn to_type(self) -> ast::Type { + match self { + ast::NotType::Pred => ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred), + ast::NotType::B16 => ast::Type::Scalar(ast::ScalarType::B16), + ast::NotType::B32 => ast::Type::Scalar(ast::ScalarType::B32), + ast::NotType::B64 => ast::Type::Scalar(ast::ScalarType::B64), + } + } +} + impl ast::AddDetails { fn get_type(&self) -> ast::Type { match self {