Fix level zero bindings

This commit is contained in:
Andrzej Janik 2021-05-28 00:14:45 +02:00
commit 2fc7af0434
4 changed files with 209 additions and 123 deletions

View file

@ -737,35 +737,35 @@ impl<'a> CommandList<'a> {
Ok(unsafe { Self::from_ffi(result) }) Ok(unsafe { Self::from_ffi(result) })
} }
pub fn append_memory_copy<'event, T: 'a, Dst: Into<Slice<'a, T>>, Src: Into<Slice<'a, T>>>( pub unsafe fn append_memory_copy<
&'a self, 'dep,
T: 'a + 'dep + Copy + Sized,
Dst: Into<Slice<'dep, T>>,
Src: Into<Slice<'dep, T>>,
>(
&self,
dst: Dst, dst: Dst,
src: Src, src: Src,
signal: Option<&Event<'event>>, signal: Option<&Event<'dep>>,
wait: &[Event<'event>], wait: &[&'dep Event<'dep>],
) -> Result<()> ) -> Result<()> {
where
'event: 'a,
{
let dst = dst.into(); let dst = dst.into();
let src = src.into(); let src = src.into();
let elements = std::cmp::min(dst.len(), src.len()); let elements = std::cmp::min(dst.len(), src.len());
let length = elements * mem::size_of::<T>(); let length = elements * mem::size_of::<T>();
unsafe { self.append_memory_copy_raw(dst.as_mut_ptr(), src.as_ptr(), length, signal, wait)
self.append_memory_copy_unsafe(dst.as_mut_ptr(), src.as_ptr(), length, signal, wait)
}
} }
pub unsafe fn append_memory_copy_unsafe( pub unsafe fn append_memory_copy_raw(
&self, &self,
dst: *mut c_void, dst: *mut c_void,
src: *const c_void, src: *const c_void,
length: usize, length: usize,
signal: Option<&Event>, signal: Option<&Event>,
wait: &[Event], wait: &[&Event],
) -> Result<()> { ) -> Result<()> {
let signal_event = signal.map(|e| e.as_ffi()).unwrap_or(ptr::null_mut()); let signal_event = signal.map_or(ptr::null_mut(), |e| e.as_ffi());
let (wait_len, wait_ptr) = Event::raw_slice(wait); Event::with_raw_slice(wait, |wait_len, wait_ptr| {
check!(sys::zeCommandListAppendMemoryCopy( check!(sys::zeCommandListAppendMemoryCopy(
self.as_ffi(), self.as_ffi(),
dst, dst,
@ -776,81 +776,73 @@ impl<'a> CommandList<'a> {
wait_ptr wait_ptr
)); ));
Ok(()) Ok(())
})
} }
pub fn append_memory_fill<'event, T: 'a, Dst: Into<Slice<'a, T>>>( pub unsafe fn append_memory_fill<'dep, T: Copy + Sized + 'dep, Dst: Into<Slice<'dep, T>>>(
&'a self, &'a self,
dst: Dst, dst: Dst,
pattern: u8, pattern: &T,
signal: Option<&Event<'event>>, signal: Option<&Event<'dep>>,
wait: &[Event<'event>], wait: &[&'dep Event<'dep>],
) -> Result<()> ) -> Result<()> {
where
'event: 'a,
{
let dst = dst.into(); let dst = dst.into();
let raw_pattern = &pattern as *const u8 as *const _; let raw_pattern = pattern as *const _ as *const _;
let signal_event = signal let signal_event = signal.map_or(ptr::null_mut(), |e| e.as_ffi());
.map(|e| unsafe { e.as_ffi() }) Event::with_raw_slice(wait, |wait_len, wait_ptr| {
.unwrap_or(ptr::null_mut());
let (wait_len, wait_ptr) = unsafe { Event::raw_slice(wait) };
let byte_len = dst.len() * mem::size_of::<T>();
check!(sys::zeCommandListAppendMemoryFill( check!(sys::zeCommandListAppendMemoryFill(
self.as_ffi(), self.as_ffi(),
dst.as_mut_ptr(), dst.as_mut_ptr(),
raw_pattern, raw_pattern,
mem::size_of::<u8>(), mem::size_of::<T>(),
byte_len, dst.len() * mem::size_of::<T>(),
signal_event, signal_event,
wait_len, wait_len,
wait_ptr wait_ptr
)); ));
Ok(()) Ok(())
})
} }
pub unsafe fn append_memory_fill_unsafe<T: Copy + Sized>( pub unsafe fn append_memory_fill_raw(
&self, &self,
dst: *mut c_void, dst: *mut c_void,
pattern: &T, pattern: *mut c_void,
byte_size: usize, pattern_size: usize,
size: usize,
signal: Option<&Event>, signal: Option<&Event>,
wait: &[Event], wait: &[&Event],
) -> Result<()> { ) -> Result<()> {
let signal_event = signal.map(|e| e.as_ffi()).unwrap_or(ptr::null_mut()); let signal_event = signal.map_or(ptr::null_mut(), |e| e.as_ffi());
let (wait_len, wait_ptr) = Event::raw_slice(wait); Event::with_raw_slice(wait, |wait_len, wait_ptr| {
check!(sys::zeCommandListAppendMemoryFill( check!(sys::zeCommandListAppendMemoryFill(
self.as_ffi(), self.as_ffi(),
dst, dst,
pattern as *const T as *const _, pattern,
mem::size_of::<T>(), pattern_size,
byte_size, size,
signal_event, signal_event,
wait_len, wait_len,
wait_ptr wait_ptr
)); ));
Ok(()) Ok(())
})
} }
pub fn append_launch_kernel<'event, 'kernel>( pub unsafe fn append_launch_kernel(
&'a self, &self,
kernel: &'kernel Kernel, kernel: &Kernel,
group_count: &[u32; 3], group_count: &[u32; 3],
signal: Option<&Event<'event>>, signal: Option<&Event>,
wait: &[Event<'event>], wait: &[&Event],
) -> Result<()> ) -> Result<()> {
where
'event: 'a,
'kernel: 'a,
{
let gr_count = sys::ze_group_count_t { let gr_count = sys::ze_group_count_t {
groupCountX: group_count[0], groupCountX: group_count[0],
groupCountY: group_count[1], groupCountY: group_count[1],
groupCountZ: group_count[2], groupCountZ: group_count[2],
}; };
let signal_event = signal let signal_event = signal.map_or(ptr::null_mut(), |e| e.as_ffi());
.map(|e| unsafe { e.as_ffi() }) Event::with_raw_slice(wait, |wait_len, wait_ptr| {
.unwrap_or(ptr::null_mut());
let (wait_len, wait_ptr) = unsafe { Event::raw_slice(wait) };
check!(sys::zeCommandListAppendLaunchKernel( check!(sys::zeCommandListAppendLaunchKernel(
self.as_ffi(), self.as_ffi(),
kernel.as_ffi(), kernel.as_ffi(),
@ -860,6 +852,7 @@ impl<'a> CommandList<'a> {
wait_ptr, wait_ptr,
)); ));
Ok(()) Ok(())
})
} }
pub fn close(&self) -> Result<()> { pub fn close(&self) -> Result<()> {
@ -875,17 +868,86 @@ impl<'a> Drop for CommandList<'a> {
} }
} }
pub struct CommandListBuilder<'a>(CommandList<'a>);
unsafe impl<'a> Send for CommandListBuilder<'a> {}
impl<'a> CommandListBuilder<'a> {
pub fn new(ctx: &'a Context, dev: Device) -> Result<Self> {
Ok(CommandListBuilder(CommandList::new(ctx, dev)?))
}
pub fn append_memory_copy<
'dep,
'result,
T: 'dep + Copy + Sized,
Dst: Into<Slice<'dep, T>>,
Src: Into<Slice<'dep, T>>,
>(
self,
dst: Dst,
src: Src,
signal: Option<&'dep Event<'dep>>,
wait: &[&'dep Event<'dep>],
) -> Result<CommandListBuilder<'result>>
where
'a: 'result,
'dep: 'result,
{
unsafe { self.0.append_memory_copy(dst, src, signal, wait) }?;
Ok(self)
}
pub fn append_memory_fill<'dep, 'result, T: 'dep + Copy + Sized, Dst: Into<Slice<'dep, T>>>(
self,
dst: Dst,
pattern: &T,
signal: Option<&Event<'dep>>,
wait: &[&'dep Event<'dep>],
) -> Result<CommandListBuilder<'result>>
where
'a: 'result,
'dep: 'result,
{
unsafe { self.0.append_memory_fill(dst, pattern, signal, wait) }?;
Ok(self)
}
pub fn append_launch_kernel<'dep, 'result>(
self,
kernel: &'dep Kernel,
group_count: &[u32; 3],
signal: Option<&Event<'dep>>,
wait: &[&'dep Event<'dep>],
) -> Result<CommandListBuilder<'result>>
where
'a: 'result,
'dep: 'result,
{
unsafe {
self.0
.append_launch_kernel(kernel, group_count, signal, wait)
}?;
Ok(self)
}
pub fn execute(self, q: &'a CommandQueue<'a>) -> Result<FenceGuard<'a>> {
self.0.close()?;
q.execute_and_synchronize(self.0)
}
}
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
pub struct Slice<'a, T> { pub struct Slice<'a, T: Copy + Sized> {
ptr: *mut c_void, ptr: *mut c_void,
len: usize, len: usize,
marker: PhantomData<&'a T>, marker: PhantomData<&'a T>,
} }
unsafe impl<'a, T> Send for Slice<'a, T> {} unsafe impl<'a, T: Copy + Sized> Send for Slice<'a, T> {}
unsafe impl<'a, T> Sync for Slice<'a, T> {} unsafe impl<'a, T: Copy + Sized> Sync for Slice<'a, T> {}
impl<'a, T> Slice<'a, T> { impl<'a, T: Copy + Sized> Slice<'a, T> {
pub unsafe fn new(ptr: *mut c_void, len: usize) -> Self { pub unsafe fn new(ptr: *mut c_void, len: usize) -> Self {
Self { Self {
ptr, ptr,
@ -907,7 +969,7 @@ impl<'a, T> Slice<'a, T> {
} }
} }
impl<'a, T> From<&'a [T]> for Slice<'a, T> { impl<'a, T: Copy + Sized> From<&'a [T]> for Slice<'a, T> {
fn from(s: &'a [T]) -> Self { fn from(s: &'a [T]) -> Self {
Slice { Slice {
ptr: s.as_ptr() as *mut _, ptr: s.as_ptr() as *mut _,
@ -917,7 +979,7 @@ impl<'a, T> From<&'a [T]> for Slice<'a, T> {
} }
} }
impl<'a, T: Copy> From<&'a DeviceBuffer<'a, T>> for Slice<'a, T> { impl<'a, T: Copy + Sized> From<&'a DeviceBuffer<'a, T>> for Slice<'a, T> {
fn from(b: &'a DeviceBuffer<'a, T>) -> Self { fn from(b: &'a DeviceBuffer<'a, T>) -> Self {
Slice { Slice {
ptr: b.ptr, ptr: b.ptr,
@ -996,13 +1058,21 @@ impl<'a> Event<'a> {
Ok(unsafe { Self::from_ffi(result) }) Ok(unsafe { Self::from_ffi(result) })
} }
unsafe fn raw_slice(e: &[Event]) -> (u32, *mut sys::ze_event_handle_t) { unsafe fn with_raw_slice<'x, T>(
let ptr = if e.len() == 0 { events: &[&Event<'x>],
ptr::null() f: impl FnOnce(u32, *mut sys::ze_event_handle_t) -> T,
} else { ) -> T {
e.as_ptr() let (ptr, ev_vec) = match events {
[] => (ptr::null_mut(), None),
[e] => (&e.0 as *const _ as *mut _, None),
_ => {
let mut ev_vec = events.iter().map(|e| e.as_ffi()).collect::<Vec<_>>();
(ev_vec.as_mut_ptr(), Some(ev_vec))
}
}; };
(e.len() as u32, ptr as *mut sys::ze_event_handle_t) let result = f(events.len() as u32, ptr);
drop(ev_vec);
result
} }
} }
@ -1042,7 +1112,7 @@ impl<'a> Kernel<'a> {
Ok(()) Ok(())
} }
pub fn set_arg_buffer<T: 'a, Buff: Into<Slice<'a, T>>>( pub fn set_arg_buffer<T: 'a + Copy + Sized, Buff: Into<Slice<'a, T>>>(
&self, &self,
index: u32, index: u32,
buff: Buff, buff: Buff,

View file

@ -202,7 +202,7 @@ impl<T: Debug> error::Error for DisplayError<T> {}
fn test_ptx_assert< fn test_ptx_assert<
'a, 'a,
Input: From<u8> + Debug + Copy + PartialEq, Input: From<u8> + Debug + Copy + PartialEq,
Output: From<u8> + Debug + Copy + PartialEq, Output: From<u8> + Debug + Copy + PartialEq + Default,
>( >(
name: &str, name: &str,
ptx_text: &'a str, ptx_text: &'a str,
@ -220,7 +220,7 @@ fn test_ptx_assert<
Ok(()) Ok(())
} }
fn run_spirv<Input: From<u8> + Copy + Debug, Output: From<u8> + Copy + Debug>( fn run_spirv<Input: From<u8> + Copy + Debug, Output: From<u8> + Copy + Debug + Default>(
name: &CStr, name: &CStr,
module: translate::Module, module: translate::Module,
input: &[Input], input: &[Input],
@ -286,19 +286,19 @@ fn run_spirv<Input: From<u8> + Copy + Debug, Output: From<u8> + Copy + Debug>(
let ev1 = ze::Event::new(&event_pool, 1)?; let ev1 = ze::Event::new(&event_pool, 1)?;
let ev2 = ze::Event::new(&event_pool, 2)?; let ev2 = ze::Event::new(&event_pool, 2)?;
{ {
let cmd_list = ze::CommandList::new(&ctx, dev)?; let init_evs = [&ev0, &ev1];
let init_evs = [ev0, ev1];
cmd_list.append_memory_copy(&inp_b, input, Some(&init_evs[0]), &[])?;
cmd_list.append_memory_fill(&out_b, 0, Some(&init_evs[1]), &[])?;
kernel.set_group_size(1, 1, 1)?; kernel.set_group_size(1, 1, 1)?;
kernel.set_arg_buffer(0, &inp_b)?; kernel.set_arg_buffer(0, &inp_b)?;
kernel.set_arg_buffer(1, &out_b)?; kernel.set_arg_buffer(1, &out_b)?;
if use_shared_mem { if use_shared_mem {
unsafe { kernel.set_arg_raw(2, 128, ptr::null())? }; unsafe { kernel.set_arg_raw(2, 128, ptr::null())? };
} }
cmd_list.append_launch_kernel(&kernel, &[1, 1, 1], Some(&ev2), &init_evs)?; ze::CommandListBuilder::new(&ctx, dev)?
cmd_list.append_memory_copy(&*result, &out_b, None, &[ev2])?; .append_memory_copy(&inp_b, input, Some(&init_evs[0]), &[])?
queue.execute_and_synchronize(cmd_list)?; .append_memory_fill(&out_b, &Output::default(), Some(&init_evs[1]), &[])?
.append_launch_kernel(&kernel, &[1, 1, 1], Some(&ev2), &init_evs)?
.append_memory_copy(&*result, &out_b, None, &[&ev2])?
.execute(&queue)?;
} }
} }
Ok(result) Ok(result)

View file

@ -145,12 +145,14 @@ pub fn launch_kernel(
.set_group_size(block_dim_x, block_dim_y, block_dim_z)?; .set_group_size(block_dim_x, block_dim_y, block_dim_z)?;
func.legacy_args.reset(); func.legacy_args.reset();
let cmd_list = stream.command_list()?; let cmd_list = stream.command_list()?;
unsafe {
cmd_list.append_launch_kernel( cmd_list.append_launch_kernel(
&mut func.base, &mut func.base,
&[grid_dim_x, grid_dim_y, grid_dim_z], &[grid_dim_x, grid_dim_y, grid_dim_z],
None, None,
&mut [], &mut [],
)?; )?;
}
stream.queue.execute_and_synchronize(cmd_list)?; stream.queue.execute_and_synchronize(cmd_list)?;
Ok(()) Ok(())
})? })?

View file

@ -13,7 +13,7 @@ pub fn alloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> Result<(), CUresult>
pub fn copy_v2(dst: *mut c_void, src: *const c_void, bytesize: usize) -> Result<(), CUresult> { pub fn copy_v2(dst: *mut c_void, src: *const c_void, bytesize: usize) -> Result<(), CUresult> {
GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| { GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| {
let cmd_list = stream.command_list()?; let cmd_list = stream.command_list()?;
unsafe { cmd_list.append_memory_copy_unsafe(dst, src, bytesize, None, &mut [])? }; unsafe { cmd_list.append_memory_copy_raw(dst, src, bytesize, None, &mut [])? };
stream.queue.execute_and_synchronize(cmd_list)?; stream.queue.execute_and_synchronize(cmd_list)?;
Ok::<_, CUresult>(()) Ok::<_, CUresult>(())
})? })?
@ -27,22 +27,36 @@ pub fn free_v2(ptr: *mut c_void) -> Result<(), CUresult> {
.map_err(|_| CUresult::CUDA_ERROR_INVALID_VALUE)? .map_err(|_| CUresult::CUDA_ERROR_INVALID_VALUE)?
} }
pub(crate) fn set_d32_v2(dst: *mut c_void, ui: u32, n: usize) -> Result<(), CUresult> { pub(crate) fn set_d32_v2(dst: *mut c_void, mut ui: u32, n: usize) -> Result<(), CUresult> {
GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| { GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| {
let cmd_list = stream.command_list()?; let cmd_list = stream.command_list()?;
unsafe { unsafe {
cmd_list.append_memory_fill_unsafe(dst, &ui, mem::size_of::<u32>() * n, None, &mut []) cmd_list.append_memory_fill_raw(
dst,
&mut ui as *mut _ as *mut _,
mem::size_of::<u32>(),
mem::size_of::<u32>() * n,
None,
&mut [],
)
}?; }?;
stream.queue.execute_and_synchronize(cmd_list)?; stream.queue.execute_and_synchronize(cmd_list)?;
Ok::<_, CUresult>(()) Ok::<_, CUresult>(())
})? })?
} }
pub(crate) fn set_d8_v2(dst: *mut c_void, uc: u8, n: usize) -> Result<(), CUresult> { pub(crate) fn set_d8_v2(dst: *mut c_void, mut uc: u8, n: usize) -> Result<(), CUresult> {
GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| { GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| {
let cmd_list = stream.command_list()?; let cmd_list = stream.command_list()?;
unsafe { unsafe {
cmd_list.append_memory_fill_unsafe(dst, &uc, mem::size_of::<u8>() * n, None, &mut []) cmd_list.append_memory_fill_raw(
dst,
&mut uc as *mut _ as *mut _,
mem::size_of::<u8>(),
mem::size_of::<u8>() * n,
None,
&mut [],
)
}?; }?;
stream.queue.execute_and_synchronize(cmd_list)?; stream.queue.execute_and_synchronize(cmd_list)?;
Ok::<_, CUresult>(()) Ok::<_, CUresult>(())