Minor fixes, add tests

This commit is contained in:
Andrzej Janik 2025-09-04 17:41:43 +00:00
commit 6f068f2737
11 changed files with 147 additions and 25 deletions

View file

@ -77,13 +77,13 @@ bitflags! {
}
impl FatbincWrapper {
pub const MAGIC: c_uint = 0x466243B1;
pub const MAGIC: [u8; 4] = [0x46, 0x62, 0x43, 0xB1];
pub const VERSION_V1: c_uint = 0x1;
pub const VERSION_V2: c_uint = 0x2;
}
impl FatbinHeader {
pub const MAGIC: c_uint = 0xBA55ED50;
pub const MAGIC: [u8; 4] = [0xBA, 0x55, 0xED, 0x50];
pub const VERSION: c_ushort = 0x01;
}

View file

@ -43,7 +43,11 @@ pub fn parse_fatbinc_wrapper<T: Sized>(ptr: &*const T) -> Result<&FatbincWrapper
unsafe { ptr.cast::<FatbincWrapper>().as_ref() }
.ok_or(ParseError::NullPointer("FatbincWrapper"))
.and_then(|ptr| {
ParseError::check_fields("FATBINC_MAGIC", ptr.magic, [FatbincWrapper::MAGIC])?;
ParseError::check_fields(
"FATBINC_MAGIC",
ptr.magic,
[u32::from_ne_bytes(FatbincWrapper::MAGIC)],
)?;
ParseError::check_fields(
"FATBINC_VERSION",
ptr.version,
@ -57,7 +61,11 @@ fn parse_fatbin_header<T: Sized>(ptr: &*const T) -> Result<&FatbinHeader, ParseE
unsafe { ptr.cast::<FatbinHeader>().as_ref() }
.ok_or(ParseError::NullPointer("FatbinHeader"))
.and_then(|ptr| {
ParseError::check_fields("FATBIN_MAGIC", ptr.magic, [FatbinHeader::MAGIC])?;
ParseError::check_fields(
"FATBIN_MAGIC",
ptr.magic,
[u32::from_ne_bytes(FatbinHeader::MAGIC)],
)?;
ParseError::check_fields("FATBIN_VERSION", ptr.version, [FatbinHeader::VERSION])?;
Ok(ptr)
})

View file

@ -1495,6 +1495,16 @@ pub struct Module<'input> {
pub invalid_directives: usize,
}
impl Module<'_> {
pub fn empty() -> Self {
Module {
version: (1, 0),
directives: Vec::new(),
invalid_directives: usize::MAX,
}
}
}
#[derive(Copy, Clone)]
pub enum MulDetails {
Integer {

View file

@ -3,8 +3,8 @@ use logos::Logos;
use ptx_parser_macros::derive_parser;
use rustc_hash::FxHashMap;
use std::fmt::Debug;
use std::iter;
use std::num::{NonZeroU8, ParseFloatError, ParseIntError};
use std::{iter, usize};
use winnow::ascii::dec_uint;
use winnow::combinator::*;
use winnow::error::{ErrMode, ErrorKind};
@ -401,7 +401,7 @@ pub fn parse_module_checked<'input>(
.map_err(|err| PtxError::Parser(err.into_inner()))
};
match parse_result {
Ok(result) if errors.is_empty() => Ok(result),
Ok(result) if errors.is_empty() && result.invalid_directives == 0 => Ok(result),
Ok(_) => Err(errors),
Err(err) => {
errors.push(err);
@ -410,6 +410,39 @@ pub fn parse_module_checked<'input>(
}
}
pub fn parse_module_unchecked<'input>(text: &'input str) -> ast::Module<'input> {
let mut lexer = Token::lexer(text);
let mut errors = Vec::new();
let mut tokens = Vec::new();
loop {
let maybe_token = match lexer.next() {
Some(maybe_token) => maybe_token,
None => break,
};
match maybe_token {
Ok(token) => tokens.push((token, lexer.span())),
Err(mut err) => {
err.0 = lexer.span();
errors.push(PtxError::from(err))
}
}
}
if !errors.is_empty() {
return ast::Module::empty();
}
let parse_result = {
let state = PtxParserState::new(text, &mut errors);
let parser = PtxParser {
state,
input: &tokens[..],
};
module
.parse(parser)
.map_err(|err| PtxError::Parser(err.into_inner()))
};
parse_result.unwrap_or(ast::Module::empty())
}
fn module<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::Module<'input>> {
trace(
"module",

View file

@ -1,7 +1,7 @@
use vergen_gix::{Emitter, GixBuilder};
fn main() {
let git = GixBuilder::all_git().unwrap();
let git = GixBuilder::default().sha(false).build().unwrap();
Emitter::default()
.add_instructions(&git)
.unwrap()

View file

@ -102,7 +102,11 @@ pub(crate) fn get_current_device() -> Result<hipDevice_t, CUerror> {
stack
.try_borrow()
.map_err(|_| CUerror::UNKNOWN)
.and_then(|s| s.last().ok_or(CUerror::UNKNOWN).map(|(_, dev)| *dev))
.and_then(|s| {
s.last()
.ok_or(CUerror::INVALID_CONTEXT)
.map(|(_, dev)| *dev)
})
})
}

