Implement cuGetProcAddress and cuGetProcAddress_v2 (#377)

This commit is contained in:
Violet 2025-06-10 16:07:35 -07:00 committed by GitHub
commit 62f3e63355
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 72 additions and 3 deletions

View file

@ -5,8 +5,9 @@ use cuda_types::cuda::*;
use hip_runtime_sys::*;
use std::{
ffi::{CStr, CString},
mem, slice,
mem, ptr, slice,
sync::OnceLock,
usize,
};
pub(crate) struct GlobalState {
@ -79,3 +80,51 @@ pub(crate) fn get_version(version: &mut ::core::ffi::c_int) -> CUresult {
*version = cuda_types::cuda::CUDA_VERSION as i32;
Ok(())
}
pub(crate) unsafe fn get_proc_address(
symbol: &CStr,
pfn: &mut *mut ::core::ffi::c_void,
cuda_version: ::core::ffi::c_int,
flags: cuda_types::cuda::cuuint64_t,
) -> CUresult {
get_proc_address_v2(symbol, pfn, cuda_version, flags, None)
}
pub(crate) unsafe fn get_proc_address_v2(
symbol: &CStr,
pfn: &mut *mut ::core::ffi::c_void,
cuda_version: ::core::ffi::c_int,
flags: cuda_types::cuda::cuuint64_t,
symbol_status: Option<&mut cuda_types::cuda::CUdriverProcAddressQueryResult>,
) -> CUresult {
// This implementation is mostly the same as cuGetProcAddress_v2 in zluda_dump. We may want to factor out the duplication at some point.
fn raw_match(name: &[u8], flag: u64, version: i32) -> *mut ::core::ffi::c_void {
use crate::*;
include!("../../../zluda_bindgen/src/process_table.rs")
}
let fn_ptr = raw_match(symbol.to_bytes(), flags, cuda_version);
match fn_ptr as usize {
0 => {
if let Some(symbol_status) = symbol_status {
*symbol_status = cuda_types::cuda::CUdriverProcAddressQueryResult::CU_GET_PROC_ADDRESS_SYMBOL_NOT_FOUND;
}
*pfn = ptr::null_mut();
CUresult::ERROR_NOT_FOUND
}
usize::MAX => {
if let Some(symbol_status) = symbol_status {
*symbol_status = cuda_types::cuda::CUdriverProcAddressQueryResult::CU_GET_PROC_ADDRESS_VERSION_NOT_SUFFICIENT;
}
*pfn = ptr::null_mut();
CUresult::ERROR_NOT_FOUND
}
_ => {
if let Some(symbol_status) = symbol_status {
*symbol_status =
cuda_types::cuda::CUdriverProcAddressQueryResult::CU_GET_PROC_ADDRESS_SUCCESS;
}
*pfn = fn_ptr;
Ok(())
}
}
}

View file

@ -1,6 +1,6 @@
use cuda_types::cuda::*;
use hip_runtime_sys::*;
use std::mem::{self, ManuallyDrop, MaybeUninit};
use std::{ffi::CStr, mem::{self, ManuallyDrop, MaybeUninit}, ptr};
pub(super) mod context;
pub(super) mod device;
@ -42,6 +42,12 @@ macro_rules! from_cuda_nop {
}
}
}
impl<'a> FromCuda<'a, *mut $type_> for Option<&'a mut $type_> {
fn from_cuda(x: &'a *mut $type_) -> Result<Self, CUerror> {
Ok(unsafe { x.as_mut() })
}
}
)*
};
}
@ -111,9 +117,11 @@ from_cuda_nop!(
u8,
i32,
u32,
u64,
usize,
cuda_types::cuda::CUdevprop,
CUdevice_attribute
CUdevice_attribute,
CUdriverProcAddressQueryResult
);
from_cuda_transmute!(
CUuuid => hipUUID,
@ -136,6 +144,16 @@ impl<'a> FromCuda<'a, CUlimit> for hipLimit_t {
}
}
impl<'a> FromCuda<'a, *const ::core::ffi::c_char> for &CStr {
fn from_cuda(s: &'a *const ::core::ffi::c_char) -> Result<Self, CUerror> {
if *s != ptr::null() {
Ok(unsafe { CStr::from_ptr(*s) })
} else {
Err(CUerror::INVALID_VALUE)
}
}
}
pub(crate) trait ZludaObject: Sized + Send + Sync {
const COOKIE: usize;
const LIVENESS_FAIL: CUerror = cuda_types::cuda::CUerror::INVALID_VALUE;

View file

@ -62,6 +62,8 @@ cuda_base::cuda_function_declarations!(
cuDeviceTotalMem_v2,
cuDriverGetVersion,
cuFuncGetAttribute,
cuGetProcAddress,
cuGetProcAddress_v2,
cuInit,
cuMemAlloc_v2,
cuMemFree_v2,