Refactor host code to use one big lock

This commit is contained in:
Andrzej Janik 2020-11-11 22:35:34 +01:00
commit a2e77fe961
15 changed files with 914 additions and 540 deletions

View file

@ -173,6 +173,16 @@ impl Context {
check!(sys::zeContextCreate(drv.0, &ctx_desc, &mut result)); check!(sys::zeContextCreate(drv.0, &ctx_desc, &mut result));
Ok(Context(result)) Ok(Context(result))
} }
pub unsafe fn mem_free(&mut self, ptr: *mut c_void) -> Result<()> {
check! {
sys::zeMemFree(
self.0,
ptr,
)
};
Ok(())
}
} }
impl Drop for Context { impl Drop for Context {

27
notcuda/build.rs Normal file
View file

@ -0,0 +1,27 @@
// HACK ALERT
// This buidl script has been copy-pasted from cl-sys to avoid CUDA libraries
// overriding path to OpenCL
fn main() {
if cfg!(windows) {
let known_sdk = [
// E.g. "c:\Program Files (x86)\Intel\OpenCL SDK\lib\x86\"
("INTELOCLSDKROOT", "x64", "x86"),
// E.g. "c:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v8.0\lib\Win32\"
("CUDA_PATH", "x64", "Win32"),
// E.g. "C:\Program Files (x86)\AMD APP SDK\3.0\lib\x86\"
("AMDAPPSDKROOT", "x86_64", "x86"),
];
for info in known_sdk.iter() {
if let Ok(sdk) = std::env::var(info.0) {
let mut path = std::path::PathBuf::from(sdk);
path.push("lib");
path.push(if cfg!(target_arch="x86_64") { info.1 } else { info.2 });
println!("cargo:rustc-link-search=native={}", path.display());
}
}
println!("cargo:rustc-link-search=native=C:\\Program Files (x86)\\OCL_SDK_Light\\lib\\x86_64");
}
}

View file

@ -2210,12 +2210,12 @@ pub extern "C" fn cuDriverGetVersion(driverVersion: *mut ::std::os::raw::c_int)
#[cfg_attr(not(test), no_mangle)] #[cfg_attr(not(test), no_mangle)]
pub extern "C" fn cuDeviceGet(device: *mut CUdevice, ordinal: ::std::os::raw::c_int) -> CUresult { pub extern "C" fn cuDeviceGet(device: *mut CUdevice, ordinal: ::std::os::raw::c_int) -> CUresult {
r#impl::device::get(device.decuda(), ordinal) r#impl::device::get(device.decuda(), ordinal).encuda()
} }
#[cfg_attr(not(test), no_mangle)] #[cfg_attr(not(test), no_mangle)]
pub extern "C" fn cuDeviceGetCount(count: *mut ::std::os::raw::c_int) -> CUresult { pub extern "C" fn cuDeviceGetCount(count: *mut ::std::os::raw::c_int) -> CUresult {
r#impl::device::get_count(count) r#impl::device::get_count(count).encuda()
} }
#[cfg_attr(not(test), no_mangle)] #[cfg_attr(not(test), no_mangle)]
@ -2314,7 +2314,6 @@ pub extern "C" fn cuDevicePrimaryCtxReset(dev: CUdevice) -> CUresult {
cuDevicePrimaryCtxReset_v2(dev) cuDevicePrimaryCtxReset_v2(dev)
} }
#[cfg_attr(not(test), no_mangle)] #[cfg_attr(not(test), no_mangle)]
pub extern "C" fn cuDevicePrimaryCtxReset_v2(dev: CUdevice) -> CUresult { pub extern "C" fn cuDevicePrimaryCtxReset_v2(dev: CUdevice) -> CUresult {
r#impl::unimplemented() r#impl::unimplemented()
@ -2331,7 +2330,7 @@ pub extern "C" fn cuCtxCreate_v2(
#[cfg_attr(not(test), no_mangle)] #[cfg_attr(not(test), no_mangle)]
pub extern "C" fn cuCtxDestroy_v2(ctx: CUcontext) -> CUresult { pub extern "C" fn cuCtxDestroy_v2(ctx: CUcontext) -> CUresult {
r#impl::context::destroy_v2(ctx.decuda()) r#impl::context::destroy_v2(ctx.decuda()).encuda()
} }
#[cfg_attr(not(test), no_mangle)] #[cfg_attr(not(test), no_mangle)]
@ -2356,7 +2355,7 @@ pub extern "C" fn cuCtxGetCurrent(pctx: *mut CUcontext) -> CUresult {
#[cfg_attr(not(test), no_mangle)] #[cfg_attr(not(test), no_mangle)]
pub extern "C" fn cuCtxGetDevice(device: *mut CUdevice) -> CUresult { pub extern "C" fn cuCtxGetDevice(device: *mut CUdevice) -> CUresult {
r#impl::context::get_device(device.decuda()) r#impl::context::get_device(device.decuda()).encuda()
} }
#[cfg_attr(not(test), no_mangle)] #[cfg_attr(not(test), no_mangle)]
@ -2404,7 +2403,7 @@ pub extern "C" fn cuCtxGetApiVersion(
ctx: CUcontext, ctx: CUcontext,
version: *mut ::std::os::raw::c_uint, version: *mut ::std::os::raw::c_uint,
) -> CUresult { ) -> CUresult {
r#impl::context::get_api_version(ctx.decuda(), version) r#impl::context::get_api_version(ctx.decuda(), version).encuda()
} }
#[cfg_attr(not(test), no_mangle)] #[cfg_attr(not(test), no_mangle)]
@ -2422,12 +2421,12 @@ pub extern "C" fn cuCtxResetPersistingL2Cache() -> CUresult {
#[cfg_attr(not(test), no_mangle)] #[cfg_attr(not(test), no_mangle)]
pub extern "C" fn cuCtxAttach(pctx: *mut CUcontext, flags: ::std::os::raw::c_uint) -> CUresult { pub extern "C" fn cuCtxAttach(pctx: *mut CUcontext, flags: ::std::os::raw::c_uint) -> CUresult {
r#impl::unimplemented() r#impl::context::attach(pctx.decuda(), flags).encuda()
} }
#[cfg_attr(not(test), no_mangle)] #[cfg_attr(not(test), no_mangle)]
pub extern "C" fn cuCtxDetach(ctx: CUcontext) -> CUresult { pub extern "C" fn cuCtxDetach(ctx: CUcontext) -> CUresult {
r#impl::unimplemented() r#impl::context::detach(ctx.decuda()).encuda()
} }
#[cfg_attr(not(test), no_mangle)] #[cfg_attr(not(test), no_mangle)]
@ -2443,7 +2442,7 @@ pub extern "C" fn cuModuleLoadData(
module: *mut CUmodule, module: *mut CUmodule,
image: *const ::std::os::raw::c_void, image: *const ::std::os::raw::c_void,
) -> CUresult { ) -> CUresult {
r#impl::unimplemented() r#impl::module::load_data(module.decuda(), image).encuda()
} }
#[cfg_attr(not(test), no_mangle)] #[cfg_attr(not(test), no_mangle)]
@ -2564,7 +2563,7 @@ pub extern "C" fn cuMemGetInfo_v2(free: *mut usize, total: *mut usize) -> CUresu
#[cfg_attr(not(test), no_mangle)] #[cfg_attr(not(test), no_mangle)]
pub extern "C" fn cuMemAlloc_v2(dptr: *mut CUdeviceptr, bytesize: usize) -> CUresult { pub extern "C" fn cuMemAlloc_v2(dptr: *mut CUdeviceptr, bytesize: usize) -> CUresult {
r#impl::memory::alloc_v2(dptr.decuda(), bytesize) r#impl::memory::alloc_v2(dptr.decuda(), bytesize).encuda()
} }
#[cfg_attr(not(test), no_mangle)] #[cfg_attr(not(test), no_mangle)]
@ -3281,7 +3280,7 @@ pub extern "C" fn cuStreamCreate(
phStream: *mut CUstream, phStream: *mut CUstream,
Flags: ::std::os::raw::c_uint, Flags: ::std::os::raw::c_uint,
) -> CUresult { ) -> CUresult {
r#impl::unimplemented() r#impl::stream::create(phStream.decuda(), Flags).encuda()
} }
#[cfg_attr(not(test), no_mangle)] #[cfg_attr(not(test), no_mangle)]
@ -3311,7 +3310,7 @@ pub extern "C" fn cuStreamGetFlags(
#[cfg_attr(not(test), no_mangle)] #[cfg_attr(not(test), no_mangle)]
pub extern "C" fn cuStreamGetCtx(hStream: CUstream, pctx: *mut CUcontext) -> CUresult { pub extern "C" fn cuStreamGetCtx(hStream: CUstream, pctx: *mut CUcontext) -> CUresult {
r#impl::unimplemented() r#impl::stream::get_ctx(hStream.decuda(), pctx.decuda()).encuda()
} }
#[cfg_attr(not(test), no_mangle)] #[cfg_attr(not(test), no_mangle)]
@ -3390,7 +3389,7 @@ pub extern "C" fn cuStreamSynchronize(hStream: CUstream) -> CUresult {
#[cfg_attr(not(test), no_mangle)] #[cfg_attr(not(test), no_mangle)]
pub extern "C" fn cuStreamDestroy_v2(hStream: CUstream) -> CUresult { pub extern "C" fn cuStreamDestroy_v2(hStream: CUstream) -> CUresult {
r#impl::unimplemented() r#impl::stream::destroy_v2(hStream.decuda()).encuda()
} }
#[cfg_attr(not(test), no_mangle)] #[cfg_attr(not(test), no_mangle)]

View file

@ -1,18 +1,15 @@
use super::CUresult; use super::{device, stream::Stream, stream::StreamData, HasLivenessCookie, LiveCheck};
use super::{device, HasLivenessCookie, LiveCheck}; use super::{CUresult, GlobalState};
use crate::{cuda::CUcontext, cuda_impl}; use crate::{cuda::CUcontext, cuda_impl};
use l0::sys::ze_result_t; use l0::sys::ze_result_t;
use std::mem::{self, ManuallyDrop}; use std::{cell::RefCell, num::NonZeroU32, os::raw::c_uint, ptr, sync::atomic::AtomicU32};
use std::{ use std::{
cell::RefCell, collections::HashSet,
num::NonZeroU32, mem::{self},
os::raw::c_uint,
ptr,
sync::{atomic::AtomicU32, Mutex},
}; };
thread_local! { thread_local! {
pub static CONTEXT_STACK: RefCell<Vec<*const Context>> = RefCell::new(Vec::new()); pub static CONTEXT_STACK: RefCell<Vec<*mut Context>> = RefCell::new(Vec::new());
} }
pub type Context = LiveCheck<ContextData>; pub type Context = LiveCheck<ContextData>;
@ -23,6 +20,17 @@ impl HasLivenessCookie for ContextData {
#[cfg(target_pointer_width = "32")] #[cfg(target_pointer_width = "32")]
const COOKIE: usize = 0x0b643ffb; const COOKIE: usize = 0x0b643ffb;
const LIVENESS_FAIL: CUresult = CUresult::CUDA_ERROR_INVALID_CONTEXT;
fn try_drop(&mut self) -> Result<(), CUresult> {
for stream in self.streams.iter() {
let stream = unsafe { &mut **stream };
stream.context = ptr::null_mut();
Stream::destroy_impl(unsafe { Stream::ptr_from_inner(stream) })?;
}
Ok(())
}
} }
enum ContextRefCount { enum ContextRefCount {
@ -67,26 +75,16 @@ impl ContextRefCount {
} }
} }
} }
fn is_primary(&self) -> bool {
match self {
ContextRefCount::Primary => true,
ContextRefCount::NonPrimary(_) => false,
}
}
} }
pub struct ContextData { pub struct ContextData {
pub flags: AtomicU32, pub flags: AtomicU32,
pub device_index: device::Index,
// This pointer is null only for a moment when constructing primary context // This pointer is null only for a moment when constructing primary context
pub device: *const Mutex<device::Device>, pub device: *mut device::Device,
// The split between mutable / non-mutable is mainly to avoid recursive locking in cuDevicePrimaryCtxGetState
pub mutable: Mutex<ContextDataMutable>,
}
pub struct ContextDataMutable {
ref_count: ContextRefCount, ref_count: ContextRefCount,
pub default_stream: StreamData,
pub streams: HashSet<*mut StreamData>,
// All the fields below are here to support internal CUDA driver API
pub cuda_manager: *mut cuda_impl::rt::ContextStateManager, pub cuda_manager: *mut cuda_impl::rt::ContextStateManager,
pub cuda_state: *mut cuda_impl::rt::ContextState, pub cuda_state: *mut cuda_impl::rt::ContextState,
pub cuda_dtor_cb: Option< pub cuda_dtor_cb: Option<
@ -100,63 +98,75 @@ pub struct ContextDataMutable {
impl ContextData { impl ContextData {
pub fn new( pub fn new(
l0_ctx: &mut l0::Context,
l0_dev: &l0::Device,
flags: c_uint, flags: c_uint,
is_primary: bool, is_primary: bool,
dev_index: device::Index, dev: *mut device::Device,
dev: *const Mutex<device::Device>, ) -> Result<Self, CUresult> {
) -> Self { let default_stream = StreamData::new_unitialized(l0_ctx, l0_dev)?;
ContextData { Ok(ContextData {
flags: AtomicU32::new(flags), flags: AtomicU32::new(flags),
device_index: dev_index,
device: dev, device: dev,
mutable: Mutex::new(ContextDataMutable {
ref_count: ContextRefCount::new(is_primary), ref_count: ContextRefCount::new(is_primary),
default_stream,
streams: HashSet::new(),
cuda_manager: ptr::null_mut(), cuda_manager: ptr::null_mut(),
cuda_state: ptr::null_mut(), cuda_state: ptr::null_mut(),
cuda_dtor_cb: None, cuda_dtor_cb: None,
}), })
}
} }
} }
pub fn create_v2(pctx: *mut *mut Context, flags: u32, dev_idx: device::Index) -> CUresult { impl Context {
pub fn late_init(&mut self) {
let ctx_data = self.as_option_mut().unwrap();
ctx_data.default_stream.context = ctx_data as *mut _;
}
}
pub fn create_v2(
pctx: *mut *mut Context,
flags: u32,
dev_idx: device::Index,
) -> Result<(), CUresult> {
if pctx == ptr::null_mut() { if pctx == ptr::null_mut() {
return CUresult::CUDA_ERROR_INVALID_VALUE; return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
} }
let dev = device::get_device_ref(dev_idx); let mut ctx_box = GlobalState::lock_device(dev_idx, |dev| {
let dev = match dev { let dev_ptr = dev as *mut _;
Ok(d) => d, let mut ctx_box = Box::new(LiveCheck::new(ContextData::new(
Err(e) => return e, &mut dev.l0_context,
}; &dev.base,
let mut ctx = Box::new(LiveCheck::new(ContextData::new(flags, false, dev_idx, dev))); flags,
let ctx_ref = ctx.as_mut() as *mut Context; false,
dev_ptr as *mut _,
)?));
ctx_box.late_init();
Ok::<_, CUresult>(ctx_box)
})??;
let ctx_ref = ctx_box.as_mut() as *mut Context;
unsafe { *pctx = ctx_ref }; unsafe { *pctx = ctx_ref };
mem::forget(ctx); mem::forget(ctx_box);
CONTEXT_STACK.with(|stack| stack.borrow_mut().push(ctx_ref)); CONTEXT_STACK.with(|stack| stack.borrow_mut().push(ctx_ref));
CUresult::CUDA_SUCCESS Ok(())
} }
pub fn destroy_v2(ctx: *mut Context) -> CUresult { pub fn destroy_v2(ctx: *mut Context) -> Result<(), CUresult> {
if ctx == ptr::null_mut() { if ctx == ptr::null_mut() {
return CUresult::CUDA_ERROR_INVALID_VALUE; return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
} }
CONTEXT_STACK.with(|stack| { CONTEXT_STACK.with(|stack| {
let mut stack = stack.borrow_mut(); let mut stack = stack.borrow_mut();
let should_pop = match stack.last() { let should_pop = match stack.last() {
Some(active_ctx) => *active_ctx == (ctx as *const _), Some(active_ctx) => *active_ctx == (ctx as *mut _),
None => false, None => false,
}; };
if should_pop { if should_pop {
stack.pop(); stack.pop();
} }
}); });
let mut ctx_box = ManuallyDrop::new(unsafe { Box::from_raw(ctx) }); GlobalState::lock(|_| Context::destroy_impl(ctx))?
if !ctx_box.try_drop() {
CUresult::CUDA_ERROR_INVALID_CONTEXT
} else {
unsafe { ManuallyDrop::drop(&mut ctx_box) };
CUresult::CUDA_SUCCESS
}
} }
pub fn pop_current_v2(pctx: *mut *mut Context) -> CUresult { pub fn pop_current_v2(pctx: *mut *mut Context) -> CUresult {
@ -172,17 +182,6 @@ pub fn pop_current_v2(pctx: *mut *mut Context) -> CUresult {
CUresult::CUDA_SUCCESS CUresult::CUDA_SUCCESS
} }
pub fn with_current<F: FnOnce(&ContextData) -> R, R>(f: F) -> Result<R, CUresult> {
CONTEXT_STACK.with(|stack| {
stack
.borrow()
.last()
.and_then(|c| unsafe { &**c }.as_ref())
.ok_or(CUresult::CUDA_ERROR_INVALID_CONTEXT)
.map(f)
})
}
pub fn get_current(pctx: *mut *mut Context) -> l0::Result<()> { pub fn get_current(pctx: *mut *mut Context) -> l0::Result<()> {
if pctx == ptr::null_mut() { if pctx == ptr::null_mut() {
return Err(ze_result_t::ZE_RESULT_ERROR_INVALID_ARGUMENT); return Err(ze_result_t::ZE_RESULT_ERROR_INVALID_ARGUMENT);
@ -205,37 +204,53 @@ pub fn set_current(ctx: *mut Context) -> CUresult {
} }
} }
pub fn get_api_version(ctx: *mut Context, version: *mut u32) -> CUresult { pub fn get_api_version(ctx: *mut Context, version: *mut u32) -> Result<(), CUresult> {
let _ctx = match unsafe { ctx.as_mut() } { if ctx == ptr::null_mut() {
None => return CUresult::CUDA_ERROR_INVALID_VALUE, return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
Some(ctx) => match ctx.as_mut() { }
None => return CUresult::CUDA_ERROR_INVALID_CONTEXT, GlobalState::lock(|_| {
Some(ctx) => ctx, unsafe { &*ctx }.as_result()?;
}, Ok::<_, CUresult>(())
}; })??;
//TODO: query device for properties roughly matching CUDA API version //TODO: query device for properties roughly matching CUDA API version
unsafe { *version = 1100 }; unsafe { *version = 1100 };
CUresult::CUDA_SUCCESS Ok(())
} }
pub fn get_device(dev: *mut device::Index) -> CUresult { pub fn get_device(dev: *mut device::Index) -> Result<(), CUresult> {
let dev_idx = with_current(|ctx| ctx.device_index); let dev_idx = GlobalState::lock_current_context(|ctx| unsafe { &*ctx.device }.index)?;
match dev_idx { unsafe { *dev = dev_idx };
Ok(idx) => { Ok(())
unsafe { *dev = idx }
CUresult::CUDA_SUCCESS
} }
Err(err) => err,
pub fn attach(pctx: *mut *mut Context, _flags: c_uint) -> Result<(), CUresult> {
if pctx == ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
} }
let ctx = GlobalState::lock_current_context_unchecked(|unchecked_ctx| {
let ctx = unchecked_ctx.as_result_mut()?;
ctx.ref_count.incr()?;
Ok::<_, CUresult>(unchecked_ctx as *mut _)
})??;
unsafe { *pctx = ctx };
Ok(())
}
pub fn detach(pctx: *mut Context) -> Result<(), CUresult> {
if pctx == ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
GlobalState::lock_current_context_unchecked(|unchecked_ctx| {
let ctx = unchecked_ctx.as_result_mut()?;
if ctx.ref_count.decr() {
Context::destroy_impl(unchecked_ctx)?;
}
Ok::<_, CUresult>(())
})?
} }
#[cfg(test)] #[cfg(test)]
pub fn is_context_stack_empty() -> bool { mod test {
CONTEXT_STACK.with(|stack| stack.borrow().is_empty())
}
#[cfg(test)]
mod tests {
use super::super::test::CudaDriverFns; use super::super::test::CudaDriverFns;
use super::super::CUresult; use super::super::CUresult;
use std::{ffi::c_void, ptr}; use std::{ffi::c_void, ptr};

View file

@ -1,24 +1,21 @@
use super::{context, transmute_lifetime, CUresult, Error}; use super::{context, CUresult, GlobalState};
use crate::cuda; use crate::cuda;
use cuda::{CUdevice_attribute, CUuuid_st}; use cuda::{CUdevice_attribute, CUuuid_st};
use std::{ use std::{
cmp, mem, cmp, mem,
os::raw::{c_char, c_int}, os::raw::{c_char, c_int},
ptr, ptr,
sync::{ sync::atomic::{AtomicU32, Ordering},
atomic::{AtomicU32, Ordering},
Mutex, MutexGuard,
},
}; };
const PROJECT_URL_SUFFIX: &'static str = " [github.com/vosen/notCUDA]"; const PROJECT_URL_SUFFIX: &'static str = " [github.com/vosen/notCUDA]";
static mut DEVICES: Option<Vec<Mutex<Device>>> = None;
#[repr(transparent)] #[repr(transparent)]
#[derive(Clone, Copy)] #[derive(Clone, Copy, Eq, PartialEq, Hash)]
pub struct Index(pub c_int); pub struct Index(pub c_int);
pub struct Device { pub struct Device {
pub index: Index,
pub base: l0::Device, pub base: l0::Device,
pub default_queue: l0::CommandQueue, pub default_queue: l0::CommandQueue,
pub l0_context: l0::Context, pub l0_context: l0::Context,
@ -33,17 +30,19 @@ unsafe impl Send for Device {}
impl Device { impl Device {
// Unsafe because it does not fully initalize primary_context // Unsafe because it does not fully initalize primary_context
unsafe fn new(drv: &l0::Driver, d: l0::Device, idx: usize) -> l0::Result<Self> { unsafe fn new(drv: &l0::Driver, l0_dev: l0::Device, idx: usize) -> Result<Self, CUresult> {
let mut ctx = l0::Context::new(drv)?; let mut ctx = l0::Context::new(drv)?;
let queue = l0::CommandQueue::new(&mut ctx, &d)?; let queue = l0::CommandQueue::new(&mut ctx, &l0_dev)?;
let primary_context = context::Context::new(context::ContextData::new( let primary_context = context::Context::new(context::ContextData::new(
&mut ctx,
&l0_dev,
0, 0,
true, true,
Index(idx as c_int), ptr::null_mut(),
ptr::null(), )?);
));
Ok(Self { Ok(Self {
base: d, index: Index(idx as c_int),
base: l0_dev,
default_queue: queue, default_queue: queue,
l0_context: ctx, l0_context: ctx,
primary_context: primary_context, primary_context: primary_context,
@ -93,83 +92,53 @@ impl Device {
Err(e) => Err(e), Err(e) => Err(e),
} }
} }
pub fn late_init(&mut self) {
self.primary_context.as_option_mut().unwrap().device = self as *mut _;
}
} }
pub fn init(driver: &l0::Driver) -> l0::Result<()> { pub fn init(driver: &l0::Driver) -> Result<Vec<Device>, CUresult> {
let ze_devices = driver.devices()?; let ze_devices = driver.devices()?;
let mut devices = ze_devices let mut devices = ze_devices
.into_iter() .into_iter()
.enumerate() .enumerate()
.map(|(idx, d)| unsafe { Device::new(driver, d, idx) }.map(Mutex::new)) .map(|(idx, d)| unsafe { Device::new(driver, d, idx) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
for d in devices.iter_mut() { for dev in devices.iter_mut() {
d.get_mut() dev.late_init();
.unwrap() dev.primary_context.late_init();
.primary_context
.as_mut()
.unwrap()
.device = d;
} }
unsafe { DEVICES = Some(devices) }; Ok(devices)
}
pub fn get_count(count: *mut c_int) -> Result<(), CUresult> {
let len = GlobalState::lock(|state| state.devices.len())?;
unsafe { *count = len as c_int };
Ok(()) Ok(())
} }
fn devices() -> Result<&'static Vec<Mutex<Device>>, CUresult> { pub fn get(device: *mut Index, ordinal: c_int) -> Result<(), CUresult> {
match unsafe { &DEVICES } {
Some(devs) => Ok(devs),
None => Err(CUresult::CUDA_ERROR_NOT_INITIALIZED),
}
}
pub fn get_device_ref(Index(dev_idx): Index) -> Result<&'static Mutex<Device>, CUresult> {
let devs = devices()?;
if dev_idx < 0 || dev_idx >= devs.len() as c_int {
return Err(CUresult::CUDA_ERROR_INVALID_DEVICE);
}
Ok(&devs[dev_idx as usize])
}
pub fn get_device(dev_idx: Index) -> Result<MutexGuard<'static, Device>, CUresult> {
let dev = get_device_ref(dev_idx)?;
dev.lock().map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE)
}
pub fn get_count(count: *mut c_int) -> CUresult {
let len = devices().map(|d| d.len());
match len {
Ok(len) => {
unsafe { *count = len as c_int };
CUresult::CUDA_SUCCESS
}
Err(e) => e,
}
}
pub fn get(device: *mut Index, ordinal: c_int) -> CUresult {
if device == ptr::null_mut() || ordinal < 0 { if device == ptr::null_mut() || ordinal < 0 {
return CUresult::CUDA_ERROR_INVALID_VALUE; return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
} }
let len = devices().map(|d| d.len()); let len = GlobalState::lock(|state| state.devices.len())?;
match len { if ordinal < (len as i32) {
Ok(len) if ordinal < (len as i32) => {
unsafe { *device = Index(ordinal) }; unsafe { *device = Index(ordinal) };
CUresult::CUDA_SUCCESS Ok(())
} } else {
Ok(_) => CUresult::CUDA_ERROR_INVALID_VALUE, Err(CUresult::CUDA_ERROR_INVALID_VALUE)
Err(e) => e,
} }
} }
pub fn get_name(name: *mut c_char, len: i32, dev: Index) -> Result<(), CUresult> { pub fn get_name(name: *mut c_char, len: i32, dev_idx: Index) -> Result<(), CUresult> {
if name == ptr::null_mut() || len < 0 { if name == ptr::null_mut() || len < 0 {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE); return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
} }
// This is safe because devices are 'static let name_ptr = GlobalState::lock_device(dev_idx, |dev| {
let name_ptr = { let props = dev.get_properties()?;
let mut dev = get_device(dev)?; Ok::<_, l0::sys::ze_result_t>(props.name.as_ptr())
let props = dev.get_properties().map_err(Into::<CUresult>::into)?; })??;
props.name.as_ptr()
};
let name_len = (0..256) let name_len = (0..256)
.position(|i| unsafe { *name_ptr.add(i) } == 0) .position(|i| unsafe { *name_ptr.add(i) } == 0)
.unwrap_or(256); .unwrap_or(256);
@ -189,20 +158,14 @@ pub fn get_name(name: *mut c_char, len: i32, dev: Index) -> Result<(), CUresult>
Ok(()) Ok(())
} }
pub fn total_mem_v2(bytes: *mut usize, dev: Index) -> Result<(), CUresult> { pub fn total_mem_v2(bytes: *mut usize, dev_idx: Index) -> Result<(), CUresult> {
if bytes == ptr::null_mut() { if bytes == ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE); return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
} }
// This is safe because devices are 'static let mem_props = GlobalState::lock_device(dev_idx, |dev| {
let mem_props = { let mem_props = dev.get_memory_properties()?;
let mut dev = get_device(dev)?; Ok::<_, l0::sys::ze_result_t>(mem_props)
unsafe { })??;
transmute_lifetime(
dev.get_memory_properties()
.map_err(Into::<CUresult>::into)?,
)
}
};
let max_mem = mem_props let max_mem = mem_props
.iter() .iter()
.map(|p| p.totalSize) .map(|p| p.totalSize)
@ -228,56 +191,101 @@ impl CUdevice_attribute {
} }
} }
pub fn get_attribute(pi: *mut i32, attrib: CUdevice_attribute, dev: Index) -> Result<(), Error> { pub fn get_attribute(
pi: *mut i32,
attrib: CUdevice_attribute,
dev_idx: Index,
) -> Result<(), CUresult> {
if pi == ptr::null_mut() { if pi == ptr::null_mut() {
return Err(Error::Cuda(CUresult::CUDA_ERROR_INVALID_VALUE)); return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
} }
if let Some(value) = attrib.get_static_value() { if let Some(value) = attrib.get_static_value() {
unsafe { *pi = value }; unsafe { *pi = value };
return Ok(()); return Ok(());
} }
let mut dev = get_device(dev).map_err(Error::Cuda)?;
let value = match attrib { let value = match attrib {
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_ASYNC_ENGINE_COUNT => { CUdevice_attribute::CU_DEVICE_ATTRIBUTE_ASYNC_ENGINE_COUNT => {
dev.get_properties().map_err(Error::L0)?.maxHardwareContexts as i32 GlobalState::lock_device(dev_idx, |dev| {
let props = dev.get_properties()?;
Ok::<_, l0::sys::ze_result_t>(props.maxHardwareContexts as i32)
})??
} }
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT => { CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT => {
let props = dev.get_properties().map_err(Error::L0)?; GlobalState::lock_device(dev_idx, |dev| {
(props.numSlices * props.numSubslicesPerSlice * props.numEUsPerSubslice) as i32 let props = dev.get_properties()?;
Ok::<_, l0::sys::ze_result_t>(
(props.numSlices * props.numSubslicesPerSlice * props.numEUsPerSubslice) as i32,
)
})??
} }
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_WIDTH => cmp::min( CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_WIDTH => {
dev.get_image_properties() GlobalState::lock_device(dev_idx, |dev| {
.map_err(Error::L0)? let props = dev.get_image_properties()?;
.maxImageDims1D, Ok::<_, l0::sys::ze_result_t>(cmp::min(
props.maxImageDims1D,
c_int::max_value() as u32, c_int::max_value() as u32,
) as c_int, ) as c_int)
})??
}
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X => { CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X => {
let props = dev.get_compute_properties().map_err(Error::L0)?; GlobalState::lock_device(dev_idx, |dev| {
cmp::max(i32::max_value() as u32, props.maxGroupCountX) as i32 let props = dev.get_compute_properties()?;
Ok::<_, l0::sys::ze_result_t>(cmp::max(
i32::max_value() as u32,
props.maxGroupCountX,
) as i32)
})??
} }
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y => { CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y => {
let props = dev.get_compute_properties().map_err(Error::L0)?; GlobalState::lock_device(dev_idx, |dev| {
cmp::max(i32::max_value() as u32, props.maxGroupCountY) as i32 let props = dev.get_compute_properties()?;
Ok::<_, l0::sys::ze_result_t>(cmp::max(
i32::max_value() as u32,
props.maxGroupCountY,
) as i32)
})??
} }
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z => { CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z => {
let props = dev.get_compute_properties().map_err(Error::L0)?; GlobalState::lock_device(dev_idx, |dev| {
cmp::max(i32::max_value() as u32, props.maxGroupCountZ) as i32 let props = dev.get_compute_properties()?;
Ok::<_, l0::sys::ze_result_t>(cmp::max(
i32::max_value() as u32,
props.maxGroupCountZ,
) as i32)
})??
} }
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X => { CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X => {
let props = dev.get_compute_properties().map_err(Error::L0)?; GlobalState::lock_device(dev_idx, |dev| {
cmp::max(i32::max_value() as u32, props.maxGroupSizeX) as i32 let props = dev.get_compute_properties()?;
Ok::<_, l0::sys::ze_result_t>(
cmp::max(i32::max_value() as u32, props.maxGroupSizeX) as i32,
)
})??
} }
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y => { CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y => {
let props = dev.get_compute_properties().map_err(Error::L0)?; GlobalState::lock_device(dev_idx, |dev| {
cmp::max(i32::max_value() as u32, props.maxGroupSizeY) as i32 let props = dev.get_compute_properties()?;
Ok::<_, l0::sys::ze_result_t>(
cmp::max(i32::max_value() as u32, props.maxGroupSizeY) as i32,
)
})??
} }
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z => { CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z => {
let props = dev.get_compute_properties().map_err(Error::L0)?; GlobalState::lock_device(dev_idx, |dev| {
cmp::max(i32::max_value() as u32, props.maxGroupSizeZ) as i32 let props = dev.get_compute_properties()?;
Ok::<_, l0::sys::ze_result_t>(
cmp::max(i32::max_value() as u32, props.maxGroupSizeZ) as i32,
)
})??
} }
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK => { CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK => {
let props = dev.get_compute_properties().map_err(Error::L0)?; GlobalState::lock_device(dev_idx, |dev| {
cmp::max(i32::max_value() as u32, props.maxTotalGroupSize) as i32 let props = dev.get_compute_properties()?;
Ok::<_, l0::sys::ze_result_t>(cmp::max(
i32::max_value() as u32,
props.maxTotalGroupSize,
) as i32)
})??
} }
_ => { _ => {
// TODO: support more attributes for CUDA runtime // TODO: support more attributes for CUDA runtime
@ -293,14 +301,11 @@ pub fn get_attribute(pi: *mut i32, attrib: CUdevice_attribute, dev: Index) -> Re
Ok(()) Ok(())
} }
pub fn get_uuid(uuid: *mut CUuuid_st, dev: Index) -> Result<(), Error> { pub fn get_uuid(uuid: *mut CUuuid_st, dev_idx: Index) -> Result<(), CUresult> {
let ze_uuid = { let ze_uuid = GlobalState::lock_device(dev_idx, |dev| {
get_device(dev) let props = dev.get_properties()?;
.map_err(Error::Cuda)? Ok::<_, l0::sys::ze_result_t>(props.uuid)
.get_properties() })??;
.map_err(Error::L0)?
.uuid
};
unsafe { unsafe {
*uuid = CUuuid_st { *uuid = CUuuid_st {
bytes: mem::transmute(ze_uuid.id), bytes: mem::transmute(ze_uuid.id),
@ -309,53 +314,39 @@ pub fn get_uuid(uuid: *mut CUuuid_st, dev: Index) -> Result<(), Error> {
Ok(()) Ok(())
} }
pub fn with_current_exclusive<F: FnOnce(&mut Device) -> R, R>(f: F) -> Result<R, CUresult> {
let dev = super::context::with_current(|ctx| ctx.device);
dev.and_then(|dev| {
unsafe { &*dev }
.try_lock()
.map(|mut dev| f(&mut dev))
.map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE)
})
}
pub fn with_exclusive<F: FnOnce(&mut Device) -> R, R>(dev: Index, f: F) -> Result<R, CUresult> {
let dev = get_device_ref(dev)?;
dev.try_lock()
.map(|mut dev| f(&mut dev))
.map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE)
}
pub fn primary_ctx_get_state( pub fn primary_ctx_get_state(
idx: Index, dev_idx: Index,
flags: *mut u32, flags: *mut u32,
active: *mut i32, active: *mut i32,
) -> Result<(), CUresult> { ) -> Result<(), CUresult> {
let (ctx_ptr, flags_ptr) = with_exclusive(idx, |dev| { let (is_active, flags_value) = GlobalState::lock_device(dev_idx, |dev| {
// This is safe because primary context can't be dropped // This is safe because primary context can't be dropped
let ctx_ptr = &dev.primary_context as *const _; let ctx_ptr = &mut dev.primary_context as *mut _;
let flags_ptr = let flags_ptr =
(&unsafe { dev.primary_context.as_ref_unchecked() }.flags) as *const AtomicU32; (&unsafe { dev.primary_context.as_ref_unchecked() }.flags) as *const AtomicU32;
(ctx_ptr, flags_ptr)
})?;
let is_active = context::CONTEXT_STACK let is_active = context::CONTEXT_STACK
.with(|stack| stack.borrow().last().map(|x| *x)) .with(|stack| stack.borrow().last().map(|x| *x))
.map(|current| current == ctx_ptr) .map(|current| current == ctx_ptr)
.unwrap_or(false); .unwrap_or(false);
let flags_value = unsafe { &*flags_ptr }.load(Ordering::Relaxed); let flags_value = unsafe { &*flags_ptr }.load(Ordering::Relaxed);
unsafe { *flags = flags_value }; Ok::<_, l0::sys::ze_result_t>((is_active, flags_value))
})??;
unsafe { *active = if is_active { 1 } else { 0 } }; unsafe { *active = if is_active { 1 } else { 0 } };
unsafe { *flags = flags_value };
Ok(()) Ok(())
} }
pub fn primary_ctx_retain(pctx: *mut *mut context::Context, dev: Index) -> Result<(), CUresult> { pub fn primary_ctx_retain(
let ctx_ptr = with_exclusive(dev, |dev| &mut dev.primary_context as *mut _)?; pctx: *mut *mut context::Context,
dev_idx: Index,
) -> Result<(), CUresult> {
let ctx_ptr = GlobalState::lock_device(dev_idx, |dev| &mut dev.primary_context as *mut _)?;
unsafe { *pctx = ctx_ptr }; unsafe { *pctx = ctx_ptr };
Ok(()) Ok(())
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod test {
use super::super::test::CudaDriverFns; use super::super::test::CudaDriverFns;
use super::super::CUresult; use super::super::CUresult;

View file

@ -4,7 +4,7 @@ use crate::{
cuda_impl, cuda_impl,
}; };
use super::{context, device, module, Decuda, Encuda}; use super::{context, context::ContextData, module, Decuda, Encuda, GlobalState};
use std::mem; use std::mem;
use std::os::raw::{c_uint, c_ulong, c_ushort}; use std::os::raw::{c_uint, c_ulong, c_ushort};
use std::{ use std::{
@ -110,17 +110,8 @@ static CUDART_INTERFACE_VTABLE: [VTableEntry; CUDART_INTERFACE_LENGTH] = [
VTableEntry { ptr: ptr::null() }, VTableEntry { ptr: ptr::null() },
]; ];
unsafe extern "C" fn cudart_interface_fn1(pctx: *mut CUcontext, dev: CUdevice) -> CUresult { unsafe extern "C" fn cudart_interface_fn1(_pctx: *mut CUcontext, _dev: CUdevice) -> CUresult {
cudart_interface_fn1_impl(pctx.decuda(), dev.decuda()).encuda() super::unimplemented()
}
fn cudart_interface_fn1_impl(
pctx: *mut *mut context::Context,
dev: device::Index,
) -> Result<(), CUresult> {
let ctx_ptr = device::with_exclusive(dev, |d| &mut d.primary_context as *mut context::Context)?;
unsafe { *pctx = ctx_ptr };
Ok(())
} }
/* /*
@ -234,10 +225,13 @@ unsafe extern "C" fn get_module_from_cubin(
}, },
Err(_) => continue, Err(_) => continue,
}; };
let module = module::ModuleData::compile_spirv(kernel_text_string); let module = module::SpirvModule::new(kernel_text_string);
match module { match module {
Ok(module) => { Ok(module) => {
*result = Box::into_raw(Box::new(module)); match module::load_data_impl(result, module) {
Ok(()) => {}
Err(err) => return err,
}
return CUresult::CUDA_SUCCESS; return CUresult::CUDA_SUCCESS;
} }
Err(_) => continue, Err(_) => continue,
@ -309,7 +303,7 @@ unsafe extern "C" fn context_local_storage_ctor(
} }
fn context_local_storage_ctor_impl( fn context_local_storage_ctor_impl(
mut cu_ctx: *mut context::Context, cu_ctx: *mut context::Context,
mgr: *mut cuda_impl::rt::ContextStateManager, mgr: *mut cuda_impl::rt::ContextStateManager,
ctx_state: *mut cuda_impl::rt::ContextState, ctx_state: *mut cuda_impl::rt::ContextState,
dtor_cb: Option< dtor_cb: Option<
@ -320,26 +314,11 @@ fn context_local_storage_ctor_impl(
), ),
>, >,
) -> Result<(), CUresult> { ) -> Result<(), CUresult> {
if cu_ctx == ptr::null_mut() { lock_context(cu_ctx, |ctx: &mut ContextData| {
context::get_current(&mut cu_ctx)?; ctx.cuda_manager = mgr;
} ctx.cuda_state = ctx_state;
if cu_ctx == ptr::null_mut() { ctx.cuda_dtor_cb = dtor_cb;
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
unsafe { &*cu_ctx }
.as_ref()
.ok_or(CUresult::CUDA_ERROR_INVALID_CONTEXT)
.and_then(|ctx| {
ctx.mutable
.try_lock()
.map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE)
.map(|mut mutable| {
mutable.cuda_manager = mgr;
mutable.cuda_state = ctx_state;
mutable.cuda_dtor_cb = dtor_cb;
}) })
})?;
Ok(())
} }
// some kind of dtor // some kind of dtor
@ -357,24 +336,10 @@ unsafe extern "C" fn context_local_storage_get_state(
fn context_local_storage_get_state_impl( fn context_local_storage_get_state_impl(
ctx_state: *mut *mut cuda_impl::rt::ContextState, ctx_state: *mut *mut cuda_impl::rt::ContextState,
mut cu_ctx: *mut context::Context, cu_ctx: *mut context::Context,
_: *mut cuda_impl::rt::ContextStateManager, _: *mut cuda_impl::rt::ContextStateManager,
) -> Result<(), CUresult> { ) -> Result<(), CUresult> {
if cu_ctx == ptr::null_mut() { let cuda_state = lock_context(cu_ctx, |ctx: &mut ContextData| ctx.cuda_state)?;
context::get_current(&mut cu_ctx)?;
}
if cu_ctx == ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
let cuda_state = unsafe { &*cu_ctx }
.as_ref()
.ok_or(CUresult::CUDA_ERROR_INVALID_CONTEXT)
.and_then(|ctx| {
ctx.mutable
.try_lock()
.map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE)
.map(|mutable| mutable.cuda_state)
})?;
if cuda_state == ptr::null_mut() { if cuda_state == ptr::null_mut() {
Err(CUresult::CUDA_ERROR_INVALID_VALUE) Err(CUresult::CUDA_ERROR_INVALID_VALUE)
} else { } else {
@ -382,3 +347,17 @@ fn context_local_storage_get_state_impl(
Ok(()) Ok(())
} }
} }
fn lock_context<T>(
cu_ctx: *mut context::Context,
fn_impl: impl FnOnce(&mut ContextData) -> T,
) -> Result<T, CUresult> {
if cu_ctx == ptr::null_mut() {
GlobalState::lock_current_context(fn_impl)
} else {
GlobalState::lock(|_| {
let ctx = unsafe { &mut *cu_ctx }.as_result_mut()?;
Ok(fn_impl(ctx))
})?
}
}

View file

@ -1,11 +1,28 @@
use ::std::os::raw::{c_uint, c_void}; use ::std::os::raw::{c_uint, c_void};
use std::ptr; use std::ptr;
use super::{device, stream::Stream, CUresult}; use super::{CUresult, GlobalState, HasLivenessCookie, LiveCheck, stream::Stream};
pub struct Function { pub type Function = LiveCheck<FunctionData>;
impl HasLivenessCookie for FunctionData {
#[cfg(target_pointer_width = "64")]
const COOKIE: usize = 0x5e2ab14d5840678e;
#[cfg(target_pointer_width = "32")]
const COOKIE: usize = 0x33e6a1e6;
const LIVENESS_FAIL: CUresult = CUresult::CUDA_ERROR_INVALID_HANDLE;
fn try_drop(&mut self) -> Result<(), CUresult> {
Ok(())
}
}
pub struct FunctionData {
pub base: l0::Kernel<'static>, pub base: l0::Kernel<'static>,
pub arg_size: Vec<usize>, pub arg_size: Vec<usize>,
pub use_shared_mem: bool,
} }
pub fn launch_kernel( pub fn launch_kernel(
@ -17,36 +34,43 @@ pub fn launch_kernel(
block_dim_y: c_uint, block_dim_y: c_uint,
block_dim_z: c_uint, block_dim_z: c_uint,
shared_mem_bytes: c_uint, shared_mem_bytes: c_uint,
strean: *mut Stream, hstream: *mut Stream,
kernel_params: *mut *mut c_void, kernel_params: *mut *mut c_void,
extra: *mut *mut c_void, extra: *mut *mut c_void,
) -> Result<(), CUresult> { ) -> Result<(), CUresult> {
if f == ptr::null_mut() { if f == ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE); return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
} }
if shared_mem_bytes != 0 || strean != ptr::null_mut() || extra != ptr::null_mut() { if extra != ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_NOT_SUPPORTED); return Err(CUresult::CUDA_ERROR_NOT_SUPPORTED);
} }
let func = unsafe { &*f }; GlobalState::lock_stream(hstream, |stream| {
for (i, arg_size) in func.arg_size.iter().copied().enumerate() { let func: &mut FunctionData = unsafe { &mut *f }.as_result_mut()?;
for (i, arg_size) in func.arg_size.iter().enumerate() {
unsafe { unsafe {
func.base func.base
.set_arg_raw(i as u32, arg_size, *kernel_params.add(i))? .set_arg_raw(i as u32, *arg_size, *kernel_params.add(i))?
}; };
} }
unsafe { &*f } if func.use_shared_mem {
.base unsafe {
func.base.set_arg_raw(
func.arg_size.len() as u32,
shared_mem_bytes as usize,
ptr::null(),
)?
};
}
func.base
.set_group_size(block_dim_x, block_dim_y, block_dim_z)?; .set_group_size(block_dim_x, block_dim_y, block_dim_z)?;
device::with_current_exclusive(|dev| { let mut cmd_list = stream.command_list()?;
let mut cmd_list = l0::CommandList::new(&mut dev.l0_context, &dev.base)?;
cmd_list.append_launch_kernel( cmd_list.append_launch_kernel(
&unsafe { &*f }.base, &mut func.base,
&[grid_dim_x, grid_dim_y, grid_dim_z], &[grid_dim_x, grid_dim_y, grid_dim_z],
None, None,
&mut [], &mut [],
)?; )?;
dev.default_queue.execute(cmd_list)?; stream.queue.execute(cmd_list)?;
l0::Result::Ok(())
})??;
Ok(()) Ok(())
})?
} }

View file

@ -1,57 +1,34 @@
use super::CUresult; use super::{stream, CUresult, GlobalState};
use std::ffi::c_void; use std::ffi::c_void;
pub fn alloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> CUresult { pub fn alloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> Result<(), CUresult> {
let alloc_result = super::device::with_current_exclusive(|dev| unsafe { let ptr = GlobalState::lock_current_context(|ctx| {
dev.base.mem_alloc_device(&mut dev.l0_context, bytesize, 0) let dev = unsafe { &mut *ctx.device };
}); Ok::<_, CUresult>(unsafe { dev.base.mem_alloc_device(&mut dev.l0_context, bytesize, 0) }?)
match alloc_result { })??;
Ok(Ok(alloc)) => { unsafe { *dptr = ptr };
unsafe { *dptr = alloc }; Ok(())
CUresult::CUDA_SUCCESS
}
Ok(Err(e)) => e.into(),
Err(e) => e,
}
} }
pub fn copy_v2( pub fn copy_v2(dst: *mut c_void, src: *const c_void, bytesize: usize) -> Result<(), CUresult> {
dst: *mut c_void, GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| {
src: *const c_void, let mut cmd_list = stream.command_list()?;
bytesize: usize, unsafe { cmd_list.append_memory_copy_unsafe(dst, src, bytesize, None, &mut []) }?;
) -> Result<Result<(), l0::sys::ze_result_t>, CUresult> { stream.queue.execute(cmd_list)?;
super::device::with_current_exclusive(|dev| unsafe { Ok::<_, CUresult>(())
memcpy_impl( })?
&mut dev.l0_context, }
dst,
src, pub fn free_v2(ptr: *mut c_void) -> Result<(), CUresult> {
bytesize, GlobalState::lock_current_context(|ctx| {
&dev.base, let dev = unsafe { &mut *ctx.device };
&mut dev.default_queue, Ok::<_, CUresult>(unsafe { dev.l0_context.mem_free(ptr) }?)
)
}) })
} .map_err(|_| CUresult::CUDA_ERROR_INVALID_VALUE)?
unsafe fn memcpy_impl(
ctx: &mut l0::Context,
dst: *mut c_void,
src: *const c_void,
bytes_count: usize,
dev: &l0::Device,
queue: &mut l0::CommandQueue,
) -> l0::Result<()> {
let mut cmd_list = l0::CommandList::new(ctx, &dev)?;
cmd_list.append_memory_copy_unsafe(dst, src, bytes_count, None, &mut [])?;
queue.execute(cmd_list)?;
Ok(())
}
pub(crate) fn free_v2(_: *mut c_void)-> l0::Result<()> {
Ok(())
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod test {
use super::super::test::CudaDriverFns; use super::super::test::CudaDriverFns;
use super::super::CUresult; use super::super::CUresult;
use std::ptr; use std::ptr;
@ -82,4 +59,20 @@ mod tests {
assert_ne!(mem, ptr::null_mut()); assert_ne!(mem, ptr::null_mut());
assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS); assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS);
} }
cuda_driver_test!(free_without_ctx);
fn free_without_ctx<T: CudaDriverFns>() {
assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS);
let mut ctx = ptr::null_mut();
assert_eq!(T::cuCtxCreate_v2(&mut ctx, 0, 0), CUresult::CUDA_SUCCESS);
let mut mem = ptr::null_mut();
assert_eq!(
T::cuMemAlloc_v2(&mut mem, std::mem::size_of::<usize>()),
CUresult::CUDA_SUCCESS
);
assert_ne!(mem, ptr::null_mut());
assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS);
assert_eq!(T::cuMemFree_v2(mem), CUresult::CUDA_ERROR_INVALID_VALUE);
}
} }

View file

@ -1,5 +1,15 @@
use crate::cuda::{CUctx_st, CUdevice, CUdeviceptr, CUfunc_st, CUmod_st, CUresult, CUstream_st}; use crate::{
use std::{ffi::c_void, mem::{self, ManuallyDrop}, os::raw::c_int, sync::Mutex}; cuda::{CUctx_st, CUdevice, CUdeviceptr, CUfunc_st, CUmod_st, CUresult, CUstream_st},
r#impl::device::Device,
};
use std::{
ffi::c_void,
mem::{self, ManuallyDrop},
os::raw::c_int,
ptr,
sync::Mutex,
sync::TryLockError,
};
#[cfg(test)] #[cfg(test)]
#[macro_use] #[macro_use]
@ -7,9 +17,9 @@ pub mod test;
pub mod context; pub mod context;
pub mod device; pub mod device;
pub mod export_table; pub mod export_table;
pub mod function;
pub mod memory; pub mod memory;
pub mod module; pub mod module;
pub mod function;
pub mod stream; pub mod stream;
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
@ -22,8 +32,11 @@ pub fn unimplemented() -> CUresult {
CUresult::CUDA_ERROR_NOT_SUPPORTED CUresult::CUDA_ERROR_NOT_SUPPORTED
} }
pub trait HasLivenessCookie { pub trait HasLivenessCookie: Sized {
const COOKIE: usize; const COOKIE: usize;
const LIVENESS_FAIL: CUresult;
fn try_drop(&mut self) -> Result<(), CUresult>;
} }
// This struct is a best-effort check if wrapped value has been dropped, // This struct is a best-effort check if wrapped value has been dropped,
@ -42,19 +55,23 @@ impl<T: HasLivenessCookie> LiveCheck<T> {
} }
} }
fn destroy_impl(this: *mut Self) -> Result<(), CUresult> {
let mut ctx_box = ManuallyDrop::new(unsafe { Box::from_raw(this) });
ctx_box.try_drop()?;
unsafe { ManuallyDrop::drop(&mut ctx_box) };
Ok(())
}
unsafe fn ptr_from_inner(this: *mut T) -> *mut Self {
let outer_ptr = (this as *mut u8).sub(mem::size_of::<usize>());
outer_ptr as *mut Self
}
pub unsafe fn as_ref_unchecked(&self) -> &T { pub unsafe fn as_ref_unchecked(&self) -> &T {
&self.data &self.data
} }
pub fn as_ref(&self) -> Option<&T> { pub fn as_option_mut(&mut self) -> Option<&mut T> {
if self.cookie == T::COOKIE {
Some(&self.data)
} else {
None
}
}
pub fn as_mut(&mut self) -> Option<&mut T> {
if self.cookie == T::COOKIE { if self.cookie == T::COOKIE {
Some(&mut self.data) Some(&mut self.data)
} else { } else {
@ -62,14 +79,31 @@ impl<T: HasLivenessCookie> LiveCheck<T> {
} }
} }
pub fn as_result(&self) -> Result<&T, CUresult> {
if self.cookie == T::COOKIE {
Ok(&self.data)
} else {
Err(T::LIVENESS_FAIL)
}
}
pub fn as_result_mut(&mut self) -> Result<&mut T, CUresult> {
if self.cookie == T::COOKIE {
Ok(&mut self.data)
} else {
Err(T::LIVENESS_FAIL)
}
}
#[must_use] #[must_use]
pub fn try_drop(&mut self) -> bool { pub fn try_drop(&mut self) -> Result<(), CUresult> {
if self.cookie == T::COOKIE { if self.cookie == T::COOKIE {
self.cookie = 0; self.cookie = 0;
self.data.try_drop()?;
unsafe { ManuallyDrop::drop(&mut self.data) }; unsafe { ManuallyDrop::drop(&mut self.data) };
return true; return Ok(());
} }
false Err(T::LIVENESS_FAIL)
} }
} }
@ -121,6 +155,12 @@ impl From<l0::sys::ze_result_t> for CUresult {
} }
} }
impl<T> From<TryLockError<T>> for CUresult {
fn from(_: TryLockError<T>) -> Self {
CUresult::CUDA_ERROR_ILLEGAL_STATE
}
}
pub trait Encuda { pub trait Encuda {
type To: Sized; type To: Sized;
fn encuda(self: Self) -> Self::To; fn encuda(self: Self) -> Self::To;
@ -157,58 +197,103 @@ impl<T1: Encuda<To = CUresult>, T2: Encuda<To = CUresult>> Encuda for Result<T1,
} }
} }
pub enum Error {
L0(l0::sys::ze_result_t),
Cuda(CUresult),
}
impl Encuda for Error {
type To = CUresult;
fn encuda(self: Self) -> Self::To {
match self {
Error::L0(e) => e.into(),
Error::Cuda(e) => e,
}
}
}
lazy_static! { lazy_static! {
static ref GLOBAL_STATE: Mutex<Option<GlobalState>> = Mutex::new(None); static ref GLOBAL_STATE: Mutex<Option<GlobalState>> = Mutex::new(None);
} }
struct GlobalState { struct GlobalState {
driver: l0::Driver, devices: Vec<Device>,
} }
unsafe impl Send for GlobalState {} unsafe impl Send for GlobalState {}
impl GlobalState {
fn lock<T>(f: impl FnOnce(&mut GlobalState) -> T) -> Result<T, CUresult> {
let mut mutex = GLOBAL_STATE
.lock()
.unwrap_or_else(|poison| poison.into_inner());
let global_state = mutex.as_mut().ok_or(CUresult::CUDA_ERROR_ILLEGAL_STATE)?;
Ok(f(global_state))
}
fn lock_device<T>(
device::Index(dev_idx): device::Index,
f: impl FnOnce(&'static mut device::Device) -> T,
) -> Result<T, CUresult> {
if dev_idx < 0 {
return Err(CUresult::CUDA_ERROR_INVALID_DEVICE);
}
Self::lock(|global_state| {
if dev_idx >= global_state.devices.len() as c_int {
Err(CUresult::CUDA_ERROR_INVALID_DEVICE)
} else {
Ok(f(unsafe {
transmute_lifetime_mut(&mut global_state.devices[dev_idx as usize])
}))
}
})?
}
fn lock_current_context<F: FnOnce(&mut context::ContextData) -> R, R>(
f: F,
) -> Result<R, CUresult> {
Self::lock_current_context_unchecked(|ctx| Ok(f(ctx.as_result_mut()?)))?
}
fn lock_current_context_unchecked<F: FnOnce(&mut context::Context) -> R, R>(
f: F,
) -> Result<R, CUresult> {
context::CONTEXT_STACK.with(|stack| {
stack
.borrow_mut()
.last_mut()
.ok_or(CUresult::CUDA_ERROR_INVALID_CONTEXT)
.map(|ctx| GlobalState::lock(|_| f(unsafe { &mut **ctx })))?
})
}
fn lock_stream<T>(
stream: *mut stream::Stream,
f: impl FnOnce(&mut stream::StreamData) -> T,
) -> Result<T, CUresult> {
if stream == ptr::null_mut()
|| stream == stream::CU_STREAM_LEGACY
|| stream == stream::CU_STREAM_PER_THREAD
{
Self::lock_current_context(|ctx| Ok(f(&mut ctx.default_stream)))?
} else {
Self::lock(|_| {
let stream = unsafe { &mut *stream }.as_result_mut()?;
Ok(f(stream))
})?
}
}
}
// TODO: implement // TODO: implement
fn is_intel_gpu_driver(_: &l0::Driver) -> bool { fn is_intel_gpu_driver(_: &l0::Driver) -> bool {
true true
} }
pub fn init() -> l0::Result<()> { pub fn init() -> Result<(), CUresult> {
let mut global_state = GLOBAL_STATE let mut global_state = GLOBAL_STATE
.lock() .lock()
.map_err(|_| l0::sys::ze_result_t::ZE_RESULT_ERROR_UNKNOWN)?; .map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?;
if global_state.is_some() { if global_state.is_some() {
return Ok(()); return Ok(());
} }
l0::init()?; l0::init()?;
let drivers = l0::Driver::get()?; let drivers = l0::Driver::get()?;
let driver = match drivers.into_iter().find(is_intel_gpu_driver) { let devices = match drivers.into_iter().find(is_intel_gpu_driver) {
None => return Err(l0::sys::ze_result_t::ZE_RESULT_ERROR_UNKNOWN), None => return Err(CUresult::CUDA_ERROR_UNKNOWN),
Some(driver) => { Some(driver) => device::init(&driver)?,
device::init(&driver)?;
driver
}
}; };
*global_state = Some(GlobalState { driver }); *global_state = Some(GlobalState { devices });
drop(global_state); drop(global_state);
Ok(()) Ok(())
} }
unsafe fn transmute_lifetime<'a, 'b, T: ?Sized>(t: &'a T) -> &'b T { unsafe fn transmute_lifetime_mut<'a, 'b, T: ?Sized>(t: &'a mut T) -> &'b mut T {
mem::transmute(t) mem::transmute(t)
} }

View file

@ -1,79 +1,90 @@
use std::{ use std::{
collections::HashMap, ffi::CStr, ffi::CString, mem, os::raw::c_char, ptr, slice, sync::Mutex, collections::hash_map, collections::HashMap, ffi::c_void, ffi::CStr, ffi::CString, mem,
os::raw::c_char, ptr, slice,
}; };
use super::{function::Function, transmute_lifetime, CUresult}; use super::{
device, function::Function, function::FunctionData, CUresult, GlobalState, HasLivenessCookie,
LiveCheck,
};
use ptx; use ptx;
pub type Module = Mutex<ModuleData>; pub type Module = LiveCheck<ModuleData>;
pub struct ModuleData { impl HasLivenessCookie for ModuleData {
base: l0::Module, #[cfg(target_pointer_width = "64")]
arg_lens: HashMap<CString, Vec<usize>>, const COOKIE: usize = 0xf1313bd46505f98a;
}
pub enum ModuleCompileError<'a> { #[cfg(target_pointer_width = "32")]
Parse( const COOKIE: usize = 0xbdbe3f15;
Vec<ptx::ast::PtxError>,
Option<ptx::ParseError<usize, ptx::Token<'a>, ptx::ast::PtxError>>,
),
Compile(ptx::TranslateError),
L0(l0::sys::ze_result_t),
CUDA(CUresult),
}
impl<'a> ModuleCompileError<'a> { const LIVENESS_FAIL: CUresult = CUresult::CUDA_ERROR_INVALID_HANDLE;
pub fn get_build_log(&self) {}
}
impl<'a> From<ptx::TranslateError> for ModuleCompileError<'a> { fn try_drop(&mut self) -> Result<(), CUresult> {
fn from(err: ptx::TranslateError) -> Self { Ok(())
ModuleCompileError::Compile(err)
} }
} }
impl<'a> From<l0::sys::ze_result_t> for ModuleCompileError<'a> { pub struct ModuleData {
fn from(err: l0::sys::ze_result_t) -> Self { pub spirv: SpirvModule,
ModuleCompileError::L0(err) // This should be a Vec<>, but I'm feeling lazy
pub device_binaries: HashMap<device::Index, CompiledModule>,
}
pub struct SpirvModule {
pub binaries: Vec<u32>,
pub kernel_info: HashMap<String, ptx::KernelInfo>,
pub should_link_ptx_impl: Option<&'static [u8]>,
pub build_options: CString,
}
pub struct CompiledModule {
pub base: l0::Module,
pub kernels: HashMap<CString, Box<Function>>,
}
impl<L, T, E> From<ptx::ParseError<L, T, E>> for CUresult {
fn from(_: ptx::ParseError<L, T, E>) -> Self {
CUresult::CUDA_ERROR_INVALID_PTX
} }
} }
impl<'a> From<CUresult> for ModuleCompileError<'a> { impl From<ptx::TranslateError> for CUresult {
fn from(err: CUresult) -> Self { fn from(_: ptx::TranslateError) -> Self {
ModuleCompileError::CUDA(err) CUresult::CUDA_ERROR_INVALID_PTX
} }
} }
impl ModuleData { impl SpirvModule {
pub fn compile_spirv<'a>(ptx_text: &'a str) -> Result<Module, ModuleCompileError<'a>> { pub fn new_raw<'a>(text: *const c_char) -> Result<Self, CUresult> {
let u8_text = unsafe { CStr::from_ptr(text) };
let ptx_text = u8_text
.to_str()
.map_err(|_| CUresult::CUDA_ERROR_INVALID_PTX)?;
Self::new(ptx_text)
}
pub fn new<'a>(ptx_text: &str) -> Result<Self, CUresult> {
let mut errors = Vec::new(); let mut errors = Vec::new();
let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_text); let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_text)?;
let ast = match ast { let spirv_module = ptx::to_spirv_module(ast)?;
Err(e) => return Err(ModuleCompileError::Parse(errors, Some(e))), Ok(SpirvModule {
Ok(_) if errors.len() > 0 => return Err(ModuleCompileError::Parse(errors, None)), binaries: spirv_module.assemble(),
Ok(ast) => ast, kernel_info: spirv_module.kernel_info,
}; should_link_ptx_impl: spirv_module.should_link_ptx_impl,
let (_, spirv, all_arg_lens) = ptx::to_spirv(ast)?; build_options: spirv_module.build_options,
})
}
pub fn compile(&self, ctx: &mut l0::Context, dev: &l0::Device) -> Result<l0::Module, CUresult> {
let byte_il = unsafe { let byte_il = unsafe {
slice::from_raw_parts::<u8>( slice::from_raw_parts(
spirv.as_ptr() as *const _, self.binaries.as_ptr() as *const u8,
spirv.len() * mem::size_of::<u32>(), self.binaries.len() * mem::size_of::<u32>(),
) )
}; };
let module = super::device::with_current_exclusive(|dev| { let l0_module = l0::Module::build_spirv(ctx, dev, byte_il, None).0?;
l0::Module::build_spirv(&mut dev.l0_context, &dev.base, byte_il, None) Ok(l0_module)
});
match module {
Ok((Ok(module), _)) => Ok(Mutex::new(Self {
base: module,
arg_lens: all_arg_lens
.into_iter()
.map(|(k, v)| (CString::new(k).unwrap(), v))
.collect(),
})),
Ok((Err(err), _)) => Err(ModuleCompileError::from(err)),
Err(err) => Err(ModuleCompileError::from(err)),
}
} }
} }
@ -85,34 +96,75 @@ pub fn get_function(
if hfunc == ptr::null_mut() || hmod == ptr::null_mut() || name == ptr::null() { if hfunc == ptr::null_mut() || hmod == ptr::null_mut() || name == ptr::null() {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE); return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
} }
let name = unsafe { CStr::from_ptr(name) }; let name = unsafe { CStr::from_ptr(name) }.to_owned();
let (mut kernel, args_len) = unsafe { &*hmod } let function: *mut Function = GlobalState::lock_current_context(|ctx| {
.try_lock() let module = unsafe { &mut *hmod }.as_result_mut()?;
.map(|module| { let device = unsafe { &mut *ctx.device };
Result::<_, CUresult>::Ok(( let compiled_module = match module.device_binaries.entry(device.index) {
l0::Kernel::new_resident(unsafe { transmute_lifetime(&module.base) }, name)?, hash_map::Entry::Occupied(entry) => entry.into_mut(),
module hash_map::Entry::Vacant(entry) => {
.arg_lens let new_module = CompiledModule {
.get(name) base: module.spirv.compile(&mut device.l0_context, &device.base)?,
.ok_or(CUresult::CUDA_ERROR_NOT_FOUND)? kernels: HashMap::new(),
.clone(),
))
})
.map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE)??;
kernel.set_indirect_access(
l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE
| l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_HOST
| l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_SHARED,
)?;
unsafe {
*hfunc = Box::into_raw(Box::new(Function {
base: kernel,
arg_size: args_len,
}))
}; };
entry.insert(new_module)
}
};
//let compiled_module = unsafe { transmute_lifetime_mut(compiled_module) };
let kernel = match compiled_module.kernels.entry(name) {
hash_map::Entry::Occupied(entry) => entry.into_mut().as_mut(),
hash_map::Entry::Vacant(entry) => {
let kernel_info = module
.spirv
.kernel_info
.get(unsafe {
std::str::from_utf8_unchecked(entry.key().as_c_str().to_bytes())
})
.ok_or(CUresult::CUDA_ERROR_NOT_FOUND)?;
let kernel =
l0::Kernel::new_resident(&compiled_module.base, entry.key().as_c_str())?;
entry.insert(Box::new(Function::new(FunctionData {
base: kernel,
arg_size: kernel_info.arguments_sizes.clone(),
use_shared_mem: kernel_info.uses_shared_mem,
})))
}
};
Ok::<_, CUresult>(kernel as *mut _)
})??;
unsafe { *hfunc = function };
Ok(()) Ok(())
} }
pub(crate) fn unload(_: *mut Module) -> Result<(), CUresult> { pub(crate) fn load_data(pmod: *mut *mut Module, image: *const c_void) -> Result<(), CUresult> {
let spirv_data = SpirvModule::new_raw(image as *const _)?;
load_data_impl(pmod, spirv_data)
}
pub fn load_data_impl(pmod: *mut *mut Module, spirv_data: SpirvModule) -> Result<(), CUresult> {
let module = GlobalState::lock_current_context(|ctx| {
let device = unsafe { &mut *ctx.device };
let l0_module = spirv_data.compile(&mut device.l0_context, &device.base)?;
let mut device_binaries = HashMap::new();
let compiled_module = CompiledModule {
base: l0_module,
kernels: HashMap::new(),
};
device_binaries.insert(device.index, compiled_module);
let module_data = ModuleData {
spirv: spirv_data,
device_binaries,
};
Ok::<_, CUresult>(module_data)
})??;
let module_ptr = Box::into_raw(Box::new(Module::new(module)));
unsafe { *pmod = module_ptr };
Ok(()) Ok(())
} }
pub(crate) fn unload(module: *mut Module) -> Result<(), CUresult> {
if module == ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
GlobalState::lock(|_| Module::destroy_impl(module))?
}

View file

@ -1,36 +1,114 @@
use std::cell::RefCell; use super::{
context::{Context, ContextData},
CUresult, GlobalState,
};
use std::{mem, ptr};
use device::Device; use super::{HasLivenessCookie, LiveCheck};
use super::device; pub type Stream = LiveCheck<StreamData>;
pub struct Stream { pub const CU_STREAM_LEGACY: *mut Stream = 1 as *mut _;
dev: *mut Device, pub const CU_STREAM_PER_THREAD: *mut Stream = 2 as *mut _;
}
pub struct DefaultStream { impl HasLivenessCookie for StreamData {
streams: Vec<Option<Stream>>, #[cfg(target_pointer_width = "64")]
} const COOKIE: usize = 0x512097354de18d35;
impl DefaultStream { #[cfg(target_pointer_width = "32")]
fn new() -> Self { const COOKIE: usize = 0x77d5cc0b;
DefaultStream {
streams: Vec::new(), const LIVENESS_FAIL: CUresult = CUresult::CUDA_ERROR_INVALID_HANDLE;
fn try_drop(&mut self) -> Result<(), CUresult> {
if self.context != ptr::null_mut() {
let context = unsafe { &mut *self.context };
if !context.streams.remove(&(self as *mut _)) {
return Err(CUresult::CUDA_ERROR_UNKNOWN);
} }
} }
Ok(())
}
} }
thread_local! { pub struct StreamData {
pub static DEFAULT_STREAM: RefCell<DefaultStream> = RefCell::new(DefaultStream::new()); pub context: *mut ContextData,
pub queue: l0::CommandQueue,
}
impl StreamData {
pub fn new_unitialized(ctx: &mut l0::Context, dev: &l0::Device) -> Result<Self, CUresult> {
Ok(StreamData {
context: ptr::null_mut(),
queue: l0::CommandQueue::new(ctx, dev)?,
})
}
pub fn new(ctx: &mut ContextData) -> Result<Self, CUresult> {
let l0_ctx = &mut unsafe { &mut *ctx.device }.l0_context;
let l0_dev = &unsafe { &*ctx.device }.base;
Ok(StreamData {
context: ctx as *mut _,
queue: l0::CommandQueue::new(l0_ctx, l0_dev)?,
})
}
pub fn command_list(&self) -> Result<l0::CommandList, l0::sys::_ze_result_t> {
let ctx = unsafe { &mut *self.context };
let dev = unsafe { &mut *ctx.device };
l0::CommandList::new(&mut dev.l0_context, &dev.base)
}
}
impl Drop for StreamData {
fn drop(&mut self) {
if self.context == ptr::null_mut() {
return;
}
unsafe { (&mut *self.context).streams.remove(&(&mut *self as *mut _)) };
}
}
pub(crate) fn get_ctx(hstream: *mut Stream, pctx: *mut *mut Context) -> Result<(), CUresult> {
if pctx == ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
let ctx_ptr = GlobalState::lock_stream(hstream, |stream| stream.context)?;
if ctx_ptr == ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_CONTEXT_IS_DESTROYED);
}
unsafe { *pctx = Context::ptr_from_inner(ctx_ptr) };
Ok(())
}
pub(crate) fn create(phstream: *mut *mut Stream, _flags: u32) -> Result<(), CUresult> {
let stream_ptr = GlobalState::lock_current_context(|ctx| {
let mut stream_box = Box::new(Stream::new(StreamData::new(ctx)?));
let stream_ptr = stream_box.as_mut().as_option_mut().unwrap() as *mut _;
if !ctx.streams.insert(stream_ptr) {
return Err(CUresult::CUDA_ERROR_UNKNOWN);
}
mem::forget(stream_box);
Ok::<_, CUresult>(stream_ptr)
})??;
unsafe { *phstream = Stream::ptr_from_inner(stream_ptr) };
Ok(())
}
pub(crate) fn destroy_v2(pstream: *mut Stream) -> Result<(), CUresult> {
if pstream == ptr::null_mut() || pstream == CU_STREAM_LEGACY || pstream == CU_STREAM_PER_THREAD
{
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
GlobalState::lock(|_| Stream::destroy_impl(pstream))?
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod test {
use crate::cuda::CUstream; use crate::cuda::CUstream;
use super::super::test::CudaDriverFns; use super::super::test::CudaDriverFns;
use super::super::CUresult; use super::super::CUresult;
use std::ptr; use std::{ptr, thread};
const CU_STREAM_LEGACY: CUstream = 1 as *mut _; const CU_STREAM_LEGACY: CUstream = 1 as *mut _;
const CU_STREAM_PER_THREAD: CUstream = 2 as *mut _; const CU_STREAM_PER_THREAD: CUstream = 2 as *mut _;
@ -65,5 +143,100 @@ mod tests {
CUresult::CUDA_SUCCESS CUresult::CUDA_SUCCESS
); );
assert_eq!(ctx2, stream_ctx2); assert_eq!(ctx2, stream_ctx2);
// Cleanup
assert_eq!(T::cuCtxDestroy_v2(ctx1), CUresult::CUDA_SUCCESS);
assert_eq!(T::cuCtxDestroy_v2(ctx2), CUresult::CUDA_SUCCESS);
}
cuda_driver_test!(stream_context_destroyed);
fn stream_context_destroyed<T: CudaDriverFns>() {
assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS);
let mut ctx = ptr::null_mut();
assert_eq!(T::cuCtxCreate_v2(&mut ctx, 0, 0), CUresult::CUDA_SUCCESS);
let mut stream = ptr::null_mut();
assert_eq!(T::cuStreamCreate(&mut stream, 0), CUresult::CUDA_SUCCESS);
let mut stream_ctx1 = ptr::null_mut();
assert_eq!(
T::cuStreamGetCtx(stream, &mut stream_ctx1),
CUresult::CUDA_SUCCESS
);
assert_eq!(stream_ctx1, ctx);
assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS);
let mut stream_ctx2 = ptr::null_mut();
// When a context gets destroyed, its streams are also destroyed
let cuda_result = T::cuStreamGetCtx(stream, &mut stream_ctx2);
assert!(
cuda_result == CUresult::CUDA_ERROR_INVALID_HANDLE
|| cuda_result == CUresult::CUDA_ERROR_INVALID_CONTEXT
|| cuda_result == CUresult::CUDA_ERROR_CONTEXT_IS_DESTROYED
);
assert_eq!(
T::cuStreamDestroy_v2(stream),
CUresult::CUDA_ERROR_INVALID_HANDLE
);
// Check if creating another context is possible
let mut ctx2 = ptr::null_mut();
assert_eq!(T::cuCtxCreate_v2(&mut ctx2, 0, 0), CUresult::CUDA_SUCCESS);
// Cleanup
assert_eq!(T::cuCtxDestroy_v2(ctx2), CUresult::CUDA_SUCCESS);
}
cuda_driver_test!(stream_moves_context_to_another_thread);
fn stream_moves_context_to_another_thread<T: CudaDriverFns>() {
assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS);
let mut ctx = ptr::null_mut();
assert_eq!(T::cuCtxCreate_v2(&mut ctx, 0, 0), CUresult::CUDA_SUCCESS);
let mut stream = ptr::null_mut();
assert_eq!(T::cuStreamCreate(&mut stream, 0), CUresult::CUDA_SUCCESS);
let mut stream_ctx1 = ptr::null_mut();
assert_eq!(
T::cuStreamGetCtx(stream, &mut stream_ctx1),
CUresult::CUDA_SUCCESS
);
assert_eq!(stream_ctx1, ctx);
let stream_ptr = stream as usize;
let stream_ctx_on_thread = thread::spawn(move || {
let mut stream_ctx2 = ptr::null_mut();
assert_eq!(
T::cuStreamGetCtx(stream_ptr as *mut _, &mut stream_ctx2),
CUresult::CUDA_SUCCESS
);
stream_ctx2 as usize
})
.join()
.unwrap();
assert_eq!(stream_ctx1, stream_ctx_on_thread as *mut _);
// Cleanup
assert_eq!(T::cuStreamDestroy_v2(stream), CUresult::CUDA_SUCCESS);
assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS);
}
cuda_driver_test!(can_destroy_stream);
fn can_destroy_stream<T: CudaDriverFns>() {
assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS);
let mut ctx = ptr::null_mut();
assert_eq!(T::cuCtxCreate_v2(&mut ctx, 0, 0), CUresult::CUDA_SUCCESS);
let mut stream = ptr::null_mut();
assert_eq!(T::cuStreamCreate(&mut stream, 0), CUresult::CUDA_SUCCESS);
assert_eq!(T::cuStreamDestroy_v2(stream), CUresult::CUDA_SUCCESS);
// Cleanup
assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS);
}
cuda_driver_test!(cant_destroy_default_stream);
fn cant_destroy_default_stream<T: CudaDriverFns>() {
assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS);
let mut ctx = ptr::null_mut();
assert_eq!(T::cuCtxCreate_v2(&mut ctx, 0, 0), CUresult::CUDA_SUCCESS);
assert_ne!(
T::cuStreamDestroy_v2(super::CU_STREAM_LEGACY as *mut _),
CUresult::CUDA_SUCCESS
);
// Cleanup
assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS);
} }
} }

View file

@ -1,8 +1,12 @@
#![allow(non_snake_case)] #![allow(non_snake_case)]
use crate::{cuda::CUstream, r#impl as notcuda}; use crate::cuda as notcuda;
use crate::r#impl::CUresult; use crate::cuda::CUstream;
use crate::{cuda::CUuuid, r#impl::Encuda}; use crate::cuda::CUuuid;
use crate::{
cuda::{CUdevice, CUdeviceptr},
r#impl::CUresult,
};
use ::std::{ use ::std::{
ffi::c_void, ffi::c_void,
os::raw::{c_int, c_uint}, os::raw::{c_int, c_uint},
@ -37,48 +41,63 @@ pub trait CudaDriverFns {
fn cuDeviceGetUuid(uuid: *mut CUuuid, dev: c_int) -> CUresult; fn cuDeviceGetUuid(uuid: *mut CUuuid, dev: c_int) -> CUresult;
fn cuDevicePrimaryCtxGetState(dev: c_int, flags: *mut c_uint, active: *mut c_int) -> CUresult; fn cuDevicePrimaryCtxGetState(dev: c_int, flags: *mut c_uint, active: *mut c_int) -> CUresult;
fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult; fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult;
fn cuStreamCreate(stream: *mut CUstream, flags: c_uint) -> CUresult;
fn cuMemFree_v2(mem: *mut c_void) -> CUresult;
fn cuStreamDestroy_v2(stream: CUstream) -> CUresult;
} }
pub struct NotCuda(); pub struct NotCuda();
impl CudaDriverFns for NotCuda { impl CudaDriverFns for NotCuda {
fn cuInit(_flags: c_uint) -> CUresult { fn cuInit(_flags: c_uint) -> CUresult {
crate::cuda::cuInit(_flags as _) notcuda::cuInit(_flags as _)
} }
fn cuCtxCreate_v2(pctx: *mut *mut c_void, flags: c_uint, dev: c_int) -> CUresult { fn cuCtxCreate_v2(pctx: *mut *mut c_void, flags: c_uint, dev: c_int) -> CUresult {
notcuda::context::create_v2(pctx as *mut _, flags, notcuda::device::Index(dev)).encuda() notcuda::cuCtxCreate_v2(pctx as *mut _, flags, CUdevice(dev))
} }
fn cuCtxDestroy_v2(ctx: *mut c_void) -> CUresult { fn cuCtxDestroy_v2(ctx: *mut c_void) -> CUresult {
notcuda::context::destroy_v2(ctx as *mut _) notcuda::cuCtxDestroy_v2(ctx as *mut _)
} }
fn cuCtxPopCurrent_v2(pctx: *mut *mut c_void) -> CUresult { fn cuCtxPopCurrent_v2(pctx: *mut *mut c_void) -> CUresult {
notcuda::context::pop_current_v2(pctx as *mut _) notcuda::cuCtxPopCurrent_v2(pctx as *mut _)
} }
fn cuCtxGetApiVersion(ctx: *mut c_void, version: *mut c_uint) -> CUresult { fn cuCtxGetApiVersion(ctx: *mut c_void, version: *mut c_uint) -> CUresult {
notcuda::context::get_api_version(ctx as *mut _, version) notcuda::cuCtxGetApiVersion(ctx as *mut _, version)
} }
fn cuCtxGetCurrent(pctx: *mut *mut c_void) -> CUresult { fn cuCtxGetCurrent(pctx: *mut *mut c_void) -> CUresult {
notcuda::context::get_current(pctx as *mut _).encuda() notcuda::cuCtxGetCurrent(pctx as *mut _)
} }
fn cuMemAlloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> CUresult { fn cuMemAlloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> CUresult {
notcuda::memory::alloc_v2(dptr as *mut _, bytesize) notcuda::cuMemAlloc_v2(dptr as *mut _, bytesize)
} }
fn cuDeviceGetUuid(uuid: *mut CUuuid, dev: c_int) -> CUresult { fn cuDeviceGetUuid(uuid: *mut CUuuid, dev: c_int) -> CUresult {
notcuda::device::get_uuid(uuid, notcuda::device::Index(dev)).encuda() notcuda::cuDeviceGetUuid(uuid, CUdevice(dev))
} }
fn cuDevicePrimaryCtxGetState(dev: c_int, flags: *mut c_uint, active: *mut c_int) -> CUresult { fn cuDevicePrimaryCtxGetState(dev: c_int, flags: *mut c_uint, active: *mut c_int) -> CUresult {
notcuda::device::primary_ctx_get_state(notcuda::device::Index(dev), flags, active).encuda() notcuda::cuDevicePrimaryCtxGetState(CUdevice(dev), flags, active)
} }
fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult { fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult {
crate::cuda::cuStreamGetCtx(hStream, pctx as _) notcuda::cuStreamGetCtx(hStream, pctx as _)
}
fn cuStreamCreate(stream: *mut CUstream, flags: c_uint) -> CUresult {
notcuda::cuStreamCreate(stream, flags)
}
fn cuMemFree_v2(dptr: *mut c_void) -> CUresult {
notcuda::cuMemFree_v2(CUdeviceptr(dptr as _))
}
fn cuStreamDestroy_v2(stream: CUstream) -> CUresult {
notcuda::cuStreamDestroy_v2(stream)
} }
} }
@ -123,4 +142,16 @@ impl CudaDriverFns for Cuda {
fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult { fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult {
unsafe { CUresult(cuda::cuStreamGetCtx(hStream as _, pctx as _) as c_uint) } unsafe { CUresult(cuda::cuStreamGetCtx(hStream as _, pctx as _) as c_uint) }
} }
fn cuStreamCreate(stream: *mut CUstream, flags: c_uint) -> CUresult {
unsafe { CUresult(cuda::cuStreamCreate(stream as _, flags as _) as c_uint) }
}
fn cuMemFree_v2(mem: *mut c_void) -> CUresult {
unsafe { CUresult(cuda::cuMemFree_v2(mem as _) as c_uint) }
}
fn cuStreamDestroy_v2(stream: CUstream) -> CUresult {
unsafe { CUresult(cuda::cuStreamDestroy_v2(stream as _) as c_uint) }
}
} }

View file

@ -34,8 +34,9 @@ pub use crate::ptx::ModuleParser;
pub use lalrpop_util::lexer::Token; pub use lalrpop_util::lexer::Token;
pub use lalrpop_util::ParseError; pub use lalrpop_util::ParseError;
pub use rspirv::dr::Error as SpirvError; pub use rspirv::dr::Error as SpirvError;
pub use translate::TranslateError as TranslateError; pub use translate::to_spirv_module;
pub use translate::to_spirv; pub use translate::KernelInfo;
pub use translate::TranslateError;
pub(crate) fn without_none<T>(x: Vec<Option<T>>) -> Vec<T> { pub(crate) fn without_none<T>(x: Vec<Option<T>>) -> Vec<T> {
x.into_iter().filter_map(|x| x).collect() x.into_iter().filter_map(|x| x).collect()

View file

@ -12,7 +12,7 @@ fn parse_and_assert(s: &str) {
fn compile_and_assert(s: &str) -> Result<(), TranslateError> { fn compile_and_assert(s: &str) -> Result<(), TranslateError> {
let mut errors = Vec::new(); let mut errors = Vec::new();
let ast = ptx::ModuleParser::new().parse(&mut errors, s).unwrap(); let ast = ptx::ModuleParser::new().parse(&mut errors, s).unwrap();
crate::to_spirv(ast)?; crate::to_spirv_module(ast)?;
Ok(()) Ok(())
} }

View file

@ -1,7 +1,7 @@
use crate::ast; use crate::ast;
use half::f16; use half::f16;
use rspirv::{binary::Disassemble, dr}; use rspirv::{binary::Disassemble, dr};
use std::{borrow::Cow, convert::TryFrom, ffi::CString, hash::Hash, iter, mem}; use std::{borrow::Cow, ffi::CString, hash::Hash, iter, mem};
use std::{ use std::{
collections::{hash_map, HashMap, HashSet}, collections::{hash_map, HashMap, HashSet},
convert::TryInto, convert::TryInto,
@ -450,6 +450,11 @@ pub struct Module {
pub should_link_ptx_impl: Option<&'static [u8]>, pub should_link_ptx_impl: Option<&'static [u8]>,
pub build_options: CString, pub build_options: CString,
} }
impl Module {
pub fn assemble(&self) -> Vec<u32> {
self.spirv.assemble()
}
}
pub struct KernelInfo { pub struct KernelInfo {
pub arguments_sizes: Vec<usize>, pub arguments_sizes: Vec<usize>,
@ -1046,8 +1051,12 @@ fn emit_function_header<'a>(
kernel_info: &mut HashMap<String, KernelInfo>, kernel_info: &mut HashMap<String, KernelInfo>,
) -> Result<(), TranslateError> { ) -> Result<(), TranslateError> {
if let MethodName::Kernel(name) = func_decl.name { if let MethodName::Kernel(name) = func_decl.name {
let args_lens = func_decl let input_args = if !func_decl.uses_shared_mem {
.input func_decl.input.as_slice()
} else {
&func_decl.input[0..func_decl.input.len() - 1]
};
let args_lens = input_args
.iter() .iter()
.map(|param| param.v_type.size_of()) .map(|param| param.v_type.size_of())
.collect(); .collect();
@ -1135,21 +1144,6 @@ fn emit_function_header<'a>(
Ok(()) Ok(())
} }
pub fn to_spirv<'a>(
ast: ast::Module<'a>,
) -> Result<(Option<&'static [u8]>, Vec<u32>, HashMap<String, Vec<usize>>), TranslateError> {
let module = to_spirv_module(ast)?;
Ok((
module.should_link_ptx_impl,
module.spirv.assemble(),
module
.kernel_info
.into_iter()
.map(|(k, v)| (k, v.arguments_sizes))
.collect(),
))
}
fn emit_capabilities(builder: &mut dr::Builder) { fn emit_capabilities(builder: &mut dr::Builder) {
builder.capability(spirv::Capability::GenericPointer); builder.capability(spirv::Capability::GenericPointer);
builder.capability(spirv::Capability::Linkage); builder.capability(spirv::Capability::Linkage);