diff --git a/cuda_types/src/dark_api.rs b/cuda_types/src/dark_api.rs index 442c0b6..435b472 100644 --- a/cuda_types/src/dark_api.rs +++ b/cuda_types/src/dark_api.rs @@ -77,13 +77,13 @@ bitflags! { } impl FatbincWrapper { - pub const MAGIC: c_uint = 0x466243B1; + pub const MAGIC: [u8; 4] = [0x46, 0x62, 0x43, 0xB1]; pub const VERSION_V1: c_uint = 0x1; pub const VERSION_V2: c_uint = 0x2; } impl FatbinHeader { - pub const MAGIC: c_uint = 0xBA55ED50; + pub const MAGIC: [u8; 4] = [0xBA, 0x55, 0xED, 0x50]; pub const VERSION: c_ushort = 0x01; } diff --git a/dark_api/src/fatbin.rs b/dark_api/src/fatbin.rs index 8d7868d..86cff8e 100644 --- a/dark_api/src/fatbin.rs +++ b/dark_api/src/fatbin.rs @@ -43,7 +43,11 @@ pub fn parse_fatbinc_wrapper(ptr: &*const T) -> Result<&FatbincWrapper unsafe { ptr.cast::().as_ref() } .ok_or(ParseError::NullPointer("FatbincWrapper")) .and_then(|ptr| { - ParseError::check_fields("FATBINC_MAGIC", ptr.magic, [FatbincWrapper::MAGIC])?; + ParseError::check_fields( + "FATBINC_MAGIC", + ptr.magic, + [u32::from_ne_bytes(FatbincWrapper::MAGIC)], + )?; ParseError::check_fields( "FATBINC_VERSION", ptr.version, @@ -57,7 +61,11 @@ fn parse_fatbin_header(ptr: &*const T) -> Result<&FatbinHeader, ParseE unsafe { ptr.cast::().as_ref() } .ok_or(ParseError::NullPointer("FatbinHeader")) .and_then(|ptr| { - ParseError::check_fields("FATBIN_MAGIC", ptr.magic, [FatbinHeader::MAGIC])?; + ParseError::check_fields( + "FATBIN_MAGIC", + ptr.magic, + [u32::from_ne_bytes(FatbinHeader::MAGIC)], + )?; ParseError::check_fields("FATBIN_VERSION", ptr.version, [FatbinHeader::VERSION])?; Ok(ptr) }) diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index a58830e..e186653 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -1495,6 +1495,16 @@ pub struct Module<'input> { pub invalid_directives: usize, } +impl Module<'_> { + pub fn empty() -> Self { + Module { + version: (1, 0), + directives: Vec::new(), + invalid_directives: usize::MAX, + } + } +} + #[derive(Copy, Clone)] pub enum MulDetails { Integer { diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index dd54a35..6201e43 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -3,8 +3,8 @@ use logos::Logos; use ptx_parser_macros::derive_parser; use rustc_hash::FxHashMap; use std::fmt::Debug; -use std::iter; use std::num::{NonZeroU8, ParseFloatError, ParseIntError}; +use std::{iter, usize}; use winnow::ascii::dec_uint; use winnow::combinator::*; use winnow::error::{ErrMode, ErrorKind}; @@ -401,7 +401,7 @@ pub fn parse_module_checked<'input>( .map_err(|err| PtxError::Parser(err.into_inner())) }; match parse_result { - Ok(result) if errors.is_empty() => Ok(result), + Ok(result) if errors.is_empty() && result.invalid_directives == 0 => Ok(result), Ok(_) => Err(errors), Err(err) => { errors.push(err); @@ -410,6 +410,39 @@ pub fn parse_module_checked<'input>( } } +pub fn parse_module_unchecked<'input>(text: &'input str) -> ast::Module<'input> { + let mut lexer = Token::lexer(text); + let mut errors = Vec::new(); + let mut tokens = Vec::new(); + loop { + let maybe_token = match lexer.next() { + Some(maybe_token) => maybe_token, + None => break, + }; + match maybe_token { + Ok(token) => tokens.push((token, lexer.span())), + Err(mut err) => { + err.0 = lexer.span(); + errors.push(PtxError::from(err)) + } + } + } + if !errors.is_empty() { + return ast::Module::empty(); + } + let parse_result = { + let state = PtxParserState::new(text, &mut errors); + let parser = PtxParser { + state, + input: &tokens[..], + }; + module + .parse(parser) + .map_err(|err| PtxError::Parser(err.into_inner())) + }; + parse_result.unwrap_or(ast::Module::empty()) +} + fn module<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult> { trace( "module", diff --git a/zluda/build.rs b/zluda/build.rs index 1f4e654..82508db 100644 --- a/zluda/build.rs +++ b/zluda/build.rs @@ -1,7 +1,7 @@ use vergen_gix::{Emitter, GixBuilder}; fn main() { - let git = GixBuilder::all_git().unwrap(); + let git = GixBuilder::default().sha(false).build().unwrap(); Emitter::default() .add_instructions(&git) .unwrap() diff --git a/zluda/src/impl/context.rs b/zluda/src/impl/context.rs index 4862a45..8933116 100644 --- a/zluda/src/impl/context.rs +++ b/zluda/src/impl/context.rs @@ -102,7 +102,11 @@ pub(crate) fn get_current_device() -> Result { stack .try_borrow() .map_err(|_| CUerror::UNKNOWN) - .and_then(|s| s.last().ok_or(CUerror::UNKNOWN).map(|(_, dev)| *dev)) + .and_then(|s| { + s.last() + .ok_or(CUerror::INVALID_CONTEXT) + .map(|(_, dev)| *dev) + }) }) } diff --git a/zluda/src/impl/device.rs b/zluda/src/impl/device.rs index 074fcb4..03b2dcc 100644 --- a/zluda/src/impl/device.rs +++ b/zluda/src/impl/device.rs @@ -517,3 +517,23 @@ pub(crate) unsafe fn primary_context_get_state( *active_out = active; Ok(()) } + +#[cfg(test)] +mod tests { + use crate::tests::CudaApi; + use cuda_macros::test_cuda; + use std::{mem, ptr}; + + #[test_cuda] + unsafe fn primary_ctx_retain_does_not_make_it_active(api: impl CudaApi) { + api.cuInit(0); + let mut current_ctx = mem::zeroed(); + api.cuCtxGetCurrent(&mut current_ctx); + assert_eq!(current_ctx.0, ptr::null_mut()); + let mut primary_ctx = mem::zeroed(); + api.cuDevicePrimaryCtxRetain(&mut primary_ctx, 0); + assert_ne!(primary_ctx.0, ptr::null_mut()); + api.cuCtxGetCurrent(&mut current_ctx); + assert_eq!(current_ctx.0, ptr::null_mut()); + } +} diff --git a/zluda/src/impl/driver.rs b/zluda/src/impl/driver.rs index 3a54ac7..f14d990 100644 --- a/zluda/src/impl/driver.rs +++ b/zluda/src/impl/driver.rs @@ -1,4 +1,4 @@ -use crate::r#impl::{context, device, function}; +use crate::r#impl::{self, context, device, function}; use comgr::Comgr; use cuda_types::cuda::*; use hip_runtime_sys::*; @@ -154,7 +154,7 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi { _module: *mut cuda_types::cuda::CUmodule, _fatbinc_wrapper: *const cuda_types::dark_api::FatbincWrapper, ) -> cuda_types::cuda::CUresult { - todo!() + r#impl::unimplemented() } unsafe extern "system" fn cudart_interface_fn2( @@ -176,11 +176,11 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi { _arg4: *mut std::ffi::c_void, _arg5: u32, ) -> cuda_types::cuda::CUresult { - todo!() + r#impl::unimplemented() } unsafe extern "system" fn cudart_interface_fn7(_arg1: usize) -> cuda_types::cuda::CUresult { - todo!() + r#impl::unimplemented() } unsafe extern "system" fn get_module_from_cubin_ext2( @@ -190,7 +190,7 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi { _arg4: *mut std::ffi::c_void, _arg5: u32, ) -> cuda_types::cuda::CUresult { - todo!() + r#impl::unimplemented() } unsafe extern "system" fn get_unknown_buffer1( @@ -276,7 +276,7 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi { _flags: ::std::os::raw::c_uint, _dev: cuda_types::cuda::CUdevice, ) -> cuda_types::cuda::CUresult { - todo!() + r#impl::unimplemented() } unsafe extern "system" fn heap_alloc( @@ -284,14 +284,14 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi { _arg2: usize, _arg3: usize, ) -> cuda_types::cuda::CUresult { - todo!() + r#impl::unimplemented() } unsafe extern "system" fn heap_free( _heap_alloc_record_ptr: *const std::ffi::c_void, _arg2: *mut usize, ) -> cuda_types::cuda::CUresult { - todo!() + r#impl::unimplemented() } unsafe extern "system" fn device_get_attribute_ext( @@ -300,14 +300,14 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi { _unknown: std::ffi::c_int, _result: *mut [usize; 2], ) -> cuda_types::cuda::CUresult { - todo!() + r#impl::unimplemented() } unsafe extern "system" fn device_get_something( _result: *mut std::ffi::c_uchar, _dev: cuda_types::cuda::CUdevice, ) -> cuda_types::cuda::CUresult { - todo!() + r#impl::unimplemented() } unsafe extern "system" fn integrity_check( diff --git a/zluda/src/impl/library.rs b/zluda/src/impl/library.rs index 6841cfc..3c213bf 100644 --- a/zluda/src/impl/library.rs +++ b/zluda/src/impl/library.rs @@ -24,11 +24,11 @@ impl ZludaObject for Library { pub(crate) fn load_data( library: &mut CUlibrary, code: *const ::core::ffi::c_void, - _jit_options: &mut CUjit_option, - _jit_options_values: &mut *mut ::core::ffi::c_void, + _jit_options: Option<&mut CUjit_option>, + _jit_options_values: Option<&mut *mut ::core::ffi::c_void>, _num_jit_options: ::core::ffi::c_uint, - _library_options: &mut CUlibraryOption, - _library_option_values: &mut *mut ::core::ffi::c_void, + _library_options: Option<&mut CUlibraryOption>, + _library_option_values: Option<&mut *mut ::core::ffi::c_void>, _num_library_options: ::core::ffi::c_uint, ) -> CUresult { let hip_module = module::load_hip_module(code)?; @@ -61,3 +61,46 @@ pub(crate) unsafe fn get_global( ) -> hipError_t { hipModuleGetGlobal(dptr, bytes, library.base, name) } + +#[cfg(test)] +mod tests { + use crate::tests::CudaApi; + use cuda_macros::test_cuda; + use cuda_types::cuda::{CUresult, CUresultConsts}; + use std::{ffi::CStr, mem, ptr}; + + #[test_cuda] + unsafe fn library_loads_without_context(api: impl CudaApi) { + static PTX: &'static CStr = c" + .version 7.0 + .target sm_70 + .address_size 64 + + .visible .entry foobar() { + ret; + } + "; + api.cuInit(0); + let mut device = mem::zeroed(); + assert_eq!( + CUresult::ERROR_INVALID_CONTEXT, + api.cuCtxGetDevice_unchecked(&mut device) + ); + let mut module = mem::zeroed(); + assert_eq!( + CUresult::ERROR_INVALID_CONTEXT, + api.cuModuleLoadData_unchecked(&mut module, PTX.as_ptr().cast()) + ); + let mut library = mem::zeroed(); + api.cuLibraryLoadData( + &mut library, + PTX.as_ptr().cast(), + ptr::null_mut(), + ptr::null_mut(), + 0, + ptr::null_mut(), + ptr::null_mut(), + 0, + ); + } +} diff --git a/zluda/src/impl/module.rs b/zluda/src/impl/module.rs index f8db917..8cfe493 100644 --- a/zluda/src/impl/module.rs +++ b/zluda/src/impl/module.rs @@ -47,7 +47,7 @@ fn get_ptx(image: *const ::core::ffi::c_void) -> Result { return Err(CUerror::INVALID_VALUE); } - let ptx = if unsafe { *(image as *const u32) } == FatbincWrapper::MAGIC { + let ptx = if unsafe { *(image as *const [u8; 4]) } == FatbincWrapper::MAGIC { let ptx_bytes = get_ptx_from_wrapped_fatbin(image)?; std::str::from_utf8(&ptx_bytes) .map_err(|_| CUerror::UNKNOWN)? @@ -139,7 +139,11 @@ fn compile_from_ptx_and_cache( text: &str, cache_with_key: &mut Option<(zluda_cache::ModuleCache, zluda_cache::ModuleKey)>, ) -> Result, CUerror> { - let ast = ptx_parser::parse_module_checked(text).map_err(|_| CUerror::NO_BINARY_FOR_GPU)?; + let ast = if cfg!(debug_assertions) { + ptx_parser::parse_module_checked(text).map_err(|_| CUerror::NO_BINARY_FOR_GPU)? + } else { + ptx_parser::parse_module_unchecked(text) + }; let llvm_module = ptx::to_llvm_module(ast, attributes).map_err(|_| CUerror::UNKNOWN)?; let elf_module = comgr::compile_bitcode( comgr, diff --git a/zluda_trace/src/trace.rs b/zluda_trace/src/trace.rs index 1242df6..ae0662a 100644 --- a/zluda_trace/src/trace.rs +++ b/zluda_trace/src/trace.rs @@ -145,7 +145,7 @@ impl StateTracker { raw_image, kind: "archive", }) - } else if unsafe { *(raw_image as *const u32) } == FatbincWrapper::MAGIC { + } else if unsafe { *(raw_image as *const [u8; 4]) } == FatbincWrapper::MAGIC { unsafe { fn_logger.try_(|fn_logger| { trace::record_submodules_from_wrapped_fatbin(