From e40785aa7491de16c65de7aa599105102ffa7355 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Thu, 27 May 2021 02:05:17 +0200 Subject: [PATCH] Refactor L0 bindings --- level_zero/src/ze.rs | 836 ++++++++++++++++++++-------------- ptx/src/test/spirv_run/mod.rs | 63 ++- zluda/src/impl/context.rs | 6 +- zluda/src/impl/device.rs | 38 +- zluda/src/impl/function.rs | 4 +- zluda/src/impl/memory.rs | 18 +- zluda/src/impl/module.rs | 20 +- zluda/src/impl/stream.rs | 8 +- zluda_ml/src/impl.rs | 3 +- 9 files changed, 577 insertions(+), 419 deletions(-) diff --git a/level_zero/src/ze.rs b/level_zero/src/ze.rs index d2b1115..88adfe6 100644 --- a/level_zero/src/ze.rs +++ b/level_zero/src/ze.rs @@ -1,11 +1,37 @@ +use sys::zeFenceDestroy; + use crate::sys; use std::{ ffi::{c_void, CStr, CString}, fmt::Debug, marker::PhantomData, - mem, ptr, + mem, + ptr::{self, NonNull}, }; +/* + This module is not a user-friendly, safe binding. The problem is tracking + object lifetimes. E.g. kernel object cannot outlive module object. + While Rust is relatively good at it, it's tricky to translate it to a safe + API in a way that we can mix and match them, but here's I'd sketch it: + - There's no &mut references: all API operations copy data in and out + - All baseline objects are Send, but not Sync + - There are some problems with using "naked" Rc and Arc: + - We should not allow users to create Rc by themselves without including + parent pointer + - We should not allow DerefMut in Mutex and moving out of it + - Objects are wrapped in Rc> and Arc>, parent + pointer is part of ZeCell/ZeMutex: + - Then e.g. zeKernelCreate is mapped three times: + - unsafe Module(&self) -> Kernel + - Module(&Rc>) -> Rc> + - Module(&Arc>) -> Arc + - You create ZeCell by moving Module and Rc + - Pro: Rc and Arc are allowed to be self receivers + - Open question: should some operations take the parent mutex? If so, should + it be done recursively? +*/ + macro_rules! check { ($expr:expr) => { #[allow(unused_unsafe)] @@ -39,102 +65,155 @@ pub fn init() -> Result<()> { } } +// Mutability: no (list of allocations is under a mutex) +// Lifetime: 'static #[repr(transparent)] -pub struct Driver(sys::ze_driver_handle_t); +#[derive(Copy, Clone, PartialEq, Eq)] +pub struct Driver(NonNull); unsafe impl Send for Driver {} unsafe impl Sync for Driver {} impl Driver { - pub unsafe fn as_ffi(&self) -> sys::ze_driver_handle_t { - self.0 + pub unsafe fn as_ffi(self) -> sys::ze_driver_handle_t { + self.0.as_ptr() } pub unsafe fn from_ffi(x: sys::ze_driver_handle_t) -> Self { - Self(x) + if x == ptr::null_mut() { + panic!("FFI handle can't be zero") + } + Self(NonNull::new_unchecked(x)) } pub fn get() -> Result> { let mut len = 0; let mut temp = ptr::null_mut(); check!(sys::zeDriverGet(&mut len, &mut temp)); - let mut result = (0..len) - .map(|_| Driver(ptr::null_mut())) - .collect::>(); + let mut result = Vec::with_capacity(len as usize); check!(sys::zeDriverGet(&mut len, result.as_mut_ptr() as *mut _)); - Ok(result) - } - - pub fn devices(&self) -> Result> { - let mut len = 0; - let mut temp = ptr::null_mut(); - check!(sys::zeDeviceGet(self.0, &mut len, &mut temp)); - let mut result = (0..len) - .map(|_| Device(ptr::null_mut())) - .collect::>(); - check!(sys::zeDeviceGet( - self.0, - &mut len, - result.as_mut_ptr() as *mut _ - )); - if (len as usize) < result.len() { - result.truncate(len as usize); + unsafe { + result.set_len(len as usize); } Ok(result) } - pub fn get_properties(&self) -> Result { - let mut result = unsafe { mem::zeroed::() }; - check!(sys::zeDriverGetProperties(self.0, &mut result)); + pub fn devices(self) -> Result> { + let mut len = 0; + let mut temp = ptr::null_mut(); + check!(sys::zeDeviceGet(self.as_ffi(), &mut len, &mut temp)); + let mut result = Vec::with_capacity(len as usize); + check!(sys::zeDeviceGet( + self.as_ffi(), + &mut len, + result.as_mut_ptr() as *mut _ + )); + unsafe { + result.set_len(len as usize); + } Ok(result) } + + pub fn get_properties(self, props: &mut sys::ze_driver_properties_t) -> Result<()> { + check!(sys::zeDriverGetProperties(self.as_ffi(), props)); + Ok(()) + } } +// Mutability: no (list of peer allocations under a mutex) +// Lifetime: 'static #[repr(transparent)] -pub struct Device(sys::ze_device_handle_t); +#[derive(Copy, Clone, PartialEq, Eq)] +pub struct Device(NonNull); + +unsafe impl Send for Device {} +unsafe impl Sync for Device {} impl Device { - pub unsafe fn as_ffi(&self) -> sys::ze_device_handle_t { - self.0 + pub unsafe fn as_ffi(self) -> sys::ze_device_handle_t { + self.0.as_ptr() } pub unsafe fn from_ffi(x: sys::ze_device_handle_t) -> Self { - Self(x) + if x == ptr::null_mut() { + panic!("FFI handle can't be zero") + } + Self(NonNull::new_unchecked(x)) } - pub fn get_properties(&self) -> Result> { - let mut props = Box::new(unsafe { mem::zeroed::() }); - check! { sys::zeDeviceGetProperties(self.0, props.as_mut()) }; - Ok(props) + pub fn get_properties(self, props: &mut sys::ze_device_properties_t) -> Result<()> { + check! { sys::zeDeviceGetProperties(self.as_ffi(), props) }; + Ok(()) } - pub fn get_image_properties(&self) -> Result> { - let mut props = Box::new(unsafe { mem::zeroed::() }); - check! { sys::zeDeviceGetImageProperties(self.0, props.as_mut()) }; - Ok(props) + pub fn get_image_properties(self, props: &mut sys::ze_device_image_properties_t) -> Result<()> { + check! { sys::zeDeviceGetImageProperties(self.as_ffi(), props) }; + Ok(()) } - pub fn get_memory_properties(&self) -> Result> { + pub fn get_memory_properties(self) -> Result> { let mut count = 0u32; - check! { sys::zeDeviceGetMemoryProperties(self.0, &mut count, ptr::null_mut()) }; + check! { sys::zeDeviceGetMemoryProperties(self.as_ffi(), &mut count, ptr::null_mut()) }; if count == 0 { return Ok(Vec::new()); } let mut props = vec![unsafe { mem::zeroed::() }; count as usize]; - check! { sys::zeDeviceGetMemoryProperties(self.0, &mut count, props.as_mut_ptr()) }; + check! { sys::zeDeviceGetMemoryProperties(self.as_ffi(), &mut count, props.as_mut_ptr()) }; Ok(props) } - pub fn get_compute_properties(&self) -> Result> { - let mut props = Box::new(unsafe { mem::zeroed::() }); - check! { sys::zeDeviceGetComputeProperties(self.0, props.as_mut()) }; - Ok(props) + pub fn get_compute_properties( + self, + props: &mut sys::ze_device_compute_properties_t, + ) -> Result<()> { + check! { sys::zeDeviceGetComputeProperties(self.as_ffi(), props) }; + Ok(()) + } +} + +// Mutability: no +#[repr(transparent)] +pub struct Context(NonNull); + +unsafe impl Send for Context {} +unsafe impl Sync for Context {} + +impl Context { + pub unsafe fn as_ffi(&self) -> sys::ze_context_handle_t { + self.0.as_ptr() + } + pub unsafe fn from_ffi(x: sys::ze_context_handle_t) -> Self { + if x == ptr::null_mut() { + panic!("FFI handle can't be zero") + } + Self(NonNull::new_unchecked(x)) } - pub unsafe fn mem_alloc_device( - &mut self, - ctx: &mut Context, + pub fn new(drv: Driver, devices: Option<&[Device]>) -> Result { + let ctx_desc = sys::ze_context_desc_t { + stype: sys::ze_structure_type_t::ZE_STRUCTURE_TYPE_CONTEXT_DESC, + pNext: ptr::null(), + flags: sys::ze_context_flags_t(0), + }; + let mut result = ptr::null_mut(); + let (dev_ptr, dev_len) = match devices { + None => (ptr::null(), 0), + Some(devs) => (devs.as_ptr(), devs.len()), + }; + check!(sys::zeContextCreateEx( + drv.as_ffi(), + &ctx_desc, + dev_len as u32, + dev_ptr as *mut _, + &mut result + )); + Ok(unsafe { Self::from_ffi(result) }) + } + + pub fn mem_alloc_device( + &self, size: usize, alignment: usize, + device: Device, ) -> Result<*mut c_void> { let descr = sys::ze_device_mem_alloc_desc_t { stype: sys::ze_structure_type_t::ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC, @@ -143,47 +222,24 @@ impl Device { ordinal: 0, }; let mut result = ptr::null_mut(); - // TODO: check current context for the device check! { sys::zeMemAllocDevice( - ctx.0, + self.as_ffi(), &descr, size, alignment, - self.0, + device.as_ffi(), &mut result, ) }; Ok(result) } -} -#[repr(transparent)] -pub struct Context(sys::ze_context_handle_t); - -impl Context { - pub unsafe fn as_ffi(&self) -> sys::ze_context_handle_t { - self.0 - } - pub unsafe fn from_ffi(x: sys::ze_context_handle_t) -> Self { - Self(x) - } - - pub fn new(drv: &Driver) -> Result { - let ctx_desc = sys::ze_context_desc_t { - stype: sys::ze_structure_type_t::ZE_STRUCTURE_TYPE_CONTEXT_DESC, - pNext: ptr::null(), - flags: sys::ze_context_flags_t(0), - }; - let mut result = ptr::null_mut(); - check!(sys::zeContextCreate(drv.0, &ctx_desc, &mut result)); - Ok(Context(result)) - } - - pub unsafe fn mem_free(&mut self, ptr: *mut c_void) -> Result<()> { + // This operation is safe because Level Zero impl tracks allocations + pub fn mem_free(&self, ptr: *mut c_void) -> Result<()> { check! { sys::zeMemFree( - self.0, + self.as_ffi(), ptr, ) }; @@ -194,22 +250,32 @@ impl Context { impl Drop for Context { #[allow(unused_must_use)] fn drop(&mut self) { - check_panic! { sys::zeContextDestroy(self.0) }; + check_panic! { sys::zeContextDestroy(self.as_ffi()) }; } } +// Mutability: yes (residency container and others) +// Lifetime parent: Context #[repr(transparent)] -pub struct CommandQueue(sys::ze_command_queue_handle_t); +pub struct CommandQueue<'a>( + NonNull, + PhantomData<&'a ()>, +); -impl CommandQueue { +unsafe impl<'a> Send for CommandQueue<'a> {} + +impl<'a> CommandQueue<'a> { pub unsafe fn as_ffi(&self) -> sys::ze_command_queue_handle_t { - self.0 + self.0.as_ptr() } pub unsafe fn from_ffi(x: sys::ze_command_queue_handle_t) -> Self { - Self(x) + if x == ptr::null_mut() { + panic!("FFI handle can't be zero") + } + Self(NonNull::new_unchecked(x), PhantomData) } - pub fn new(ctx: &mut Context, d: &Device) -> Result { + pub fn new(ctx: &'a Context, d: Device) -> Result { let que_desc = sys::ze_command_queue_desc_t { stype: sys::ze_structure_type_t::ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC, pNext: ptr::null(), @@ -221,48 +287,138 @@ impl CommandQueue { }; let mut result = ptr::null_mut(); check!(sys::zeCommandQueueCreate( - ctx.0, - d.0, + ctx.as_ffi(), + d.as_ffi(), &que_desc, &mut result )); - Ok(CommandQueue(result)) + Ok(unsafe { Self::from_ffi(result) }) } - pub fn execute<'a>(&'a self, cmd: CommandList) -> Result> { - check!(sys::zeCommandListClose(cmd.0)); - let result = FenceGuard::new(self, cmd.0)?; - let mut raw_cmd = cmd.0; - mem::forget(cmd); + pub fn execute_and_synchronize<'cmd_list>( + &'a self, + cmd: CommandList<'cmd_list>, + ) -> Result> + where + 'a: 'cmd_list, + { + let fence_guard = FenceGuard::new(self, cmd)?; + unsafe { self.execute(&fence_guard.1, Some(&fence_guard.0))? }; + Ok(fence_guard) + } + + pub unsafe fn execute<'cmd_list, 'fence>( + &self, + cmd: &CommandList<'cmd_list>, + fence: Option<&Fence<'fence>>, + ) -> Result<()> + where + 'cmd_list: 'fence, + 'a: 'cmd_list, + { + let fence_ptr = fence.map_or(ptr::null_mut(), |f| f.as_ffi()); check!(sys::zeCommandQueueExecuteCommandLists( - self.0, + self.as_ffi(), 1, - &mut raw_cmd, - result.0 + &mut cmd.as_ffi(), + fence_ptr )); - Ok(result) + Ok(()) } } -impl Drop for CommandQueue { +impl<'a> Drop for CommandQueue<'a> { #[allow(unused_must_use)] fn drop(&mut self) { - check_panic! { sys::zeCommandQueueDestroy(self.0) }; + check_panic! { sys::zeCommandQueueDestroy(self.as_ffi()) }; } } -pub struct Module(sys::ze_module_handle_t); +pub struct FenceGuard<'a>(Fence<'a>, CommandList<'a>); + +impl<'a> FenceGuard<'a> { + fn new(q: &'a CommandQueue, cmd_list: CommandList<'a>) -> Result { + Ok(FenceGuard(Fence::new(q)?, cmd_list)) + } +} + +impl<'a> Drop for FenceGuard<'a> { + #[allow(unused_must_use)] + fn drop(&mut self) { + if let Err(e) = self.0.host_synchronize() { + panic!(e) + } + } +} + +// Mutability: yes (reset) +// Lifetime parent: queue +#[repr(transparent)] +pub struct Fence<'a>(NonNull, PhantomData<&'a ()>); + +unsafe impl<'a> Send for Fence<'a> {} + +impl<'a> Fence<'a> { + pub unsafe fn as_ffi(&self) -> sys::ze_fence_handle_t { + self.0.as_ptr() + } + pub unsafe fn from_ffi(x: sys::ze_fence_handle_t) -> Self { + if x == ptr::null_mut() { + panic!("FFI handle can't be zero") + } + Self(NonNull::new_unchecked(x), PhantomData) + } + + pub fn new(queue: &'a CommandQueue) -> Result { + let desc = sys::_ze_fence_desc_t { + stype: sys::ze_structure_type_t::ZE_STRUCTURE_TYPE_FENCE_DESC, + pNext: ptr::null(), + flags: sys::ze_fence_flags_t(0), + }; + let mut result = ptr::null_mut(); + check!(sys::zeFenceCreate(queue.as_ffi(), &desc, &mut result)); + Ok(unsafe { Self::from_ffi(result) }) + } + + pub fn host_synchronize(&self) -> Result<()> { + check!(sys::zeFenceHostSynchronize(self.as_ffi(), u64::max_value())); + Ok(()) + } +} + +impl<'a> Drop for Fence<'a> { + fn drop(&mut self) { + check_panic! { zeFenceDestroy(self.as_ffi()) }; + } +} + +// Mutability: yes (building, linking) +// Lifetime parent: Context +#[repr(transparent)] +pub struct Module<'a>(NonNull, PhantomData<&'a ()>); + +unsafe impl<'a> Send for Module<'a> {} + +impl<'a> Module<'a> { + pub unsafe fn as_ffi(&self) -> sys::ze_module_handle_t { + self.0.as_ptr() + } + pub unsafe fn from_ffi(x: sys::ze_module_handle_t) -> Self { + if x == ptr::null_mut() { + panic!("FFI handle can't be zero") + } + Self(NonNull::new_unchecked(x), PhantomData) + } -impl Module { // HACK ALERT // We use OpenCL for now to do SPIR-V linking, because Level0 // does not allow linking. Don't let presence of zeModuleDynamicLink fool // you, it's not currently possible to create non-compiled modules. // zeModuleCreate always compiles (builds and links). - pub fn build_link_spirv<'a>( - ctx: &mut Context, - d: &Device, - binaries: &[&'a [u8]], + pub fn build_link_spirv<'buffers>( + ctx: &'a Context, + d: Device, + binaries: &[&'buffers [u8]], opts: Option<&CStr>, ) -> (Result, Option) { let ocl_program = match Self::build_link_spirv_impl(binaries, opts) { @@ -283,8 +439,8 @@ impl Module { } } - fn build_link_spirv_impl<'a>( - binaries: &[&'a [u8]], + fn build_link_spirv_impl<'buffers>( + binaries: &[&'buffers [u8]], opts: Option<&CStr>, ) -> ocl_core::Result { let platforms = ocl_core::get_platform_ids()?; @@ -348,8 +504,8 @@ impl Module { } pub fn build_spirv( - ctx: &mut Context, - d: &Device, + ctx: &'a Context, + d: Device, bin: &[u8], opts: Option<&CStr>, ) -> Result { @@ -357,8 +513,8 @@ impl Module { } pub fn build_spirv_logged( - ctx: &mut Context, - d: &Device, + ctx: &'a Context, + d: Device, bin: &[u8], opts: Option<&CStr>, ) -> (Result, BuildLog) { @@ -366,17 +522,17 @@ impl Module { } pub fn build_native_logged( - ctx: &mut Context, - d: &Device, + ctx: &'a Context, + d: Device, bin: &[u8], ) -> (Result, BuildLog) { Module::new_logged(ctx, false, d, bin, None) } fn new( - ctx: &mut Context, + ctx: &'a Context, spirv: bool, - d: &Device, + d: Device, bin: &[u8], opts: Option<&CStr>, ) -> Result { @@ -394,18 +550,22 @@ impl Module { pConstants: ptr::null(), }; let mut result: sys::ze_module_handle_t = ptr::null_mut(); - let err = unsafe { sys::zeModuleCreate(ctx.0, d.0, &desc, &mut result, ptr::null_mut()) }; - if err != crate::sys::ze_result_t::ZE_RESULT_SUCCESS { - Result::Err(err) - } else { - Ok(Module(result)) - } + check! { + sys::zeModuleCreate( + ctx.as_ffi(), + d.as_ffi(), + &desc, + &mut result, + ptr::null_mut(), + ) + }; + Ok(unsafe { Self::from_ffi(result) }) } fn new_logged( - ctx: &mut Context, + ctx: &'a Context, spirv: bool, - d: &Device, + d: Device, bin: &[u8], opts: Option<&CStr>, ) -> (Result, BuildLog) { @@ -424,74 +584,83 @@ impl Module { }; let mut result: sys::ze_module_handle_t = ptr::null_mut(); let mut log_handle = ptr::null_mut(); - let err = unsafe { sys::zeModuleCreate(ctx.0, d.0, &desc, &mut result, &mut log_handle) }; - let log = BuildLog(log_handle); - if err != crate::sys::ze_result_t::ZE_RESULT_SUCCESS { + let err = unsafe { + sys::zeModuleCreate( + ctx.as_ffi(), + d.as_ffi(), + &desc, + &mut result, + &mut log_handle, + ) + }; + let log = unsafe { BuildLog::from_ffi(log_handle) }; + if err != sys::ze_result_t::ZE_RESULT_SUCCESS { (Result::Err(err), log) } else { - (Ok(Module(result)), log) + (Ok(unsafe { Self::from_ffi(result) }), log) } } } -impl Drop for Module { +impl<'a> Drop for Module<'a> { #[allow(unused_must_use)] fn drop(&mut self) { - check_panic! { sys::zeModuleDestroy(self.0) }; + check_panic! { sys::zeModuleDestroy(self.as_ffi()) }; } } -pub struct BuildLog(sys::ze_module_build_log_handle_t); +// Mutability: none +// Lifetime parent: none, but need to destroy +pub struct BuildLog(NonNull); + +unsafe impl Sync for BuildLog {} +unsafe impl Send for BuildLog {} impl BuildLog { pub unsafe fn as_ffi(&self) -> sys::ze_module_build_log_handle_t { - self.0 + self.0.as_ptr() } pub unsafe fn from_ffi(x: sys::ze_module_build_log_handle_t) -> Self { - Self(x) + if x == ptr::null_mut() { + panic!("FFI handle can't be zero") + } + Self(NonNull::new_unchecked(x)) } - pub fn get_cstring(&self) -> Result { + pub fn to_cstring(&self) -> Result { let mut size = 0; - check! { sys::zeModuleBuildLogGetString(self.0, &mut size, ptr::null_mut()) }; + check! { sys::zeModuleBuildLogGetString(self.as_ffi(), &mut size, ptr::null_mut()) }; let mut str_vec = vec![0u8; size]; - check! { sys::zeModuleBuildLogGetString(self.0, &mut size, str_vec.as_mut_ptr() as *mut i8) }; - str_vec.pop(); - Ok(CString::new(str_vec).map_err(|_| sys::ze_result_t::ZE_RESULT_ERROR_UNKNOWN)?) + check! { sys::zeModuleBuildLogGetString(self.as_ffi(), &mut size, str_vec.as_mut_ptr() as *mut i8) }; + str_vec.push(0); + Ok(unsafe { CString::from_vec_unchecked(str_vec) }) } } impl Drop for BuildLog { fn drop(&mut self) { - check_panic!(sys::zeModuleBuildLogDestroy(self.0)); + check_panic!(sys::zeModuleBuildLogDestroy(self.as_ffi())); } } -pub trait SafeRepr {} -impl SafeRepr for u8 {} -impl SafeRepr for i8 {} -impl SafeRepr for u16 {} -impl SafeRepr for i16 {} -impl SafeRepr for u32 {} -impl SafeRepr for i32 {} -impl SafeRepr for u64 {} -impl SafeRepr for i64 {} -impl SafeRepr for f32 {} -impl SafeRepr for f64 {} - -pub struct DeviceBuffer { +// Mutability: none +// Lifetime parent: Context +pub struct DeviceBuffer<'a, T: Copy> { ptr: *mut c_void, ctx: sys::ze_context_handle_t, len: usize, - marker: PhantomData, + marker: PhantomData<&'a T>, } -impl DeviceBuffer { - pub unsafe fn as_ffi(&self) -> *mut c_void { - self.ptr +unsafe impl<'a, T: Copy> Sync for DeviceBuffer<'a, T> {} +unsafe impl<'a, T: Copy> Send for DeviceBuffer<'a, T> {} + +impl<'a, T: Copy> DeviceBuffer<'a, T> { + pub unsafe fn as_ffi(&self) -> (sys::ze_context_handle_t, *mut c_void, usize) { + (self.ctx, self.ptr, self.len) } pub unsafe fn from_ffi(ctx: sys::ze_context_handle_t, ptr: *mut c_void, len: usize) -> Self { - let marker = PhantomData::; + let marker = PhantomData::<&'a T>; Self { ptr, ctx, @@ -500,7 +669,7 @@ impl DeviceBuffer { } } - pub fn new(ctx: &mut Context, dev: &Device, len: usize) -> Result { + pub fn new(ctx: &'a Context, dev: Device, len: usize) -> Result { let desc = sys::_ze_device_mem_alloc_desc_t { stype: sys::ze_structure_type_t::ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC, pNext: ptr::null(), @@ -509,39 +678,49 @@ impl DeviceBuffer { }; let mut result = ptr::null_mut(); check!(sys::zeMemAllocDevice( - ctx.0, + ctx.as_ffi(), &desc, len * mem::size_of::(), mem::align_of::(), - dev.0, + dev.as_ffi(), &mut result )); - Ok(unsafe { Self::from_ffi(ctx.0, result, len) }) + Ok(unsafe { Self::from_ffi(ctx.as_ffi(), result, len) }) } pub fn len(&self) -> usize { self.len } + + pub fn data(&self) -> *mut c_void { + self.ptr + } } -impl Drop for DeviceBuffer { - #[allow(unused_must_use)] +impl<'a, T: Copy> Drop for DeviceBuffer<'a, T> { fn drop(&mut self) { check_panic! { sys::zeMemFree(self.ctx, self.ptr) }; } } -pub struct CommandList<'a>(sys::ze_command_list_handle_t, PhantomData<&'a ()>); +// Mutability: yes (appends) +// Lifetime parent: Context +pub struct CommandList<'a>(NonNull, PhantomData<&'a ()>); + +unsafe impl<'a> Send for CommandList<'a> {} impl<'a> CommandList<'a> { pub unsafe fn as_ffi(&self) -> sys::ze_command_list_handle_t { - self.0 + self.0.as_ptr() } pub unsafe fn from_ffi(x: sys::ze_command_list_handle_t) -> Self { - Self(x, PhantomData) + if x == ptr::null_mut() { + panic!("FFI handle can't be zero") + } + Self(NonNull::new_unchecked(x), PhantomData) } - pub fn new(ctx: &mut Context, dev: &Device) -> Result { + pub fn new(ctx: &'a Context, dev: Device) -> Result { let desc = sys::ze_command_list_desc_t { stype: sys::_ze_structure_type_t::ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC, commandQueueGroupOrdinal: 0, @@ -549,40 +728,46 @@ impl<'a> CommandList<'a> { flags: sys::ze_command_list_flags_t(0), }; let mut result: sys::ze_command_list_handle_t = ptr::null_mut(); - check!(sys::zeCommandListCreate(ctx.0, dev.0, &desc, &mut result)); - Ok(Self(result, PhantomData)) + check!(sys::zeCommandListCreate( + ctx.as_ffi(), + dev.as_ffi(), + &desc, + &mut result + )); + Ok(unsafe { Self::from_ffi(result) }) } - pub fn append_memory_copy< - T: 'a, - Dst: Into>, - Src: Into>, - >( - &mut self, + pub fn append_memory_copy<'event, T: 'a, Dst: Into>, Src: Into>>( + &'a self, dst: Dst, src: Src, - signal: Option<&mut Event<'a>>, - wait: &mut [Event<'a>], - ) -> Result<()> { + signal: Option<&Event<'event>>, + wait: &[Event<'event>], + ) -> Result<()> + where + 'event: 'a, + { let dst = dst.into(); let src = src.into(); let elements = std::cmp::min(dst.len(), src.len()); let length = elements * mem::size_of::(); - unsafe { self.append_memory_copy_unsafe(dst.get(), src.get(), length, signal, wait) } + unsafe { + self.append_memory_copy_unsafe(dst.as_mut_ptr(), src.as_ptr(), length, signal, wait) + } } pub unsafe fn append_memory_copy_unsafe( - &mut self, + &self, dst: *mut c_void, src: *const c_void, length: usize, - signal: Option<&mut Event<'a>>, - wait: &mut [Event<'a>], + signal: Option<&Event>, + wait: &[Event], ) -> Result<()> { - let signal_event = signal.map(|e| e.0).unwrap_or(ptr::null_mut()); + let signal_event = signal.map(|e| e.as_ffi()).unwrap_or(ptr::null_mut()); let (wait_len, wait_ptr) = Event::raw_slice(wait); check!(sys::zeCommandListAppendMemoryCopy( - self.0, + self.as_ffi(), dst, src, length, @@ -593,20 +778,26 @@ impl<'a> CommandList<'a> { Ok(()) } - pub fn append_memory_fill( - &mut self, - dst: BufferPtrMut<'a, T>, + pub fn append_memory_fill<'event, T: 'a, Dst: Into>>( + &'a self, + dst: Dst, pattern: u8, - signal: Option<&mut Event<'a>>, - wait: &mut [Event<'a>], - ) -> Result<()> { + signal: Option<&Event<'event>>, + wait: &[Event<'event>], + ) -> Result<()> + where + 'event: 'a, + { + let dst = dst.into(); let raw_pattern = &pattern as *const u8 as *const _; - let signal_event = signal.map(|e| e.0).unwrap_or(ptr::null_mut()); + let signal_event = signal + .map(|e| unsafe { e.as_ffi() }) + .unwrap_or(ptr::null_mut()); let (wait_len, wait_ptr) = unsafe { Event::raw_slice(wait) }; let byte_len = dst.len() * mem::size_of::(); check!(sys::zeCommandListAppendMemoryFill( - self.0, - dst.get(), + self.as_ffi(), + dst.as_mut_ptr(), raw_pattern, mem::size_of::(), byte_len, @@ -618,17 +809,17 @@ impl<'a> CommandList<'a> { } pub unsafe fn append_memory_fill_unsafe( - &mut self, + &self, dst: *mut c_void, pattern: &T, byte_size: usize, - signal: Option<&mut Event<'a>>, - wait: &mut [Event<'a>], + signal: Option<&Event>, + wait: &[Event], ) -> Result<()> { - let signal_event = signal.map(|e| e.0).unwrap_or(ptr::null_mut()); + let signal_event = signal.map(|e| e.as_ffi()).unwrap_or(ptr::null_mut()); let (wait_len, wait_ptr) = Event::raw_slice(wait); check!(sys::zeCommandListAppendMemoryFill( - self.0, + self.as_ffi(), dst, pattern as *const T as *const _, mem::size_of::(), @@ -640,23 +831,29 @@ impl<'a> CommandList<'a> { Ok(()) } - pub fn append_launch_kernel( - &mut self, - kernel: &'a Kernel, + pub fn append_launch_kernel<'event, 'kernel>( + &'a self, + kernel: &'kernel Kernel, group_count: &[u32; 3], - signal: Option<&mut Event<'a>>, - wait: &mut [Event<'a>], - ) -> Result<()> { + signal: Option<&Event<'event>>, + wait: &[Event<'event>], + ) -> Result<()> + where + 'event: 'a, + 'kernel: 'a, + { let gr_count = sys::ze_group_count_t { groupCountX: group_count[0], groupCountY: group_count[1], groupCountZ: group_count[2], }; - let signal_event = signal.map(|e| e.0).unwrap_or(ptr::null_mut()); + let signal_event = signal + .map(|e| unsafe { e.as_ffi() }) + .unwrap_or(ptr::null_mut()); let (wait_len, wait_ptr) = unsafe { Event::raw_slice(wait) }; check!(sys::zeCommandListAppendLaunchKernel( - self.0, - kernel.0, + self.as_ffi(), + kernel.as_ffi(), &gr_count, signal_event, wait_len, @@ -664,176 +861,129 @@ impl<'a> CommandList<'a> { )); Ok(()) } + + pub fn close(&self) -> Result<()> { + check!(sys::zeCommandListClose(self.as_ffi())); + Ok(()) + } } impl<'a> Drop for CommandList<'a> { #[allow(unused_must_use)] fn drop(&mut self) { - check_panic! { sys::zeCommandListDestroy(self.0) }; - } -} - -pub struct FenceGuard<'a>( - sys::ze_fence_handle_t, - sys::ze_command_list_handle_t, - PhantomData<&'a ()>, -); - -impl<'a> FenceGuard<'a> { - fn new(q: &'a CommandQueue, cmd_list: sys::ze_command_list_handle_t) -> Result { - let desc = sys::_ze_fence_desc_t { - stype: sys::ze_structure_type_t::ZE_STRUCTURE_TYPE_FENCE_DESC, - pNext: ptr::null(), - flags: sys::ze_fence_flags_t(0), - }; - let mut result = ptr::null_mut(); - check!(sys::zeFenceCreate(q.0, &desc, &mut result)); - Ok(FenceGuard(result, cmd_list, PhantomData)) - } -} - -impl<'a> Drop for FenceGuard<'a> { - #[allow(unused_must_use)] - fn drop(&mut self) { - check_panic! { sys::zeFenceHostSynchronize(self.0, u64::max_value()) }; - check_panic! { sys::zeFenceDestroy(self.0) }; - check_panic! { sys::zeCommandListDestroy(self.1) }; + check_panic! { sys::zeCommandListDestroy(self.as_ffi()) }; } } #[derive(Copy, Clone)] -pub struct BufferPtr<'a, T> { - ptr: *const c_void, - marker: PhantomData<&'a T>, - elems: usize, -} - -impl<'a, T> BufferPtr<'a, T> { - pub unsafe fn get(self) -> *const c_void { - return self.ptr; - } - - pub fn len(&self) -> usize { - self.elems - } -} - -impl<'a, T> From<&'a [T]> for BufferPtr<'a, T> { - fn from(s: &'a [T]) -> Self { - BufferPtr { - ptr: s.as_ptr() as *const _, - marker: PhantomData, - elems: s.len(), - } - } -} - -impl<'a, T: SafeRepr> From<&'a DeviceBuffer> for BufferPtr<'a, T> { - fn from(b: &'a DeviceBuffer) -> Self { - BufferPtr { - ptr: b.ptr as *const _, - marker: PhantomData, - elems: b.len(), - } - } -} - -#[derive(Copy, Clone)] -pub struct BufferPtrMut<'a, T> { +pub struct Slice<'a, T> { ptr: *mut c_void, - marker: PhantomData<&'a mut T>, - elems: usize, + len: usize, + marker: PhantomData<&'a T>, } -impl<'a, T> BufferPtrMut<'a, T> { - pub unsafe fn get(self) -> *mut c_void { - return self.ptr; +unsafe impl<'a, T> Send for Slice<'a, T> {} +unsafe impl<'a, T> Sync for Slice<'a, T> {} + +impl<'a, T> Slice<'a, T> { + pub unsafe fn new(ptr: *mut c_void, len: usize) -> Self { + Self { + ptr, + len, + marker: PhantomData, + } + } + + pub fn as_ptr(&self) -> *const c_void { + self.ptr + } + + pub fn as_mut_ptr(&self) -> *mut c_void { + self.ptr } pub fn len(&self) -> usize { - self.elems + self.len } } -impl<'a, T> From<&'a mut [T]> for BufferPtrMut<'a, T> { - fn from(s: &'a mut [T]) -> Self { - BufferPtrMut { - ptr: s.as_mut_ptr() as *mut _, +impl<'a, T> From<&'a [T]> for Slice<'a, T> { + fn from(s: &'a [T]) -> Self { + Slice { + ptr: s.as_ptr() as *mut _, + len: s.len(), marker: PhantomData, - elems: s.len(), } } } -impl<'a, T: SafeRepr> From<&'a mut DeviceBuffer> for BufferPtrMut<'a, T> { - fn from(b: &'a mut DeviceBuffer) -> Self { - BufferPtrMut { - ptr: b.ptr as *mut _, - marker: PhantomData, - elems: b.len(), - } - } -} - -impl<'a, T: SafeRepr> From> for BufferPtr<'a, T> { - fn from(b: BufferPtrMut<'a, T>) -> Self { - BufferPtr { +impl<'a, T: Copy> From<&'a DeviceBuffer<'a, T>> for Slice<'a, T> { + fn from(b: &'a DeviceBuffer<'a, T>) -> Self { + Slice { ptr: b.ptr, + len: b.len, marker: PhantomData, - elems: b.len(), } } } -pub struct EventPool<'a>(sys::ze_event_pool_handle_t, PhantomData<&'a ()>); + +// Mutability: yes (appends) +// Lifetime parent: Context +pub struct EventPool<'a>(NonNull, PhantomData<&'a ()>); impl<'a> EventPool<'a> { pub unsafe fn as_ffi(&self) -> sys::ze_event_pool_handle_t { - self.0 + self.0.as_ptr() } pub unsafe fn from_ffi(x: sys::ze_event_pool_handle_t) -> Self { - Self(x, PhantomData) + if x == ptr::null_mut() { + panic!("FFI handle can't be zero") + } + Self(NonNull::new_unchecked(x), PhantomData) } - pub fn new(ctx: &mut Context, count: u32, dev: Option<&[&'a Device]>) -> Result { + + pub fn new(ctx: &'a Context, count: u32, devs: Option<&[Device]>) -> Result { let desc = sys::ze_event_pool_desc_t { stype: sys::ze_structure_type_t::ZE_STRUCTURE_TYPE_EVENT_POOL_DESC, pNext: ptr::null(), flags: sys::ze_event_pool_flags_t(0), count: count, }; - let mut dev = dev.map(|d| d.iter().map(|d| d.0).collect::>()); - let dev_len = dev.as_ref().map_or(0, |d| d.len() as u32); - let dev_ptr = dev.as_mut().map_or(ptr::null_mut(), |d| d.as_mut_ptr()); + let (dev_len, dev_ptr) = devs.map_or((0, ptr::null_mut()), |devs| { + (devs.len(), devs.as_ptr() as *mut _) + }); let mut result = ptr::null_mut(); check!(sys::zeEventPoolCreate( - ctx.0, + ctx.as_ffi(), &desc, - dev_len, + dev_len as u32, dev_ptr, &mut result )); - Ok(Self(result, PhantomData)) + Ok(unsafe { Self::from_ffi(result) }) } } impl<'a> Drop for EventPool<'a> { - #[allow(unused_must_use)] fn drop(&mut self) { - check_panic! { sys::zeEventPoolDestroy(self.0) }; + check_panic! { sys::zeEventPoolDestroy(self.as_ffi()) }; } } -pub struct Event<'a>(sys::ze_event_handle_t, PhantomData<&'a ()>); +pub struct Event<'a>(NonNull, PhantomData<&'a ()>); impl<'a> Event<'a> { pub unsafe fn as_ffi(&self) -> sys::ze_event_handle_t { - self.0 + self.0.as_ptr() } - pub unsafe fn from_ffi(x: sys::ze_event_handle_t) -> Self { - Self(x, PhantomData) + if x == ptr::null_mut() { + panic!("FFI handle can't be zero") + } + Self(NonNull::new_unchecked(x), PhantomData) } - pub fn new(pool: &'a EventPool, index: u32) -> Result { + pub fn new(pool: &'a EventPool<'a>, index: u32) -> Result { let desc = sys::ze_event_desc_t { stype: sys::ze_structure_type_t::ZE_STRUCTURE_TYPE_EVENT_DESC, pNext: ptr::null(), @@ -842,36 +992,37 @@ impl<'a> Event<'a> { wait: sys::ze_event_scope_flags_t(0), }; let mut result = ptr::null_mut(); - check!(sys::zeEventCreate(pool.0, &desc, &mut result)); - Ok(Self(result, PhantomData)) + check!(sys::zeEventCreate(pool.as_ffi(), &desc, &mut result)); + Ok(unsafe { Self::from_ffi(result) }) } - unsafe fn raw_slice(e: &mut [Event]) -> (u32, *mut sys::ze_event_handle_t) { + unsafe fn raw_slice(e: &[Event]) -> (u32, *mut sys::ze_event_handle_t) { let ptr = if e.len() == 0 { - ptr::null_mut() + ptr::null() } else { - e.as_mut_ptr() + e.as_ptr() }; (e.len() as u32, ptr as *mut sys::ze_event_handle_t) } } impl<'a> Drop for Event<'a> { - #[allow(unused_must_use)] fn drop(&mut self) { - check_panic! { sys::zeEventDestroy(self.0) }; + check_panic! { sys::zeEventDestroy(self.as_ffi()) }; } } -pub struct Kernel<'a>(sys::ze_kernel_handle_t, PhantomData<&'a ()>); +pub struct Kernel<'a>(NonNull, PhantomData<&'a ()>); impl<'a> Kernel<'a> { pub unsafe fn as_ffi(&self) -> sys::ze_kernel_handle_t { - self.0 + self.0.as_ptr() } - pub unsafe fn from_ffi(x: sys::ze_kernel_handle_t) -> Self { - Self(x, PhantomData) + if x == ptr::null_mut() { + panic!("FFI handle can't be zero") + } + Self(NonNull::new_unchecked(x), PhantomData) } pub fn new_resident(module: &'a Module, name: &CStr) -> Result { @@ -882,26 +1033,23 @@ impl<'a> Kernel<'a> { pKernelName: name.as_ptr() as *const _, }; let mut result = ptr::null_mut(); - check!(sys::zeKernelCreate(module.0, &desc, &mut result)); - Ok(Self(result, PhantomData)) + check!(sys::zeKernelCreate(module.as_ffi(), &desc, &mut result)); + Ok(unsafe { Self::from_ffi(result) }) } - pub fn set_indirect_access( - &mut self, - flags: sys::ze_kernel_indirect_access_flags_t, - ) -> Result<()> { - check!(sys::zeKernelSetIndirectAccess(self.0, flags)); + pub fn set_indirect_access(&self, flags: sys::ze_kernel_indirect_access_flags_t) -> Result<()> { + check!(sys::zeKernelSetIndirectAccess(self.as_ffi(), flags)); Ok(()) } - pub fn set_arg_buffer>>( + pub fn set_arg_buffer>>( &self, index: u32, buff: Buff, ) -> Result<()> { - let ptr = unsafe { buff.into().get() }; + let ptr = buff.into().as_mut_ptr(); check!(sys::zeKernelSetArgumentValue( - self.0, + self.as_ffi(), index, mem::size_of::<*const ()>(), &ptr as *const _ as *const _, @@ -911,7 +1059,7 @@ impl<'a> Kernel<'a> { pub fn set_arg_scalar(&self, index: u32, value: &T) -> Result<()> { check!(sys::zeKernelSetArgumentValue( - self.0, + self.as_ffi(), index, mem::size_of::(), value as *const T as *const _, @@ -920,18 +1068,26 @@ impl<'a> Kernel<'a> { } pub unsafe fn set_arg_raw(&self, index: u32, size: usize, value: *const c_void) -> Result<()> { - check!(sys::zeKernelSetArgumentValue(self.0, index, size, value)); + check!(sys::zeKernelSetArgumentValue( + self.as_ffi(), + index, + size, + value + )); Ok(()) } pub fn set_group_size(&self, x: u32, y: u32, z: u32) -> Result<()> { - check!(sys::zeKernelSetGroupSize(self.0, x, y, z)); + check!(sys::zeKernelSetGroupSize(self.as_ffi(), x, y, z)); Ok(()) } pub fn get_properties(&self) -> Result> { let mut props = Box::new(unsafe { mem::zeroed::() }); - check!(sys::zeKernelGetProperties(self.0, props.as_mut() as *mut _)); + check!(sys::zeKernelGetProperties( + self.as_ffi(), + props.as_mut() as *mut _ + )); Ok(props) } } @@ -939,7 +1095,7 @@ impl<'a> Kernel<'a> { impl<'a> Drop for Kernel<'a> { #[allow(unused_must_use)] fn drop(&mut self) { - check_panic! { sys::zeKernelDestroy(self.0) }; + check_panic! { sys::zeKernelDestroy(self.as_ffi()) }; } } diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 14d3284..94114db 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -201,8 +201,8 @@ impl error::Error for DisplayError {} fn test_ptx_assert< 'a, - Input: From + ze::SafeRepr + Debug + Copy + PartialEq, - Output: From + ze::SafeRepr + Debug + Copy + PartialEq, + Input: From + Debug + Copy + PartialEq, + Output: From + Debug + Copy + PartialEq, >( name: &str, ptx_text: &'a str, @@ -220,10 +220,7 @@ fn test_ptx_assert< Ok(()) } -fn run_spirv< - Input: From + ze::SafeRepr + Copy + Debug, - Output: From + ze::SafeRepr + Copy + Debug, ->( +fn run_spirv + Copy + Debug, Output: From + Copy + Debug>( name: &CStr, module: translate::Module, input: &[Input], @@ -242,25 +239,25 @@ fn run_spirv< .get(name.to_str().unwrap()) .map(|info| info.uses_shared_mem) .unwrap_or(false); - let mut result = vec![0u8.into(); output.len()]; + let result = vec![0u8.into(); output.len()]; { let mut drivers = ze::Driver::get()?; let drv = drivers.drain(0..1).next().unwrap(); - let mut ctx = ze::Context::new(&drv)?; let mut devices = drv.devices()?; let dev = devices.drain(0..1).next().unwrap(); - let queue = ze::CommandQueue::new(&mut ctx, &dev)?; + let ctx = ze::Context::new(drv, None)?; + let queue = ze::CommandQueue::new(&ctx, dev)?; let (module, maybe_log) = match module.should_link_ptx_impl { Some(ptx_impl) => ze::Module::build_link_spirv( - &mut ctx, - &dev, + &ctx, + dev, &[ptx_impl, byte_il], Some(module.build_options.as_c_str()), ), None => { let (module, log) = ze::Module::build_spirv_logged( - &mut ctx, - &dev, + &ctx, + dev, byte_il, Some(module.build_options.as_c_str()), ); @@ -271,38 +268,38 @@ fn run_spirv< Ok(m) => m, Err(err) => { let raw_err_string = maybe_log - .map(|log| log.get_cstring()) + .map(|log| log.to_cstring()) .transpose()? .unwrap_or(CString::default()); let err_string = raw_err_string.to_string_lossy(); panic!("{:?}\n{}", err, err_string); } }; - let mut kernel = ze::Kernel::new_resident(&module, name)?; + let kernel = ze::Kernel::new_resident(&module, name)?; kernel.set_indirect_access( ze::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE, )?; - let mut inp_b = ze::DeviceBuffer::::new(&mut ctx, &dev, cmp::max(input.len(), 1))?; - let mut out_b = ze::DeviceBuffer::::new(&mut ctx, &dev, cmp::max(output.len(), 1))?; - let inp_b_ptr_mut: ze::BufferPtrMut = (&mut inp_b).into(); - let event_pool = ze::EventPool::new(&mut ctx, 3, Some(&[&dev]))?; + let inp_b = ze::DeviceBuffer::::new(&ctx, dev, cmp::max(input.len(), 1))?; + let out_b = ze::DeviceBuffer::::new(&ctx, dev, cmp::max(output.len(), 1))?; + let event_pool = ze::EventPool::new(&ctx, 3, Some(&[dev]))?; let ev0 = ze::Event::new(&event_pool, 0)?; let ev1 = ze::Event::new(&event_pool, 1)?; - let mut ev2 = ze::Event::new(&event_pool, 2)?; - let mut cmd_list = ze::CommandList::new(&mut ctx, &dev)?; - let out_b_ptr_mut: ze::BufferPtrMut = (&mut out_b).into(); - let mut init_evs = [ev0, ev1]; - cmd_list.append_memory_copy(inp_b_ptr_mut, input, Some(&mut init_evs[0]), &mut [])?; - cmd_list.append_memory_fill(out_b_ptr_mut, 0, Some(&mut init_evs[1]), &mut [])?; - kernel.set_group_size(1, 1, 1)?; - kernel.set_arg_buffer(0, inp_b_ptr_mut)?; - kernel.set_arg_buffer(1, out_b_ptr_mut)?; - if use_shared_mem { - unsafe { kernel.set_arg_raw(2, 128, ptr::null())? }; + let ev2 = ze::Event::new(&event_pool, 2)?; + { + let cmd_list = ze::CommandList::new(&ctx, dev)?; + let init_evs = [ev0, ev1]; + cmd_list.append_memory_copy(&inp_b, input, Some(&init_evs[0]), &[])?; + cmd_list.append_memory_fill(&out_b, 0, Some(&init_evs[1]), &[])?; + kernel.set_group_size(1, 1, 1)?; + kernel.set_arg_buffer(0, &inp_b)?; + kernel.set_arg_buffer(1, &out_b)?; + if use_shared_mem { + unsafe { kernel.set_arg_raw(2, 128, ptr::null())? }; + } + cmd_list.append_launch_kernel(&kernel, &[1, 1, 1], Some(&ev2), &init_evs)?; + cmd_list.append_memory_copy(&*result, &out_b, None, &[ev2])?; + queue.execute_and_synchronize(cmd_list)?; } - cmd_list.append_launch_kernel(&kernel, &[1, 1, 1], Some(&mut ev2), &mut init_evs)?; - cmd_list.append_memory_copy(result.as_mut_slice(), out_b_ptr_mut, None, &mut [ev2])?; - queue.execute(cmd_list)?; } Ok(result) } diff --git a/zluda/src/impl/context.rs b/zluda/src/impl/context.rs index 2d72460..5ef427e 100644 --- a/zluda/src/impl/context.rs +++ b/zluda/src/impl/context.rs @@ -98,8 +98,8 @@ pub struct ContextData { impl ContextData { pub fn new( - l0_ctx: &mut l0::Context, - l0_dev: &l0::Device, + l0_ctx: &'static l0::Context, + l0_dev: l0::Device, flags: c_uint, is_primary: bool, dev: *mut device::Device, @@ -137,7 +137,7 @@ pub fn create_v2( let dev_ptr = dev as *mut _; let mut ctx_box = Box::new(LiveCheck::new(ContextData::new( &mut dev.l0_context, - &dev.base, + dev.base, flags, false, dev_ptr as *mut _, diff --git a/zluda/src/impl/device.rs b/zluda/src/impl/device.rs index 29cac2d..63bf39f 100644 --- a/zluda/src/impl/device.rs +++ b/zluda/src/impl/device.rs @@ -18,7 +18,7 @@ pub struct Index(pub c_int); pub struct Device { pub index: Index, pub base: l0::Device, - pub default_queue: l0::CommandQueue, + pub default_queue: l0::CommandQueue<'static>, pub l0_context: l0::Context, pub primary_context: context::Context, properties: Option>, @@ -31,12 +31,13 @@ unsafe impl Send for Device {} impl Device { // Unsafe because it does not fully initalize primary_context + // and we transmute lifetimes left and right unsafe fn new(drv: &l0::Driver, l0_dev: l0::Device, idx: usize) -> Result { - let mut ctx = l0::Context::new(drv)?; - let queue = l0::CommandQueue::new(&mut ctx, &l0_dev)?; + let ctx = l0::Context::new(*drv, Some(&[l0_dev]))?; + let queue = l0::CommandQueue::new(mem::transmute(&ctx), l0_dev)?; let primary_context = context::Context::new(context::ContextData::new( - &mut ctx, - &l0_dev, + mem::transmute(&ctx), + l0_dev, 0, true, ptr::null_mut(), @@ -58,20 +59,18 @@ impl Device { if let Some(ref prop) = self.properties { return Ok(prop); } - match self.base.get_properties() { - Ok(prop) => Ok(self.properties.get_or_insert(prop)), - Err(e) => Err(e), - } + let mut props = Default::default(); + self.base.get_properties(&mut props)?; + Ok(self.properties.get_or_insert(Box::new(props))) } fn get_image_properties(&mut self) -> l0::Result<&l0::sys::ze_device_image_properties_t> { if let Some(ref prop) = self.image_properties { return Ok(prop); } - match self.base.get_image_properties() { - Ok(prop) => Ok(self.image_properties.get_or_insert(prop)), - Err(e) => Err(e), - } + let mut props = Default::default(); + self.base.get_image_properties(&mut props)?; + Ok(self.image_properties.get_or_insert(Box::new(props))) } fn get_memory_properties(&mut self) -> l0::Result<&[l0::sys::ze_device_memory_properties_t]> { @@ -88,10 +87,9 @@ impl Device { if let Some(ref prop) = self.compute_properties { return Ok(prop); } - match self.base.get_compute_properties() { - Ok(prop) => Ok(self.compute_properties.get_or_insert(prop)), - Err(e) => Err(e), - } + let mut props = Default::default(); + self.base.get_compute_properties(&mut props)?; + Ok(self.compute_properties.get_or_insert(Box::new(props))) } pub fn late_init(&mut self) { @@ -351,7 +349,11 @@ pub fn get_uuid(uuid: *mut CUuuid_st, dev_idx: Index) -> Result<(), CUresult> { } // TODO: add support if Level 0 exposes it -pub fn get_luid(luid: *mut c_char, dev_node_mask: *mut c_uint, _dev_idx: Index) -> Result<(), CUresult> { +pub fn get_luid( + luid: *mut c_char, + dev_node_mask: *mut c_uint, + _dev_idx: Index, +) -> Result<(), CUresult> { unsafe { ptr::write_bytes(luid, 0u8, 8) }; unsafe { *dev_node_mask = 0 }; Ok(()) diff --git a/zluda/src/impl/function.rs b/zluda/src/impl/function.rs index 11f15e6..e236160 100644 --- a/zluda/src/impl/function.rs +++ b/zluda/src/impl/function.rs @@ -144,14 +144,14 @@ pub fn launch_kernel( func.base .set_group_size(block_dim_x, block_dim_y, block_dim_z)?; func.legacy_args.reset(); - let mut cmd_list = stream.command_list()?; + let cmd_list = stream.command_list()?; cmd_list.append_launch_kernel( &mut func.base, &[grid_dim_x, grid_dim_y, grid_dim_z], None, &mut [], )?; - stream.queue.execute(cmd_list)?; + stream.queue.execute_and_synchronize(cmd_list)?; Ok(()) })? } diff --git a/zluda/src/impl/memory.rs b/zluda/src/impl/memory.rs index f33a08c..5db6472 100644 --- a/zluda/src/impl/memory.rs +++ b/zluda/src/impl/memory.rs @@ -4,7 +4,7 @@ use std::{ffi::c_void, mem}; pub fn alloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> Result<(), CUresult> { let ptr = GlobalState::lock_current_context(|ctx| { let dev = unsafe { &mut *ctx.device }; - Ok::<_, CUresult>(unsafe { dev.base.mem_alloc_device(&mut dev.l0_context, bytesize, 0) }?) + Ok::<_, CUresult>(dev.l0_context.mem_alloc_device(bytesize, 0, dev.base)?) })??; unsafe { *dptr = ptr }; Ok(()) @@ -12,9 +12,9 @@ pub fn alloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> Result<(), CUresult> pub fn copy_v2(dst: *mut c_void, src: *const c_void, bytesize: usize) -> Result<(), CUresult> { GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| { - let mut cmd_list = stream.command_list()?; - unsafe { cmd_list.append_memory_copy_unsafe(dst, src, bytesize, None, &mut []) }?; - stream.queue.execute(cmd_list)?; + let cmd_list = stream.command_list()?; + unsafe { cmd_list.append_memory_copy_unsafe(dst, src, bytesize, None, &mut [])? }; + stream.queue.execute_and_synchronize(cmd_list)?; Ok::<_, CUresult>(()) })? } @@ -22,29 +22,29 @@ pub fn copy_v2(dst: *mut c_void, src: *const c_void, bytesize: usize) -> Result< pub fn free_v2(ptr: *mut c_void) -> Result<(), CUresult> { GlobalState::lock_current_context(|ctx| { let dev = unsafe { &mut *ctx.device }; - Ok::<_, CUresult>(unsafe { dev.l0_context.mem_free(ptr) }?) + Ok::<_, CUresult>(dev.l0_context.mem_free(ptr)?) }) .map_err(|_| CUresult::CUDA_ERROR_INVALID_VALUE)? } pub(crate) fn set_d32_v2(dst: *mut c_void, ui: u32, n: usize) -> Result<(), CUresult> { GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| { - let mut cmd_list = stream.command_list()?; + let cmd_list = stream.command_list()?; unsafe { cmd_list.append_memory_fill_unsafe(dst, &ui, mem::size_of::() * n, None, &mut []) }?; - stream.queue.execute(cmd_list)?; + stream.queue.execute_and_synchronize(cmd_list)?; Ok::<_, CUresult>(()) })? } pub(crate) fn set_d8_v2(dst: *mut c_void, uc: u8, n: usize) -> Result<(), CUresult> { GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| { - let mut cmd_list = stream.command_list()?; + let cmd_list = stream.command_list()?; unsafe { cmd_list.append_memory_fill_unsafe(dst, &uc, mem::size_of::() * n, None, &mut []) }?; - stream.queue.execute(cmd_list)?; + stream.queue.execute_and_synchronize(cmd_list)?; Ok::<_, CUresult>(()) })? } diff --git a/zluda/src/impl/module.rs b/zluda/src/impl/module.rs index 98580f8..6268904 100644 --- a/zluda/src/impl/module.rs +++ b/zluda/src/impl/module.rs @@ -41,7 +41,7 @@ pub struct SpirvModule { } pub struct CompiledModule { - pub base: l0::Module, + pub base: l0::Module<'static>, pub kernels: HashMap>, } @@ -78,7 +78,11 @@ impl SpirvModule { }) } - pub fn compile(&self, ctx: &mut l0::Context, dev: &l0::Device) -> Result { + pub fn compile<'a>( + &self, + ctx: &'a l0::Context, + dev: l0::Device, + ) -> Result, CUresult> { let byte_il = unsafe { slice::from_raw_parts( self.binaries.as_ptr() as *const u8, @@ -86,13 +90,11 @@ impl SpirvModule { ) }; let l0_module = match self.should_link_ptx_impl { - None => { - l0::Module::build_spirv(ctx, dev, byte_il, Some(self.build_options.as_c_str())) - } + None => l0::Module::build_spirv(ctx, dev, byte_il, Some(self.build_options.as_c_str())), Some(ptx_impl) => { l0::Module::build_link_spirv( ctx, - &dev, + dev, &[ptx_impl, byte_il], Some(self.build_options.as_c_str()), ) @@ -119,7 +121,7 @@ pub fn get_function( hash_map::Entry::Occupied(entry) => entry.into_mut(), hash_map::Entry::Vacant(entry) => { let new_module = CompiledModule { - base: module.spirv.compile(&mut device.l0_context, &device.base)?, + base: module.spirv.compile(&mut device.l0_context, device.base)?, kernels: HashMap::new(), }; entry.insert(new_module) @@ -135,7 +137,7 @@ pub fn get_function( std::str::from_utf8_unchecked(entry.key().as_c_str().to_bytes()) }) .ok_or(CUresult::CUDA_ERROR_NOT_FOUND)?; - let mut kernel = + let kernel = l0::Kernel::new_resident(&compiled_module.base, entry.key().as_c_str())?; kernel.set_indirect_access( l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE @@ -165,7 +167,7 @@ pub(crate) fn load_data(pmod: *mut *mut Module, image: *const c_void) -> Result< pub fn load_data_impl(pmod: *mut *mut Module, spirv_data: SpirvModule) -> Result<(), CUresult> { let module = GlobalState::lock_current_context(|ctx| { let device = unsafe { &mut *ctx.device }; - let l0_module = spirv_data.compile(&mut device.l0_context, &device.base)?; + let l0_module = spirv_data.compile(&device.l0_context, device.base)?; let mut device_binaries = HashMap::new(); let compiled_module = CompiledModule { base: l0_module, diff --git a/zluda/src/impl/stream.rs b/zluda/src/impl/stream.rs index e212dfc..0fafe92 100644 --- a/zluda/src/impl/stream.rs +++ b/zluda/src/impl/stream.rs @@ -33,11 +33,11 @@ impl HasLivenessCookie for StreamData { pub struct StreamData { pub context: *mut ContextData, - pub queue: l0::CommandQueue, + pub queue: l0::CommandQueue<'static>, } impl StreamData { - pub fn new_unitialized(ctx: &mut l0::Context, dev: &l0::Device) -> Result { + pub fn new_unitialized(ctx: &'static l0::Context, dev: l0::Device) -> Result { Ok(StreamData { context: ptr::null_mut(), queue: l0::CommandQueue::new(ctx, dev)?, @@ -45,7 +45,7 @@ impl StreamData { } pub fn new(ctx: &mut ContextData) -> Result { let l0_ctx = &mut unsafe { &mut *ctx.device }.l0_context; - let l0_dev = &unsafe { &*ctx.device }.base; + let l0_dev = unsafe { &*ctx.device }.base; Ok(StreamData { context: ctx as *mut _, queue: l0::CommandQueue::new(l0_ctx, l0_dev)?, @@ -55,7 +55,7 @@ impl StreamData { pub fn command_list(&self) -> Result { let ctx = unsafe { &mut *self.context }; let dev = unsafe { &mut *ctx.device }; - l0::CommandList::new(&mut dev.l0_context, &dev.base) + l0::CommandList::new(&mut dev.l0_context, dev.base) } } diff --git a/zluda_ml/src/impl.rs b/zluda_ml/src/impl.rs index 75f3ca2..1068b00 100644 --- a/zluda_ml/src/impl.rs +++ b/zluda_ml/src/impl.rs @@ -127,7 +127,8 @@ pub(crate) fn system_get_driver_version( len: 0, }; for d in drivers { - let props = d.get_properties()?; + let mut props = Default::default(); + d.get_properties(&mut props)?; let driver_version = props.driverVersion; write!(&mut output_write, "{}", driver_version) .map_err(|_| nvmlReturn_t::NVML_ERROR_UNKNOWN)?;