Add missing vray instructions

This commit is contained in:
Andrzej Janik 2021-09-16 01:25:09 +02:00
parent 467782b1d0
commit ca0d8ec666
10 changed files with 399 additions and 4 deletions

View file

@ -287,6 +287,9 @@ pub enum Instruction<P: ArgParams> {
Bfe { typ: ScalarType, arg: Arg4<P> },
Bfi { typ: ScalarType, arg: Arg5<P> },
Rem { typ: ScalarType, arg: Arg3<P> },
Prmt { control: u16, arg: Arg3<P> },
Activemask { arg: Arg1<P> },
Membar { level: MemScope },
}
#[derive(Copy, Clone)]

View file

@ -70,6 +70,7 @@ match {
".func",
".ge",
".geu",
".gl",
".global",
".gpu",
".gt",
@ -142,6 +143,7 @@ match {
} else {
// IF YOU ARE ADDING A NEW TOKEN HERE ALSO ADD IT BELOW TO ExtendedID
"abs",
"activemask",
"add",
"and",
"atom",
@ -165,6 +167,7 @@ match {
"mad",
"map_f64_to_f32",
"max",
"membar",
"min",
"mov",
"mul",
@ -172,6 +175,7 @@ match {
"not",
"or",
"popc",
"prmt",
"rcp",
"rem",
"ret",
@ -196,6 +200,7 @@ match {
ExtendedID : &'input str = {
"abs",
"activemask",
"add",
"and",
"atom",
@ -219,6 +224,7 @@ ExtendedID : &'input str = {
"mad",
"map_f64_to_f32",
"max",
"membar",
"min",
"mov",
"mul",
@ -226,6 +232,7 @@ ExtendedID : &'input str = {
"not",
"or",
"popc",
"prmt",
"rcp",
"rem",
"ret",
@ -292,6 +299,16 @@ U8Num: u8 = {
}
}
U16Num: u16 = {
<x:NumToken> =>? {
let (text, radix, _) = x;
match u16::from_str_radix(text, radix) {
Ok(x) => Ok(x),
Err(err) => Err(ParseError::User { error: ast::PtxError::from(err) })
}
}
}
U32Num: u32 = {
<x:NumToken> =>? {
let (text, radix, _) = x;
@ -761,6 +778,9 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
InstRem,
InstBfe,
InstBfi,
InstPrmt,
InstActivemask,
InstMembar,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
@ -821,6 +841,12 @@ MemScope: ast::MemScope = {
".sys" => ast::MemScope::Sys
};
MembarLevel: ast::MemScope = {
".cta" => ast::MemScope::Cta,
".gl" => ast::MemScope::Gpu,
".sys" => ast::MemScope::Sys
};
LdNonGlobalStateSpace: ast::StateSpace = {
".const" => ast::StateSpace::Const,
".local" => ast::StateSpace::Local,
@ -1445,8 +1471,9 @@ SelpType: ast::ScalarType = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar
InstBar: ast::Instruction<ast::ParsedArgParams<'input>> = {
"bar" ".sync" <a:Arg1Bar> => ast::Instruction::Bar(ast::BarDetails::SyncAligned, a),
"barrier" ".sync" <a:Arg1Bar> => ast::Instruction::Bar(ast::BarDetails::SyncAligned, a),
"barrier" ".sync" ".aligned" <a:Arg1Bar> => ast::Instruction::Bar(ast::BarDetails::SyncAligned, a),
"bar" ".sync" <a:Arg1Bar> => ast::Instruction::Bar(ast::BarDetails::SyncAligned, a)
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-atom
@ -1731,11 +1758,25 @@ InstBfi: ast::Instruction<ast::ParsedArgParams<'input>> = {
"bfi" <typ:BitType> <arg:Arg5> => ast::Instruction::Bfi{ <> }
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt
InstPrmt: ast::Instruction<ast::ParsedArgParams<'input>> = {
"prmt" ".b32" <arg:Arg3> "," <control:U16Num> => ast::Instruction::Prmt{ <> }
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-rem
InstRem: ast::Instruction<ast::ParsedArgParams<'input>> = {
"rem" <typ:IntType> <arg:Arg3> => ast::Instruction::Rem{ <> }
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-activemask
InstActivemask: ast::Instruction<ast::ParsedArgParams<'input>> = {
"activemask" ".b32" <arg:Arg1> => ast::Instruction::Activemask{ <> }
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-membar
InstMembar: ast::Instruction<ast::ParsedArgParams<'input>> = {
"membar" <level:MembarLevel> => ast::Instruction::Membar{ <> }
}
NegTypeFtz: ast::ScalarType = {
".f16" => ast::ScalarType::F16,

View file

@ -0,0 +1,18 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry activemask(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 out_addr;
.reg .b32 temp;
ld.param.u64 out_addr, [output];
activemask.b32 temp;
st.u32 [out_addr], temp;
ret;
}

View file

@ -0,0 +1,45 @@
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int8
OpCapability Int16
OpCapability Int64
OpCapability Float16
OpCapability Float64
%16 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "activemask"
OpExecutionMode %1 ContractionOff
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%19 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%uint = OpTypeInt 32 0
%_ptr_Function_uint = OpTypePointer Function %uint
%v4uint = OpTypeVector %uint 4
%bool = OpTypeBool
%true = OpConstantTrue %bool
%_ptr_Generic_uint = OpTypePointer Generic %uint
%1 = OpFunction %void None %19
%6 = OpFunctionParameter %ulong
%7 = OpFunctionParameter %ulong
%14 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
%5 = OpVariable %_ptr_Function_uint Function
OpStore %2 %6
OpStore %3 %7
%8 = OpLoad %ulong %3 Aligned 8
OpStore %4 %8
%26 = OpSubgroupBallotKHR %v4uint %true
%9 = OpCompositeExtract %uint %26 0
OpStore %5 %9
%10 = OpLoad %ulong %4
%11 = OpLoad %uint %5
%12 = OpConvertUToPtr %_ptr_Generic_uint %10
%13 = OpCopyObject %uint %11
OpStore %12 %13 Aligned 4
OpReturn
OpFunctionEnd

View file

@ -0,0 +1,21 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry membar(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .s32 temp;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.u32 temp, [in_addr];
membar.sys;
st.s32 [out_addr], temp;
ret;
}

View file

@ -0,0 +1,49 @@
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int8
OpCapability Int16
OpCapability Int64
OpCapability Float16
OpCapability Float64
%20 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "membar"
OpExecutionMode %1 ContractionOff
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%23 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%uint = OpTypeInt 32 0
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_Generic_uint = OpTypePointer Generic %uint
%uint_0 = OpConstant %uint 0
%uint_784 = OpConstant %uint 784
%1 = OpFunction %void None %23
%7 = OpFunctionParameter %ulong
%8 = OpFunctionParameter %ulong
%18 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
%5 = OpVariable %_ptr_Function_ulong Function
%6 = OpVariable %_ptr_Function_uint Function
OpStore %2 %7
OpStore %3 %8
%9 = OpLoad %ulong %2 Aligned 8
OpStore %4 %9
%10 = OpLoad %ulong %3 Aligned 8
OpStore %5 %10
%12 = OpLoad %ulong %4
%16 = OpConvertUToPtr %_ptr_Generic_uint %12
%15 = OpLoad %uint %16 Aligned 4
%11 = OpCopyObject %uint %15
OpStore %6 %11
OpMemoryBarrier %uint_0 %uint_784
%13 = OpLoad %ulong %5
%14 = OpLoad %uint %6
%17 = OpConvertUToPtr %_ptr_Generic_uint %13
OpStore %17 %14 Aligned 4
OpReturn
OpFunctionEnd

View file

@ -206,6 +206,9 @@ test_ptx!(stateful_neg_offset, [1237518u64], [1237518u64]);
test_ptx!(const, [0u16], [10u16, 20, 30, 40]);
test_ptx!(cvt_s16_s8, [0x139231C2u32], [0xFFFFFFC2u32]);
test_ptx!(cvt_f64_f32, [0.125f32], [0.125f64]);
test_ptx!(prmt, [0x70c507d6u32, 0x6fbd4b5cu32], [0x6fbdd65cu32]);
test_ptx!(activemask, [0u32], [1u32]);
test_ptx!(membar, [152731u32], [152731u32]);
struct DisplayError<T: Debug> {
err: T,

View file

@ -0,0 +1,23 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry prmt(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .u32 temp1;
.reg .u32 temp2;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.u32 temp1, [in_addr];
ld.u32 temp2, [in_addr+4];
prmt.b32 temp2, temp1, temp2, 30212;
st.u32 [out_addr], temp2;
ret;
}

View file

@ -0,0 +1,67 @@
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int8
OpCapability Int16
OpCapability Int64
OpCapability Float16
OpCapability Float64
%31 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "prmt"
OpExecutionMode %1 ContractionOff
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%34 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%uint = OpTypeInt 32 0
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_Generic_uint = OpTypePointer Generic %uint
%ulong_4 = OpConstant %ulong 4
%uchar = OpTypeInt 8 0
%_ptr_Generic_uchar = OpTypePointer Generic %uchar
%v4uchar = OpTypeVector %uchar 4
%1 = OpFunction %void None %34
%8 = OpFunctionParameter %ulong
%9 = OpFunctionParameter %ulong
%29 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
%5 = OpVariable %_ptr_Function_ulong Function
%6 = OpVariable %_ptr_Function_uint Function
%7 = OpVariable %_ptr_Function_uint Function
OpStore %2 %8
OpStore %3 %9
%10 = OpLoad %ulong %2 Aligned 8
OpStore %4 %10
%11 = OpLoad %ulong %3 Aligned 8
OpStore %5 %11
%13 = OpLoad %ulong %4
%23 = OpConvertUToPtr %_ptr_Generic_uint %13
%12 = OpLoad %uint %23 Aligned 4
OpStore %6 %12
%15 = OpLoad %ulong %4
%24 = OpConvertUToPtr %_ptr_Generic_uint %15
%41 = OpBitcast %_ptr_Generic_uchar %24
%42 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %41 %ulong_4
%22 = OpBitcast %_ptr_Generic_uint %42
%14 = OpLoad %uint %22 Aligned 4
OpStore %7 %14
%17 = OpLoad %uint %6
%18 = OpLoad %uint %7
%26 = OpCopyObject %uint %17
%27 = OpCopyObject %uint %18
%44 = OpBitcast %v4uchar %26
%45 = OpBitcast %v4uchar %27
%46 = OpVectorShuffle %v4uchar %44 %45 4 0 6 7
%25 = OpBitcast %uint %46
%16 = OpCopyObject %uint %25
OpStore %7 %16
%19 = OpLoad %ulong %5
%20 = OpLoad %uint %7
%28 = OpConvertUToPtr %_ptr_Generic_uint %19
OpStore %28 %20 Aligned 4
OpReturn
OpFunctionEnd

View file

@ -2992,6 +2992,76 @@ fn emit_function_body_ops<'input>(
let result_type = map.get_or_add_scalar(builder, (*typ).into());
builder_fn(builder, result_type, Some(arg.dst), arg.src1, arg.src2)?;
}
ast::Instruction::Prmt { control, arg } => {
let control = *control as u32;
let components = [
(control >> 0) & 0b1111,
(control >> 4) & 0b1111,
(control >> 8) & 0b1111,
(control >> 12) & 0b1111,
];
if components.iter().any(|&c| c > 7) {
return Err(TranslateError::Todo);
}
let vec4_b8_type =
map.get_or_add(builder, SpirvType::Vector(SpirvScalarKey::B8, 4));
let b32_type = map.get_or_add_scalar(builder, ast::ScalarType::B32);
let src1_vector = builder.bitcast(vec4_b8_type, None, arg.src1)?;
let src2_vector = builder.bitcast(vec4_b8_type, None, arg.src2)?;
let dst_vector = builder.vector_shuffle(
vec4_b8_type,
None,
src1_vector,
src2_vector,
components,
)?;
builder.bitcast(b32_type, Some(arg.dst), dst_vector)?;
}
ast::Instruction::Activemask { arg } => {
let b32_type = map.get_or_add_scalar(builder, ast::ScalarType::B32);
let vec4_b32_type =
map.get_or_add(builder, SpirvType::Vector(SpirvScalarKey::B32, 4));
let pred_true = map.get_or_add_constant(
builder,
&ast::Type::Scalar(ast::ScalarType::Pred),
&[1],
)?;
let dst_vector = builder.subgroup_ballot_khr(vec4_b32_type, None, pred_true)?;
builder.composite_extract(b32_type, Some(arg.src), dst_vector, [0])?;
}
ast::Instruction::Membar { level } => {
let (scope, semantics) = match level {
ast::MemScope::Cta => (
spirv::Scope::Workgroup,
spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY
| spirv::MemorySemantics::WORKGROUP_MEMORY
| spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT,
),
ast::MemScope::Gpu => (
spirv::Scope::Device,
spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY
| spirv::MemorySemantics::WORKGROUP_MEMORY
| spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT,
),
ast::MemScope::Sys => (
spirv::Scope::CrossDevice,
spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY
| spirv::MemorySemantics::WORKGROUP_MEMORY
| spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT,
),
};
let spirv_scope = map.get_or_add_constant(
builder,
&ast::Type::Scalar(ast::ScalarType::U32),
&vec_repr(scope as u32),
)?;
let spirv_semantics = map.get_or_add_constant(
builder,
&ast::Type::Scalar(ast::ScalarType::U32),
&vec_repr(semantics),
)?;
builder.memory_barrier(spirv_scope, spirv_semantics)?;
}
},
Statement::LoadVar(details) => {
emit_load_var(builder, map, details)?;
@ -4172,7 +4242,6 @@ fn normalize_identifiers<'input, 'b>(
match s {
ast::Statement::Label(id) => {
id_defs.add_def(*id, None, false);
eprintln!("{}", id);
}
_ => (),
}
@ -5800,7 +5869,7 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
let new_args = a.map(visitor, &d)?;
ast::Instruction::St(d, new_args)
}
ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, None)?),
ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, false, None)?),
ast::Instruction::Ret(d) => ast::Instruction::Ret(d),
ast::Instruction::Cvta(d, a) => {
let inst_type = ast::Type::Scalar(ast::ScalarType::B64);
@ -5942,6 +6011,21 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
arg: arg.map_non_shift(visitor, &full_type, false)?,
}
}
ast::Instruction::Prmt { control, arg } => ast::Instruction::Prmt {
control,
arg: arg.map_prmt(visitor)?,
},
ast::Instruction::Activemask { arg } => ast::Instruction::Activemask {
arg: arg.map(
visitor,
true,
Some((
&ast::Type::Scalar(ast::ScalarType::B32),
ast::StateSpace::Reg,
)),
)?,
},
ast::Instruction::Membar { level } => ast::Instruction::Membar { level },
})
}
}
@ -6202,6 +6286,9 @@ impl ast::Instruction<ExpandedArgParams> {
ast::Instruction::Bfe { .. } => None,
ast::Instruction::Bfi { .. } => None,
ast::Instruction::Rem { .. } => None,
ast::Instruction::Prmt { .. } => None,
ast::Instruction::Activemask { .. } => None,
ast::Instruction::Membar { .. } => None,
ast::Instruction::Sub(ast::ArithDetails::Float(float_control), _)
| ast::Instruction::Add(ast::ArithDetails::Float(float_control), _)
| ast::Instruction::Mul(ast::MulDetails::Float(float_control), _)
@ -6339,12 +6426,13 @@ impl<T: ArgParamsEx> ast::Arg1<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
is_dst: bool,
t: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<ast::Arg1<U>, TranslateError> {
let new_src = visitor.id(
ArgumentDescriptor {
op: self.src,
is_dst: false,
is_dst,
is_memory_access: false,
non_default_implicit_conversion: None,
},
@ -6685,6 +6773,43 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
)?;
Ok(ast::Arg3 { dst, src1, src2 })
}
fn map_prmt<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
) -> Result<ast::Arg3<U>, TranslateError> {
let dst = visitor.operand(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
is_memory_access: false,
non_default_implicit_conversion: None,
},
&ast::Type::Scalar(ast::ScalarType::B32),
ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
is_memory_access: false,
non_default_implicit_conversion: None,
},
&ast::Type::Scalar(ast::ScalarType::B32),
ast::StateSpace::Reg,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
is_memory_access: false,
non_default_implicit_conversion: None,
},
&ast::Type::Scalar(ast::ScalarType::B32),
ast::StateSpace::Reg,
)?;
Ok(ast::Arg3 { dst, src1, src2 })
}
}
impl<T: ArgParamsEx> ast::Arg4<T> {