diff --git a/level_zero/src/ze.rs b/level_zero/src/ze.rs index 91386cd..ef44c9d 100644 --- a/level_zero/src/ze.rs +++ b/level_zero/src/ze.rs @@ -919,6 +919,19 @@ impl<'a> CommandList<'a> { }) } + pub unsafe fn append_barrier(&self, signal: Option<&Event>, wait: &[&Event]) -> Result<()> { + let signal_event = signal.map_or(ptr::null_mut(), |e| e.as_ffi()); + Event::with_raw_slice(wait, |wait_len, wait_ptr| { + check!(sys::zeCommandListAppendBarrier( + self.as_ffi(), + signal_event, + wait_len, + wait_ptr + )); + Ok(()) + }) + } + pub fn close(&self) -> Result<()> { check!(sys::zeCommandListClose(self.as_ffi())); Ok(()) @@ -1068,11 +1081,16 @@ impl<'a> EventPool<'a> { Self(NonNull::new_unchecked(x), PhantomData) } - pub fn new(ctx: &'a Context, count: u32, devs: Option<&[Device]>) -> Result { + pub fn new( + ctx: &'a Context, + flags: sys::ze_event_pool_flags_t, + count: u32, + devs: Option<&[Device]>, + ) -> Result { let desc = sys::ze_event_pool_desc_t { stype: sys::ze_structure_type_t::ZE_STRUCTURE_TYPE_EVENT_POOL_DESC, pNext: ptr::null(), - flags: sys::ze_event_pool_flags_t(0), + flags: flags, count: count, }; let (dev_len, dev_ptr) = devs.map_or((0, ptr::null_mut()), |devs| { @@ -1109,13 +1127,18 @@ impl<'a> Event<'a> { Self(NonNull::new_unchecked(x), PhantomData) } - pub fn new(pool: &'a EventPool<'a>, index: u32) -> Result { + pub fn new( + pool: &'a EventPool<'a>, + index: u32, + signal: sys::ze_event_scope_flags_t, + wait: sys::ze_event_scope_flags_t, + ) -> Result { let desc = sys::ze_event_desc_t { stype: sys::ze_structure_type_t::ZE_STRUCTURE_TYPE_EVENT_DESC, pNext: ptr::null(), index: index, - signal: sys::ze_event_scope_flags_t(0), - wait: sys::ze_event_scope_flags_t(0), + signal, + wait, }; let mut result = ptr::null_mut(); check!(sys::zeEventCreate(pool.as_ffi(), &desc, &mut result)); diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 4f8b252..226043f 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -283,10 +283,26 @@ fn run_spirv + Copy + Debug, Output: From + Copy + Debug + D )?; let inp_b = ze::DeviceBuffer::::new(&ctx, dev, cmp::max(input.len(), 1))?; let out_b = ze::DeviceBuffer::::new(&ctx, dev, cmp::max(output.len(), 1))?; - let event_pool = ze::EventPool::new(&ctx, 3, Some(&[dev]))?; - let ev0 = ze::Event::new(&event_pool, 0)?; - let ev1 = ze::Event::new(&event_pool, 1)?; - let ev2 = ze::Event::new(&event_pool, 2)?; + let event_pool = + ze::EventPool::new(&ctx, ze::sys::ze_event_pool_flags_t(0), 3, Some(&[dev]))?; + let ev0 = ze::Event::new( + &event_pool, + 0, + ze::sys::ze_event_scope_flags_t(0), + ze::sys::ze_event_scope_flags_t(0), + )?; + let ev1 = ze::Event::new( + &event_pool, + 1, + ze::sys::ze_event_scope_flags_t(0), + ze::sys::ze_event_scope_flags_t(0), + )?; + let ev2 = ze::Event::new( + &event_pool, + 2, + ze::sys::ze_event_scope_flags_t(0), + ze::sys::ze_event_scope_flags_t(0), + )?; { let init_evs = [&ev0, &ev1]; kernel.set_group_size(1, 1, 1)?; diff --git a/zluda/src/impl/context.rs b/zluda/src/impl/context.rs index 9ea0874..8d7a465 100644 --- a/zluda/src/impl/context.rs +++ b/zluda/src/impl/context.rs @@ -102,9 +102,10 @@ impl ContextData { l0_dev: l0::Device, flags: c_uint, is_primary: bool, + host_event: (l0::Event<'static>, u64), dev: *mut device::Device, ) -> Result { - let default_stream = StreamData::new_unitialized(l0_ctx, l0_dev)?; + let default_stream = StreamData::new_unitialized(l0_ctx, l0_dev, host_event)?; Ok(ContextData { flags: AtomicU32::new(flags), device: dev, @@ -136,10 +137,11 @@ pub fn create_v2( let mut ctx_box = GlobalState::lock_device(dev_idx, |dev| { let dev_ptr = dev as *mut _; let mut ctx_box = Box::new(LiveCheck::new(ContextData::new( - &mut dev.l0_context, + &dev.l0_context, dev.base, flags, false, + dev.host_event_pool.get(dev.base, &dev.l0_context)?, dev_ptr as *mut _, )?)); ctx_box.late_init(); diff --git a/zluda/src/impl/device.rs b/zluda/src/impl/device.rs index 0594252..e686f27 100644 --- a/zluda/src/impl/device.rs +++ b/zluda/src/impl/device.rs @@ -21,7 +21,8 @@ pub struct Device { pub default_queue: l0::CommandQueue<'static>, pub l0_context: l0::Context, pub primary_context: context::Context, - pub event_pool: DynamicEventPool, + pub device_event_pool: DynamicEventPool, + pub host_event_pool: DynamicEventPool, properties: Option>, image_properties: Option>, memory_properties: Option>, @@ -36,21 +37,36 @@ impl Device { unsafe fn new(drv: &l0::Driver, l0_dev: l0::Device, idx: usize) -> Result { let ctx = l0::Context::new(*drv, Some(&[l0_dev]))?; let queue = l0::CommandQueue::new(mem::transmute(&ctx), l0_dev)?; + let mut host_event_pool = DynamicEventPool::new( + l0_dev, + transmute_lifetime(&ctx), + l0::sys::ze_event_pool_flags_t::ZE_EVENT_POOL_FLAG_HOST_VISIBLE, + l0::sys::ze_event_scope_flags_t::ZE_EVENT_SCOPE_FLAG_HOST, + )?; + let host_event = + transmute_lifetime_mut(&mut host_event_pool).get(l0_dev, transmute_lifetime(&ctx))?; let primary_context = context::Context::new(context::ContextData::new( - mem::transmute(&ctx), + transmute_lifetime(&ctx), l0_dev, 0, true, + host_event, ptr::null_mut(), )?); - let event_pool = DynamicEventPool::new(l0_dev, transmute_lifetime(&ctx))?; + let device_event_pool = DynamicEventPool::new( + l0_dev, + transmute_lifetime(&ctx), + l0::sys::ze_event_pool_flags_t(0), + l0::sys::ze_event_scope_flags_t(0), + )?; Ok(Self { index: Index(idx as c_int), base: l0_dev, default_queue: queue, l0_context: ctx, primary_context: primary_context, - event_pool, + device_event_pool, + host_event_pool, properties: None, image_properties: None, memory_properties: None, @@ -400,14 +416,23 @@ pub(crate) fn primary_ctx_release_v2(_dev_idx: Index) -> CUresult { pub struct DynamicEventPool { count: usize, + pool_flags: l0::sys::ze_event_pool_flags_t, + signal_flags: l0::sys::ze_event_scope_flags_t, events: Vec, } impl DynamicEventPool { - fn new(dev: l0::Device, ctx: &'static l0::Context) -> l0::Result { + fn new( + dev: l0::Device, + ctx: &'static l0::Context, + pool_flags: l0::sys::ze_event_pool_flags_t, + signal_flags: l0::sys::ze_event_scope_flags_t, + ) -> l0::Result { Ok(DynamicEventPool { count: 0, - events: vec![DynamicEventPoolEntry::new(dev, ctx)?], + pool_flags, + signal_flags, + events: vec![DynamicEventPoolEntry::new(dev, ctx, pool_flags)?], }) } @@ -420,14 +445,17 @@ impl DynamicEventPool { let events = unsafe { transmute_lifetime_mut(&mut self.events) }; let (global_idx, (ev, local_idx)) = { for (idx, entry) in self.events.iter_mut().enumerate() { - if let Some((ev, local_idx)) = entry.get()? { + if let Some((ev, local_idx)) = entry.get(self.signal_flags)? { let marker = (idx << 32) as u64 | local_idx as u64; return Ok((ev, marker)); } } - events.push(DynamicEventPoolEntry::new(dev, ctx)?); + events.push(DynamicEventPoolEntry::new(dev, ctx, self.pool_flags)?); let global_idx = (events.len() - 1) as u64; - (global_idx, events.last_mut().unwrap().get()?.unwrap()) + ( + global_idx, + events.last_mut().unwrap().get(self.signal_flags)?.unwrap(), + ) }; let marker = (global_idx << 32) | local_idx as u64; Ok((ev, marker)) @@ -452,10 +480,15 @@ struct DynamicEventPoolEntry { } impl DynamicEventPoolEntry { - fn new(dev: l0::Device, ctx: &'static l0::Context) -> l0::Result { + fn new( + dev: l0::Device, + ctx: &'static l0::Context, + flags: l0::sys::ze_event_pool_flags_t, + ) -> l0::Result { Ok(DynamicEventPoolEntry { event_pool: l0::EventPool::new( ctx, + flags, DYNAMIC_EVENT_POOL_ENTRY_SIZE as u32, Some(&[dev]), )?, @@ -463,7 +496,10 @@ impl DynamicEventPoolEntry { }) } - fn get(&'static mut self) -> l0::Result, u32)>> { + fn get( + &'static mut self, + signal: l0::sys::ze_event_scope_flags_t, + ) -> l0::Result, u32)>> { for (idx, value) in self.bit_map.iter_mut().enumerate() { let shift = first_index_of_zero_u64(*value); if shift == 64 { @@ -471,7 +507,12 @@ impl DynamicEventPoolEntry { } *value = *value | (1u64 << shift); let entry_index = (idx as u32 * 64u32) + shift; - let event = l0::Event::new(&self.event_pool, entry_index)?; + let event = l0::Event::new( + &self.event_pool, + entry_index, + signal, + l0::sys::ze_event_scope_flags_t(0), + )?; return Ok(Some((event, entry_index))); } Ok(None) diff --git a/zluda/src/impl/stream.rs b/zluda/src/impl/stream.rs index 59f8778..1de422b 100644 --- a/zluda/src/impl/stream.rs +++ b/zluda/src/impl/stream.rs @@ -38,28 +38,35 @@ pub struct StreamData { pub busy_events: VecDeque<(l0::Event<'static>, u64)>, // This could be a Vec, but I'd rather reuse earliest enqueued event not the one recently enqueued pub free_events: VecDeque<(l0::Event<'static>, u64)>, + pub synchronization_event: (l0::Event<'static>, u64), } impl StreamData { pub fn new_unitialized( ctx: &'static l0::Context, device: l0::Device, + host_event: (l0::Event<'static>, u64), ) -> Result { Ok(StreamData { context: ptr::null_mut(), cmd_list: l0::CommandList::new_immediate(ctx, device)?, busy_events: VecDeque::new(), free_events: VecDeque::new(), + synchronization_event: host_event, }) } pub fn new(ctx: &mut ContextData) -> Result { let l0_ctx = &mut unsafe { &mut *ctx.device }.l0_context; let device = unsafe { &*ctx.device }.base; + let synchronization_event = unsafe { &mut *ctx.device } + .host_event_pool + .get(device, l0_ctx)?; Ok(StreamData { context: ctx as *mut _, cmd_list: l0::CommandList::new_immediate(l0_ctx, device)?, busy_events: VecDeque::new(), free_events: VecDeque::new(), + synchronization_event, }) } @@ -98,9 +105,17 @@ impl StreamData { } pub fn synchronize(&mut self) -> l0::Result<()> { - if let Some((ev, _)) = self.busy_events.back() { - ev.host_synchronize(u64::MAX)?; - } + let empty = []; + let busy_event_arr = self.busy_events.back().map(|(ev, _)| [ev]); + let wait_events = busy_event_arr.as_ref().map_or(&empty[..], |arr| &arr[..]); + unsafe { + self.cmd_list + .append_barrier(Some(&self.synchronization_event.0), wait_events)? + }; + self.synchronization_event + .0 + .host_synchronize(u64::max_value())?; + self.synchronization_event.0.host_reset()?; self.reuse_all_finished_events()?; Ok(()) } @@ -114,7 +129,7 @@ impl StreamData { .pop_front() .map(|x| Ok(x)) .unwrap_or_else(|| { - let event_pool = unsafe { &mut (*(*self.context).device).event_pool }; + let event_pool = unsafe { &mut (*(*self.context).device).device_event_pool }; event_pool.get(l0_dev, l0_ctx) }) } @@ -126,8 +141,8 @@ impl Drop for StreamData { return; } for (_, marker) in self.busy_events.iter().chain(self.free_events.iter()) { - let event_pool = unsafe { &mut (*(*self.context).device).event_pool }; - event_pool.mark_as_free(*marker); + let device_event_pool = unsafe { &mut (*(*self.context).device).device_event_pool }; + device_event_pool.mark_as_free(*marker); } unsafe { (&mut *self.context).streams.remove(&(&mut *self as *mut _)) }; }