From d47cd1e133995a08af15edd23c476ebf6d5cabf8 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Wed, 5 Aug 2020 23:50:20 +0200 Subject: [PATCH] Add support for cvta and global ld/st --- ptx/src/ast.rs | 23 +++++++++++++++- ptx/src/ptx.lalrpop | 36 +++++++++++++++++++++++++ ptx/src/test/spirv_run/cvta.ptx | 23 ++++++++++++++++ ptx/src/test/spirv_run/cvta.spvtxt | 42 +++++++++++++++++++++++++++++ ptx/src/test/spirv_run/mod.rs | 3 ++- ptx/src/translate.rs | 43 +++++++++++++++++++++++------- 6 files changed, 158 insertions(+), 12 deletions(-) create mode 100644 ptx/src/test/spirv_run/cvta.ptx create mode 100644 ptx/src/test/spirv_run/cvta.spvtxt diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index a2c6d66..ed58d42 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -207,6 +207,7 @@ pub enum Instruction { Not(NotType, Arg2

), Bra(BraData, Arg1

), Cvt(CvtDetails, Arg2

), + Cvta(CvtaDetails, Arg2

), Shl(ShlType, Arg3

), St(StData, Arg2St

), Ret(RetData), @@ -308,7 +309,7 @@ pub enum LdScope { Sys, } -#[derive(Copy, Clone, PartialEq, Eq)] +#[derive(Copy, Clone, PartialEq, Eq, Debug)] pub enum LdStateSpace { Generic, Const, @@ -511,6 +512,26 @@ impl CvtDetails { } } +pub struct CvtaDetails { + pub to: CvtaStateSpace, + pub from: CvtaStateSpace, + pub size: CvtaSize, +} + +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum CvtaStateSpace { + Generic, + Const, + Global, + Local, + Shared, +} + +pub enum CvtaSize { + U32, + U64, +} + #[derive(PartialEq, Eq, Copy, Clone)] pub enum ShlType { B16, diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 5f97e6c..66e831e 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -90,6 +90,7 @@ match { ".sreg", ".sys", ".target", + ".to", ".u16", ".u32", ".u64", @@ -110,6 +111,7 @@ match { "add", "bra", "cvt", + "cvta", "debug", "ld", "map_f64_to_f32", @@ -136,6 +138,7 @@ ExtendedID : &'input str = { "add", "bra", "cvt", + "cvta", "debug", "ld", "map_f64_to_f32", @@ -322,6 +325,7 @@ Instruction: ast::Instruction> = { InstShl, InstSt, InstRet, + InstCvta, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld @@ -783,6 +787,38 @@ InstRet: ast::Instruction> = { "ret" => ast::Instruction::Ret(ast::RetData { uniform: u.is_some() }) }; +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvta +InstCvta: ast::Instruction> = { + "cvta" => { + ast::Instruction::Cvta(ast::CvtaDetails { + to: to, + from: ast::CvtaStateSpace::Generic, + size: s + }, + a) + }, + "cvta" ".to" => { + ast::Instruction::Cvta(ast::CvtaDetails { + to: ast::CvtaStateSpace::Generic, + from: from, + size: s + }, + a) + } +} + +CvtaStateSpace: ast::CvtaStateSpace = { + ".const" => ast::CvtaStateSpace::Const, + ".global" => ast::CvtaStateSpace::Global, + ".local" => ast::CvtaStateSpace::Local, + ".shared" => ast::CvtaStateSpace::Shared, +} + +CvtaSize: ast::CvtaSize = { + ".u32" => ast::CvtaSize::U32, + ".u64" => ast::CvtaSize::U64, +} + Operand: ast::Operand<&'input str> = { => ast::Operand::Reg(r), "+" => { diff --git a/ptx/src/test/spirv_run/cvta.ptx b/ptx/src/test/spirv_run/cvta.ptx new file mode 100644 index 0000000..c24c959 --- /dev/null +++ b/ptx/src/test/spirv_run/cvta.ptx @@ -0,0 +1,23 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry cvta( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .f32 temp; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + cvta.to.global.u64 in_addr, in_addr; + cvta.to.global.u64 out_addr, out_addr; + + ld.global.f32 temp, [in_addr]; + st.global.f32 [out_addr], temp; + ret; +} \ No newline at end of file diff --git a/ptx/src/test/spirv_run/cvta.spvtxt b/ptx/src/test/spirv_run/cvta.spvtxt new file mode 100644 index 0000000..1aa7425 --- /dev/null +++ b/ptx/src/test/spirv_run/cvta.spvtxt @@ -0,0 +1,42 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + %1 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %5 "cvta" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %4 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %float = OpTypeFloat 32 +%_ptr_Function_float = OpTypePointer Function %float +%_ptr_CrossWorkgroup_float = OpTypePointer CrossWorkgroup %float + %5 = OpFunction %void None %4 + %6 = OpFunctionParameter %ulong + %7 = OpFunctionParameter %ulong + %21 = OpLabel + %8 = OpVariable %_ptr_Function_ulong Function + %9 = OpVariable %_ptr_Function_ulong Function + %10 = OpVariable %_ptr_Function_float Function + OpStore %8 %6 + OpStore %9 %7 + %12 = OpLoad %ulong %8 + %11 = OpCopyObject %ulong %12 + OpStore %8 %11 + %14 = OpLoad %ulong %9 + %13 = OpCopyObject %ulong %14 + OpStore %9 %13 + %16 = OpLoad %ulong %8 + %19 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %16 + %15 = OpLoad %float %19 + OpStore %10 %15 + %17 = OpLoad %ulong %9 + %18 = OpLoad %float %10 + %20 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %17 + OpStore %20 %18 + OpReturn + OpFunctionEnd + \ No newline at end of file diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index e1e5c32..c159280 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -48,7 +48,8 @@ test_ptx!(setp, [10u64, 11u64], [1u64, 0u64]); test_ptx!(bra, [10u64], [11u64]); test_ptx!(not, [0u64], [u64::max_value()]); test_ptx!(shl, [11u64], [44u64]); -test_ptx!(cvt_sat_s_u, [0i32], [0i32]); +test_ptx!(cvt_sat_s_u, [-1i32], [0i32]); +test_ptx!(cvta, [3.0f32], [3.0f32]); struct DisplayError { err: T, diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 511ef72..ebce1dd 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -671,7 +671,7 @@ fn emit_function_body_ops( } let result_type = map.get_or_add_scalar(builder, data.typ); match data.state_space { - ast::LdStateSpace::Generic => { + ast::LdStateSpace::Generic | ast::LdStateSpace::Global => { builder.load(result_type, Some(arg.dst), arg.src, None, [])?; } ast::LdStateSpace::Param => { @@ -683,7 +683,8 @@ fn emit_function_body_ops( ast::Instruction::St(data, arg) => { if data.qualifier != ast::LdStQualifier::Weak || data.vector.is_some() - || data.state_space != ast::StStateSpace::Generic + || (data.state_space != ast::StStateSpace::Generic + && data.state_space != ast::StStateSpace::Global) { todo!() } @@ -729,6 +730,13 @@ fn emit_function_body_ops( ast::Instruction::Cvt(dets, arg) => { emit_cvt(builder, map, opencl, dets, arg)?; } + ast::Instruction::Cvta(_, arg) => { + // This would be only meaningful if const/slm/global pointers + // had a different format than generic pointers, but they don't pretty much by ptx definition + // Honestly, I have no idea why this instruction exists and is emitted by the compiler + let result_type = map.get_or_add_scalar(builder, ast::ScalarType::B64); + builder.copy_object(result_type, Some(arg.dst), arg.src)?; + } ast::Instruction::SetpBool(_, _) => todo!(), }, Statement::LoadVar(arg, typ) => { @@ -997,13 +1005,10 @@ fn emit_implicit_conversion( _ => todo!(), }; match cv.kind { - ConversionKind::Ptr => { + ConversionKind::Ptr(space) => { let dst_type = map.get_or_add( builder, - SpirvType::Pointer( - SpirvScalarKey::from(to_type), - spirv_headers::StorageClass::Generic, - ), + SpirvType::Pointer(SpirvScalarKey::from(to_type), space.to_spirv()), ); builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?; } @@ -1365,6 +1370,10 @@ impl ast::Instruction { } ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, None)), ast::Instruction::Ret(d) => ast::Instruction::Ret(d), + ast::Instruction::Cvta(d, a) => { + let inst_type = ast::Type::Scalar(ast::ScalarType::B64); + ast::Instruction::Cvta(d, a.map(visitor, Some(inst_type))) + } } } } @@ -1443,6 +1452,7 @@ impl ast::Instruction { | ast::Instruction::SetpBool(_, _) | ast::Instruction::Not(_, _) | ast::Instruction::Cvt(_, _) + | ast::Instruction::Cvta(_, _) | ast::Instruction::Shl(_, _) | ast::Instruction::St(_, _) | ast::Instruction::Ret(_) => None, @@ -1498,7 +1508,7 @@ enum ConversionKind { Default, // zero-extend/chop/bitcast depending on types SignExtend, - Ptr, + Ptr(ast::LdStateSpace), } impl ImplicitConversion { @@ -1944,6 +1954,19 @@ impl ast::IntType { } } +impl ast::LdStateSpace { + fn to_spirv(self) -> spirv::StorageClass { + match self { + ast::LdStateSpace::Const => spirv::StorageClass::UniformConstant, + ast::LdStateSpace::Generic => spirv::StorageClass::Generic, + ast::LdStateSpace::Global => spirv::StorageClass::CrossWorkgroup, + ast::LdStateSpace::Local => spirv::StorageClass::Function, + ast::LdStateSpace::Shared => spirv::StorageClass::Workgroup, + ast::LdStateSpace::Param => unreachable!(), + } + } +} + fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool { match (instr, operand) { (ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => { @@ -1980,7 +2003,7 @@ fn insert_implicit_conversions_ld_src( src, should_convert_ld_param_src, ), - ast::LdStateSpace::Generic => { + ast::LdStateSpace::Generic | ast::LdStateSpace::Global => { let new_src_type = ast::Type::Scalar(ast::ScalarType::from_parts( mem::size_of::() as u8, ScalarKind::Byte, @@ -1998,7 +2021,7 @@ fn insert_implicit_conversions_ld_src( new_src, new_src_type, instr_type, - ConversionKind::Ptr, + ConversionKind::Ptr(state_space), ) } _ => todo!(),