diff --git a/level_zero/src/ze.rs b/level_zero/src/ze.rs index f8a2c3b..4267682 100644 --- a/level_zero/src/ze.rs +++ b/level_zero/src/ze.rs @@ -173,6 +173,16 @@ impl Context { check!(sys::zeContextCreate(drv.0, &ctx_desc, &mut result)); Ok(Context(result)) } + + pub unsafe fn mem_free(&mut self, ptr: *mut c_void) -> Result<()> { + check! { + sys::zeMemFree( + self.0, + ptr, + ) + }; + Ok(()) + } } impl Drop for Context { @@ -239,7 +249,7 @@ pub struct Module(sys::ze_module_handle_t); impl Module { // HACK ALERT - // We use OpenCL for now to do SPIR-V linking, because Level0 + // We use OpenCL for now to do SPIR-V linking, because Level0 // does not allow linking. Don't let presence of zeModuleDynamicLink fool // you, it's not currently possible to create non-compiled modules. // zeModuleCreate always compiles (builds and links). diff --git a/notcuda/build.rs b/notcuda/build.rs new file mode 100644 index 0000000..3b8999f --- /dev/null +++ b/notcuda/build.rs @@ -0,0 +1,27 @@ +// HACK ALERT +// This buidl script has been copy-pasted from cl-sys to avoid CUDA libraries +// overriding path to OpenCL + + fn main() { + if cfg!(windows) { + let known_sdk = [ + // E.g. "c:\Program Files (x86)\Intel\OpenCL SDK\lib\x86\" + ("INTELOCLSDKROOT", "x64", "x86"), + // E.g. "c:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v8.0\lib\Win32\" + ("CUDA_PATH", "x64", "Win32"), + // E.g. "C:\Program Files (x86)\AMD APP SDK\3.0\lib\x86\" + ("AMDAPPSDKROOT", "x86_64", "x86"), + ]; + + for info in known_sdk.iter() { + if let Ok(sdk) = std::env::var(info.0) { + let mut path = std::path::PathBuf::from(sdk); + path.push("lib"); + path.push(if cfg!(target_arch="x86_64") { info.1 } else { info.2 }); + println!("cargo:rustc-link-search=native={}", path.display()); + } + } + + println!("cargo:rustc-link-search=native=C:\\Program Files (x86)\\OCL_SDK_Light\\lib\\x86_64"); + } +} \ No newline at end of file diff --git a/notcuda/src/cuda.rs b/notcuda/src/cuda.rs index a18ebf9..335da4a 100644 --- a/notcuda/src/cuda.rs +++ b/notcuda/src/cuda.rs @@ -2210,12 +2210,12 @@ pub extern "C" fn cuDriverGetVersion(driverVersion: *mut ::std::os::raw::c_int) #[cfg_attr(not(test), no_mangle)] pub extern "C" fn cuDeviceGet(device: *mut CUdevice, ordinal: ::std::os::raw::c_int) -> CUresult { - r#impl::device::get(device.decuda(), ordinal) + r#impl::device::get(device.decuda(), ordinal).encuda() } #[cfg_attr(not(test), no_mangle)] pub extern "C" fn cuDeviceGetCount(count: *mut ::std::os::raw::c_int) -> CUresult { - r#impl::device::get_count(count) + r#impl::device::get_count(count).encuda() } #[cfg_attr(not(test), no_mangle)] @@ -2314,7 +2314,6 @@ pub extern "C" fn cuDevicePrimaryCtxReset(dev: CUdevice) -> CUresult { cuDevicePrimaryCtxReset_v2(dev) } - #[cfg_attr(not(test), no_mangle)] pub extern "C" fn cuDevicePrimaryCtxReset_v2(dev: CUdevice) -> CUresult { r#impl::unimplemented() @@ -2331,7 +2330,7 @@ pub extern "C" fn cuCtxCreate_v2( #[cfg_attr(not(test), no_mangle)] pub extern "C" fn cuCtxDestroy_v2(ctx: CUcontext) -> CUresult { - r#impl::context::destroy_v2(ctx.decuda()) + r#impl::context::destroy_v2(ctx.decuda()).encuda() } #[cfg_attr(not(test), no_mangle)] @@ -2356,7 +2355,7 @@ pub extern "C" fn cuCtxGetCurrent(pctx: *mut CUcontext) -> CUresult { #[cfg_attr(not(test), no_mangle)] pub extern "C" fn cuCtxGetDevice(device: *mut CUdevice) -> CUresult { - r#impl::context::get_device(device.decuda()) + r#impl::context::get_device(device.decuda()).encuda() } #[cfg_attr(not(test), no_mangle)] @@ -2404,7 +2403,7 @@ pub extern "C" fn cuCtxGetApiVersion( ctx: CUcontext, version: *mut ::std::os::raw::c_uint, ) -> CUresult { - r#impl::context::get_api_version(ctx.decuda(), version) + r#impl::context::get_api_version(ctx.decuda(), version).encuda() } #[cfg_attr(not(test), no_mangle)] @@ -2422,12 +2421,12 @@ pub extern "C" fn cuCtxResetPersistingL2Cache() -> CUresult { #[cfg_attr(not(test), no_mangle)] pub extern "C" fn cuCtxAttach(pctx: *mut CUcontext, flags: ::std::os::raw::c_uint) -> CUresult { - r#impl::unimplemented() + r#impl::context::attach(pctx.decuda(), flags).encuda() } #[cfg_attr(not(test), no_mangle)] pub extern "C" fn cuCtxDetach(ctx: CUcontext) -> CUresult { - r#impl::unimplemented() + r#impl::context::detach(ctx.decuda()).encuda() } #[cfg_attr(not(test), no_mangle)] @@ -2443,7 +2442,7 @@ pub extern "C" fn cuModuleLoadData( module: *mut CUmodule, image: *const ::std::os::raw::c_void, ) -> CUresult { - r#impl::unimplemented() + r#impl::module::load_data(module.decuda(), image).encuda() } #[cfg_attr(not(test), no_mangle)] @@ -2564,7 +2563,7 @@ pub extern "C" fn cuMemGetInfo_v2(free: *mut usize, total: *mut usize) -> CUresu #[cfg_attr(not(test), no_mangle)] pub extern "C" fn cuMemAlloc_v2(dptr: *mut CUdeviceptr, bytesize: usize) -> CUresult { - r#impl::memory::alloc_v2(dptr.decuda(), bytesize) + r#impl::memory::alloc_v2(dptr.decuda(), bytesize).encuda() } #[cfg_attr(not(test), no_mangle)] @@ -3281,7 +3280,7 @@ pub extern "C" fn cuStreamCreate( phStream: *mut CUstream, Flags: ::std::os::raw::c_uint, ) -> CUresult { - r#impl::unimplemented() + r#impl::stream::create(phStream.decuda(), Flags).encuda() } #[cfg_attr(not(test), no_mangle)] @@ -3311,7 +3310,7 @@ pub extern "C" fn cuStreamGetFlags( #[cfg_attr(not(test), no_mangle)] pub extern "C" fn cuStreamGetCtx(hStream: CUstream, pctx: *mut CUcontext) -> CUresult { - r#impl::unimplemented() + r#impl::stream::get_ctx(hStream.decuda(), pctx.decuda()).encuda() } #[cfg_attr(not(test), no_mangle)] @@ -3390,7 +3389,7 @@ pub extern "C" fn cuStreamSynchronize(hStream: CUstream) -> CUresult { #[cfg_attr(not(test), no_mangle)] pub extern "C" fn cuStreamDestroy_v2(hStream: CUstream) -> CUresult { - r#impl::unimplemented() + r#impl::stream::destroy_v2(hStream.decuda()).encuda() } #[cfg_attr(not(test), no_mangle)] diff --git a/notcuda/src/impl/context.rs b/notcuda/src/impl/context.rs index 91d4460..9689ecf 100644 --- a/notcuda/src/impl/context.rs +++ b/notcuda/src/impl/context.rs @@ -1,18 +1,15 @@ -use super::CUresult; -use super::{device, HasLivenessCookie, LiveCheck}; +use super::{device, stream::Stream, stream::StreamData, HasLivenessCookie, LiveCheck}; +use super::{CUresult, GlobalState}; use crate::{cuda::CUcontext, cuda_impl}; use l0::sys::ze_result_t; -use std::mem::{self, ManuallyDrop}; +use std::{cell::RefCell, num::NonZeroU32, os::raw::c_uint, ptr, sync::atomic::AtomicU32}; use std::{ - cell::RefCell, - num::NonZeroU32, - os::raw::c_uint, - ptr, - sync::{atomic::AtomicU32, Mutex}, + collections::HashSet, + mem::{self}, }; thread_local! { - pub static CONTEXT_STACK: RefCell> = RefCell::new(Vec::new()); + pub static CONTEXT_STACK: RefCell> = RefCell::new(Vec::new()); } pub type Context = LiveCheck; @@ -23,6 +20,17 @@ impl HasLivenessCookie for ContextData { #[cfg(target_pointer_width = "32")] const COOKIE: usize = 0x0b643ffb; + + const LIVENESS_FAIL: CUresult = CUresult::CUDA_ERROR_INVALID_CONTEXT; + + fn try_drop(&mut self) -> Result<(), CUresult> { + for stream in self.streams.iter() { + let stream = unsafe { &mut **stream }; + stream.context = ptr::null_mut(); + Stream::destroy_impl(unsafe { Stream::ptr_from_inner(stream) })?; + } + Ok(()) + } } enum ContextRefCount { @@ -67,26 +75,16 @@ impl ContextRefCount { } } } - - fn is_primary(&self) -> bool { - match self { - ContextRefCount::Primary => true, - ContextRefCount::NonPrimary(_) => false, - } - } } pub struct ContextData { pub flags: AtomicU32, - pub device_index: device::Index, // This pointer is null only for a moment when constructing primary context - pub device: *const Mutex, - // The split between mutable / non-mutable is mainly to avoid recursive locking in cuDevicePrimaryCtxGetState - pub mutable: Mutex, -} - -pub struct ContextDataMutable { + pub device: *mut device::Device, ref_count: ContextRefCount, + pub default_stream: StreamData, + pub streams: HashSet<*mut StreamData>, + // All the fields below are here to support internal CUDA driver API pub cuda_manager: *mut cuda_impl::rt::ContextStateManager, pub cuda_state: *mut cuda_impl::rt::ContextState, pub cuda_dtor_cb: Option< @@ -100,63 +98,75 @@ pub struct ContextDataMutable { impl ContextData { pub fn new( + l0_ctx: &mut l0::Context, + l0_dev: &l0::Device, flags: c_uint, is_primary: bool, - dev_index: device::Index, - dev: *const Mutex, - ) -> Self { - ContextData { + dev: *mut device::Device, + ) -> Result { + let default_stream = StreamData::new_unitialized(l0_ctx, l0_dev)?; + Ok(ContextData { flags: AtomicU32::new(flags), - device_index: dev_index, device: dev, - mutable: Mutex::new(ContextDataMutable { - ref_count: ContextRefCount::new(is_primary), - cuda_manager: ptr::null_mut(), - cuda_state: ptr::null_mut(), - cuda_dtor_cb: None, - }), - } + ref_count: ContextRefCount::new(is_primary), + default_stream, + streams: HashSet::new(), + cuda_manager: ptr::null_mut(), + cuda_state: ptr::null_mut(), + cuda_dtor_cb: None, + }) } } -pub fn create_v2(pctx: *mut *mut Context, flags: u32, dev_idx: device::Index) -> CUresult { +impl Context { + pub fn late_init(&mut self) { + let ctx_data = self.as_option_mut().unwrap(); + ctx_data.default_stream.context = ctx_data as *mut _; + } +} + +pub fn create_v2( + pctx: *mut *mut Context, + flags: u32, + dev_idx: device::Index, +) -> Result<(), CUresult> { if pctx == ptr::null_mut() { - return CUresult::CUDA_ERROR_INVALID_VALUE; + return Err(CUresult::CUDA_ERROR_INVALID_VALUE); } - let dev = device::get_device_ref(dev_idx); - let dev = match dev { - Ok(d) => d, - Err(e) => return e, - }; - let mut ctx = Box::new(LiveCheck::new(ContextData::new(flags, false, dev_idx, dev))); - let ctx_ref = ctx.as_mut() as *mut Context; + let mut ctx_box = GlobalState::lock_device(dev_idx, |dev| { + let dev_ptr = dev as *mut _; + let mut ctx_box = Box::new(LiveCheck::new(ContextData::new( + &mut dev.l0_context, + &dev.base, + flags, + false, + dev_ptr as *mut _, + )?)); + ctx_box.late_init(); + Ok::<_, CUresult>(ctx_box) + })??; + let ctx_ref = ctx_box.as_mut() as *mut Context; unsafe { *pctx = ctx_ref }; - mem::forget(ctx); + mem::forget(ctx_box); CONTEXT_STACK.with(|stack| stack.borrow_mut().push(ctx_ref)); - CUresult::CUDA_SUCCESS + Ok(()) } -pub fn destroy_v2(ctx: *mut Context) -> CUresult { +pub fn destroy_v2(ctx: *mut Context) -> Result<(), CUresult> { if ctx == ptr::null_mut() { - return CUresult::CUDA_ERROR_INVALID_VALUE; + return Err(CUresult::CUDA_ERROR_INVALID_VALUE); } CONTEXT_STACK.with(|stack| { let mut stack = stack.borrow_mut(); let should_pop = match stack.last() { - Some(active_ctx) => *active_ctx == (ctx as *const _), + Some(active_ctx) => *active_ctx == (ctx as *mut _), None => false, }; if should_pop { stack.pop(); } }); - let mut ctx_box = ManuallyDrop::new(unsafe { Box::from_raw(ctx) }); - if !ctx_box.try_drop() { - CUresult::CUDA_ERROR_INVALID_CONTEXT - } else { - unsafe { ManuallyDrop::drop(&mut ctx_box) }; - CUresult::CUDA_SUCCESS - } + GlobalState::lock(|_| Context::destroy_impl(ctx))? } pub fn pop_current_v2(pctx: *mut *mut Context) -> CUresult { @@ -172,17 +182,6 @@ pub fn pop_current_v2(pctx: *mut *mut Context) -> CUresult { CUresult::CUDA_SUCCESS } -pub fn with_current R, R>(f: F) -> Result { - CONTEXT_STACK.with(|stack| { - stack - .borrow() - .last() - .and_then(|c| unsafe { &**c }.as_ref()) - .ok_or(CUresult::CUDA_ERROR_INVALID_CONTEXT) - .map(f) - }) -} - pub fn get_current(pctx: *mut *mut Context) -> l0::Result<()> { if pctx == ptr::null_mut() { return Err(ze_result_t::ZE_RESULT_ERROR_INVALID_ARGUMENT); @@ -205,37 +204,53 @@ pub fn set_current(ctx: *mut Context) -> CUresult { } } -pub fn get_api_version(ctx: *mut Context, version: *mut u32) -> CUresult { - let _ctx = match unsafe { ctx.as_mut() } { - None => return CUresult::CUDA_ERROR_INVALID_VALUE, - Some(ctx) => match ctx.as_mut() { - None => return CUresult::CUDA_ERROR_INVALID_CONTEXT, - Some(ctx) => ctx, - }, - }; +pub fn get_api_version(ctx: *mut Context, version: *mut u32) -> Result<(), CUresult> { + if ctx == ptr::null_mut() { + return Err(CUresult::CUDA_ERROR_INVALID_VALUE); + } + GlobalState::lock(|_| { + unsafe { &*ctx }.as_result()?; + Ok::<_, CUresult>(()) + })??; //TODO: query device for properties roughly matching CUDA API version unsafe { *version = 1100 }; - CUresult::CUDA_SUCCESS + Ok(()) } -pub fn get_device(dev: *mut device::Index) -> CUresult { - let dev_idx = with_current(|ctx| ctx.device_index); - match dev_idx { - Ok(idx) => { - unsafe { *dev = idx } - CUresult::CUDA_SUCCESS - } - Err(err) => err, +pub fn get_device(dev: *mut device::Index) -> Result<(), CUresult> { + let dev_idx = GlobalState::lock_current_context(|ctx| unsafe { &*ctx.device }.index)?; + unsafe { *dev = dev_idx }; + Ok(()) +} + +pub fn attach(pctx: *mut *mut Context, _flags: c_uint) -> Result<(), CUresult> { + if pctx == ptr::null_mut() { + return Err(CUresult::CUDA_ERROR_INVALID_VALUE); } + let ctx = GlobalState::lock_current_context_unchecked(|unchecked_ctx| { + let ctx = unchecked_ctx.as_result_mut()?; + ctx.ref_count.incr()?; + Ok::<_, CUresult>(unchecked_ctx as *mut _) + })??; + unsafe { *pctx = ctx }; + Ok(()) +} + +pub fn detach(pctx: *mut Context) -> Result<(), CUresult> { + if pctx == ptr::null_mut() { + return Err(CUresult::CUDA_ERROR_INVALID_VALUE); + } + GlobalState::lock_current_context_unchecked(|unchecked_ctx| { + let ctx = unchecked_ctx.as_result_mut()?; + if ctx.ref_count.decr() { + Context::destroy_impl(unchecked_ctx)?; + } + Ok::<_, CUresult>(()) + })? } #[cfg(test)] -pub fn is_context_stack_empty() -> bool { - CONTEXT_STACK.with(|stack| stack.borrow().is_empty()) -} - -#[cfg(test)] -mod tests { +mod test { use super::super::test::CudaDriverFns; use super::super::CUresult; use std::{ffi::c_void, ptr}; diff --git a/notcuda/src/impl/device.rs b/notcuda/src/impl/device.rs index d4859d3..b8d263d 100644 --- a/notcuda/src/impl/device.rs +++ b/notcuda/src/impl/device.rs @@ -1,24 +1,21 @@ -use super::{context, transmute_lifetime, CUresult, Error}; +use super::{context, CUresult, GlobalState}; use crate::cuda; use cuda::{CUdevice_attribute, CUuuid_st}; use std::{ cmp, mem, os::raw::{c_char, c_int}, ptr, - sync::{ - atomic::{AtomicU32, Ordering}, - Mutex, MutexGuard, - }, + sync::atomic::{AtomicU32, Ordering}, }; const PROJECT_URL_SUFFIX: &'static str = " [github.com/vosen/notCUDA]"; -static mut DEVICES: Option>> = None; #[repr(transparent)] -#[derive(Clone, Copy)] +#[derive(Clone, Copy, Eq, PartialEq, Hash)] pub struct Index(pub c_int); pub struct Device { + pub index: Index, pub base: l0::Device, pub default_queue: l0::CommandQueue, pub l0_context: l0::Context, @@ -33,17 +30,19 @@ unsafe impl Send for Device {} impl Device { // Unsafe because it does not fully initalize primary_context - unsafe fn new(drv: &l0::Driver, d: l0::Device, idx: usize) -> l0::Result { + unsafe fn new(drv: &l0::Driver, l0_dev: l0::Device, idx: usize) -> Result { let mut ctx = l0::Context::new(drv)?; - let queue = l0::CommandQueue::new(&mut ctx, &d)?; + let queue = l0::CommandQueue::new(&mut ctx, &l0_dev)?; let primary_context = context::Context::new(context::ContextData::new( + &mut ctx, + &l0_dev, 0, true, - Index(idx as c_int), - ptr::null(), - )); + ptr::null_mut(), + )?); Ok(Self { - base: d, + index: Index(idx as c_int), + base: l0_dev, default_queue: queue, l0_context: ctx, primary_context: primary_context, @@ -93,83 +92,53 @@ impl Device { Err(e) => Err(e), } } + + pub fn late_init(&mut self) { + self.primary_context.as_option_mut().unwrap().device = self as *mut _; + } } -pub fn init(driver: &l0::Driver) -> l0::Result<()> { +pub fn init(driver: &l0::Driver) -> Result, CUresult> { let ze_devices = driver.devices()?; let mut devices = ze_devices .into_iter() .enumerate() - .map(|(idx, d)| unsafe { Device::new(driver, d, idx) }.map(Mutex::new)) + .map(|(idx, d)| unsafe { Device::new(driver, d, idx) }) .collect::, _>>()?; - for d in devices.iter_mut() { - d.get_mut() - .unwrap() - .primary_context - .as_mut() - .unwrap() - .device = d; + for dev in devices.iter_mut() { + dev.late_init(); + dev.primary_context.late_init(); } - unsafe { DEVICES = Some(devices) }; + Ok(devices) +} + +pub fn get_count(count: *mut c_int) -> Result<(), CUresult> { + let len = GlobalState::lock(|state| state.devices.len())?; + unsafe { *count = len as c_int }; Ok(()) } -fn devices() -> Result<&'static Vec>, CUresult> { - match unsafe { &DEVICES } { - Some(devs) => Ok(devs), - None => Err(CUresult::CUDA_ERROR_NOT_INITIALIZED), - } -} - -pub fn get_device_ref(Index(dev_idx): Index) -> Result<&'static Mutex, CUresult> { - let devs = devices()?; - if dev_idx < 0 || dev_idx >= devs.len() as c_int { - return Err(CUresult::CUDA_ERROR_INVALID_DEVICE); - } - Ok(&devs[dev_idx as usize]) -} - -pub fn get_device(dev_idx: Index) -> Result, CUresult> { - let dev = get_device_ref(dev_idx)?; - dev.lock().map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE) -} - -pub fn get_count(count: *mut c_int) -> CUresult { - let len = devices().map(|d| d.len()); - match len { - Ok(len) => { - unsafe { *count = len as c_int }; - CUresult::CUDA_SUCCESS - } - Err(e) => e, - } -} - -pub fn get(device: *mut Index, ordinal: c_int) -> CUresult { +pub fn get(device: *mut Index, ordinal: c_int) -> Result<(), CUresult> { if device == ptr::null_mut() || ordinal < 0 { - return CUresult::CUDA_ERROR_INVALID_VALUE; + return Err(CUresult::CUDA_ERROR_INVALID_VALUE); } - let len = devices().map(|d| d.len()); - match len { - Ok(len) if ordinal < (len as i32) => { - unsafe { *device = Index(ordinal) }; - CUresult::CUDA_SUCCESS - } - Ok(_) => CUresult::CUDA_ERROR_INVALID_VALUE, - Err(e) => e, + let len = GlobalState::lock(|state| state.devices.len())?; + if ordinal < (len as i32) { + unsafe { *device = Index(ordinal) }; + Ok(()) + } else { + Err(CUresult::CUDA_ERROR_INVALID_VALUE) } } -pub fn get_name(name: *mut c_char, len: i32, dev: Index) -> Result<(), CUresult> { +pub fn get_name(name: *mut c_char, len: i32, dev_idx: Index) -> Result<(), CUresult> { if name == ptr::null_mut() || len < 0 { return Err(CUresult::CUDA_ERROR_INVALID_VALUE); } - // This is safe because devices are 'static - let name_ptr = { - let mut dev = get_device(dev)?; - let props = dev.get_properties().map_err(Into::::into)?; - props.name.as_ptr() - }; + let name_ptr = GlobalState::lock_device(dev_idx, |dev| { + let props = dev.get_properties()?; + Ok::<_, l0::sys::ze_result_t>(props.name.as_ptr()) + })??; let name_len = (0..256) .position(|i| unsafe { *name_ptr.add(i) } == 0) .unwrap_or(256); @@ -189,20 +158,14 @@ pub fn get_name(name: *mut c_char, len: i32, dev: Index) -> Result<(), CUresult> Ok(()) } -pub fn total_mem_v2(bytes: *mut usize, dev: Index) -> Result<(), CUresult> { +pub fn total_mem_v2(bytes: *mut usize, dev_idx: Index) -> Result<(), CUresult> { if bytes == ptr::null_mut() { return Err(CUresult::CUDA_ERROR_INVALID_VALUE); } - // This is safe because devices are 'static - let mem_props = { - let mut dev = get_device(dev)?; - unsafe { - transmute_lifetime( - dev.get_memory_properties() - .map_err(Into::::into)?, - ) - } - }; + let mem_props = GlobalState::lock_device(dev_idx, |dev| { + let mem_props = dev.get_memory_properties()?; + Ok::<_, l0::sys::ze_result_t>(mem_props) + })??; let max_mem = mem_props .iter() .map(|p| p.totalSize) @@ -228,56 +191,101 @@ impl CUdevice_attribute { } } -pub fn get_attribute(pi: *mut i32, attrib: CUdevice_attribute, dev: Index) -> Result<(), Error> { +pub fn get_attribute( + pi: *mut i32, + attrib: CUdevice_attribute, + dev_idx: Index, +) -> Result<(), CUresult> { if pi == ptr::null_mut() { - return Err(Error::Cuda(CUresult::CUDA_ERROR_INVALID_VALUE)); + return Err(CUresult::CUDA_ERROR_INVALID_VALUE); } if let Some(value) = attrib.get_static_value() { unsafe { *pi = value }; return Ok(()); } - let mut dev = get_device(dev).map_err(Error::Cuda)?; let value = match attrib { CUdevice_attribute::CU_DEVICE_ATTRIBUTE_ASYNC_ENGINE_COUNT => { - dev.get_properties().map_err(Error::L0)?.maxHardwareContexts as i32 + GlobalState::lock_device(dev_idx, |dev| { + let props = dev.get_properties()?; + Ok::<_, l0::sys::ze_result_t>(props.maxHardwareContexts as i32) + })?? } CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT => { - let props = dev.get_properties().map_err(Error::L0)?; - (props.numSlices * props.numSubslicesPerSlice * props.numEUsPerSubslice) as i32 + GlobalState::lock_device(dev_idx, |dev| { + let props = dev.get_properties()?; + Ok::<_, l0::sys::ze_result_t>( + (props.numSlices * props.numSubslicesPerSlice * props.numEUsPerSubslice) as i32, + ) + })?? + } + CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_WIDTH => { + GlobalState::lock_device(dev_idx, |dev| { + let props = dev.get_image_properties()?; + Ok::<_, l0::sys::ze_result_t>(cmp::min( + props.maxImageDims1D, + c_int::max_value() as u32, + ) as c_int) + })?? } - CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_WIDTH => cmp::min( - dev.get_image_properties() - .map_err(Error::L0)? - .maxImageDims1D, - c_int::max_value() as u32, - ) as c_int, CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X => { - let props = dev.get_compute_properties().map_err(Error::L0)?; - cmp::max(i32::max_value() as u32, props.maxGroupCountX) as i32 + GlobalState::lock_device(dev_idx, |dev| { + let props = dev.get_compute_properties()?; + Ok::<_, l0::sys::ze_result_t>(cmp::max( + i32::max_value() as u32, + props.maxGroupCountX, + ) as i32) + })?? } CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y => { - let props = dev.get_compute_properties().map_err(Error::L0)?; - cmp::max(i32::max_value() as u32, props.maxGroupCountY) as i32 + GlobalState::lock_device(dev_idx, |dev| { + let props = dev.get_compute_properties()?; + Ok::<_, l0::sys::ze_result_t>(cmp::max( + i32::max_value() as u32, + props.maxGroupCountY, + ) as i32) + })?? } CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z => { - let props = dev.get_compute_properties().map_err(Error::L0)?; - cmp::max(i32::max_value() as u32, props.maxGroupCountZ) as i32 + GlobalState::lock_device(dev_idx, |dev| { + let props = dev.get_compute_properties()?; + Ok::<_, l0::sys::ze_result_t>(cmp::max( + i32::max_value() as u32, + props.maxGroupCountZ, + ) as i32) + })?? } CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X => { - let props = dev.get_compute_properties().map_err(Error::L0)?; - cmp::max(i32::max_value() as u32, props.maxGroupSizeX) as i32 + GlobalState::lock_device(dev_idx, |dev| { + let props = dev.get_compute_properties()?; + Ok::<_, l0::sys::ze_result_t>( + cmp::max(i32::max_value() as u32, props.maxGroupSizeX) as i32, + ) + })?? } CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y => { - let props = dev.get_compute_properties().map_err(Error::L0)?; - cmp::max(i32::max_value() as u32, props.maxGroupSizeY) as i32 + GlobalState::lock_device(dev_idx, |dev| { + let props = dev.get_compute_properties()?; + Ok::<_, l0::sys::ze_result_t>( + cmp::max(i32::max_value() as u32, props.maxGroupSizeY) as i32, + ) + })?? } CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z => { - let props = dev.get_compute_properties().map_err(Error::L0)?; - cmp::max(i32::max_value() as u32, props.maxGroupSizeZ) as i32 + GlobalState::lock_device(dev_idx, |dev| { + let props = dev.get_compute_properties()?; + Ok::<_, l0::sys::ze_result_t>( + cmp::max(i32::max_value() as u32, props.maxGroupSizeZ) as i32, + ) + })?? } CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK => { - let props = dev.get_compute_properties().map_err(Error::L0)?; - cmp::max(i32::max_value() as u32, props.maxTotalGroupSize) as i32 + GlobalState::lock_device(dev_idx, |dev| { + let props = dev.get_compute_properties()?; + Ok::<_, l0::sys::ze_result_t>(cmp::max( + i32::max_value() as u32, + props.maxTotalGroupSize, + ) as i32) + })?? } _ => { // TODO: support more attributes for CUDA runtime @@ -293,14 +301,11 @@ pub fn get_attribute(pi: *mut i32, attrib: CUdevice_attribute, dev: Index) -> Re Ok(()) } -pub fn get_uuid(uuid: *mut CUuuid_st, dev: Index) -> Result<(), Error> { - let ze_uuid = { - get_device(dev) - .map_err(Error::Cuda)? - .get_properties() - .map_err(Error::L0)? - .uuid - }; +pub fn get_uuid(uuid: *mut CUuuid_st, dev_idx: Index) -> Result<(), CUresult> { + let ze_uuid = GlobalState::lock_device(dev_idx, |dev| { + let props = dev.get_properties()?; + Ok::<_, l0::sys::ze_result_t>(props.uuid) + })??; unsafe { *uuid = CUuuid_st { bytes: mem::transmute(ze_uuid.id), @@ -309,53 +314,39 @@ pub fn get_uuid(uuid: *mut CUuuid_st, dev: Index) -> Result<(), Error> { Ok(()) } -pub fn with_current_exclusive R, R>(f: F) -> Result { - let dev = super::context::with_current(|ctx| ctx.device); - dev.and_then(|dev| { - unsafe { &*dev } - .try_lock() - .map(|mut dev| f(&mut dev)) - .map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE) - }) -} - -pub fn with_exclusive R, R>(dev: Index, f: F) -> Result { - let dev = get_device_ref(dev)?; - dev.try_lock() - .map(|mut dev| f(&mut dev)) - .map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE) -} - pub fn primary_ctx_get_state( - idx: Index, + dev_idx: Index, flags: *mut u32, active: *mut i32, ) -> Result<(), CUresult> { - let (ctx_ptr, flags_ptr) = with_exclusive(idx, |dev| { + let (is_active, flags_value) = GlobalState::lock_device(dev_idx, |dev| { // This is safe because primary context can't be dropped - let ctx_ptr = &dev.primary_context as *const _; + let ctx_ptr = &mut dev.primary_context as *mut _; let flags_ptr = (&unsafe { dev.primary_context.as_ref_unchecked() }.flags) as *const AtomicU32; - (ctx_ptr, flags_ptr) - })?; - let is_active = context::CONTEXT_STACK - .with(|stack| stack.borrow().last().map(|x| *x)) - .map(|current| current == ctx_ptr) - .unwrap_or(false); - let flags_value = unsafe { &*flags_ptr }.load(Ordering::Relaxed); - unsafe { *flags = flags_value }; + let is_active = context::CONTEXT_STACK + .with(|stack| stack.borrow().last().map(|x| *x)) + .map(|current| current == ctx_ptr) + .unwrap_or(false); + let flags_value = unsafe { &*flags_ptr }.load(Ordering::Relaxed); + Ok::<_, l0::sys::ze_result_t>((is_active, flags_value)) + })??; unsafe { *active = if is_active { 1 } else { 0 } }; + unsafe { *flags = flags_value }; Ok(()) } -pub fn primary_ctx_retain(pctx: *mut *mut context::Context, dev: Index) -> Result<(), CUresult> { - let ctx_ptr = with_exclusive(dev, |dev| &mut dev.primary_context as *mut _)?; +pub fn primary_ctx_retain( + pctx: *mut *mut context::Context, + dev_idx: Index, +) -> Result<(), CUresult> { + let ctx_ptr = GlobalState::lock_device(dev_idx, |dev| &mut dev.primary_context as *mut _)?; unsafe { *pctx = ctx_ptr }; Ok(()) } #[cfg(test)] -mod tests { +mod test { use super::super::test::CudaDriverFns; use super::super::CUresult; diff --git a/notcuda/src/impl/export_table.rs b/notcuda/src/impl/export_table.rs index 562af37..ae9f6e3 100644 --- a/notcuda/src/impl/export_table.rs +++ b/notcuda/src/impl/export_table.rs @@ -4,7 +4,7 @@ use crate::{ cuda_impl, }; -use super::{context, device, module, Decuda, Encuda}; +use super::{context, context::ContextData, module, Decuda, Encuda, GlobalState}; use std::mem; use std::os::raw::{c_uint, c_ulong, c_ushort}; use std::{ @@ -110,17 +110,8 @@ static CUDART_INTERFACE_VTABLE: [VTableEntry; CUDART_INTERFACE_LENGTH] = [ VTableEntry { ptr: ptr::null() }, ]; -unsafe extern "C" fn cudart_interface_fn1(pctx: *mut CUcontext, dev: CUdevice) -> CUresult { - cudart_interface_fn1_impl(pctx.decuda(), dev.decuda()).encuda() -} - -fn cudart_interface_fn1_impl( - pctx: *mut *mut context::Context, - dev: device::Index, -) -> Result<(), CUresult> { - let ctx_ptr = device::with_exclusive(dev, |d| &mut d.primary_context as *mut context::Context)?; - unsafe { *pctx = ctx_ptr }; - Ok(()) +unsafe extern "C" fn cudart_interface_fn1(_pctx: *mut CUcontext, _dev: CUdevice) -> CUresult { + super::unimplemented() } /* @@ -200,7 +191,7 @@ unsafe extern "C" fn get_module_from_cubin( ptr1: *mut c_void, ptr2: *mut c_void, ) -> CUresult { - // Not sure what those twoparameters are actually used for, + // Not sure what those two parameters are actually used for, // they are somehow involved in __cudaRegisterHostVar if ptr1 != ptr::null_mut() || ptr2 != ptr::null_mut() { return CUresult::CUDA_ERROR_NOT_SUPPORTED; @@ -234,10 +225,13 @@ unsafe extern "C" fn get_module_from_cubin( }, Err(_) => continue, }; - let module = module::ModuleData::compile_spirv(kernel_text_string); + let module = module::SpirvModule::new(kernel_text_string); match module { Ok(module) => { - *result = Box::into_raw(Box::new(module)); + match module::load_data_impl(result, module) { + Ok(()) => {} + Err(err) => return err, + } return CUresult::CUDA_SUCCESS; } Err(_) => continue, @@ -309,7 +303,7 @@ unsafe extern "C" fn context_local_storage_ctor( } fn context_local_storage_ctor_impl( - mut cu_ctx: *mut context::Context, + cu_ctx: *mut context::Context, mgr: *mut cuda_impl::rt::ContextStateManager, ctx_state: *mut cuda_impl::rt::ContextState, dtor_cb: Option< @@ -320,26 +314,11 @@ fn context_local_storage_ctor_impl( ), >, ) -> Result<(), CUresult> { - if cu_ctx == ptr::null_mut() { - context::get_current(&mut cu_ctx)?; - } - if cu_ctx == ptr::null_mut() { - return Err(CUresult::CUDA_ERROR_INVALID_VALUE); - } - unsafe { &*cu_ctx } - .as_ref() - .ok_or(CUresult::CUDA_ERROR_INVALID_CONTEXT) - .and_then(|ctx| { - ctx.mutable - .try_lock() - .map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE) - .map(|mut mutable| { - mutable.cuda_manager = mgr; - mutable.cuda_state = ctx_state; - mutable.cuda_dtor_cb = dtor_cb; - }) - })?; - Ok(()) + lock_context(cu_ctx, |ctx: &mut ContextData| { + ctx.cuda_manager = mgr; + ctx.cuda_state = ctx_state; + ctx.cuda_dtor_cb = dtor_cb; + }) } // some kind of dtor @@ -357,24 +336,10 @@ unsafe extern "C" fn context_local_storage_get_state( fn context_local_storage_get_state_impl( ctx_state: *mut *mut cuda_impl::rt::ContextState, - mut cu_ctx: *mut context::Context, + cu_ctx: *mut context::Context, _: *mut cuda_impl::rt::ContextStateManager, ) -> Result<(), CUresult> { - if cu_ctx == ptr::null_mut() { - context::get_current(&mut cu_ctx)?; - } - if cu_ctx == ptr::null_mut() { - return Err(CUresult::CUDA_ERROR_INVALID_VALUE); - } - let cuda_state = unsafe { &*cu_ctx } - .as_ref() - .ok_or(CUresult::CUDA_ERROR_INVALID_CONTEXT) - .and_then(|ctx| { - ctx.mutable - .try_lock() - .map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE) - .map(|mutable| mutable.cuda_state) - })?; + let cuda_state = lock_context(cu_ctx, |ctx: &mut ContextData| ctx.cuda_state)?; if cuda_state == ptr::null_mut() { Err(CUresult::CUDA_ERROR_INVALID_VALUE) } else { @@ -382,3 +347,17 @@ fn context_local_storage_get_state_impl( Ok(()) } } + +fn lock_context( + cu_ctx: *mut context::Context, + fn_impl: impl FnOnce(&mut ContextData) -> T, +) -> Result { + if cu_ctx == ptr::null_mut() { + GlobalState::lock_current_context(fn_impl) + } else { + GlobalState::lock(|_| { + let ctx = unsafe { &mut *cu_ctx }.as_result_mut()?; + Ok(fn_impl(ctx)) + })? + } +} diff --git a/notcuda/src/impl/function.rs b/notcuda/src/impl/function.rs index 0ab3bea..394f806 100644 --- a/notcuda/src/impl/function.rs +++ b/notcuda/src/impl/function.rs @@ -1,11 +1,28 @@ use ::std::os::raw::{c_uint, c_void}; use std::ptr; -use super::{device, stream::Stream, CUresult}; +use super::{CUresult, GlobalState, HasLivenessCookie, LiveCheck, stream::Stream}; -pub struct Function { +pub type Function = LiveCheck; + +impl HasLivenessCookie for FunctionData { + #[cfg(target_pointer_width = "64")] + const COOKIE: usize = 0x5e2ab14d5840678e; + + #[cfg(target_pointer_width = "32")] + const COOKIE: usize = 0x33e6a1e6; + + const LIVENESS_FAIL: CUresult = CUresult::CUDA_ERROR_INVALID_HANDLE; + + fn try_drop(&mut self) -> Result<(), CUresult> { + Ok(()) + } +} + +pub struct FunctionData { pub base: l0::Kernel<'static>, pub arg_size: Vec, + pub use_shared_mem: bool, } pub fn launch_kernel( @@ -17,36 +34,43 @@ pub fn launch_kernel( block_dim_y: c_uint, block_dim_z: c_uint, shared_mem_bytes: c_uint, - strean: *mut Stream, + hstream: *mut Stream, kernel_params: *mut *mut c_void, extra: *mut *mut c_void, ) -> Result<(), CUresult> { if f == ptr::null_mut() { return Err(CUresult::CUDA_ERROR_INVALID_VALUE); } - if shared_mem_bytes != 0 || strean != ptr::null_mut() || extra != ptr::null_mut() { + if extra != ptr::null_mut() { return Err(CUresult::CUDA_ERROR_NOT_SUPPORTED); } - let func = unsafe { &*f }; - for (i, arg_size) in func.arg_size.iter().copied().enumerate() { - unsafe { - func.base - .set_arg_raw(i as u32, arg_size, *kernel_params.add(i))? - }; - } - unsafe { &*f } - .base - .set_group_size(block_dim_x, block_dim_y, block_dim_z)?; - device::with_current_exclusive(|dev| { - let mut cmd_list = l0::CommandList::new(&mut dev.l0_context, &dev.base)?; + GlobalState::lock_stream(hstream, |stream| { + let func: &mut FunctionData = unsafe { &mut *f }.as_result_mut()?; + for (i, arg_size) in func.arg_size.iter().enumerate() { + unsafe { + func.base + .set_arg_raw(i as u32, *arg_size, *kernel_params.add(i))? + }; + } + if func.use_shared_mem { + unsafe { + func.base.set_arg_raw( + func.arg_size.len() as u32, + shared_mem_bytes as usize, + ptr::null(), + )? + }; + } + func.base + .set_group_size(block_dim_x, block_dim_y, block_dim_z)?; + let mut cmd_list = stream.command_list()?; cmd_list.append_launch_kernel( - &unsafe { &*f }.base, + &mut func.base, &[grid_dim_x, grid_dim_y, grid_dim_z], None, &mut [], )?; - dev.default_queue.execute(cmd_list)?; - l0::Result::Ok(()) - })??; - Ok(()) + stream.queue.execute(cmd_list)?; + Ok(()) + })? } diff --git a/notcuda/src/impl/memory.rs b/notcuda/src/impl/memory.rs index 439b26f..62dc1cc 100644 --- a/notcuda/src/impl/memory.rs +++ b/notcuda/src/impl/memory.rs @@ -1,57 +1,34 @@ -use super::CUresult; +use super::{stream, CUresult, GlobalState}; use std::ffi::c_void; -pub fn alloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> CUresult { - let alloc_result = super::device::with_current_exclusive(|dev| unsafe { - dev.base.mem_alloc_device(&mut dev.l0_context, bytesize, 0) - }); - match alloc_result { - Ok(Ok(alloc)) => { - unsafe { *dptr = alloc }; - CUresult::CUDA_SUCCESS - } - Ok(Err(e)) => e.into(), - Err(e) => e, - } +pub fn alloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> Result<(), CUresult> { + let ptr = GlobalState::lock_current_context(|ctx| { + let dev = unsafe { &mut *ctx.device }; + Ok::<_, CUresult>(unsafe { dev.base.mem_alloc_device(&mut dev.l0_context, bytesize, 0) }?) + })??; + unsafe { *dptr = ptr }; + Ok(()) } -pub fn copy_v2( - dst: *mut c_void, - src: *const c_void, - bytesize: usize, -) -> Result, CUresult> { - super::device::with_current_exclusive(|dev| unsafe { - memcpy_impl( - &mut dev.l0_context, - dst, - src, - bytesize, - &dev.base, - &mut dev.default_queue, - ) +pub fn copy_v2(dst: *mut c_void, src: *const c_void, bytesize: usize) -> Result<(), CUresult> { + GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| { + let mut cmd_list = stream.command_list()?; + unsafe { cmd_list.append_memory_copy_unsafe(dst, src, bytesize, None, &mut []) }?; + stream.queue.execute(cmd_list)?; + Ok::<_, CUresult>(()) + })? +} + +pub fn free_v2(ptr: *mut c_void) -> Result<(), CUresult> { + GlobalState::lock_current_context(|ctx| { + let dev = unsafe { &mut *ctx.device }; + Ok::<_, CUresult>(unsafe { dev.l0_context.mem_free(ptr) }?) }) -} - -unsafe fn memcpy_impl( - ctx: &mut l0::Context, - dst: *mut c_void, - src: *const c_void, - bytes_count: usize, - dev: &l0::Device, - queue: &mut l0::CommandQueue, -) -> l0::Result<()> { - let mut cmd_list = l0::CommandList::new(ctx, &dev)?; - cmd_list.append_memory_copy_unsafe(dst, src, bytes_count, None, &mut [])?; - queue.execute(cmd_list)?; - Ok(()) -} - -pub(crate) fn free_v2(_: *mut c_void)-> l0::Result<()> { - Ok(()) + .map_err(|_| CUresult::CUDA_ERROR_INVALID_VALUE)? } #[cfg(test)] -mod tests { +mod test { use super::super::test::CudaDriverFns; use super::super::CUresult; use std::ptr; @@ -82,4 +59,20 @@ mod tests { assert_ne!(mem, ptr::null_mut()); assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS); } + + cuda_driver_test!(free_without_ctx); + + fn free_without_ctx() { + assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS); + let mut ctx = ptr::null_mut(); + assert_eq!(T::cuCtxCreate_v2(&mut ctx, 0, 0), CUresult::CUDA_SUCCESS); + let mut mem = ptr::null_mut(); + assert_eq!( + T::cuMemAlloc_v2(&mut mem, std::mem::size_of::()), + CUresult::CUDA_SUCCESS + ); + assert_ne!(mem, ptr::null_mut()); + assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS); + assert_eq!(T::cuMemFree_v2(mem), CUresult::CUDA_ERROR_INVALID_VALUE); + } } diff --git a/notcuda/src/impl/mod.rs b/notcuda/src/impl/mod.rs index 5a72ce4..770a32b 100644 --- a/notcuda/src/impl/mod.rs +++ b/notcuda/src/impl/mod.rs @@ -1,5 +1,15 @@ -use crate::cuda::{CUctx_st, CUdevice, CUdeviceptr, CUfunc_st, CUmod_st, CUresult, CUstream_st}; -use std::{ffi::c_void, mem::{self, ManuallyDrop}, os::raw::c_int, sync::Mutex}; +use crate::{ + cuda::{CUctx_st, CUdevice, CUdeviceptr, CUfunc_st, CUmod_st, CUresult, CUstream_st}, + r#impl::device::Device, +}; +use std::{ + ffi::c_void, + mem::{self, ManuallyDrop}, + os::raw::c_int, + ptr, + sync::Mutex, + sync::TryLockError, +}; #[cfg(test)] #[macro_use] @@ -7,9 +17,9 @@ pub mod test; pub mod context; pub mod device; pub mod export_table; +pub mod function; pub mod memory; pub mod module; -pub mod function; pub mod stream; #[cfg(debug_assertions)] @@ -22,8 +32,11 @@ pub fn unimplemented() -> CUresult { CUresult::CUDA_ERROR_NOT_SUPPORTED } -pub trait HasLivenessCookie { +pub trait HasLivenessCookie: Sized { const COOKIE: usize; + const LIVENESS_FAIL: CUresult; + + fn try_drop(&mut self) -> Result<(), CUresult>; } // This struct is a best-effort check if wrapped value has been dropped, @@ -42,19 +55,23 @@ impl LiveCheck { } } + fn destroy_impl(this: *mut Self) -> Result<(), CUresult> { + let mut ctx_box = ManuallyDrop::new(unsafe { Box::from_raw(this) }); + ctx_box.try_drop()?; + unsafe { ManuallyDrop::drop(&mut ctx_box) }; + Ok(()) + } + + unsafe fn ptr_from_inner(this: *mut T) -> *mut Self { + let outer_ptr = (this as *mut u8).sub(mem::size_of::()); + outer_ptr as *mut Self + } + pub unsafe fn as_ref_unchecked(&self) -> &T { &self.data } - pub fn as_ref(&self) -> Option<&T> { - if self.cookie == T::COOKIE { - Some(&self.data) - } else { - None - } - } - - pub fn as_mut(&mut self) -> Option<&mut T> { + pub fn as_option_mut(&mut self) -> Option<&mut T> { if self.cookie == T::COOKIE { Some(&mut self.data) } else { @@ -62,14 +79,31 @@ impl LiveCheck { } } + pub fn as_result(&self) -> Result<&T, CUresult> { + if self.cookie == T::COOKIE { + Ok(&self.data) + } else { + Err(T::LIVENESS_FAIL) + } + } + + pub fn as_result_mut(&mut self) -> Result<&mut T, CUresult> { + if self.cookie == T::COOKIE { + Ok(&mut self.data) + } else { + Err(T::LIVENESS_FAIL) + } + } + #[must_use] - pub fn try_drop(&mut self) -> bool { + pub fn try_drop(&mut self) -> Result<(), CUresult> { if self.cookie == T::COOKIE { self.cookie = 0; + self.data.try_drop()?; unsafe { ManuallyDrop::drop(&mut self.data) }; - return true; + return Ok(()); } - false + Err(T::LIVENESS_FAIL) } } @@ -121,6 +155,12 @@ impl From for CUresult { } } +impl From> for CUresult { + fn from(_: TryLockError) -> Self { + CUresult::CUDA_ERROR_ILLEGAL_STATE + } +} + pub trait Encuda { type To: Sized; fn encuda(self: Self) -> Self::To; @@ -157,58 +197,103 @@ impl, T2: Encuda> Encuda for Result Self::To { - match self { - Error::L0(e) => e.into(), - Error::Cuda(e) => e, - } - } -} - lazy_static! { static ref GLOBAL_STATE: Mutex> = Mutex::new(None); } struct GlobalState { - driver: l0::Driver, + devices: Vec, } unsafe impl Send for GlobalState {} +impl GlobalState { + fn lock(f: impl FnOnce(&mut GlobalState) -> T) -> Result { + let mut mutex = GLOBAL_STATE + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + let global_state = mutex.as_mut().ok_or(CUresult::CUDA_ERROR_ILLEGAL_STATE)?; + Ok(f(global_state)) + } + + fn lock_device( + device::Index(dev_idx): device::Index, + f: impl FnOnce(&'static mut device::Device) -> T, + ) -> Result { + if dev_idx < 0 { + return Err(CUresult::CUDA_ERROR_INVALID_DEVICE); + } + Self::lock(|global_state| { + if dev_idx >= global_state.devices.len() as c_int { + Err(CUresult::CUDA_ERROR_INVALID_DEVICE) + } else { + Ok(f(unsafe { + transmute_lifetime_mut(&mut global_state.devices[dev_idx as usize]) + })) + } + })? + } + + fn lock_current_context R, R>( + f: F, + ) -> Result { + Self::lock_current_context_unchecked(|ctx| Ok(f(ctx.as_result_mut()?)))? + } + + fn lock_current_context_unchecked R, R>( + f: F, + ) -> Result { + context::CONTEXT_STACK.with(|stack| { + stack + .borrow_mut() + .last_mut() + .ok_or(CUresult::CUDA_ERROR_INVALID_CONTEXT) + .map(|ctx| GlobalState::lock(|_| f(unsafe { &mut **ctx })))? + }) + } + + fn lock_stream( + stream: *mut stream::Stream, + f: impl FnOnce(&mut stream::StreamData) -> T, + ) -> Result { + if stream == ptr::null_mut() + || stream == stream::CU_STREAM_LEGACY + || stream == stream::CU_STREAM_PER_THREAD + { + Self::lock_current_context(|ctx| Ok(f(&mut ctx.default_stream)))? + } else { + Self::lock(|_| { + let stream = unsafe { &mut *stream }.as_result_mut()?; + Ok(f(stream)) + })? + } + } +} + // TODO: implement fn is_intel_gpu_driver(_: &l0::Driver) -> bool { true } -pub fn init() -> l0::Result<()> { +pub fn init() -> Result<(), CUresult> { let mut global_state = GLOBAL_STATE .lock() - .map_err(|_| l0::sys::ze_result_t::ZE_RESULT_ERROR_UNKNOWN)?; + .map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?; if global_state.is_some() { return Ok(()); } l0::init()?; let drivers = l0::Driver::get()?; - let driver = match drivers.into_iter().find(is_intel_gpu_driver) { - None => return Err(l0::sys::ze_result_t::ZE_RESULT_ERROR_UNKNOWN), - Some(driver) => { - device::init(&driver)?; - driver - } + let devices = match drivers.into_iter().find(is_intel_gpu_driver) { + None => return Err(CUresult::CUDA_ERROR_UNKNOWN), + Some(driver) => device::init(&driver)?, }; - *global_state = Some(GlobalState { driver }); + *global_state = Some(GlobalState { devices }); drop(global_state); Ok(()) } -unsafe fn transmute_lifetime<'a, 'b, T: ?Sized>(t: &'a T) -> &'b T { +unsafe fn transmute_lifetime_mut<'a, 'b, T: ?Sized>(t: &'a mut T) -> &'b mut T { mem::transmute(t) } diff --git a/notcuda/src/impl/module.rs b/notcuda/src/impl/module.rs index 35436c3..4422107 100644 --- a/notcuda/src/impl/module.rs +++ b/notcuda/src/impl/module.rs @@ -1,79 +1,90 @@ use std::{ - collections::HashMap, ffi::CStr, ffi::CString, mem, os::raw::c_char, ptr, slice, sync::Mutex, + collections::hash_map, collections::HashMap, ffi::c_void, ffi::CStr, ffi::CString, mem, + os::raw::c_char, ptr, slice, }; -use super::{function::Function, transmute_lifetime, CUresult}; +use super::{ + device, function::Function, function::FunctionData, CUresult, GlobalState, HasLivenessCookie, + LiveCheck, +}; use ptx; -pub type Module = Mutex; +pub type Module = LiveCheck; + +impl HasLivenessCookie for ModuleData { + #[cfg(target_pointer_width = "64")] + const COOKIE: usize = 0xf1313bd46505f98a; + + #[cfg(target_pointer_width = "32")] + const COOKIE: usize = 0xbdbe3f15; + + const LIVENESS_FAIL: CUresult = CUresult::CUDA_ERROR_INVALID_HANDLE; + + fn try_drop(&mut self) -> Result<(), CUresult> { + Ok(()) + } +} pub struct ModuleData { - base: l0::Module, - arg_lens: HashMap>, + pub spirv: SpirvModule, + // This should be a Vec<>, but I'm feeling lazy + pub device_binaries: HashMap, } -pub enum ModuleCompileError<'a> { - Parse( - Vec, - Option, ptx::ast::PtxError>>, - ), - Compile(ptx::TranslateError), - L0(l0::sys::ze_result_t), - CUDA(CUresult), +pub struct SpirvModule { + pub binaries: Vec, + pub kernel_info: HashMap, + pub should_link_ptx_impl: Option<&'static [u8]>, + pub build_options: CString, } -impl<'a> ModuleCompileError<'a> { - pub fn get_build_log(&self) {} +pub struct CompiledModule { + pub base: l0::Module, + pub kernels: HashMap>, } -impl<'a> From for ModuleCompileError<'a> { - fn from(err: ptx::TranslateError) -> Self { - ModuleCompileError::Compile(err) +impl From> for CUresult { + fn from(_: ptx::ParseError) -> Self { + CUresult::CUDA_ERROR_INVALID_PTX } } -impl<'a> From for ModuleCompileError<'a> { - fn from(err: l0::sys::ze_result_t) -> Self { - ModuleCompileError::L0(err) +impl From for CUresult { + fn from(_: ptx::TranslateError) -> Self { + CUresult::CUDA_ERROR_INVALID_PTX } } -impl<'a> From for ModuleCompileError<'a> { - fn from(err: CUresult) -> Self { - ModuleCompileError::CUDA(err) +impl SpirvModule { + pub fn new_raw<'a>(text: *const c_char) -> Result { + let u8_text = unsafe { CStr::from_ptr(text) }; + let ptx_text = u8_text + .to_str() + .map_err(|_| CUresult::CUDA_ERROR_INVALID_PTX)?; + Self::new(ptx_text) } -} -impl ModuleData { - pub fn compile_spirv<'a>(ptx_text: &'a str) -> Result> { + pub fn new<'a>(ptx_text: &str) -> Result { let mut errors = Vec::new(); - let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_text); - let ast = match ast { - Err(e) => return Err(ModuleCompileError::Parse(errors, Some(e))), - Ok(_) if errors.len() > 0 => return Err(ModuleCompileError::Parse(errors, None)), - Ok(ast) => ast, - }; - let (_, spirv, all_arg_lens) = ptx::to_spirv(ast)?; + let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_text)?; + let spirv_module = ptx::to_spirv_module(ast)?; + Ok(SpirvModule { + binaries: spirv_module.assemble(), + kernel_info: spirv_module.kernel_info, + should_link_ptx_impl: spirv_module.should_link_ptx_impl, + build_options: spirv_module.build_options, + }) + } + + pub fn compile(&self, ctx: &mut l0::Context, dev: &l0::Device) -> Result { let byte_il = unsafe { - slice::from_raw_parts::( - spirv.as_ptr() as *const _, - spirv.len() * mem::size_of::(), + slice::from_raw_parts( + self.binaries.as_ptr() as *const u8, + self.binaries.len() * mem::size_of::(), ) }; - let module = super::device::with_current_exclusive(|dev| { - l0::Module::build_spirv(&mut dev.l0_context, &dev.base, byte_il, None) - }); - match module { - Ok((Ok(module), _)) => Ok(Mutex::new(Self { - base: module, - arg_lens: all_arg_lens - .into_iter() - .map(|(k, v)| (CString::new(k).unwrap(), v)) - .collect(), - })), - Ok((Err(err), _)) => Err(ModuleCompileError::from(err)), - Err(err) => Err(ModuleCompileError::from(err)), - } + let l0_module = l0::Module::build_spirv(ctx, dev, byte_il, None).0?; + Ok(l0_module) } } @@ -85,34 +96,75 @@ pub fn get_function( if hfunc == ptr::null_mut() || hmod == ptr::null_mut() || name == ptr::null() { return Err(CUresult::CUDA_ERROR_INVALID_VALUE); } - let name = unsafe { CStr::from_ptr(name) }; - let (mut kernel, args_len) = unsafe { &*hmod } - .try_lock() - .map(|module| { - Result::<_, CUresult>::Ok(( - l0::Kernel::new_resident(unsafe { transmute_lifetime(&module.base) }, name)?, - module - .arg_lens - .get(name) - .ok_or(CUresult::CUDA_ERROR_NOT_FOUND)? - .clone(), - )) - }) - .map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE)??; - kernel.set_indirect_access( - l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE - | l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_HOST - | l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_SHARED, - )?; - unsafe { - *hfunc = Box::into_raw(Box::new(Function { - base: kernel, - arg_size: args_len, - })) - }; + let name = unsafe { CStr::from_ptr(name) }.to_owned(); + let function: *mut Function = GlobalState::lock_current_context(|ctx| { + let module = unsafe { &mut *hmod }.as_result_mut()?; + let device = unsafe { &mut *ctx.device }; + let compiled_module = match module.device_binaries.entry(device.index) { + hash_map::Entry::Occupied(entry) => entry.into_mut(), + hash_map::Entry::Vacant(entry) => { + let new_module = CompiledModule { + base: module.spirv.compile(&mut device.l0_context, &device.base)?, + kernels: HashMap::new(), + }; + entry.insert(new_module) + } + }; + //let compiled_module = unsafe { transmute_lifetime_mut(compiled_module) }; + let kernel = match compiled_module.kernels.entry(name) { + hash_map::Entry::Occupied(entry) => entry.into_mut().as_mut(), + hash_map::Entry::Vacant(entry) => { + let kernel_info = module + .spirv + .kernel_info + .get(unsafe { + std::str::from_utf8_unchecked(entry.key().as_c_str().to_bytes()) + }) + .ok_or(CUresult::CUDA_ERROR_NOT_FOUND)?; + let kernel = + l0::Kernel::new_resident(&compiled_module.base, entry.key().as_c_str())?; + entry.insert(Box::new(Function::new(FunctionData { + base: kernel, + arg_size: kernel_info.arguments_sizes.clone(), + use_shared_mem: kernel_info.uses_shared_mem, + }))) + } + }; + Ok::<_, CUresult>(kernel as *mut _) + })??; + unsafe { *hfunc = function }; Ok(()) } -pub(crate) fn unload(_: *mut Module) -> Result<(), CUresult> { +pub(crate) fn load_data(pmod: *mut *mut Module, image: *const c_void) -> Result<(), CUresult> { + let spirv_data = SpirvModule::new_raw(image as *const _)?; + load_data_impl(pmod, spirv_data) +} + +pub fn load_data_impl(pmod: *mut *mut Module, spirv_data: SpirvModule) -> Result<(), CUresult> { + let module = GlobalState::lock_current_context(|ctx| { + let device = unsafe { &mut *ctx.device }; + let l0_module = spirv_data.compile(&mut device.l0_context, &device.base)?; + let mut device_binaries = HashMap::new(); + let compiled_module = CompiledModule { + base: l0_module, + kernels: HashMap::new(), + }; + device_binaries.insert(device.index, compiled_module); + let module_data = ModuleData { + spirv: spirv_data, + device_binaries, + }; + Ok::<_, CUresult>(module_data) + })??; + let module_ptr = Box::into_raw(Box::new(Module::new(module))); + unsafe { *pmod = module_ptr }; Ok(()) } + +pub(crate) fn unload(module: *mut Module) -> Result<(), CUresult> { + if module == ptr::null_mut() { + return Err(CUresult::CUDA_ERROR_INVALID_VALUE); + } + GlobalState::lock(|_| Module::destroy_impl(module))? +} diff --git a/notcuda/src/impl/stream.rs b/notcuda/src/impl/stream.rs index 1844677..e212dfc 100644 --- a/notcuda/src/impl/stream.rs +++ b/notcuda/src/impl/stream.rs @@ -1,36 +1,114 @@ -use std::cell::RefCell; +use super::{ + context::{Context, ContextData}, + CUresult, GlobalState, +}; +use std::{mem, ptr}; -use device::Device; +use super::{HasLivenessCookie, LiveCheck}; -use super::device; +pub type Stream = LiveCheck; -pub struct Stream { - dev: *mut Device, -} +pub const CU_STREAM_LEGACY: *mut Stream = 1 as *mut _; +pub const CU_STREAM_PER_THREAD: *mut Stream = 2 as *mut _; -pub struct DefaultStream { - streams: Vec>, -} +impl HasLivenessCookie for StreamData { + #[cfg(target_pointer_width = "64")] + const COOKIE: usize = 0x512097354de18d35; -impl DefaultStream { - fn new() -> Self { - DefaultStream { - streams: Vec::new(), + #[cfg(target_pointer_width = "32")] + const COOKIE: usize = 0x77d5cc0b; + + const LIVENESS_FAIL: CUresult = CUresult::CUDA_ERROR_INVALID_HANDLE; + + fn try_drop(&mut self) -> Result<(), CUresult> { + if self.context != ptr::null_mut() { + let context = unsafe { &mut *self.context }; + if !context.streams.remove(&(self as *mut _)) { + return Err(CUresult::CUDA_ERROR_UNKNOWN); + } } + Ok(()) } } -thread_local! { - pub static DEFAULT_STREAM: RefCell = RefCell::new(DefaultStream::new()); +pub struct StreamData { + pub context: *mut ContextData, + pub queue: l0::CommandQueue, +} + +impl StreamData { + pub fn new_unitialized(ctx: &mut l0::Context, dev: &l0::Device) -> Result { + Ok(StreamData { + context: ptr::null_mut(), + queue: l0::CommandQueue::new(ctx, dev)?, + }) + } + pub fn new(ctx: &mut ContextData) -> Result { + let l0_ctx = &mut unsafe { &mut *ctx.device }.l0_context; + let l0_dev = &unsafe { &*ctx.device }.base; + Ok(StreamData { + context: ctx as *mut _, + queue: l0::CommandQueue::new(l0_ctx, l0_dev)?, + }) + } + + pub fn command_list(&self) -> Result { + let ctx = unsafe { &mut *self.context }; + let dev = unsafe { &mut *ctx.device }; + l0::CommandList::new(&mut dev.l0_context, &dev.base) + } +} + +impl Drop for StreamData { + fn drop(&mut self) { + if self.context == ptr::null_mut() { + return; + } + unsafe { (&mut *self.context).streams.remove(&(&mut *self as *mut _)) }; + } +} + +pub(crate) fn get_ctx(hstream: *mut Stream, pctx: *mut *mut Context) -> Result<(), CUresult> { + if pctx == ptr::null_mut() { + return Err(CUresult::CUDA_ERROR_INVALID_VALUE); + } + let ctx_ptr = GlobalState::lock_stream(hstream, |stream| stream.context)?; + if ctx_ptr == ptr::null_mut() { + return Err(CUresult::CUDA_ERROR_CONTEXT_IS_DESTROYED); + } + unsafe { *pctx = Context::ptr_from_inner(ctx_ptr) }; + Ok(()) +} + +pub(crate) fn create(phstream: *mut *mut Stream, _flags: u32) -> Result<(), CUresult> { + let stream_ptr = GlobalState::lock_current_context(|ctx| { + let mut stream_box = Box::new(Stream::new(StreamData::new(ctx)?)); + let stream_ptr = stream_box.as_mut().as_option_mut().unwrap() as *mut _; + if !ctx.streams.insert(stream_ptr) { + return Err(CUresult::CUDA_ERROR_UNKNOWN); + } + mem::forget(stream_box); + Ok::<_, CUresult>(stream_ptr) + })??; + unsafe { *phstream = Stream::ptr_from_inner(stream_ptr) }; + Ok(()) +} + +pub(crate) fn destroy_v2(pstream: *mut Stream) -> Result<(), CUresult> { + if pstream == ptr::null_mut() || pstream == CU_STREAM_LEGACY || pstream == CU_STREAM_PER_THREAD + { + return Err(CUresult::CUDA_ERROR_INVALID_VALUE); + } + GlobalState::lock(|_| Stream::destroy_impl(pstream))? } #[cfg(test)] -mod tests { +mod test { use crate::cuda::CUstream; use super::super::test::CudaDriverFns; use super::super::CUresult; - use std::ptr; + use std::{ptr, thread}; const CU_STREAM_LEGACY: CUstream = 1 as *mut _; const CU_STREAM_PER_THREAD: CUstream = 2 as *mut _; @@ -65,5 +143,100 @@ mod tests { CUresult::CUDA_SUCCESS ); assert_eq!(ctx2, stream_ctx2); + // Cleanup + assert_eq!(T::cuCtxDestroy_v2(ctx1), CUresult::CUDA_SUCCESS); + assert_eq!(T::cuCtxDestroy_v2(ctx2), CUresult::CUDA_SUCCESS); + } + + cuda_driver_test!(stream_context_destroyed); + + fn stream_context_destroyed() { + assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS); + let mut ctx = ptr::null_mut(); + assert_eq!(T::cuCtxCreate_v2(&mut ctx, 0, 0), CUresult::CUDA_SUCCESS); + let mut stream = ptr::null_mut(); + assert_eq!(T::cuStreamCreate(&mut stream, 0), CUresult::CUDA_SUCCESS); + let mut stream_ctx1 = ptr::null_mut(); + assert_eq!( + T::cuStreamGetCtx(stream, &mut stream_ctx1), + CUresult::CUDA_SUCCESS + ); + assert_eq!(stream_ctx1, ctx); + assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS); + let mut stream_ctx2 = ptr::null_mut(); + // When a context gets destroyed, its streams are also destroyed + let cuda_result = T::cuStreamGetCtx(stream, &mut stream_ctx2); + assert!( + cuda_result == CUresult::CUDA_ERROR_INVALID_HANDLE + || cuda_result == CUresult::CUDA_ERROR_INVALID_CONTEXT + || cuda_result == CUresult::CUDA_ERROR_CONTEXT_IS_DESTROYED + ); + assert_eq!( + T::cuStreamDestroy_v2(stream), + CUresult::CUDA_ERROR_INVALID_HANDLE + ); + // Check if creating another context is possible + let mut ctx2 = ptr::null_mut(); + assert_eq!(T::cuCtxCreate_v2(&mut ctx2, 0, 0), CUresult::CUDA_SUCCESS); + // Cleanup + assert_eq!(T::cuCtxDestroy_v2(ctx2), CUresult::CUDA_SUCCESS); + } + + cuda_driver_test!(stream_moves_context_to_another_thread); + + fn stream_moves_context_to_another_thread() { + assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS); + let mut ctx = ptr::null_mut(); + assert_eq!(T::cuCtxCreate_v2(&mut ctx, 0, 0), CUresult::CUDA_SUCCESS); + let mut stream = ptr::null_mut(); + assert_eq!(T::cuStreamCreate(&mut stream, 0), CUresult::CUDA_SUCCESS); + let mut stream_ctx1 = ptr::null_mut(); + assert_eq!( + T::cuStreamGetCtx(stream, &mut stream_ctx1), + CUresult::CUDA_SUCCESS + ); + assert_eq!(stream_ctx1, ctx); + let stream_ptr = stream as usize; + let stream_ctx_on_thread = thread::spawn(move || { + let mut stream_ctx2 = ptr::null_mut(); + assert_eq!( + T::cuStreamGetCtx(stream_ptr as *mut _, &mut stream_ctx2), + CUresult::CUDA_SUCCESS + ); + stream_ctx2 as usize + }) + .join() + .unwrap(); + assert_eq!(stream_ctx1, stream_ctx_on_thread as *mut _); + // Cleanup + assert_eq!(T::cuStreamDestroy_v2(stream), CUresult::CUDA_SUCCESS); + assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS); + } + + cuda_driver_test!(can_destroy_stream); + + fn can_destroy_stream() { + assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS); + let mut ctx = ptr::null_mut(); + assert_eq!(T::cuCtxCreate_v2(&mut ctx, 0, 0), CUresult::CUDA_SUCCESS); + let mut stream = ptr::null_mut(); + assert_eq!(T::cuStreamCreate(&mut stream, 0), CUresult::CUDA_SUCCESS); + assert_eq!(T::cuStreamDestroy_v2(stream), CUresult::CUDA_SUCCESS); + // Cleanup + assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS); + } + + cuda_driver_test!(cant_destroy_default_stream); + + fn cant_destroy_default_stream() { + assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS); + let mut ctx = ptr::null_mut(); + assert_eq!(T::cuCtxCreate_v2(&mut ctx, 0, 0), CUresult::CUDA_SUCCESS); + assert_ne!( + T::cuStreamDestroy_v2(super::CU_STREAM_LEGACY as *mut _), + CUresult::CUDA_SUCCESS + ); + // Cleanup + assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS); } } diff --git a/notcuda/src/impl/test.rs b/notcuda/src/impl/test.rs index dbd2eff..b6ed926 100644 --- a/notcuda/src/impl/test.rs +++ b/notcuda/src/impl/test.rs @@ -1,8 +1,12 @@ #![allow(non_snake_case)] -use crate::{cuda::CUstream, r#impl as notcuda}; -use crate::r#impl::CUresult; -use crate::{cuda::CUuuid, r#impl::Encuda}; +use crate::cuda as notcuda; +use crate::cuda::CUstream; +use crate::cuda::CUuuid; +use crate::{ + cuda::{CUdevice, CUdeviceptr}, + r#impl::CUresult, +}; use ::std::{ ffi::c_void, os::raw::{c_int, c_uint}, @@ -37,48 +41,63 @@ pub trait CudaDriverFns { fn cuDeviceGetUuid(uuid: *mut CUuuid, dev: c_int) -> CUresult; fn cuDevicePrimaryCtxGetState(dev: c_int, flags: *mut c_uint, active: *mut c_int) -> CUresult; fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult; + fn cuStreamCreate(stream: *mut CUstream, flags: c_uint) -> CUresult; + fn cuMemFree_v2(mem: *mut c_void) -> CUresult; + fn cuStreamDestroy_v2(stream: CUstream) -> CUresult; } pub struct NotCuda(); impl CudaDriverFns for NotCuda { fn cuInit(_flags: c_uint) -> CUresult { - crate::cuda::cuInit(_flags as _) + notcuda::cuInit(_flags as _) } fn cuCtxCreate_v2(pctx: *mut *mut c_void, flags: c_uint, dev: c_int) -> CUresult { - notcuda::context::create_v2(pctx as *mut _, flags, notcuda::device::Index(dev)).encuda() + notcuda::cuCtxCreate_v2(pctx as *mut _, flags, CUdevice(dev)) } fn cuCtxDestroy_v2(ctx: *mut c_void) -> CUresult { - notcuda::context::destroy_v2(ctx as *mut _) + notcuda::cuCtxDestroy_v2(ctx as *mut _) } fn cuCtxPopCurrent_v2(pctx: *mut *mut c_void) -> CUresult { - notcuda::context::pop_current_v2(pctx as *mut _) + notcuda::cuCtxPopCurrent_v2(pctx as *mut _) } fn cuCtxGetApiVersion(ctx: *mut c_void, version: *mut c_uint) -> CUresult { - notcuda::context::get_api_version(ctx as *mut _, version) + notcuda::cuCtxGetApiVersion(ctx as *mut _, version) } fn cuCtxGetCurrent(pctx: *mut *mut c_void) -> CUresult { - notcuda::context::get_current(pctx as *mut _).encuda() + notcuda::cuCtxGetCurrent(pctx as *mut _) } fn cuMemAlloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> CUresult { - notcuda::memory::alloc_v2(dptr as *mut _, bytesize) + notcuda::cuMemAlloc_v2(dptr as *mut _, bytesize) } fn cuDeviceGetUuid(uuid: *mut CUuuid, dev: c_int) -> CUresult { - notcuda::device::get_uuid(uuid, notcuda::device::Index(dev)).encuda() + notcuda::cuDeviceGetUuid(uuid, CUdevice(dev)) } fn cuDevicePrimaryCtxGetState(dev: c_int, flags: *mut c_uint, active: *mut c_int) -> CUresult { - notcuda::device::primary_ctx_get_state(notcuda::device::Index(dev), flags, active).encuda() + notcuda::cuDevicePrimaryCtxGetState(CUdevice(dev), flags, active) } fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult { - crate::cuda::cuStreamGetCtx(hStream, pctx as _) + notcuda::cuStreamGetCtx(hStream, pctx as _) + } + + fn cuStreamCreate(stream: *mut CUstream, flags: c_uint) -> CUresult { + notcuda::cuStreamCreate(stream, flags) + } + + fn cuMemFree_v2(dptr: *mut c_void) -> CUresult { + notcuda::cuMemFree_v2(CUdeviceptr(dptr as _)) + } + + fn cuStreamDestroy_v2(stream: CUstream) -> CUresult { + notcuda::cuStreamDestroy_v2(stream) } } @@ -123,4 +142,16 @@ impl CudaDriverFns for Cuda { fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult { unsafe { CUresult(cuda::cuStreamGetCtx(hStream as _, pctx as _) as c_uint) } } + + fn cuStreamCreate(stream: *mut CUstream, flags: c_uint) -> CUresult { + unsafe { CUresult(cuda::cuStreamCreate(stream as _, flags as _) as c_uint) } + } + + fn cuMemFree_v2(mem: *mut c_void) -> CUresult { + unsafe { CUresult(cuda::cuMemFree_v2(mem as _) as c_uint) } + } + + fn cuStreamDestroy_v2(stream: CUstream) -> CUresult { + unsafe { CUresult(cuda::cuStreamDestroy_v2(stream as _) as c_uint) } + } } diff --git a/ptx/src/lib.rs b/ptx/src/lib.rs index 1aac8ab..591428f 100644 --- a/ptx/src/lib.rs +++ b/ptx/src/lib.rs @@ -34,8 +34,9 @@ pub use crate::ptx::ModuleParser; pub use lalrpop_util::lexer::Token; pub use lalrpop_util::ParseError; pub use rspirv::dr::Error as SpirvError; -pub use translate::TranslateError as TranslateError; -pub use translate::to_spirv; +pub use translate::to_spirv_module; +pub use translate::KernelInfo; +pub use translate::TranslateError; pub(crate) fn without_none(x: Vec>) -> Vec { x.into_iter().filter_map(|x| x).collect() diff --git a/ptx/src/test/mod.rs b/ptx/src/test/mod.rs index 0339141..0785f3e 100644 --- a/ptx/src/test/mod.rs +++ b/ptx/src/test/mod.rs @@ -12,7 +12,7 @@ fn parse_and_assert(s: &str) { fn compile_and_assert(s: &str) -> Result<(), TranslateError> { let mut errors = Vec::new(); let ast = ptx::ModuleParser::new().parse(&mut errors, s).unwrap(); - crate::to_spirv(ast)?; + crate::to_spirv_module(ast)?; Ok(()) } diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index c0e15f2..3d0f476 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1,7 +1,7 @@ use crate::ast; use half::f16; use rspirv::{binary::Disassemble, dr}; -use std::{borrow::Cow, convert::TryFrom, ffi::CString, hash::Hash, iter, mem}; +use std::{borrow::Cow, ffi::CString, hash::Hash, iter, mem}; use std::{ collections::{hash_map, HashMap, HashSet}, convert::TryInto, @@ -450,6 +450,11 @@ pub struct Module { pub should_link_ptx_impl: Option<&'static [u8]>, pub build_options: CString, } +impl Module { + pub fn assemble(&self) -> Vec { + self.spirv.assemble() + } +} pub struct KernelInfo { pub arguments_sizes: Vec, @@ -1046,8 +1051,12 @@ fn emit_function_header<'a>( kernel_info: &mut HashMap, ) -> Result<(), TranslateError> { if let MethodName::Kernel(name) = func_decl.name { - let args_lens = func_decl - .input + let input_args = if !func_decl.uses_shared_mem { + func_decl.input.as_slice() + } else { + &func_decl.input[0..func_decl.input.len() - 1] + }; + let args_lens = input_args .iter() .map(|param| param.v_type.size_of()) .collect(); @@ -1135,21 +1144,6 @@ fn emit_function_header<'a>( Ok(()) } -pub fn to_spirv<'a>( - ast: ast::Module<'a>, -) -> Result<(Option<&'static [u8]>, Vec, HashMap>), TranslateError> { - let module = to_spirv_module(ast)?; - Ok(( - module.should_link_ptx_impl, - module.spirv.assemble(), - module - .kernel_info - .into_iter() - .map(|(k, v)| (k, v.arguments_sizes)) - .collect(), - )) -} - fn emit_capabilities(builder: &mut dr::Builder) { builder.capability(spirv::Capability::GenericPointer); builder.capability(spirv::Capability::Linkage);