mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-10-01 21:59:38 +00:00
Handle compressed files in fatbins more correctly
This commit is contained in:
parent
6f14025e9b
commit
7a8a1984ae
2 changed files with 20 additions and 8 deletions
|
@ -45,13 +45,14 @@ pub struct FatbinHeader {
|
|||
}
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug)]
|
||||
pub struct FatbinFileHeader {
|
||||
pub kind: c_ushort,
|
||||
pub version: c_ushort,
|
||||
pub header_size: c_uint,
|
||||
pub padded_payload_size: c_uint,
|
||||
pub unknown0: c_uint, // check if it's written into separately
|
||||
pub payload_size: c_uint,
|
||||
pub unknown0: c_uint, // check if it's written into separately
|
||||
pub compressed_size: c_uint,
|
||||
pub unknown1: c_uint,
|
||||
pub unknown2: c_uint,
|
||||
pub sm_version: c_uint,
|
||||
|
@ -63,6 +64,7 @@ pub struct FatbinFileHeader {
|
|||
}
|
||||
|
||||
bitflags! {
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub struct FatbinFileHeaderFlags: u64 {
|
||||
const Is64Bit = 0x0000000000000001;
|
||||
const Debug = 0x0000000000000002;
|
||||
|
|
|
@ -182,13 +182,20 @@ impl<'a> FatbinFile<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
pub unsafe fn get_payload(self) -> &'a [u8] {
|
||||
pub unsafe fn get_non_compressed_payload(self) -> &'a [u8] {
|
||||
let start = std::ptr::from_ref(self.header)
|
||||
.cast::<u8>()
|
||||
.add(self.header.header_size as usize);
|
||||
std::slice::from_raw_parts(start, self.header.payload_size as usize)
|
||||
}
|
||||
|
||||
pub unsafe fn get_compressed_payload(self) -> &'a [u8] {
|
||||
let start = std::ptr::from_ref(self.header)
|
||||
.cast::<u8>()
|
||||
.add(self.header.header_size as usize);
|
||||
std::slice::from_raw_parts(start, self.header.compressed_size as usize)
|
||||
}
|
||||
|
||||
pub unsafe fn get_or_decompress_content(self) -> Result<Cow<'a, [u8]>, FatbinError> {
|
||||
let mut payload = if self
|
||||
.header
|
||||
|
@ -203,7 +210,7 @@ impl<'a> FatbinFile<'a> {
|
|||
{
|
||||
Cow::Owned(unsafe { decompress_zstd(self) }?)
|
||||
} else {
|
||||
Cow::Borrowed(unsafe { self.get_payload() })
|
||||
Cow::Borrowed(unsafe { self.get_non_compressed_payload() })
|
||||
};
|
||||
|
||||
// Remove trailing zeros
|
||||
|
@ -246,7 +253,10 @@ impl<'a> FatbinFileIterator<'a> {
|
|||
let this = &*self.file_buffer.as_ptr().cast::<FatbinFileHeader>();
|
||||
let next_element = self
|
||||
.file_buffer
|
||||
.split_at_checked(this.header_size as usize + this.padded_payload_size as usize)
|
||||
.split_at_checked(
|
||||
this.header_size as usize
|
||||
+ u32::max(this.payload_size, this.compressed_size) as usize,
|
||||
)
|
||||
.map(|(_, next)| next);
|
||||
self.file_buffer = next_element.unwrap_or(&[]);
|
||||
Some(
|
||||
|
@ -267,9 +277,9 @@ pub unsafe fn decompress_lz4(file: FatbinFile) -> Result<Vec<u8>, FatbinError> {
|
|||
let mut decompressed_vec = vec![0u8; decompressed_size];
|
||||
loop {
|
||||
match lz4_sys::LZ4_decompress_safe(
|
||||
file.get_payload().as_ptr() as *const _,
|
||||
file.get_compressed_payload().as_ptr() as *const _,
|
||||
decompressed_vec.as_mut_ptr() as *mut _,
|
||||
file.header.payload_size as _,
|
||||
file.header.compressed_size as _,
|
||||
decompressed_vec.len() as _,
|
||||
) {
|
||||
error if error < 0 => {
|
||||
|
@ -289,7 +299,7 @@ pub unsafe fn decompress_lz4(file: FatbinFile) -> Result<Vec<u8>, FatbinError> {
|
|||
|
||||
pub unsafe fn decompress_zstd(file: FatbinFile) -> Result<Vec<u8>, FatbinError> {
|
||||
let mut result = Vec::with_capacity(file.header.uncompressed_payload as usize);
|
||||
let payload = file.get_payload();
|
||||
let payload = file.get_compressed_payload();
|
||||
match zstd_safe::decompress(&mut result, payload) {
|
||||
Ok(actual_size) => {
|
||||
result.truncate(actual_size);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue