diff --git a/compiler/src/main.rs b/compiler/src/main.rs index 5effaaf..9d1a5d1 100644 --- a/compiler/src/main.rs +++ b/compiler/src/main.rs @@ -6,6 +6,7 @@ use std::io::{self, Write}; use std::path::{Path, PathBuf}; use std::process::ExitCode; use std::str; +use std::time::Instant; use std::{env, mem}; mod error; @@ -57,19 +58,12 @@ fn main_core() -> Result<(), CompilerError> { let arch: String = match opts.arch { Some(s) => s, - None => { - (|| { - let runtime = hip::Runtime::load()?; - runtime.init()?; - get_gpu_arch(&runtime) - })() - .unwrap_or_else(|_| DEFAULT_ARCH.to_owned()) - /* - get_gpu_arch(&mut dev_props) - .map(String::from) - .unwrap_or(DEFAULT_ARCH.to_owned()) - */ - } + None => (|| { + let runtime = hip::Runtime::load()?; + runtime.init()?; + get_gpu_arch(&runtime) + })() + .unwrap_or_else(|_| DEFAULT_ARCH.to_owned()), }; let ptx = fs::read(&ptx_path).map_err(CompilerError::from)?; @@ -83,6 +77,7 @@ fn main_core() -> Result<(), CompilerError> { write_to_file(bytes, &output_path).unwrap(); }; + let mut start = Instant::now(); comgr::compile_bitcode( &comgr, &arch, @@ -92,17 +87,22 @@ fn main_core() -> Result<(), CompilerError> { Some(&comgr_hook), ) .map_err(CompilerError::from)?; + report_pass_time("compile_bitcode", &mut start); Ok(()) } fn ptx_to_llvm(ptx: &str) -> Result { let ast = ptx_parser::parse_module_checked(ptx).map_err(CompilerError::from)?; + let mut start = Instant::now(); let module = ptx::to_llvm_module( ast, ptx::Attributes { clock_rate: 2124000, }, + |pass| { + report_pass_time(pass, &mut start); + }, ) .map_err(CompilerError::from)?; let bitcode = module.llvm_ir.write_bitcode_to_memory().to_vec(); @@ -117,6 +117,12 @@ fn ptx_to_llvm(ptx: &str) -> Result { }) } +fn report_pass_time(pass: &str, start: &mut Instant) { + let duration = start.elapsed(); + println!("Pass {:?} took {:?}", pass, duration); + *start = Instant::now(); +} + #[derive(Debug)] struct LLVMArtifacts { bitcode: Vec, diff --git a/cuda_macros/src/lib.rs b/cuda_macros/src/lib.rs index 483ba75..1554f8f 100644 --- a/cuda_macros/src/lib.rs +++ b/cuda_macros/src/lib.rs @@ -313,7 +313,7 @@ fn join( pub fn test_cuda(_attr: TokenStream, item: TokenStream) -> TokenStream { let fn_ = parse_macro_input!(item as syn::ItemFn); let cuda_fn = format_ident!("{}{}", fn_.sig.ident, "_nvidia"); - let zluda_fn = format_ident!("{}{}", fn_.sig.ident, "_amdgpu"); + let zluda_fn = format_ident!("{}{}", fn_.sig.ident, "_zluda"); let fn_name = fn_.sig.ident.clone(); quote! { #[test] diff --git a/ext/rocm_smi-sys/build.rs b/ext/rocm_smi-sys/build.rs new file mode 100644 index 0000000..49060c3 --- /dev/null +++ b/ext/rocm_smi-sys/build.rs @@ -0,0 +1,7 @@ +use std::env::VarError; + +fn main() -> Result<(), VarError> { + println!("cargo:rustc-link-lib=dylib=rocm_smi64"); + println!("cargo:rustc-link-search=native=/opt/rocm/lib/"); + Ok(()) +} diff --git a/ptx/src/pass/insert_explicit_load_store.rs b/ptx/src/pass/insert_explicit_load_store.rs index 32597c5..2805dfa 100644 --- a/ptx/src/pass/insert_explicit_load_store.rs +++ b/ptx/src/pass/insert_explicit_load_store.rs @@ -114,6 +114,13 @@ fn run_statement<'a, 'input>( result.push(Statement::Instruction(instruction)); result.extend(visitor.post.drain(..).map(Statement::Instruction)); } + Statement::Instruction(ast::Instruction::Mov { data, arguments }) => { + let instruction = visitor.visit_mov(data, arguments); + let instruction = ast::visit_map(instruction, visitor)?; + result.extend(visitor.pre.drain(..).map(Statement::Instruction)); + result.push(Statement::Instruction(instruction)); + result.extend(visitor.post.drain(..).map(Statement::Instruction)); + } Statement::PtrAccess(ptr_access) => { let statement = Statement::PtrAccess(visitor.visit_ptr_access(ptr_access)?); let statement = statement.visit_map(visitor)?; @@ -293,6 +300,39 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { }) } + fn visit_mov( + &mut self, + data: ptx_parser::MovDetails, + mut arguments: ptx_parser::MovArgs, + ) -> ast::Instruction { + if let Some(remap) = self.variables.get(&arguments.src) { + match remap { + RemapAction::PreLdPostSt { .. } => {} + RemapAction::LDStSpaceChange { + name, + new_space, + old_space, + } => { + let generic_var = self + .resolver + .register_unnamed(Some((data.typ.clone(), ast::StateSpace::Reg))); + self.pre.push(ast::Instruction::Cvta { + data: ast::CvtaDetails { + state_space: *new_space, + direction: ast::CvtaDirection::ExplicitToGeneric, + }, + arguments: ast::CvtaArgs { + dst: generic_var, + src: *name, + }, + }); + arguments.src = generic_var; + } + } + } + ast::Instruction::Mov { data, arguments } + } + fn visit_variable(&mut self, var: &mut ast::Variable) -> Result<(), TranslateError> { let old_space = match var.state_space { space @ (ptx_parser::StateSpace::Reg | ptx_parser::StateSpace::Param) => space, diff --git a/ptx/src/pass/insert_implicit_conversions2.rs b/ptx/src/pass/insert_implicit_conversions2.rs index b2d3161..b1e473b 100644 --- a/ptx/src/pass/insert_implicit_conversions2.rs +++ b/ptx/src/pass/insert_implicit_conversions2.rs @@ -152,11 +152,11 @@ fn is_addressable(this: ast::StateSpace) -> bool { | ast::StateSpace::Generic | ast::StateSpace::Global | ast::StateSpace::Local - | ast::StateSpace::Shared => true, + | ast::StateSpace::Shared + | ast::StateSpace::ParamEntry => true, ast::StateSpace::Param | ast::StateSpace::Reg => false, ast::StateSpace::SharedCluster | ast::StateSpace::SharedCta - | ast::StateSpace::ParamEntry | ast::StateSpace::ParamFunc => todo!(), } } @@ -180,13 +180,14 @@ fn default_implicit_conversion_space( | ast::StateSpace::Generic | ast::StateSpace::Const | ast::StateSpace::Local - | ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)), + | ast::StateSpace::Shared + | ast::StateSpace::Param => Ok(Some(ConversionKind::BitToPtr)), _ => Err(error_mismatched_type()), }, ast::Type::Scalar(ast::ScalarType::B32) | ast::Type::Scalar(ast::ScalarType::U32) | ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space { - ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => { + ast::StateSpace::Local | ast::StateSpace::Shared => { Ok(Some(ConversionKind::BitToPtr)) } _ => Err(error_mismatched_type()), @@ -220,7 +221,7 @@ fn coerces_to_generic(this: ast::StateSpace) -> bool { ast::StateSpace::Global | ast::StateSpace::Const | ast::StateSpace::Local - | ptx_parser::StateSpace::SharedCta + | ast::StateSpace::SharedCta | ast::StateSpace::SharedCluster | ast::StateSpace::Shared => true, ast::StateSpace::Reg diff --git a/ptx/src/pass/llvm/emit.rs b/ptx/src/pass/llvm/emit.rs index 6ac676e..c811a53 100644 --- a/ptx/src/pass/llvm/emit.rs +++ b/ptx/src/pass/llvm/emit.rs @@ -704,7 +704,7 @@ impl<'a> MethodEmitContext<'a> { }); Ok(()) } - _ => todo!(), + _ => return Err(error_todo()), } } @@ -2455,7 +2455,7 @@ impl<'a> MethodEmitContext<'a> { (control >> 12) & 0b1111, ]; if components.iter().any(|&c| c > 7) { - return Err(TranslateError::Todo("".to_string())); + return Err(error_todo()); } let u32_type = get_scalar_type(self.context, ast::ScalarType::U32); let v4u8_type = get_type(self.context, &ast::Type::Vector(4, ast::ScalarType::U8))?; diff --git a/ptx/src/pass/llvm/mod.rs b/ptx/src/pass/llvm/mod.rs index 40781fc..cd1814d 100644 --- a/ptx/src/pass/llvm/mod.rs +++ b/ptx/src/pass/llvm/mod.rs @@ -169,8 +169,9 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR ast::ScalarType::F32 => unsafe { LLVMFloatTypeInContext(context) }, ast::ScalarType::F64 => unsafe { LLVMDoubleTypeInContext(context) }, ast::ScalarType::BF16 => unsafe { LLVMBFloatTypeInContext(context) }, - ast::ScalarType::U16x2 => todo!(), - ast::ScalarType::S16x2 => todo!(), + ast::ScalarType::U16x2 | ast::ScalarType::S16x2 => unsafe { + LLVMVectorType(LLVMInt16TypeInContext(context), 2) + }, ast::ScalarType::F16x2 => unsafe { LLVMVectorType(LLVMHalfTypeInContext(context), 2) }, ast::ScalarType::BF16x2 => unsafe { LLVMVectorType(LLVMBFloatTypeInContext(context), 2) }, } @@ -180,14 +181,17 @@ fn get_state_space(space: ast::StateSpace) -> Result { match space { ast::StateSpace::Reg => Ok(PRIVATE_ADDRESS_SPACE), ast::StateSpace::Generic => Ok(GENERIC_ADDRESS_SPACE), - ast::StateSpace::Param => Err(TranslateError::Todo("".to_string())), + // This is dodgy, we try our best to convert all .param into either + // .param::entry or .local, but we can't always succeed. + // In those cases we convert .param into generic address space + ast::StateSpace::Param => Ok(GENERIC_ADDRESS_SPACE), ast::StateSpace::ParamEntry => Ok(CONSTANT_ADDRESS_SPACE), - ast::StateSpace::ParamFunc => Err(TranslateError::Todo("".to_string())), + ast::StateSpace::ParamFunc => Err(error_todo()), ast::StateSpace::Local => Ok(PRIVATE_ADDRESS_SPACE), ast::StateSpace::Global => Ok(GLOBAL_ADDRESS_SPACE), ast::StateSpace::Const => Ok(CONSTANT_ADDRESS_SPACE), ast::StateSpace::Shared => Ok(SHARED_ADDRESS_SPACE), - ast::StateSpace::SharedCta => Err(TranslateError::Todo("".to_string())), - ast::StateSpace::SharedCluster => Err(TranslateError::Todo("".to_string())), + ast::StateSpace::SharedCta => Err(error_todo()), + ast::StateSpace::SharedCluster => Err(error_todo()), } } diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index f743fad..0b9ef79 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -60,31 +60,48 @@ pub struct Attributes { pub fn to_llvm_module<'input>( ast: ast::Module<'input>, attributes: Attributes, + mut on_pass_end: impl FnMut(&str), ) -> Result { let mut flat_resolver = GlobalStringIdentResolver2::<'input>::new(SpirvWord(1)); let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver); let sreg_map = SpecialRegistersMap::new(&mut scoped_resolver)?; let directives = normalize_identifiers2::run(&mut scoped_resolver, ast.directives)?; + on_pass_end("normalize_identifiers2"); let directives = replace_known_functions::run(&mut flat_resolver, directives); + on_pass_end("replace_known_functions"); let directives = normalize_predicates2::run(&mut flat_resolver, directives)?; + on_pass_end("normalize_predicates2"); let directives = resolve_function_pointers::run(directives)?; + on_pass_end("resolve_function_pointers"); let directives = fix_special_registers::run(&mut flat_resolver, &sreg_map, directives)?; + on_pass_end("fix_special_registers"); let directives = expand_operands::run(&mut flat_resolver, directives)?; + on_pass_end("expand_operands"); let directives = insert_post_saturation::run(&mut flat_resolver, directives)?; + on_pass_end("insert_post_saturation"); let directives = deparamize_functions::run(&mut flat_resolver, directives)?; + on_pass_end("deparamize_functions"); let directives = replace_instructions_with_functions_fp_required::run(&mut flat_resolver, directives)?; + on_pass_end("replace_instructions_with_functions_fp_required"); let directives = normalize_basic_blocks::run(&mut flat_resolver, directives)?; + on_pass_end("normalize_basic_blocks"); let directives = remove_unreachable_basic_blocks::run(directives)?; + on_pass_end("remove_unreachable_basic_blocks"); let directives = instruction_mode_to_global_mode::run(&mut flat_resolver, directives)?; + on_pass_end("instruction_mode_to_global_mode"); let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?; + on_pass_end("insert_explicit_load_store"); let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?; + on_pass_end("insert_implicit_conversions2"); let directives = replace_instructions_with_functions::run(&mut flat_resolver, directives)?; + on_pass_end("replace_instructions_with_functions"); let directives = hoist_globals::run(directives)?; - + on_pass_end("hoist_globals"); let context = llvm::Context::new(); let llvm_ir = llvm::emit::run(&context, flat_resolver, directives)?; let attributes_ir = llvm::attributes::run(&context, attributes)?; + on_pass_end("emit_llvm"); Ok(Module { llvm_ir, attributes_ir, diff --git a/ptx/src/test/ll/param_is_addressable.ll b/ptx/src/test/ll/param_is_addressable.ll new file mode 100644 index 0000000..6f75d5e --- /dev/null +++ b/ptx/src/test/ll/param_is_addressable.ll @@ -0,0 +1,34 @@ +define amdgpu_kernel void @param_is_addressable(ptr addrspace(4) byref(i64) %"33", ptr addrspace(4) byref(i64) %"34") #0 { + %"35" = alloca i64, align 8, addrspace(5) + %"36" = alloca i64, align 8, addrspace(5) + %"37" = alloca i64, align 8, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"32" + +"32": ; preds = %1 + %"38" = load i64, ptr addrspace(4) %"33", align 8 + store i64 %"38", ptr addrspace(5) %"35", align 8 + %"39" = load i64, ptr addrspace(4) %"34", align 8 + store i64 %"39", ptr addrspace(5) %"36", align 8 + %"49" = ptrtoint ptr addrspace(4) %"33" to i64 + %2 = inttoptr i64 %"49" to ptr addrspace(4) + %"40" = addrspacecast ptr addrspace(4) %2 to ptr + store ptr %"40", ptr addrspace(5) %"37", align 8 + %"43" = load i64, ptr addrspace(5) %"37", align 8 + %"50" = inttoptr i64 %"43" to ptr + %"42" = load i64, ptr %"50", align 8 + store i64 %"42", ptr addrspace(5) %"37", align 8 + %"45" = load i64, ptr addrspace(5) %"37", align 8 + %"46" = load i64, ptr addrspace(5) %"35", align 8 + %"51" = sub i64 %"45", %"46" + store i64 %"51", ptr addrspace(5) %"37", align 8 + %"47" = load i64, ptr addrspace(5) %"36", align 8 + %"48" = load i64, ptr addrspace(5) %"37", align 8 + %"53" = inttoptr i64 %"47" to ptr + store i64 %"48", ptr %"53", align 8 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee" "denormal-fp-math-f32"="ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" } \ No newline at end of file diff --git a/ptx/src/test/mod.rs b/ptx/src/test/mod.rs index f746d63..49c31dd 100644 --- a/ptx/src/test/mod.rs +++ b/ptx/src/test/mod.rs @@ -38,7 +38,7 @@ fn compile_and_assert(ptx_text: &str) -> Result<(), TranslateError> { let attributes = pass::Attributes { clock_rate: 2124000, }; - crate::to_llvm_module(ast, attributes)?; + crate::to_llvm_module(ast, attributes, |_| {})?; Ok(()) } diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index a6e62c2..46bdd0b 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -345,6 +345,7 @@ test_ptx!( [0x8e2da590u32, 0xedeaee14, 0x248a9f70], [613065134u32] ); +test_ptx!(param_is_addressable, [0xDEAD], [0u64]); test_ptx!(assertfail); // TODO: not yet supported @@ -562,6 +563,7 @@ fn test_hip_assert< pass::Attributes { clock_rate: 2124000, }, + |_| {}, ) .unwrap(); let name = CString::new(name)?; @@ -582,6 +584,7 @@ fn test_llvm_assert( pass::Attributes { clock_rate: 2124000, }, + |_| {}, ) .unwrap(); let actual_ll = llvm_ir.llvm_ir.print_module_to_string(); diff --git a/ptx/src/test/spirv_run/param_is_addressable.ptx b/ptx/src/test/spirv_run/param_is_addressable.ptx new file mode 100644 index 0000000..8d394b3 --- /dev/null +++ b/ptx/src/test/spirv_run/param_is_addressable.ptx @@ -0,0 +1,22 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry param_is_addressable( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .b64 temp; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + mov.b64 temp, input; + ld.param.b64 temp, [temp]; + sub.u64 temp, temp, in_addr; + st.u64 [out_addr], temp; + ret; +} diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index 98c1351..4253ae6 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -892,7 +892,7 @@ fn method_parameter<'a, 'input: 'a>( ( vector_prefix, scalar_type, - opt((Token::DotPtr, Token::DotGlobal)), + opt((Token::DotPtr, opt(Token::DotGlobal))), opt(align.verify(|x| x.count_ones() == 1)), ident, ), diff --git a/ptxas/src/main.rs b/ptxas/src/main.rs index d71ca6b..0ffe841 100644 --- a/ptxas/src/main.rs +++ b/ptxas/src/main.rs @@ -42,6 +42,7 @@ fn main() { ptx::Attributes { clock_rate: clock_rate as u32, }, + |_| {}, ) .unwrap(); let elf_binary = comgr::compile_bitcode( diff --git a/zluda/lib/OpenCL.lib b/zluda/lib/OpenCL.lib deleted file mode 100644 index 2b766ee..0000000 Binary files a/zluda/lib/OpenCL.lib and /dev/null differ diff --git a/zluda/src/impl/device.rs b/zluda/src/impl/device.rs index 03b2dcc..6816994 100644 --- a/zluda/src/impl/device.rs +++ b/zluda/src/impl/device.rs @@ -494,6 +494,10 @@ pub(crate) fn primary_context_release(hip_dev: hipDevice_t) -> CUresult { Ok(()) } +pub(crate) fn primary_context_release_v2(hip_dev: hipDevice_t) -> CUresult { + primary_context_release(hip_dev) +} + pub(crate) fn primary_context_reset(hip_dev: hipDevice_t) -> CUresult { let (ctx, _) = get_primary_context(hip_dev)?; ctx.with_state_mut(|state| state.reset())?; diff --git a/zluda/src/impl/driver.rs b/zluda/src/impl/driver.rs index dbddb1a..737f5c3 100644 --- a/zluda/src/impl/driver.rs +++ b/zluda/src/impl/driver.rs @@ -166,7 +166,9 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi { None => return CUresult::ERROR_INVALID_VALUE, }; - device::primary_context_retain(pctx, hip_dev) + let (_, cu_ctx) = device::get_primary_context(hip_dev)?; + *pctx = cu_ctx; + Ok(()) } unsafe extern "system" fn get_module_from_cubin_ext1( @@ -362,8 +364,8 @@ fn get_device_hash_info() -> Result, CUerror> { (0..device_count) .map(|dev| { - let mut guid = CUuuid_st { bytes: [0; 16] }; - unsafe { crate::cuDeviceGetUuid(&mut guid, dev)? }; + let mut guid = unsafe { mem::zeroed() }; + device::get_uuid_v2(&mut guid, dev)?; let mut pci_domain = 0; device::get_attribute( @@ -387,7 +389,7 @@ fn get_device_hash_info() -> Result, CUerror> { )?; Ok(::dark_api::DeviceHashinfo { - guid, + guid: unsafe { mem::transmute(guid) }, pci_domain, pci_bus, pci_device, @@ -525,8 +527,57 @@ pub(crate) unsafe fn launch_kernel_ex( Ok(()) } +pub(crate) unsafe fn get_error_string( + error: cuda_types::cuda::CUresult, + error_string: &mut *const ::core::ffi::c_char, +) -> CUresult { + *error_string = match error { + CUresult::SUCCESS => c"no error".as_ptr(), + CUresult::ERROR_INVALID_VALUE => c"invalid value".as_ptr(), + CUresult::ERROR_OUT_OF_MEMORY => c"out of memory".as_ptr(), + CUresult::ERROR_NOT_INITIALIZED => c"driver not initialized".as_ptr(), + CUresult::ERROR_DEINITIALIZED => c"driver deinitialized".as_ptr(), + CUresult::ERROR_NO_DEVICE => c"no CUDA-capable device is detected".as_ptr(), + CUresult::ERROR_INVALID_DEVICE => c"invalid device".as_ptr(), + CUresult::ERROR_INVALID_IMAGE => c"invalid kernel image".as_ptr(), + CUresult::ERROR_INVALID_CONTEXT => c"invalid context".as_ptr(), + CUresult::ERROR_CONTEXT_ALREADY_CURRENT => c"context already current".as_ptr(), + CUresult::ERROR_MAP_FAILED => c"map failed".as_ptr(), + CUresult::ERROR_UNMAP_FAILED => c"unmap failed".as_ptr(), + CUresult::ERROR_ARRAY_IS_MAPPED => c"array is mapped".as_ptr(), + CUresult::ERROR_ALREADY_MAPPED => c"already mapped".as_ptr(), + CUresult::ERROR_NO_BINARY_FOR_GPU => c"no binary for GPU".as_ptr(), + CUresult::ERROR_ALREADY_ACQUIRED => c"already acquired".as_ptr(), + CUresult::ERROR_NOT_MAPPED => c"not mapped".as_ptr(), + CUresult::ERROR_NOT_SUPPORTED => c"operation not supported".as_ptr(), + CUresult::ERROR_INVALID_SOURCE => c"invalid source".as_ptr(), + CUresult::ERROR_FILE_NOT_FOUND => c"file not found".as_ptr(), + CUresult::ERROR_INVALID_HANDLE => c"invalid handle".as_ptr(), + CUresult::ERROR_NOT_READY => c"not ready".as_ptr(), + CUresult::ERROR_ILLEGAL_ADDRESS => c"illegal address".as_ptr(), + CUresult::ERROR_LAUNCH_OUT_OF_RESOURCES => c"launch out of resources".as_ptr(), + CUresult::ERROR_LAUNCH_TIMEOUT => c"launch timeout".as_ptr(), + CUresult::ERROR_LAUNCH_INCOMPATIBLE_TEXTURING => c"launch incompatible texturing".as_ptr(), + CUresult::ERROR_PEER_ACCESS_ALREADY_ENABLED => c"peer access already enabled".as_ptr(), + CUresult::ERROR_PEER_ACCESS_NOT_ENABLED => c"peer access not enabled".as_ptr(), + CUresult::ERROR_PRIMARY_CONTEXT_ACTIVE => c"primary context active".as_ptr(), + CUresult::ERROR_CONTEXT_IS_DESTROYED => c"context is destroyed".as_ptr(), + CUresult::ERROR_ASSERT => c"device-side assert triggered".as_ptr(), + CUresult::ERROR_TOO_MANY_PEERS => c"too many peers".as_ptr(), + CUresult::ERROR_HOST_MEMORY_ALREADY_REGISTERED => { + c"host memory already registered".as_ptr() + } + CUresult::ERROR_HOST_MEMORY_NOT_REGISTERED => c"host memory not registered".as_ptr(), + CUresult::ERROR_UNKNOWN => c"unknown error".as_ptr(), + _ => c"error".as_ptr(), + }; + Ok(()) +} + #[cfg(test)] mod tests { + use std::i32; + use crate::r#impl::driver::AllocationInfo; use crate::tests::CudaApi; use cuda_macros::test_cuda; @@ -571,4 +622,34 @@ mod tests { } assert_eq!(alloc_info.get_offset_and_info(0x2000 + 8), None); } + + #[test_cuda] + fn primary_context_is_inactive_on_init(api: impl CudaApi) { + api.cuInit(0); + let mut flags = u32::MAX; + let mut active = i32::MAX; + api.cuDevicePrimaryCtxGetState(0, &mut flags, &mut active); + assert_eq!(flags, 0); + assert_eq!(active, 0); + } + + #[test_cuda] + unsafe fn cudart_interface_fn2_creates_inactive_primary_ctx(api: impl CudaApi) { + api.cuInit(0); + let mut table_ptr = std::ptr::null(); + api.cuGetExportTable(&mut table_ptr, &dark_api::cuda::CudartInterface::GUID); + let cuda_rt_iface = dark_api::cuda::CudartInterface::new(table_ptr); + let mut dark_ctx = std::mem::zeroed(); + cuda_rt_iface + .cudart_interface_fn2(&mut dark_ctx, 0) + .unwrap(); + let mut flags = u32::MAX; + let mut active = i32::MAX; + api.cuDevicePrimaryCtxGetState(0, &mut flags, &mut active); + assert_eq!(flags, 0); + assert_eq!(active, 0); + let mut primary_ctx = std::mem::zeroed(); + api.cuDevicePrimaryCtxRetain(&mut primary_ctx, 0); + assert_eq!(dark_ctx.0, primary_ctx.0); + } } diff --git a/zluda/src/impl/library.rs b/zluda/src/impl/library.rs index 6ca3745..2fcc56a 100644 --- a/zluda/src/impl/library.rs +++ b/zluda/src/impl/library.rs @@ -55,12 +55,11 @@ pub(crate) unsafe fn load_data( _jit_options_values: Option<&mut *mut ::core::ffi::c_void>, _num_jit_options: ::core::ffi::c_uint, library_options: Option<&mut CUlibraryOption>, - library_option_values: Option<&mut *mut ::core::ffi::c_void>, + _library_option_values: Option<&mut *mut ::core::ffi::c_void>, num_library_options: ::core::ffi::c_uint, ) -> CUresult { let global_state = driver::global_state()?; - let options = - LibraryOptions::load(library_options, library_option_values, num_library_options)?; + let options = LibraryOptions::load(library_options, num_library_options)?; let library = Library { data: LibraryData::new(code as *mut c_void, options.preserve_binary)?, modules: vec![OnceLock::new(); global_state.devices.len()], @@ -76,7 +75,6 @@ struct LibraryOptions { impl LibraryOptions { unsafe fn load( library_options: Option<&mut CUlibraryOption>, - library_option_values: Option<&mut *mut ::core::ffi::c_void>, num_library_options: ::core::ffi::c_uint, ) -> Result { if num_library_options == 0 { @@ -84,28 +82,17 @@ impl LibraryOptions { preserve_binary: false, }); } - let (library_options, library_option_values) = - match (library_options, library_option_values) { - (Some(library_options), Some(library_option_values)) => { - let library_options = - std::slice::from_raw_parts(library_options, num_library_options as usize); - let library_option_values = std::slice::from_raw_parts( - library_option_values, - num_library_options as usize, - ); - (library_options, library_option_values) - } - _ => return Err(CUerror::INVALID_VALUE), - }; + let library_options = match library_options { + Some(library_options) => { + std::slice::from_raw_parts(library_options, num_library_options as usize) + } + _ => return Err(CUerror::INVALID_VALUE), + }; let mut preserve_binary = false; - for (option, value) in library_options - .iter() - .copied() - .zip(library_option_values.iter()) - { + for option in library_options.iter().copied() { match option { CUlibraryOption::CU_LIBRARY_BINARY_IS_PRESERVED => { - preserve_binary = *(value.cast::()); + preserve_binary = true; } _ => return Err(CUerror::NOT_SUPPORTED), } @@ -156,10 +143,7 @@ mod tests { use crate::tests::CudaApi; use cuda_macros::test_cuda; use cuda_types::cuda::{CUlibraryOption, CUresult, CUresultConsts}; - use std::{ - ffi::{c_void, CStr}, - mem, ptr, - }; + use std::{ffi::CStr, mem, ptr}; #[test_cuda] unsafe fn library_loads_without_context(api: impl CudaApi) { @@ -191,7 +175,7 @@ mod tests { ptr::null_mut(), 0, [CUlibraryOption::CU_LIBRARY_BINARY_IS_PRESERVED].as_mut_ptr(), - [(&true as *const bool) as *mut c_void].as_mut_ptr(), + ptr::null_mut(), 1, ); assert_eq!( diff --git a/zluda/src/impl/memory.rs b/zluda/src/impl/memory.rs index 9cf83b8..70395ed 100644 --- a/zluda/src/impl/memory.rs +++ b/zluda/src/impl/memory.rs @@ -153,3 +153,11 @@ pub(crate) unsafe fn set_d8_async( ) -> hipError_t { hipMemsetD8Async(dst_device, uc, n, stream) } + +pub(crate) fn get_allocation_granularity( + _granularity: &mut usize, + _property: &cuda_types::cuda::CUmemAllocationProp, + _option: cuda_types::cuda::CUmemAllocationGranularity_flags, +) -> CUresult { + CUresult::ERROR_NOT_SUPPORTED +} diff --git a/zluda/src/impl/module.rs b/zluda/src/impl/module.rs index ae9dcf1..506f824 100644 --- a/zluda/src/impl/module.rs +++ b/zluda/src/impl/module.rs @@ -138,7 +138,7 @@ fn compile_from_ptx_and_cache( } else { ptx_parser::parse_module_unchecked(text) }; - let llvm_module = ptx::to_llvm_module(ast, attributes).map_err(|_| CUerror::UNKNOWN)?; + let llvm_module = ptx::to_llvm_module(ast, attributes, |_| {}).map_err(|_| CUerror::UNKNOWN)?; let elf_module = comgr::compile_bitcode( comgr, gcn_arch, diff --git a/zluda/src/lib.rs b/zluda/src/lib.rs index 8d06f10..14b7bad 100644 --- a/zluda/src/lib.rs +++ b/zluda/src/lib.rs @@ -93,6 +93,7 @@ cuda_macros::cuda_function_declarations!( cuDeviceGetUuid_v2, cuDevicePrimaryCtxGetState, cuDevicePrimaryCtxRelease, + cuDevicePrimaryCtxRelease_v2, cuDevicePrimaryCtxReset, cuDevicePrimaryCtxRetain, cuDeviceTotalMem_v2, @@ -104,6 +105,7 @@ cuda_macros::cuda_function_declarations!( cuEventSynchronize, cuFuncGetAttribute, cuFuncSetAttribute, + cuGetErrorString, cuGetExportTable, cuGetProcAddress, cuGetProcAddress_v2, @@ -126,6 +128,7 @@ cuda_macros::cuda_function_declarations!( cuMemFreeHost, cuMemFree_v2, cuMemGetAddressRange_v2, + cuMemGetAllocationGranularity, cuMemGetInfo_v2, cuMemHostAlloc, cuMemRetainAllocationHandle, diff --git a/zluda_common/src/lib.rs b/zluda_common/src/lib.rs index c857f6b..95ec415 100644 --- a/zluda_common/src/lib.rs +++ b/zluda_common/src/lib.rs @@ -165,7 +165,10 @@ from_cuda_nop!( nvmlDevice_t, nvmlFieldValue_t, nvmlGpuFabricInfo_t, - cublasLtHandle_t + cublasLtHandle_t, + CUmemAllocationGranularity_flags, + CUmemAllocationProp, + CUresult ); from_cuda_transmute!( CUuuid => hipUUID,