From 7a8a1984ae470a73f0c220a9f306d529fcaa5e74 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Mon, 8 Sep 2025 20:16:53 +0000 Subject: [PATCH] Handle compressed files in fatbins more correctly --- cuda_types/src/dark_api.rs | 6 ++++-- dark_api/src/fatbin.rs | 22 ++++++++++++++++------ 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/cuda_types/src/dark_api.rs b/cuda_types/src/dark_api.rs index 1db5b2b..bb7b2cf 100644 --- a/cuda_types/src/dark_api.rs +++ b/cuda_types/src/dark_api.rs @@ -45,13 +45,14 @@ pub struct FatbinHeader { } #[repr(C)] +#[derive(Debug)] pub struct FatbinFileHeader { pub kind: c_ushort, pub version: c_ushort, pub header_size: c_uint, - pub padded_payload_size: c_uint, - pub unknown0: c_uint, // check if it's written into separately pub payload_size: c_uint, + pub unknown0: c_uint, // check if it's written into separately + pub compressed_size: c_uint, pub unknown1: c_uint, pub unknown2: c_uint, pub sm_version: c_uint, @@ -63,6 +64,7 @@ pub struct FatbinFileHeader { } bitflags! { + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct FatbinFileHeaderFlags: u64 { const Is64Bit = 0x0000000000000001; const Debug = 0x0000000000000002; diff --git a/dark_api/src/fatbin.rs b/dark_api/src/fatbin.rs index e347f26..9488499 100644 --- a/dark_api/src/fatbin.rs +++ b/dark_api/src/fatbin.rs @@ -182,13 +182,20 @@ impl<'a> FatbinFile<'a> { } } - pub unsafe fn get_payload(self) -> &'a [u8] { + pub unsafe fn get_non_compressed_payload(self) -> &'a [u8] { let start = std::ptr::from_ref(self.header) .cast::() .add(self.header.header_size as usize); std::slice::from_raw_parts(start, self.header.payload_size as usize) } + pub unsafe fn get_compressed_payload(self) -> &'a [u8] { + let start = std::ptr::from_ref(self.header) + .cast::() + .add(self.header.header_size as usize); + std::slice::from_raw_parts(start, self.header.compressed_size as usize) + } + pub unsafe fn get_or_decompress_content(self) -> Result, FatbinError> { let mut payload = if self .header @@ -203,7 +210,7 @@ impl<'a> FatbinFile<'a> { { Cow::Owned(unsafe { decompress_zstd(self) }?) } else { - Cow::Borrowed(unsafe { self.get_payload() }) + Cow::Borrowed(unsafe { self.get_non_compressed_payload() }) }; // Remove trailing zeros @@ -246,7 +253,10 @@ impl<'a> FatbinFileIterator<'a> { let this = &*self.file_buffer.as_ptr().cast::(); let next_element = self .file_buffer - .split_at_checked(this.header_size as usize + this.padded_payload_size as usize) + .split_at_checked( + this.header_size as usize + + u32::max(this.payload_size, this.compressed_size) as usize, + ) .map(|(_, next)| next); self.file_buffer = next_element.unwrap_or(&[]); Some( @@ -267,9 +277,9 @@ pub unsafe fn decompress_lz4(file: FatbinFile) -> Result, FatbinError> { let mut decompressed_vec = vec![0u8; decompressed_size]; loop { match lz4_sys::LZ4_decompress_safe( - file.get_payload().as_ptr() as *const _, + file.get_compressed_payload().as_ptr() as *const _, decompressed_vec.as_mut_ptr() as *mut _, - file.header.payload_size as _, + file.header.compressed_size as _, decompressed_vec.len() as _, ) { error if error < 0 => { @@ -289,7 +299,7 @@ pub unsafe fn decompress_lz4(file: FatbinFile) -> Result, FatbinError> { pub unsafe fn decompress_zstd(file: FatbinFile) -> Result, FatbinError> { let mut result = Vec::with_capacity(file.header.uncompressed_payload as usize); - let payload = file.get_payload(); + let payload = file.get_compressed_payload(); match zstd_safe::decompress(&mut result, payload) { Ok(actual_size) => { result.truncate(actual_size);