mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
Refactor L0 bindings
This commit is contained in:
parent
58a7fe53c6
commit
e40785aa74
9 changed files with 577 additions and 419 deletions
File diff suppressed because it is too large
Load diff
|
@ -201,8 +201,8 @@ impl<T: Debug> error::Error for DisplayError<T> {}
|
|||
|
||||
fn test_ptx_assert<
|
||||
'a,
|
||||
Input: From<u8> + ze::SafeRepr + Debug + Copy + PartialEq,
|
||||
Output: From<u8> + ze::SafeRepr + Debug + Copy + PartialEq,
|
||||
Input: From<u8> + Debug + Copy + PartialEq,
|
||||
Output: From<u8> + Debug + Copy + PartialEq,
|
||||
>(
|
||||
name: &str,
|
||||
ptx_text: &'a str,
|
||||
|
@ -220,10 +220,7 @@ fn test_ptx_assert<
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn run_spirv<
|
||||
Input: From<u8> + ze::SafeRepr + Copy + Debug,
|
||||
Output: From<u8> + ze::SafeRepr + Copy + Debug,
|
||||
>(
|
||||
fn run_spirv<Input: From<u8> + Copy + Debug, Output: From<u8> + Copy + Debug>(
|
||||
name: &CStr,
|
||||
module: translate::Module,
|
||||
input: &[Input],
|
||||
|
@ -242,25 +239,25 @@ fn run_spirv<
|
|||
.get(name.to_str().unwrap())
|
||||
.map(|info| info.uses_shared_mem)
|
||||
.unwrap_or(false);
|
||||
let mut result = vec![0u8.into(); output.len()];
|
||||
let result = vec![0u8.into(); output.len()];
|
||||
{
|
||||
let mut drivers = ze::Driver::get()?;
|
||||
let drv = drivers.drain(0..1).next().unwrap();
|
||||
let mut ctx = ze::Context::new(&drv)?;
|
||||
let mut devices = drv.devices()?;
|
||||
let dev = devices.drain(0..1).next().unwrap();
|
||||
let queue = ze::CommandQueue::new(&mut ctx, &dev)?;
|
||||
let ctx = ze::Context::new(drv, None)?;
|
||||
let queue = ze::CommandQueue::new(&ctx, dev)?;
|
||||
let (module, maybe_log) = match module.should_link_ptx_impl {
|
||||
Some(ptx_impl) => ze::Module::build_link_spirv(
|
||||
&mut ctx,
|
||||
&dev,
|
||||
&ctx,
|
||||
dev,
|
||||
&[ptx_impl, byte_il],
|
||||
Some(module.build_options.as_c_str()),
|
||||
),
|
||||
None => {
|
||||
let (module, log) = ze::Module::build_spirv_logged(
|
||||
&mut ctx,
|
||||
&dev,
|
||||
&ctx,
|
||||
dev,
|
||||
byte_il,
|
||||
Some(module.build_options.as_c_str()),
|
||||
);
|
||||
|
@ -271,38 +268,38 @@ fn run_spirv<
|
|||
Ok(m) => m,
|
||||
Err(err) => {
|
||||
let raw_err_string = maybe_log
|
||||
.map(|log| log.get_cstring())
|
||||
.map(|log| log.to_cstring())
|
||||
.transpose()?
|
||||
.unwrap_or(CString::default());
|
||||
let err_string = raw_err_string.to_string_lossy();
|
||||
panic!("{:?}\n{}", err, err_string);
|
||||
}
|
||||
};
|
||||
let mut kernel = ze::Kernel::new_resident(&module, name)?;
|
||||
let kernel = ze::Kernel::new_resident(&module, name)?;
|
||||
kernel.set_indirect_access(
|
||||
ze::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE,
|
||||
)?;
|
||||
let mut inp_b = ze::DeviceBuffer::<Input>::new(&mut ctx, &dev, cmp::max(input.len(), 1))?;
|
||||
let mut out_b = ze::DeviceBuffer::<Output>::new(&mut ctx, &dev, cmp::max(output.len(), 1))?;
|
||||
let inp_b_ptr_mut: ze::BufferPtrMut<Input> = (&mut inp_b).into();
|
||||
let event_pool = ze::EventPool::new(&mut ctx, 3, Some(&[&dev]))?;
|
||||
let inp_b = ze::DeviceBuffer::<Input>::new(&ctx, dev, cmp::max(input.len(), 1))?;
|
||||
let out_b = ze::DeviceBuffer::<Output>::new(&ctx, dev, cmp::max(output.len(), 1))?;
|
||||
let event_pool = ze::EventPool::new(&ctx, 3, Some(&[dev]))?;
|
||||
let ev0 = ze::Event::new(&event_pool, 0)?;
|
||||
let ev1 = ze::Event::new(&event_pool, 1)?;
|
||||
let mut ev2 = ze::Event::new(&event_pool, 2)?;
|
||||
let mut cmd_list = ze::CommandList::new(&mut ctx, &dev)?;
|
||||
let out_b_ptr_mut: ze::BufferPtrMut<Output> = (&mut out_b).into();
|
||||
let mut init_evs = [ev0, ev1];
|
||||
cmd_list.append_memory_copy(inp_b_ptr_mut, input, Some(&mut init_evs[0]), &mut [])?;
|
||||
cmd_list.append_memory_fill(out_b_ptr_mut, 0, Some(&mut init_evs[1]), &mut [])?;
|
||||
kernel.set_group_size(1, 1, 1)?;
|
||||
kernel.set_arg_buffer(0, inp_b_ptr_mut)?;
|
||||
kernel.set_arg_buffer(1, out_b_ptr_mut)?;
|
||||
if use_shared_mem {
|
||||
unsafe { kernel.set_arg_raw(2, 128, ptr::null())? };
|
||||
let ev2 = ze::Event::new(&event_pool, 2)?;
|
||||
{
|
||||
let cmd_list = ze::CommandList::new(&ctx, dev)?;
|
||||
let init_evs = [ev0, ev1];
|
||||
cmd_list.append_memory_copy(&inp_b, input, Some(&init_evs[0]), &[])?;
|
||||
cmd_list.append_memory_fill(&out_b, 0, Some(&init_evs[1]), &[])?;
|
||||
kernel.set_group_size(1, 1, 1)?;
|
||||
kernel.set_arg_buffer(0, &inp_b)?;
|
||||
kernel.set_arg_buffer(1, &out_b)?;
|
||||
if use_shared_mem {
|
||||
unsafe { kernel.set_arg_raw(2, 128, ptr::null())? };
|
||||
}
|
||||
cmd_list.append_launch_kernel(&kernel, &[1, 1, 1], Some(&ev2), &init_evs)?;
|
||||
cmd_list.append_memory_copy(&*result, &out_b, None, &[ev2])?;
|
||||
queue.execute_and_synchronize(cmd_list)?;
|
||||
}
|
||||
cmd_list.append_launch_kernel(&kernel, &[1, 1, 1], Some(&mut ev2), &mut init_evs)?;
|
||||
cmd_list.append_memory_copy(result.as_mut_slice(), out_b_ptr_mut, None, &mut [ev2])?;
|
||||
queue.execute(cmd_list)?;
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
|
|
@ -98,8 +98,8 @@ pub struct ContextData {
|
|||
|
||||
impl ContextData {
|
||||
pub fn new(
|
||||
l0_ctx: &mut l0::Context,
|
||||
l0_dev: &l0::Device,
|
||||
l0_ctx: &'static l0::Context,
|
||||
l0_dev: l0::Device,
|
||||
flags: c_uint,
|
||||
is_primary: bool,
|
||||
dev: *mut device::Device,
|
||||
|
@ -137,7 +137,7 @@ pub fn create_v2(
|
|||
let dev_ptr = dev as *mut _;
|
||||
let mut ctx_box = Box::new(LiveCheck::new(ContextData::new(
|
||||
&mut dev.l0_context,
|
||||
&dev.base,
|
||||
dev.base,
|
||||
flags,
|
||||
false,
|
||||
dev_ptr as *mut _,
|
||||
|
|
|
@ -18,7 +18,7 @@ pub struct Index(pub c_int);
|
|||
pub struct Device {
|
||||
pub index: Index,
|
||||
pub base: l0::Device,
|
||||
pub default_queue: l0::CommandQueue,
|
||||
pub default_queue: l0::CommandQueue<'static>,
|
||||
pub l0_context: l0::Context,
|
||||
pub primary_context: context::Context,
|
||||
properties: Option<Box<l0::sys::ze_device_properties_t>>,
|
||||
|
@ -31,12 +31,13 @@ unsafe impl Send for Device {}
|
|||
|
||||
impl Device {
|
||||
// Unsafe because it does not fully initalize primary_context
|
||||
// and we transmute lifetimes left and right
|
||||
unsafe fn new(drv: &l0::Driver, l0_dev: l0::Device, idx: usize) -> Result<Self, CUresult> {
|
||||
let mut ctx = l0::Context::new(drv)?;
|
||||
let queue = l0::CommandQueue::new(&mut ctx, &l0_dev)?;
|
||||
let ctx = l0::Context::new(*drv, Some(&[l0_dev]))?;
|
||||
let queue = l0::CommandQueue::new(mem::transmute(&ctx), l0_dev)?;
|
||||
let primary_context = context::Context::new(context::ContextData::new(
|
||||
&mut ctx,
|
||||
&l0_dev,
|
||||
mem::transmute(&ctx),
|
||||
l0_dev,
|
||||
0,
|
||||
true,
|
||||
ptr::null_mut(),
|
||||
|
@ -58,20 +59,18 @@ impl Device {
|
|||
if let Some(ref prop) = self.properties {
|
||||
return Ok(prop);
|
||||
}
|
||||
match self.base.get_properties() {
|
||||
Ok(prop) => Ok(self.properties.get_or_insert(prop)),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
let mut props = Default::default();
|
||||
self.base.get_properties(&mut props)?;
|
||||
Ok(self.properties.get_or_insert(Box::new(props)))
|
||||
}
|
||||
|
||||
fn get_image_properties(&mut self) -> l0::Result<&l0::sys::ze_device_image_properties_t> {
|
||||
if let Some(ref prop) = self.image_properties {
|
||||
return Ok(prop);
|
||||
}
|
||||
match self.base.get_image_properties() {
|
||||
Ok(prop) => Ok(self.image_properties.get_or_insert(prop)),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
let mut props = Default::default();
|
||||
self.base.get_image_properties(&mut props)?;
|
||||
Ok(self.image_properties.get_or_insert(Box::new(props)))
|
||||
}
|
||||
|
||||
fn get_memory_properties(&mut self) -> l0::Result<&[l0::sys::ze_device_memory_properties_t]> {
|
||||
|
@ -88,10 +87,9 @@ impl Device {
|
|||
if let Some(ref prop) = self.compute_properties {
|
||||
return Ok(prop);
|
||||
}
|
||||
match self.base.get_compute_properties() {
|
||||
Ok(prop) => Ok(self.compute_properties.get_or_insert(prop)),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
let mut props = Default::default();
|
||||
self.base.get_compute_properties(&mut props)?;
|
||||
Ok(self.compute_properties.get_or_insert(Box::new(props)))
|
||||
}
|
||||
|
||||
pub fn late_init(&mut self) {
|
||||
|
@ -351,7 +349,11 @@ pub fn get_uuid(uuid: *mut CUuuid_st, dev_idx: Index) -> Result<(), CUresult> {
|
|||
}
|
||||
|
||||
// TODO: add support if Level 0 exposes it
|
||||
pub fn get_luid(luid: *mut c_char, dev_node_mask: *mut c_uint, _dev_idx: Index) -> Result<(), CUresult> {
|
||||
pub fn get_luid(
|
||||
luid: *mut c_char,
|
||||
dev_node_mask: *mut c_uint,
|
||||
_dev_idx: Index,
|
||||
) -> Result<(), CUresult> {
|
||||
unsafe { ptr::write_bytes(luid, 0u8, 8) };
|
||||
unsafe { *dev_node_mask = 0 };
|
||||
Ok(())
|
||||
|
|
|
@ -144,14 +144,14 @@ pub fn launch_kernel(
|
|||
func.base
|
||||
.set_group_size(block_dim_x, block_dim_y, block_dim_z)?;
|
||||
func.legacy_args.reset();
|
||||
let mut cmd_list = stream.command_list()?;
|
||||
let cmd_list = stream.command_list()?;
|
||||
cmd_list.append_launch_kernel(
|
||||
&mut func.base,
|
||||
&[grid_dim_x, grid_dim_y, grid_dim_z],
|
||||
None,
|
||||
&mut [],
|
||||
)?;
|
||||
stream.queue.execute(cmd_list)?;
|
||||
stream.queue.execute_and_synchronize(cmd_list)?;
|
||||
Ok(())
|
||||
})?
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@ use std::{ffi::c_void, mem};
|
|||
pub fn alloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> Result<(), CUresult> {
|
||||
let ptr = GlobalState::lock_current_context(|ctx| {
|
||||
let dev = unsafe { &mut *ctx.device };
|
||||
Ok::<_, CUresult>(unsafe { dev.base.mem_alloc_device(&mut dev.l0_context, bytesize, 0) }?)
|
||||
Ok::<_, CUresult>(dev.l0_context.mem_alloc_device(bytesize, 0, dev.base)?)
|
||||
})??;
|
||||
unsafe { *dptr = ptr };
|
||||
Ok(())
|
||||
|
@ -12,9 +12,9 @@ pub fn alloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> Result<(), CUresult>
|
|||
|
||||
pub fn copy_v2(dst: *mut c_void, src: *const c_void, bytesize: usize) -> Result<(), CUresult> {
|
||||
GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| {
|
||||
let mut cmd_list = stream.command_list()?;
|
||||
unsafe { cmd_list.append_memory_copy_unsafe(dst, src, bytesize, None, &mut []) }?;
|
||||
stream.queue.execute(cmd_list)?;
|
||||
let cmd_list = stream.command_list()?;
|
||||
unsafe { cmd_list.append_memory_copy_unsafe(dst, src, bytesize, None, &mut [])? };
|
||||
stream.queue.execute_and_synchronize(cmd_list)?;
|
||||
Ok::<_, CUresult>(())
|
||||
})?
|
||||
}
|
||||
|
@ -22,29 +22,29 @@ pub fn copy_v2(dst: *mut c_void, src: *const c_void, bytesize: usize) -> Result<
|
|||
pub fn free_v2(ptr: *mut c_void) -> Result<(), CUresult> {
|
||||
GlobalState::lock_current_context(|ctx| {
|
||||
let dev = unsafe { &mut *ctx.device };
|
||||
Ok::<_, CUresult>(unsafe { dev.l0_context.mem_free(ptr) }?)
|
||||
Ok::<_, CUresult>(dev.l0_context.mem_free(ptr)?)
|
||||
})
|
||||
.map_err(|_| CUresult::CUDA_ERROR_INVALID_VALUE)?
|
||||
}
|
||||
|
||||
pub(crate) fn set_d32_v2(dst: *mut c_void, ui: u32, n: usize) -> Result<(), CUresult> {
|
||||
GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| {
|
||||
let mut cmd_list = stream.command_list()?;
|
||||
let cmd_list = stream.command_list()?;
|
||||
unsafe {
|
||||
cmd_list.append_memory_fill_unsafe(dst, &ui, mem::size_of::<u32>() * n, None, &mut [])
|
||||
}?;
|
||||
stream.queue.execute(cmd_list)?;
|
||||
stream.queue.execute_and_synchronize(cmd_list)?;
|
||||
Ok::<_, CUresult>(())
|
||||
})?
|
||||
}
|
||||
|
||||
pub(crate) fn set_d8_v2(dst: *mut c_void, uc: u8, n: usize) -> Result<(), CUresult> {
|
||||
GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| {
|
||||
let mut cmd_list = stream.command_list()?;
|
||||
let cmd_list = stream.command_list()?;
|
||||
unsafe {
|
||||
cmd_list.append_memory_fill_unsafe(dst, &uc, mem::size_of::<u8>() * n, None, &mut [])
|
||||
}?;
|
||||
stream.queue.execute(cmd_list)?;
|
||||
stream.queue.execute_and_synchronize(cmd_list)?;
|
||||
Ok::<_, CUresult>(())
|
||||
})?
|
||||
}
|
||||
|
|
|
@ -41,7 +41,7 @@ pub struct SpirvModule {
|
|||
}
|
||||
|
||||
pub struct CompiledModule {
|
||||
pub base: l0::Module,
|
||||
pub base: l0::Module<'static>,
|
||||
pub kernels: HashMap<CString, Box<Function>>,
|
||||
}
|
||||
|
||||
|
@ -78,7 +78,11 @@ impl SpirvModule {
|
|||
})
|
||||
}
|
||||
|
||||
pub fn compile(&self, ctx: &mut l0::Context, dev: &l0::Device) -> Result<l0::Module, CUresult> {
|
||||
pub fn compile<'a>(
|
||||
&self,
|
||||
ctx: &'a l0::Context,
|
||||
dev: l0::Device,
|
||||
) -> Result<l0::Module<'a>, CUresult> {
|
||||
let byte_il = unsafe {
|
||||
slice::from_raw_parts(
|
||||
self.binaries.as_ptr() as *const u8,
|
||||
|
@ -86,13 +90,11 @@ impl SpirvModule {
|
|||
)
|
||||
};
|
||||
let l0_module = match self.should_link_ptx_impl {
|
||||
None => {
|
||||
l0::Module::build_spirv(ctx, dev, byte_il, Some(self.build_options.as_c_str()))
|
||||
}
|
||||
None => l0::Module::build_spirv(ctx, dev, byte_il, Some(self.build_options.as_c_str())),
|
||||
Some(ptx_impl) => {
|
||||
l0::Module::build_link_spirv(
|
||||
ctx,
|
||||
&dev,
|
||||
dev,
|
||||
&[ptx_impl, byte_il],
|
||||
Some(self.build_options.as_c_str()),
|
||||
)
|
||||
|
@ -119,7 +121,7 @@ pub fn get_function(
|
|||
hash_map::Entry::Occupied(entry) => entry.into_mut(),
|
||||
hash_map::Entry::Vacant(entry) => {
|
||||
let new_module = CompiledModule {
|
||||
base: module.spirv.compile(&mut device.l0_context, &device.base)?,
|
||||
base: module.spirv.compile(&mut device.l0_context, device.base)?,
|
||||
kernels: HashMap::new(),
|
||||
};
|
||||
entry.insert(new_module)
|
||||
|
@ -135,7 +137,7 @@ pub fn get_function(
|
|||
std::str::from_utf8_unchecked(entry.key().as_c_str().to_bytes())
|
||||
})
|
||||
.ok_or(CUresult::CUDA_ERROR_NOT_FOUND)?;
|
||||
let mut kernel =
|
||||
let kernel =
|
||||
l0::Kernel::new_resident(&compiled_module.base, entry.key().as_c_str())?;
|
||||
kernel.set_indirect_access(
|
||||
l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE
|
||||
|
@ -165,7 +167,7 @@ pub(crate) fn load_data(pmod: *mut *mut Module, image: *const c_void) -> Result<
|
|||
pub fn load_data_impl(pmod: *mut *mut Module, spirv_data: SpirvModule) -> Result<(), CUresult> {
|
||||
let module = GlobalState::lock_current_context(|ctx| {
|
||||
let device = unsafe { &mut *ctx.device };
|
||||
let l0_module = spirv_data.compile(&mut device.l0_context, &device.base)?;
|
||||
let l0_module = spirv_data.compile(&device.l0_context, device.base)?;
|
||||
let mut device_binaries = HashMap::new();
|
||||
let compiled_module = CompiledModule {
|
||||
base: l0_module,
|
||||
|
|
|
@ -33,11 +33,11 @@ impl HasLivenessCookie for StreamData {
|
|||
|
||||
pub struct StreamData {
|
||||
pub context: *mut ContextData,
|
||||
pub queue: l0::CommandQueue,
|
||||
pub queue: l0::CommandQueue<'static>,
|
||||
}
|
||||
|
||||
impl StreamData {
|
||||
pub fn new_unitialized(ctx: &mut l0::Context, dev: &l0::Device) -> Result<Self, CUresult> {
|
||||
pub fn new_unitialized(ctx: &'static l0::Context, dev: l0::Device) -> Result<Self, CUresult> {
|
||||
Ok(StreamData {
|
||||
context: ptr::null_mut(),
|
||||
queue: l0::CommandQueue::new(ctx, dev)?,
|
||||
|
@ -45,7 +45,7 @@ impl StreamData {
|
|||
}
|
||||
pub fn new(ctx: &mut ContextData) -> Result<Self, CUresult> {
|
||||
let l0_ctx = &mut unsafe { &mut *ctx.device }.l0_context;
|
||||
let l0_dev = &unsafe { &*ctx.device }.base;
|
||||
let l0_dev = unsafe { &*ctx.device }.base;
|
||||
Ok(StreamData {
|
||||
context: ctx as *mut _,
|
||||
queue: l0::CommandQueue::new(l0_ctx, l0_dev)?,
|
||||
|
@ -55,7 +55,7 @@ impl StreamData {
|
|||
pub fn command_list(&self) -> Result<l0::CommandList, l0::sys::_ze_result_t> {
|
||||
let ctx = unsafe { &mut *self.context };
|
||||
let dev = unsafe { &mut *ctx.device };
|
||||
l0::CommandList::new(&mut dev.l0_context, &dev.base)
|
||||
l0::CommandList::new(&mut dev.l0_context, dev.base)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -127,7 +127,8 @@ pub(crate) fn system_get_driver_version(
|
|||
len: 0,
|
||||
};
|
||||
for d in drivers {
|
||||
let props = d.get_properties()?;
|
||||
let mut props = Default::default();
|
||||
d.get_properties(&mut props)?;
|
||||
let driver_version = props.driverVersion;
|
||||
write!(&mut output_write, "{}", driver_version)
|
||||
.map_err(|_| nvmlReturn_t::NVML_ERROR_UNKNOWN)?;
|
||||
|
|
Loading…
Add table
Reference in a new issue