diff --git a/Cargo.lock b/Cargo.lock index 816c9fb..7b7948b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -333,8 +333,10 @@ dependencies = [ "cglue", "cuda_types", "format", + "lz4-sys", "paste", "uuid", + "zstd-safe", ] [[package]] @@ -500,6 +502,18 @@ dependencies = [ "uuid", ] +[[package]] +name = "getrandom" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasi", +] + [[package]] name = "glob" version = "0.3.1" @@ -588,10 +602,11 @@ checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" [[package]] name = "jobserver" -version = "0.1.32" +version = "0.1.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" +checksum = "38f262f097c174adebe41eb73d66ae9c06b2844fb0da69969647bbddd9b0538a" dependencies = [ + "getrandom", "libc", ] @@ -1124,6 +1139,12 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + [[package]] name = "rawpointer" version = "0.2.1" @@ -1522,6 +1543,15 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "wasi" +version = "0.14.2+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3" +dependencies = [ + "wit-bindgen-rt", +] + [[package]] name = "wchar" version = "0.6.1" @@ -1666,6 +1696,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "wit-bindgen-rt" +version = "0.39.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1" +dependencies = [ + "bitflags 2.9.1", +] + [[package]] name = "xattr" version = "1.5.0" @@ -1790,7 +1829,6 @@ dependencies = [ "format", "goblin", "libc", - "lz4-sys", "parking_lot", "paste", "ptx", diff --git a/cuda_base/src/lib.rs b/cuda_base/src/lib.rs index 0f40843..6cef62d 100644 --- a/cuda_base/src/lib.rs +++ b/cuda_base/src/lib.rs @@ -199,7 +199,8 @@ impl VisitMut for FixFnSignatures { } const MODULES: &[&str] = &[ - "context", "device", "driver", "function", "link", "memory", "module", "pointer", "stream", + "context", "device", "driver", "function", "library", "link", "memory", "module", "pointer", + "stream", ]; #[proc_macro] diff --git a/cuda_types/src/dark_api.rs b/cuda_types/src/dark_api.rs index 0ee85d7..442c0b6 100644 --- a/cuda_types/src/dark_api.rs +++ b/cuda_types/src/dark_api.rs @@ -78,122 +78,17 @@ bitflags! { impl FatbincWrapper { pub const MAGIC: c_uint = 0x466243B1; - const VERSION_V1: c_uint = 0x1; + pub const VERSION_V1: c_uint = 0x1; pub const VERSION_V2: c_uint = 0x2; - - pub fn new<'a, T: Sized>(ptr: &*const T) -> Result<&'a Self, ParseError> { - unsafe { ptr.cast::().as_ref() } - .ok_or(ParseError::NullPointer("FatbincWrapper")) - .and_then(|ptr| { - ParseError::check_fields("FATBINC_MAGIC", ptr.magic, [Self::MAGIC])?; - ParseError::check_fields( - "FATBINC_VERSION", - ptr.version, - [Self::VERSION_V1, Self::VERSION_V2], - )?; - Ok(ptr) - }) - } } impl FatbinHeader { - const MAGIC: c_uint = 0xBA55ED50; - const VERSION: c_ushort = 0x01; - - pub fn new<'a, T: Sized>(ptr: &'a *const T) -> Result<&'a Self, ParseError> { - unsafe { ptr.cast::().as_ref() } - .ok_or(ParseError::NullPointer("FatbinHeader")) - .and_then(|ptr| { - ParseError::check_fields("FATBIN_MAGIC", ptr.magic, [Self::MAGIC])?; - ParseError::check_fields("FATBIN_VERSION", ptr.version, [Self::VERSION])?; - Ok(ptr) - }) - } - - pub unsafe fn get_content<'a>(&'a self) -> &'a [u8] { - let start = std::ptr::from_ref(self) - .cast::() - .add(self.header_size as usize); - std::slice::from_raw_parts(start, self.files_size as usize) - } + pub const MAGIC: c_uint = 0xBA55ED50; + pub const VERSION: c_ushort = 0x01; } impl FatbinFileHeader { pub const HEADER_KIND_PTX: c_ushort = 0x01; pub const HEADER_KIND_ELF: c_ushort = 0x02; - const HEADER_VERSION_CURRENT: c_ushort = 0x101; - - pub fn new_ptx(ptr: *const T) -> Result, ParseError> { - unsafe { ptr.cast::().as_ref() } - .ok_or(ParseError::NullPointer("FatbinFileHeader")) - .and_then(|ptr| { - ParseError::check_fields( - "FATBIN_FILE_HEADER_VERSION_CURRENT", - ptr.version, - [Self::HEADER_VERSION_CURRENT], - )?; - match ptr.kind { - Self::HEADER_KIND_PTX => Ok(Some(ptr)), - Self::HEADER_KIND_ELF => Ok(None), - _ => Err(ParseError::UnexpectedBinaryField { - field_name: "FATBIN_FILE_HEADER_KIND", - observed: ptr.kind.into(), - expected: vec![Self::HEADER_KIND_PTX.into(), Self::HEADER_KIND_ELF.into()], - }), - } - }) - } - - pub unsafe fn next<'a>(slice: &'a mut &[u8]) -> Result, ParseError> { - if slice.len() < std::mem::size_of::() { - return Ok(None); - } - let this = &*slice.as_ptr().cast::(); - let next_element = slice - .split_at_checked(this.header_size as usize + this.padded_payload_size as usize) - .map(|(_, next)| next); - *slice = next_element.unwrap_or(&[]); - ParseError::check_fields( - "FATBIN_FILE_HEADER_VERSION_CURRENT", - this.version, - [Self::HEADER_VERSION_CURRENT], - )?; - Ok(Some(this)) - } - - pub unsafe fn get_payload<'a>(&'a self) -> &'a [u8] { - let start = std::ptr::from_ref(self) - .cast::() - .add(self.header_size as usize); - std::slice::from_raw_parts(start, self.payload_size as usize) - } -} - -pub enum ParseError { - NullPointer(&'static str), - UnexpectedBinaryField { - field_name: &'static str, - observed: u32, - expected: Vec, - }, -} - -impl ParseError { - pub(crate) fn check_fields + Eq + Copy>( - name: &'static str, - observed: T, - expected: [T; N], - ) -> Result<(), Self> { - if expected.contains(&observed) { - Ok(()) - } else { - let observed = observed.into(); - let expected = expected.into_iter().map(Into::into).collect(); - Err(ParseError::UnexpectedBinaryField { - field_name: name, - expected, - observed, - }) - } - } + pub const HEADER_VERSION_CURRENT: c_ushort = 0x101; } diff --git a/dark_api/Cargo.toml b/dark_api/Cargo.toml index 313b203..f65e538 100644 --- a/dark_api/Cargo.toml +++ b/dark_api/Cargo.toml @@ -10,3 +10,5 @@ uuid = "1.16" paste = "1.0" bit-vec = "0.8.0" cglue = "0.3.5" +lz4-sys = "1.9" +zstd-safe = { version = "7.2.4", features = ["std"] } diff --git a/dark_api/src/fatbin.rs b/dark_api/src/fatbin.rs new file mode 100644 index 0000000..a34f806 --- /dev/null +++ b/dark_api/src/fatbin.rs @@ -0,0 +1,235 @@ +// This file contains a higher-level interface for parsing fatbins + +use std::ptr; + +use cuda_types::dark_api::*; + +pub enum ParseError { + NullPointer(&'static str), + UnexpectedBinaryField { + field_name: &'static str, + observed: u32, + expected: Vec, + }, +} + +impl ParseError { + pub(crate) fn check_fields + Eq + Copy>( + name: &'static str, + observed: T, + expected: [T; N], + ) -> Result<(), Self> { + if expected.contains(&observed) { + Ok(()) + } else { + let observed = observed.into(); + let expected = expected.into_iter().map(Into::into).collect(); + Err(ParseError::UnexpectedBinaryField { + field_name: name, + expected, + observed, + }) + } + } +} + +pub enum FatbinError { + ParseFailure(ParseError), + Lz4DecompressionFailure, + ZstdDecompressionFailure(usize), +} + +pub fn parse_fatbinc_wrapper(ptr: &*const T) -> Result<&FatbincWrapper, ParseError> { + unsafe { ptr.cast::().as_ref() } + .ok_or(ParseError::NullPointer("FatbincWrapper")) + .and_then(|ptr| { + ParseError::check_fields("FATBINC_MAGIC", ptr.magic, [FatbincWrapper::MAGIC])?; + ParseError::check_fields( + "FATBINC_VERSION", + ptr.version, + [FatbincWrapper::VERSION_V1, FatbincWrapper::VERSION_V2], + )?; + Ok(ptr) + }) +} + +fn parse_fatbin_header(ptr: &*const T) -> Result<&FatbinHeader, ParseError> { + unsafe { ptr.cast::().as_ref() } + .ok_or(ParseError::NullPointer("FatbinHeader")) + .and_then(|ptr| { + ParseError::check_fields("FATBIN_MAGIC", ptr.magic, [FatbinHeader::MAGIC])?; + ParseError::check_fields("FATBIN_VERSION", ptr.version, [FatbinHeader::VERSION])?; + Ok(ptr) + }) +} + +pub struct Fatbin<'a> { + pub wrapper: &'a FatbincWrapper, +} + +impl<'a> Fatbin<'a> { + pub fn new(ptr: &'a *const T) -> Result { + let wrapper: &FatbincWrapper = + parse_fatbinc_wrapper(ptr).map_err(|e| FatbinError::ParseFailure(e))?; + + Ok(Fatbin { wrapper }) + } + + pub fn get_first(&self) -> Result { + let header: &FatbinHeader = + parse_fatbin_header(&self.wrapper.data).map_err(|e| FatbinError::ParseFailure(e))?; + Ok(FatbinSubmodule::new(header)) + } + + pub fn get_submodules(&self) -> Option { + let is_version_2 = self.wrapper.version == FatbincWrapper::VERSION_V2; + if !is_version_2 { + return None; + } + + Some(FatbinSubmoduleIterator { + fatbins: self.wrapper.filename_or_fatbins as *const *const std::ffi::c_void, + }) + } +} + +pub struct FatbinSubmodule<'a> { + pub header: &'a FatbinHeader, // TODO: maybe make private +} + +impl<'a> FatbinSubmodule<'a> { + pub fn new(header: &'a FatbinHeader) -> Self { + FatbinSubmodule { header } + } + + pub fn get_files(&self) -> FatbinFileIterator { + unsafe { FatbinFileIterator::new(self.header) } + } +} + +pub struct FatbinSubmoduleIterator { + fatbins: *const *const std::ffi::c_void, +} + +impl FatbinSubmoduleIterator { + pub unsafe fn next(&mut self) -> Result, ParseError> { + if *self.fatbins != ptr::null() { + let header = *self.fatbins as *const FatbinHeader; + self.fatbins = self.fatbins.add(1); + Ok(Some(FatbinSubmodule::new(header.as_ref().ok_or( + ParseError::NullPointer("FatbinSubmoduleIterator"), + )?))) + } else { + Ok(None) + } + } +} + +pub struct FatbinFile<'a> { + pub header: &'a FatbinFileHeader, +} + +impl<'a> FatbinFile<'a> { + pub fn new(header: &'a FatbinFileHeader) -> Self { + Self { header } + } + + pub unsafe fn get_payload(&'a self) -> &'a [u8] { + let start = std::ptr::from_ref(self) + .cast::() + .add(self.header.header_size as usize); + std::slice::from_raw_parts(start, self.header.payload_size as usize) + } + + pub unsafe fn decompress(&'a self) -> Result, FatbinError> { + let payload = if self + .header + .flags + .contains(FatbinFileHeaderFlags::CompressedLz4) + { + unsafe { decompress_lz4(self) }? + } else if self + .header + .flags + .contains(FatbinFileHeaderFlags::CompressedZstd) + { + unsafe { decompress_zstd(self) }? + } else { + unsafe { self.get_payload().to_vec() } + }; + + Ok(payload) + } +} + +pub struct FatbinFileIterator<'a> { + file_buffer: &'a [u8], +} + +impl<'a> FatbinFileIterator<'a> { + pub unsafe fn new(header: &'a FatbinHeader) -> Self { + let start = std::ptr::from_ref(header) + .cast::() + .add(header.header_size as usize); + let file_buffer = std::slice::from_raw_parts(start, header.files_size as usize); + + Self { file_buffer } + } + + pub unsafe fn next(&mut self) -> Result, ParseError> { + if self.file_buffer.len() < std::mem::size_of::() { + return Ok(None); + } + let this = &*self.file_buffer.as_ptr().cast::(); + let next_element = self + .file_buffer + .split_at_checked(this.header_size as usize + this.padded_payload_size as usize) + .map(|(_, next)| next); + self.file_buffer = next_element.unwrap_or(&[]); + ParseError::check_fields( + "FATBIN_FILE_HEADER_VERSION_CURRENT", + this.version, + [FatbinFileHeader::HEADER_VERSION_CURRENT], + )?; + Ok(Some(FatbinFile::new(this))) + } +} + +const MAX_MODULE_DECOMPRESSION_BOUND: usize = 64 * 1024 * 1024; + +pub unsafe fn decompress_lz4(file: &FatbinFile) -> Result, FatbinError> { + let decompressed_size = usize::max(1024, file.header.uncompressed_payload as usize); + let mut decompressed_vec = vec![0u8; decompressed_size]; + loop { + match lz4_sys::LZ4_decompress_safe( + file.get_payload().as_ptr() as *const _, + decompressed_vec.as_mut_ptr() as *mut _, + file.header.payload_size as _, + decompressed_vec.len() as _, + ) { + error if error < 0 => { + let new_size = decompressed_vec.len() * 2; + if new_size > MAX_MODULE_DECOMPRESSION_BOUND { + return Err(FatbinError::Lz4DecompressionFailure); + } + decompressed_vec.resize(decompressed_vec.len() * 2, 0); + } + real_decompressed_size => { + decompressed_vec.truncate(real_decompressed_size as usize); + return Ok(decompressed_vec); + } + } + } +} + +pub unsafe fn decompress_zstd(file: &FatbinFile) -> Result, FatbinError> { + let mut result = Vec::with_capacity(file.header.uncompressed_payload as usize); + let payload = file.get_payload(); + match zstd_safe::decompress(&mut result, payload) { + Ok(actual_size) => { + result.truncate(actual_size); + Ok(result) + } + Err(err) => Err(FatbinError::ZstdDecompressionFailure(err)), + } +} diff --git a/dark_api/src/lib.rs b/dark_api/src/lib.rs index 3cd1f0e..5f87c0d 100644 --- a/dark_api/src/lib.rs +++ b/dark_api/src/lib.rs @@ -2,6 +2,8 @@ use std::ffi::c_void; use cuda_types::cuda::CUuuid; +pub mod fatbin; + macro_rules! dark_api_init { (SIZE_OF, $table_len:literal, $type_:ty) => { (std::mem::size_of::() * $table_len) as *const std::ffi::c_void diff --git a/zluda/src/impl/library.rs b/zluda/src/impl/library.rs new file mode 100644 index 0000000..d9a99cc --- /dev/null +++ b/zluda/src/impl/library.rs @@ -0,0 +1,38 @@ +use super::module; + +use super::ZludaObject; + +use cuda_types::cuda::*; +use hip_runtime_sys::*; + +pub(crate) struct Library { + base: hipModule_t, +} + +impl ZludaObject for Library { + const COOKIE: usize = 0xb328a916cc234d7c; + + type CudaHandle = CUlibrary; + + fn drop_checked(&mut self) -> CUresult { + // TODO: we will want to test that we handle `cuModuleUnload` on a module that came from a library correctly, without calling `hipModuleUnload` twice. + unsafe { hipModuleUnload(self.base) }?; + Ok(()) + } +} + +/// This implementation simply loads the code as a HIP module for now. The various JIT and library options are ignored. +pub(crate) fn load_data( + library: &mut CUlibrary, + code: *const ::core::ffi::c_void, + _jit_options: &mut CUjit_option, + _jit_options_values: &mut *mut ::core::ffi::c_void, + _num_jit_options: ::core::ffi::c_uint, + _library_options: &mut CUlibraryOption, + _library_option_values: &mut *mut ::core::ffi::c_void, + _num_library_options: ::core::ffi::c_uint, +) -> CUresult { + let hip_module = module::load_hip_module(code)?; + *library = Library { base: hip_module }.wrap(); + Ok(()) +} diff --git a/zluda/src/impl/mod.rs b/zluda/src/impl/mod.rs index 0f05db8..e9d675b 100644 --- a/zluda/src/impl/mod.rs +++ b/zluda/src/impl/mod.rs @@ -10,6 +10,7 @@ pub(super) mod context; pub(super) mod device; pub(super) mod driver; pub(super) mod function; +pub(super) mod library; pub(super) mod memory; pub(super) mod module; pub(super) mod pointer; @@ -135,6 +136,9 @@ from_cuda_nop!( cuda_types::cuda::CUdevprop, CUdevice_attribute, CUdriverProcAddressQueryResult, + CUjit_option, + CUlibrary, + CUlibraryOption, CUmoduleLoadingMode, CUuuid ); @@ -169,6 +173,15 @@ impl<'a> FromCuda<'a, *const ::core::ffi::c_char> for &CStr { } } +impl<'a> FromCuda<'a, *const ::core::ffi::c_void> for &'a ::core::ffi::c_void { + fn from_cuda(x: &'a *const ::core::ffi::c_void) -> Result { + match unsafe { x.as_ref() } { + Some(x) => Ok(x), + None => Err(CUerror::INVALID_VALUE), + } + } +} + pub(crate) trait ZludaObject: Sized + Send + Sync { const COOKIE: usize; const LIVENESS_FAIL: CUerror = cuda_types::cuda::CUerror::INVALID_VALUE; diff --git a/zluda/src/impl/module.rs b/zluda/src/impl/module.rs index e0170a2..c55bfa6 100644 --- a/zluda/src/impl/module.rs +++ b/zluda/src/impl/module.rs @@ -1,5 +1,9 @@ use super::{driver, ZludaObject}; -use cuda_types::cuda::*; +use cuda_types::{ + cuda::*, + dark_api::{FatbinFileHeader, FatbincWrapper}, +}; +use dark_api::fatbin::Fatbin; use hip_runtime_sys::*; use std::{ffi::CStr, mem}; @@ -18,12 +22,47 @@ impl ZludaObject for Module { } } -pub(crate) fn load_data(module: &mut CUmodule, image: *const std::ffi::c_void) -> CUresult { +fn get_ptx_from_wrapped_fatbin(image: *const ::core::ffi::c_void) -> Result, CUerror> { + let fatbin = Fatbin::new(&image).map_err(|_| CUerror::UNKNOWN)?; + let first = fatbin.get_first().map_err(|_| CUerror::UNKNOWN)?; + let mut files = first.get_files(); + + while let Some(file) = unsafe { files.next().map_err(|_| CUerror::UNKNOWN)? } { + // Eventually we will want to get the PTX for the highest hardware version that we can get to compile. But for now we just get the first PTX we can find. + if file.header.kind == FatbinFileHeader::HEADER_KIND_PTX { + let decompressed = unsafe { file.decompress() }.map_err(|_| CUerror::UNKNOWN)?; + return Ok(decompressed); + } + } + + Err(CUerror::NO_BINARY_FOR_GPU) +} + +/// get_ptx takes an `image` that can be either a fatbin or a NULL-terminated ptx, and returns a String containing a ptx extracted from `image`. +fn get_ptx(image: *const ::core::ffi::c_void) -> Result { + if image.is_null() { + return Err(CUerror::INVALID_VALUE); + } + + let ptx = if unsafe { *(image as *const u32) } == FatbincWrapper::MAGIC { + let ptx_bytes = get_ptx_from_wrapped_fatbin(image)?; + str::from_utf8(&ptx_bytes) + .map_err(|_| CUerror::UNKNOWN)? + .to_owned() + } else { + unsafe { CStr::from_ptr(image.cast()) } + .to_str() + .map_err(|_| CUerror::INVALID_VALUE)? + .to_owned() + }; + + Ok(ptx) +} + +pub(crate) fn load_hip_module(image: *const std::ffi::c_void) -> Result { let global_state = driver::global_state()?; - let text = unsafe { CStr::from_ptr(image.cast()) } - .to_str() - .map_err(|_| CUerror::INVALID_VALUE)?; - let ast = ptx_parser::parse_module_checked(text).map_err(|_| CUerror::NO_BINARY_FOR_GPU)?; + let text = get_ptx(image)?; + let ast = ptx_parser::parse_module_checked(&text).map_err(|_| CUerror::NO_BINARY_FOR_GPU)?; let llvm_module = ptx::to_llvm_module(ast).map_err(|_| CUerror::UNKNOWN)?; let mut dev = 0; unsafe { hipCtxGetDevice(&mut dev) }?; @@ -38,6 +77,11 @@ pub(crate) fn load_data(module: &mut CUmodule, image: *const std::ffi::c_void) - .map_err(|_| CUerror::UNKNOWN)?; let mut hip_module = unsafe { mem::zeroed() }; unsafe { hipModuleLoadData(&mut hip_module, elf_module.as_ptr().cast()) }?; + Ok(hip_module) +} + +pub(crate) fn load_data(module: &mut CUmodule, image: &std::ffi::c_void) -> CUresult { + let hip_module = load_hip_module(image)?; *module = Module { base: hip_module }.wrap(); Ok(()) } diff --git a/zluda/src/lib.rs b/zluda/src/lib.rs index 0deec1a..0b8c784 100644 --- a/zluda/src/lib.rs +++ b/zluda/src/lib.rs @@ -66,6 +66,7 @@ cuda_base::cuda_function_declarations!( cuGetProcAddress, cuGetProcAddress_v2, cuInit, + cuLibraryLoadData, cuMemAlloc_v2, cuMemFree_v2, cuMemGetAddressRange_v2, @@ -84,4 +85,4 @@ cuda_base::cuda_function_declarations!( implemented_in_function <= [ cuLaunchKernel, ] -); \ No newline at end of file +); diff --git a/zluda_dump/Cargo.toml b/zluda_dump/Cargo.toml index 3dd97a8..234b58f 100644 --- a/zluda_dump/Cargo.toml +++ b/zluda_dump/Cargo.toml @@ -14,7 +14,6 @@ ptx_parser = { path = "../ptx_parser" } zluda_dump_common = { path = "../zluda_dump_common" } format = { path = "../format" } dark_api = { path = "../dark_api" } -lz4-sys = "1.9" regex = "1.4" dynasm = "1.2" dynasmrt = "1.2" diff --git a/zluda_dump/src/lib.rs b/zluda_dump/src/lib.rs index d5fd0f8..9440dab 100644 --- a/zluda_dump/src/lib.rs +++ b/zluda_dump/src/lib.rs @@ -1,3 +1,4 @@ +use ::dark_api::fatbin::FatbinFileIterator; use ::dark_api::FnFfi; use cuda_types::cuda::*; use dark_api::DarkApiState2; @@ -360,7 +361,16 @@ impl DarkApiDump { }); } fn_logger.try_(|fn_logger| unsafe { - trace::record_submodules_from_fatbin(*module, fatbin_header, fn_logger, state) + trace::record_submodules( + *module, + fn_logger, + state, + FatbinFileIterator::new( + fatbin_header + .as_ref() + .ok_or(ErrorEntry::NullPointer("get_module_from_cubin_ext2_post"))?, + ), + ) }); } } diff --git a/zluda_dump/src/log.rs b/zluda_dump/src/log.rs index 337bc75..6ce7be2 100644 --- a/zluda_dump/src/log.rs +++ b/zluda_dump/src/log.rs @@ -308,11 +308,11 @@ pub(crate) enum ErrorEntry { unsafe impl Send for ErrorEntry {} unsafe impl Sync for ErrorEntry {} -impl From for ErrorEntry { - fn from(e: cuda_types::dark_api::ParseError) -> Self { +impl From for ErrorEntry { + fn from(e: dark_api::fatbin::ParseError) -> Self { match e { - cuda_types::dark_api::ParseError::NullPointer(s) => ErrorEntry::NullPointer(s), - cuda_types::dark_api::ParseError::UnexpectedBinaryField { + dark_api::fatbin::ParseError::NullPointer(s) => ErrorEntry::NullPointer(s), + dark_api::fatbin::ParseError::UnexpectedBinaryField { field_name, observed, expected, @@ -325,6 +325,20 @@ impl From for ErrorEntry { } } +impl From for ErrorEntry { + fn from(e: dark_api::fatbin::FatbinError) -> Self { + match e { + dark_api::fatbin::FatbinError::ParseFailure(parse_error) => parse_error.into(), + dark_api::fatbin::FatbinError::Lz4DecompressionFailure => { + ErrorEntry::Lz4DecompressionFailure + } + dark_api::fatbin::FatbinError::ZstdDecompressionFailure(c) => { + ErrorEntry::ZstdDecompressionFailure(c) + } + } + } +} + impl Display for ErrorEntry { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { diff --git a/zluda_dump/src/trace.rs b/zluda_dump/src/trace.rs index 81d416d..23665fc 100644 --- a/zluda_dump/src/trace.rs +++ b/zluda_dump/src/trace.rs @@ -4,7 +4,10 @@ use crate::{ }; use cuda_types::{ cuda::*, - dark_api::{FatbinFileHeader, FatbinFileHeaderFlags, FatbinHeader, FatbincWrapper}, + dark_api::{FatbinFileHeader, FatbinFileHeaderFlags, FatbincWrapper}, +}; +use dark_api::fatbin::{ + decompress_lz4, decompress_zstd, Fatbin, FatbinFileIterator, FatbinSubmodule, }; use rustc_hash::{FxHashMap, FxHashSet}; use std::{ @@ -13,7 +16,6 @@ use std::{ fs::{self, File}, io::{self, Read, Write}, path::PathBuf, - ptr, }; use unwrap_or::unwrap_some_or; @@ -259,14 +261,12 @@ pub(crate) unsafe fn record_submodules_from_wrapped_fatbin( fn_logger: &mut FnCallLog, state: &mut StateTracker, ) -> Result<(), ErrorEntry> { - let fatbinc_wrapper = FatbincWrapper::new(&fatbinc_wrapper).map_err(ErrorEntry::from)?; - let is_version_2 = fatbinc_wrapper.version == FatbincWrapper::VERSION_V2; - record_submodules_from_fatbin(module, (*fatbinc_wrapper).data, fn_logger, state)?; - if is_version_2 { - let mut current = (*fatbinc_wrapper).filename_or_fatbins as *const *const c_void; - while *current != ptr::null() { - record_submodules_from_fatbin(module, *current as *const _, fn_logger, state)?; - current = current.add(1); + let fatbin = Fatbin::new(&fatbinc_wrapper).map_err(ErrorEntry::from)?; + let first = fatbin.get_first().map_err(ErrorEntry::from)?; + record_submodules_from_fatbin(module, first, fn_logger, state)?; + if let Some(mut submodules) = fatbin.get_submodules() { + while let Some(current) = submodules.next()? { + record_submodules_from_fatbin(module, current, fn_logger, state)?; } } Ok(()) @@ -274,37 +274,43 @@ pub(crate) unsafe fn record_submodules_from_wrapped_fatbin( pub(crate) unsafe fn record_submodules_from_fatbin( module: CUmodule, - fatbin_header: *const FatbinHeader, + submodule: FatbinSubmodule, logger: &mut FnCallLog, state: &mut StateTracker, ) -> Result<(), ErrorEntry> { - let header = FatbinHeader::new(&fatbin_header).map_err(ErrorEntry::from)?; - let file = header.get_content(); - record_submodules(module, logger, state, file)?; + record_submodules(module, logger, state, submodule.get_files())?; Ok(()) } -unsafe fn record_submodules( +pub(crate) unsafe fn record_submodules( module: CUmodule, fn_logger: &mut FnCallLog, state: &mut StateTracker, - mut file_buffer: &[u8], + mut files: FatbinFileIterator, ) -> Result<(), ErrorEntry> { - while let Some(file) = FatbinFileHeader::next(&mut file_buffer)? { - let mut payload = if file.flags.contains(FatbinFileHeaderFlags::CompressedLz4) { + while let Some(file) = files.next()? { + let mut payload = if file + .header + .flags + .contains(FatbinFileHeaderFlags::CompressedLz4) + { Cow::Owned(unwrap_some_or!( - fn_logger.try_return(|| decompress_lz4(file)), + fn_logger.try_return(|| decompress_lz4(&file).map_err(|e| e.into())), continue )) - } else if file.flags.contains(FatbinFileHeaderFlags::CompressedZstd) { + } else if file + .header + .flags + .contains(FatbinFileHeaderFlags::CompressedZstd) + { Cow::Owned(unwrap_some_or!( - fn_logger.try_return(|| decompress_zstd(file)), + fn_logger.try_return(|| decompress_zstd(&file).map_err(|e| e.into())), continue )) } else { Cow::Borrowed(file.get_payload()) }; - match file.kind { + match file.header.kind { FatbinFileHeader::HEADER_KIND_PTX => { while payload.last() == Some(&0) { // remove trailing zeros @@ -322,50 +328,10 @@ unsafe fn record_submodules( UInt::U16(FatbinFileHeader::HEADER_KIND_PTX), UInt::U16(FatbinFileHeader::HEADER_KIND_ELF), ], - observed: UInt::U16(file.kind), + observed: UInt::U16(file.header.kind), }); } } } Ok(()) } - -const MAX_MODULE_DECOMPRESSION_BOUND: usize = 64 * 1024 * 1024; - -unsafe fn decompress_lz4(file: &FatbinFileHeader) -> Result, ErrorEntry> { - let decompressed_size = usize::max(1024, (*file).uncompressed_payload as usize); - let mut decompressed_vec = vec![0u8; decompressed_size]; - loop { - match lz4_sys::LZ4_decompress_safe( - file.get_payload().as_ptr() as *const _, - decompressed_vec.as_mut_ptr() as *mut _, - (*file).payload_size as _, - decompressed_vec.len() as _, - ) { - error if error < 0 => { - let new_size = decompressed_vec.len() * 2; - if new_size > MAX_MODULE_DECOMPRESSION_BOUND { - return Err(ErrorEntry::Lz4DecompressionFailure); - } - decompressed_vec.resize(decompressed_vec.len() * 2, 0); - } - real_decompressed_size => { - decompressed_vec.truncate(real_decompressed_size as usize); - return Ok(decompressed_vec); - } - } - } -} - -unsafe fn decompress_zstd(file: &FatbinFileHeader) -> Result, ErrorEntry> { - let mut result = Vec::with_capacity(file.uncompressed_payload as usize); - let payload = file.get_payload(); - dbg!((payload.len(), file.uncompressed_payload, file.payload_size)); - match zstd_safe::decompress(&mut result, payload) { - Ok(actual_size) => { - result.truncate(actual_size); - Ok(result) - } - Err(err) => Err(ErrorEntry::ZstdDecompressionFailure(err)), - } -}