mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-08-09 09:39:57 +00:00
Clean up L0 use
This commit is contained in:
parent
88756f569b
commit
4b894abd96
4 changed files with 359 additions and 244 deletions
|
@ -1,6 +1,7 @@
|
|||
use crate::sys;
|
||||
use std::num::NonZeroUsize;
|
||||
use std::{
|
||||
ffi::c_void,
|
||||
ffi::{c_void, CStr},
|
||||
fmt::{Debug, Display},
|
||||
marker::PhantomData,
|
||||
mem, ptr,
|
||||
|
@ -283,7 +284,10 @@ impl CommandQueue {
|
|||
let mut raw_cmd = cmd.0;
|
||||
mem::forget(cmd);
|
||||
check!(sys::zeCommandQueueExecuteCommandLists(
|
||||
self.0, 1, &mut raw_cmd, result.0
|
||||
self.0,
|
||||
1,
|
||||
&mut raw_cmd,
|
||||
result.0
|
||||
));
|
||||
Ok(result)
|
||||
}
|
||||
|
@ -360,6 +364,7 @@ impl SafeRepr for f64 {}
|
|||
pub struct DeviceBuffer<T: SafeRepr> {
|
||||
ptr: *mut c_void,
|
||||
driver: sys::ze_driver_handle_t,
|
||||
len: usize,
|
||||
marker: PhantomData<T>,
|
||||
}
|
||||
|
||||
|
@ -367,11 +372,12 @@ impl<T: SafeRepr> DeviceBuffer<T> {
|
|||
pub unsafe fn as_ffi(&self) -> *mut c_void {
|
||||
self.ptr
|
||||
}
|
||||
pub unsafe fn from_ffi(driver: sys::ze_driver_handle_t, ptr: *mut c_void) -> Self {
|
||||
pub unsafe fn from_ffi(driver: sys::ze_driver_handle_t, ptr: *mut c_void, len: usize) -> Self {
|
||||
let marker = PhantomData::<T>;
|
||||
Self {
|
||||
ptr,
|
||||
driver,
|
||||
len,
|
||||
marker,
|
||||
}
|
||||
}
|
||||
|
@ -392,7 +398,11 @@ impl<T: SafeRepr> DeviceBuffer<T> {
|
|||
dev.0,
|
||||
&mut result
|
||||
));
|
||||
Ok(unsafe { Self::from_ffi(drv.0, result) })
|
||||
Ok(unsafe { Self::from_ffi(drv.0, result, len) })
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.len
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -403,28 +413,98 @@ impl<T: SafeRepr> Drop for DeviceBuffer<T> {
|
|||
}
|
||||
}
|
||||
|
||||
pub struct CommandList(sys::ze_command_list_handle_t);
|
||||
pub struct CommandList<'a>(sys::ze_command_list_handle_t, PhantomData<&'a ()>);
|
||||
|
||||
impl CommandList {
|
||||
impl<'a> CommandList<'a> {
|
||||
pub unsafe fn as_ffi(&self) -> sys::ze_command_list_handle_t {
|
||||
self.0
|
||||
}
|
||||
pub unsafe fn from_ffi(x: sys::ze_command_list_handle_t) -> Self {
|
||||
Self(x)
|
||||
Self(x, PhantomData)
|
||||
}
|
||||
|
||||
pub fn new(dev: &Device) -> Result<Self> {
|
||||
let desc = sys::_ze_command_list_desc_t {
|
||||
let desc = sys::ze_command_list_desc_t {
|
||||
version: sys::ze_command_list_desc_version_t::ZE_COMMAND_LIST_DESC_VERSION_CURRENT,
|
||||
flags: sys::ze_command_list_flag_t::ZE_COMMAND_LIST_FLAG_NONE,
|
||||
};
|
||||
let mut result: sys::ze_command_list_handle_t = ptr::null_mut();
|
||||
check!(sys::zeCommandListCreate(dev.0, &desc, &mut result));
|
||||
Ok(Self(result))
|
||||
Ok(Self(result, PhantomData))
|
||||
}
|
||||
|
||||
pub fn append_memory_copy<
|
||||
T: 'a,
|
||||
Dst: Into<BufferPtrMut<'a, T>>,
|
||||
Src: Into<BufferPtr<'a, T>>,
|
||||
>(
|
||||
&mut self,
|
||||
dst: Dst,
|
||||
src: Src,
|
||||
length: Option<usize>,
|
||||
signal: Option<&Event<'a>>,
|
||||
) -> Result<()> {
|
||||
let dst = dst.into();
|
||||
let src = src.into();
|
||||
let elements = length.unwrap_or(std::cmp::max(dst.len(), src.len()));
|
||||
let event = signal.map(|e| e.0).unwrap_or(ptr::null_mut());
|
||||
check!(sys::zeCommandListAppendMemoryCopy(
|
||||
self.0,
|
||||
dst.get(),
|
||||
src.get(),
|
||||
elements * std::mem::size_of::<T>(),
|
||||
event,
|
||||
));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn append_memory_fill<T>(
|
||||
&mut self,
|
||||
dst: BufferPtrMut<'a, T>,
|
||||
pattern: T,
|
||||
signal: Option<&Event<'a>>,
|
||||
) -> Result<()> {
|
||||
let raw_pattern = &pattern as *const T as *const _;
|
||||
let event = signal.map(|e| e.0).unwrap_or(ptr::null_mut());
|
||||
let byte_len = dst.len() * mem::size_of::<T>();
|
||||
check!(sys::zeCommandListAppendMemoryFill(
|
||||
self.0,
|
||||
dst.get(),
|
||||
raw_pattern,
|
||||
mem::size_of::<T>(),
|
||||
byte_len,
|
||||
event,
|
||||
));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn append_launch_kernel(
|
||||
&mut self,
|
||||
kernel: &'a Kernel,
|
||||
group_count: &[u32; 3],
|
||||
signal: Option<&Event<'a>>,
|
||||
wait: &[&Event<'a>],
|
||||
) -> Result<()> {
|
||||
let gr_count = sys::ze_group_count_t {
|
||||
groupCountX: group_count[0],
|
||||
groupCountY: group_count[1],
|
||||
groupCountZ: group_count[2],
|
||||
};
|
||||
let event = signal.map(|e| e.0).unwrap_or(ptr::null_mut());
|
||||
let mut wait_ptrs = wait.iter().map(|e| e.0).collect::<Vec<_>>();
|
||||
check!(sys::zeCommandListAppendLaunchKernel(
|
||||
self.0,
|
||||
kernel.0,
|
||||
&gr_count,
|
||||
event,
|
||||
wait.len() as u32,
|
||||
wait_ptrs.as_mut_ptr(),
|
||||
));
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for CommandList {
|
||||
impl<'a> Drop for CommandList<'a> {
|
||||
#[allow(unused_must_use)]
|
||||
fn drop(&mut self) {
|
||||
unsafe { sys::zeCommandListDestroy(self.0) };
|
||||
|
@ -457,3 +537,214 @@ impl<'a> Drop for FenceGuard<'a> {
|
|||
unsafe { sys::zeCommandListDestroy(self.1) };
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct BufferPtr<'a, T> {
|
||||
ptr: *const c_void,
|
||||
marker: PhantomData<&'a T>,
|
||||
elems: usize,
|
||||
}
|
||||
|
||||
impl<'a, T> BufferPtr<'a, T> {
|
||||
pub unsafe fn get(self) -> *const c_void {
|
||||
return self.ptr;
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.elems
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> From<&'a [T]> for BufferPtr<'a, T> {
|
||||
fn from(s: &'a [T]) -> Self {
|
||||
BufferPtr {
|
||||
ptr: s.as_ptr() as *const _,
|
||||
marker: PhantomData,
|
||||
elems: s.len(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: SafeRepr> From<&'a DeviceBuffer<T>> for BufferPtr<'a, T> {
|
||||
fn from(b: &'a DeviceBuffer<T>) -> Self {
|
||||
BufferPtr {
|
||||
ptr: b.ptr as *const _,
|
||||
marker: PhantomData,
|
||||
elems: b.len(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct BufferPtrMut<'a, T> {
|
||||
ptr: *mut c_void,
|
||||
marker: PhantomData<&'a mut T>,
|
||||
elems: usize,
|
||||
}
|
||||
|
||||
impl<'a, T> BufferPtrMut<'a, T> {
|
||||
pub unsafe fn get(self) -> *mut c_void {
|
||||
return self.ptr;
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.elems
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> From<&'a mut [T]> for BufferPtrMut<'a, T> {
|
||||
fn from(s: &'a mut [T]) -> Self {
|
||||
BufferPtrMut {
|
||||
ptr: s.as_mut_ptr() as *mut _,
|
||||
marker: PhantomData,
|
||||
elems: s.len(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: SafeRepr> From<&'a mut DeviceBuffer<T>> for BufferPtrMut<'a, T> {
|
||||
fn from(b: &'a mut DeviceBuffer<T>) -> Self {
|
||||
BufferPtrMut {
|
||||
ptr: b.ptr as *mut _,
|
||||
marker: PhantomData,
|
||||
elems: b.len(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: SafeRepr> From<BufferPtrMut<'a, T>> for BufferPtr<'a, T> {
|
||||
fn from(b: BufferPtrMut<'a, T>) -> Self {
|
||||
BufferPtr {
|
||||
ptr: b.ptr,
|
||||
marker: PhantomData,
|
||||
elems: b.len(),
|
||||
}
|
||||
}
|
||||
}
|
||||
pub struct EventPool<'a>(sys::ze_event_pool_handle_t, PhantomData<&'a ()>);
|
||||
|
||||
impl<'a> EventPool<'a> {
|
||||
pub unsafe fn as_ffi(&self) -> sys::ze_event_pool_handle_t {
|
||||
self.0
|
||||
}
|
||||
pub unsafe fn from_ffi(x: sys::ze_event_pool_handle_t) -> Self {
|
||||
Self(x, PhantomData)
|
||||
}
|
||||
pub fn new(driver: &Driver, count: u32, dev: Option<&[&'a Device]>) -> Result<Self> {
|
||||
let desc = sys::ze_event_pool_desc_t {
|
||||
version: sys::ze_event_pool_desc_version_t::ZE_EVENT_POOL_DESC_VERSION_CURRENT,
|
||||
flags: sys::ze_event_pool_flag_t::ZE_EVENT_POOL_FLAG_DEFAULT,
|
||||
count: count,
|
||||
};
|
||||
let mut dev = dev.map(|d| d.iter().map(|d| d.0).collect::<Vec<_>>());
|
||||
let dev_len = dev.as_ref().map_or(0, |d| d.len() as u32);
|
||||
let dev_ptr = dev.as_mut().map_or(ptr::null_mut(), |d| d.as_mut_ptr());
|
||||
let mut result = ptr::null_mut();
|
||||
check!(sys::zeEventPoolCreate(
|
||||
driver.0,
|
||||
&desc,
|
||||
dev_len,
|
||||
dev_ptr,
|
||||
&mut result
|
||||
));
|
||||
Ok(Self(result, PhantomData))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Drop for EventPool<'a> {
|
||||
#[allow(unused_must_use)]
|
||||
fn drop(&mut self) {
|
||||
unsafe { sys::zeEventPoolDestroy(self.0) };
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Event<'a>(sys::ze_event_handle_t, PhantomData<&'a ()>);
|
||||
|
||||
impl<'a> Event<'a> {
|
||||
pub unsafe fn as_ffi(&self) -> sys::ze_event_handle_t {
|
||||
self.0
|
||||
}
|
||||
|
||||
pub unsafe fn from_ffi(x: sys::ze_event_handle_t) -> Self {
|
||||
Self(x, PhantomData)
|
||||
}
|
||||
|
||||
pub fn new(pool: &'a EventPool, index: u32) -> Result<Self> {
|
||||
let desc = sys::ze_event_desc_t {
|
||||
version: sys::ze_event_desc_version_t::ZE_EVENT_DESC_VERSION_CURRENT,
|
||||
index: index,
|
||||
signal: sys::ze_event_scope_flag_t::ZE_EVENT_SCOPE_FLAG_NONE,
|
||||
wait: sys::ze_event_scope_flag_t::ZE_EVENT_SCOPE_FLAG_NONE,
|
||||
};
|
||||
let mut result = ptr::null_mut();
|
||||
check!(sys::zeEventCreate(pool.0, &desc, &mut result));
|
||||
Ok(Self(result, PhantomData))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Drop for Event<'a> {
|
||||
#[allow(unused_must_use)]
|
||||
fn drop(&mut self) {
|
||||
unsafe { sys::zeEventDestroy(self.0) };
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Kernel<'a>(sys::ze_kernel_handle_t, PhantomData<&'a ()>);
|
||||
|
||||
impl<'a> Kernel<'a> {
|
||||
pub unsafe fn as_ffi(&self) -> sys::ze_kernel_handle_t {
|
||||
self.0
|
||||
}
|
||||
|
||||
pub unsafe fn from_ffi(x: sys::ze_kernel_handle_t) -> Self {
|
||||
Self(x, PhantomData)
|
||||
}
|
||||
|
||||
pub fn new(module: &'a Module, name: &CStr) -> Result<Self> {
|
||||
let desc = sys::ze_kernel_desc_t {
|
||||
version: sys::ze_kernel_desc_version_t::ZE_KERNEL_DESC_VERSION_CURRENT,
|
||||
flags: sys::ze_kernel_flag_t::ZE_KERNEL_FLAG_NONE,
|
||||
pKernelName: name.as_ptr() as *const _,
|
||||
};
|
||||
let mut result = ptr::null_mut();
|
||||
check!(sys::zeKernelCreate(module.0, &desc, &mut result));
|
||||
Ok(Self(result, PhantomData))
|
||||
}
|
||||
|
||||
pub fn set_arg_buffer<T: 'a, Buff: Into<BufferPtr<'a, T>>>(
|
||||
&self,
|
||||
index: u32,
|
||||
buff: Buff,
|
||||
) -> Result<()> {
|
||||
let ptr = unsafe { buff.into().get() };
|
||||
check!(sys::zeKernelSetArgumentValue(
|
||||
self.0,
|
||||
index,
|
||||
mem::size_of::<T>(),
|
||||
&ptr as *const _ as *const _,
|
||||
));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn set_arg_scalar<T: Copy>(&self, index: u32, value: &T) -> Result<()> {
|
||||
check!(sys::zeKernelSetArgumentValue(
|
||||
self.0,
|
||||
index,
|
||||
mem::size_of::<T>(),
|
||||
value as *const T as *const _,
|
||||
));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn set_group_size(&self, x: u32, y: u32, z: u32) -> Result<()> {
|
||||
check!(sys::zeKernelSetGroupSize(self.0, x, y, z));
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Drop for Kernel<'a> {
|
||||
#[allow(unused_must_use)]
|
||||
fn drop(&mut self) {
|
||||
unsafe { sys::zeKernelDestroy(self.0) };
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue