From 6480cccc4fb129eb0f9bfd0a0ade6895d04ff55e Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sun, 25 Oct 2020 11:21:51 +0100 Subject: [PATCH] Implement rcp instruction --- ptx/src/ast.rs | 7 ++++ ptx/src/ptx.lalrpop | 29 ++++++++++++++ ptx/src/test/spirv_run/mod.rs | 1 + ptx/src/test/spirv_run/rcp.ptx | 21 ++++++++++ ptx/src/test/spirv_run/rcp.spvtxt | 51 ++++++++++++++++++++++++ ptx/src/translate.rs | 64 +++++++++++++++++++++++++------ 6 files changed, 162 insertions(+), 11 deletions(-) create mode 100644 ptx/src/test/spirv_run/rcp.ptx create mode 100644 ptx/src/test/spirv_run/rcp.spvtxt diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 1cbe721..f7cdcc3 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -510,6 +510,7 @@ pub enum Instruction { Sub(ArithDetails, Arg3

), Min(MinMaxDetails, Arg3

), Max(MinMaxDetails, Arg3

), + Rcp(RcpDetails, Arg2

), } #[derive(Copy, Clone)] @@ -520,6 +521,12 @@ pub struct AbsDetails { pub flush_to_zero: bool, pub typ: ScalarType, } +#[derive(Copy, Clone)] +pub struct RcpDetails { + pub rounding: Option, + pub flush_to_zero: bool, + pub is_f64: bool, +} pub struct CallInst { pub uniform: bool, diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index c29d16b..a132705 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -35,6 +35,7 @@ match { ".address_size", ".align", ".and", + ".approx", ".b16", ".b32", ".b64", @@ -134,6 +135,7 @@ match { "mul", "not", "or", + "rcp", "ret", "setp", "shl", @@ -166,6 +168,7 @@ ExtendedID : &'input str = { "mul", "not", "or", + "rcp", "ret", "setp", "shl", @@ -542,6 +545,7 @@ Instruction: ast::Instruction> = { InstSub, InstMin, InstMax, + InstRcp }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld @@ -1119,6 +1123,31 @@ OrType: ast::OrType = { ".b64" => ast::OrType::B64, } +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rcp +InstRcp: ast::Instruction> = { + "rcp" ".f32" => { + let details = ast::RcpDetails { + rounding, + flush_to_zero: ftz.is_some(), + is_f64: false, + }; + ast::Instruction::Rcp(details, a) + }, + "rcp" ".f64" => { + let details = ast::RcpDetails { + rounding: Some(rn), + flush_to_zero: false, + is_f64: true, + }; + ast::Instruction::Rcp(details, a) + } +}; + +RcpRoundingMode: Option = { + ".approx" => None, + => Some(r) +}; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-sub // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sub // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-sub diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 3a8acb1..b4ae149 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -80,6 +80,7 @@ test_ptx!(max, [555i32, 444i32], [555i32]); test_ptx!(global_array, [0xDEADu32], [1u32]); test_ptx!(extern_shared, [127u64], [127u64]); test_ptx!(extern_shared_call, [121u64], [123u64]); +test_ptx!(rcp, [2f32], [0.5f32]); struct DisplayError { err: T, diff --git a/ptx/src/test/spirv_run/rcp.ptx b/ptx/src/test/spirv_run/rcp.ptx new file mode 100644 index 0000000..eb02d7e --- /dev/null +++ b/ptx/src/test/spirv_run/rcp.ptx @@ -0,0 +1,21 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry rcp( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .f32 temp; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.f32 temp, [in_addr]; + rcp.approx.f32 temp, temp; + st.f32 [out_addr], temp; + ret; +} diff --git a/ptx/src/test/spirv_run/rcp.spvtxt b/ptx/src/test/spirv_run/rcp.spvtxt new file mode 100644 index 0000000..08b3e6e --- /dev/null +++ b/ptx/src/test/spirv_run/rcp.spvtxt @@ -0,0 +1,51 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %23 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "rcp" + OpDecorate %15 FPFastMathMode AllowRecip + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %26 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %float = OpTypeFloat 32 +%_ptr_Function_float = OpTypePointer Function %float +%_ptr_Generic_float = OpTypePointer Generic %float + %float_1 = OpConstant %float 1 + %1 = OpFunction %void None %26 + %7 = OpFunctionParameter %ulong + %8 = OpFunctionParameter %ulong + %21 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_float Function + OpStore %2 %7 + OpStore %3 %8 + %10 = OpLoad %ulong %2 + %9 = OpCopyObject %ulong %10 + OpStore %4 %9 + %12 = OpLoad %ulong %3 + %11 = OpCopyObject %ulong %12 + OpStore %5 %11 + %14 = OpLoad %ulong %4 + %19 = OpConvertUToPtr %_ptr_Generic_float %14 + %13 = OpLoad %float %19 + OpStore %6 %13 + %16 = OpLoad %float %6 + %15 = OpFDiv %float %float_1 %16 + OpStore %6 %15 + %17 = OpLoad %ulong %5 + %18 = OpLoad %float %6 + %20 = OpConvertUToPtr %_ptr_Generic_float %17 + OpStore %20 %18 + OpReturn + OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index ab7187f..cccf6ad 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1144,6 +1144,9 @@ fn convert_to_typed_statements( ast::Instruction::Max(d, a) => { result.push(Statement::Instruction(ast::Instruction::Max(d, a.cast()))) } + ast::Instruction::Rcp(d, a) => { + result.push(Statement::Instruction(ast::Instruction::Rcp(d, a.cast()))) + } }, Statement::Label(i) => result.push(Statement::Label(i)), Statement::Variable(v) => result.push(Statement::Variable(v)), @@ -2179,6 +2182,9 @@ fn emit_function_body_ops( ast::Instruction::Max(d, a) => { emit_max(builder, map, opencl, d, a)?; } + ast::Instruction::Rcp(d, a) => { + emit_rcp(builder, map, d, a)?; + } }, Statement::LoadVar(arg, typ) => { let type_id = map.get_or_add(builder, SpirvType::from(typ.clone())); @@ -2209,6 +2215,40 @@ fn emit_function_body_ops( Ok(()) } +fn emit_rcp( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + desc: &ast::RcpDetails, + a: &ast::Arg2, +) -> Result<(), TranslateError> { + if desc.flush_to_zero { + todo!() + } + let (instr_type, constant) = if desc.is_f64 { + (ast::ScalarType::F64, vec_repr(1.0f64)) + } else { + (ast::ScalarType::F32, vec_repr(1.0f32)) + }; + let one = map.get_or_add_constant(builder, &ast::Type::Scalar(instr_type), &constant)?; + let result_type = map.get_or_add_scalar(builder, instr_type); + builder.f_div(result_type, Some(a.dst), one, a.src)?; + emit_rounding_decoration(builder, a.dst, desc.rounding); + builder.decorate( + a.dst, + spirv::Decoration::FPFastMathMode, + &[dr::Operand::FPFastMathMode( + spirv::FPFastMathMode::ALLOW_RECIP, + )], + ); + Ok(()) +} + +fn vec_repr(t: T) -> Vec { + let mut result = vec![0; mem::size_of::()]; + unsafe { std::ptr::copy_nonoverlapping(&t, result.as_mut_ptr() as *mut _, 1) }; + result +} + fn emit_variable( builder: &mut dr::Builder, map: &mut TypeWordMap, @@ -3735,7 +3775,7 @@ impl ast::Instruction { ) -> Result, TranslateError> { Ok(match self { ast::Instruction::Abs(d, arg) => { - ast::Instruction::Abs(d, arg.map(visitor, false, &ast::Type::Scalar(d.typ))?) + ast::Instruction::Abs(d, arg.map(visitor, &ast::Type::Scalar(d.typ))?) } // Call instruction is converted to a call statement early on ast::Instruction::Call(_) => return Err(TranslateError::Unreachable), @@ -3766,9 +3806,7 @@ impl ast::Instruction { let inst_type = d.typ; ast::Instruction::SetpBool(d, a.map(visitor, &ast::Type::Scalar(inst_type))?) } - ast::Instruction::Not(t, a) => { - ast::Instruction::Not(t, a.map(visitor, false, &t.to_type())?) - } + ast::Instruction::Not(t, a) => ast::Instruction::Not(t, a.map(visitor, &t.to_type())?), ast::Instruction::Cvt(d, a) => { let (dst_t, src_t) = match &d { ast::CvtDetails::FloatFromFloat(desc) => ( @@ -3806,7 +3844,7 @@ impl ast::Instruction { ast::Instruction::Ret(d) => ast::Instruction::Ret(d), ast::Instruction::Cvta(d, a) => { let inst_type = ast::Type::Scalar(ast::ScalarType::B64); - ast::Instruction::Cvta(d, a.map(visitor, false, &inst_type)?) + ast::Instruction::Cvta(d, a.map(visitor, &inst_type)?) } ast::Instruction::Mad(d, a) => { let inst_type = d.get_type(); @@ -3829,6 +3867,14 @@ impl ast::Instruction { let typ = d.get_type(); ast::Instruction::Max(d, a.map_non_shift(visitor, &typ, false)?) } + ast::Instruction::Rcp(d, a) => { + let typ = ast::Type::Scalar(if d.is_f64 { + ast::ScalarType::F64 + } else { + ast::ScalarType::F32 + }); + ast::Instruction::Rcp(d, a.map(visitor, &typ)?) + } }) } } @@ -4072,6 +4118,7 @@ impl ast::Instruction { | ast::Instruction::Sub(_, _) | ast::Instruction::Min(_, _) | ast::Instruction::Max(_, _) + | ast::Instruction::Rcp(_, _) | ast::Instruction::Mad(_, _) => None, } } @@ -4289,7 +4336,6 @@ impl ast::Arg2 { fn map>( self, visitor: &mut V, - src_is_addr: bool, t: &ast::Type, ) -> Result, TranslateError> { let new_dst = visitor.id( @@ -4304,11 +4350,7 @@ impl ast::Arg2 { ArgumentDescriptor { op: self.src, is_dst: false, - sema: if src_is_addr { - ArgumentSemantics::Address - } else { - ArgumentSemantics::Default - }, + sema: ArgumentSemantics::Default, }, t, )?;