diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 097e19c..b509dfe 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -339,6 +339,7 @@ pub enum Instruction { Cvt(CvtDetails, Arg2

), Cvta(CvtaDetails, Arg2

), Shl(ShlType, Arg3

), + Shr(ShrType, Arg3

), St(StData, Arg2St

), Ret(RetData), Call(CallInst

), @@ -762,6 +763,18 @@ pub enum ShlType { B64, } +sub_scalar_type!(ShrType { + B16, + B32, + B64, + U16, + U32, + U64, + S16, + S32, + S64, +}); + pub struct StData { pub qualifier: LdStQualifier, pub state_space: StStateSpace, diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index ba3fc2b..debdae7 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -439,6 +439,7 @@ Instruction: ast::Instruction> = { InstBra, InstCvt, InstShl, + InstShr, InstSt, InstRet, InstCvta, @@ -918,6 +919,23 @@ ShlType: ast::ShlType = { ".b64" => ast::ShlType::B64, }; +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shr +InstShr: ast::Instruction> = { + "shr" => ast::Instruction::Shr(t, a) +}; + +ShrType: ast::ShrType = { + ".b16" => ast::ShrType::B16, + ".b32" => ast::ShrType::B32, + ".b64" => ast::ShrType::B64, + ".u16" => ast::ShrType::U16, + ".u32" => ast::ShrType::U32, + ".u64" => ast::ShrType::U64, + ".s16" => ast::ShrType::S16, + ".s32" => ast::ShrType::S32, + ".s64" => ast::ShrType::S64, +}; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st // Warning: NVIDIA documentation is incorrect, you can specify scope only once InstSt: ast::Instruction> = { diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 5a16755..6f516fd 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -68,6 +68,8 @@ test_ptx!(pred_not, [10u64, 11u64], [2u64, 0u64]); test_ptx!(mad_s32, [2i32, 3i32, 4i32], [10i32, 10i32, 10i32]); test_ptx!(mul_wide, [0x01_00_00_00__01_00_00_00i64], [0x1_00_00_00_00_00_00i64]); test_ptx!(vector_extract, [1u8, 2u8, 3u8, 4u8], [3u8, 4u8, 1u8, 2u8]); +test_ptx!(shr, [-2i32], [-1i32]); + struct DisplayError { err: T, diff --git a/ptx/src/test/spirv_run/shr.ptx b/ptx/src/test/spirv_run/shr.ptx new file mode 100644 index 0000000..0a12fa7 --- /dev/null +++ b/ptx/src/test/spirv_run/shr.ptx @@ -0,0 +1,21 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry shr( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .s32 temp; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.s32 temp, [in_addr]; + shr.s32 temp, temp, 1; + st.s32 [out_addr], temp; + ret; +} diff --git a/ptx/src/test/spirv_run/shr.spvtxt b/ptx/src/test/spirv_run/shr.spvtxt new file mode 100644 index 0000000..417839d --- /dev/null +++ b/ptx/src/test/spirv_run/shr.spvtxt @@ -0,0 +1,50 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %24 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "shr" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %27 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %uint = OpTypeInt 32 0 +%_ptr_Function_uint = OpTypePointer Function %uint +%_ptr_Generic_uint = OpTypePointer Generic %uint + %uint_1 = OpConstant %uint 1 + %1 = OpFunction %void None %27 + %7 = OpFunctionParameter %ulong + %8 = OpFunctionParameter %ulong + %22 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_uint Function + OpStore %2 %7 + OpStore %3 %8 + %10 = OpLoad %ulong %2 + %9 = OpCopyObject %ulong %10 + OpStore %4 %9 + %12 = OpLoad %ulong %3 + %11 = OpCopyObject %ulong %12 + OpStore %5 %11 + %14 = OpLoad %ulong %4 + %20 = OpConvertUToPtr %_ptr_Generic_uint %14 + %13 = OpLoad %uint %20 + OpStore %6 %13 + %16 = OpLoad %uint %6 + %15 = OpShiftRightArithmetic %uint %16 %uint_1 + OpStore %6 %15 + %17 = OpLoad %ulong %5 + %18 = OpLoad %uint %6 + %21 = OpConvertUToPtr %_ptr_Generic_uint %17 + OpStore %21 %18 + OpReturn + OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 37cef00..fe6a7dc 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -589,6 +589,9 @@ fn convert_to_typed_statements( ast::Instruction::Mad(d, a) => { result.push(Statement::Instruction(ast::Instruction::Mad(d, a.cast()))) } + ast::Instruction::Shr(d, a) => { + result.push(Statement::Instruction(ast::Instruction::Shr(d, a.cast()))) + } }, Statement::Label(i) => result.push(Statement::Label(i)), Statement::Variable(v) => result.push(Statement::Variable(v)), @@ -1555,6 +1558,14 @@ fn emit_function_body_ops( let result_type = map.get_or_add(builder, SpirvType::from(t.to_type())); builder.shift_left_logical(result_type, Some(a.dst), a.src1, a.src2)?; } + ast::Instruction::Shr(t, a) => { + let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t)); + if t.signed() { + builder.shift_right_arithmetic(result_type, Some(a.dst), a.src1, a.src2)?; + } else { + builder.shift_right_logical(result_type, Some(a.dst), a.src1, a.src2)?; + } + } ast::Instruction::Cvt(dets, arg) => { emit_cvt(builder, map, dets, arg)?; } @@ -2874,6 +2885,9 @@ impl ast::Instruction { ast::Instruction::Shl(t, a) => { ast::Instruction::Shl(t, a.map_shift(visitor, t.to_type())?) } + ast::Instruction::Shr(t, a) => { + ast::Instruction::Shr(t, a.map_shift(visitor, ast::Type::Scalar(t.into()))?) + } ast::Instruction::St(d, a) => { let inst_type = d.typ; let is_param = d.state_space == ast::StStateSpace::Param @@ -3094,6 +3108,7 @@ impl ast::Instruction { | ast::Instruction::Cvt(_, _) | ast::Instruction::Cvta(_, _) | ast::Instruction::Shl(_, _) + | ast::Instruction::Shr(_, _) | ast::Instruction::St(_, _) | ast::Instruction::Ret(_) | ast::Instruction::Abs(_, _) @@ -4009,6 +4024,15 @@ impl ast::ShlType { } } +impl ast::ShrType { + fn signed(&self) -> bool { + match self { + ast::ShrType::S16 | ast::ShrType::S32 | ast::ShrType::S64 => true, + _ => false, + } + } +} + impl ast::AddDetails { fn get_type(&self) -> ast::Type { match self {