Prepare level zero and our compiler for global addressing

This commit is contained in:
Andrzej Janik 2021-06-20 12:13:40 +02:00
parent 2fc7af0434
commit e018de83ae
2 changed files with 44 additions and 3 deletions

View file

@ -600,6 +600,45 @@ impl<'a> Module<'a> {
(Ok(unsafe { Self::from_ffi(result) }), log)
}
}
pub fn get_global_pointer(&self, global_name: &CStr) -> Result<(usize, *mut c_void)> {
let slice = global_name.to_bytes_with_nul();
let mut result_size = 0;
let mut result_ptr = ptr::null_mut();
check!(sys::zeModuleGetGlobalPointer(
self.as_ffi(),
slice.as_ptr() as *const _,
&mut result_size,
&mut result_ptr,
));
Ok((result_size, result_ptr))
}
pub fn dynamic_link(modules: &[&Module]) -> Result<()> {
unsafe {
Self::with_raw_slice(modules, |num, ptr| {
check!(sys::zeModuleDynamicLink(num, ptr, ptr::null_mut()));
Ok(())
})
}
}
unsafe fn with_raw_slice<'x, T>(
modules: &[&Module<'x>],
f: impl FnOnce(u32, *mut sys::ze_module_handle_t) -> T,
) -> T {
let (ptr, mod_vec) = match modules {
[] => (ptr::null_mut(), None),
[e] => (&e.0 as *const _ as *mut _, None),
_ => {
let mut ev_vec = modules.iter().map(|e| e.as_ffi()).collect::<Vec<_>>();
(ev_vec.as_mut_ptr(), Some(ev_vec))
}
};
let result = f(modules.len() as u32, ptr);
drop(mod_vec);
result
}
}
impl<'a> Drop for Module<'a> {

View file

@ -553,9 +553,9 @@ fn emit_denorm_build_string(
}
}
if flush_over_preserve > 0 {
CString::new("-cl-denorms-are-zero").unwrap()
CString::new("-ze-take-global-address -ze-denorms-are-zero").unwrap()
} else {
CString::default()
CString::new("-ze-take-global-address").unwrap()
}
}
@ -4973,7 +4973,9 @@ impl PtxSpecialRegister {
PtxSpecialRegister::Tid | PtxSpecialRegister::Tid64 => {
spirv::BuiltIn::LocalInvocationId
}
PtxSpecialRegister::Ntid | PtxSpecialRegister::Ntid64 => spirv::BuiltIn::WorkgroupSize,
PtxSpecialRegister::Ntid | PtxSpecialRegister::Ntid64 => {
spirv::BuiltIn::EnqueuedWorkgroupSize
}
PtxSpecialRegister::Ctaid | PtxSpecialRegister::Ctaid64 => spirv::BuiltIn::WorkgroupId,
PtxSpecialRegister::Nctaid | PtxSpecialRegister::Nctaid64 => {
spirv::BuiltIn::NumWorkgroups