View file

@ -517,3 +517,23 @@ pub(crate) unsafe fn primary_context_get_state(
*active_out = active;
Ok(())
}
#[cfg(test)]
mod tests {
use crate::tests::CudaApi;
use cuda_macros::test_cuda;
use std::{mem, ptr};
#[test_cuda]
unsafe fn primary_ctx_retain_does_not_make_it_active(api: impl CudaApi) {
api.cuInit(0);
let mut current_ctx = mem::zeroed();
api.cuCtxGetCurrent(&mut current_ctx);
assert_eq!(current_ctx.0, ptr::null_mut());
let mut primary_ctx = mem::zeroed();
api.cuDevicePrimaryCtxRetain(&mut primary_ctx, 0);
assert_ne!(primary_ctx.0, ptr::null_mut());
api.cuCtxGetCurrent(&mut current_ctx);
assert_eq!(current_ctx.0, ptr::null_mut());
}
}

View file

@ -1,4 +1,4 @@
use crate::r#impl::{context, device, function};
use crate::r#impl::{self, context, device, function};
use comgr::Comgr;
use cuda_types::cuda::*;
use hip_runtime_sys::*;
@ -154,7 +154,7 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi {
_module: *mut cuda_types::cuda::CUmodule,
_fatbinc_wrapper: *const cuda_types::dark_api::FatbincWrapper,
) -> cuda_types::cuda::CUresult {
todo!()
r#impl::unimplemented()
}
unsafe extern "system" fn cudart_interface_fn2(
@ -176,11 +176,11 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi {
_arg4: *mut std::ffi::c_void,
_arg5: u32,
) -> cuda_types::cuda::CUresult {
todo!()
r#impl::unimplemented()
}
unsafe extern "system" fn cudart_interface_fn7(_arg1: usize) -> cuda_types::cuda::CUresult {
todo!()
r#impl::unimplemented()
}
unsafe extern "system" fn get_module_from_cubin_ext2(
@ -190,7 +190,7 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi {
_arg4: *mut std::ffi::c_void,
_arg5: u32,
) -> cuda_types::cuda::CUresult {
todo!()
r#impl::unimplemented()
}
unsafe extern "system" fn get_unknown_buffer1(
@ -276,7 +276,7 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi {
_flags: ::std::os::raw::c_uint,
_dev: cuda_types::cuda::CUdevice,
) -> cuda_types::cuda::CUresult {
todo!()
r#impl::unimplemented()
}
unsafe extern "system" fn heap_alloc(
@ -284,14 +284,14 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi {
_arg2: usize,
_arg3: usize,
) -> cuda_types::cuda::CUresult {
todo!()
r#impl::unimplemented()
}
unsafe extern "system" fn heap_free(
_heap_alloc_record_ptr: *const std::ffi::c_void,
_arg2: *mut usize,
) -> cuda_types::cuda::CUresult {
todo!()
r#impl::unimplemented()
}
unsafe extern "system" fn device_get_attribute_ext(
@ -300,14 +300,14 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi {
_unknown: std::ffi::c_int,
_result: *mut [usize; 2],
) -> cuda_types::cuda::CUresult {
todo!()
r#impl::unimplemented()
}
unsafe extern "system" fn device_get_something(
_result: *mut std::ffi::c_uchar,
_dev: cuda_types::cuda::CUdevice,
) -> cuda_types::cuda::CUresult {
todo!()
r#impl::unimplemented()
}
unsafe extern "system" fn integrity_check(

View file

@ -24,11 +24,11 @@ impl ZludaObject for Library {
pub(crate) fn load_data(
library: &mut CUlibrary,
code: *const ::core::ffi::c_void,
_jit_options: &mut CUjit_option,
_jit_options_values: &mut *mut ::core::ffi::c_void,
_jit_options: Option<&mut CUjit_option>,
_jit_options_values: Option<&mut *mut ::core::ffi::c_void>,
_num_jit_options: ::core::ffi::c_uint,
_library_options: &mut CUlibraryOption,
_library_option_values: &mut *mut ::core::ffi::c_void,
_library_options: Option<&mut CUlibraryOption>,
_library_option_values: Option<&mut *mut ::core::ffi::c_void>,
_num_library_options: ::core::ffi::c_uint,
) -> CUresult {
let hip_module = module::load_hip_module(code)?;
@ -61,3 +61,46 @@ pub(crate) unsafe fn get_global(
) -> hipError_t {
hipModuleGetGlobal(dptr, bytes, library.base, name)
}
#[cfg(test)]
mod tests {
use crate::tests::CudaApi;
use cuda_macros::test_cuda;
use cuda_types::cuda::{CUresult, CUresultConsts};
use std::{ffi::CStr, mem, ptr};
#[test_cuda]
unsafe fn library_loads_without_context(api: impl CudaApi) {
static PTX: &'static CStr = c"
.version 7.0
.target sm_70
.address_size 64
.visible .entry foobar() {
ret;
}
";
api.cuInit(0);
let mut device = mem::zeroed();
assert_eq!(
CUresult::ERROR_INVALID_CONTEXT,
api.cuCtxGetDevice_unchecked(&mut device)
);
let mut module = mem::zeroed();
assert_eq!(
CUresult::ERROR_INVALID_CONTEXT,
api.cuModuleLoadData_unchecked(&mut module, PTX.as_ptr().cast())
);
let mut library = mem::zeroed();
api.cuLibraryLoadData(
&mut library,
PTX.as_ptr().cast(),
ptr::null_mut(),
ptr::null_mut(),
0,
ptr::null_mut(),
ptr::null_mut(),
0,
);
}
}

View file

@ -47,7 +47,7 @@ fn get_ptx(image: *const ::core::ffi::c_void) -> Result<String, CUerror> {
return Err(CUerror::INVALID_VALUE);
}
let ptx = if unsafe { *(image as *const u32) } == FatbincWrapper::MAGIC {
let ptx = if unsafe { *(image as *const [u8; 4]) } == FatbincWrapper::MAGIC {
let ptx_bytes = get_ptx_from_wrapped_fatbin(image)?;
std::str::from_utf8(&ptx_bytes)
.map_err(|_| CUerror::UNKNOWN)?
@ -139,7 +139,11 @@ fn compile_from_ptx_and_cache(
text: &str,
cache_with_key: &mut Option<(zluda_cache::ModuleCache, zluda_cache::ModuleKey)>,
) -> Result<Vec<u8>, CUerror> {
let ast = ptx_parser::parse_module_checked(text).map_err(|_| CUerror::NO_BINARY_FOR_GPU)?;
let ast = if cfg!(debug_assertions) {
ptx_parser::parse_module_checked(text).map_err(|_| CUerror::NO_BINARY_FOR_GPU)?
} else {
ptx_parser::parse_module_unchecked(text)
};
let llvm_module = ptx::to_llvm_module(ast, attributes).map_err(|_| CUerror::UNKNOWN)?;
let elf_module = comgr::compile_bitcode(
comgr,

View file

@ -145,7 +145,7 @@ impl StateTracker {
raw_image,
kind: "archive",
})
} else if unsafe { *(raw_image as *const u32) } == FatbincWrapper::MAGIC {
} else if unsafe { *(raw_image as *const [u8; 4]) } == FatbincWrapper::MAGIC {
unsafe {
fn_logger.try_(|fn_logger| {
trace::record_submodules_from_wrapped_fatbin(