diff --git a/Cargo.lock b/Cargo.lock index 3f883dd..cabec9f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2565,7 +2565,6 @@ dependencies = [ "ptx_parser", "quick-error", "rustc-hash 2.0.0", - "serde", "smallvec", "strum 0.26.3", "strum_macros 0.26.4", diff --git a/comgr/src/lib.rs b/comgr/src/lib.rs index 9c5671b..8ba0a87 100644 --- a/comgr/src/lib.rs +++ b/comgr/src/lib.rs @@ -237,8 +237,8 @@ pub fn compile_bitcode( ] .into_iter(); let opt_options = if cfg!(debug_assertions) { - //[c"-g", c"-mllvm", c"-print-before-all", c"", c""] - [c"-g", c"", c"", c"", c""] + //[c"-g", c"-mamdgpu-precise-memory-op", c"-mllvm", c"-print-before-all", c""] + [c"-g", c"-mamdgpu-precise-memory-op", c"", c"", c""] } else { [ c"-g0", diff --git a/ptx/Cargo.toml b/ptx/Cargo.toml index c9a5a6b..7ee6e43 100644 --- a/ptx/Cargo.toml +++ b/ptx/Cargo.toml @@ -22,7 +22,6 @@ microlp = "0.2.11" int-enum = "1.1" unwrap_or = "1.0.1" smallvec = "1.15.1" -serde = { version = "1.0.219", features = ["derive"] } [dev-dependencies] hip_runtime-sys = { path = "../ext/hip_runtime-sys" } diff --git a/ptx/src/pass/llvm/emit.rs b/ptx/src/pass/llvm/emit.rs index c811a53..baaff6a 100644 --- a/ptx/src/pass/llvm/emit.rs +++ b/ptx/src/pass/llvm/emit.rs @@ -153,6 +153,9 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { for (i, param) in method.input_arguments.iter().enumerate() { let value = unsafe { LLVMGetParam(fn_, i as u32) }; let name = self.resolver.get_or_add(param.name); + if let Some(align) = param.align { + unsafe { LLVMSetParamAlignment(value, align) }; + } unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) }; self.resolver.register(param.name, value); if method.is_kernel { diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 0b9ef79..4f87dc3 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -51,7 +51,7 @@ quick_error! { } /// GPU attributes needed at compile time. -#[derive(serde::Serialize)] +#[derive(Copy, Clone)] pub struct Attributes { /// Clock frequency in kHz. pub clock_rate: u32, diff --git a/zluda/src/impl/hipfix.rs b/zluda/src/impl/hipfix.rs new file mode 100644 index 0000000..f957849 --- /dev/null +++ b/zluda/src/impl/hipfix.rs @@ -0,0 +1,12 @@ +// There's a bug in hipDrvPointerGetAttributes where it returns +// HIP_ERROR_INVALID_VALUE if the pointer is null. It works correctly for any +// other invalid pointer +pub(crate) fn get_attributes( + ptr: hip_runtime_sys::hipDeviceptr_t, +) -> hip_runtime_sys::hipDeviceptr_t { + if ptr.0.is_null() { + hip_runtime_sys::hipDeviceptr_t(usize::MAX as _) + } else { + ptr + } +} diff --git a/zluda/src/impl/mod.rs b/zluda/src/impl/mod.rs index f73a972..60ecb80 100644 --- a/zluda/src/impl/mod.rs +++ b/zluda/src/impl/mod.rs @@ -7,6 +7,7 @@ pub(super) mod driver; pub(super) mod event; pub(super) mod function; pub(super) mod graph; +pub(super) mod hipfix; pub(super) mod kernel; pub(super) mod library; pub(super) mod memory; diff --git a/zluda/src/impl/module.rs b/zluda/src/impl/module.rs index 506f824..669fe1f 100644 --- a/zluda/src/impl/module.rs +++ b/zluda/src/impl/module.rs @@ -61,8 +61,9 @@ pub(crate) fn load_hip_module(library: CodeLibraryRef) -> Result Result() -> Result { let hip_dev = super::context::get_current_device()?; let mut props = unsafe { mem::zeroed() }; @@ -100,7 +107,7 @@ fn get_cache_key<'a, 'b>( global_state: &'static driver::GlobalState, isa: &'a str, text: &str, - attributes: &ptx::Attributes, + attributes: &impl serde::Serialize, ) -> Option> { // Serialization here is deterministic. When marking a type with // #[derive(serde::Serialize)] the derived implementation will just write @@ -129,7 +136,7 @@ fn load_cached_binary( fn compile_from_ptx_and_cache( comgr: &comgr::Comgr, gcn_arch: &str, - attributes: ptx::Attributes, + attributes: ExtraCacheAttributes, text: &str, cache_with_key: &mut Option<(zluda_cache::ModuleCache, zluda_cache::ModuleKey)>, ) -> Result, CUerror> { @@ -138,7 +145,14 @@ fn compile_from_ptx_and_cache( } else { ptx_parser::parse_module_unchecked(text) }; - let llvm_module = ptx::to_llvm_module(ast, attributes, |_| {}).map_err(|_| CUerror::UNKNOWN)?; + let llvm_module = ptx::to_llvm_module( + ast, + ptx::Attributes { + clock_rate: attributes.clock_rate, + }, + |_| {}, + ) + .map_err(|_| CUerror::UNKNOWN)?; let elf_module = comgr::compile_bitcode( comgr, gcn_arch, diff --git a/zluda/src/impl/pointer.rs b/zluda/src/impl/pointer.rs index 8eda15e..6541fce 100644 --- a/zluda/src/impl/pointer.rs +++ b/zluda/src/impl/pointer.rs @@ -2,7 +2,7 @@ use cuda_types::cuda::*; use hip_runtime_sys::*; use std::{ffi::c_void, ptr}; -use crate::r#impl::driver; +use crate::r#impl::{driver, hipfix}; // TODO: handlehipMemoryTypeUnregistered fn to_cu_memory_type(cu: hipMemoryType) -> Result { @@ -59,7 +59,12 @@ pub(crate) unsafe fn get_attributes( data: &mut *mut ::core::ffi::c_void, ptr: hipDeviceptr_t, ) -> CUresult { - hipDrvPointerGetAttributes(num_attributes, attributes, data, ptr)?; + hipDrvPointerGetAttributes( + num_attributes, + attributes, + data, + hipfix::get_attributes(ptr), + )?; let attributes = std::slice::from_raw_parts_mut(attributes, num_attributes as usize); let data = std::slice::from_raw_parts_mut(data, num_attributes as usize); for (attr, data_ptr) in attributes.iter().copied().zip(data.iter().copied()) { @@ -88,7 +93,7 @@ mod tests { use crate::tests::CudaApi; use cuda_macros::test_cuda; use cuda_types::cuda::*; - use std::{ffi::c_void, mem, ptr}; + use std::{ffi::c_void, i32, mem, ptr, usize}; #[test_cuda] pub unsafe fn unknown_ptr_attribute(api: impl CudaApi) { @@ -162,4 +167,47 @@ mod tests { ); assert_eq!(context, CUcontext(ptr::null_mut())); } + + #[test_cuda] + pub unsafe fn null_ptr_attributes_success(api: impl CudaApi) { + api.cuInit(0); + api.cuCtxCreate_v2(&mut mem::zeroed(), 0, 0); + let mut context = CUcontext(1 as _); + let mut mem_type = mem::transmute::<_, CUmemorytype>(u32::MAX); + let mut dev_ptr = mem::transmute::<_, *mut c_void>(usize::MAX); + let mut host_ptr = mem::transmute::<_, *mut c_void>(usize::MAX); + let mut is_managed = true; + let mut ordinal = i32::MAX; + let mut attrs = [ + CUpointer_attribute::CU_POINTER_ATTRIBUTE_CONTEXT, + CUpointer_attribute::CU_POINTER_ATTRIBUTE_MEMORY_TYPE, + CUpointer_attribute::CU_POINTER_ATTRIBUTE_DEVICE_POINTER, + CUpointer_attribute::CU_POINTER_ATTRIBUTE_HOST_POINTER, + CUpointer_attribute::CU_POINTER_ATTRIBUTE_IS_MANAGED, + CUpointer_attribute::CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL, + ]; + let mut values = [ + std::ptr::from_mut(&mut context).cast::(), + std::ptr::from_mut(&mut mem_type).cast::(), + std::ptr::from_mut(&mut dev_ptr).cast::(), + std::ptr::from_mut(&mut host_ptr).cast::(), + std::ptr::from_mut(&mut is_managed).cast::(), + std::ptr::from_mut(&mut ordinal).cast::(), + ]; + assert_eq!( + CUresult::SUCCESS, + api.cuPointerGetAttributes_unchecked( + attrs.len() as u32, + attrs.as_mut_ptr(), + values.as_mut_ptr(), + CUdeviceptr_v2(ptr::null_mut()) + ) + ); + assert_eq!(context, CUcontext(ptr::null_mut())); + assert_eq!(mem_type, CUmemorytype(0)); + assert_eq!(dev_ptr, ptr::null_mut()); + assert_eq!(host_ptr, ptr::null_mut()); + assert_eq!(is_managed, false); + assert_eq!(ordinal, -2); + } }