diff --git a/level_zero/src/ze.rs b/level_zero/src/ze.rs index f4cd0ae..703f2ce 100644 --- a/level_zero/src/ze.rs +++ b/level_zero/src/ze.rs @@ -600,6 +600,45 @@ impl<'a> Module<'a> { (Ok(unsafe { Self::from_ffi(result) }), log) } } + + pub fn get_global_pointer(&self, global_name: &CStr) -> Result<(usize, *mut c_void)> { + let slice = global_name.to_bytes_with_nul(); + let mut result_size = 0; + let mut result_ptr = ptr::null_mut(); + check!(sys::zeModuleGetGlobalPointer( + self.as_ffi(), + slice.as_ptr() as *const _, + &mut result_size, + &mut result_ptr, + )); + Ok((result_size, result_ptr)) + } + + pub fn dynamic_link(modules: &[&Module]) -> Result<()> { + unsafe { + Self::with_raw_slice(modules, |num, ptr| { + check!(sys::zeModuleDynamicLink(num, ptr, ptr::null_mut())); + Ok(()) + }) + } + } + + unsafe fn with_raw_slice<'x, T>( + modules: &[&Module<'x>], + f: impl FnOnce(u32, *mut sys::ze_module_handle_t) -> T, + ) -> T { + let (ptr, mod_vec) = match modules { + [] => (ptr::null_mut(), None), + [e] => (&e.0 as *const _ as *mut _, None), + _ => { + let mut ev_vec = modules.iter().map(|e| e.as_ffi()).collect::>(); + (ev_vec.as_mut_ptr(), Some(ev_vec)) + } + }; + let result = f(modules.len() as u32, ptr); + drop(mod_vec); + result + } } impl<'a> Drop for Module<'a> { diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index e39280a..7170950 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -553,9 +553,9 @@ fn emit_denorm_build_string( } } if flush_over_preserve > 0 { - CString::new("-cl-denorms-are-zero").unwrap() + CString::new("-ze-take-global-address -ze-denorms-are-zero").unwrap() } else { - CString::default() + CString::new("-ze-take-global-address").unwrap() } } @@ -4973,7 +4973,9 @@ impl PtxSpecialRegister { PtxSpecialRegister::Tid | PtxSpecialRegister::Tid64 => { spirv::BuiltIn::LocalInvocationId } - PtxSpecialRegister::Ntid | PtxSpecialRegister::Ntid64 => spirv::BuiltIn::WorkgroupSize, + PtxSpecialRegister::Ntid | PtxSpecialRegister::Ntid64 => { + spirv::BuiltIn::EnqueuedWorkgroupSize + } PtxSpecialRegister::Ctaid | PtxSpecialRegister::Ctaid64 => spirv::BuiltIn::WorkgroupId, PtxSpecialRegister::Nctaid | PtxSpecialRegister::Nctaid64 => { spirv::BuiltIn::NumWorkgroups