diff --git a/comgr/src/lib.rs b/comgr/src/lib.rs index f27a127..bdec0fb 100644 --- a/comgr/src/lib.rs +++ b/comgr/src/lib.rs @@ -79,6 +79,10 @@ impl ActionInfo { unsafe { amd_comgr_action_info_set_isa_name(self.get(), full_isa.as_ptr().cast()) } } + fn set_language(&self, language: amd_comgr_language_t) -> Result<(), amd_comgr_status_s> { + unsafe { amd_comgr_action_info_set_language(self.get(), language) } + } + fn get(&self) -> amd_comgr_action_info_t { self.0 } @@ -90,36 +94,56 @@ impl Drop for ActionInfo { } } -pub fn compile_bitcode(gcn_arch: &CStr, buffer: &[u8]) -> Result, amd_comgr_status_s> { +pub fn compile_bitcode( + gcn_arch: &CStr, + main_buffer: &[u8], + ptx_impl: &[u8], +) -> Result, amd_comgr_status_s> { use amd_comgr_sys::*; let bitcode_data_set = DataSet::new()?; - let bitcode_data = Data::new( + let main_bitcode_data = Data::new( amd_comgr_data_kind_t::AMD_COMGR_DATA_KIND_BC, c"zluda.bc", - buffer, + main_buffer, + )?; + bitcode_data_set.add(&main_bitcode_data)?; + let stdlib_bitcode_data = Data::new( + amd_comgr_data_kind_t::AMD_COMGR_DATA_KIND_BC, + c"ptx_impl.bc", + ptx_impl, + )?; + bitcode_data_set.add(&stdlib_bitcode_data)?; + let lang_action_info = ActionInfo::new()?; + lang_action_info.set_isa_name(gcn_arch)?; + lang_action_info.set_language(amd_comgr_language_t::AMD_COMGR_LANGUAGE_LLVM_IR)?; + let linked_data_set = do_action( + &bitcode_data_set, + &lang_action_info, + amd_comgr_action_kind_t::AMD_COMGR_ACTION_COMPILE_SOURCE_WITH_DEVICE_LIBS_TO_BC, )?; - bitcode_data_set.add(&bitcode_data)?; - let reloc_data_set = DataSet::new()?; let action_info = ActionInfo::new()?; action_info.set_isa_name(gcn_arch)?; - unsafe { - amd_comgr_do_action( - amd_comgr_action_kind_t::AMD_COMGR_ACTION_CODEGEN_BC_TO_RELOCATABLE, - action_info.get(), - bitcode_data_set.get(), - reloc_data_set.get(), - ) - }?; - let exec_data_set = DataSet::new()?; - unsafe { - amd_comgr_do_action( - amd_comgr_action_kind_t::AMD_COMGR_ACTION_LINK_RELOCATABLE_TO_EXECUTABLE, - action_info.get(), - reloc_data_set.get(), - exec_data_set.get(), - ) - }?; + let reloc_data_set = do_action( + &linked_data_set, + &action_info, + amd_comgr_action_kind_t::AMD_COMGR_ACTION_CODEGEN_BC_TO_RELOCATABLE, + )?; + let exec_data_set = do_action( + &reloc_data_set, + &action_info, + amd_comgr_action_kind_t::AMD_COMGR_ACTION_LINK_RELOCATABLE_TO_EXECUTABLE, + )?; let executable = exec_data_set.get_data(amd_comgr_data_kind_t::AMD_COMGR_DATA_KIND_EXECUTABLE, 0)?; executable.copy_content() } + +fn do_action( + data_set: &DataSet, + action: &ActionInfo, + kind: amd_comgr_action_kind_t, +) -> Result { + let result = DataSet::new()?; + unsafe { amd_comgr_do_action(kind, action.get(), data_set.get(), result.get()) }?; + Ok(result) +} diff --git a/ptx/lib/zluda_ptx_impl.bc b/ptx/lib/zluda_ptx_impl.bc index 2d194c4..cbbf2dc 100644 Binary files a/ptx/lib/zluda_ptx_impl.bc and b/ptx/lib/zluda_ptx_impl.bc differ diff --git a/ptx/lib/zluda_ptx_impl.cpp b/ptx/lib/zluda_ptx_impl.cpp new file mode 100644 index 0000000..937bda1 --- /dev/null +++ b/ptx/lib/zluda_ptx_impl.cpp @@ -0,0 +1,18 @@ +// Every time this file changes it must te rebuilt, you need llvm-17: +// /opt/rocm/llvm/bin/clang -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1010 && llvm-dis-17 zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1010\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | llvm-as-17 - -o zluda_ptx_impl.bc && llvm-dis-17 zluda_ptx_impl.bc + +#include +#include + +#define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_ ## NAME + +extern "C" { + uint32_t FUNC(activemask)() { + return __builtin_amdgcn_read_exec_lo(); + } + + size_t __ockl_get_local_size(uint32_t) __device__; + uint32_t FUNC(sreg_ntid)(uint8_t member) { + return (uint32_t)__ockl_get_local_size(member); + } +} diff --git a/ptx/src/pass/deparamize_functions.rs b/ptx/src/pass/deparamize_functions.rs index 04c8831..6e0beab 100644 --- a/ptx/src/pass/deparamize_functions.rs +++ b/ptx/src/pass/deparamize_functions.rs @@ -94,7 +94,7 @@ fn run_method<'input>( .body .map(|statements| { for statement in statements { - run_statement(&remap_returns, &mut body, statement)?; + run_statement(resolver, &remap_returns, &mut body, statement)?; } Ok::<_, TranslateError>(body) }) @@ -110,6 +110,7 @@ fn run_method<'input>( } fn run_statement<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, remap_returns: &Vec<(SpirvWord, SpirvWord, ast::Type)>, result: &mut Vec, SpirvWord>>, statement: Statement, SpirvWord>, @@ -133,6 +134,66 @@ fn run_statement<'input>( } result.push(statement); } + Statement::Instruction(ast::Instruction::Call { + mut data, + mut arguments, + }) => { + let mut post_st = Vec::new(); + for ((type_, space), ident) in data + .input_arguments + .iter_mut() + .zip(arguments.input_arguments.iter_mut()) + { + if *space == ptx_parser::StateSpace::Param { + *space = ptx_parser::StateSpace::Reg; + let old_name = *ident; + *ident = resolver + .register_unnamed(Some((type_.clone(), ptx_parser::StateSpace::Reg))); + result.push(Statement::Instruction(ast::Instruction::Ld { + data: ast::LdDetails { + qualifier: ast::LdStQualifier::Weak, + state_space: ast::StateSpace::Param, + caching: ast::LdCacheOperator::Cached, + typ: type_.clone(), + non_coherent: false, + }, + arguments: ast::LdArgs { + dst: *ident, + src: old_name, + }, + })); + } + } + for ((type_, space), ident) in data + .return_arguments + .iter_mut() + .zip(arguments.return_arguments.iter_mut()) + { + if *space == ptx_parser::StateSpace::Param { + *space = ptx_parser::StateSpace::Reg; + let old_name = *ident; + *ident = resolver + .register_unnamed(Some((type_.clone(), ptx_parser::StateSpace::Reg))); + post_st.push(Statement::Instruction(ast::Instruction::St { + data: ast::StData { + qualifier: ast::LdStQualifier::Weak, + state_space: ast::StateSpace::Param, + caching: ast::StCacheOperator::Writethrough, + typ: type_.clone(), + }, + arguments: ast::StArgs { + src1: old_name, + src2: *ident, + }, + })); + } + } + result.push(Statement::Instruction(ast::Instruction::Call { + data, + arguments, + })); + result.extend(post_st.into_iter()); + } statement => { result.push(statement); } diff --git a/ptx/src/pass/fix_special_registers2.rs b/ptx/src/pass/fix_special_registers2.rs index 97f6356..3553139 100644 --- a/ptx/src/pass/fix_special_registers2.rs +++ b/ptx/src/pass/fix_special_registers2.rs @@ -31,10 +31,10 @@ pub(super) fn run<'a, 'input>( sreg_to_function, result: Vec::new(), }; - directives - .into_iter() - .map(|directive| run_directive(&mut visitor, directive)) - .collect::, _>>() + for directive in directives.into_iter() { + result.push(run_directive(&mut visitor, directive)?); + } + Ok(result) } fn run_directive<'a, 'input>( diff --git a/ptx/src/pass/hoist_globals.rs b/ptx/src/pass/hoist_globals.rs index 753172a..718c052 100644 --- a/ptx/src/pass/hoist_globals.rs +++ b/ptx/src/pass/hoist_globals.rs @@ -5,7 +5,7 @@ pub(super) fn run<'input>( ) -> Result, SpirvWord>>, TranslateError> { let mut result = Vec::with_capacity(directives.len()); for mut directive in directives.into_iter() { - run_directive(&mut result, &mut directive); + run_directive(&mut result, &mut directive)?; result.push(directive); } Ok(result) diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 0e233ed..7ba9ed0 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -39,9 +39,8 @@ mod normalize_predicates; mod normalize_predicates2; mod resolve_function_pointers; -static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv"); -static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc"); -const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl__"; +static ZLUDA_PTX_IMPL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc"); +const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl_"; pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result { let mut id_defs = GlobalStringIdResolver::<'input>::new(SpirvWord(1)); @@ -220,6 +219,12 @@ pub struct Module { pub kernel_info: HashMap, } +impl Module { + pub fn linked_bitcode(&self) -> &[u8] { + ZLUDA_PTX_IMPL + } +} + struct GlobalStringIdResolver<'input> { current_id: SpirvWord, variables: HashMap, SpirvWord>, @@ -1975,7 +1980,7 @@ impl SpecialRegistersMap2 { let name = ast::MethodName::Func(resolver.register_named(Cow::Owned(external_fn_name), None)); let return_type = sreg.get_function_return_type(); - let input_type = sreg.get_function_return_type(); + let input_type = sreg.get_function_input_type(); ( sreg, ast::MethodDeclaration { @@ -1988,14 +1993,17 @@ impl SpecialRegistersMap2 { array_init: Vec::new(), }], name: name, - input_arguments: vec![ast::Variable { - align: None, - v_type: input_type.into(), - state_space: ast::StateSpace::Reg, - name: resolver - .register_unnamed(Some((input_type.into(), ast::StateSpace::Reg))), - array_init: Vec::new(), - }], + input_arguments: input_type + .into_iter() + .map(|type_| ast::Variable { + align: None, + v_type: type_.into(), + state_space: ast::StateSpace::Reg, + name: resolver + .register_unnamed(Some((type_.into(), ast::StateSpace::Reg))), + array_init: Vec::new(), + }) + .collect::>(), shared_mem: None, }, ) diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index e15d6ea..60f5052 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -326,6 +326,7 @@ fn run_hip + Copy + Debug, Output: From + Copy + Debug + Def let elf_module = comgr::compile_bitcode( unsafe { CStr::from_ptr(dev_props.gcnArchName.as_ptr()) }, &*module.llvm_ir, + module.linked_bitcode(), ) .unwrap(); let mut module = ptr::null_mut();