Handle compressed files in fatbins more correctly

This commit is contained in:
Andrzej Janik 2025-09-08 20:16:53 +00:00
commit 7a8a1984ae
2 changed files with 20 additions and 8 deletions

View file

@ -45,13 +45,14 @@ pub struct FatbinHeader {
}
#[repr(C)]
#[derive(Debug)]
pub struct FatbinFileHeader {
pub kind: c_ushort,
pub version: c_ushort,
pub header_size: c_uint,
pub padded_payload_size: c_uint,
pub unknown0: c_uint, // check if it's written into separately
pub payload_size: c_uint,
pub unknown0: c_uint, // check if it's written into separately
pub compressed_size: c_uint,
pub unknown1: c_uint,
pub unknown2: c_uint,
pub sm_version: c_uint,
@ -63,6 +64,7 @@ pub struct FatbinFileHeader {
}
bitflags! {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct FatbinFileHeaderFlags: u64 {
const Is64Bit = 0x0000000000000001;
const Debug = 0x0000000000000002;

View file

@ -182,13 +182,20 @@ impl<'a> FatbinFile<'a> {
}
}
pub unsafe fn get_payload(self) -> &'a [u8] {
pub unsafe fn get_non_compressed_payload(self) -> &'a [u8] {
let start = std::ptr::from_ref(self.header)
.cast::<u8>()
.add(self.header.header_size as usize);
std::slice::from_raw_parts(start, self.header.payload_size as usize)
}
pub unsafe fn get_compressed_payload(self) -> &'a [u8] {
let start = std::ptr::from_ref(self.header)
.cast::<u8>()
.add(self.header.header_size as usize);
std::slice::from_raw_parts(start, self.header.compressed_size as usize)
}
pub unsafe fn get_or_decompress_content(self) -> Result<Cow<'a, [u8]>, FatbinError> {
let mut payload = if self
.header
@ -203,7 +210,7 @@ impl<'a> FatbinFile<'a> {
{
Cow::Owned(unsafe { decompress_zstd(self) }?)
} else {
Cow::Borrowed(unsafe { self.get_payload() })
Cow::Borrowed(unsafe { self.get_non_compressed_payload() })
};
// Remove trailing zeros
@ -246,7 +253,10 @@ impl<'a> FatbinFileIterator<'a> {
let this = &*self.file_buffer.as_ptr().cast::<FatbinFileHeader>();
let next_element = self
.file_buffer
.split_at_checked(this.header_size as usize + this.padded_payload_size as usize)
.split_at_checked(
this.header_size as usize
+ u32::max(this.payload_size, this.compressed_size) as usize,
)
.map(|(_, next)| next);
self.file_buffer = next_element.unwrap_or(&[]);
Some(
@ -267,9 +277,9 @@ pub unsafe fn decompress_lz4(file: FatbinFile) -> Result<Vec<u8>, FatbinError> {
let mut decompressed_vec = vec![0u8; decompressed_size];
loop {
match lz4_sys::LZ4_decompress_safe(
file.get_payload().as_ptr() as *const _,
file.get_compressed_payload().as_ptr() as *const _,
decompressed_vec.as_mut_ptr() as *mut _,
file.header.payload_size as _,
file.header.compressed_size as _,
decompressed_vec.len() as _,
) {
error if error < 0 => {
@ -289,7 +299,7 @@ pub unsafe fn decompress_lz4(file: FatbinFile) -> Result<Vec<u8>, FatbinError> {
pub unsafe fn decompress_zstd(file: FatbinFile) -> Result<Vec<u8>, FatbinError> {
let mut result = Vec::with_capacity(file.header.uncompressed_payload as usize);
let payload = file.get_payload();
let payload = file.get_compressed_payload();
match zstd_safe::decompress(&mut result, payload) {
Ok(actual_size) => {
result.truncate(actual_size);