From c26ab5daedc9a855b7407c0a449b7a40922ae243 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Tue, 25 Feb 2020 23:08:11 +0100 Subject: [PATCH] Add malloc and context setter/getter --- level_zero-sys/README | 2 +- level_zero-sys/src/zex_api.rs | 55 +++++++++++++++++------------------ notcuda/src/cu.rs | 12 +++++++- notcuda/src/lib.rs | 51 ++++++++++++++++++++++++++++---- notcuda/src/ze.rs | 10 +++++++ 5 files changed, 94 insertions(+), 36 deletions(-) diff --git a/level_zero-sys/README b/level_zero-sys/README index 70fbfb5..2534908 100644 --- a/level_zero-sys/README +++ b/level_zero-sys/README @@ -1 +1 @@ -bindgen --default-enum-style=rust --whitelist-function ze.* /usr/include/level_zero/zex_api.h -o zex_api.rs -- -x c++ && sed -i 's/pub enum _ze_result_t/#[must_use]\npub enum _ze_result_t/g' zex_api.rs \ No newline at end of file +bindgen --size_t-is-usize --default-enum-style=rust --whitelist-function ze.* /usr/include/level_zero/zex_api.h -o zex_api.rs -- -x c++ && sed -i 's/pub enum _ze_result_t/#[must_use]\npub enum _ze_result_t/g' zex_api.rs \ No newline at end of file diff --git a/level_zero-sys/src/zex_api.rs b/level_zero-sys/src/zex_api.rs index c016272..bcdf079 100644 --- a/level_zero-sys/src/zex_api.rs +++ b/level_zero-sys/src/zex_api.rs @@ -3,7 +3,6 @@ pub type __uint8_t = ::std::os::raw::c_uchar; pub type __uint32_t = ::std::os::raw::c_uint; pub type __uint64_t = ::std::os::raw::c_ulong; -pub type size_t = ::std::os::raw::c_ulong; #[doc = ""] #[doc = " @brief compiler-independent type"] pub type ze_bool_t = u8; @@ -1990,14 +1989,14 @@ pub struct _ze_device_cache_properties_t { #[doc = "< section vs Generic Cache)"] pub intermediateCacheControlSupported: ze_bool_t, #[doc = "< [out] Per-cache Intermediate Cache (L1/L2) size, in bytes"] - pub intermediateCacheSize: size_t, + pub intermediateCacheSize: usize, #[doc = "< [out] Cacheline size in bytes for intermediate cacheline (L1/L2)."] pub intermediateCachelineSize: u32, #[doc = "< [out] Support User control on Last Level Cache (i.e. Resize SLM"] #[doc = "< section vs Generic Cache)."] pub lastLevelCacheSizeControlSupported: ze_bool_t, #[doc = "< [out] Per-cache Last Level Cache (L3) size, in bytes"] - pub lastLevelCacheSize: size_t, + pub lastLevelCacheSize: usize, #[doc = "< [out] Cacheline size in bytes for last-level cacheline (L3)."] pub lastLevelCachelineSize: u32, } @@ -3062,7 +3061,7 @@ extern "C" { pub fn zeCommandListAppendMemoryRangesBarrier( hCommandList: ze_command_list_handle_t, numRanges: u32, - pRangeSizes: *const size_t, + pRangeSizes: *const usize, pRanges: *mut *const ::std::os::raw::c_void, hSignalEvent: ze_event_handle_t, numWaitEvents: u32, @@ -3123,7 +3122,7 @@ extern "C" { hCommandList: ze_command_list_handle_t, dstptr: *mut ::std::os::raw::c_void, srcptr: *const ::std::os::raw::c_void, - size: size_t, + size: usize, hEvent: ze_event_handle_t, ) -> ze_result_t; } @@ -3160,8 +3159,8 @@ extern "C" { hCommandList: ze_command_list_handle_t, ptr: *mut ::std::os::raw::c_void, pattern: *const ::std::os::raw::c_void, - pattern_size: size_t, - size: size_t, + pattern_size: usize, + size: usize, hEvent: ze_event_handle_t, ) -> ze_result_t; } @@ -3554,7 +3553,7 @@ extern "C" { pub fn zeCommandListAppendMemoryPrefetch( hCommandList: ze_command_list_handle_t, ptr: *const ::std::os::raw::c_void, - size: size_t, + size: usize, ) -> ze_result_t; } #[repr(u32)] @@ -3623,7 +3622,7 @@ extern "C" { hCommandList: ze_command_list_handle_t, hDevice: ze_device_handle_t, ptr: *const ::std::os::raw::c_void, - size: size_t, + size: usize, advice: ze_memory_advice_t, ) -> ze_result_t; } @@ -5130,8 +5129,8 @@ extern "C" { hDriver: ze_driver_handle_t, device_desc: *const ze_device_mem_alloc_desc_t, host_desc: *const ze_host_mem_alloc_desc_t, - size: size_t, - alignment: size_t, + size: usize, + alignment: usize, hDevice: ze_device_handle_t, pptr: *mut *mut ::std::os::raw::c_void, ) -> ze_result_t; @@ -5171,8 +5170,8 @@ extern "C" { pub fn zeDriverAllocDeviceMem( hDriver: ze_driver_handle_t, device_desc: *const ze_device_mem_alloc_desc_t, - size: size_t, - alignment: size_t, + size: usize, + alignment: usize, hDevice: ze_device_handle_t, pptr: *mut *mut ::std::os::raw::c_void, ) -> ze_result_t; @@ -5213,8 +5212,8 @@ extern "C" { pub fn zeDriverAllocHostMem( hDriver: ze_driver_handle_t, host_desc: *const ze_host_mem_alloc_desc_t, - size: size_t, - alignment: size_t, + size: usize, + alignment: usize, pptr: *mut *mut ::std::os::raw::c_void, ) -> ze_result_t; } @@ -5389,7 +5388,7 @@ extern "C" { hDriver: ze_driver_handle_t, ptr: *const ::std::os::raw::c_void, pBase: *mut *mut ::std::os::raw::c_void, - pSize: *mut size_t, + pSize: *mut usize, ) -> ze_result_t; } extern "C" { @@ -5583,7 +5582,7 @@ pub struct _ze_module_desc_t { #[doc = "< [in] Module format passed in with pInputModule"] pub format: ze_module_format_t, #[doc = "< [in] size of input IL or ISA from pInputModule."] - pub inputSize: size_t, + pub inputSize: usize, #[doc = "< [in] pointer to IL or ISA"] pub pInputModule: *const u8, #[doc = "< [in] string containing compiler flags. See programming guide for build"] @@ -5790,7 +5789,7 @@ extern "C" { #[doc = " + `nullptr == pSize`"] pub fn zeModuleBuildLogGetString( hModuleBuildLog: ze_module_build_log_handle_t, - pSize: *mut size_t, + pSize: *mut usize, pBuildLog: *mut ::std::os::raw::c_char, ) -> ze_result_t; } @@ -5820,7 +5819,7 @@ extern "C" { #[doc = " + `nullptr == pSize`"] pub fn zeModuleGetNativeBinary( hModule: ze_module_handle_t, - pSize: *mut size_t, + pSize: *mut usize, pModuleNativeBinary: *mut u8, ) -> ze_result_t; } @@ -6129,7 +6128,7 @@ extern "C" { pub fn zeKernelSetArgumentValue( hKernel: ze_kernel_handle_t, argIndex: u32, - argSize: size_t, + argSize: usize, pArgValue: *const ::std::os::raw::c_void, ) -> ze_result_t; } @@ -6606,7 +6605,7 @@ extern "C" { pub fn zeDeviceMakeMemoryResident( hDevice: ze_device_handle_t, ptr: *mut ::std::os::raw::c_void, - size: size_t, + size: usize, ) -> ze_result_t; } extern "C" { @@ -6633,7 +6632,7 @@ extern "C" { pub fn zeDeviceEvictMemory( hDevice: ze_device_handle_t, ptr: *mut ::std::os::raw::c_void, - size: size_t, + size: usize, ) -> ze_result_t; } extern "C" { @@ -7090,7 +7089,7 @@ extern "C" { hCommandGraph: zex_command_graph_handle_t, phCommandNode: *mut zex_command_graph_handle_t, phParentNodes: *mut zex_command_graph_handle_t, - noParentNodes: size_t, + noParentNodes: usize, nodeType: COMMANDGRAPH_TYPE, ) -> ze_result_t; } @@ -7099,7 +7098,7 @@ extern "C" { hCommandGraph: zex_command_graph_handle_t, phCommandNode: *mut zex_command_graph_handle_t, phParentNodes: *mut zex_command_graph_handle_t, - noParentNodes: size_t, + noParentNodes: usize, regDestination: ALU_REG, regSourceAddress: ALU_REG, ) -> ze_result_t; @@ -7109,7 +7108,7 @@ extern "C" { hCommandGraph: zex_command_graph_handle_t, phCommandNode: *mut zex_command_graph_handle_t, phParentNodes: *mut zex_command_graph_handle_t, - noParentNodes: size_t, + noParentNodes: usize, regDestinationAddress: ALU_REG, regSource: ALU_REG, ) -> ze_result_t; @@ -7118,7 +7117,7 @@ extern "C" { pub fn zexCommandGraphNodeAddChildren( hCommandNode: zex_command_graph_handle_t, phChildrenNodes: *mut zex_command_graph_handle_t, - noChildrenNodes: size_t, + noChildrenNodes: usize, ) -> ze_result_t; } extern "C" { @@ -7152,7 +7151,7 @@ extern "C" { #[doc = " - ::ZE_RESULT_ERROR_UNKNOWN"] pub fn zexCommandListReserveSpace( hCommandList: zex_command_list_handle_t, - size: size_t, + size: usize, ptr: *mut *mut ::std::os::raw::c_void, ) -> ze_result_t; } @@ -7188,7 +7187,7 @@ extern "C" { pub fn zexCommandListAppendMIMath( hCommandList: zex_command_list_handle_t, opArray: *mut zex_alu_operation_t, - noOperations: size_t, + noOperations: usize, ) -> ze_result_t; } extern "C" { diff --git a/notcuda/src/cu.rs b/notcuda/src/cu.rs index 75eca4f..fa6f1a5 100644 --- a/notcuda/src/cu.rs +++ b/notcuda/src/cu.rs @@ -1,6 +1,7 @@ use num_enum::TryFromPrimitive; use std::convert::TryFrom; use std::os::raw::c_int; +use std::ptr; #[repr(C)] #[allow(non_camel_case_types)] @@ -156,4 +157,13 @@ pub struct Uuid { pub struct Device(pub c_int); #[repr(transparent)] -pub struct DevicePtr(c_int); \ No newline at end of file +pub struct DevicePtr(usize); + +#[repr(transparent)] +#[derive(Clone, PartialEq)] +pub struct Context(*mut ()); +impl Context { + pub fn null() -> Context { + Context(ptr::null_mut()) + } +} \ No newline at end of file diff --git a/notcuda/src/lib.rs b/notcuda/src/lib.rs index 81f2b92..fad9f48 100644 --- a/notcuda/src/lib.rs +++ b/notcuda/src/lib.rs @@ -3,9 +3,12 @@ extern crate level_zero_sys as l0; extern crate lazy_static; use std::convert::TryFrom; -use std::sync::Mutex; -use std::ptr; use std::os::raw::{c_char, c_int, c_uint}; +use std::ptr; +use std::cell::RefCell; +use std::sync::Mutex; + +use ze::Versioned; #[macro_use] macro_rules! l0_check_err { @@ -23,11 +26,15 @@ mod cu; mod export_table; mod ze; -lazy_static! { - pub static ref GLOBAL_STATE: Mutex> = Mutex::new(None); +thread_local! { + static CONTEXT_STACK: RefCell> = RefCell::new(Vec::new()); } -pub struct Driver { +lazy_static! { + static ref GLOBAL_STATE: Mutex> = Mutex::new(None); +} + +struct Driver { base: l0::ze_driver_handle_t, devices: Vec:: } @@ -180,8 +187,40 @@ pub extern "C" fn cuDeviceGetUuid(uuid: *mut cu::Uuid, dev_idx: cu::Device) -> c Driver::call_device(dev_idx, |dev| dev.get_uuid(uuid)) } +#[no_mangle] +pub extern "C" fn cuCtxGetCurrent(pctx: *mut cu::Context) -> cu::Result { + let ctx = CONTEXT_STACK.with(|stack| { + match stack.borrow().last() { + Some(ctx) => ctx.clone(), + None => cu::Context::null() + } + }); + unsafe { *pctx = ctx }; + cu::Result::SUCCESS +} + +#[no_mangle] +pub extern "C" fn cuCtxSetCurrent(ctx: cu::Context) -> cu::Result { + CONTEXT_STACK.with(|stack| { + let mut stack = stack.borrow_mut(); + stack.pop(); + if ctx != cu::Context::null() { + stack.push(ctx); + } + }); + cu::Result::SUCCESS +} #[no_mangle] pub extern "C" fn cuMemAlloc_v2(dptr: *mut cu::DevicePtr, bytesize: usize) -> cu::Result { - unimplemented!() + if dptr == ptr::null_mut() || bytesize == 0 { + return cu::Result::ERROR_INVALID_VALUE; + } + Driver::call(|drv| { + let mut descr = l0::ze_device_mem_alloc_desc_t::new(); + descr.flags = l0::ze_device_mem_alloc_flag_t::ZE_DEVICE_MEM_ALLOC_FLAG_DEFAULT; + descr.ordinal = 0; + // TODO: check current context for the device + unsafe { l0::zeDriverAllocDeviceMem(drv.base, &descr, bytesize, 0, drv.devices[0].0, dptr as *mut _) } + }) } \ No newline at end of file diff --git a/notcuda/src/ze.rs b/notcuda/src/ze.rs index e1fc804..748691e 100644 --- a/notcuda/src/ze.rs +++ b/notcuda/src/ze.rs @@ -77,6 +77,16 @@ impl Versioned for ze_device_image_properties_t { } } +impl Versioned for ze_device_mem_alloc_desc_t { + type Version = ze_device_mem_alloc_desc_version_t; + fn current() -> Self::Version { + ze_device_mem_alloc_desc_version_t::ZE_DEVICE_MEM_ALLOC_DESC_VERSION_CURRENT + } + fn version(&mut self) -> &mut Self::Version { + &mut self.version + } +} + #[derive(Clone, Copy)] #[repr(transparent)] // required so a Vec can be safely transmutted to Vec pub struct Device(pub ze_device_handle_t);