From 3f41f21acb51f7a1d305630dc2a4e5c5df5e4a83 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Thu, 24 Sep 2020 01:54:16 +0200 Subject: [PATCH] Implement more host code, moving execution further --- level_zero/src/ze.rs | 6 ++ notcuda/src/cuda.rs | 2 +- notcuda/src/impl/device.rs | 48 +++++++++++++--- notcuda/src/impl/export_table.rs | 25 ++++----- notcuda/src/impl/mod.rs | 15 +++-- notcuda/src/impl/module.rs | 75 +++++++++++++++++++++---- ptx/src/test/spirv_run/mod.rs | 1 + ptx/src/test/spirv_run/pred_not.ptx | 28 +++++++++ ptx/src/test/spirv_run/pred_not.spvtxt | 78 ++++++++++++++++++++++++++ 9 files changed, 241 insertions(+), 37 deletions(-) create mode 100644 ptx/src/test/spirv_run/pred_not.ptx create mode 100644 ptx/src/test/spirv_run/pred_not.spvtxt diff --git a/level_zero/src/ze.rs b/level_zero/src/ze.rs index cee736c..16b9130 100644 --- a/level_zero/src/ze.rs +++ b/level_zero/src/ze.rs @@ -118,6 +118,12 @@ impl Device { Ok(props) } + pub fn get_compute_properties(&self) -> Result> { + let mut props = Box::new(unsafe { mem::zeroed::() }); + check! { sys::zeDeviceGetComputeProperties(self.0, props.as_mut()) }; + Ok(props) + } + pub unsafe fn mem_alloc_device( &mut self, ctx: &mut Context, diff --git a/notcuda/src/cuda.rs b/notcuda/src/cuda.rs index 3267042..122f0da 100644 --- a/notcuda/src/cuda.rs +++ b/notcuda/src/cuda.rs @@ -2501,7 +2501,7 @@ pub extern "C" fn cuModuleGetFunction( hmod: CUmodule, name: *const ::std::os::raw::c_char, ) -> CUresult { - r#impl::unimplemented() + r#impl::module::get_function(hfunc.decuda(), hmod.decuda(), name).encuda() } #[no_mangle] diff --git a/notcuda/src/impl/device.rs b/notcuda/src/impl/device.rs index 8a8f2f8..db39efd 100644 --- a/notcuda/src/impl/device.rs +++ b/notcuda/src/impl/device.rs @@ -1,4 +1,4 @@ -use super::{context, CUresult, Error}; +use super::{context, transmute_lifetime, CUresult, Error}; use crate::cuda; use cuda::{CUdevice_attribute, CUuuid_st}; use std::{ @@ -25,6 +25,7 @@ pub struct Device { properties: Option>, image_properties: Option>, memory_properties: Option>, + compute_properties: Option>, } unsafe impl Send for Device {} @@ -48,6 +49,7 @@ impl Device { properties: None, image_properties: None, memory_properties: None, + compute_properties: None, }) } @@ -80,6 +82,16 @@ impl Device { Err(e) => Err(e), } } + + fn get_compute_properties(&mut self) -> l0::Result<&l0::sys::ze_device_compute_properties_t> { + if let Some(ref prop) = self.compute_properties { + return Ok(prop); + } + match self.base.get_compute_properties() { + Ok(prop) => Ok(self.compute_properties.get_or_insert(prop)), + Err(e) => Err(e), + } + } } pub fn init(driver: &l0::Driver) -> l0::Result<()> { @@ -166,10 +178,6 @@ pub fn get_name(name: *mut c_char, len: i32, dev: Index) -> Result<(), CUresult> Ok(()) } -unsafe fn transmute_lifetime<'a, 'b, T: ?Sized>(t: &'a T) -> &'b T { - mem::transmute(t) -} - pub fn total_mem_v2(bytes: *mut usize, dev: Index) -> Result<(), CUresult> { if bytes == ptr::null_mut() { return Err(CUresult::CUDA_ERROR_INVALID_VALUE); @@ -232,6 +240,34 @@ pub fn get_attribute(pi: *mut i32, attrib: CUdevice_attribute, dev: Index) -> Re .maxImageDims1D, c_int::max_value() as u32, ) as c_int, + CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X => { + let props = dev.get_compute_properties().map_err(Error::L0)?; + cmp::max(i32::max_value() as u32, props.maxGroupCountX) as i32 + } + CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y => { + let props = dev.get_compute_properties().map_err(Error::L0)?; + cmp::max(i32::max_value() as u32, props.maxGroupCountY) as i32 + } + CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z => { + let props = dev.get_compute_properties().map_err(Error::L0)?; + cmp::max(i32::max_value() as u32, props.maxGroupCountZ) as i32 + } + CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X => { + let props = dev.get_compute_properties().map_err(Error::L0)?; + cmp::max(i32::max_value() as u32, props.maxGroupSizeX) as i32 + } + CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y => { + let props = dev.get_compute_properties().map_err(Error::L0)?; + cmp::max(i32::max_value() as u32, props.maxGroupSizeY) as i32 + } + CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z => { + let props = dev.get_compute_properties().map_err(Error::L0)?; + cmp::max(i32::max_value() as u32, props.maxGroupSizeZ) as i32 + } + CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK => { + let props = dev.get_compute_properties().map_err(Error::L0)?; + cmp::max(i32::max_value() as u32, props.maxTotalGroupSize) as i32 + } _ => { // TODO: support more attributes for CUDA runtime /* @@ -311,8 +347,6 @@ pub fn primary_ctx_retain(pctx: *mut *mut context::Context, dev: Index) -> Resul mod tests { use super::super::test::CudaDriverFns; use super::super::CUresult; - use crate::cuda::CUuuid; - use std::{ffi::c_void, mem, ptr}; cuda_driver_test!(primary_ctx_default_inactive); diff --git a/notcuda/src/impl/export_table.rs b/notcuda/src/impl/export_table.rs index 233c496..9a6d72c 100644 --- a/notcuda/src/impl/export_table.rs +++ b/notcuda/src/impl/export_table.rs @@ -66,12 +66,11 @@ static TOOLS_RUNTIME_CALLBACK_HOOKS_VTABLE: [VTableEntry; TOOLS_RUNTIME_CALLBACK ptr: runtime_callback_hooks_fn5 as *const (), }, ]; -static mut TOOLS_RUNTIME_CALLBACK_HOOKS_FN1_SPACE: [u8; 512] = [0; 512]; +static mut TOOLS_RUNTIME_CALLBACK_HOOKS_FN1_SPACE: [usize; 512] = [0; 512]; -unsafe extern "C" fn runtime_callback_hooks_fn1(ptr: *mut *mut u8, size: *mut usize) -> *mut u8 { +unsafe extern "C" fn runtime_callback_hooks_fn1(ptr: *mut *mut usize, size: *mut usize) { *ptr = TOOLS_RUNTIME_CALLBACK_HOOKS_FN1_SPACE.as_mut_ptr(); *size = TOOLS_RUNTIME_CALLBACK_HOOKS_FN1_SPACE.len(); - return TOOLS_RUNTIME_CALLBACK_HOOKS_FN1_SPACE.as_mut_ptr(); } static mut TOOLS_RUNTIME_CALLBACK_HOOKS_FN5_SPACE: [u8; 2] = [0; 2]; @@ -198,9 +197,14 @@ struct FatbinFileHeader { unsafe extern "C" fn get_module_from_cubin( result: *mut CUmodule, fatbinc_wrapper: *const FatbincWrapper, - _: *mut c_void, - _: *mut c_void, + ptr1: *mut c_void, + ptr2: *mut c_void, ) -> CUresult { + // Not sure what those twoparameters are actually used for, + // they are somehow involved in __cudaRegisterHostVar + if ptr1 != ptr::null_mut() || ptr2 != ptr::null_mut() { + return CUresult::CUDA_ERROR_NOT_SUPPORTED; + } if result == ptr::null_mut() || (*fatbinc_wrapper).magic != FATBINC_MAGIC || (*fatbinc_wrapper).version != FATBINC_VERSION @@ -208,11 +212,6 @@ unsafe extern "C" fn get_module_from_cubin( return CUresult::CUDA_ERROR_INVALID_VALUE; } let result = result.decuda(); - let mut dev_count = 0; - let cu_result = device::get_count(&mut dev_count); - if cu_result != CUresult::CUDA_SUCCESS { - return cu_result; - } let fatbin_header = (*fatbinc_wrapper).data; if (*fatbin_header).magic != FATBIN_MAGIC || (*fatbin_header).version != FATBIN_VERSION { return CUresult::CUDA_ERROR_INVALID_VALUE; @@ -235,7 +234,7 @@ unsafe extern "C" fn get_module_from_cubin( }, Err(_) => continue, }; - let module = module::Module::compile(kernel_text_string, dev_count as usize); + let module = module::ModuleData::compile_spirv(kernel_text_string); match module { Ok(module) => { *result = Box::into_raw(Box::new(module)); @@ -310,7 +309,7 @@ unsafe extern "C" fn context_local_storage_ctor( } fn context_local_storage_ctor_impl( - cu_ctx: *mut context::Context, + mut cu_ctx: *mut context::Context, mgr: *mut cuda_impl::rt::ContextStateManager, ctx_state: *mut cuda_impl::rt::ContextState, dtor_cb: Option< @@ -322,7 +321,7 @@ fn context_local_storage_ctor_impl( >, ) -> Result<(), CUresult> { if cu_ctx == ptr::null_mut() { - return Err(CUresult::CUDA_ERROR_NOT_SUPPORTED); + context::get_current(&mut cu_ctx)?; } unsafe { &*cu_ctx } .as_ref() diff --git a/notcuda/src/impl/mod.rs b/notcuda/src/impl/mod.rs index 7813532..c37b85d 100644 --- a/notcuda/src/impl/mod.rs +++ b/notcuda/src/impl/mod.rs @@ -1,5 +1,5 @@ -use crate::cuda::{CUctx_st, CUdevice, CUdeviceptr, CUresult, CUmodule}; -use std::{ffi::c_void, mem::ManuallyDrop, os::raw::c_int, sync::Mutex}; +use crate::cuda::{CUctx_st, CUdevice, CUdeviceptr, CUfunction, CUmod_st, CUmodule, CUresult}; +use std::{ffi::c_void, mem::{self, ManuallyDrop}, os::raw::c_int, sync::Mutex}; #[cfg(test)] #[macro_use] @@ -206,6 +206,10 @@ pub fn init() -> l0::Result<()> { Ok(()) } +unsafe fn transmute_lifetime<'a, 'b, T: ?Sized>(t: &'a T) -> &'b T { + mem::transmute(t) +} + pub fn driver_get_version() -> c_int { i32::max_value() } @@ -234,7 +238,10 @@ impl Decuda<*mut c_void> for CUdeviceptr { } } -impl<'a> CudaRepr for CUmodule { - type Impl = *mut module::Module; +impl<'a> CudaRepr for CUmod_st { + type Impl = module::Module; } +impl<'a> CudaRepr for CUfunction { + type Impl = *mut module::Function; +} diff --git a/notcuda/src/impl/module.rs b/notcuda/src/impl/module.rs index 4b664b5..06d050d 100644 --- a/notcuda/src/impl/module.rs +++ b/notcuda/src/impl/module.rs @@ -1,8 +1,14 @@ +use std::{ffi::c_void, ffi::CStr, mem, os::raw::c_char, ptr, slice, sync::Mutex}; + +use super::{transmute_lifetime, CUresult}; use ptx; -pub struct Module { - spirv_code: Vec, - compiled_code: Vec>>, // size as big as the number of devices +use super::context; + +pub type Module = Mutex; + +pub struct ModuleData { + base: l0::Module, } pub enum ModuleCompileError<'a> { @@ -10,21 +16,35 @@ pub enum ModuleCompileError<'a> { Vec, Option, ptx::ast::PtxError>>, ), - Compile(ptx::SpirvError), + Compile(ptx::TranslateError), + L0(l0::sys::ze_result_t), + CUDA(CUresult), } impl<'a> ModuleCompileError<'a> { pub fn get_build_log(&self) {} } -impl<'a> From for ModuleCompileError<'a> { - fn from(err: ptx::SpirvError) -> Self { +impl<'a> From for ModuleCompileError<'a> { + fn from(err: ptx::TranslateError) -> Self { ModuleCompileError::Compile(err) } } -impl Module { - pub fn compile(ptx_text: &str, devices: usize) -> Result { +impl<'a> From for ModuleCompileError<'a> { + fn from(err: l0::sys::ze_result_t) -> Self { + ModuleCompileError::L0(err) + } +} + +impl<'a> From for ModuleCompileError<'a> { + fn from(err: CUresult) -> Self { + ModuleCompileError::CUDA(err) + } +} + +impl ModuleData { + pub fn compile_spirv<'a>(ptx_text: &'a str) -> Result> { let mut errors = Vec::new(); let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_text); let ast = match ast { @@ -33,9 +53,40 @@ impl Module { Ok(ast) => ast, }; let spirv = ptx::to_spirv(ast)?; - Ok(Self { - spirv_code: spirv, - compiled_code: vec![None; devices], - }) + let byte_il = unsafe { + slice::from_raw_parts::( + spirv.as_ptr() as *const _, + spirv.len() * mem::size_of::(), + ) + }; + let module = super::device::with_current_exclusive(|dev| { + l0::Module::new_spirv(&mut dev.l0_context, &dev.base, byte_il, None) + }); + match module { + Ok(Ok(module)) => Ok(Mutex::new(Self { base: module })), + Ok(Err(err)) => Err(ModuleCompileError::from(err)), + Err(err) => Err(ModuleCompileError::from(err)), + } } } + +pub struct Function { + base: l0::Kernel<'static>, +} + +pub fn get_function( + hfunc: *mut *mut Function, + hmod: *mut Module, + name: *const c_char, +) -> Result<(), CUresult> { + if hfunc == ptr::null_mut() || hmod == ptr::null_mut() || name == ptr::null() { + return Err(CUresult::CUDA_ERROR_INVALID_VALUE); + } + let name = unsafe { CStr::from_ptr(name) }; + let kernel = unsafe { &*hmod } + .try_lock() + .map(|module| l0::Kernel::new_resident(unsafe { transmute_lifetime(&module.base) }, name)) + .map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE)??; + unsafe { *hfunc = Box::into_raw(Box::new(Function { base: kernel })) }; + Ok(()) +} diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 06843f0..78c3375 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -64,6 +64,7 @@ test_ptx!(reg_local, [12u64], [13u64]); test_ptx!(mov_address, [0xDEADu64], [0u64]); test_ptx!(b64tof64, [111u64], [111u64]); test_ptx!(implicit_param, [34u32], [34u32]); +test_ptx!(pred_not, [10u64, 11u64], [2u64, 0u64]); struct DisplayError { err: T, diff --git a/ptx/src/test/spirv_run/pred_not.ptx b/ptx/src/test/spirv_run/pred_not.ptx new file mode 100644 index 0000000..e058470 --- /dev/null +++ b/ptx/src/test/spirv_run/pred_not.ptx @@ -0,0 +1,28 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry pred_not( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u64 temp; + .reg .u64 temp2; + .reg .u64 temp3; + .reg .pred pred; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.u64 temp, [in_addr]; + ld.u64 temp2, [in_addr + 8]; + setp.lt.u64 pred, temp, temp2; + not.pred pred, pred; + @pred mov.u64 temp3, 1; + @!pred mov.u64 temp3, 2; + st.u64 [out_addr], temp3; + ret; +} diff --git a/ptx/src/test/spirv_run/pred_not.spvtxt b/ptx/src/test/spirv_run/pred_not.spvtxt new file mode 100644 index 0000000..410b1e4 --- /dev/null +++ b/ptx/src/test/spirv_run/pred_not.spvtxt @@ -0,0 +1,78 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + OpCapability Float64 + %44 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "pred_not" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %47 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %bool = OpTypeBool +%_ptr_Function_bool = OpTypePointer Function %bool +%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %ulong_8 = OpConstant %ulong 8 + %true = OpConstantTrue %bool + %false = OpConstantFalse %bool + %ulong_1 = OpConstant %ulong 1 + %ulong_2 = OpConstant %ulong 2 + %1 = OpFunction %void None %47 + %14 = OpFunctionParameter %ulong + %15 = OpFunctionParameter %ulong + %42 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_ulong Function + %7 = OpVariable %_ptr_Function_ulong Function + %8 = OpVariable %_ptr_Function_ulong Function + %9 = OpVariable %_ptr_Function_bool Function + OpStore %2 %14 + OpStore %3 %15 + %17 = OpLoad %ulong %2 + %16 = OpCopyObject %ulong %17 + OpStore %4 %16 + %19 = OpLoad %ulong %3 + %18 = OpCopyObject %ulong %19 + OpStore %5 %18 + %21 = OpLoad %ulong %4 + %39 = OpConvertUToPtr %_ptr_Generic_ulong %21 + %20 = OpLoad %ulong %39 + OpStore %6 %20 + %23 = OpLoad %ulong %4 + %36 = OpIAdd %ulong %23 %ulong_8 + %40 = OpConvertUToPtr %_ptr_Generic_ulong %36 + %22 = OpLoad %ulong %40 + OpStore %7 %22 + %25 = OpLoad %ulong %6 + %26 = OpLoad %ulong %7 + %24 = OpULessThan %bool %25 %26 + OpStore %9 %24 + %28 = OpLoad %bool %9 + %27 = OpSelect %bool %28 %false %true + OpStore %9 %27 + %29 = OpLoad %bool %9 + OpBranchConditional %29 %10 %11 + %10 = OpLabel + %30 = OpCopyObject %ulong %ulong_1 + OpStore %8 %30 + OpBranch %11 + %11 = OpLabel + %31 = OpLoad %bool %9 + OpBranchConditional %31 %13 %12 + %12 = OpLabel + %32 = OpCopyObject %ulong %ulong_2 + OpStore %8 %32 + OpBranch %13 + %13 = OpLabel + %33 = OpLoad %ulong %5 + %34 = OpLoad %ulong %8 + %41 = OpConvertUToPtr %_ptr_Generic_ulong %33 + OpStore %41 %34 + OpReturn + OpFunctionEnd