From b4de21fbc5eaf33540f1121bfe7c6ba0acaff6c9 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Mon, 2 Aug 2021 01:04:05 +0200 Subject: [PATCH] Use calls to OpenCL builtins when translating sregs, do SPIRV->LLVM conversion on every build --- ptx/src/test/spirv_run/mod.rs | 1 + ptx/src/test/spirv_run/ntid.spvtxt | 40 ++++--- ptx/src/test/spirv_run/vector4.ptx | 22 ++++ ptx/src/test/spirv_run/vector4.spvtxt | 99 ++++++++++++++++ ptx/src/translate.rs | 162 ++++++++++++++++---------- zluda/Cargo.toml | 1 + zluda/src/impl/module.rs | 34 ++++++ 7 files changed, 278 insertions(+), 81 deletions(-) create mode 100644 ptx/src/test/spirv_run/vector4.ptx create mode 100644 ptx/src/test/spirv_run/vector4.spvtxt diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 226043f..d5bc8dd 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -61,6 +61,7 @@ test_ptx!(block, [1u64], [2u64]); test_ptx!(local_align, [1u64], [1u64]); test_ptx!(call, [1u64], [2u64]); test_ptx!(vector, [1u32, 2u32], [3u32, 3u32]); +test_ptx!(vector4, [1u32, 2u32, 3u32, 4u32], [4u32]); test_ptx!(ld_st_offset, [1u32, 2u32], [2u32, 1u32]); test_ptx!(ntid, [3u32], [4u32]); test_ptx!(reg_local, [12u64], [13u64]); diff --git a/ptx/src/test/spirv_run/ntid.spvtxt b/ptx/src/test/spirv_run/ntid.spvtxt index 7b5a630..e5f343c 100644 --- a/ptx/src/test/spirv_run/ntid.spvtxt +++ b/ptx/src/test/spirv_run/ntid.spvtxt @@ -7,24 +7,27 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %28 = OpExtInstImport "OpenCL.std" + %31 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "ntid" %gl_WorkGroupSize - OpDecorate %gl_WorkGroupSize BuiltIn WorkgroupSize + OpEntryPoint Kernel %1 "ntid" + OpExecutionMode %1 ContractionOff + OpDecorate %24 LinkageAttributes "get_local_size" Import %void = OpTypeVoid %ulong = OpTypeInt 64 0 - %v3ulong = OpTypeVector %ulong 3 -%_ptr_Input_v3ulong = OpTypePointer Input %v3ulong -%gl_WorkGroupSize = OpVariable %_ptr_Input_v3ulong Input - %33 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong %uint = OpTypeInt 32 0 + %35 = OpTypeFunction %ulong %uint + %36 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint - %1 = OpFunction %void None %33 + %uint_0 = OpConstant %uint 0 + %24 = OpFunction %ulong None %35 + %26 = OpFunctionParameter %uint + OpFunctionEnd + %1 = OpFunction %void None %36 %9 = OpFunctionParameter %ulong %10 = OpFunctionParameter %ulong - %26 = OpLabel + %29 = OpLabel %2 = OpVariable %_ptr_Function_ulong Function %3 = OpVariable %_ptr_Function_ulong Function %4 = OpVariable %_ptr_Function_ulong Function @@ -38,13 +41,12 @@ %12 = OpLoad %ulong %3 Aligned 8 OpStore %5 %12 %14 = OpLoad %ulong %4 - %24 = OpConvertUToPtr %_ptr_Generic_uint %14 - %13 = OpLoad %uint %24 Aligned 4 + %27 = OpConvertUToPtr %_ptr_Generic_uint %14 + %13 = OpLoad %uint %27 Aligned 4 OpStore %6 %13 - %38 = OpLoad %v3ulong %gl_WorkGroupSize - %23 = OpCompositeExtract %ulong %38 0 - %39 = OpBitcast %ulong %23 - %16 = OpUConvert %uint %39 + %23 = OpFunctionCall %ulong %24 %uint_0 + %40 = OpBitcast %ulong %23 + %16 = OpUConvert %uint %40 %15 = OpCopyObject %uint %16 OpStore %7 %15 %18 = OpLoad %uint %6 @@ -53,7 +55,7 @@ OpStore %6 %17 %20 = OpLoad %ulong %5 %21 = OpLoad %uint %6 - %25 = OpConvertUToPtr %_ptr_Generic_uint %20 - OpStore %25 %21 Aligned 4 + %28 = OpConvertUToPtr %_ptr_Generic_uint %20 + OpStore %28 %21 Aligned 4 OpReturn - OpFunctionEnd + OpFunctionEnd \ No newline at end of file diff --git a/ptx/src/test/spirv_run/vector4.ptx b/ptx/src/test/spirv_run/vector4.ptx new file mode 100644 index 0000000..d010b70 --- /dev/null +++ b/ptx/src/test/spirv_run/vector4.ptx @@ -0,0 +1,22 @@ +.version 6.5 +.target sm_60 +.address_size 64 + +.visible .entry vector4( + .param .u64 input_p, + .param .u64 output_p +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .v4 .u32 temp; + .reg .u32 temp_scalar; + + ld.param.u64 in_addr, [input_p]; + ld.param.u64 out_addr, [output_p]; + + ld.v4.u32 temp, [in_addr]; + mov.b32 temp_scalar, temp.w; + st.u32 [out_addr], temp_scalar; + ret; +} \ No newline at end of file diff --git a/ptx/src/test/spirv_run/vector4.spvtxt b/ptx/src/test/spirv_run/vector4.spvtxt new file mode 100644 index 0000000..8253bf9 --- /dev/null +++ b/ptx/src/test/spirv_run/vector4.spvtxt @@ -0,0 +1,99 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %51 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %25 "vector" + %void = OpTypeVoid + %uint = OpTypeInt 32 0 + %v2uint = OpTypeVector %uint 2 + %55 = OpTypeFunction %v2uint %v2uint +%_ptr_Function_v2uint = OpTypePointer Function %v2uint +%_ptr_Function_uint = OpTypePointer Function %uint + %uint_0 = OpConstant %uint 0 + %uint_1 = OpConstant %uint 1 + %ulong = OpTypeInt 64 0 + %67 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_Generic_v2uint = OpTypePointer Generic %v2uint + %1 = OpFunction %v2uint None %55 + %7 = OpFunctionParameter %v2uint + %24 = OpLabel + %3 = OpVariable %_ptr_Function_v2uint Function + %2 = OpVariable %_ptr_Function_v2uint Function + %4 = OpVariable %_ptr_Function_v2uint Function + %5 = OpVariable %_ptr_Function_uint Function + %6 = OpVariable %_ptr_Function_uint Function + OpStore %3 %7 + %59 = OpInBoundsAccessChain %_ptr_Function_uint %3 %uint_0 + %9 = OpLoad %uint %59 + %8 = OpCopyObject %uint %9 + OpStore %5 %8 + %61 = OpInBoundsAccessChain %_ptr_Function_uint %3 %uint_1 + %11 = OpLoad %uint %61 + %10 = OpCopyObject %uint %11 + OpStore %6 %10 + %13 = OpLoad %uint %5 + %14 = OpLoad %uint %6 + %12 = OpIAdd %uint %13 %14 + OpStore %6 %12 + %16 = OpLoad %uint %6 + %15 = OpCopyObject %uint %16 + %62 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_0 + OpStore %62 %15 + %18 = OpLoad %uint %6 + %17 = OpCopyObject %uint %18 + %63 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_1 + OpStore %63 %17 + %64 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_1 + %20 = OpLoad %uint %64 + %19 = OpCopyObject %uint %20 + %65 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_0 + OpStore %65 %19 + %22 = OpLoad %v2uint %4 + %21 = OpCopyObject %v2uint %22 + OpStore %2 %21 + %23 = OpLoad %v2uint %2 + OpReturnValue %23 + OpFunctionEnd + %25 = OpFunction %void None %67 + %34 = OpFunctionParameter %ulong + %35 = OpFunctionParameter %ulong + %49 = OpLabel + %26 = OpVariable %_ptr_Function_ulong Function + %27 = OpVariable %_ptr_Function_ulong Function + %28 = OpVariable %_ptr_Function_ulong Function + %29 = OpVariable %_ptr_Function_ulong Function + %30 = OpVariable %_ptr_Function_v2uint Function + %31 = OpVariable %_ptr_Function_uint Function + %32 = OpVariable %_ptr_Function_uint Function + %33 = OpVariable %_ptr_Function_ulong Function + OpStore %26 %34 + OpStore %27 %35 + %36 = OpLoad %ulong %26 Aligned 8 + OpStore %28 %36 + %37 = OpLoad %ulong %27 Aligned 8 + OpStore %29 %37 + %39 = OpLoad %ulong %28 + %46 = OpConvertUToPtr %_ptr_Generic_v2uint %39 + %38 = OpLoad %v2uint %46 Aligned 8 + OpStore %30 %38 + %41 = OpLoad %v2uint %30 + %40 = OpFunctionCall %v2uint %1 %41 + OpStore %30 %40 + %43 = OpLoad %v2uint %30 + %47 = OpBitcast %ulong %43 + %42 = OpCopyObject %ulong %47 + OpStore %33 %42 + %44 = OpLoad %ulong %29 + %45 = OpLoad %v2uint %30 + %48 = OpConvertUToPtr %_ptr_Generic_v2uint %44 + OpStore %48 %45 Aligned 8 + OpReturn + OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 5fea075..6c2c594 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -448,7 +448,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result( &mut numeric_id_defs, &mut (*func_decl).borrow_mut(), )?; - let ssa_statements = fix_special_registers(ssa_statements, &mut numeric_id_defs)?; + let ssa_statements = + fix_special_registers(ptx_impl_imports, ssa_statements, &mut numeric_id_defs)?; let mut numeric_id_defs = numeric_id_defs.finish(); let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?; let expanded_statements = @@ -1269,6 +1270,7 @@ fn to_ssa<'input, 'b>( } fn fix_special_registers( + ptx_impl_imports: &mut HashMap, typed_statements: Vec, numeric_id_defs: &mut NumericIdResolver, ) -> Result, TranslateError> { @@ -1276,7 +1278,6 @@ fn fix_special_registers( for s in typed_statements { match s { Statement::LoadVar( - mut details @ LoadVarDetails { @@ -1285,48 +1286,53 @@ fn fix_special_registers( }, ) => { let index = details.member_index.unwrap().0; - if index == 3 { - result.push(Statement::Constant(ConstantDefinition { - dst: details.arg.dst, - typ: ast::ScalarType::U32, - value: ast::ImmediateValue::U64(0), - })); - } else { - let sreg_and_type = match numeric_id_defs.special_registers.get(details.arg.src) - { - Some(reg) => get_sreg_id_scalar_type(numeric_id_defs, reg), - None => None, - }; - let (sreg_src, scalar_typ, vector_width) = match sreg_and_type { - Some(sreg_and_type) => sreg_and_type, - None => { - result.push(Statement::LoadVar(details)); - continue; - } - }; - let temp_id = numeric_id_defs - .register_intermediate(Some((details.typ.clone(), details.state_space))); - let real_dst = details.arg.dst; - details.arg.dst = temp_id; - result.push(Statement::LoadVar(LoadVarDetails { - arg: Arg2 { - src: sreg_src, - dst: temp_id, - }, - state_space: ast::StateSpace::Sreg, - typ: ast::Type::Scalar(scalar_typ), - member_index: Some((index, Some(vector_width))), - })); - result.push(Statement::Conversion(ImplicitConversion { - src: temp_id, - dst: real_dst, - from_type: ast::Type::Scalar(scalar_typ), - from_space: ast::StateSpace::Sreg, - to_type: ast::Type::Scalar(ast::ScalarType::U32), - to_space: ast::StateSpace::Sreg, - kind: ConversionKind::Default, - })); - } + let sreg = numeric_id_defs + .special_registers + .get(details.arg.src) + .ok_or_else(|| error_unreachable())?; + let (ocl_name, ocl_type) = sreg.get_opencl_fn_type(); + let index_constant = numeric_id_defs.register_intermediate(Some(( + ast::Type::Scalar(ast::ScalarType::U32), + ast::StateSpace::Reg, + ))); + result.push(Statement::Constant(ConstantDefinition { + dst: index_constant, + typ: ast::ScalarType::U32, + value: ast::ImmediateValue::U64(index as u64), + })); + let fn_result = numeric_id_defs.register_intermediate(Some(( + ast::Type::Scalar(ocl_type), + ast::StateSpace::Reg, + ))); + let return_arguments = + vec![(fn_result, ast::Type::Scalar(ocl_type), ast::StateSpace::Reg)]; + let input_arguments = vec![( + TypedOperand::Reg(index_constant), + ast::Type::Scalar(ast::ScalarType::U32), + ast::StateSpace::Reg, + )]; + let fn_call = register_external_fn_call( + numeric_id_defs, + ptx_impl_imports, + ocl_name.to_string(), + return_arguments.iter().map(|(_, typ, space)| (typ, *space)), + input_arguments.iter().map(|(_, typ, space)| (typ, *space)), + )?; + result.push(Statement::Call(ResolvedCall { + uniform: false, + return_arguments, + name: fn_call, + input_arguments, + })); + result.push(Statement::Conversion(ImplicitConversion { + src: fn_result, + dst: details.arg.dst, + from_type: ast::Type::Scalar(ocl_type), + from_space: ast::StateSpace::Reg, + to_type: ast::Type::Scalar(ast::ScalarType::U32), + to_space: ast::StateSpace::Reg, + kind: ConversionKind::Default, + })); } s => result.push(s), } @@ -1721,8 +1727,8 @@ fn instruction_to_fn_call( id_defs, ptx_impl_imports, fn_name, - return_arguments, - input_arguments, + return_arguments.iter().map(|(_, typ, state)| (typ, *state)), + input_arguments.iter().map(|(_, typ, state)| (typ, *state)), )?; Ok(Statement::Call(ResolvedCall { uniform: false, @@ -1732,12 +1738,12 @@ fn instruction_to_fn_call( })) } -fn register_external_fn_call( +fn register_external_fn_call<'a>( id_defs: &mut NumericIdResolver, ptx_impl_imports: &mut HashMap, name: String, - return_arguments: &[(ArgumentDescriptor, ast::Type, ast::StateSpace)], - input_arguments: &[(ArgumentDescriptor, ast::Type, ast::StateSpace)], + return_arguments: impl Iterator, + input_arguments: impl Iterator, ) -> Result { match ptx_impl_imports.entry(name) { hash_map::Entry::Vacant(entry) => { @@ -1770,19 +1776,18 @@ fn register_external_fn_call( } } -fn fn_arguments_to_variables( +fn fn_arguments_to_variables<'a>( id_defs: &mut NumericIdResolver, - args: &[(ArgumentDescriptor, ast::Type, ast::StateSpace)], + args: impl Iterator, ) -> Vec> { - args.iter() - .map(|(_, typ, space)| ast::Variable { - align: None, - v_type: typ.clone(), - state_space: *space, - name: id_defs.register_intermediate(None), - array_init: Vec::new(), - }) - .collect::>() + args.map(|(typ, space)| ast::Variable { + align: None, + v_type: typ.clone(), + state_space: space, + name: id_defs.register_intermediate(None), + array_init: Vec::new(), + }) + .collect::>() } fn arguments_to_resolved_arguments( @@ -2226,7 +2231,7 @@ fn expand_arguments<'a, 'b>( Statement::StoreVar(details) => result.push(Statement::StoreVar(details)), Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)), Statement::Conversion(conv) => result.push(Statement::Conversion(conv)), - Statement::Constant(_) => return Err(error_unreachable()), + Statement::Constant(c) => result.push(Statement::Constant(c)), } } Ok(result) @@ -4686,6 +4691,19 @@ impl PtxSpecialRegister { } } + fn get_scalar_type(self) -> ast::ScalarType { + match self { + PtxSpecialRegister::Tid + | PtxSpecialRegister::Ntid + | PtxSpecialRegister::Ctaid + | PtxSpecialRegister::Nctaid => ast::ScalarType::U32, + PtxSpecialRegister::Tid64 + | PtxSpecialRegister::Ntid64 + | PtxSpecialRegister::Ctaid64 + | PtxSpecialRegister::Nctaid64 => ast::ScalarType::U64, + } + } + fn get_builtin(self) -> spirv::BuiltIn { match self { PtxSpecialRegister::Tid | PtxSpecialRegister::Tid64 => { @@ -4701,6 +4719,23 @@ impl PtxSpecialRegister { } } + fn get_opencl_fn_type(self) -> (&'static str, ast::ScalarType) { + match self { + PtxSpecialRegister::Tid | PtxSpecialRegister::Tid64 => { + ("_Z12get_local_idj", ast::ScalarType::U64) + } + PtxSpecialRegister::Ntid | PtxSpecialRegister::Ntid64 => { + ("_Z14get_local_sizej", ast::ScalarType::U64) + } + PtxSpecialRegister::Ctaid | PtxSpecialRegister::Ctaid64 => { + ("_Z12get_group_idj", ast::ScalarType::U64) + } + PtxSpecialRegister::Nctaid | PtxSpecialRegister::Nctaid64 => { + ("_Z14get_num_groupsj", ast::ScalarType::U64) + } + } + } + fn normalized_sreg_and_type(self) -> Option<(PtxSpecialRegister, ast::ScalarType, u8)> { match self { PtxSpecialRegister::Tid => Some((PtxSpecialRegister::Tid64, ast::ScalarType::U64, 3)), @@ -4743,6 +4778,8 @@ impl SpecialRegistersMap { } fn interface(&self) -> Vec { + return Vec::new(); + /* self.reg_to_id .iter() .filter_map(|(sreg, id)| { @@ -4753,6 +4790,7 @@ impl SpecialRegistersMap { } }) .collect::>() + */ } fn get(&self, id: spirv::Word) -> Option { diff --git a/zluda/Cargo.toml b/zluda/Cargo.toml index cfb1a50..b54fd1d 100644 --- a/zluda/Cargo.toml +++ b/zluda/Cargo.toml @@ -14,6 +14,7 @@ level_zero-sys = { path = "../level_zero-sys" } lazy_static = "1.4" num_enum = "0.4" lz4-sys = "1.9" +tempfile = "3" [dependencies.ocl-core] version = "0.11" diff --git a/zluda/src/impl/module.rs b/zluda/src/impl/module.rs index a1fa9dd..4a338de 100644 --- a/zluda/src/impl/module.rs +++ b/zluda/src/impl/module.rs @@ -4,8 +4,10 @@ use std::{ ffi::c_void, ffi::CStr, ffi::CString, + io::{self, Write}, mem, os::raw::{c_char, c_int, c_uint}, + process::{Command, Stdio}, ptr, slice, }; @@ -20,6 +22,7 @@ use super::{ CUresult, GlobalState, HasLivenessCookie, LiveCheck, }; use ptx; +use tempfile::NamedTempFile; pub type Module = LiveCheck; @@ -88,6 +91,36 @@ impl SpirvModule { }) } + const LLVM_SPIRV: &'static str = "/home/vosen/amd/llvm-project/build/bin/llvm-spirv"; + const AMDGPU: &'static str = "/opt/amdgpu-pro/"; + const AMDGPU_BITCODE: [&'static str; 8] = [ + "opencl", + "ocml", + "ockl", + "oclc_correctly_rounded_sqrt_off", + "oclc_daz_opt_on", + "oclc_finite_only_off", + "oclc_unsafe_math_off", + "oclc_wavefrontsize64_off", + ]; + const AMDGPU_BITCODE_DEVICE_PREFIX: &'static str = "oclc_isa_version_"; + const AMDGPU_DEVICE: &'static str = "gfx1010"; + + fn compile_amd(spirv_il: &[u8]) -> io::Result<()> { + let dir = tempfile::tempdir()?; + let mut spirv = NamedTempFile::new_in(&dir)?; + let llvm = NamedTempFile::new_in(&dir)?; + spirv.write_all(spirv_il)?; + let mut cmd = Command::new(Self::LLVM_SPIRV) + .arg("-r") + .arg("-o") + .arg(llvm.path()) + .arg(spirv.path()) + .status()?; + assert!(cmd.success()); + Ok(()) + } + pub fn compile<'a>( &self, ctx: &ocl_core::Context, @@ -99,6 +132,7 @@ impl SpirvModule { self.binaries.len() * mem::size_of::(), ) }; + Self::compile_amd(byte_il).unwrap(); let main_module = ocl_core::create_program_with_il(ctx, byte_il, None)?; let main_module = match self.should_link_ptx_impl { None => {