More compiler fixes

This commit is contained in:
Andrzej Janik 2025-09-16 16:47:22 +00:00
commit ff41e760d1
9 changed files with 88 additions and 12 deletions

1
Cargo.lock generated
View file

@ -2565,7 +2565,6 @@ dependencies = [
"ptx_parser", "ptx_parser",
"quick-error", "quick-error",
"rustc-hash 2.0.0", "rustc-hash 2.0.0",
"serde",
"smallvec", "smallvec",
"strum 0.26.3", "strum 0.26.3",
"strum_macros 0.26.4", "strum_macros 0.26.4",

View file

@ -237,8 +237,8 @@ pub fn compile_bitcode(
] ]
.into_iter(); .into_iter();
let opt_options = if cfg!(debug_assertions) { let opt_options = if cfg!(debug_assertions) {
//[c"-g", c"-mllvm", c"-print-before-all", c"", c""] //[c"-g", c"-mamdgpu-precise-memory-op", c"-mllvm", c"-print-before-all", c""]
[c"-g", c"", c"", c"", c""] [c"-g", c"-mamdgpu-precise-memory-op", c"", c"", c""]
} else { } else {
[ [
c"-g0", c"-g0",

View file

@ -22,7 +22,6 @@ microlp = "0.2.11"
int-enum = "1.1" int-enum = "1.1"
unwrap_or = "1.0.1" unwrap_or = "1.0.1"
smallvec = "1.15.1" smallvec = "1.15.1"
serde = { version = "1.0.219", features = ["derive"] }
[dev-dependencies] [dev-dependencies]
hip_runtime-sys = { path = "../ext/hip_runtime-sys" } hip_runtime-sys = { path = "../ext/hip_runtime-sys" }

View file

@ -153,6 +153,9 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
for (i, param) in method.input_arguments.iter().enumerate() { for (i, param) in method.input_arguments.iter().enumerate() {
let value = unsafe { LLVMGetParam(fn_, i as u32) }; let value = unsafe { LLVMGetParam(fn_, i as u32) };
let name = self.resolver.get_or_add(param.name); let name = self.resolver.get_or_add(param.name);
if let Some(align) = param.align {
unsafe { LLVMSetParamAlignment(value, align) };
}
unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) }; unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) };
self.resolver.register(param.name, value); self.resolver.register(param.name, value);
if method.is_kernel { if method.is_kernel {

View file

@ -51,7 +51,7 @@ quick_error! {
} }
/// GPU attributes needed at compile time. /// GPU attributes needed at compile time.
#[derive(serde::Serialize)] #[derive(Copy, Clone)]
pub struct Attributes { pub struct Attributes {
/// Clock frequency in kHz. /// Clock frequency in kHz.
pub clock_rate: u32, pub clock_rate: u32,

12
zluda/src/impl/hipfix.rs Normal file
View file

@ -0,0 +1,12 @@
// There's a bug in hipDrvPointerGetAttributes where it returns
// HIP_ERROR_INVALID_VALUE if the pointer is null. It works correctly for any
// other invalid pointer
pub(crate) fn get_attributes(
ptr: hip_runtime_sys::hipDeviceptr_t,
) -> hip_runtime_sys::hipDeviceptr_t {
if ptr.0.is_null() {
hip_runtime_sys::hipDeviceptr_t(usize::MAX as _)
} else {
ptr
}
}

View file

@ -7,6 +7,7 @@ pub(super) mod driver;
pub(super) mod event; pub(super) mod event;
pub(super) mod function; pub(super) mod function;
pub(super) mod graph; pub(super) mod graph;
pub(super) mod hipfix;
pub(super) mod kernel; pub(super) mod kernel;
pub(super) mod library; pub(super) mod library;
pub(super) mod memory; pub(super) mod memory;

View file

@ -61,8 +61,9 @@ pub(crate) fn load_hip_module(library: CodeLibraryRef) -> Result<hipModule_t, CU
let text = get_ptx(library)?; let text = get_ptx(library)?;
let hip_properties = get_hip_properties()?; let hip_properties = get_hip_properties()?;
let gcn_arch = get_gcn_arch(&hip_properties)?; let gcn_arch = get_gcn_arch(&hip_properties)?;
let attributes = ptx::Attributes { let attributes = ExtraCacheAttributes {
clock_rate: hip_properties.clockRate as u32, clock_rate: hip_properties.clockRate as u32,
is_debug: cfg!(debug_assertions),
}; };
let mut cache_with_key = global_state.cache_path.as_ref().and_then(|p| { let mut cache_with_key = global_state.cache_path.as_ref().and_then(|p| {
let cache = zluda_cache::ModuleCache::open(p)?; let cache = zluda_cache::ModuleCache::open(p)?;
@ -84,6 +85,12 @@ pub(crate) fn load_hip_module(library: CodeLibraryRef) -> Result<hipModule_t, CU
Ok(hip_module) Ok(hip_module)
} }
#[derive(serde::Serialize)]
struct ExtraCacheAttributes {
is_debug: bool,
clock_rate: u32,
}
fn get_hip_properties<'a>() -> Result<hipDeviceProp_tR0600, CUerror> { fn get_hip_properties<'a>() -> Result<hipDeviceProp_tR0600, CUerror> {
let hip_dev = super::context::get_current_device()?; let hip_dev = super::context::get_current_device()?;
let mut props = unsafe { mem::zeroed() }; let mut props = unsafe { mem::zeroed() };
@ -100,7 +107,7 @@ fn get_cache_key<'a, 'b>(
global_state: &'static driver::GlobalState, global_state: &'static driver::GlobalState,
isa: &'a str, isa: &'a str,
text: &str, text: &str,
attributes: &ptx::Attributes, attributes: &impl serde::Serialize,
) -> Option<zluda_cache::ModuleKey<'a>> { ) -> Option<zluda_cache::ModuleKey<'a>> {
// Serialization here is deterministic. When marking a type with // Serialization here is deterministic. When marking a type with
// #[derive(serde::Serialize)] the derived implementation will just write // #[derive(serde::Serialize)] the derived implementation will just write
@ -129,7 +136,7 @@ fn load_cached_binary(
fn compile_from_ptx_and_cache( fn compile_from_ptx_and_cache(
comgr: &comgr::Comgr, comgr: &comgr::Comgr,
gcn_arch: &str, gcn_arch: &str,
attributes: ptx::Attributes, attributes: ExtraCacheAttributes,
text: &str, text: &str,
cache_with_key: &mut Option<(zluda_cache::ModuleCache, zluda_cache::ModuleKey)>, cache_with_key: &mut Option<(zluda_cache::ModuleCache, zluda_cache::ModuleKey)>,
) -> Result<Vec<u8>, CUerror> { ) -> Result<Vec<u8>, CUerror> {
@ -138,7 +145,14 @@ fn compile_from_ptx_and_cache(
} else { } else {
ptx_parser::parse_module_unchecked(text) ptx_parser::parse_module_unchecked(text)
}; };
let llvm_module = ptx::to_llvm_module(ast, attributes, |_| {}).map_err(|_| CUerror::UNKNOWN)?; let llvm_module = ptx::to_llvm_module(
ast,
ptx::Attributes {
clock_rate: attributes.clock_rate,
},
|_| {},
)
.map_err(|_| CUerror::UNKNOWN)?;
let elf_module = comgr::compile_bitcode( let elf_module = comgr::compile_bitcode(
comgr, comgr,
gcn_arch, gcn_arch,

View file

@ -2,7 +2,7 @@ use cuda_types::cuda::*;
use hip_runtime_sys::*; use hip_runtime_sys::*;
use std::{ffi::c_void, ptr}; use std::{ffi::c_void, ptr};
use crate::r#impl::driver; use crate::r#impl::{driver, hipfix};
// TODO: handlehipMemoryTypeUnregistered // TODO: handlehipMemoryTypeUnregistered
fn to_cu_memory_type(cu: hipMemoryType) -> Result<CUmemorytype, hipErrorCode_t> { fn to_cu_memory_type(cu: hipMemoryType) -> Result<CUmemorytype, hipErrorCode_t> {
@ -59,7 +59,12 @@ pub(crate) unsafe fn get_attributes(
data: &mut *mut ::core::ffi::c_void, data: &mut *mut ::core::ffi::c_void,
ptr: hipDeviceptr_t, ptr: hipDeviceptr_t,
) -> CUresult { ) -> CUresult {
hipDrvPointerGetAttributes(num_attributes, attributes, data, ptr)?; hipDrvPointerGetAttributes(
num_attributes,
attributes,
data,
hipfix::get_attributes(ptr),
)?;
let attributes = std::slice::from_raw_parts_mut(attributes, num_attributes as usize); let attributes = std::slice::from_raw_parts_mut(attributes, num_attributes as usize);
let data = std::slice::from_raw_parts_mut(data, num_attributes as usize); let data = std::slice::from_raw_parts_mut(data, num_attributes as usize);
for (attr, data_ptr) in attributes.iter().copied().zip(data.iter().copied()) { for (attr, data_ptr) in attributes.iter().copied().zip(data.iter().copied()) {
@ -88,7 +93,7 @@ mod tests {
use crate::tests::CudaApi; use crate::tests::CudaApi;
use cuda_macros::test_cuda; use cuda_macros::test_cuda;
use cuda_types::cuda::*; use cuda_types::cuda::*;
use std::{ffi::c_void, mem, ptr}; use std::{ffi::c_void, i32, mem, ptr, usize};
#[test_cuda] #[test_cuda]
pub unsafe fn unknown_ptr_attribute(api: impl CudaApi) { pub unsafe fn unknown_ptr_attribute(api: impl CudaApi) {
@ -162,4 +167,47 @@ mod tests {
); );
assert_eq!(context, CUcontext(ptr::null_mut())); assert_eq!(context, CUcontext(ptr::null_mut()));
} }
#[test_cuda]
pub unsafe fn null_ptr_attributes_success(api: impl CudaApi) {
api.cuInit(0);
api.cuCtxCreate_v2(&mut mem::zeroed(), 0, 0);
let mut context = CUcontext(1 as _);
let mut mem_type = mem::transmute::<_, CUmemorytype>(u32::MAX);
let mut dev_ptr = mem::transmute::<_, *mut c_void>(usize::MAX);
let mut host_ptr = mem::transmute::<_, *mut c_void>(usize::MAX);
let mut is_managed = true;
let mut ordinal = i32::MAX;
let mut attrs = [
CUpointer_attribute::CU_POINTER_ATTRIBUTE_CONTEXT,
CUpointer_attribute::CU_POINTER_ATTRIBUTE_MEMORY_TYPE,
CUpointer_attribute::CU_POINTER_ATTRIBUTE_DEVICE_POINTER,
CUpointer_attribute::CU_POINTER_ATTRIBUTE_HOST_POINTER,
CUpointer_attribute::CU_POINTER_ATTRIBUTE_IS_MANAGED,
CUpointer_attribute::CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
];
let mut values = [
std::ptr::from_mut(&mut context).cast::<c_void>(),
std::ptr::from_mut(&mut mem_type).cast::<c_void>(),
std::ptr::from_mut(&mut dev_ptr).cast::<c_void>(),
std::ptr::from_mut(&mut host_ptr).cast::<c_void>(),
std::ptr::from_mut(&mut is_managed).cast::<c_void>(),
std::ptr::from_mut(&mut ordinal).cast::<c_void>(),
];
assert_eq!(
CUresult::SUCCESS,
api.cuPointerGetAttributes_unchecked(
attrs.len() as u32,
attrs.as_mut_ptr(),
values.as_mut_ptr(),
CUdeviceptr_v2(ptr::null_mut())
)
);
assert_eq!(context, CUcontext(ptr::null_mut()));
assert_eq!(mem_type, CUmemorytype(0));
assert_eq!(dev_ptr, ptr::null_mut());
assert_eq!(host_ptr, ptr::null_mut());
assert_eq!(is_managed, false);
assert_eq!(ordinal, -2);
}
} }