mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-10-02 06:10:06 +00:00
Handle compressed files in fatbins more correctly
This commit is contained in:
parent
6f14025e9b
commit
7a8a1984ae
2 changed files with 20 additions and 8 deletions
|
@ -45,13 +45,14 @@ pub struct FatbinHeader {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[repr(C)]
|
#[repr(C)]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct FatbinFileHeader {
|
pub struct FatbinFileHeader {
|
||||||
pub kind: c_ushort,
|
pub kind: c_ushort,
|
||||||
pub version: c_ushort,
|
pub version: c_ushort,
|
||||||
pub header_size: c_uint,
|
pub header_size: c_uint,
|
||||||
pub padded_payload_size: c_uint,
|
|
||||||
pub unknown0: c_uint, // check if it's written into separately
|
|
||||||
pub payload_size: c_uint,
|
pub payload_size: c_uint,
|
||||||
|
pub unknown0: c_uint, // check if it's written into separately
|
||||||
|
pub compressed_size: c_uint,
|
||||||
pub unknown1: c_uint,
|
pub unknown1: c_uint,
|
||||||
pub unknown2: c_uint,
|
pub unknown2: c_uint,
|
||||||
pub sm_version: c_uint,
|
pub sm_version: c_uint,
|
||||||
|
@ -63,6 +64,7 @@ pub struct FatbinFileHeader {
|
||||||
}
|
}
|
||||||
|
|
||||||
bitflags! {
|
bitflags! {
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
pub struct FatbinFileHeaderFlags: u64 {
|
pub struct FatbinFileHeaderFlags: u64 {
|
||||||
const Is64Bit = 0x0000000000000001;
|
const Is64Bit = 0x0000000000000001;
|
||||||
const Debug = 0x0000000000000002;
|
const Debug = 0x0000000000000002;
|
||||||
|
|
|
@ -182,13 +182,20 @@ impl<'a> FatbinFile<'a> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub unsafe fn get_payload(self) -> &'a [u8] {
|
pub unsafe fn get_non_compressed_payload(self) -> &'a [u8] {
|
||||||
let start = std::ptr::from_ref(self.header)
|
let start = std::ptr::from_ref(self.header)
|
||||||
.cast::<u8>()
|
.cast::<u8>()
|
||||||
.add(self.header.header_size as usize);
|
.add(self.header.header_size as usize);
|
||||||
std::slice::from_raw_parts(start, self.header.payload_size as usize)
|
std::slice::from_raw_parts(start, self.header.payload_size as usize)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub unsafe fn get_compressed_payload(self) -> &'a [u8] {
|
||||||
|
let start = std::ptr::from_ref(self.header)
|
||||||
|
.cast::<u8>()
|
||||||
|
.add(self.header.header_size as usize);
|
||||||
|
std::slice::from_raw_parts(start, self.header.compressed_size as usize)
|
||||||
|
}
|
||||||
|
|
||||||
pub unsafe fn get_or_decompress_content(self) -> Result<Cow<'a, [u8]>, FatbinError> {
|
pub unsafe fn get_or_decompress_content(self) -> Result<Cow<'a, [u8]>, FatbinError> {
|
||||||
let mut payload = if self
|
let mut payload = if self
|
||||||
.header
|
.header
|
||||||
|
@ -203,7 +210,7 @@ impl<'a> FatbinFile<'a> {
|
||||||
{
|
{
|
||||||
Cow::Owned(unsafe { decompress_zstd(self) }?)
|
Cow::Owned(unsafe { decompress_zstd(self) }?)
|
||||||
} else {
|
} else {
|
||||||
Cow::Borrowed(unsafe { self.get_payload() })
|
Cow::Borrowed(unsafe { self.get_non_compressed_payload() })
|
||||||
};
|
};
|
||||||
|
|
||||||
// Remove trailing zeros
|
// Remove trailing zeros
|
||||||
|
@ -246,7 +253,10 @@ impl<'a> FatbinFileIterator<'a> {
|
||||||
let this = &*self.file_buffer.as_ptr().cast::<FatbinFileHeader>();
|
let this = &*self.file_buffer.as_ptr().cast::<FatbinFileHeader>();
|
||||||
let next_element = self
|
let next_element = self
|
||||||
.file_buffer
|
.file_buffer
|
||||||
.split_at_checked(this.header_size as usize + this.padded_payload_size as usize)
|
.split_at_checked(
|
||||||
|
this.header_size as usize
|
||||||
|
+ u32::max(this.payload_size, this.compressed_size) as usize,
|
||||||
|
)
|
||||||
.map(|(_, next)| next);
|
.map(|(_, next)| next);
|
||||||
self.file_buffer = next_element.unwrap_or(&[]);
|
self.file_buffer = next_element.unwrap_or(&[]);
|
||||||
Some(
|
Some(
|
||||||
|
@ -267,9 +277,9 @@ pub unsafe fn decompress_lz4(file: FatbinFile) -> Result<Vec<u8>, FatbinError> {
|
||||||
let mut decompressed_vec = vec![0u8; decompressed_size];
|
let mut decompressed_vec = vec![0u8; decompressed_size];
|
||||||
loop {
|
loop {
|
||||||
match lz4_sys::LZ4_decompress_safe(
|
match lz4_sys::LZ4_decompress_safe(
|
||||||
file.get_payload().as_ptr() as *const _,
|
file.get_compressed_payload().as_ptr() as *const _,
|
||||||
decompressed_vec.as_mut_ptr() as *mut _,
|
decompressed_vec.as_mut_ptr() as *mut _,
|
||||||
file.header.payload_size as _,
|
file.header.compressed_size as _,
|
||||||
decompressed_vec.len() as _,
|
decompressed_vec.len() as _,
|
||||||
) {
|
) {
|
||||||
error if error < 0 => {
|
error if error < 0 => {
|
||||||
|
@ -289,7 +299,7 @@ pub unsafe fn decompress_lz4(file: FatbinFile) -> Result<Vec<u8>, FatbinError> {
|
||||||
|
|
||||||
pub unsafe fn decompress_zstd(file: FatbinFile) -> Result<Vec<u8>, FatbinError> {
|
pub unsafe fn decompress_zstd(file: FatbinFile) -> Result<Vec<u8>, FatbinError> {
|
||||||
let mut result = Vec::with_capacity(file.header.uncompressed_payload as usize);
|
let mut result = Vec::with_capacity(file.header.uncompressed_payload as usize);
|
||||||
let payload = file.get_payload();
|
let payload = file.get_compressed_payload();
|
||||||
match zstd_safe::decompress(&mut result, payload) {
|
match zstd_safe::decompress(&mut result, payload) {
|
||||||
Ok(actual_size) => {
|
Ok(actual_size) => {
|
||||||
result.truncate(actual_size);
|
result.truncate(actual_size);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue