diff --git a/level_zero/src/ze.rs b/level_zero/src/ze.rs index 88adfe6..f4cd0ae 100644 --- a/level_zero/src/ze.rs +++ b/level_zero/src/ze.rs @@ -737,129 +737,122 @@ impl<'a> CommandList<'a> { Ok(unsafe { Self::from_ffi(result) }) } - pub fn append_memory_copy<'event, T: 'a, Dst: Into>, Src: Into>>( - &'a self, + pub unsafe fn append_memory_copy< + 'dep, + T: 'a + 'dep + Copy + Sized, + Dst: Into>, + Src: Into>, + >( + &self, dst: Dst, src: Src, - signal: Option<&Event<'event>>, - wait: &[Event<'event>], - ) -> Result<()> - where - 'event: 'a, - { + signal: Option<&Event<'dep>>, + wait: &[&'dep Event<'dep>], + ) -> Result<()> { let dst = dst.into(); let src = src.into(); let elements = std::cmp::min(dst.len(), src.len()); let length = elements * mem::size_of::(); - unsafe { - self.append_memory_copy_unsafe(dst.as_mut_ptr(), src.as_ptr(), length, signal, wait) - } + self.append_memory_copy_raw(dst.as_mut_ptr(), src.as_ptr(), length, signal, wait) } - pub unsafe fn append_memory_copy_unsafe( + pub unsafe fn append_memory_copy_raw( &self, dst: *mut c_void, src: *const c_void, length: usize, signal: Option<&Event>, - wait: &[Event], + wait: &[&Event], ) -> Result<()> { - let signal_event = signal.map(|e| e.as_ffi()).unwrap_or(ptr::null_mut()); - let (wait_len, wait_ptr) = Event::raw_slice(wait); - check!(sys::zeCommandListAppendMemoryCopy( - self.as_ffi(), - dst, - src, - length, - signal_event, - wait_len, - wait_ptr - )); - Ok(()) + let signal_event = signal.map_or(ptr::null_mut(), |e| e.as_ffi()); + Event::with_raw_slice(wait, |wait_len, wait_ptr| { + check!(sys::zeCommandListAppendMemoryCopy( + self.as_ffi(), + dst, + src, + length, + signal_event, + wait_len, + wait_ptr + )); + Ok(()) + }) } - pub fn append_memory_fill<'event, T: 'a, Dst: Into>>( + pub unsafe fn append_memory_fill<'dep, T: Copy + Sized + 'dep, Dst: Into>>( &'a self, dst: Dst, - pattern: u8, - signal: Option<&Event<'event>>, - wait: &[Event<'event>], - ) -> Result<()> - where - 'event: 'a, - { + pattern: &T, + signal: Option<&Event<'dep>>, + wait: &[&'dep Event<'dep>], + ) -> Result<()> { let dst = dst.into(); - let raw_pattern = &pattern as *const u8 as *const _; - let signal_event = signal - .map(|e| unsafe { e.as_ffi() }) - .unwrap_or(ptr::null_mut()); - let (wait_len, wait_ptr) = unsafe { Event::raw_slice(wait) }; - let byte_len = dst.len() * mem::size_of::(); - check!(sys::zeCommandListAppendMemoryFill( - self.as_ffi(), - dst.as_mut_ptr(), - raw_pattern, - mem::size_of::(), - byte_len, - signal_event, - wait_len, - wait_ptr - )); - Ok(()) + let raw_pattern = pattern as *const _ as *const _; + let signal_event = signal.map_or(ptr::null_mut(), |e| e.as_ffi()); + Event::with_raw_slice(wait, |wait_len, wait_ptr| { + check!(sys::zeCommandListAppendMemoryFill( + self.as_ffi(), + dst.as_mut_ptr(), + raw_pattern, + mem::size_of::(), + dst.len() * mem::size_of::(), + signal_event, + wait_len, + wait_ptr + )); + Ok(()) + }) } - pub unsafe fn append_memory_fill_unsafe( + pub unsafe fn append_memory_fill_raw( &self, dst: *mut c_void, - pattern: &T, - byte_size: usize, + pattern: *mut c_void, + pattern_size: usize, + size: usize, signal: Option<&Event>, - wait: &[Event], + wait: &[&Event], ) -> Result<()> { - let signal_event = signal.map(|e| e.as_ffi()).unwrap_or(ptr::null_mut()); - let (wait_len, wait_ptr) = Event::raw_slice(wait); - check!(sys::zeCommandListAppendMemoryFill( - self.as_ffi(), - dst, - pattern as *const T as *const _, - mem::size_of::(), - byte_size, - signal_event, - wait_len, - wait_ptr - )); - Ok(()) + let signal_event = signal.map_or(ptr::null_mut(), |e| e.as_ffi()); + Event::with_raw_slice(wait, |wait_len, wait_ptr| { + check!(sys::zeCommandListAppendMemoryFill( + self.as_ffi(), + dst, + pattern, + pattern_size, + size, + signal_event, + wait_len, + wait_ptr + )); + Ok(()) + }) } - pub fn append_launch_kernel<'event, 'kernel>( - &'a self, - kernel: &'kernel Kernel, + pub unsafe fn append_launch_kernel( + &self, + kernel: &Kernel, group_count: &[u32; 3], - signal: Option<&Event<'event>>, - wait: &[Event<'event>], - ) -> Result<()> - where - 'event: 'a, - 'kernel: 'a, - { + signal: Option<&Event>, + wait: &[&Event], + ) -> Result<()> { let gr_count = sys::ze_group_count_t { groupCountX: group_count[0], groupCountY: group_count[1], groupCountZ: group_count[2], }; - let signal_event = signal - .map(|e| unsafe { e.as_ffi() }) - .unwrap_or(ptr::null_mut()); - let (wait_len, wait_ptr) = unsafe { Event::raw_slice(wait) }; - check!(sys::zeCommandListAppendLaunchKernel( - self.as_ffi(), - kernel.as_ffi(), - &gr_count, - signal_event, - wait_len, - wait_ptr, - )); - Ok(()) + let signal_event = signal.map_or(ptr::null_mut(), |e| e.as_ffi()); + Event::with_raw_slice(wait, |wait_len, wait_ptr| { + check!(sys::zeCommandListAppendLaunchKernel( + self.as_ffi(), + kernel.as_ffi(), + &gr_count, + signal_event, + wait_len, + wait_ptr, + )); + Ok(()) + }) } pub fn close(&self) -> Result<()> { @@ -875,17 +868,86 @@ impl<'a> Drop for CommandList<'a> { } } +pub struct CommandListBuilder<'a>(CommandList<'a>); + +unsafe impl<'a> Send for CommandListBuilder<'a> {} + +impl<'a> CommandListBuilder<'a> { + pub fn new(ctx: &'a Context, dev: Device) -> Result { + Ok(CommandListBuilder(CommandList::new(ctx, dev)?)) + } + + pub fn append_memory_copy< + 'dep, + 'result, + T: 'dep + Copy + Sized, + Dst: Into>, + Src: Into>, + >( + self, + dst: Dst, + src: Src, + signal: Option<&'dep Event<'dep>>, + wait: &[&'dep Event<'dep>], + ) -> Result> + where + 'a: 'result, + 'dep: 'result, + { + unsafe { self.0.append_memory_copy(dst, src, signal, wait) }?; + Ok(self) + } + + pub fn append_memory_fill<'dep, 'result, T: 'dep + Copy + Sized, Dst: Into>>( + self, + dst: Dst, + pattern: &T, + signal: Option<&Event<'dep>>, + wait: &[&'dep Event<'dep>], + ) -> Result> + where + 'a: 'result, + 'dep: 'result, + { + unsafe { self.0.append_memory_fill(dst, pattern, signal, wait) }?; + Ok(self) + } + + pub fn append_launch_kernel<'dep, 'result>( + self, + kernel: &'dep Kernel, + group_count: &[u32; 3], + signal: Option<&Event<'dep>>, + wait: &[&'dep Event<'dep>], + ) -> Result> + where + 'a: 'result, + 'dep: 'result, + { + unsafe { + self.0 + .append_launch_kernel(kernel, group_count, signal, wait) + }?; + Ok(self) + } + + pub fn execute(self, q: &'a CommandQueue<'a>) -> Result> { + self.0.close()?; + q.execute_and_synchronize(self.0) + } +} + #[derive(Copy, Clone)] -pub struct Slice<'a, T> { +pub struct Slice<'a, T: Copy + Sized> { ptr: *mut c_void, len: usize, marker: PhantomData<&'a T>, } -unsafe impl<'a, T> Send for Slice<'a, T> {} -unsafe impl<'a, T> Sync for Slice<'a, T> {} +unsafe impl<'a, T: Copy + Sized> Send for Slice<'a, T> {} +unsafe impl<'a, T: Copy + Sized> Sync for Slice<'a, T> {} -impl<'a, T> Slice<'a, T> { +impl<'a, T: Copy + Sized> Slice<'a, T> { pub unsafe fn new(ptr: *mut c_void, len: usize) -> Self { Self { ptr, @@ -907,7 +969,7 @@ impl<'a, T> Slice<'a, T> { } } -impl<'a, T> From<&'a [T]> for Slice<'a, T> { +impl<'a, T: Copy + Sized> From<&'a [T]> for Slice<'a, T> { fn from(s: &'a [T]) -> Self { Slice { ptr: s.as_ptr() as *mut _, @@ -917,7 +979,7 @@ impl<'a, T> From<&'a [T]> for Slice<'a, T> { } } -impl<'a, T: Copy> From<&'a DeviceBuffer<'a, T>> for Slice<'a, T> { +impl<'a, T: Copy + Sized> From<&'a DeviceBuffer<'a, T>> for Slice<'a, T> { fn from(b: &'a DeviceBuffer<'a, T>) -> Self { Slice { ptr: b.ptr, @@ -996,13 +1058,21 @@ impl<'a> Event<'a> { Ok(unsafe { Self::from_ffi(result) }) } - unsafe fn raw_slice(e: &[Event]) -> (u32, *mut sys::ze_event_handle_t) { - let ptr = if e.len() == 0 { - ptr::null() - } else { - e.as_ptr() + unsafe fn with_raw_slice<'x, T>( + events: &[&Event<'x>], + f: impl FnOnce(u32, *mut sys::ze_event_handle_t) -> T, + ) -> T { + let (ptr, ev_vec) = match events { + [] => (ptr::null_mut(), None), + [e] => (&e.0 as *const _ as *mut _, None), + _ => { + let mut ev_vec = events.iter().map(|e| e.as_ffi()).collect::>(); + (ev_vec.as_mut_ptr(), Some(ev_vec)) + } }; - (e.len() as u32, ptr as *mut sys::ze_event_handle_t) + let result = f(events.len() as u32, ptr); + drop(ev_vec); + result } } @@ -1042,7 +1112,7 @@ impl<'a> Kernel<'a> { Ok(()) } - pub fn set_arg_buffer>>( + pub fn set_arg_buffer>>( &self, index: u32, buff: Buff, diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 94114db..c9ed9b1 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -202,7 +202,7 @@ impl error::Error for DisplayError {} fn test_ptx_assert< 'a, Input: From + Debug + Copy + PartialEq, - Output: From + Debug + Copy + PartialEq, + Output: From + Debug + Copy + PartialEq + Default, >( name: &str, ptx_text: &'a str, @@ -220,7 +220,7 @@ fn test_ptx_assert< Ok(()) } -fn run_spirv + Copy + Debug, Output: From + Copy + Debug>( +fn run_spirv + Copy + Debug, Output: From + Copy + Debug + Default>( name: &CStr, module: translate::Module, input: &[Input], @@ -286,19 +286,19 @@ fn run_spirv + Copy + Debug, Output: From + Copy + Debug>( let ev1 = ze::Event::new(&event_pool, 1)?; let ev2 = ze::Event::new(&event_pool, 2)?; { - let cmd_list = ze::CommandList::new(&ctx, dev)?; - let init_evs = [ev0, ev1]; - cmd_list.append_memory_copy(&inp_b, input, Some(&init_evs[0]), &[])?; - cmd_list.append_memory_fill(&out_b, 0, Some(&init_evs[1]), &[])?; + let init_evs = [&ev0, &ev1]; kernel.set_group_size(1, 1, 1)?; kernel.set_arg_buffer(0, &inp_b)?; kernel.set_arg_buffer(1, &out_b)?; if use_shared_mem { unsafe { kernel.set_arg_raw(2, 128, ptr::null())? }; } - cmd_list.append_launch_kernel(&kernel, &[1, 1, 1], Some(&ev2), &init_evs)?; - cmd_list.append_memory_copy(&*result, &out_b, None, &[ev2])?; - queue.execute_and_synchronize(cmd_list)?; + ze::CommandListBuilder::new(&ctx, dev)? + .append_memory_copy(&inp_b, input, Some(&init_evs[0]), &[])? + .append_memory_fill(&out_b, &Output::default(), Some(&init_evs[1]), &[])? + .append_launch_kernel(&kernel, &[1, 1, 1], Some(&ev2), &init_evs)? + .append_memory_copy(&*result, &out_b, None, &[&ev2])? + .execute(&queue)?; } } Ok(result) diff --git a/zluda/src/impl/function.rs b/zluda/src/impl/function.rs index e236160..bcb7bd6 100644 --- a/zluda/src/impl/function.rs +++ b/zluda/src/impl/function.rs @@ -145,12 +145,14 @@ pub fn launch_kernel( .set_group_size(block_dim_x, block_dim_y, block_dim_z)?; func.legacy_args.reset(); let cmd_list = stream.command_list()?; - cmd_list.append_launch_kernel( - &mut func.base, - &[grid_dim_x, grid_dim_y, grid_dim_z], - None, - &mut [], - )?; + unsafe { + cmd_list.append_launch_kernel( + &mut func.base, + &[grid_dim_x, grid_dim_y, grid_dim_z], + None, + &mut [], + )?; + } stream.queue.execute_and_synchronize(cmd_list)?; Ok(()) })? diff --git a/zluda/src/impl/memory.rs b/zluda/src/impl/memory.rs index 5db6472..2a6236f 100644 --- a/zluda/src/impl/memory.rs +++ b/zluda/src/impl/memory.rs @@ -13,7 +13,7 @@ pub fn alloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> Result<(), CUresult> pub fn copy_v2(dst: *mut c_void, src: *const c_void, bytesize: usize) -> Result<(), CUresult> { GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| { let cmd_list = stream.command_list()?; - unsafe { cmd_list.append_memory_copy_unsafe(dst, src, bytesize, None, &mut [])? }; + unsafe { cmd_list.append_memory_copy_raw(dst, src, bytesize, None, &mut [])? }; stream.queue.execute_and_synchronize(cmd_list)?; Ok::<_, CUresult>(()) })? @@ -27,22 +27,36 @@ pub fn free_v2(ptr: *mut c_void) -> Result<(), CUresult> { .map_err(|_| CUresult::CUDA_ERROR_INVALID_VALUE)? } -pub(crate) fn set_d32_v2(dst: *mut c_void, ui: u32, n: usize) -> Result<(), CUresult> { +pub(crate) fn set_d32_v2(dst: *mut c_void, mut ui: u32, n: usize) -> Result<(), CUresult> { GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| { let cmd_list = stream.command_list()?; unsafe { - cmd_list.append_memory_fill_unsafe(dst, &ui, mem::size_of::() * n, None, &mut []) + cmd_list.append_memory_fill_raw( + dst, + &mut ui as *mut _ as *mut _, + mem::size_of::(), + mem::size_of::() * n, + None, + &mut [], + ) }?; stream.queue.execute_and_synchronize(cmd_list)?; Ok::<_, CUresult>(()) })? } -pub(crate) fn set_d8_v2(dst: *mut c_void, uc: u8, n: usize) -> Result<(), CUresult> { +pub(crate) fn set_d8_v2(dst: *mut c_void, mut uc: u8, n: usize) -> Result<(), CUresult> { GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| { let cmd_list = stream.command_list()?; unsafe { - cmd_list.append_memory_fill_unsafe(dst, &uc, mem::size_of::() * n, None, &mut []) + cmd_list.append_memory_fill_raw( + dst, + &mut uc as *mut _ as *mut _, + mem::size_of::(), + mem::size_of::() * n, + None, + &mut [], + ) }?; stream.queue.execute_and_synchronize(cmd_list)?; Ok::<_, CUresult>(())