mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 08:24:44 +00:00
Add missing entry to the export table and
fix problems with device handling
This commit is contained in:
parent
21d091a47d
commit
796e030c4e
3 changed files with 92 additions and 37 deletions
|
@ -1,3 +1,5 @@
|
|||
use std::os::raw::c_int;
|
||||
|
||||
#[repr(C)]
|
||||
#[allow(non_camel_case_types)]
|
||||
pub enum Result {
|
||||
|
@ -96,4 +98,7 @@ impl Result {
|
|||
#[derive(PartialEq, Eq)]
|
||||
pub struct Uuid {
|
||||
pub x: [std::os::raw::c_uchar; 16]
|
||||
}
|
||||
}
|
||||
|
||||
#[repr(transparent)]
|
||||
pub struct Device(pub c_int);
|
|
@ -2,57 +2,82 @@ use super::cu;
|
|||
|
||||
use std::mem;
|
||||
use std::ptr;
|
||||
use std::os::raw::{c_int, c_ulong};
|
||||
|
||||
#[no_mangle]
|
||||
pub unsafe extern "stdcall" fn cuGetExportTable(
|
||||
pub unsafe extern "C" fn cuGetExportTable(
|
||||
table: *mut *const std::os::raw::c_void,
|
||||
id: *const cu::Uuid,
|
||||
) -> cu::Result {
|
||||
if *id == GUID0 {
|
||||
if *id == CU_ETID_ToolsRuntimeCallbackHooks {
|
||||
*table = TABLE0.as_ptr() as *const _;
|
||||
} else if *id == CU_ETID_CudartInterface {
|
||||
*table = TABLE1.as_ptr() as *const _;
|
||||
}
|
||||
return cu::Result::SUCCESS;
|
||||
}
|
||||
|
||||
const GUID0: cu::Uuid = cu::Uuid {
|
||||
const CU_ETID_ToolsRuntimeCallbackHooks: cu::Uuid = cu::Uuid {
|
||||
x: [
|
||||
0xa0, 0x94, 0x79, 0x8c, 0x2e, 0x74, 0x2e, 0x74, 0x93, 0xf2, 0x08, 0x00, 0x20, 0x0c, 0x0a,
|
||||
0x66,
|
||||
],
|
||||
};
|
||||
#[repr(C)]
|
||||
union Table0Member {
|
||||
count: usize,
|
||||
union PtrOrLength {
|
||||
ptr: *const (),
|
||||
length: usize,
|
||||
}
|
||||
unsafe impl Sync for Table0Member {}
|
||||
unsafe impl Sync for PtrOrLength {}
|
||||
const TABLE0_LEN: usize = 7;
|
||||
static TABLE0: [Table0Member; TABLE0_LEN] = [
|
||||
Table0Member {
|
||||
count: mem::size_of::<[Table0Member; TABLE0_LEN]>(),
|
||||
static TABLE0: [PtrOrLength; TABLE0_LEN] = [
|
||||
PtrOrLength {
|
||||
length: mem::size_of::<[PtrOrLength; TABLE0_LEN]>(),
|
||||
},
|
||||
Table0Member { ptr: ptr::null() },
|
||||
Table0Member {
|
||||
PtrOrLength { ptr: ptr::null() },
|
||||
PtrOrLength {
|
||||
ptr: table0_fn1 as *const (),
|
||||
},
|
||||
Table0Member { ptr: ptr::null() },
|
||||
Table0Member { ptr: ptr::null() },
|
||||
Table0Member { ptr: ptr::null() },
|
||||
Table0Member {
|
||||
PtrOrLength { ptr: ptr::null() },
|
||||
PtrOrLength { ptr: ptr::null() },
|
||||
PtrOrLength { ptr: ptr::null() },
|
||||
PtrOrLength {
|
||||
ptr: table0_fn5 as *const (),
|
||||
},
|
||||
];
|
||||
static mut TABLE0_FN1_SPACE: [u8; 512] = [0; 512];
|
||||
static mut TABLE0_FN5_SPACE: [u8; 2] = [0; 2];
|
||||
|
||||
unsafe extern "stdcall" fn table0_fn1(ptr: *mut *mut u8, size: *mut usize) -> *mut u8 {
|
||||
unsafe extern "C" fn table0_fn1(ptr: *mut *mut u8, size: *mut usize) -> *mut u8 {
|
||||
*ptr = TABLE0_FN1_SPACE.as_mut_ptr();
|
||||
*size = TABLE0_FN1_SPACE.len();
|
||||
return TABLE0_FN1_SPACE.as_mut_ptr();
|
||||
}
|
||||
|
||||
unsafe extern "stdcall" fn table0_fn5(ptr: *mut *mut u8, size: *mut usize) -> *mut u8 {
|
||||
unsafe extern "C" fn table0_fn5(ptr: *mut *mut u8, size: *mut usize) -> *mut u8 {
|
||||
*ptr = TABLE0_FN5_SPACE.as_mut_ptr();
|
||||
*size = TABLE0_FN5_SPACE.len();
|
||||
return TABLE0_FN5_SPACE.as_mut_ptr();
|
||||
}
|
||||
|
||||
const CU_ETID_CudartInterface: cu::Uuid = cu::Uuid {
|
||||
x: [
|
||||
0x6b, 0xd5, 0xfb, 0x6c, 0x5b, 0xf4, 0xe7, 0x4a, 0x89, 0x87, 0xd9, 0x39, 0x12, 0xfd, 0x9d,
|
||||
0xf9
|
||||
],
|
||||
};
|
||||
|
||||
const TABLE1_LEN: usize = 3;
|
||||
static TABLE1: [PtrOrLength; TABLE1_LEN] = [
|
||||
PtrOrLength {
|
||||
length: mem::size_of::<[PtrOrLength; TABLE1_LEN]>(),
|
||||
},
|
||||
PtrOrLength { ptr: ptr::null() },
|
||||
PtrOrLength {
|
||||
ptr: table1_fn1 as *const (),
|
||||
},
|
||||
];
|
||||
|
||||
unsafe extern "C" fn table1_fn1(_: *mut c_ulong, _: c_int) -> c_int {
|
||||
0
|
||||
}
|
|
@ -4,6 +4,8 @@ extern crate lazy_static;
|
|||
|
||||
use std::sync::Mutex;
|
||||
use std::ptr;
|
||||
use std::cmp;
|
||||
use std::os::raw::{c_char, c_int};
|
||||
|
||||
mod cu;
|
||||
mod export_table;
|
||||
|
@ -24,9 +26,9 @@ lazy_static! {
|
|||
}
|
||||
|
||||
pub struct Driver {
|
||||
base: l0::ze_driver_handle_t
|
||||
base: l0::ze_driver_handle_t,
|
||||
devices: Vec::<l0::ze_device_handle_t>
|
||||
}
|
||||
|
||||
unsafe impl Send for Driver {}
|
||||
unsafe impl Sync for Driver {}
|
||||
|
||||
|
@ -34,8 +36,15 @@ impl Driver {
|
|||
fn new() -> Result<Driver, l0::ze_result_t> {
|
||||
let mut driver_count = 1;
|
||||
let mut handle = ptr::null_mut();
|
||||
l0_check!{ l0::zeDriverGet(&mut driver_count, &mut handle) };
|
||||
Ok(Driver{ base: handle })
|
||||
l0_check!{ l0::zeDriverGet(&mut driver_count, &mut handle) };
|
||||
let mut count = 0;
|
||||
l0_check! { l0::zeDeviceGet(handle, &mut count, ptr::null_mut()) }
|
||||
let mut devices = vec![ptr::null_mut(); count as usize];
|
||||
l0_check! { l0::zeDeviceGet(handle, &mut count, devices.as_mut_ptr()) }
|
||||
if (count as usize) < devices.len() {
|
||||
devices.truncate(count as usize);
|
||||
}
|
||||
Ok(Driver{ base: handle, devices: devices })
|
||||
}
|
||||
|
||||
fn call<F: FnOnce(&mut Driver) -> l0::ze_result_t>(f: F) -> cu::Result {
|
||||
|
@ -53,33 +62,44 @@ impl Driver {
|
|||
}
|
||||
|
||||
fn device_get_count(&self, count: *mut i32) -> l0::ze_result_t {
|
||||
unsafe { l0::zeDeviceGet(self.base, count as *mut _ as *mut _, ptr::null_mut()) }
|
||||
unsafe { *count = self.devices.len() as i32 };
|
||||
l0::ze_result_t::ZE_RESULT_SUCCESS
|
||||
}
|
||||
|
||||
fn device_get(&self, device: *mut l0::ze_device_handle_t, ordinal: ::std::os::raw::c_int) -> l0::ze_result_t {
|
||||
let count = (ordinal as u32) + 1;
|
||||
let mut devices_found = count;
|
||||
let mut handles = vec![ptr::null_mut(); count as usize];
|
||||
let result = unsafe { l0::zeDeviceGet(self.base, &mut devices_found, handles.as_mut_ptr()) };
|
||||
if result != l0::ze_result_t::ZE_RESULT_SUCCESS {
|
||||
return result;
|
||||
}
|
||||
if devices_found < count {
|
||||
fn device_get(&self, device: *mut cu::Device, ordinal: c_int) -> l0::ze_result_t {
|
||||
if ordinal < 0 || (ordinal as usize) >= self.devices.len() {
|
||||
return l0::ze_result_t::ZE_RESULT_ERROR_INVALID_ARGUMENT;
|
||||
}
|
||||
unsafe { *device = handles[(count as usize) - 1] };
|
||||
unsafe { *device = cu::Device(ordinal) };
|
||||
l0::ze_result_t::ZE_RESULT_SUCCESS
|
||||
}
|
||||
|
||||
fn device_get_name(&self, name: *mut c_char, len: c_int, cu::Device(dev): cu::Device) -> l0::ze_result_t {
|
||||
if len <= 0 || dev < 0 || (dev as usize) >= self.devices.len() {
|
||||
return l0::ze_result_t::ZE_RESULT_ERROR_INVALID_ARGUMENT;
|
||||
}
|
||||
let mut props = Box::new(unsafe { std::mem::zeroed::<l0::ze_device_properties_t>() });
|
||||
props.version = l0::ze_device_properties_version_t::ZE_DEVICE_PROPERTIES_VERSION_CURRENT;
|
||||
let result = unsafe { l0::zeDeviceGetProperties(self.devices[dev as usize], props.as_mut()) };
|
||||
if result != l0::ze_result_t::ZE_RESULT_SUCCESS {
|
||||
return result;
|
||||
}
|
||||
let null_pos = props.name.iter().position(|&c| c == 0).unwrap_or(0);
|
||||
let dst_null_pos = cmp::min((len - 1) as usize, null_pos);
|
||||
unsafe { *(name.add(dst_null_pos)) = 0 };
|
||||
unsafe { std::ptr::copy_nonoverlapping(props.name.as_ptr(), name, dst_null_pos) };
|
||||
l0::ze_result_t::ZE_RESULT_SUCCESS
|
||||
}
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "stdcall" fn cuDriverGetVersion(version: &mut std::os::raw::c_int) -> cu::Result {
|
||||
pub extern "C" fn cuDriverGetVersion(version: &mut c_int) -> cu::Result {
|
||||
*version = i32::max_value();
|
||||
return cu::Result::SUCCESS;
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub unsafe extern "stdcall" fn cuInit(_: *const std::os::raw::c_uint) -> cu::Result {
|
||||
pub unsafe extern "C" fn cuInit(_: *const c_int) -> cu::Result {
|
||||
let l0_init = l0::zeInit(l0::ze_init_flag_t::ZE_INIT_FLAG_GPU_ONLY);
|
||||
if l0_init != l0::ze_result_t::ZE_RESULT_SUCCESS {
|
||||
return cu::Result::from_l0(l0_init);
|
||||
|
@ -99,11 +119,16 @@ pub unsafe extern "stdcall" fn cuInit(_: *const std::os::raw::c_uint) -> cu::Res
|
|||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "stdcall" fn cuDeviceGetCount(count: *mut std::os::raw::c_int) -> cu::Result {
|
||||
pub extern "C" fn cuDeviceGetCount(count: *mut c_int) -> cu::Result {
|
||||
Driver::call(|driver| driver.device_get_count(count))
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "stdcall" fn cuDeviceGet(device: *mut l0::ze_device_handle_t, ordinal: ::std::os::raw::c_int) -> cu::Result {
|
||||
pub extern "C" fn cuDeviceGet(device: *mut cu::Device, ordinal: c_int) -> cu::Result {
|
||||
Driver::call(|driver| driver.device_get(device, ordinal))
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn cuDeviceGetName(name: *mut c_char, len: c_int, dev: cu::Device) -> cu::Result {
|
||||
Driver::call(|driver| driver.device_get_name(name, len, dev))
|
||||
}
|
Loading…
Add table
Reference in a new issue