diff --git a/notcuda/src/cu.rs b/notcuda/src/cu.rs index df07099..311552c 100644 --- a/notcuda/src/cu.rs +++ b/notcuda/src/cu.rs @@ -1,3 +1,5 @@ +use std::os::raw::c_int; + #[repr(C)] #[allow(non_camel_case_types)] pub enum Result { @@ -96,4 +98,7 @@ impl Result { #[derive(PartialEq, Eq)] pub struct Uuid { pub x: [std::os::raw::c_uchar; 16] -} \ No newline at end of file +} + +#[repr(transparent)] +pub struct Device(pub c_int); \ No newline at end of file diff --git a/notcuda/src/export_table.rs b/notcuda/src/export_table.rs index 0985c6e..0fef013 100644 --- a/notcuda/src/export_table.rs +++ b/notcuda/src/export_table.rs @@ -2,57 +2,82 @@ use super::cu; use std::mem; use std::ptr; +use std::os::raw::{c_int, c_ulong}; #[no_mangle] -pub unsafe extern "stdcall" fn cuGetExportTable( +pub unsafe extern "C" fn cuGetExportTable( table: *mut *const std::os::raw::c_void, id: *const cu::Uuid, ) -> cu::Result { - if *id == GUID0 { + if *id == CU_ETID_ToolsRuntimeCallbackHooks { *table = TABLE0.as_ptr() as *const _; + } else if *id == CU_ETID_CudartInterface { + *table = TABLE1.as_ptr() as *const _; } return cu::Result::SUCCESS; } -const GUID0: cu::Uuid = cu::Uuid { +const CU_ETID_ToolsRuntimeCallbackHooks: cu::Uuid = cu::Uuid { x: [ 0xa0, 0x94, 0x79, 0x8c, 0x2e, 0x74, 0x2e, 0x74, 0x93, 0xf2, 0x08, 0x00, 0x20, 0x0c, 0x0a, 0x66, ], }; #[repr(C)] -union Table0Member { - count: usize, +union PtrOrLength { ptr: *const (), + length: usize, } -unsafe impl Sync for Table0Member {} +unsafe impl Sync for PtrOrLength {} const TABLE0_LEN: usize = 7; -static TABLE0: [Table0Member; TABLE0_LEN] = [ - Table0Member { - count: mem::size_of::<[Table0Member; TABLE0_LEN]>(), +static TABLE0: [PtrOrLength; TABLE0_LEN] = [ + PtrOrLength { + length: mem::size_of::<[PtrOrLength; TABLE0_LEN]>(), }, - Table0Member { ptr: ptr::null() }, - Table0Member { + PtrOrLength { ptr: ptr::null() }, + PtrOrLength { ptr: table0_fn1 as *const (), }, - Table0Member { ptr: ptr::null() }, - Table0Member { ptr: ptr::null() }, - Table0Member { ptr: ptr::null() }, - Table0Member { + PtrOrLength { ptr: ptr::null() }, + PtrOrLength { ptr: ptr::null() }, + PtrOrLength { ptr: ptr::null() }, + PtrOrLength { ptr: table0_fn5 as *const (), }, ]; static mut TABLE0_FN1_SPACE: [u8; 512] = [0; 512]; static mut TABLE0_FN5_SPACE: [u8; 2] = [0; 2]; -unsafe extern "stdcall" fn table0_fn1(ptr: *mut *mut u8, size: *mut usize) -> *mut u8 { +unsafe extern "C" fn table0_fn1(ptr: *mut *mut u8, size: *mut usize) -> *mut u8 { *ptr = TABLE0_FN1_SPACE.as_mut_ptr(); *size = TABLE0_FN1_SPACE.len(); return TABLE0_FN1_SPACE.as_mut_ptr(); } -unsafe extern "stdcall" fn table0_fn5(ptr: *mut *mut u8, size: *mut usize) -> *mut u8 { +unsafe extern "C" fn table0_fn5(ptr: *mut *mut u8, size: *mut usize) -> *mut u8 { *ptr = TABLE0_FN5_SPACE.as_mut_ptr(); *size = TABLE0_FN5_SPACE.len(); return TABLE0_FN5_SPACE.as_mut_ptr(); } + +const CU_ETID_CudartInterface: cu::Uuid = cu::Uuid { + x: [ + 0x6b, 0xd5, 0xfb, 0x6c, 0x5b, 0xf4, 0xe7, 0x4a, 0x89, 0x87, 0xd9, 0x39, 0x12, 0xfd, 0x9d, + 0xf9 + ], +}; + +const TABLE1_LEN: usize = 3; +static TABLE1: [PtrOrLength; TABLE1_LEN] = [ + PtrOrLength { + length: mem::size_of::<[PtrOrLength; TABLE1_LEN]>(), + }, + PtrOrLength { ptr: ptr::null() }, + PtrOrLength { + ptr: table1_fn1 as *const (), + }, +]; + +unsafe extern "C" fn table1_fn1(_: *mut c_ulong, _: c_int) -> c_int { + 0 +} \ No newline at end of file diff --git a/notcuda/src/lib.rs b/notcuda/src/lib.rs index 29c5a70..27b34a6 100644 --- a/notcuda/src/lib.rs +++ b/notcuda/src/lib.rs @@ -4,6 +4,8 @@ extern crate lazy_static; use std::sync::Mutex; use std::ptr; +use std::cmp; +use std::os::raw::{c_char, c_int}; mod cu; mod export_table; @@ -24,9 +26,9 @@ lazy_static! { } pub struct Driver { - base: l0::ze_driver_handle_t + base: l0::ze_driver_handle_t, + devices: Vec:: } - unsafe impl Send for Driver {} unsafe impl Sync for Driver {} @@ -34,8 +36,15 @@ impl Driver { fn new() -> Result { let mut driver_count = 1; let mut handle = ptr::null_mut(); - l0_check!{ l0::zeDriverGet(&mut driver_count, &mut handle) }; - Ok(Driver{ base: handle }) + l0_check!{ l0::zeDriverGet(&mut driver_count, &mut handle) }; + let mut count = 0; + l0_check! { l0::zeDeviceGet(handle, &mut count, ptr::null_mut()) } + let mut devices = vec![ptr::null_mut(); count as usize]; + l0_check! { l0::zeDeviceGet(handle, &mut count, devices.as_mut_ptr()) } + if (count as usize) < devices.len() { + devices.truncate(count as usize); + } + Ok(Driver{ base: handle, devices: devices }) } fn call l0::ze_result_t>(f: F) -> cu::Result { @@ -53,33 +62,44 @@ impl Driver { } fn device_get_count(&self, count: *mut i32) -> l0::ze_result_t { - unsafe { l0::zeDeviceGet(self.base, count as *mut _ as *mut _, ptr::null_mut()) } + unsafe { *count = self.devices.len() as i32 }; + l0::ze_result_t::ZE_RESULT_SUCCESS } - fn device_get(&self, device: *mut l0::ze_device_handle_t, ordinal: ::std::os::raw::c_int) -> l0::ze_result_t { - let count = (ordinal as u32) + 1; - let mut devices_found = count; - let mut handles = vec![ptr::null_mut(); count as usize]; - let result = unsafe { l0::zeDeviceGet(self.base, &mut devices_found, handles.as_mut_ptr()) }; - if result != l0::ze_result_t::ZE_RESULT_SUCCESS { - return result; - } - if devices_found < count { + fn device_get(&self, device: *mut cu::Device, ordinal: c_int) -> l0::ze_result_t { + if ordinal < 0 || (ordinal as usize) >= self.devices.len() { return l0::ze_result_t::ZE_RESULT_ERROR_INVALID_ARGUMENT; } - unsafe { *device = handles[(count as usize) - 1] }; + unsafe { *device = cu::Device(ordinal) }; + l0::ze_result_t::ZE_RESULT_SUCCESS + } + + fn device_get_name(&self, name: *mut c_char, len: c_int, cu::Device(dev): cu::Device) -> l0::ze_result_t { + if len <= 0 || dev < 0 || (dev as usize) >= self.devices.len() { + return l0::ze_result_t::ZE_RESULT_ERROR_INVALID_ARGUMENT; + } + let mut props = Box::new(unsafe { std::mem::zeroed::() }); + props.version = l0::ze_device_properties_version_t::ZE_DEVICE_PROPERTIES_VERSION_CURRENT; + let result = unsafe { l0::zeDeviceGetProperties(self.devices[dev as usize], props.as_mut()) }; + if result != l0::ze_result_t::ZE_RESULT_SUCCESS { + return result; + } + let null_pos = props.name.iter().position(|&c| c == 0).unwrap_or(0); + let dst_null_pos = cmp::min((len - 1) as usize, null_pos); + unsafe { *(name.add(dst_null_pos)) = 0 }; + unsafe { std::ptr::copy_nonoverlapping(props.name.as_ptr(), name, dst_null_pos) }; l0::ze_result_t::ZE_RESULT_SUCCESS } } #[no_mangle] -pub extern "stdcall" fn cuDriverGetVersion(version: &mut std::os::raw::c_int) -> cu::Result { +pub extern "C" fn cuDriverGetVersion(version: &mut c_int) -> cu::Result { *version = i32::max_value(); return cu::Result::SUCCESS; } #[no_mangle] -pub unsafe extern "stdcall" fn cuInit(_: *const std::os::raw::c_uint) -> cu::Result { +pub unsafe extern "C" fn cuInit(_: *const c_int) -> cu::Result { let l0_init = l0::zeInit(l0::ze_init_flag_t::ZE_INIT_FLAG_GPU_ONLY); if l0_init != l0::ze_result_t::ZE_RESULT_SUCCESS { return cu::Result::from_l0(l0_init); @@ -99,11 +119,16 @@ pub unsafe extern "stdcall" fn cuInit(_: *const std::os::raw::c_uint) -> cu::Res } #[no_mangle] -pub extern "stdcall" fn cuDeviceGetCount(count: *mut std::os::raw::c_int) -> cu::Result { +pub extern "C" fn cuDeviceGetCount(count: *mut c_int) -> cu::Result { Driver::call(|driver| driver.device_get_count(count)) } #[no_mangle] -pub extern "stdcall" fn cuDeviceGet(device: *mut l0::ze_device_handle_t, ordinal: ::std::os::raw::c_int) -> cu::Result { +pub extern "C" fn cuDeviceGet(device: *mut cu::Device, ordinal: c_int) -> cu::Result { Driver::call(|driver| driver.device_get(device, ordinal)) +} + +#[no_mangle] +pub extern "C" fn cuDeviceGetName(name: *mut c_char, len: c_int, dev: cu::Device) -> cu::Result { + Driver::call(|driver| driver.device_get_name(name, len, dev)) } \ No newline at end of file