mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-19 16:04:44 +00:00
Add more missing host-side code
This commit is contained in:
parent
c461cefd7d
commit
502b0c957e
12 changed files with 348 additions and 371 deletions
|
@ -150,6 +150,10 @@ impl VisitMut for FixFnSignatures {
|
|||
}
|
||||
}
|
||||
|
||||
const MODULES: &[&str] = &[
|
||||
"context", "device", "driver", "function", "link", "memory", "module", "pointer",
|
||||
];
|
||||
|
||||
#[proc_macro]
|
||||
pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream {
|
||||
let mut path = parse_macro_input!(tokens as syn::Path);
|
||||
|
@ -161,8 +165,9 @@ pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream {
|
|||
.0
|
||||
.ident
|
||||
.to_string();
|
||||
let already_has_module = MODULES.contains(&&*path.segments.last().unwrap().ident.to_string());
|
||||
let segments: Vec<String> = split(&fn_[2..]); // skip "cu"
|
||||
let fn_path = join(segments);
|
||||
let fn_path = join(segments, !already_has_module);
|
||||
quote! {
|
||||
#path #fn_path
|
||||
}
|
||||
|
@ -181,23 +186,16 @@ fn split(fn_: &str) -> Vec<String> {
|
|||
result
|
||||
}
|
||||
|
||||
fn join(fn_: Vec<String>) -> Punctuated<Ident, Token![::]> {
|
||||
fn join(fn_: Vec<String>, find_module: bool) -> Punctuated<Ident, Token![::]> {
|
||||
fn full_form(segment: &str) -> Option<&[&str]> {
|
||||
Some(match segment {
|
||||
"ctx" => &["context"],
|
||||
"func" => &["function"],
|
||||
"mem" => &["memory"],
|
||||
"memcpy" => &["memory", "copy"],
|
||||
_ => return None,
|
||||
})
|
||||
}
|
||||
const MODULES: &[&str] = &[
|
||||
"context",
|
||||
"device",
|
||||
"function",
|
||||
"link",
|
||||
"memory",
|
||||
"module",
|
||||
"pointer"
|
||||
];
|
||||
let mut normalized: Vec<&str> = Vec::new();
|
||||
for segment in fn_.iter() {
|
||||
match full_form(segment) {
|
||||
|
@ -205,18 +203,20 @@ fn join(fn_: Vec<String>) -> Punctuated<Ident, Token![::]> {
|
|||
None => normalized.push(&*segment),
|
||||
}
|
||||
}
|
||||
if !find_module {
|
||||
return [Ident::new(&normalized.join("_"), Span::call_site())]
|
||||
.into_iter()
|
||||
.collect();
|
||||
}
|
||||
if !MODULES.contains(&normalized[0]) {
|
||||
let mut globalized = vec!["global"];
|
||||
let mut globalized = vec!["driver"];
|
||||
globalized.extend(normalized);
|
||||
normalized = globalized;
|
||||
}
|
||||
let (module, path) = normalized.split_first().unwrap();
|
||||
let path = path.join("_");
|
||||
let mut result = Punctuated::new();
|
||||
result.extend(
|
||||
[module, &&*path]
|
||||
.into_iter()
|
||||
.map(|s| Ident::new(s, Span::call_site())),
|
||||
);
|
||||
result
|
||||
[module, &&*path]
|
||||
.into_iter()
|
||||
.map(|s| Ident::new(s, Span::call_site()))
|
||||
.collect()
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ num_enum = "0.4"
|
|||
lz4-sys = "1.9"
|
||||
tempfile = "3"
|
||||
paste = "1.0"
|
||||
rustc-hash = "1.1"
|
||||
|
||||
[target.'cfg(windows)'.dependencies]
|
||||
winapi = { version = "0.3", features = ["heapapi", "std"] }
|
||||
|
|
|
@ -1,4 +1,46 @@
|
|||
use super::{driver, FromCuda, ZludaObject};
|
||||
use cuda_types::*;
|
||||
use hip_runtime_sys::*;
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::{cell::RefCell, ptr, sync::Mutex};
|
||||
|
||||
thread_local! {
|
||||
pub(crate) static CONTEXT_STACK: RefCell<Vec<(CUcontext, hipDevice_t)>> = RefCell::new(Vec::new());
|
||||
}
|
||||
|
||||
pub(crate) struct Context {
|
||||
pub(crate) device: hipDevice_t,
|
||||
pub(crate) mutable: Mutex<OwnedByContext>,
|
||||
}
|
||||
|
||||
pub(crate) struct OwnedByContext {
|
||||
pub(crate) ref_count: usize, // only used by primary context
|
||||
pub(crate) _memory: FxHashSet<hipDeviceptr_t>,
|
||||
pub(crate) _streams: FxHashSet<hipStream_t>,
|
||||
pub(crate) _modules: FxHashSet<CUmodule>,
|
||||
}
|
||||
|
||||
impl ZludaObject for Context {
|
||||
const COOKIE: usize = 0x5f867c6d9cb73315;
|
||||
|
||||
type CudaHandle = CUcontext;
|
||||
|
||||
fn drop_checked(&mut self) -> CUresult {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new(device: hipDevice_t) -> Context {
|
||||
Context {
|
||||
device,
|
||||
mutable: Mutex::new(OwnedByContext {
|
||||
ref_count: 0,
|
||||
_memory: FxHashSet::default(),
|
||||
_streams: FxHashSet::default(),
|
||||
_modules: FxHashSet::default(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn get_limit(pvalue: *mut usize, limit: hipLimit_t) -> hipError_t {
|
||||
unsafe { hipDeviceGetLimit(pvalue, limit) }
|
||||
|
@ -11,3 +53,41 @@ pub(crate) fn set_limit(limit: hipLimit_t, value: usize) -> hipError_t {
|
|||
pub(crate) fn synchronize() -> hipError_t {
|
||||
unsafe { hipDeviceSynchronize() }
|
||||
}
|
||||
|
||||
pub(crate) fn get_primary(hip_dev: hipDevice_t) -> Result<(&'static Context, CUcontext), CUerror> {
|
||||
let dev = driver::device(hip_dev)?;
|
||||
Ok(dev.primary_context())
|
||||
}
|
||||
|
||||
pub(crate) fn set_current(raw_ctx: CUcontext) -> CUresult {
|
||||
let new_device = if raw_ctx.0 == ptr::null_mut() {
|
||||
CONTEXT_STACK.with(|stack| {
|
||||
let mut stack = stack.borrow_mut();
|
||||
if let Some((_, old_device)) = stack.pop() {
|
||||
if let Some((_, new_device)) = stack.last() {
|
||||
if old_device != *new_device {
|
||||
return Some(*new_device);
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
})
|
||||
} else {
|
||||
let ctx: &Context = FromCuda::from_cuda(&raw_ctx)?;
|
||||
let device = ctx.device;
|
||||
CONTEXT_STACK.with(move |stack| {
|
||||
let mut stack = stack.borrow_mut();
|
||||
let last_device = stack.last().map(|(_, dev)| *dev);
|
||||
stack.push((raw_ctx, device));
|
||||
match last_device {
|
||||
None => Some(device),
|
||||
Some(last_device) if last_device != device => Some(device),
|
||||
_ => None,
|
||||
}
|
||||
})
|
||||
};
|
||||
if let Some(dev) = new_device {
|
||||
unsafe { hipSetDevice(dev)? };
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -2,6 +2,8 @@ use cuda_types::*;
|
|||
use hip_runtime_sys::*;
|
||||
use std::{mem, ptr};
|
||||
|
||||
use super::context;
|
||||
|
||||
const PROJECT_SUFFIX: &[u8] = b" [ZLUDA]\0";
|
||||
pub const COMPUTE_CAPABILITY_MAJOR: i32 = 8;
|
||||
pub const COMPUTE_CAPABILITY_MINOR: i32 = 8;
|
||||
|
@ -307,3 +309,31 @@ pub(crate) fn get_count(count: &mut ::core::ffi::c_int) -> hipError_t {
|
|||
fn clamp_usize(x: usize) -> i32 {
|
||||
usize::min(x, i32::MAX as usize) as i32
|
||||
}
|
||||
|
||||
pub(crate) fn primary_context_retain(
|
||||
pctx: &mut CUcontext,
|
||||
hip_dev: hipDevice_t,
|
||||
) -> Result<(), CUerror> {
|
||||
let (ctx, raw_ctx) = context::get_primary(hip_dev)?;
|
||||
{
|
||||
let mut mutable_ctx = ctx.mutable.lock().map_err(|_| CUerror::UNKNOWN)?;
|
||||
mutable_ctx.ref_count += 1;
|
||||
}
|
||||
*pctx = raw_ctx;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn primary_context_release(hip_dev: hipDevice_t) -> Result<(), CUerror> {
|
||||
let (ctx, _) = context::get_primary(hip_dev)?;
|
||||
{
|
||||
let mut mutable_ctx = ctx.mutable.lock().map_err(|_| CUerror::UNKNOWN)?;
|
||||
if mutable_ctx.ref_count == 0 {
|
||||
return Err(CUerror::INVALID_CONTEXT);
|
||||
}
|
||||
mutable_ctx.ref_count -= 1;
|
||||
if mutable_ctx.ref_count == 0 {
|
||||
// TODO: drop all children
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
|
79
zluda/src/impl/driver.rs
Normal file
79
zluda/src/impl/driver.rs
Normal file
|
@ -0,0 +1,79 @@
|
|||
use cuda_types::*;
|
||||
use hip_runtime_sys::*;
|
||||
use std::{
|
||||
ffi::{CStr, CString},
|
||||
mem, slice,
|
||||
sync::OnceLock,
|
||||
};
|
||||
|
||||
use crate::r#impl::context;
|
||||
|
||||
use super::LiveCheck;
|
||||
|
||||
pub(crate) struct GlobalState {
|
||||
pub devices: Vec<Device>,
|
||||
}
|
||||
|
||||
pub(crate) struct Device {
|
||||
pub(crate) _comgr_isa: CString,
|
||||
primary_context: LiveCheck<context::Context>,
|
||||
}
|
||||
|
||||
impl Device {
|
||||
pub(crate) fn primary_context<'a>(&'a self) -> (&'a context::Context, CUcontext) {
|
||||
unsafe {
|
||||
(
|
||||
self.primary_context.data.assume_init_ref(),
|
||||
self.primary_context.as_handle(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn device(dev: i32) -> Result<&'static Device, CUerror> {
|
||||
global_state()?
|
||||
.devices
|
||||
.get(dev as usize)
|
||||
.ok_or(CUerror::INVALID_DEVICE)
|
||||
}
|
||||
|
||||
pub(crate) fn global_state() -> Result<&'static GlobalState, CUerror> {
|
||||
static GLOBAL_STATE: OnceLock<Result<GlobalState, CUerror>> = OnceLock::new();
|
||||
fn cast_slice<'a>(bytes: &'a [i8]) -> &'a [u8] {
|
||||
unsafe { slice::from_raw_parts(bytes.as_ptr().cast(), bytes.len()) }
|
||||
}
|
||||
GLOBAL_STATE
|
||||
.get_or_init(|| {
|
||||
let mut device_count = 0;
|
||||
unsafe { hipGetDeviceCount(&mut device_count) }?;
|
||||
Ok(GlobalState {
|
||||
devices: (0..device_count)
|
||||
.map(|i| {
|
||||
let mut props = unsafe { mem::zeroed() };
|
||||
unsafe { hipGetDevicePropertiesR0600(&mut props, i) }?;
|
||||
Ok::<_, CUerror>(Device {
|
||||
_comgr_isa: CStr::from_bytes_until_nul(cast_slice(
|
||||
&props.gcnArchName[..],
|
||||
))
|
||||
.map_err(|_| CUerror::UNKNOWN)?
|
||||
.to_owned(),
|
||||
primary_context: LiveCheck::new(context::new(i)),
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
})
|
||||
})
|
||||
.as_ref()
|
||||
.map_err(|e| *e)
|
||||
}
|
||||
|
||||
pub(crate) fn init(flags: ::core::ffi::c_uint) -> CUresult {
|
||||
unsafe { hipInit(flags) }?;
|
||||
global_state()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn get_version(version: &mut ::core::ffi::c_int) -> CUresult {
|
||||
*version = cuda_types::CUDA_VERSION as i32;
|
||||
Ok(())
|
||||
}
|
|
@ -1,26 +1,46 @@
|
|||
use hip_runtime_sys::{hipError_t, hipFuncAttribute, hipFuncGetAttribute, hipFuncGetAttributes, hipFunction_attribute, hipLaunchKernel, hipModuleLaunchKernel};
|
||||
|
||||
use super::{CUresult, HasLivenessCookie, LiveCheck};
|
||||
use crate::cuda::{CUfunction, CUfunction_attribute, CUstream};
|
||||
use ::std::os::raw::{c_uint, c_void};
|
||||
use std::{mem, ptr};
|
||||
use hip_runtime_sys::*;
|
||||
|
||||
pub(crate) fn get_attribute(
|
||||
pi: *mut i32,
|
||||
cu_attrib: CUfunction_attribute,
|
||||
func: CUfunction,
|
||||
pi: &mut i32,
|
||||
cu_attrib: hipFunction_attribute,
|
||||
func: hipFunction_t,
|
||||
) -> hipError_t {
|
||||
if pi == ptr::null_mut() || func == ptr::null_mut() {
|
||||
return hipError_t::hipErrorInvalidValue;
|
||||
// TODO: implement HIP_FUNC_ATTRIBUTE_PTX_VERSION
|
||||
// TODO: implement HIP_FUNC_ATTRIBUTE_BINARY_VERSION
|
||||
unsafe { hipFuncGetAttribute(pi, cu_attrib, func) }?;
|
||||
if cu_attrib == hipFunction_attribute::HIP_FUNC_ATTRIBUTE_NUM_REGS {
|
||||
*pi = (*pi).max(1);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn launch_kernel(
|
||||
f: hipFunction_t,
|
||||
grid_dim_x: ::core::ffi::c_uint,
|
||||
grid_dim_y: ::core::ffi::c_uint,
|
||||
grid_dim_z: ::core::ffi::c_uint,
|
||||
block_dim_x: ::core::ffi::c_uint,
|
||||
block_dim_y: ::core::ffi::c_uint,
|
||||
block_dim_z: ::core::ffi::c_uint,
|
||||
shared_mem_bytes: ::core::ffi::c_uint,
|
||||
stream: hipStream_t,
|
||||
kernel_params: *mut *mut ::core::ffi::c_void,
|
||||
extra: *mut *mut ::core::ffi::c_void,
|
||||
) -> hipError_t {
|
||||
// TODO: fix constants in extra
|
||||
unsafe {
|
||||
hipModuleLaunchKernel(
|
||||
f,
|
||||
grid_dim_x,
|
||||
grid_dim_y,
|
||||
grid_dim_z,
|
||||
block_dim_x,
|
||||
block_dim_y,
|
||||
block_dim_z,
|
||||
shared_mem_bytes,
|
||||
stream,
|
||||
kernel_params,
|
||||
extra,
|
||||
)
|
||||
}
|
||||
let attrib = match cu_attrib {
|
||||
CUfunction_attribute::CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK => {
|
||||
hipFunction_attribute::HIP_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK
|
||||
}
|
||||
CUfunction_attribute::CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES => {
|
||||
hipFunction_attribute::HIP_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES
|
||||
}
|
||||
_ => return hipError_t::hipErrorInvalidValue,
|
||||
};
|
||||
unsafe { hipFuncGetAttribute(pi, attrib, func as _) }
|
||||
}
|
||||
|
|
|
@ -1,86 +0,0 @@
|
|||
use std::{
|
||||
ffi::{c_void, CStr},
|
||||
mem, ptr, slice,
|
||||
};
|
||||
|
||||
use hip_runtime_sys::{hipCtxGetDevice, hipError_t, hipGetDeviceProperties};
|
||||
|
||||
use crate::{
|
||||
cuda::{CUjitInputType, CUjit_option, CUlinkState, CUresult},
|
||||
hip_call,
|
||||
};
|
||||
|
||||
use super::module::{self, SpirvModule};
|
||||
|
||||
struct LinkState {
|
||||
modules: Vec<SpirvModule>,
|
||||
result: Option<Vec<u8>>,
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn create(
|
||||
num_options: u32,
|
||||
options: *mut CUjit_option,
|
||||
option_values: *mut *mut c_void,
|
||||
state_out: *mut CUlinkState,
|
||||
) -> CUresult {
|
||||
if state_out == ptr::null_mut() {
|
||||
return CUresult::CUDA_ERROR_INVALID_VALUE;
|
||||
}
|
||||
let state = Box::new(LinkState {
|
||||
modules: Vec::new(),
|
||||
result: None,
|
||||
});
|
||||
*state_out = mem::transmute(state);
|
||||
CUresult::CUDA_SUCCESS
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn add_data(
|
||||
state: CUlinkState,
|
||||
type_: CUjitInputType,
|
||||
data: *mut c_void,
|
||||
size: usize,
|
||||
name: *const i8,
|
||||
num_options: u32,
|
||||
options: *mut CUjit_option,
|
||||
option_values: *mut *mut c_void,
|
||||
) -> Result<(), hipError_t> {
|
||||
if state == ptr::null_mut() {
|
||||
return Err(hipError_t::hipErrorInvalidValue);
|
||||
}
|
||||
let state: *mut LinkState = mem::transmute(state);
|
||||
let state = &mut *state;
|
||||
// V-RAY specific hack
|
||||
if state.modules.len() == 2 {
|
||||
return Err(hipError_t::hipSuccess);
|
||||
}
|
||||
let spirv_data = SpirvModule::new_raw(data as *const _)?;
|
||||
state.modules.push(spirv_data);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn complete(
|
||||
state: CUlinkState,
|
||||
cubin_out: *mut *mut c_void,
|
||||
size_out: *mut usize,
|
||||
) -> Result<(), hipError_t> {
|
||||
let mut dev = 0;
|
||||
hip_call! { hipCtxGetDevice(&mut dev) };
|
||||
let mut props = unsafe { mem::zeroed() };
|
||||
hip_call! { hipGetDeviceProperties(&mut props, dev) };
|
||||
let state: &mut LinkState = mem::transmute(state);
|
||||
let spirv_bins = state.modules.iter().map(|m| &m.binaries[..]);
|
||||
let should_link_ptx_impl = state.modules.iter().find_map(|m| m.should_link_ptx_impl);
|
||||
let mut arch_binary = module::compile_amd(&props, spirv_bins, should_link_ptx_impl)
|
||||
.map_err(|_| hipError_t::hipErrorUnknown)?;
|
||||
let ptr = arch_binary.as_mut_ptr();
|
||||
let size = arch_binary.len();
|
||||
state.result = Some(arch_binary);
|
||||
*cubin_out = ptr as _;
|
||||
*size_out = size;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn destroy(state: CUlinkState) -> CUresult {
|
||||
let state: Box<LinkState> = mem::transmute(state);
|
||||
CUresult::CUDA_SUCCESS
|
||||
}
|
|
@ -1,55 +1,25 @@
|
|||
use hip_runtime_sys::{
|
||||
hipDrvMemcpy3D, hipError_t, hipMemcpy3D, hipMemcpy3DParms, hipMemoryType, hipPitchedPtr,
|
||||
hipPos, HIP_MEMCPY3D,
|
||||
};
|
||||
use std::ptr;
|
||||
use hip_runtime_sys::*;
|
||||
|
||||
use crate::{
|
||||
cuda::{CUDA_MEMCPY3D_st, CUdeviceptr, CUmemorytype, CUresult},
|
||||
hip_call,
|
||||
};
|
||||
|
||||
// TODO change HIP impl to 64 bits
|
||||
pub(crate) unsafe fn copy_3d(cu_copy: *const CUDA_MEMCPY3D_st) -> Result<(), hipError_t> {
|
||||
if cu_copy == ptr::null() {
|
||||
return Err(hipError_t::hipErrorInvalidValue);
|
||||
}
|
||||
let cu_copy = *cu_copy;
|
||||
let hip_copy = HIP_MEMCPY3D {
|
||||
srcXInBytes: cu_copy.srcXInBytes as u32,
|
||||
srcY: cu_copy.srcY as u32,
|
||||
srcZ: cu_copy.srcZ as u32,
|
||||
srcLOD: cu_copy.srcLOD as u32,
|
||||
srcMemoryType: memory_type(cu_copy.srcMemoryType)?,
|
||||
srcHost: cu_copy.srcHost,
|
||||
srcDevice: cu_copy.srcDevice.0 as _,
|
||||
srcArray: cu_copy.srcArray as _,
|
||||
srcPitch: cu_copy.srcPitch as u32,
|
||||
srcHeight: cu_copy.srcHeight as u32,
|
||||
dstXInBytes: cu_copy.dstXInBytes as u32,
|
||||
dstY: cu_copy.dstY as u32,
|
||||
dstZ: cu_copy.dstZ as u32,
|
||||
dstLOD: cu_copy.dstLOD as u32,
|
||||
dstMemoryType: memory_type(cu_copy.dstMemoryType)?,
|
||||
dstHost: cu_copy.dstHost,
|
||||
dstDevice: cu_copy.dstDevice.0 as _,
|
||||
dstArray: cu_copy.dstArray as _,
|
||||
dstPitch: cu_copy.dstPitch as u32,
|
||||
dstHeight: cu_copy.dstHeight as u32,
|
||||
WidthInBytes: cu_copy.WidthInBytes as u32,
|
||||
Height: cu_copy.Height as u32,
|
||||
Depth: cu_copy.Depth as u32,
|
||||
};
|
||||
hip_call! { hipDrvMemcpy3D(&hip_copy) };
|
||||
Ok(())
|
||||
pub(crate) fn alloc_v2(dptr: *mut hipDeviceptr_t, bytesize: usize) -> hipError_t {
|
||||
unsafe { hipMalloc(dptr.cast(), bytesize) }
|
||||
}
|
||||
|
||||
pub(crate) fn memory_type(cu: CUmemorytype) -> Result<hipMemoryType, hipError_t> {
|
||||
match cu {
|
||||
CUmemorytype::CU_MEMORYTYPE_HOST => Ok(hipMemoryType::hipMemoryTypeHost),
|
||||
CUmemorytype::CU_MEMORYTYPE_DEVICE => Ok(hipMemoryType::hipMemoryTypeDevice),
|
||||
CUmemorytype::CU_MEMORYTYPE_ARRAY => Ok(hipMemoryType::hipMemoryTypeArray),
|
||||
CUmemorytype::CU_MEMORYTYPE_UNIFIED => Ok(hipMemoryType::hipMemoryTypeUnified),
|
||||
_ => Err(hipError_t::hipErrorInvalidValue),
|
||||
}
|
||||
pub(crate) fn free_v2(dptr: hipDeviceptr_t) -> hipError_t {
|
||||
unsafe { hipFree(dptr.0) }
|
||||
}
|
||||
|
||||
pub(crate) fn copy_dto_h_v2(
|
||||
dst_host: *mut ::core::ffi::c_void,
|
||||
src_device: hipDeviceptr_t,
|
||||
byte_count: usize,
|
||||
) -> hipError_t {
|
||||
unsafe { hipMemcpyDtoH(dst_host, src_device, byte_count) }
|
||||
}
|
||||
|
||||
pub(crate) fn copy_hto_d_v2(
|
||||
dst_device: hipDeviceptr_t,
|
||||
src_host: *const ::core::ffi::c_void,
|
||||
byte_count: usize,
|
||||
) -> hipError_t {
|
||||
unsafe { hipMemcpyHtoD(dst_device, src_host.cast_mut(), byte_count) }
|
||||
}
|
||||
|
|
|
@ -1,10 +1,14 @@
|
|||
use cuda_types::*;
|
||||
use hip_runtime_sys::*;
|
||||
use std::mem::{self, ManuallyDrop};
|
||||
use std::mem::{self, ManuallyDrop, MaybeUninit};
|
||||
|
||||
pub(super) mod context;
|
||||
pub(super) mod device;
|
||||
pub(super) mod driver;
|
||||
pub(super) mod function;
|
||||
pub(super) mod memory;
|
||||
pub(super) mod module;
|
||||
pub(super) mod pointer;
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
pub(crate) fn unimplemented() -> CUresult {
|
||||
|
@ -97,9 +101,12 @@ macro_rules! from_cuda_object {
|
|||
|
||||
from_cuda_nop!(
|
||||
*mut i8,
|
||||
*mut i32,
|
||||
*mut usize,
|
||||
*const std::ffi::c_void,
|
||||
*const ::core::ffi::c_void,
|
||||
*const ::core::ffi::c_char,
|
||||
*mut ::core::ffi::c_void,
|
||||
*mut *mut ::core::ffi::c_void,
|
||||
i32,
|
||||
u32,
|
||||
usize,
|
||||
|
@ -107,11 +114,14 @@ from_cuda_nop!(
|
|||
CUdevice_attribute
|
||||
);
|
||||
from_cuda_transmute!(
|
||||
CUdevice => hipDevice_t,
|
||||
CUuuid => hipUUID,
|
||||
CUfunction => hipFunction_t
|
||||
CUfunction => hipFunction_t,
|
||||
CUfunction_attribute => hipFunction_attribute,
|
||||
CUstream => hipStream_t,
|
||||
CUpointer_attribute => hipPointer_attribute,
|
||||
CUdeviceptr_v2 => hipDeviceptr_t
|
||||
);
|
||||
from_cuda_object!(module::Module);
|
||||
from_cuda_object!(module::Module, context::Context);
|
||||
|
||||
impl<'a> FromCuda<'a, CUlimit> for hipLimit_t {
|
||||
fn from_cuda(limit: &'a CUlimit) -> Result<Self, CUerror> {
|
||||
|
@ -140,20 +150,28 @@ pub(crate) trait ZludaObject: Sized + Send + Sync {
|
|||
#[repr(C)]
|
||||
pub(crate) struct LiveCheck<T: ZludaObject> {
|
||||
cookie: usize,
|
||||
data: ManuallyDrop<T>,
|
||||
data: MaybeUninit<T>,
|
||||
}
|
||||
|
||||
impl<T: ZludaObject> LiveCheck<T> {
|
||||
fn wrap(data: T) -> *mut Self {
|
||||
Box::into_raw(Box::new(LiveCheck {
|
||||
fn new(data: T) -> Self {
|
||||
LiveCheck {
|
||||
cookie: T::COOKIE,
|
||||
data: ManuallyDrop::new(data),
|
||||
}))
|
||||
data: MaybeUninit::new(data),
|
||||
}
|
||||
}
|
||||
|
||||
fn as_handle(&self) -> T::CudaHandle {
|
||||
unsafe { mem::transmute_copy(self) }
|
||||
}
|
||||
|
||||
fn wrap(data: T) -> *mut Self {
|
||||
Box::into_raw(Box::new(Self::new(data)))
|
||||
}
|
||||
|
||||
fn as_result(&self) -> Result<&T, CUerror> {
|
||||
if self.cookie == T::COOKIE {
|
||||
Ok(&self.data)
|
||||
Ok(unsafe { self.data.assume_init_ref() })
|
||||
} else {
|
||||
Err(T::LIVENESS_FAIL)
|
||||
}
|
||||
|
@ -167,8 +185,8 @@ impl<T: ZludaObject> LiveCheck<T> {
|
|||
fn drop_checked(&mut self) -> Result<Result<(), CUerror>, CUerror> {
|
||||
if self.cookie == T::COOKIE {
|
||||
self.cookie = 0;
|
||||
let result = self.data.drop_checked();
|
||||
unsafe { ManuallyDrop::drop(&mut self.data) };
|
||||
let result = unsafe { self.data.assume_init_mut().drop_checked() };
|
||||
unsafe { MaybeUninit::assume_init_drop(&mut self.data) };
|
||||
Ok(result)
|
||||
} else {
|
||||
Err(T::LIVENESS_FAIL)
|
||||
|
@ -189,7 +207,3 @@ pub fn drop_checked<T: ZludaObject>(handle: T::CudaHandle) -> Result<(), CUerror
|
|||
unsafe { ManuallyDrop::drop(&mut wrapped_object) };
|
||||
underlying_error
|
||||
}
|
||||
|
||||
pub(crate) fn init(flags: ::core::ffi::c_uint) -> hipError_t {
|
||||
unsafe { hipInit(flags) }
|
||||
}
|
||||
|
|
|
@ -6,28 +6,27 @@ pub(crate) unsafe fn get_attribute(
|
|||
data: *mut c_void,
|
||||
attribute: hipPointer_attribute,
|
||||
ptr: hipDeviceptr_t,
|
||||
) -> CUresult {
|
||||
) -> hipError_t {
|
||||
if data == ptr::null_mut() {
|
||||
return CUresult::ERROR_INVALID_VALUE;
|
||||
return hipError_t::ErrorInvalidValue;
|
||||
}
|
||||
// TODO: implement by getting device ordinal & allocation start,
|
||||
// then go through every context for that device
|
||||
if attribute == hipPointer_attribute::HIP_POINTER_ATTRIBUTE_CONTEXT {
|
||||
return CUresult::ERROR_NOT_SUPPORTED;
|
||||
match attribute {
|
||||
// TODO: implement by getting device ordinal & allocation start,
|
||||
// then go through every context for that device
|
||||
hipPointer_attribute::HIP_POINTER_ATTRIBUTE_CONTEXT => hipError_t::ErrorNotSupported,
|
||||
hipPointer_attribute::HIP_POINTER_ATTRIBUTE_MEMORY_TYPE => {
|
||||
let mut hip_result = hipMemoryType(0);
|
||||
hipPointerGetAttribute(
|
||||
(&mut hip_result as *mut hipMemoryType).cast::<c_void>(),
|
||||
attribute,
|
||||
ptr,
|
||||
)?;
|
||||
let cuda_result = memory_type(hip_result)?;
|
||||
unsafe { *(data.cast()) = cuda_result };
|
||||
Ok(())
|
||||
}
|
||||
_ => unsafe { hipPointerGetAttribute(data, attribute, ptr) },
|
||||
}
|
||||
if attribute == hipPointer_attribute::HIP_POINTER_ATTRIBUTE_MEMORY_TYPE {
|
||||
let mut hip_result = hipMemoryType(0);
|
||||
hipPointerGetAttribute(
|
||||
(&mut hip_result as *mut hipMemoryType).cast::<c_void>(),
|
||||
attribute,
|
||||
ptr,
|
||||
)?;
|
||||
let cuda_result = memory_type(hip_result)?;
|
||||
*(data as _) = cuda_result;
|
||||
} else {
|
||||
hipPointerGetAttribute(data, attribute, ptr)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn memory_type(cu: hipMemoryType) -> Result<CUmemorytype, hipErrorCode_t> {
|
||||
|
@ -36,6 +35,6 @@ fn memory_type(cu: hipMemoryType) -> Result<CUmemorytype, hipErrorCode_t> {
|
|||
hipMemoryType::hipMemoryTypeDevice => Ok(CUmemorytype::CU_MEMORYTYPE_DEVICE),
|
||||
hipMemoryType::hipMemoryTypeArray => Ok(CUmemorytype::CU_MEMORYTYPE_ARRAY),
|
||||
hipMemoryType::hipMemoryTypeUnified => Ok(CUmemorytype::CU_MEMORYTYPE_UNIFIED),
|
||||
_ => Err(hipErrorCode_t::hipErrorInvalidValue),
|
||||
_ => Err(hipErrorCode_t::InvalidValue),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,157 +0,0 @@
|
|||
#![allow(non_snake_case)]
|
||||
|
||||
use crate::cuda as zluda;
|
||||
use crate::cuda::CUstream;
|
||||
use crate::cuda::CUuuid;
|
||||
use crate::{
|
||||
cuda::{CUdevice, CUdeviceptr},
|
||||
r#impl::CUresult,
|
||||
};
|
||||
use ::std::{
|
||||
ffi::c_void,
|
||||
os::raw::{c_int, c_uint},
|
||||
};
|
||||
use cuda_driver_sys as cuda;
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! cuda_driver_test {
|
||||
($func:ident) => {
|
||||
paste! {
|
||||
#[test]
|
||||
fn [<$func _zluda>]() {
|
||||
$func::<crate::r#impl::test::Zluda>()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn [<$func _cuda>]() {
|
||||
$func::<crate::r#impl::test::Cuda>()
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub trait CudaDriverFns {
|
||||
fn cuInit(flags: c_uint) -> CUresult;
|
||||
fn cuCtxCreate_v2(pctx: *mut *mut c_void, flags: c_uint, dev: c_int) -> CUresult;
|
||||
fn cuCtxDestroy_v2(ctx: *mut c_void) -> CUresult;
|
||||
fn cuCtxPopCurrent_v2(pctx: *mut *mut c_void) -> CUresult;
|
||||
fn cuCtxGetApiVersion(ctx: *mut c_void, version: *mut c_uint) -> CUresult;
|
||||
fn cuCtxGetCurrent(pctx: *mut *mut c_void) -> CUresult;
|
||||
fn cuMemAlloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> CUresult;
|
||||
fn cuDeviceGetUuid(uuid: *mut CUuuid, dev: c_int) -> CUresult;
|
||||
fn cuDevicePrimaryCtxGetState(dev: c_int, flags: *mut c_uint, active: *mut c_int) -> CUresult;
|
||||
fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult;
|
||||
fn cuStreamCreate(stream: *mut CUstream, flags: c_uint) -> CUresult;
|
||||
fn cuMemFree_v2(mem: *mut c_void) -> CUresult;
|
||||
fn cuStreamDestroy_v2(stream: CUstream) -> CUresult;
|
||||
}
|
||||
|
||||
pub struct Zluda();
|
||||
|
||||
impl CudaDriverFns for Zluda {
|
||||
fn cuInit(_flags: c_uint) -> CUresult {
|
||||
zluda::cuInit(_flags as _)
|
||||
}
|
||||
|
||||
fn cuCtxCreate_v2(pctx: *mut *mut c_void, flags: c_uint, dev: c_int) -> CUresult {
|
||||
zluda::cuCtxCreate_v2(pctx as *mut _, flags, CUdevice(dev))
|
||||
}
|
||||
|
||||
fn cuCtxDestroy_v2(ctx: *mut c_void) -> CUresult {
|
||||
zluda::cuCtxDestroy_v2(ctx as *mut _)
|
||||
}
|
||||
|
||||
fn cuCtxPopCurrent_v2(pctx: *mut *mut c_void) -> CUresult {
|
||||
zluda::cuCtxPopCurrent_v2(pctx as *mut _)
|
||||
}
|
||||
|
||||
fn cuCtxGetApiVersion(ctx: *mut c_void, version: *mut c_uint) -> CUresult {
|
||||
zluda::cuCtxGetApiVersion(ctx as *mut _, version)
|
||||
}
|
||||
|
||||
fn cuCtxGetCurrent(pctx: *mut *mut c_void) -> CUresult {
|
||||
zluda::cuCtxGetCurrent(pctx as *mut _)
|
||||
}
|
||||
fn cuMemAlloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> CUresult {
|
||||
zluda::cuMemAlloc_v2(dptr as *mut _, bytesize)
|
||||
}
|
||||
|
||||
fn cuDeviceGetUuid(uuid: *mut CUuuid, dev: c_int) -> CUresult {
|
||||
zluda::cuDeviceGetUuid(uuid, CUdevice(dev))
|
||||
}
|
||||
|
||||
fn cuDevicePrimaryCtxGetState(dev: c_int, flags: *mut c_uint, active: *mut c_int) -> CUresult {
|
||||
zluda::cuDevicePrimaryCtxGetState(CUdevice(dev), flags, active)
|
||||
}
|
||||
|
||||
fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult {
|
||||
zluda::cuStreamGetCtx(hStream, pctx as _)
|
||||
}
|
||||
|
||||
fn cuStreamCreate(stream: *mut CUstream, flags: c_uint) -> CUresult {
|
||||
zluda::cuStreamCreate(stream, flags)
|
||||
}
|
||||
|
||||
fn cuMemFree_v2(dptr: *mut c_void) -> CUresult {
|
||||
zluda::cuMemFree_v2(CUdeviceptr(dptr as _))
|
||||
}
|
||||
|
||||
fn cuStreamDestroy_v2(stream: CUstream) -> CUresult {
|
||||
zluda::cuStreamDestroy_v2(stream)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Cuda();
|
||||
|
||||
impl CudaDriverFns for Cuda {
|
||||
fn cuInit(flags: c_uint) -> CUresult {
|
||||
unsafe { CUresult(cuda::cuInit(flags) as c_uint) }
|
||||
}
|
||||
|
||||
fn cuCtxCreate_v2(pctx: *mut *mut c_void, flags: c_uint, dev: c_int) -> CUresult {
|
||||
unsafe { CUresult(cuda::cuCtxCreate_v2(pctx as *mut _, flags, dev) as c_uint) }
|
||||
}
|
||||
|
||||
fn cuCtxDestroy_v2(ctx: *mut c_void) -> CUresult {
|
||||
unsafe { CUresult(cuda::cuCtxDestroy_v2(ctx as *mut _) as c_uint) }
|
||||
}
|
||||
|
||||
fn cuCtxPopCurrent_v2(pctx: *mut *mut c_void) -> CUresult {
|
||||
unsafe { CUresult(cuda::cuCtxPopCurrent_v2(pctx as *mut _) as c_uint) }
|
||||
}
|
||||
|
||||
fn cuCtxGetApiVersion(ctx: *mut c_void, version: *mut c_uint) -> CUresult {
|
||||
unsafe { CUresult(cuda::cuCtxGetApiVersion(ctx as *mut _, version) as c_uint) }
|
||||
}
|
||||
|
||||
fn cuCtxGetCurrent(pctx: *mut *mut c_void) -> CUresult {
|
||||
unsafe { CUresult(cuda::cuCtxGetCurrent(pctx as *mut _) as c_uint) }
|
||||
}
|
||||
fn cuMemAlloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> CUresult {
|
||||
unsafe { CUresult(cuda::cuMemAlloc_v2(dptr as *mut _, bytesize) as c_uint) }
|
||||
}
|
||||
|
||||
fn cuDeviceGetUuid(uuid: *mut CUuuid, dev: c_int) -> CUresult {
|
||||
unsafe { CUresult(cuda::cuDeviceGetUuid(uuid as *mut _, dev) as c_uint) }
|
||||
}
|
||||
|
||||
fn cuDevicePrimaryCtxGetState(dev: c_int, flags: *mut c_uint, active: *mut c_int) -> CUresult {
|
||||
unsafe { CUresult(cuda::cuDevicePrimaryCtxGetState(dev, flags, active) as c_uint) }
|
||||
}
|
||||
|
||||
fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult {
|
||||
unsafe { CUresult(cuda::cuStreamGetCtx(hStream as _, pctx as _) as c_uint) }
|
||||
}
|
||||
|
||||
fn cuStreamCreate(stream: *mut CUstream, flags: c_uint) -> CUresult {
|
||||
unsafe { CUresult(cuda::cuStreamCreate(stream as _, flags as _) as c_uint) }
|
||||
}
|
||||
|
||||
fn cuMemFree_v2(mem: *mut c_void) -> CUresult {
|
||||
unsafe { CUresult(cuda::cuMemFree_v2(mem as _) as c_uint) }
|
||||
}
|
||||
|
||||
fn cuStreamDestroy_v2(stream: CUstream) -> CUresult {
|
||||
unsafe { CUresult(cuda::cuStreamDestroy_v2(stream as _) as c_uint) }
|
||||
}
|
||||
}
|
|
@ -27,10 +27,25 @@ macro_rules! implemented {
|
|||
};
|
||||
}
|
||||
|
||||
macro_rules! implemented_in_function {
|
||||
($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:path;)*) => {
|
||||
$(
|
||||
#[cfg_attr(not(test), no_mangle)]
|
||||
#[allow(improper_ctypes)]
|
||||
#[allow(improper_ctypes_definitions)]
|
||||
pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
|
||||
cuda_base::cuda_normalize_fn!( crate::r#impl::function::$fn_name ) ($(crate::r#impl::FromCuda::from_cuda(&$arg_id)?),*)?;
|
||||
Ok(())
|
||||
}
|
||||
)*
|
||||
};
|
||||
}
|
||||
|
||||
cuda_base::cuda_function_declarations!(
|
||||
unimplemented,
|
||||
implemented <= [
|
||||
cuCtxGetLimit,
|
||||
cuCtxSetCurrent,
|
||||
cuCtxSetLimit,
|
||||
cuCtxSynchronize,
|
||||
cuDeviceComputeCapability,
|
||||
|
@ -39,13 +54,25 @@ cuda_base::cuda_function_declarations!(
|
|||
cuDeviceGetCount,
|
||||
cuDeviceGetLuid,
|
||||
cuDeviceGetName,
|
||||
cuDevicePrimaryCtxRelease,
|
||||
cuDevicePrimaryCtxRetain,
|
||||
cuDeviceGetProperties,
|
||||
cuDeviceGetUuid,
|
||||
cuDeviceGetUuid_v2,
|
||||
cuDeviceTotalMem_v2,
|
||||
cuDriverGetVersion,
|
||||
cuFuncGetAttribute,
|
||||
cuInit,
|
||||
cuMemAlloc_v2,
|
||||
cuMemFree_v2,
|
||||
cuMemcpyDtoH_v2,
|
||||
cuMemcpyHtoD_v2,
|
||||
cuModuleGetFunction,
|
||||
cuModuleLoadData,
|
||||
cuModuleUnload,
|
||||
cuPointerGetAttribute,
|
||||
],
|
||||
implemented_in_function <= [
|
||||
cuLaunchKernel,
|
||||
]
|
||||
);
|
Loading…
Add table
Reference in a new issue