Add more missing host-side code

This commit is contained in:
Andrzej Janik 2024-11-25 06:17:14 +01:00
parent c461cefd7d
commit 502b0c957e
12 changed files with 348 additions and 371 deletions

View file

@ -150,6 +150,10 @@ impl VisitMut for FixFnSignatures {
}
}
const MODULES: &[&str] = &[
"context", "device", "driver", "function", "link", "memory", "module", "pointer",
];
#[proc_macro]
pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream {
let mut path = parse_macro_input!(tokens as syn::Path);
@ -161,8 +165,9 @@ pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream {
.0
.ident
.to_string();
let already_has_module = MODULES.contains(&&*path.segments.last().unwrap().ident.to_string());
let segments: Vec<String> = split(&fn_[2..]); // skip "cu"
let fn_path = join(segments);
let fn_path = join(segments, !already_has_module);
quote! {
#path #fn_path
}
@ -181,23 +186,16 @@ fn split(fn_: &str) -> Vec<String> {
result
}
fn join(fn_: Vec<String>) -> Punctuated<Ident, Token![::]> {
fn join(fn_: Vec<String>, find_module: bool) -> Punctuated<Ident, Token![::]> {
fn full_form(segment: &str) -> Option<&[&str]> {
Some(match segment {
"ctx" => &["context"],
"func" => &["function"],
"mem" => &["memory"],
"memcpy" => &["memory", "copy"],
_ => return None,
})
}
const MODULES: &[&str] = &[
"context",
"device",
"function",
"link",
"memory",
"module",
"pointer"
];
let mut normalized: Vec<&str> = Vec::new();
for segment in fn_.iter() {
match full_form(segment) {
@ -205,18 +203,20 @@ fn join(fn_: Vec<String>) -> Punctuated<Ident, Token![::]> {
None => normalized.push(&*segment),
}
}
if !find_module {
return [Ident::new(&normalized.join("_"), Span::call_site())]
.into_iter()
.collect();
}
if !MODULES.contains(&normalized[0]) {
let mut globalized = vec!["global"];
let mut globalized = vec!["driver"];
globalized.extend(normalized);
normalized = globalized;
}
let (module, path) = normalized.split_first().unwrap();
let path = path.join("_");
let mut result = Punctuated::new();
result.extend(
[module, &&*path]
.into_iter()
.map(|s| Ident::new(s, Span::call_site())),
);
result
[module, &&*path]
.into_iter()
.map(|s| Ident::new(s, Span::call_site()))
.collect()
}

View file

@ -20,6 +20,7 @@ num_enum = "0.4"
lz4-sys = "1.9"
tempfile = "3"
paste = "1.0"
rustc-hash = "1.1"
[target.'cfg(windows)'.dependencies]
winapi = { version = "0.3", features = ["heapapi", "std"] }

View file

@ -1,4 +1,46 @@
use super::{driver, FromCuda, ZludaObject};
use cuda_types::*;
use hip_runtime_sys::*;
use rustc_hash::FxHashSet;
use std::{cell::RefCell, ptr, sync::Mutex};
thread_local! {
pub(crate) static CONTEXT_STACK: RefCell<Vec<(CUcontext, hipDevice_t)>> = RefCell::new(Vec::new());
}
pub(crate) struct Context {
pub(crate) device: hipDevice_t,
pub(crate) mutable: Mutex<OwnedByContext>,
}
pub(crate) struct OwnedByContext {
pub(crate) ref_count: usize, // only used by primary context
pub(crate) _memory: FxHashSet<hipDeviceptr_t>,
pub(crate) _streams: FxHashSet<hipStream_t>,
pub(crate) _modules: FxHashSet<CUmodule>,
}
impl ZludaObject for Context {
const COOKIE: usize = 0x5f867c6d9cb73315;
type CudaHandle = CUcontext;
fn drop_checked(&mut self) -> CUresult {
Ok(())
}
}
pub(crate) fn new(device: hipDevice_t) -> Context {
Context {
device,
mutable: Mutex::new(OwnedByContext {
ref_count: 0,
_memory: FxHashSet::default(),
_streams: FxHashSet::default(),
_modules: FxHashSet::default(),
}),
}
}
pub(crate) unsafe fn get_limit(pvalue: *mut usize, limit: hipLimit_t) -> hipError_t {
unsafe { hipDeviceGetLimit(pvalue, limit) }
@ -11,3 +53,41 @@ pub(crate) fn set_limit(limit: hipLimit_t, value: usize) -> hipError_t {
pub(crate) fn synchronize() -> hipError_t {
unsafe { hipDeviceSynchronize() }
}
pub(crate) fn get_primary(hip_dev: hipDevice_t) -> Result<(&'static Context, CUcontext), CUerror> {
let dev = driver::device(hip_dev)?;
Ok(dev.primary_context())
}
pub(crate) fn set_current(raw_ctx: CUcontext) -> CUresult {
let new_device = if raw_ctx.0 == ptr::null_mut() {
CONTEXT_STACK.with(|stack| {
let mut stack = stack.borrow_mut();
if let Some((_, old_device)) = stack.pop() {
if let Some((_, new_device)) = stack.last() {
if old_device != *new_device {
return Some(*new_device);
}
}
}
None
})
} else {
let ctx: &Context = FromCuda::from_cuda(&raw_ctx)?;
let device = ctx.device;
CONTEXT_STACK.with(move |stack| {
let mut stack = stack.borrow_mut();
let last_device = stack.last().map(|(_, dev)| *dev);
stack.push((raw_ctx, device));
match last_device {
None => Some(device),
Some(last_device) if last_device != device => Some(device),
_ => None,
}
})
};
if let Some(dev) = new_device {
unsafe { hipSetDevice(dev)? };
}
Ok(())
}

View file

@ -2,6 +2,8 @@ use cuda_types::*;
use hip_runtime_sys::*;
use std::{mem, ptr};
use super::context;
const PROJECT_SUFFIX: &[u8] = b" [ZLUDA]\0";
pub const COMPUTE_CAPABILITY_MAJOR: i32 = 8;
pub const COMPUTE_CAPABILITY_MINOR: i32 = 8;
@ -307,3 +309,31 @@ pub(crate) fn get_count(count: &mut ::core::ffi::c_int) -> hipError_t {
fn clamp_usize(x: usize) -> i32 {
usize::min(x, i32::MAX as usize) as i32
}
pub(crate) fn primary_context_retain(
pctx: &mut CUcontext,
hip_dev: hipDevice_t,
) -> Result<(), CUerror> {
let (ctx, raw_ctx) = context::get_primary(hip_dev)?;
{
let mut mutable_ctx = ctx.mutable.lock().map_err(|_| CUerror::UNKNOWN)?;
mutable_ctx.ref_count += 1;
}
*pctx = raw_ctx;
Ok(())
}
pub(crate) fn primary_context_release(hip_dev: hipDevice_t) -> Result<(), CUerror> {
let (ctx, _) = context::get_primary(hip_dev)?;
{
let mut mutable_ctx = ctx.mutable.lock().map_err(|_| CUerror::UNKNOWN)?;
if mutable_ctx.ref_count == 0 {
return Err(CUerror::INVALID_CONTEXT);
}
mutable_ctx.ref_count -= 1;
if mutable_ctx.ref_count == 0 {
// TODO: drop all children
}
}
Ok(())
}

79
zluda/src/impl/driver.rs Normal file
View file

@ -0,0 +1,79 @@
use cuda_types::*;
use hip_runtime_sys::*;
use std::{
ffi::{CStr, CString},
mem, slice,
sync::OnceLock,
};
use crate::r#impl::context;
use super::LiveCheck;
pub(crate) struct GlobalState {
pub devices: Vec<Device>,
}
pub(crate) struct Device {
pub(crate) _comgr_isa: CString,
primary_context: LiveCheck<context::Context>,
}
impl Device {
pub(crate) fn primary_context<'a>(&'a self) -> (&'a context::Context, CUcontext) {
unsafe {
(
self.primary_context.data.assume_init_ref(),
self.primary_context.as_handle(),
)
}
}
}
pub(crate) fn device(dev: i32) -> Result<&'static Device, CUerror> {
global_state()?
.devices
.get(dev as usize)
.ok_or(CUerror::INVALID_DEVICE)
}
pub(crate) fn global_state() -> Result<&'static GlobalState, CUerror> {
static GLOBAL_STATE: OnceLock<Result<GlobalState, CUerror>> = OnceLock::new();
fn cast_slice<'a>(bytes: &'a [i8]) -> &'a [u8] {
unsafe { slice::from_raw_parts(bytes.as_ptr().cast(), bytes.len()) }
}
GLOBAL_STATE
.get_or_init(|| {
let mut device_count = 0;
unsafe { hipGetDeviceCount(&mut device_count) }?;
Ok(GlobalState {
devices: (0..device_count)
.map(|i| {
let mut props = unsafe { mem::zeroed() };
unsafe { hipGetDevicePropertiesR0600(&mut props, i) }?;
Ok::<_, CUerror>(Device {
_comgr_isa: CStr::from_bytes_until_nul(cast_slice(
&props.gcnArchName[..],
))
.map_err(|_| CUerror::UNKNOWN)?
.to_owned(),
primary_context: LiveCheck::new(context::new(i)),
})
})
.collect::<Result<Vec<_>, _>>()?,
})
})
.as_ref()
.map_err(|e| *e)
}
pub(crate) fn init(flags: ::core::ffi::c_uint) -> CUresult {
unsafe { hipInit(flags) }?;
global_state()?;
Ok(())
}
pub(crate) fn get_version(version: &mut ::core::ffi::c_int) -> CUresult {
*version = cuda_types::CUDA_VERSION as i32;
Ok(())
}

View file

@ -1,26 +1,46 @@
use hip_runtime_sys::{hipError_t, hipFuncAttribute, hipFuncGetAttribute, hipFuncGetAttributes, hipFunction_attribute, hipLaunchKernel, hipModuleLaunchKernel};
use super::{CUresult, HasLivenessCookie, LiveCheck};
use crate::cuda::{CUfunction, CUfunction_attribute, CUstream};
use ::std::os::raw::{c_uint, c_void};
use std::{mem, ptr};
use hip_runtime_sys::*;
pub(crate) fn get_attribute(
pi: *mut i32,
cu_attrib: CUfunction_attribute,
func: CUfunction,
pi: &mut i32,
cu_attrib: hipFunction_attribute,
func: hipFunction_t,
) -> hipError_t {
if pi == ptr::null_mut() || func == ptr::null_mut() {
return hipError_t::hipErrorInvalidValue;
// TODO: implement HIP_FUNC_ATTRIBUTE_PTX_VERSION
// TODO: implement HIP_FUNC_ATTRIBUTE_BINARY_VERSION
unsafe { hipFuncGetAttribute(pi, cu_attrib, func) }?;
if cu_attrib == hipFunction_attribute::HIP_FUNC_ATTRIBUTE_NUM_REGS {
*pi = (*pi).max(1);
}
Ok(())
}
pub(crate) fn launch_kernel(
f: hipFunction_t,
grid_dim_x: ::core::ffi::c_uint,
grid_dim_y: ::core::ffi::c_uint,
grid_dim_z: ::core::ffi::c_uint,
block_dim_x: ::core::ffi::c_uint,
block_dim_y: ::core::ffi::c_uint,
block_dim_z: ::core::ffi::c_uint,
shared_mem_bytes: ::core::ffi::c_uint,
stream: hipStream_t,
kernel_params: *mut *mut ::core::ffi::c_void,
extra: *mut *mut ::core::ffi::c_void,
) -> hipError_t {
// TODO: fix constants in extra
unsafe {
hipModuleLaunchKernel(
f,
grid_dim_x,
grid_dim_y,
grid_dim_z,
block_dim_x,
block_dim_y,
block_dim_z,
shared_mem_bytes,
stream,
kernel_params,
extra,
)
}
let attrib = match cu_attrib {
CUfunction_attribute::CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK => {
hipFunction_attribute::HIP_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK
}
CUfunction_attribute::CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES => {
hipFunction_attribute::HIP_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES
}
_ => return hipError_t::hipErrorInvalidValue,
};
unsafe { hipFuncGetAttribute(pi, attrib, func as _) }
}

View file

@ -1,86 +0,0 @@
use std::{
ffi::{c_void, CStr},
mem, ptr, slice,
};
use hip_runtime_sys::{hipCtxGetDevice, hipError_t, hipGetDeviceProperties};
use crate::{
cuda::{CUjitInputType, CUjit_option, CUlinkState, CUresult},
hip_call,
};
use super::module::{self, SpirvModule};
struct LinkState {
modules: Vec<SpirvModule>,
result: Option<Vec<u8>>,
}
pub(crate) unsafe fn create(
num_options: u32,
options: *mut CUjit_option,
option_values: *mut *mut c_void,
state_out: *mut CUlinkState,
) -> CUresult {
if state_out == ptr::null_mut() {
return CUresult::CUDA_ERROR_INVALID_VALUE;
}
let state = Box::new(LinkState {
modules: Vec::new(),
result: None,
});
*state_out = mem::transmute(state);
CUresult::CUDA_SUCCESS
}
pub(crate) unsafe fn add_data(
state: CUlinkState,
type_: CUjitInputType,
data: *mut c_void,
size: usize,
name: *const i8,
num_options: u32,
options: *mut CUjit_option,
option_values: *mut *mut c_void,
) -> Result<(), hipError_t> {
if state == ptr::null_mut() {
return Err(hipError_t::hipErrorInvalidValue);
}
let state: *mut LinkState = mem::transmute(state);
let state = &mut *state;
// V-RAY specific hack
if state.modules.len() == 2 {
return Err(hipError_t::hipSuccess);
}
let spirv_data = SpirvModule::new_raw(data as *const _)?;
state.modules.push(spirv_data);
Ok(())
}
pub(crate) unsafe fn complete(
state: CUlinkState,
cubin_out: *mut *mut c_void,
size_out: *mut usize,
) -> Result<(), hipError_t> {
let mut dev = 0;
hip_call! { hipCtxGetDevice(&mut dev) };
let mut props = unsafe { mem::zeroed() };
hip_call! { hipGetDeviceProperties(&mut props, dev) };
let state: &mut LinkState = mem::transmute(state);
let spirv_bins = state.modules.iter().map(|m| &m.binaries[..]);
let should_link_ptx_impl = state.modules.iter().find_map(|m| m.should_link_ptx_impl);
let mut arch_binary = module::compile_amd(&props, spirv_bins, should_link_ptx_impl)
.map_err(|_| hipError_t::hipErrorUnknown)?;
let ptr = arch_binary.as_mut_ptr();
let size = arch_binary.len();
state.result = Some(arch_binary);
*cubin_out = ptr as _;
*size_out = size;
Ok(())
}
pub(crate) unsafe fn destroy(state: CUlinkState) -> CUresult {
let state: Box<LinkState> = mem::transmute(state);
CUresult::CUDA_SUCCESS
}

View file

@ -1,55 +1,25 @@
use hip_runtime_sys::{
hipDrvMemcpy3D, hipError_t, hipMemcpy3D, hipMemcpy3DParms, hipMemoryType, hipPitchedPtr,
hipPos, HIP_MEMCPY3D,
};
use std::ptr;
use hip_runtime_sys::*;
use crate::{
cuda::{CUDA_MEMCPY3D_st, CUdeviceptr, CUmemorytype, CUresult},
hip_call,
};
// TODO change HIP impl to 64 bits
pub(crate) unsafe fn copy_3d(cu_copy: *const CUDA_MEMCPY3D_st) -> Result<(), hipError_t> {
if cu_copy == ptr::null() {
return Err(hipError_t::hipErrorInvalidValue);
}
let cu_copy = *cu_copy;
let hip_copy = HIP_MEMCPY3D {
srcXInBytes: cu_copy.srcXInBytes as u32,
srcY: cu_copy.srcY as u32,
srcZ: cu_copy.srcZ as u32,
srcLOD: cu_copy.srcLOD as u32,
srcMemoryType: memory_type(cu_copy.srcMemoryType)?,
srcHost: cu_copy.srcHost,
srcDevice: cu_copy.srcDevice.0 as _,
srcArray: cu_copy.srcArray as _,
srcPitch: cu_copy.srcPitch as u32,
srcHeight: cu_copy.srcHeight as u32,
dstXInBytes: cu_copy.dstXInBytes as u32,
dstY: cu_copy.dstY as u32,
dstZ: cu_copy.dstZ as u32,
dstLOD: cu_copy.dstLOD as u32,
dstMemoryType: memory_type(cu_copy.dstMemoryType)?,
dstHost: cu_copy.dstHost,
dstDevice: cu_copy.dstDevice.0 as _,
dstArray: cu_copy.dstArray as _,
dstPitch: cu_copy.dstPitch as u32,
dstHeight: cu_copy.dstHeight as u32,
WidthInBytes: cu_copy.WidthInBytes as u32,
Height: cu_copy.Height as u32,
Depth: cu_copy.Depth as u32,
};
hip_call! { hipDrvMemcpy3D(&hip_copy) };
Ok(())
pub(crate) fn alloc_v2(dptr: *mut hipDeviceptr_t, bytesize: usize) -> hipError_t {
unsafe { hipMalloc(dptr.cast(), bytesize) }
}
pub(crate) fn memory_type(cu: CUmemorytype) -> Result<hipMemoryType, hipError_t> {
match cu {
CUmemorytype::CU_MEMORYTYPE_HOST => Ok(hipMemoryType::hipMemoryTypeHost),
CUmemorytype::CU_MEMORYTYPE_DEVICE => Ok(hipMemoryType::hipMemoryTypeDevice),
CUmemorytype::CU_MEMORYTYPE_ARRAY => Ok(hipMemoryType::hipMemoryTypeArray),
CUmemorytype::CU_MEMORYTYPE_UNIFIED => Ok(hipMemoryType::hipMemoryTypeUnified),
_ => Err(hipError_t::hipErrorInvalidValue),
}
pub(crate) fn free_v2(dptr: hipDeviceptr_t) -> hipError_t {
unsafe { hipFree(dptr.0) }
}
pub(crate) fn copy_dto_h_v2(
dst_host: *mut ::core::ffi::c_void,
src_device: hipDeviceptr_t,
byte_count: usize,
) -> hipError_t {
unsafe { hipMemcpyDtoH(dst_host, src_device, byte_count) }
}
pub(crate) fn copy_hto_d_v2(
dst_device: hipDeviceptr_t,
src_host: *const ::core::ffi::c_void,
byte_count: usize,
) -> hipError_t {
unsafe { hipMemcpyHtoD(dst_device, src_host.cast_mut(), byte_count) }
}

View file

@ -1,10 +1,14 @@
use cuda_types::*;
use hip_runtime_sys::*;
use std::mem::{self, ManuallyDrop};
use std::mem::{self, ManuallyDrop, MaybeUninit};
pub(super) mod context;
pub(super) mod device;
pub(super) mod driver;
pub(super) mod function;
pub(super) mod memory;
pub(super) mod module;
pub(super) mod pointer;
#[cfg(debug_assertions)]
pub(crate) fn unimplemented() -> CUresult {
@ -97,9 +101,12 @@ macro_rules! from_cuda_object {
from_cuda_nop!(
*mut i8,
*mut i32,
*mut usize,
*const std::ffi::c_void,
*const ::core::ffi::c_void,
*const ::core::ffi::c_char,
*mut ::core::ffi::c_void,
*mut *mut ::core::ffi::c_void,
i32,
u32,
usize,
@ -107,11 +114,14 @@ from_cuda_nop!(
CUdevice_attribute
);
from_cuda_transmute!(
CUdevice => hipDevice_t,
CUuuid => hipUUID,
CUfunction => hipFunction_t
CUfunction => hipFunction_t,
CUfunction_attribute => hipFunction_attribute,
CUstream => hipStream_t,
CUpointer_attribute => hipPointer_attribute,
CUdeviceptr_v2 => hipDeviceptr_t
);
from_cuda_object!(module::Module);
from_cuda_object!(module::Module, context::Context);
impl<'a> FromCuda<'a, CUlimit> for hipLimit_t {
fn from_cuda(limit: &'a CUlimit) -> Result<Self, CUerror> {
@ -140,20 +150,28 @@ pub(crate) trait ZludaObject: Sized + Send + Sync {
#[repr(C)]
pub(crate) struct LiveCheck<T: ZludaObject> {
cookie: usize,
data: ManuallyDrop<T>,
data: MaybeUninit<T>,
}
impl<T: ZludaObject> LiveCheck<T> {
fn wrap(data: T) -> *mut Self {
Box::into_raw(Box::new(LiveCheck {
fn new(data: T) -> Self {
LiveCheck {
cookie: T::COOKIE,
data: ManuallyDrop::new(data),
}))
data: MaybeUninit::new(data),
}
}
fn as_handle(&self) -> T::CudaHandle {
unsafe { mem::transmute_copy(self) }
}
fn wrap(data: T) -> *mut Self {
Box::into_raw(Box::new(Self::new(data)))
}
fn as_result(&self) -> Result<&T, CUerror> {
if self.cookie == T::COOKIE {
Ok(&self.data)
Ok(unsafe { self.data.assume_init_ref() })
} else {
Err(T::LIVENESS_FAIL)
}
@ -167,8 +185,8 @@ impl<T: ZludaObject> LiveCheck<T> {
fn drop_checked(&mut self) -> Result<Result<(), CUerror>, CUerror> {
if self.cookie == T::COOKIE {
self.cookie = 0;
let result = self.data.drop_checked();
unsafe { ManuallyDrop::drop(&mut self.data) };
let result = unsafe { self.data.assume_init_mut().drop_checked() };
unsafe { MaybeUninit::assume_init_drop(&mut self.data) };
Ok(result)
} else {
Err(T::LIVENESS_FAIL)
@ -189,7 +207,3 @@ pub fn drop_checked<T: ZludaObject>(handle: T::CudaHandle) -> Result<(), CUerror
unsafe { ManuallyDrop::drop(&mut wrapped_object) };
underlying_error
}
pub(crate) fn init(flags: ::core::ffi::c_uint) -> hipError_t {
unsafe { hipInit(flags) }
}

View file

@ -6,28 +6,27 @@ pub(crate) unsafe fn get_attribute(
data: *mut c_void,
attribute: hipPointer_attribute,
ptr: hipDeviceptr_t,
) -> CUresult {
) -> hipError_t {
if data == ptr::null_mut() {
return CUresult::ERROR_INVALID_VALUE;
return hipError_t::ErrorInvalidValue;
}
// TODO: implement by getting device ordinal & allocation start,
// then go through every context for that device
if attribute == hipPointer_attribute::HIP_POINTER_ATTRIBUTE_CONTEXT {
return CUresult::ERROR_NOT_SUPPORTED;
match attribute {
// TODO: implement by getting device ordinal & allocation start,
// then go through every context for that device
hipPointer_attribute::HIP_POINTER_ATTRIBUTE_CONTEXT => hipError_t::ErrorNotSupported,
hipPointer_attribute::HIP_POINTER_ATTRIBUTE_MEMORY_TYPE => {
let mut hip_result = hipMemoryType(0);
hipPointerGetAttribute(
(&mut hip_result as *mut hipMemoryType).cast::<c_void>(),
attribute,
ptr,
)?;
let cuda_result = memory_type(hip_result)?;
unsafe { *(data.cast()) = cuda_result };
Ok(())
}
_ => unsafe { hipPointerGetAttribute(data, attribute, ptr) },
}
if attribute == hipPointer_attribute::HIP_POINTER_ATTRIBUTE_MEMORY_TYPE {
let mut hip_result = hipMemoryType(0);
hipPointerGetAttribute(
(&mut hip_result as *mut hipMemoryType).cast::<c_void>(),
attribute,
ptr,
)?;
let cuda_result = memory_type(hip_result)?;
*(data as _) = cuda_result;
} else {
hipPointerGetAttribute(data, attribute, ptr)?;
}
Ok(())
}
fn memory_type(cu: hipMemoryType) -> Result<CUmemorytype, hipErrorCode_t> {
@ -36,6 +35,6 @@ fn memory_type(cu: hipMemoryType) -> Result<CUmemorytype, hipErrorCode_t> {
hipMemoryType::hipMemoryTypeDevice => Ok(CUmemorytype::CU_MEMORYTYPE_DEVICE),
hipMemoryType::hipMemoryTypeArray => Ok(CUmemorytype::CU_MEMORYTYPE_ARRAY),
hipMemoryType::hipMemoryTypeUnified => Ok(CUmemorytype::CU_MEMORYTYPE_UNIFIED),
_ => Err(hipErrorCode_t::hipErrorInvalidValue),
_ => Err(hipErrorCode_t::InvalidValue),
}
}

View file

@ -1,157 +0,0 @@
#![allow(non_snake_case)]
use crate::cuda as zluda;
use crate::cuda::CUstream;
use crate::cuda::CUuuid;
use crate::{
cuda::{CUdevice, CUdeviceptr},
r#impl::CUresult,
};
use ::std::{
ffi::c_void,
os::raw::{c_int, c_uint},
};
use cuda_driver_sys as cuda;
#[macro_export]
macro_rules! cuda_driver_test {
($func:ident) => {
paste! {
#[test]
fn [<$func _zluda>]() {
$func::<crate::r#impl::test::Zluda>()
}
#[test]
fn [<$func _cuda>]() {
$func::<crate::r#impl::test::Cuda>()
}
}
};
}
pub trait CudaDriverFns {
fn cuInit(flags: c_uint) -> CUresult;
fn cuCtxCreate_v2(pctx: *mut *mut c_void, flags: c_uint, dev: c_int) -> CUresult;
fn cuCtxDestroy_v2(ctx: *mut c_void) -> CUresult;
fn cuCtxPopCurrent_v2(pctx: *mut *mut c_void) -> CUresult;
fn cuCtxGetApiVersion(ctx: *mut c_void, version: *mut c_uint) -> CUresult;
fn cuCtxGetCurrent(pctx: *mut *mut c_void) -> CUresult;
fn cuMemAlloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> CUresult;
fn cuDeviceGetUuid(uuid: *mut CUuuid, dev: c_int) -> CUresult;
fn cuDevicePrimaryCtxGetState(dev: c_int, flags: *mut c_uint, active: *mut c_int) -> CUresult;
fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult;
fn cuStreamCreate(stream: *mut CUstream, flags: c_uint) -> CUresult;
fn cuMemFree_v2(mem: *mut c_void) -> CUresult;
fn cuStreamDestroy_v2(stream: CUstream) -> CUresult;
}
pub struct Zluda();
impl CudaDriverFns for Zluda {
fn cuInit(_flags: c_uint) -> CUresult {
zluda::cuInit(_flags as _)
}
fn cuCtxCreate_v2(pctx: *mut *mut c_void, flags: c_uint, dev: c_int) -> CUresult {
zluda::cuCtxCreate_v2(pctx as *mut _, flags, CUdevice(dev))
}
fn cuCtxDestroy_v2(ctx: *mut c_void) -> CUresult {
zluda::cuCtxDestroy_v2(ctx as *mut _)
}
fn cuCtxPopCurrent_v2(pctx: *mut *mut c_void) -> CUresult {
zluda::cuCtxPopCurrent_v2(pctx as *mut _)
}
fn cuCtxGetApiVersion(ctx: *mut c_void, version: *mut c_uint) -> CUresult {
zluda::cuCtxGetApiVersion(ctx as *mut _, version)
}
fn cuCtxGetCurrent(pctx: *mut *mut c_void) -> CUresult {
zluda::cuCtxGetCurrent(pctx as *mut _)
}
fn cuMemAlloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> CUresult {
zluda::cuMemAlloc_v2(dptr as *mut _, bytesize)
}
fn cuDeviceGetUuid(uuid: *mut CUuuid, dev: c_int) -> CUresult {
zluda::cuDeviceGetUuid(uuid, CUdevice(dev))
}
fn cuDevicePrimaryCtxGetState(dev: c_int, flags: *mut c_uint, active: *mut c_int) -> CUresult {
zluda::cuDevicePrimaryCtxGetState(CUdevice(dev), flags, active)
}
fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult {
zluda::cuStreamGetCtx(hStream, pctx as _)
}
fn cuStreamCreate(stream: *mut CUstream, flags: c_uint) -> CUresult {
zluda::cuStreamCreate(stream, flags)
}
fn cuMemFree_v2(dptr: *mut c_void) -> CUresult {
zluda::cuMemFree_v2(CUdeviceptr(dptr as _))
}
fn cuStreamDestroy_v2(stream: CUstream) -> CUresult {
zluda::cuStreamDestroy_v2(stream)
}
}
pub struct Cuda();
impl CudaDriverFns for Cuda {
fn cuInit(flags: c_uint) -> CUresult {
unsafe { CUresult(cuda::cuInit(flags) as c_uint) }
}
fn cuCtxCreate_v2(pctx: *mut *mut c_void, flags: c_uint, dev: c_int) -> CUresult {
unsafe { CUresult(cuda::cuCtxCreate_v2(pctx as *mut _, flags, dev) as c_uint) }
}
fn cuCtxDestroy_v2(ctx: *mut c_void) -> CUresult {
unsafe { CUresult(cuda::cuCtxDestroy_v2(ctx as *mut _) as c_uint) }
}
fn cuCtxPopCurrent_v2(pctx: *mut *mut c_void) -> CUresult {
unsafe { CUresult(cuda::cuCtxPopCurrent_v2(pctx as *mut _) as c_uint) }
}
fn cuCtxGetApiVersion(ctx: *mut c_void, version: *mut c_uint) -> CUresult {
unsafe { CUresult(cuda::cuCtxGetApiVersion(ctx as *mut _, version) as c_uint) }
}
fn cuCtxGetCurrent(pctx: *mut *mut c_void) -> CUresult {
unsafe { CUresult(cuda::cuCtxGetCurrent(pctx as *mut _) as c_uint) }
}
fn cuMemAlloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> CUresult {
unsafe { CUresult(cuda::cuMemAlloc_v2(dptr as *mut _, bytesize) as c_uint) }
}
fn cuDeviceGetUuid(uuid: *mut CUuuid, dev: c_int) -> CUresult {
unsafe { CUresult(cuda::cuDeviceGetUuid(uuid as *mut _, dev) as c_uint) }
}
fn cuDevicePrimaryCtxGetState(dev: c_int, flags: *mut c_uint, active: *mut c_int) -> CUresult {
unsafe { CUresult(cuda::cuDevicePrimaryCtxGetState(dev, flags, active) as c_uint) }
}
fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult {
unsafe { CUresult(cuda::cuStreamGetCtx(hStream as _, pctx as _) as c_uint) }
}
fn cuStreamCreate(stream: *mut CUstream, flags: c_uint) -> CUresult {
unsafe { CUresult(cuda::cuStreamCreate(stream as _, flags as _) as c_uint) }
}
fn cuMemFree_v2(mem: *mut c_void) -> CUresult {
unsafe { CUresult(cuda::cuMemFree_v2(mem as _) as c_uint) }
}
fn cuStreamDestroy_v2(stream: CUstream) -> CUresult {
unsafe { CUresult(cuda::cuStreamDestroy_v2(stream as _) as c_uint) }
}
}

View file

@ -27,10 +27,25 @@ macro_rules! implemented {
};
}
macro_rules! implemented_in_function {
($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:path;)*) => {
$(
#[cfg_attr(not(test), no_mangle)]
#[allow(improper_ctypes)]
#[allow(improper_ctypes_definitions)]
pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
cuda_base::cuda_normalize_fn!( crate::r#impl::function::$fn_name ) ($(crate::r#impl::FromCuda::from_cuda(&$arg_id)?),*)?;
Ok(())
}
)*
};
}
cuda_base::cuda_function_declarations!(
unimplemented,
implemented <= [
cuCtxGetLimit,
cuCtxSetCurrent,
cuCtxSetLimit,
cuCtxSynchronize,
cuDeviceComputeCapability,
@ -39,13 +54,25 @@ cuda_base::cuda_function_declarations!(
cuDeviceGetCount,
cuDeviceGetLuid,
cuDeviceGetName,
cuDevicePrimaryCtxRelease,
cuDevicePrimaryCtxRetain,
cuDeviceGetProperties,
cuDeviceGetUuid,
cuDeviceGetUuid_v2,
cuDeviceTotalMem_v2,
cuDriverGetVersion,
cuFuncGetAttribute,
cuInit,
cuMemAlloc_v2,
cuMemFree_v2,
cuMemcpyDtoH_v2,
cuMemcpyHtoD_v2,
cuModuleGetFunction,
cuModuleLoadData,
cuModuleUnload,
cuPointerGetAttribute,
],
implemented_in_function <= [
cuLaunchKernel,
]
);