diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index f4502af..1266ea4 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -201,6 +201,20 @@ sub_enum!(LdStScalarType { F64, }); +sub_enum!(SelpType { + B16, + B32, + B64, + U16, + U32, + U64, + S16, + S32, + S64, + F32, + F64, +}); + pub trait UnwrapWithVec { fn unwrap_with(self, errs: &mut Vec) -> To; } @@ -512,6 +526,7 @@ pub enum Instruction { Max(MinMaxDetails, Arg3

), Rcp(RcpDetails, Arg2

), And(OrAndType, Arg3

), + Selp(SelpType, Arg4

), } #[derive(Copy, Clone)] diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 7414443..025f0be 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -142,6 +142,7 @@ match { "or", "rcp", "ret", + "selp", "setp", "shl", "shr", @@ -176,6 +177,7 @@ ExtendedID : &'input str = { "or", "rcp", "ret", + "selp", "setp", "shl", "shr", @@ -614,7 +616,8 @@ Instruction: ast::Instruction> = { InstSub, InstMin, InstMax, - InstRcp + InstRcp, + InstSelp }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld @@ -1271,6 +1274,25 @@ MinMaxDetails: ast::MinMaxDetails = { ) } +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-selp +InstSelp: ast::Instruction> = { + "selp" => ast::Instruction::Selp(t, a), +}; + +SelpType: ast::SelpType = { + ".b16" => ast::SelpType::B16, + ".b32" => ast::SelpType::B32, + ".b64" => ast::SelpType::B64, + ".u16" => ast::SelpType::U16, + ".u32" => ast::SelpType::U32, + ".u64" => ast::SelpType::U64, + ".s16" => ast::SelpType::S16, + ".s32" => ast::SelpType::S32, + ".s64" => ast::SelpType::S64, + ".f32" => ast::SelpType::F32, + ".f64" => ast::SelpType::F64, +}; + ArithDetails: ast::ArithDetails = { => ast::ArithDetails::Unsigned(t), => ast::ArithDetails::Signed(ast::ArithSInt { diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index dfdec72..f336055 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -90,6 +90,7 @@ test_ptx!(mul_non_ftz, [0b1_00000000_10000000000000000000000u32, 0x3f000000u32], test_ptx!(constant_f32, [10f32], [5f32]); test_ptx!(constant_negative, [-101i32], [101i32]); test_ptx!(and, [6u32, 3u32], [2u32]); +test_ptx!(selp, [100u16, 200u16], [200u16]); struct DisplayError { err: T, diff --git a/ptx/src/test/spirv_run/selp.ptx b/ptx/src/test/spirv_run/selp.ptx new file mode 100644 index 0000000..79171dc --- /dev/null +++ b/ptx/src/test/spirv_run/selp.ptx @@ -0,0 +1,23 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry selp( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u16 temp1; + .reg .u16 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.u16 temp1, [in_addr]; + ld.u16 temp2, [in_addr + 2]; + selp.u16 temp1, temp1, temp2, 0; + st.u16 [out_addr], temp1; + ret; +} diff --git a/ptx/src/test/spirv_run/selp.spvtxt b/ptx/src/test/spirv_run/selp.spvtxt new file mode 100644 index 0000000..dffd9af --- /dev/null +++ b/ptx/src/test/spirv_run/selp.spvtxt @@ -0,0 +1,65 @@ +; SPIR-V +; Version: 1.3 +; Generator: rspirv +; Bound: 40 +OpCapability GenericPointer +OpCapability Linkage +OpCapability Addresses +OpCapability Kernel +OpCapability Int8 +OpCapability Int16 +OpCapability Int64 +OpCapability Float16 +OpCapability Float64 +OpCapability FunctionFloatControlINTEL +OpExtension "SPV_INTEL_float_controls2" +%31 = OpExtInstImport "OpenCL.std" +OpMemoryModel Physical64 OpenCL +OpEntryPoint Kernel %1 "selp" +%32 = OpTypeVoid +%33 = OpTypeInt 64 0 +%34 = OpTypeFunction %32 %33 %33 +%35 = OpTypePointer Function %33 +%36 = OpTypeInt 16 0 +%37 = OpTypePointer Function %36 +%38 = OpTypePointer Generic %36 +%23 = OpConstant %33 2 +%39 = OpTypeBool +%25 = OpConstantFalse %39 +%1 = OpFunction %32 None %34 +%8 = OpFunctionParameter %33 +%9 = OpFunctionParameter %33 +%29 = OpLabel +%2 = OpVariable %35 Function +%3 = OpVariable %35 Function +%4 = OpVariable %35 Function +%5 = OpVariable %35 Function +%6 = OpVariable %37 Function +%7 = OpVariable %37 Function +OpStore %2 %8 +OpStore %3 %9 +%11 = OpLoad %33 %2 +%10 = OpCopyObject %33 %11 +OpStore %4 %10 +%13 = OpLoad %33 %3 +%12 = OpCopyObject %33 %13 +OpStore %5 %12 +%15 = OpLoad %33 %4 +%26 = OpConvertUToPtr %38 %15 +%14 = OpLoad %36 %26 +OpStore %6 %14 +%17 = OpLoad %33 %4 +%24 = OpIAdd %33 %17 %23 +%27 = OpConvertUToPtr %38 %24 +%16 = OpLoad %36 %27 +OpStore %7 %16 +%19 = OpLoad %36 %6 +%20 = OpLoad %36 %7 +%18 = OpSelect %36 %25 %20 %20 +OpStore %6 %18 +%21 = OpLoad %33 %5 +%22 = OpLoad %36 %6 +%28 = OpConvertUToPtr %38 %21 +OpStore %28 %22 +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index c699cc4..9d73742 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1266,6 +1266,9 @@ fn convert_to_typed_statements( ast::Instruction::And(d, a) => { result.push(Statement::Instruction(ast::Instruction::And(d, a.cast()))) } + ast::Instruction::Selp(d, a) => { + result.push(Statement::Instruction(ast::Instruction::Selp(d, a.cast()))) + } }, Statement::Label(i) => result.push(Statement::Label(i)), Statement::Variable(v) => result.push(Statement::Variable(v)), @@ -2159,6 +2162,22 @@ fn emit_function_body_ops( (ast::ScalarType::F64, ast::ImmediateValue::F64(value)) => { builder.constant_f64(typ_id, Some(cnst.dst), value); } + (ast::ScalarType::Pred, ast::ImmediateValue::U64(value)) => { + let bool_type = map.get_or_add_scalar(builder, ast::ScalarType::Pred); + if value == 0 { + builder.constant_false(bool_type, Some(cnst.dst)); + } else { + builder.constant_true(bool_type, Some(cnst.dst)); + } + } + (ast::ScalarType::Pred, ast::ImmediateValue::S64(value)) => { + let bool_type = map.get_or_add_scalar(builder, ast::ScalarType::Pred); + if value == 0 { + builder.constant_false(bool_type, Some(cnst.dst)); + } else { + builder.constant_true(bool_type, Some(cnst.dst)); + } + } _ => return Err(TranslateError::MismatchedType), } } @@ -2362,6 +2381,10 @@ fn emit_function_body_ops( builder.bitwise_and(result_type, Some(a.dst), a.src1, a.src2)?; } } + ast::Instruction::Selp(t, a) => { + let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t)); + builder.select(result_type, Some(a.dst), a.src3, a.src2, a.src2)?; + } }, Statement::LoadVar(arg, typ) => { let type_id = map.get_or_add(builder, SpirvType::from(typ.clone())); @@ -4056,6 +4079,7 @@ impl ast::Instruction { t, a.map_non_shift(visitor, &ast::Type::Scalar(t.into()), false)?, ), + ast::Instruction::Selp(t, a) => ast::Instruction::Selp(t, a.map_selp(visitor, t)?), }) } } @@ -4301,6 +4325,7 @@ impl ast::Instruction { | ast::Instruction::Max(_, _) | ast::Instruction::Rcp(_, _) | ast::Instruction::And(_, _) + | ast::Instruction::Selp(_, _) | ast::Instruction::Mad(_, _) => None, } } @@ -4321,6 +4346,7 @@ impl ast::Instruction { ast::Instruction::Or(_, _) => None, ast::Instruction::And(_, _) => None, ast::Instruction::Cvta(_, _) => None, + ast::Instruction::Selp(_, _) => None, ast::Instruction::Sub(ast::ArithDetails::Signed(_), _) => None, ast::Instruction::Sub(ast::ArithDetails::Unsigned(_), _) => None, ast::Instruction::Add(ast::ArithDetails::Signed(_), _) => None, @@ -5047,6 +5073,51 @@ impl ast::Arg4 { src3, }) } + + fn map_selp>( + self, + visitor: &mut V, + t: ast::SelpType, + ) -> Result, TranslateError> { + let dst = visitor.id( + ArgumentDescriptor { + op: self.dst, + is_dst: true, + sema: ArgumentSemantics::Default, + }, + Some(&ast::Type::Scalar(t.into())), + )?; + let src1 = visitor.operand( + ArgumentDescriptor { + op: self.src1, + is_dst: false, + sema: ArgumentSemantics::Default, + }, + &ast::Type::Scalar(t.into()), + )?; + let src2 = visitor.operand( + ArgumentDescriptor { + op: self.src2, + is_dst: false, + sema: ArgumentSemantics::Default, + }, + &ast::Type::Scalar(t.into()), + )?; + let src3 = visitor.operand( + ArgumentDescriptor { + op: self.src3, + is_dst: false, + sema: ArgumentSemantics::Default, + }, + &ast::Type::Scalar(ast::ScalarType::Pred), + )?; + Ok(ast::Arg4 { + dst, + src1, + src2, + src3, + }) + } } impl ast::Arg4Setp {