Refactor L0 bindings

This commit is contained in:
Andrzej Janik 2021-05-27 02:05:17 +02:00
parent 58a7fe53c6
commit e40785aa74
9 changed files with 577 additions and 419 deletions

File diff suppressed because it is too large Load diff

View file

@ -201,8 +201,8 @@ impl<T: Debug> error::Error for DisplayError<T> {}
fn test_ptx_assert<
'a,
Input: From<u8> + ze::SafeRepr + Debug + Copy + PartialEq,
Output: From<u8> + ze::SafeRepr + Debug + Copy + PartialEq,
Input: From<u8> + Debug + Copy + PartialEq,
Output: From<u8> + Debug + Copy + PartialEq,
>(
name: &str,
ptx_text: &'a str,
@ -220,10 +220,7 @@ fn test_ptx_assert<
Ok(())
}
fn run_spirv<
Input: From<u8> + ze::SafeRepr + Copy + Debug,
Output: From<u8> + ze::SafeRepr + Copy + Debug,
>(
fn run_spirv<Input: From<u8> + Copy + Debug, Output: From<u8> + Copy + Debug>(
name: &CStr,
module: translate::Module,
input: &[Input],
@ -242,25 +239,25 @@ fn run_spirv<
.get(name.to_str().unwrap())
.map(|info| info.uses_shared_mem)
.unwrap_or(false);
let mut result = vec![0u8.into(); output.len()];
let result = vec![0u8.into(); output.len()];
{
let mut drivers = ze::Driver::get()?;
let drv = drivers.drain(0..1).next().unwrap();
let mut ctx = ze::Context::new(&drv)?;
let mut devices = drv.devices()?;
let dev = devices.drain(0..1).next().unwrap();
let queue = ze::CommandQueue::new(&mut ctx, &dev)?;
let ctx = ze::Context::new(drv, None)?;
let queue = ze::CommandQueue::new(&ctx, dev)?;
let (module, maybe_log) = match module.should_link_ptx_impl {
Some(ptx_impl) => ze::Module::build_link_spirv(
&mut ctx,
&dev,
&ctx,
dev,
&[ptx_impl, byte_il],
Some(module.build_options.as_c_str()),
),
None => {
let (module, log) = ze::Module::build_spirv_logged(
&mut ctx,
&dev,
&ctx,
dev,
byte_il,
Some(module.build_options.as_c_str()),
);
@ -271,38 +268,38 @@ fn run_spirv<
Ok(m) => m,
Err(err) => {
let raw_err_string = maybe_log
.map(|log| log.get_cstring())
.map(|log| log.to_cstring())
.transpose()?
.unwrap_or(CString::default());
let err_string = raw_err_string.to_string_lossy();
panic!("{:?}\n{}", err, err_string);
}
};
let mut kernel = ze::Kernel::new_resident(&module, name)?;
let kernel = ze::Kernel::new_resident(&module, name)?;
kernel.set_indirect_access(
ze::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE,
)?;
let mut inp_b = ze::DeviceBuffer::<Input>::new(&mut ctx, &dev, cmp::max(input.len(), 1))?;
let mut out_b = ze::DeviceBuffer::<Output>::new(&mut ctx, &dev, cmp::max(output.len(), 1))?;
let inp_b_ptr_mut: ze::BufferPtrMut<Input> = (&mut inp_b).into();
let event_pool = ze::EventPool::new(&mut ctx, 3, Some(&[&dev]))?;
let inp_b = ze::DeviceBuffer::<Input>::new(&ctx, dev, cmp::max(input.len(), 1))?;
let out_b = ze::DeviceBuffer::<Output>::new(&ctx, dev, cmp::max(output.len(), 1))?;
let event_pool = ze::EventPool::new(&ctx, 3, Some(&[dev]))?;
let ev0 = ze::Event::new(&event_pool, 0)?;
let ev1 = ze::Event::new(&event_pool, 1)?;
let mut ev2 = ze::Event::new(&event_pool, 2)?;
let mut cmd_list = ze::CommandList::new(&mut ctx, &dev)?;
let out_b_ptr_mut: ze::BufferPtrMut<Output> = (&mut out_b).into();
let mut init_evs = [ev0, ev1];
cmd_list.append_memory_copy(inp_b_ptr_mut, input, Some(&mut init_evs[0]), &mut [])?;
cmd_list.append_memory_fill(out_b_ptr_mut, 0, Some(&mut init_evs[1]), &mut [])?;
kernel.set_group_size(1, 1, 1)?;
kernel.set_arg_buffer(0, inp_b_ptr_mut)?;
kernel.set_arg_buffer(1, out_b_ptr_mut)?;
if use_shared_mem {
unsafe { kernel.set_arg_raw(2, 128, ptr::null())? };
let ev2 = ze::Event::new(&event_pool, 2)?;
{
let cmd_list = ze::CommandList::new(&ctx, dev)?;
let init_evs = [ev0, ev1];
cmd_list.append_memory_copy(&inp_b, input, Some(&init_evs[0]), &[])?;
cmd_list.append_memory_fill(&out_b, 0, Some(&init_evs[1]), &[])?;
kernel.set_group_size(1, 1, 1)?;
kernel.set_arg_buffer(0, &inp_b)?;
kernel.set_arg_buffer(1, &out_b)?;
if use_shared_mem {
unsafe { kernel.set_arg_raw(2, 128, ptr::null())? };
}
cmd_list.append_launch_kernel(&kernel, &[1, 1, 1], Some(&ev2), &init_evs)?;
cmd_list.append_memory_copy(&*result, &out_b, None, &[ev2])?;
queue.execute_and_synchronize(cmd_list)?;
}
cmd_list.append_launch_kernel(&kernel, &[1, 1, 1], Some(&mut ev2), &mut init_evs)?;
cmd_list.append_memory_copy(result.as_mut_slice(), out_b_ptr_mut, None, &mut [ev2])?;
queue.execute(cmd_list)?;
}
Ok(result)
}

View file

@ -98,8 +98,8 @@ pub struct ContextData {
impl ContextData {
pub fn new(
l0_ctx: &mut l0::Context,
l0_dev: &l0::Device,
l0_ctx: &'static l0::Context,
l0_dev: l0::Device,
flags: c_uint,
is_primary: bool,
dev: *mut device::Device,
@ -137,7 +137,7 @@ pub fn create_v2(
let dev_ptr = dev as *mut _;
let mut ctx_box = Box::new(LiveCheck::new(ContextData::new(
&mut dev.l0_context,
&dev.base,
dev.base,
flags,
false,
dev_ptr as *mut _,

View file

@ -18,7 +18,7 @@ pub struct Index(pub c_int);
pub struct Device {
pub index: Index,
pub base: l0::Device,
pub default_queue: l0::CommandQueue,
pub default_queue: l0::CommandQueue<'static>,
pub l0_context: l0::Context,
pub primary_context: context::Context,
properties: Option<Box<l0::sys::ze_device_properties_t>>,
@ -31,12 +31,13 @@ unsafe impl Send for Device {}
impl Device {
// Unsafe because it does not fully initalize primary_context
// and we transmute lifetimes left and right
unsafe fn new(drv: &l0::Driver, l0_dev: l0::Device, idx: usize) -> Result<Self, CUresult> {
let mut ctx = l0::Context::new(drv)?;
let queue = l0::CommandQueue::new(&mut ctx, &l0_dev)?;
let ctx = l0::Context::new(*drv, Some(&[l0_dev]))?;
let queue = l0::CommandQueue::new(mem::transmute(&ctx), l0_dev)?;
let primary_context = context::Context::new(context::ContextData::new(
&mut ctx,
&l0_dev,
mem::transmute(&ctx),
l0_dev,
0,
true,
ptr::null_mut(),
@ -58,20 +59,18 @@ impl Device {
if let Some(ref prop) = self.properties {
return Ok(prop);
}
match self.base.get_properties() {
Ok(prop) => Ok(self.properties.get_or_insert(prop)),
Err(e) => Err(e),
}
let mut props = Default::default();
self.base.get_properties(&mut props)?;
Ok(self.properties.get_or_insert(Box::new(props)))
}
fn get_image_properties(&mut self) -> l0::Result<&l0::sys::ze_device_image_properties_t> {
if let Some(ref prop) = self.image_properties {
return Ok(prop);
}
match self.base.get_image_properties() {
Ok(prop) => Ok(self.image_properties.get_or_insert(prop)),
Err(e) => Err(e),
}
let mut props = Default::default();
self.base.get_image_properties(&mut props)?;
Ok(self.image_properties.get_or_insert(Box::new(props)))
}
fn get_memory_properties(&mut self) -> l0::Result<&[l0::sys::ze_device_memory_properties_t]> {
@ -88,10 +87,9 @@ impl Device {
if let Some(ref prop) = self.compute_properties {
return Ok(prop);
}
match self.base.get_compute_properties() {
Ok(prop) => Ok(self.compute_properties.get_or_insert(prop)),
Err(e) => Err(e),
}
let mut props = Default::default();
self.base.get_compute_properties(&mut props)?;
Ok(self.compute_properties.get_or_insert(Box::new(props)))
}
pub fn late_init(&mut self) {
@ -351,7 +349,11 @@ pub fn get_uuid(uuid: *mut CUuuid_st, dev_idx: Index) -> Result<(), CUresult> {
}
// TODO: add support if Level 0 exposes it
pub fn get_luid(luid: *mut c_char, dev_node_mask: *mut c_uint, _dev_idx: Index) -> Result<(), CUresult> {
pub fn get_luid(
luid: *mut c_char,
dev_node_mask: *mut c_uint,
_dev_idx: Index,
) -> Result<(), CUresult> {
unsafe { ptr::write_bytes(luid, 0u8, 8) };
unsafe { *dev_node_mask = 0 };
Ok(())

View file

@ -144,14 +144,14 @@ pub fn launch_kernel(
func.base
.set_group_size(block_dim_x, block_dim_y, block_dim_z)?;
func.legacy_args.reset();
let mut cmd_list = stream.command_list()?;
let cmd_list = stream.command_list()?;
cmd_list.append_launch_kernel(
&mut func.base,
&[grid_dim_x, grid_dim_y, grid_dim_z],
None,
&mut [],
)?;
stream.queue.execute(cmd_list)?;
stream.queue.execute_and_synchronize(cmd_list)?;
Ok(())
})?
}

View file

@ -4,7 +4,7 @@ use std::{ffi::c_void, mem};
pub fn alloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> Result<(), CUresult> {
let ptr = GlobalState::lock_current_context(|ctx| {
let dev = unsafe { &mut *ctx.device };
Ok::<_, CUresult>(unsafe { dev.base.mem_alloc_device(&mut dev.l0_context, bytesize, 0) }?)
Ok::<_, CUresult>(dev.l0_context.mem_alloc_device(bytesize, 0, dev.base)?)
})??;
unsafe { *dptr = ptr };
Ok(())
@ -12,9 +12,9 @@ pub fn alloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> Result<(), CUresult>
pub fn copy_v2(dst: *mut c_void, src: *const c_void, bytesize: usize) -> Result<(), CUresult> {
GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| {
let mut cmd_list = stream.command_list()?;
unsafe { cmd_list.append_memory_copy_unsafe(dst, src, bytesize, None, &mut []) }?;
stream.queue.execute(cmd_list)?;
let cmd_list = stream.command_list()?;
unsafe { cmd_list.append_memory_copy_unsafe(dst, src, bytesize, None, &mut [])? };
stream.queue.execute_and_synchronize(cmd_list)?;
Ok::<_, CUresult>(())
})?
}
@ -22,29 +22,29 @@ pub fn copy_v2(dst: *mut c_void, src: *const c_void, bytesize: usize) -> Result<
pub fn free_v2(ptr: *mut c_void) -> Result<(), CUresult> {
GlobalState::lock_current_context(|ctx| {
let dev = unsafe { &mut *ctx.device };
Ok::<_, CUresult>(unsafe { dev.l0_context.mem_free(ptr) }?)
Ok::<_, CUresult>(dev.l0_context.mem_free(ptr)?)
})
.map_err(|_| CUresult::CUDA_ERROR_INVALID_VALUE)?
}
pub(crate) fn set_d32_v2(dst: *mut c_void, ui: u32, n: usize) -> Result<(), CUresult> {
GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| {
let mut cmd_list = stream.command_list()?;
let cmd_list = stream.command_list()?;
unsafe {
cmd_list.append_memory_fill_unsafe(dst, &ui, mem::size_of::<u32>() * n, None, &mut [])
}?;
stream.queue.execute(cmd_list)?;
stream.queue.execute_and_synchronize(cmd_list)?;
Ok::<_, CUresult>(())
})?
}
pub(crate) fn set_d8_v2(dst: *mut c_void, uc: u8, n: usize) -> Result<(), CUresult> {
GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| {
let mut cmd_list = stream.command_list()?;
let cmd_list = stream.command_list()?;
unsafe {
cmd_list.append_memory_fill_unsafe(dst, &uc, mem::size_of::<u8>() * n, None, &mut [])
}?;
stream.queue.execute(cmd_list)?;
stream.queue.execute_and_synchronize(cmd_list)?;
Ok::<_, CUresult>(())
})?
}

View file

@ -41,7 +41,7 @@ pub struct SpirvModule {
}
pub struct CompiledModule {
pub base: l0::Module,
pub base: l0::Module<'static>,
pub kernels: HashMap<CString, Box<Function>>,
}
@ -78,7 +78,11 @@ impl SpirvModule {
})
}
pub fn compile(&self, ctx: &mut l0::Context, dev: &l0::Device) -> Result<l0::Module, CUresult> {
pub fn compile<'a>(
&self,
ctx: &'a l0::Context,
dev: l0::Device,
) -> Result<l0::Module<'a>, CUresult> {
let byte_il = unsafe {
slice::from_raw_parts(
self.binaries.as_ptr() as *const u8,
@ -86,13 +90,11 @@ impl SpirvModule {
)
};
let l0_module = match self.should_link_ptx_impl {
None => {
l0::Module::build_spirv(ctx, dev, byte_il, Some(self.build_options.as_c_str()))
}
None => l0::Module::build_spirv(ctx, dev, byte_il, Some(self.build_options.as_c_str())),
Some(ptx_impl) => {
l0::Module::build_link_spirv(
ctx,
&dev,
dev,
&[ptx_impl, byte_il],
Some(self.build_options.as_c_str()),
)
@ -119,7 +121,7 @@ pub fn get_function(
hash_map::Entry::Occupied(entry) => entry.into_mut(),
hash_map::Entry::Vacant(entry) => {
let new_module = CompiledModule {
base: module.spirv.compile(&mut device.l0_context, &device.base)?,
base: module.spirv.compile(&mut device.l0_context, device.base)?,
kernels: HashMap::new(),
};
entry.insert(new_module)
@ -135,7 +137,7 @@ pub fn get_function(
std::str::from_utf8_unchecked(entry.key().as_c_str().to_bytes())
})
.ok_or(CUresult::CUDA_ERROR_NOT_FOUND)?;
let mut kernel =
let kernel =
l0::Kernel::new_resident(&compiled_module.base, entry.key().as_c_str())?;
kernel.set_indirect_access(
l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE
@ -165,7 +167,7 @@ pub(crate) fn load_data(pmod: *mut *mut Module, image: *const c_void) -> Result<
pub fn load_data_impl(pmod: *mut *mut Module, spirv_data: SpirvModule) -> Result<(), CUresult> {
let module = GlobalState::lock_current_context(|ctx| {
let device = unsafe { &mut *ctx.device };
let l0_module = spirv_data.compile(&mut device.l0_context, &device.base)?;
let l0_module = spirv_data.compile(&device.l0_context, device.base)?;
let mut device_binaries = HashMap::new();
let compiled_module = CompiledModule {
base: l0_module,

View file

@ -33,11 +33,11 @@ impl HasLivenessCookie for StreamData {
pub struct StreamData {
pub context: *mut ContextData,
pub queue: l0::CommandQueue,
pub queue: l0::CommandQueue<'static>,
}
impl StreamData {
pub fn new_unitialized(ctx: &mut l0::Context, dev: &l0::Device) -> Result<Self, CUresult> {
pub fn new_unitialized(ctx: &'static l0::Context, dev: l0::Device) -> Result<Self, CUresult> {
Ok(StreamData {
context: ptr::null_mut(),
queue: l0::CommandQueue::new(ctx, dev)?,
@ -45,7 +45,7 @@ impl StreamData {
}
pub fn new(ctx: &mut ContextData) -> Result<Self, CUresult> {
let l0_ctx = &mut unsafe { &mut *ctx.device }.l0_context;
let l0_dev = &unsafe { &*ctx.device }.base;
let l0_dev = unsafe { &*ctx.device }.base;
Ok(StreamData {
context: ctx as *mut _,
queue: l0::CommandQueue::new(l0_ctx, l0_dev)?,
@ -55,7 +55,7 @@ impl StreamData {
pub fn command_list(&self) -> Result<l0::CommandList, l0::sys::_ze_result_t> {
let ctx = unsafe { &mut *self.context };
let dev = unsafe { &mut *ctx.device };
l0::CommandList::new(&mut dev.l0_context, &dev.base)
l0::CommandList::new(&mut dev.l0_context, dev.base)
}
}

View file

@ -127,7 +127,8 @@ pub(crate) fn system_get_driver_version(
len: 0,
};
for d in drivers {
let props = d.get_properties()?;
let mut props = Default::default();
d.get_properties(&mut props)?;
let driver_version = props.driverVersion;
write!(&mut output_write, "{}", driver_version)
.map_err(|_| nvmlReturn_t::NVML_ERROR_UNKNOWN)?;