From 6c811a55d25ab4623ad9089fce9b66a5151a4cdd Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 12 Sep 2025 22:52:33 +0200 Subject: [PATCH] Random fixes (#504) This is a collection of random changes coming from the workload I'm working on. The most important change is better support for `.params`: PTX uses .param namespace both for some local variables and kernel args. This is a problem for us because those are different address spaces on AMDGPU. So far we've made an effort to convert to local and const namespaces whenever possible, but this commit tries to handle more patterns, which are impossible to track precisely, by converting to generic space. --- compiler/src/main.rs | 32 ++++--- cuda_macros/src/lib.rs | 2 +- ext/rocm_smi-sys/build.rs | 7 ++ ptx/src/pass/insert_explicit_load_store.rs | 40 ++++++++ ptx/src/pass/insert_implicit_conversions2.rs | 11 ++- ptx/src/pass/llvm/emit.rs | 4 +- ptx/src/pass/llvm/mod.rs | 16 ++-- ptx/src/pass/mod.rs | 19 +++- ptx/src/test/ll/param_is_addressable.ll | 34 +++++++ ptx/src/test/mod.rs | 2 +- ptx/src/test/spirv_run/mod.rs | 3 + .../test/spirv_run/param_is_addressable.ptx | 22 +++++ ptx_parser/src/lib.rs | 2 +- ptxas/src/main.rs | 1 + zluda/lib/OpenCL.lib | Bin 28824 -> 0 bytes zluda/src/impl/device.rs | 4 + zluda/src/impl/driver.rs | 89 +++++++++++++++++- zluda/src/impl/library.rs | 40 +++----- zluda/src/impl/memory.rs | 8 ++ zluda/src/impl/module.rs | 2 +- zluda/src/lib.rs | 3 + zluda_common/src/lib.rs | 5 +- 22 files changed, 282 insertions(+), 64 deletions(-) create mode 100644 ext/rocm_smi-sys/build.rs create mode 100644 ptx/src/test/ll/param_is_addressable.ll create mode 100644 ptx/src/test/spirv_run/param_is_addressable.ptx delete mode 100644 zluda/lib/OpenCL.lib diff --git a/compiler/src/main.rs b/compiler/src/main.rs index 5effaaf..9d1a5d1 100644 --- a/compiler/src/main.rs +++ b/compiler/src/main.rs @@ -6,6 +6,7 @@ use std::io::{self, Write}; use std::path::{Path, PathBuf}; use std::process::ExitCode; use std::str; +use std::time::Instant; use std::{env, mem}; mod error; @@ -57,19 +58,12 @@ fn main_core() -> Result<(), CompilerError> { let arch: String = match opts.arch { Some(s) => s, - None => { - (|| { - let runtime = hip::Runtime::load()?; - runtime.init()?; - get_gpu_arch(&runtime) - })() - .unwrap_or_else(|_| DEFAULT_ARCH.to_owned()) - /* - get_gpu_arch(&mut dev_props) - .map(String::from) - .unwrap_or(DEFAULT_ARCH.to_owned()) - */ - } + None => (|| { + let runtime = hip::Runtime::load()?; + runtime.init()?; + get_gpu_arch(&runtime) + })() + .unwrap_or_else(|_| DEFAULT_ARCH.to_owned()), }; let ptx = fs::read(&ptx_path).map_err(CompilerError::from)?; @@ -83,6 +77,7 @@ fn main_core() -> Result<(), CompilerError> { write_to_file(bytes, &output_path).unwrap(); }; + let mut start = Instant::now(); comgr::compile_bitcode( &comgr, &arch, @@ -92,17 +87,22 @@ fn main_core() -> Result<(), CompilerError> { Some(&comgr_hook), ) .map_err(CompilerError::from)?; + report_pass_time("compile_bitcode", &mut start); Ok(()) } fn ptx_to_llvm(ptx: &str) -> Result { let ast = ptx_parser::parse_module_checked(ptx).map_err(CompilerError::from)?; + let mut start = Instant::now(); let module = ptx::to_llvm_module( ast, ptx::Attributes { clock_rate: 2124000, }, + |pass| { + report_pass_time(pass, &mut start); + }, ) .map_err(CompilerError::from)?; let bitcode = module.llvm_ir.write_bitcode_to_memory().to_vec(); @@ -117,6 +117,12 @@ fn ptx_to_llvm(ptx: &str) -> Result { }) } +fn report_pass_time(pass: &str, start: &mut Instant) { + let duration = start.elapsed(); + println!("Pass {:?} took {:?}", pass, duration); + *start = Instant::now(); +} + #[derive(Debug)] struct LLVMArtifacts { bitcode: Vec, diff --git a/cuda_macros/src/lib.rs b/cuda_macros/src/lib.rs index 483ba75..1554f8f 100644 --- a/cuda_macros/src/lib.rs +++ b/cuda_macros/src/lib.rs @@ -313,7 +313,7 @@ fn join( pub fn test_cuda(_attr: TokenStream, item: TokenStream) -> TokenStream { let fn_ = parse_macro_input!(item as syn::ItemFn); let cuda_fn = format_ident!("{}{}", fn_.sig.ident, "_nvidia"); - let zluda_fn = format_ident!("{}{}", fn_.sig.ident, "_amdgpu"); + let zluda_fn = format_ident!("{}{}", fn_.sig.ident, "_zluda"); let fn_name = fn_.sig.ident.clone(); quote! { #[test] diff --git a/ext/rocm_smi-sys/build.rs b/ext/rocm_smi-sys/build.rs new file mode 100644 index 0000000..49060c3 --- /dev/null +++ b/ext/rocm_smi-sys/build.rs @@ -0,0 +1,7 @@ +use std::env::VarError; + +fn main() -> Result<(), VarError> { + println!("cargo:rustc-link-lib=dylib=rocm_smi64"); + println!("cargo:rustc-link-search=native=/opt/rocm/lib/"); + Ok(()) +} diff --git a/ptx/src/pass/insert_explicit_load_store.rs b/ptx/src/pass/insert_explicit_load_store.rs index 32597c5..2805dfa 100644 --- a/ptx/src/pass/insert_explicit_load_store.rs +++ b/ptx/src/pass/insert_explicit_load_store.rs @@ -114,6 +114,13 @@ fn run_statement<'a, 'input>( result.push(Statement::Instruction(instruction)); result.extend(visitor.post.drain(..).map(Statement::Instruction)); } + Statement::Instruction(ast::Instruction::Mov { data, arguments }) => { + let instruction = visitor.visit_mov(data, arguments); + let instruction = ast::visit_map(instruction, visitor)?; + result.extend(visitor.pre.drain(..).map(Statement::Instruction)); + result.push(Statement::Instruction(instruction)); + result.extend(visitor.post.drain(..).map(Statement::Instruction)); + } Statement::PtrAccess(ptr_access) => { let statement = Statement::PtrAccess(visitor.visit_ptr_access(ptr_access)?); let statement = statement.visit_map(visitor)?; @@ -293,6 +300,39 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { }) } + fn visit_mov( + &mut self, + data: ptx_parser::MovDetails, + mut arguments: ptx_parser::MovArgs, + ) -> ast::Instruction { + if let Some(remap) = self.variables.get(&arguments.src) { + match remap { + RemapAction::PreLdPostSt { .. } => {} + RemapAction::LDStSpaceChange { + name, + new_space, + old_space, + } => { + let generic_var = self + .resolver + .register_unnamed(Some((data.typ.clone(), ast::StateSpace::Reg))); + self.pre.push(ast::Instruction::Cvta { + data: ast::CvtaDetails { + state_space: *new_space, + direction: ast::CvtaDirection::ExplicitToGeneric, + }, + arguments: ast::CvtaArgs { + dst: generic_var, + src: *name, + }, + }); + arguments.src = generic_var; + } + } + } + ast::Instruction::Mov { data, arguments } + } + fn visit_variable(&mut self, var: &mut ast::Variable) -> Result<(), TranslateError> { let old_space = match var.state_space { space @ (ptx_parser::StateSpace::Reg | ptx_parser::StateSpace::Param) => space, diff --git a/ptx/src/pass/insert_implicit_conversions2.rs b/ptx/src/pass/insert_implicit_conversions2.rs index b2d3161..b1e473b 100644 --- a/ptx/src/pass/insert_implicit_conversions2.rs +++ b/ptx/src/pass/insert_implicit_conversions2.rs @@ -152,11 +152,11 @@ fn is_addressable(this: ast::StateSpace) -> bool { | ast::StateSpace::Generic | ast::StateSpace::Global | ast::StateSpace::Local - | ast::StateSpace::Shared => true, + | ast::StateSpace::Shared + | ast::StateSpace::ParamEntry => true, ast::StateSpace::Param | ast::StateSpace::Reg => false, ast::StateSpace::SharedCluster | ast::StateSpace::SharedCta - | ast::StateSpace::ParamEntry | ast::StateSpace::ParamFunc => todo!(), } } @@ -180,13 +180,14 @@ fn default_implicit_conversion_space( | ast::StateSpace::Generic | ast::StateSpace::Const | ast::StateSpace::Local - | ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)), + | ast::StateSpace::Shared + | ast::StateSpace::Param => Ok(Some(ConversionKind::BitToPtr)), _ => Err(error_mismatched_type()), }, ast::Type::Scalar(ast::ScalarType::B32) | ast::Type::Scalar(ast::ScalarType::U32) | ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space { - ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => { + ast::StateSpace::Local | ast::StateSpace::Shared => { Ok(Some(ConversionKind::BitToPtr)) } _ => Err(error_mismatched_type()), @@ -220,7 +221,7 @@ fn coerces_to_generic(this: ast::StateSpace) -> bool { ast::StateSpace::Global | ast::StateSpace::Const | ast::StateSpace::Local - | ptx_parser::StateSpace::SharedCta + | ast::StateSpace::SharedCta | ast::StateSpace::SharedCluster | ast::StateSpace::Shared => true, ast::StateSpace::Reg diff --git a/ptx/src/pass/llvm/emit.rs b/ptx/src/pass/llvm/emit.rs index 976be8a..8bcf9e1 100644 --- a/ptx/src/pass/llvm/emit.rs +++ b/ptx/src/pass/llvm/emit.rs @@ -703,7 +703,7 @@ impl<'a> MethodEmitContext<'a> { }); Ok(()) } - _ => todo!(), + _ => return Err(error_todo()), } } @@ -2406,7 +2406,7 @@ impl<'a> MethodEmitContext<'a> { (control >> 12) & 0b1111, ]; if components.iter().any(|&c| c > 7) { - return Err(TranslateError::Todo("".to_string())); + return Err(error_todo()); } let u32_type = get_scalar_type(self.context, ast::ScalarType::U32); let v4u8_type = get_type(self.context, &ast::Type::Vector(4, ast::ScalarType::U8))?; diff --git a/ptx/src/pass/llvm/mod.rs b/ptx/src/pass/llvm/mod.rs index 40781fc..cd1814d 100644 --- a/ptx/src/pass/llvm/mod.rs +++ b/ptx/src/pass/llvm/mod.rs @@ -169,8 +169,9 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR ast::ScalarType::F32 => unsafe { LLVMFloatTypeInContext(context) }, ast::ScalarType::F64 => unsafe { LLVMDoubleTypeInContext(context) }, ast::ScalarType::BF16 => unsafe { LLVMBFloatTypeInContext(context) }, - ast::ScalarType::U16x2 => todo!(), - ast::ScalarType::S16x2 => todo!(), + ast::ScalarType::U16x2 | ast::ScalarType::S16x2 => unsafe { + LLVMVectorType(LLVMInt16TypeInContext(context), 2) + }, ast::ScalarType::F16x2 => unsafe { LLVMVectorType(LLVMHalfTypeInContext(context), 2) }, ast::ScalarType::BF16x2 => unsafe { LLVMVectorType(LLVMBFloatTypeInContext(context), 2) }, } @@ -180,14 +181,17 @@ fn get_state_space(space: ast::StateSpace) -> Result { match space { ast::StateSpace::Reg => Ok(PRIVATE_ADDRESS_SPACE), ast::StateSpace::Generic => Ok(GENERIC_ADDRESS_SPACE), - ast::StateSpace::Param => Err(TranslateError::Todo("".to_string())), + // This is dodgy, we try our best to convert all .param into either + // .param::entry or .local, but we can't always succeed. + // In those cases we convert .param into generic address space + ast::StateSpace::Param => Ok(GENERIC_ADDRESS_SPACE), ast::StateSpace::ParamEntry => Ok(CONSTANT_ADDRESS_SPACE), - ast::StateSpace::ParamFunc => Err(TranslateError::Todo("".to_string())), + ast::StateSpace::ParamFunc => Err(error_todo()), ast::StateSpace::Local => Ok(PRIVATE_ADDRESS_SPACE), ast::StateSpace::Global => Ok(GLOBAL_ADDRESS_SPACE), ast::StateSpace::Const => Ok(CONSTANT_ADDRESS_SPACE), ast::StateSpace::Shared => Ok(SHARED_ADDRESS_SPACE), - ast::StateSpace::SharedCta => Err(TranslateError::Todo("".to_string())), - ast::StateSpace::SharedCluster => Err(TranslateError::Todo("".to_string())), + ast::StateSpace::SharedCta => Err(error_todo()), + ast::StateSpace::SharedCluster => Err(error_todo()), } } diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index f743fad..0b9ef79 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -60,31 +60,48 @@ pub struct Attributes { pub fn to_llvm_module<'input>( ast: ast::Module<'input>, attributes: Attributes, + mut on_pass_end: impl FnMut(&str), ) -> Result { let mut flat_resolver = GlobalStringIdentResolver2::<'input>::new(SpirvWord(1)); let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver); let sreg_map = SpecialRegistersMap::new(&mut scoped_resolver)?; let directives = normalize_identifiers2::run(&mut scoped_resolver, ast.directives)?; + on_pass_end("normalize_identifiers2"); let directives = replace_known_functions::run(&mut flat_resolver, directives); + on_pass_end("replace_known_functions"); let directives = normalize_predicates2::run(&mut flat_resolver, directives)?; + on_pass_end("normalize_predicates2"); let directives = resolve_function_pointers::run(directives)?; + on_pass_end("resolve_function_pointers"); let directives = fix_special_registers::run(&mut flat_resolver, &sreg_map, directives)?; + on_pass_end("fix_special_registers"); let directives = expand_operands::run(&mut flat_resolver, directives)?; + on_pass_end("expand_operands"); let directives = insert_post_saturation::run(&mut flat_resolver, directives)?; + on_pass_end("insert_post_saturation"); let directives = deparamize_functions::run(&mut flat_resolver, directives)?; + on_pass_end("deparamize_functions"); let directives = replace_instructions_with_functions_fp_required::run(&mut flat_resolver, directives)?; + on_pass_end("replace_instructions_with_functions_fp_required"); let directives = normalize_basic_blocks::run(&mut flat_resolver, directives)?; + on_pass_end("normalize_basic_blocks"); let directives = remove_unreachable_basic_blocks::run(directives)?; + on_pass_end("remove_unreachable_basic_blocks"); let directives = instruction_mode_to_global_mode::run(&mut flat_resolver, directives)?; + on_pass_end("instruction_mode_to_global_mode"); let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?; + on_pass_end("insert_explicit_load_store"); let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?; + on_pass_end("insert_implicit_conversions2"); let directives = replace_instructions_with_functions::run(&mut flat_resolver, directives)?; + on_pass_end("replace_instructions_with_functions"); let directives = hoist_globals::run(directives)?; - + on_pass_end("hoist_globals"); let context = llvm::Context::new(); let llvm_ir = llvm::emit::run(&context, flat_resolver, directives)?; let attributes_ir = llvm::attributes::run(&context, attributes)?; + on_pass_end("emit_llvm"); Ok(Module { llvm_ir, attributes_ir, diff --git a/ptx/src/test/ll/param_is_addressable.ll b/ptx/src/test/ll/param_is_addressable.ll new file mode 100644 index 0000000..6f75d5e --- /dev/null +++ b/ptx/src/test/ll/param_is_addressable.ll @@ -0,0 +1,34 @@ +define amdgpu_kernel void @param_is_addressable(ptr addrspace(4) byref(i64) %"33", ptr addrspace(4) byref(i64) %"34") #0 { + %"35" = alloca i64, align 8, addrspace(5) + %"36" = alloca i64, align 8, addrspace(5) + %"37" = alloca i64, align 8, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"32" + +"32": ; preds = %1 + %"38" = load i64, ptr addrspace(4) %"33", align 8 + store i64 %"38", ptr addrspace(5) %"35", align 8 + %"39" = load i64, ptr addrspace(4) %"34", align 8 + store i64 %"39", ptr addrspace(5) %"36", align 8 + %"49" = ptrtoint ptr addrspace(4) %"33" to i64 + %2 = inttoptr i64 %"49" to ptr addrspace(4) + %"40" = addrspacecast ptr addrspace(4) %2 to ptr + store ptr %"40", ptr addrspace(5) %"37", align 8 + %"43" = load i64, ptr addrspace(5) %"37", align 8 + %"50" = inttoptr i64 %"43" to ptr + %"42" = load i64, ptr %"50", align 8 + store i64 %"42", ptr addrspace(5) %"37", align 8 + %"45" = load i64, ptr addrspace(5) %"37", align 8 + %"46" = load i64, ptr addrspace(5) %"35", align 8 + %"51" = sub i64 %"45", %"46" + store i64 %"51", ptr addrspace(5) %"37", align 8 + %"47" = load i64, ptr addrspace(5) %"36", align 8 + %"48" = load i64, ptr addrspace(5) %"37", align 8 + %"53" = inttoptr i64 %"47" to ptr + store i64 %"48", ptr %"53", align 8 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee" "denormal-fp-math-f32"="ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" } \ No newline at end of file diff --git a/ptx/src/test/mod.rs b/ptx/src/test/mod.rs index f746d63..49c31dd 100644 --- a/ptx/src/test/mod.rs +++ b/ptx/src/test/mod.rs @@ -38,7 +38,7 @@ fn compile_and_assert(ptx_text: &str) -> Result<(), TranslateError> { let attributes = pass::Attributes { clock_rate: 2124000, }; - crate::to_llvm_module(ast, attributes)?; + crate::to_llvm_module(ast, attributes, |_| {})?; Ok(()) } diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 345b112..0a39523 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -339,6 +339,7 @@ test_ptx!( [0x8e2da590u32, 0xedeaee14, 0x248a9f70], [613065134u32] ); +test_ptx!(param_is_addressable, [0xDEAD], [0u64]); test_ptx!(assertfail); // TODO: not yet supported @@ -556,6 +557,7 @@ fn test_hip_assert< pass::Attributes { clock_rate: 2124000, }, + |_| {}, ) .unwrap(); let name = CString::new(name)?; @@ -576,6 +578,7 @@ fn test_llvm_assert( pass::Attributes { clock_rate: 2124000, }, + |_| {}, ) .unwrap(); let actual_ll = llvm_ir.llvm_ir.print_module_to_string(); diff --git a/ptx/src/test/spirv_run/param_is_addressable.ptx b/ptx/src/test/spirv_run/param_is_addressable.ptx new file mode 100644 index 0000000..8d394b3 --- /dev/null +++ b/ptx/src/test/spirv_run/param_is_addressable.ptx @@ -0,0 +1,22 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry param_is_addressable( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .b64 temp; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + mov.b64 temp, input; + ld.param.b64 temp, [temp]; + sub.u64 temp, temp, in_addr; + st.u64 [out_addr], temp; + ret; +} diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index 98c1351..4253ae6 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -892,7 +892,7 @@ fn method_parameter<'a, 'input: 'a>( ( vector_prefix, scalar_type, - opt((Token::DotPtr, Token::DotGlobal)), + opt((Token::DotPtr, opt(Token::DotGlobal))), opt(align.verify(|x| x.count_ones() == 1)), ident, ), diff --git a/ptxas/src/main.rs b/ptxas/src/main.rs index d71ca6b..0ffe841 100644 --- a/ptxas/src/main.rs +++ b/ptxas/src/main.rs @@ -42,6 +42,7 @@ fn main() { ptx::Attributes { clock_rate: clock_rate as u32, }, + |_| {}, ) .unwrap(); let elf_binary = comgr::compile_bitcode( diff --git a/zluda/lib/OpenCL.lib b/zluda/lib/OpenCL.lib deleted file mode 100644 index 2b766ee858f3f474258b211c632aef8bc5368416..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 28824 zcmY$iNi0gvu;bEKKm~?o2A1ZQCYI)gsNx1tu9=0Y351!z#lXPulYzlZjDf+zhJnGd zh=IZC6a-tFFff4eJO&2qc?=9TTu@xXz+eNyZx|SC-Y_uO#z65c1_s+(3=DQ|P`rzQ z!EP4=gS{3MFJNG>U%S1I1ey7@W2+ zFgVLW@gxQYXAu6yz~KCgfx)E!ir+CXxV&Rva7}>XYYYso*BBVwoS=9Q1B2Th1_pO6 zD4xN<;68(a!GjHoYZw?jK==g%gU1U72G0a2zQDlXd4Yk!%LR&eF)(=TVqoyrfZ}-! z4Bqn?7<~AkxQ2nj2ZUcSF!;P;VDOEA;v)X7{WB5cm@MQ7zqDiU$eX7~t^*#g`Zu5-%|@BzZvbCI*J2O$-dlDo{Lufgu@$|1dBl|6yQANrU1` z3=Aok7#LC=pm+@fL+Tm^hBOH%?qFa@1K}?W3~65&7}DdQ_yhw(`UwVx3?nFBz`&5P zfPo>C2Z}2g7&1Zl9s@(>JqCs>7bxDvz>u|zfg#%jiq|nPWUpgj$dQ5KHU@?q5dOu$ zkn@XyAvX((pD{4xK4W0Wi-Y1j3=Da97#Q+>pm+xZL;em1h5{WBj*s{B4e$>NiFa`g zb`J6k2=Nb!_YZLObM|3~kM|4p@xhd5sD~(kNQZcY`gzB@IEFYf#K&jm7Q`p#IF)AR zqy!Y@rxzvWGN5r`YMgWO^HRN2i}F%)7?N|4d2ki^xdoXysR(nNQ3c@Yic%9xQk_cE z(o%~+Mj>;tsR_?4$p|RQFGwva$xJOq(}5<5s^2+3H#adaC9pKLG!?E7MG&`oG=r0K zoWbry){UkF-K@Nl)QS=$Yax8Bs@#h5b3-Z%kR1tD4Hrk%1rm1mK?^|?L2T-SQu9($ zi;_^o9ij_G3Y*@L)QXbQB82-O+F+u1bQrn7LjXmOkqbQaK>`*<))w3}CZMXoX_BW8Hv2t6 zMHr++f}{x~lY{e1i;@vBjVXg_S8!r(K@L*!0^?&<1@b$V0tTuFO%~0B(j>HO7z`J| zQ0bCdmYIxHf(4f*!GuuNh8Cw5xt68oAxaN~2u!7GULmMhb4)HQ%`8fF_wi54N=+^S z#R@JNG@VX~MMat5AcAUu31C+T3Wo5+%o3l>Vvw^?^`J37aAsDy_RR2z~QPSruF z$w(FI%sRXGDhN91N>qQdsn2(E|~}r8XoVo0q_P;8JLM-7<4>&@~~6VNne> z4&f7!Vu%o$THnM1bb}E>m}((rzyruPu>c~1p|U6&IeGXdg85igp~OFu$sipla_G7< z(~A;IQhif%F_Nz@h7_7!KbN4yymUm91RCF{!svPuOESwawIGY4=?F?qOhFG^Br%+- z!KoA;5(ph&aV)wZQGlWWB7~tfCpEDcTYe8pMUz3(864*8R+I`VLZKlA6TnpGn^=HI zcEMr3AU+mVskxveiK-GRhE;WP0jSzR*zTK}3lhSj79RJgh9iVA^@Qf-qPZF(f~GPg zu{axHD42()0;~<5A`xoQ#L;wtibl75NU4of2%`$4>nX}aY6nBZB^*f*w|a0gf~OWF zJzxoJ`XH$ZRSQH6rrIqtFS9sKK1i3BjnU|K&keq`}6h#NP zS&OU~A%LU~QY|{>rFi5QmxN^Ig3BTI)DkS>7`i-NK;aKp4B}#{fx8xBBG^Gls=$>m zSP<+b1Q%URKv8~LW=>{aI+|isab#T;C8>GEnfZBcrFqFEnfZB+DJey%#ZV`^;*ur8 z1h@R6fSkmVwEQAaAI?3s1eamR3X#l(H6ai_K@mYw32pczD@7JUQi~RW2riNua56-g zh2SEofiyZDi;w~uRR}{Z!a#^ZFdsuzaA}fzQGRIwSQZ@F5Y;H+7`npqi?Xq5LzO_% z2QN?%7NH6usRgy)5z62^WEIee1|@E=qhT=*62PVop5EaKK_bX1AsqzpNC1-E=%Sc9 zFcgCYkkmn&(O_Y)Tag8j)s+?${ z=OFW7Dxg&(a(Ts&oP$jiLkGBcLRJhBKvM@PMG(rMd^AAL(OTq{zOp@k4!3tS9IH9R@s3O{@bkj#N57igwK4F()?n7V@tQj;?i zb26(EL2U)+{Ji3lMDVacaB2zE5FAR7%z_Ud1eYY1l!BTX!Ko$af-v==c{%xsDbS%? z(D*Ex0IE7zb%LQDLjtA`wMCqqgUTm;2FwgT16Ietz#wMEz+hv+z))n#z;Mcnfx*O@ zfnlCC0|S>014D@o1H&5|28I}028LU<3=D2|3=F&M7#Otd85kDWGcX7_Ffg<^Ffjab zU|>jbWMH`9$iQIZ#K5q{iGe}RnSo)FGXujfX9k7>7Y2rRE({C_t_%#F=Z2TA)Hu*6ysQ5E5O!H@8U<+Vis0v_UcoD$B5Esb6a4L|2 z!7PY@VObCZgGev~Lt8Kd!#`XbWO5i7+Hx2ee&sMQWaTn2 zJj-QZh|6POxRb}g;FHh5up^&=L8kyb^~u1|V9sE{V98*`V9j8|V9Q{~V9(&d;K<;_ z;LPB{;L6~};LhN|;K|^{;LYH};LG60;Li}i5Xcb35X=z55Xun75Y7<65Xlh55X}(7 z5X%t95YLdnkjRk4kj#+6kjjw8kj{|7kjaq6kj;?8kjs$Akk3%SP{>fkP|Q%mP|8ro zP|i@nP{~lmP|Z-oP|Hv^LVeB5K+aqktQ-QjbfHBkVk!b@k_)B~TC`y(1rN(Zhlb#4 zkpw~GUeJ0QG;M{d1f@pBEk(c#a6OK2JlJyQe6T2JIt4tVglafQ7(R1^O^WE5AZ(hU zeL#Y9JE(fW?R_*~fWL4GVYf|C^y!}e%eL4%)Qt>|tgJot`g6468J==$ITqOjp}GzG*DhNEdj3|OOk z4k3drB0#cO!vxw4h7BndqZ^Cf0fHtaf`iX!?gptRVpJJZ7kG3R-BCmg52Klc+L}d| zBz|ZYn@-el#X3j}(?-HDERq&TuN!656+;$nfC-`n+^s@V3hw1$6(oIJ6xmoDqoc$O zXd)To?gOo-&|OGia1zM~NZkP&jl_^ZHyEOvlp#eVBjMdmbW;fp3?i8U>y2R;LFD)z zl2OpA2Hhm$hv;BtK}!UzEg`rZxGe+EmQZ0xqC*!2CpvUNNXVm$!h;lD0$RhP3q$>k zBn%nJLRwvf@D)S~-Qgfn40nSB(VYzv!*Df75Sybx;+Sp*iK06h=hzd>;oyNGgomKb zh)TGh!SY~VVHp=eG6eU)5t1UXr{O~)ND>gAgNA{y2@*BDgQO82WcZy74MAw?!saZf z3etyKV2&Ye5Cx_SwKhRpn1{LuFF6OTp8{K#1nr@qa?u)!FiDiWh)ocuK5)W-ujxVW zgTRbJ>3yIIqVzjZrJ$uX+;oU3P*J3nT`1z{Yq{XM!L=xoE-)WmH4cA5QYYLf2p6db z0dqH~WQMDPrX7%x&|U*v8Km?D%b@lbu&BVkWD43xfT@S&dNe-B255!`bw%N_@TE#f zB9KA~G)M^PS|EA>@VR4*_CH((*78Re0674Zm{AuGAzJN-S$2dNxOoXT2q6w@;iE`_ zV-eL9un@de4_6IpwZroSw3P`G11m?EfM~1h=l?Jg_9Zz((i;B^zuDE0EgKa5K=?L7+;&LIT=yPR?-yFK7VINg|hg z;HC{!6kM($1R$9MrG*Su2oVRhjgh&aDONBaT-_l|0drAX!VqOR+P9E(6~s*EgN=d3 zEy6<77B5&4qOA+&g2M)(0aCKT+P7d85FvCWC~aCKb*S^;U-LgcpJj z_dY@h>KlXr)ZJhKP$Le*B@jVyqJ%3#atBBn;S3ND$rT_jsv|%m$Zi1f5l(=&7{L>l zuz5<*>>`#{VsZ}N_90jrcgqkg4R$!Z@dXxwxE|CRMCBs21Cb=)lX+k*S zNur7oRER-@fq~%+0|UbYr~oqqNUy;Gs5&#KFar+*1H(Tgbs%vE289Dm3=9l{44fb? z0~f+m>I}jH>415eS3{#7Qd^3}a@{9A+N)$ALG!=Y2{k);_C=|#pkUKz_ zfq{jAfq|JpfPsO570g0d!oc8=oS$2umzgrPh>`LCe~>zmK_DF<8ss)62B=Sr92gkD zuFZgnfc*xtgBd1b0#XCg30Gr^5CMtHg9s=Fna;pq0p)=xCWZh82D0WYsFg~LOe`T= z1{FZ5gbAEVb}%q7K-|i}fGw4zfSitmky42GUrZ{qhG6Sg1fE==5yFkM$;2C9#?8FnB zoJjrySq7h{MzRpWBrboL;>e3?3=9lh81|tILLsaoIj@)+V~<5e1_lOh3|pZiSxB~! zUgqG7OCDS{8oAKSPF`Gg8dGMc363-iD#iJb>;(5`De@h@oXC%E6NMp(&n^LUyC@7n zQ&{s4tsGYex6%>09vp(W3pSEV8+O z*kuMkQ3kc$t;xW^AcbicC>>#M&XJs!;N4HOoC4~vNh8|~pPZn?pQbokt#S+u3^J&; zLB`RFkrNKWX5w-uKEKH#*#;jwLbi(V_zNZsd_| zh0O&aS%zQ|7oPa~912+M0wq(>*cy1631JyhC=+L|DURF$8rxAswihugg=7cmVT!Ny zri9%-@I()i7fDY&IN}l1GFHZJCv4^o$yx-H$dJX8A63w81vf#E1D0gl@c2;`t8MW4 zLWEUF4j|Hxcv6oVvaO(LN6d~n=`o2XOx4ltqdZLU*rtJO8)62Pg0RF_qH3bqhT%1m zLlR#Ntc7YDXrdP>Wss5Q@wCIWG3~-tw2|ybd_7_v6no*5xUj{g6vQL!oH_-xT3?BQ zfk7A9R@Aw0N^Hc_2iHTf5qX*&$wJaY7Eiv{N460(El-VnkE2x%8ox2XZXe~f2AU3-P_2+84SjHBHo3|^H08oz`r>q4@EWPjiradHIv1KS1*N^Hbe=Qts3 z1W&|(mo*_-hF}tzD)6)$oMCoBMnh0mx*_Z#y-dM3`s;#fCukKNl2r&MalXS>qq?Hm z1X@>yFrT`z6W5wWN^CaAF-r(q4&h12X5`&K zNRB`-iA+a0${El&r5DT%um&f3^palU;3zpkmgjYzM5|0E#W1uy@Rts~iIIOc^wbC`+9wn6s+(8{hPWV@gx2_<&n8`DUJ z+6CI=fxLE+68rG=wNg;*!^mPJrx1KC$y6k}FtZrRsRK``n1*B%Mh2s*pVEov81-u^%$r6(Nf^YRmF2oibd%Y;J)eJ`nYB4Y{z%r;1Y0@_$dfla1(`iMms zHX-fVqr^@ec>}aErx?{v_;x@f>qyQeu+t^c=dE=Z7#K=m_Q4BH>ZVwHD|Jg@c0-F$ zSnC;f&?`lL#u2Ze9S~(`_TtzfiDW&3NnAd~*M=#F*$!Vo2i<&$u!Q8?iLakf0k#ja hjT63olLEW&<(Nt&yD+wlQeY>(Hd+ CUresult { Ok(()) } +pub(crate) fn primary_context_release_v2(hip_dev: hipDevice_t) -> CUresult { + primary_context_release(hip_dev) +} + pub(crate) fn primary_context_reset(hip_dev: hipDevice_t) -> CUresult { let (ctx, _) = get_primary_context(hip_dev)?; ctx.with_state_mut(|state| state.reset())?; diff --git a/zluda/src/impl/driver.rs b/zluda/src/impl/driver.rs index dbddb1a..737f5c3 100644 --- a/zluda/src/impl/driver.rs +++ b/zluda/src/impl/driver.rs @@ -166,7 +166,9 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi { None => return CUresult::ERROR_INVALID_VALUE, }; - device::primary_context_retain(pctx, hip_dev) + let (_, cu_ctx) = device::get_primary_context(hip_dev)?; + *pctx = cu_ctx; + Ok(()) } unsafe extern "system" fn get_module_from_cubin_ext1( @@ -362,8 +364,8 @@ fn get_device_hash_info() -> Result, CUerror> { (0..device_count) .map(|dev| { - let mut guid = CUuuid_st { bytes: [0; 16] }; - unsafe { crate::cuDeviceGetUuid(&mut guid, dev)? }; + let mut guid = unsafe { mem::zeroed() }; + device::get_uuid_v2(&mut guid, dev)?; let mut pci_domain = 0; device::get_attribute( @@ -387,7 +389,7 @@ fn get_device_hash_info() -> Result, CUerror> { )?; Ok(::dark_api::DeviceHashinfo { - guid, + guid: unsafe { mem::transmute(guid) }, pci_domain, pci_bus, pci_device, @@ -525,8 +527,57 @@ pub(crate) unsafe fn launch_kernel_ex( Ok(()) } +pub(crate) unsafe fn get_error_string( + error: cuda_types::cuda::CUresult, + error_string: &mut *const ::core::ffi::c_char, +) -> CUresult { + *error_string = match error { + CUresult::SUCCESS => c"no error".as_ptr(), + CUresult::ERROR_INVALID_VALUE => c"invalid value".as_ptr(), + CUresult::ERROR_OUT_OF_MEMORY => c"out of memory".as_ptr(), + CUresult::ERROR_NOT_INITIALIZED => c"driver not initialized".as_ptr(), + CUresult::ERROR_DEINITIALIZED => c"driver deinitialized".as_ptr(), + CUresult::ERROR_NO_DEVICE => c"no CUDA-capable device is detected".as_ptr(), + CUresult::ERROR_INVALID_DEVICE => c"invalid device".as_ptr(), + CUresult::ERROR_INVALID_IMAGE => c"invalid kernel image".as_ptr(), + CUresult::ERROR_INVALID_CONTEXT => c"invalid context".as_ptr(), + CUresult::ERROR_CONTEXT_ALREADY_CURRENT => c"context already current".as_ptr(), + CUresult::ERROR_MAP_FAILED => c"map failed".as_ptr(), + CUresult::ERROR_UNMAP_FAILED => c"unmap failed".as_ptr(), + CUresult::ERROR_ARRAY_IS_MAPPED => c"array is mapped".as_ptr(), + CUresult::ERROR_ALREADY_MAPPED => c"already mapped".as_ptr(), + CUresult::ERROR_NO_BINARY_FOR_GPU => c"no binary for GPU".as_ptr(), + CUresult::ERROR_ALREADY_ACQUIRED => c"already acquired".as_ptr(), + CUresult::ERROR_NOT_MAPPED => c"not mapped".as_ptr(), + CUresult::ERROR_NOT_SUPPORTED => c"operation not supported".as_ptr(), + CUresult::ERROR_INVALID_SOURCE => c"invalid source".as_ptr(), + CUresult::ERROR_FILE_NOT_FOUND => c"file not found".as_ptr(), + CUresult::ERROR_INVALID_HANDLE => c"invalid handle".as_ptr(), + CUresult::ERROR_NOT_READY => c"not ready".as_ptr(), + CUresult::ERROR_ILLEGAL_ADDRESS => c"illegal address".as_ptr(), + CUresult::ERROR_LAUNCH_OUT_OF_RESOURCES => c"launch out of resources".as_ptr(), + CUresult::ERROR_LAUNCH_TIMEOUT => c"launch timeout".as_ptr(), + CUresult::ERROR_LAUNCH_INCOMPATIBLE_TEXTURING => c"launch incompatible texturing".as_ptr(), + CUresult::ERROR_PEER_ACCESS_ALREADY_ENABLED => c"peer access already enabled".as_ptr(), + CUresult::ERROR_PEER_ACCESS_NOT_ENABLED => c"peer access not enabled".as_ptr(), + CUresult::ERROR_PRIMARY_CONTEXT_ACTIVE => c"primary context active".as_ptr(), + CUresult::ERROR_CONTEXT_IS_DESTROYED => c"context is destroyed".as_ptr(), + CUresult::ERROR_ASSERT => c"device-side assert triggered".as_ptr(), + CUresult::ERROR_TOO_MANY_PEERS => c"too many peers".as_ptr(), + CUresult::ERROR_HOST_MEMORY_ALREADY_REGISTERED => { + c"host memory already registered".as_ptr() + } + CUresult::ERROR_HOST_MEMORY_NOT_REGISTERED => c"host memory not registered".as_ptr(), + CUresult::ERROR_UNKNOWN => c"unknown error".as_ptr(), + _ => c"error".as_ptr(), + }; + Ok(()) +} + #[cfg(test)] mod tests { + use std::i32; + use crate::r#impl::driver::AllocationInfo; use crate::tests::CudaApi; use cuda_macros::test_cuda; @@ -571,4 +622,34 @@ mod tests { } assert_eq!(alloc_info.get_offset_and_info(0x2000 + 8), None); } + + #[test_cuda] + fn primary_context_is_inactive_on_init(api: impl CudaApi) { + api.cuInit(0); + let mut flags = u32::MAX; + let mut active = i32::MAX; + api.cuDevicePrimaryCtxGetState(0, &mut flags, &mut active); + assert_eq!(flags, 0); + assert_eq!(active, 0); + } + + #[test_cuda] + unsafe fn cudart_interface_fn2_creates_inactive_primary_ctx(api: impl CudaApi) { + api.cuInit(0); + let mut table_ptr = std::ptr::null(); + api.cuGetExportTable(&mut table_ptr, &dark_api::cuda::CudartInterface::GUID); + let cuda_rt_iface = dark_api::cuda::CudartInterface::new(table_ptr); + let mut dark_ctx = std::mem::zeroed(); + cuda_rt_iface + .cudart_interface_fn2(&mut dark_ctx, 0) + .unwrap(); + let mut flags = u32::MAX; + let mut active = i32::MAX; + api.cuDevicePrimaryCtxGetState(0, &mut flags, &mut active); + assert_eq!(flags, 0); + assert_eq!(active, 0); + let mut primary_ctx = std::mem::zeroed(); + api.cuDevicePrimaryCtxRetain(&mut primary_ctx, 0); + assert_eq!(dark_ctx.0, primary_ctx.0); + } } diff --git a/zluda/src/impl/library.rs b/zluda/src/impl/library.rs index 6ca3745..2fcc56a 100644 --- a/zluda/src/impl/library.rs +++ b/zluda/src/impl/library.rs @@ -55,12 +55,11 @@ pub(crate) unsafe fn load_data( _jit_options_values: Option<&mut *mut ::core::ffi::c_void>, _num_jit_options: ::core::ffi::c_uint, library_options: Option<&mut CUlibraryOption>, - library_option_values: Option<&mut *mut ::core::ffi::c_void>, + _library_option_values: Option<&mut *mut ::core::ffi::c_void>, num_library_options: ::core::ffi::c_uint, ) -> CUresult { let global_state = driver::global_state()?; - let options = - LibraryOptions::load(library_options, library_option_values, num_library_options)?; + let options = LibraryOptions::load(library_options, num_library_options)?; let library = Library { data: LibraryData::new(code as *mut c_void, options.preserve_binary)?, modules: vec![OnceLock::new(); global_state.devices.len()], @@ -76,7 +75,6 @@ struct LibraryOptions { impl LibraryOptions { unsafe fn load( library_options: Option<&mut CUlibraryOption>, - library_option_values: Option<&mut *mut ::core::ffi::c_void>, num_library_options: ::core::ffi::c_uint, ) -> Result { if num_library_options == 0 { @@ -84,28 +82,17 @@ impl LibraryOptions { preserve_binary: false, }); } - let (library_options, library_option_values) = - match (library_options, library_option_values) { - (Some(library_options), Some(library_option_values)) => { - let library_options = - std::slice::from_raw_parts(library_options, num_library_options as usize); - let library_option_values = std::slice::from_raw_parts( - library_option_values, - num_library_options as usize, - ); - (library_options, library_option_values) - } - _ => return Err(CUerror::INVALID_VALUE), - }; + let library_options = match library_options { + Some(library_options) => { + std::slice::from_raw_parts(library_options, num_library_options as usize) + } + _ => return Err(CUerror::INVALID_VALUE), + }; let mut preserve_binary = false; - for (option, value) in library_options - .iter() - .copied() - .zip(library_option_values.iter()) - { + for option in library_options.iter().copied() { match option { CUlibraryOption::CU_LIBRARY_BINARY_IS_PRESERVED => { - preserve_binary = *(value.cast::()); + preserve_binary = true; } _ => return Err(CUerror::NOT_SUPPORTED), } @@ -156,10 +143,7 @@ mod tests { use crate::tests::CudaApi; use cuda_macros::test_cuda; use cuda_types::cuda::{CUlibraryOption, CUresult, CUresultConsts}; - use std::{ - ffi::{c_void, CStr}, - mem, ptr, - }; + use std::{ffi::CStr, mem, ptr}; #[test_cuda] unsafe fn library_loads_without_context(api: impl CudaApi) { @@ -191,7 +175,7 @@ mod tests { ptr::null_mut(), 0, [CUlibraryOption::CU_LIBRARY_BINARY_IS_PRESERVED].as_mut_ptr(), - [(&true as *const bool) as *mut c_void].as_mut_ptr(), + ptr::null_mut(), 1, ); assert_eq!( diff --git a/zluda/src/impl/memory.rs b/zluda/src/impl/memory.rs index 9cf83b8..70395ed 100644 --- a/zluda/src/impl/memory.rs +++ b/zluda/src/impl/memory.rs @@ -153,3 +153,11 @@ pub(crate) unsafe fn set_d8_async( ) -> hipError_t { hipMemsetD8Async(dst_device, uc, n, stream) } + +pub(crate) fn get_allocation_granularity( + _granularity: &mut usize, + _property: &cuda_types::cuda::CUmemAllocationProp, + _option: cuda_types::cuda::CUmemAllocationGranularity_flags, +) -> CUresult { + CUresult::ERROR_NOT_SUPPORTED +} diff --git a/zluda/src/impl/module.rs b/zluda/src/impl/module.rs index ae9dcf1..506f824 100644 --- a/zluda/src/impl/module.rs +++ b/zluda/src/impl/module.rs @@ -138,7 +138,7 @@ fn compile_from_ptx_and_cache( } else { ptx_parser::parse_module_unchecked(text) }; - let llvm_module = ptx::to_llvm_module(ast, attributes).map_err(|_| CUerror::UNKNOWN)?; + let llvm_module = ptx::to_llvm_module(ast, attributes, |_| {}).map_err(|_| CUerror::UNKNOWN)?; let elf_module = comgr::compile_bitcode( comgr, gcn_arch, diff --git a/zluda/src/lib.rs b/zluda/src/lib.rs index 8d06f10..14b7bad 100644 --- a/zluda/src/lib.rs +++ b/zluda/src/lib.rs @@ -93,6 +93,7 @@ cuda_macros::cuda_function_declarations!( cuDeviceGetUuid_v2, cuDevicePrimaryCtxGetState, cuDevicePrimaryCtxRelease, + cuDevicePrimaryCtxRelease_v2, cuDevicePrimaryCtxReset, cuDevicePrimaryCtxRetain, cuDeviceTotalMem_v2, @@ -104,6 +105,7 @@ cuda_macros::cuda_function_declarations!( cuEventSynchronize, cuFuncGetAttribute, cuFuncSetAttribute, + cuGetErrorString, cuGetExportTable, cuGetProcAddress, cuGetProcAddress_v2, @@ -126,6 +128,7 @@ cuda_macros::cuda_function_declarations!( cuMemFreeHost, cuMemFree_v2, cuMemGetAddressRange_v2, + cuMemGetAllocationGranularity, cuMemGetInfo_v2, cuMemHostAlloc, cuMemRetainAllocationHandle, diff --git a/zluda_common/src/lib.rs b/zluda_common/src/lib.rs index c857f6b..95ec415 100644 --- a/zluda_common/src/lib.rs +++ b/zluda_common/src/lib.rs @@ -165,7 +165,10 @@ from_cuda_nop!( nvmlDevice_t, nvmlFieldValue_t, nvmlGpuFabricInfo_t, - cublasLtHandle_t + cublasLtHandle_t, + CUmemAllocationGranularity_flags, + CUmemAllocationProp, + CUresult ); from_cuda_transmute!( CUuuid => hipUUID,