Implement shr

This commit is contained in:
Andrzej Janik 2020-10-01 18:13:09 +02:00
parent 3e92921275
commit 96a342e33f
6 changed files with 128 additions and 0 deletions

View file

@ -339,6 +339,7 @@ pub enum Instruction<P: ArgParams> {
Cvt(CvtDetails, Arg2<P>),
Cvta(CvtaDetails, Arg2<P>),
Shl(ShlType, Arg3<P>),
Shr(ShrType, Arg3<P>),
St(StData, Arg2St<P>),
Ret(RetData),
Call(CallInst<P>),
@ -762,6 +763,18 @@ pub enum ShlType {
B64,
}
sub_scalar_type!(ShrType {
B16,
B32,
B64,
U16,
U32,
U64,
S16,
S32,
S64,
});
pub struct StData {
pub qualifier: LdStQualifier,
pub state_space: StStateSpace,

View file

@ -439,6 +439,7 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
InstBra,
InstCvt,
InstShl,
InstShr,
InstSt,
InstRet,
InstCvta,
@ -918,6 +919,23 @@ ShlType: ast::ShlType = {
".b64" => ast::ShlType::B64,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shr
InstShr: ast::Instruction<ast::ParsedArgParams<'input>> = {
"shr" <t:ShrType> <a:Arg3> => ast::Instruction::Shr(t, a)
};
ShrType: ast::ShrType = {
".b16" => ast::ShrType::B16,
".b32" => ast::ShrType::B32,
".b64" => ast::ShrType::B64,
".u16" => ast::ShrType::U16,
".u32" => ast::ShrType::U32,
".u64" => ast::ShrType::U64,
".s16" => ast::ShrType::S16,
".s32" => ast::ShrType::S32,
".s64" => ast::ShrType::S64,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st
// Warning: NVIDIA documentation is incorrect, you can specify scope only once
InstSt: ast::Instruction<ast::ParsedArgParams<'input>> = {

View file

@ -68,6 +68,8 @@ test_ptx!(pred_not, [10u64, 11u64], [2u64, 0u64]);
test_ptx!(mad_s32, [2i32, 3i32, 4i32], [10i32, 10i32, 10i32]);
test_ptx!(mul_wide, [0x01_00_00_00__01_00_00_00i64], [0x1_00_00_00_00_00_00i64]);
test_ptx!(vector_extract, [1u8, 2u8, 3u8, 4u8], [3u8, 4u8, 1u8, 2u8]);
test_ptx!(shr, [-2i32], [-1i32]);
struct DisplayError<T: Debug> {
err: T,

View file

@ -0,0 +1,21 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry shr(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .s32 temp;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.s32 temp, [in_addr];
shr.s32 temp, temp, 1;
st.s32 [out_addr], temp;
ret;
}

View file

@ -0,0 +1,50 @@
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int8
OpCapability Int16
OpCapability Int64
OpCapability Float16
OpCapability Float64
%24 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "shr"
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%27 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%uint = OpTypeInt 32 0
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_Generic_uint = OpTypePointer Generic %uint
%uint_1 = OpConstant %uint 1
%1 = OpFunction %void None %27
%7 = OpFunctionParameter %ulong
%8 = OpFunctionParameter %ulong
%22 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
%5 = OpVariable %_ptr_Function_ulong Function
%6 = OpVariable %_ptr_Function_uint Function
OpStore %2 %7
OpStore %3 %8
%10 = OpLoad %ulong %2
%9 = OpCopyObject %ulong %10
OpStore %4 %9
%12 = OpLoad %ulong %3
%11 = OpCopyObject %ulong %12
OpStore %5 %11
%14 = OpLoad %ulong %4
%20 = OpConvertUToPtr %_ptr_Generic_uint %14
%13 = OpLoad %uint %20
OpStore %6 %13
%16 = OpLoad %uint %6
%15 = OpShiftRightArithmetic %uint %16 %uint_1
OpStore %6 %15
%17 = OpLoad %ulong %5
%18 = OpLoad %uint %6
%21 = OpConvertUToPtr %_ptr_Generic_uint %17
OpStore %21 %18
OpReturn
OpFunctionEnd

View file

@ -589,6 +589,9 @@ fn convert_to_typed_statements(
ast::Instruction::Mad(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Mad(d, a.cast())))
}
ast::Instruction::Shr(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Shr(d, a.cast())))
}
},
Statement::Label(i) => result.push(Statement::Label(i)),
Statement::Variable(v) => result.push(Statement::Variable(v)),
@ -1555,6 +1558,14 @@ fn emit_function_body_ops(
let result_type = map.get_or_add(builder, SpirvType::from(t.to_type()));
builder.shift_left_logical(result_type, Some(a.dst), a.src1, a.src2)?;
}
ast::Instruction::Shr(t, a) => {
let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
if t.signed() {
builder.shift_right_arithmetic(result_type, Some(a.dst), a.src1, a.src2)?;
} else {
builder.shift_right_logical(result_type, Some(a.dst), a.src1, a.src2)?;
}
}
ast::Instruction::Cvt(dets, arg) => {
emit_cvt(builder, map, dets, arg)?;
}
@ -2874,6 +2885,9 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
ast::Instruction::Shl(t, a) => {
ast::Instruction::Shl(t, a.map_shift(visitor, t.to_type())?)
}
ast::Instruction::Shr(t, a) => {
ast::Instruction::Shr(t, a.map_shift(visitor, ast::Type::Scalar(t.into()))?)
}
ast::Instruction::St(d, a) => {
let inst_type = d.typ;
let is_param = d.state_space == ast::StStateSpace::Param
@ -3094,6 +3108,7 @@ impl ast::Instruction<ExpandedArgParams> {
| ast::Instruction::Cvt(_, _)
| ast::Instruction::Cvta(_, _)
| ast::Instruction::Shl(_, _)
| ast::Instruction::Shr(_, _)
| ast::Instruction::St(_, _)
| ast::Instruction::Ret(_)
| ast::Instruction::Abs(_, _)
@ -4009,6 +4024,15 @@ impl ast::ShlType {
}
}
impl ast::ShrType {
fn signed(&self) -> bool {
match self {
ast::ShrType::S16 | ast::ShrType::S32 | ast::ShrType::S64 => true,
_ => false,
}
}
}
impl ast::AddDetails {
fn get_type(&self) -> ast::Type {
match self {