Support old PTX compression scheme

This commit is contained in:
Andrzej Janik 2024-03-29 02:02:09 +01:00
commit 6e144971c4
3 changed files with 43 additions and 31 deletions

View file

@ -14,6 +14,7 @@ either = "1.9"
bit-vec = "0.6.3" bit-vec = "0.6.3"
paste = "1.0" paste = "1.0"
lz4-sys = "1.9" lz4-sys = "1.9"
cloudflare-zlib = "0.2.10"
thread-id = "4.1.0" thread-id = "4.1.0"
# we don't need elf32, but goblin has a bug where elf64 does not build without elf32 # we don't need elf32, but goblin has a bug where elf64 does not build without elf32
goblin = { version = "0.5.1", default-features = false, features = ["elf64", "elf32"] } goblin = { version = "0.5.1", default-features = false, features = ["elf64", "elf32"] }

View file

@ -687,13 +687,19 @@ pub enum FatbinModule {
pub struct FatbinFile { pub struct FatbinFile {
data: *const u8, data: *const u8,
pub kind: FatbinFileKind, pub kind: FatbinFileKind,
pub compressed: bool, pub compression: FatbinCompression,
pub sm_version: u32, pub sm_version: u32,
padded_payload_size: usize, padded_payload_size: usize,
payload_size: usize, payload_size: usize,
uncompressed_payload: usize, uncompressed_payload: usize,
} }
pub enum FatbinCompression {
None,
Zlib,
Lz4,
}
impl FatbinFile { impl FatbinFile {
unsafe fn try_new(fatbin_file: &FatbinFileHeader) -> Result<Self, UnexpectedFieldError> { unsafe fn try_new(fatbin_file: &FatbinFileHeader) -> Result<Self, UnexpectedFieldError> {
let fatbin_file_version = fatbin_file.version; let fatbin_file_version = fatbin_file.version;
@ -719,22 +725,19 @@ impl FatbinFile {
}); });
} }
}; };
if fatbin_file let compression = if fatbin_file
.flags .flags
.contains(FatbinFileHeaderFlags::CompressedOld) .contains(FatbinFileHeaderFlags::CompressedOld)
{ {
return Err(UnexpectedFieldError { FatbinCompression::Zlib
name: "FATBIN_FILE_HEADER_FLAGS", } else if fatbin_file
expected: vec![
AnyUInt::U64(FatbinFileHeaderFlags::empty().bits()),
AnyUInt::U64(FatbinFileHeaderFlags::CompressedNew.bits()),
],
observed: AnyUInt::U64(fatbin_file.flags.bits()),
});
}
let compressed = fatbin_file
.flags .flags
.contains(FatbinFileHeaderFlags::CompressedNew); .contains(FatbinFileHeaderFlags::CompressedNew)
{
FatbinCompression::Lz4
} else {
FatbinCompression::None
};
let data = (fatbin_file as *const _ as *const u8).add(fatbin_file.header_size as usize); let data = (fatbin_file as *const _ as *const u8).add(fatbin_file.header_size as usize);
let padded_payload_size = fatbin_file.padded_payload_size as usize; let padded_payload_size = fatbin_file.padded_payload_size as usize;
let payload_size = fatbin_file.payload_size as usize; let payload_size = fatbin_file.payload_size as usize;
@ -743,7 +746,7 @@ impl FatbinFile {
Ok(Self { Ok(Self {
data, data,
kind, kind,
compressed, compression,
padded_payload_size, padded_payload_size,
payload_size, payload_size,
uncompressed_payload, uncompressed_payload,
@ -753,28 +756,36 @@ impl FatbinFile {
// Returning static lifetime here because all known uses of this are related to fatbin files that // Returning static lifetime here because all known uses of this are related to fatbin files that
// are constants inside files // are constants inside files
pub unsafe fn get_or_decompress(&self) -> Result<Cow<'static, [u8]>, Lz4DecompressionFailure> { pub unsafe fn get_or_decompress(&self) -> Result<Cow<'static, [u8]>, DecompressionFailure> {
if self.compressed { match self.compression {
match self.decompress_kernel_module() { FatbinCompression::Lz4 => {
Some(mut decompressed) => { match self.decompress_kernel_module_lz4() {
if self.kind == FatbinFileKind::Ptx { Some(mut decompressed) => {
decompressed.pop(); // remove trailing zero if self.kind == FatbinFileKind::Ptx {
decompressed.pop(); // remove trailing zero
}
Ok(Cow::Owned(decompressed))
} }
Ok(Cow::Owned(decompressed)) None => Err(DecompressionFailure),
} }
None => Err(Lz4DecompressionFailure),
} }
} else { FatbinCompression::Zlib => {
Ok(Cow::Borrowed(slice::from_raw_parts( let compressed =
std::slice::from_raw_parts(self.data.cast(), self.padded_payload_size);
Ok(Cow::Owned(
cloudflare_zlib::inflate(compressed).map_err(|_| DecompressionFailure)?,
))
}
FatbinCompression::None => Ok(Cow::Borrowed(slice::from_raw_parts(
self.data, self.data,
self.padded_payload_size as usize, self.padded_payload_size as usize,
))) ))),
} }
} }
const MAX_MODULE_DECOMPRESSION_BOUND: usize = 64 * 1024 * 1024; const MAX_MODULE_DECOMPRESSION_BOUND: usize = 64 * 1024 * 1024;
unsafe fn decompress_kernel_module(&self) -> Option<Vec<u8>> { unsafe fn decompress_kernel_module_lz4(&self) -> Option<Vec<u8>> {
let decompressed_size = usize::max(1024, self.uncompressed_payload as usize); let decompressed_size = usize::max(1024, self.uncompressed_payload as usize);
let mut decompressed_vec = vec![0u8; decompressed_size]; let mut decompressed_vec = vec![0u8; decompressed_size];
loop { loop {
@ -801,7 +812,7 @@ impl FatbinFile {
} }
#[derive(Debug)] #[derive(Debug)]
pub struct Lz4DecompressionFailure; pub struct DecompressionFailure;
pub fn anti_zluda_hash<F: FnMut(u32) -> AntiZludaHashInputDevice>( pub fn anti_zluda_hash<F: FnMut(u32) -> AntiZludaHashInputDevice>(
return_known_value: bool, return_known_value: bool,

View file

@ -19,7 +19,7 @@ use std::path::PathBuf;
use std::str::Utf8Error; use std::str::Utf8Error;
use zluda_dark_api::AnyUInt; use zluda_dark_api::AnyUInt;
use zluda_dark_api::FatbinFileKind; use zluda_dark_api::FatbinFileKind;
use zluda_dark_api::Lz4DecompressionFailure; use zluda_dark_api::DecompressionFailure;
use zluda_dark_api::UnexpectedFieldError; use zluda_dark_api::UnexpectedFieldError;
const LOG_PREFIX: &[u8] = b"[ZLUDA_DUMP] "; const LOG_PREFIX: &[u8] = b"[ZLUDA_DUMP] ";
@ -447,7 +447,7 @@ impl Display for LogEntry {
file_name file_name
) )
} }
LogEntry::Lz4DecompressionFailure => write!(f, "LZ4 decompression failure"), LogEntry::Lz4DecompressionFailure => write!(f, "Decompression failure"),
LogEntry::UnknownExportTableFn => write!(f, "Unknown export table function"), LogEntry::UnknownExportTableFn => write!(f, "Unknown export table function"),
LogEntry::UnexpectedBinaryField { LogEntry::UnexpectedBinaryField {
field_name, field_name,
@ -591,8 +591,8 @@ impl From<io::Error> for LogEntry {
} }
} }
impl From<Lz4DecompressionFailure> for LogEntry { impl From<DecompressionFailure> for LogEntry {
fn from(_err: Lz4DecompressionFailure) -> Self { fn from(_err: DecompressionFailure) -> Self {
LogEntry::Lz4DecompressionFailure LogEntry::Lz4DecompressionFailure
} }
} }