Support old PTX compression scheme (#188)

This commit is contained in:
Andrzej Janik 2024-03-29 02:03:23 +01:00 committed by GitHub
parent 7d4147c8b2
commit b695f44c18
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 43 additions and 31 deletions

View file

@ -14,6 +14,7 @@ either = "1.9"
bit-vec = "0.6.3"
paste = "1.0"
lz4-sys = "1.9"
cloudflare-zlib = "0.2.10"
thread-id = "4.1.0"
# we don't need elf32, but goblin has a bug where elf64 does not build without elf32
goblin = { version = "0.5.1", default-features = false, features = ["elf64", "elf32"] }

View file

@ -687,13 +687,19 @@ pub enum FatbinModule {
pub struct FatbinFile {
data: *const u8,
pub kind: FatbinFileKind,
pub compressed: bool,
pub compression: FatbinCompression,
pub sm_version: u32,
padded_payload_size: usize,
payload_size: usize,
uncompressed_payload: usize,
}
pub enum FatbinCompression {
None,
Zlib,
Lz4,
}
impl FatbinFile {
unsafe fn try_new(fatbin_file: &FatbinFileHeader) -> Result<Self, UnexpectedFieldError> {
let fatbin_file_version = fatbin_file.version;
@ -719,22 +725,19 @@ impl FatbinFile {
});
}
};
if fatbin_file
let compression = if fatbin_file
.flags
.contains(FatbinFileHeaderFlags::CompressedOld)
{
return Err(UnexpectedFieldError {
name: "FATBIN_FILE_HEADER_FLAGS",
expected: vec![
AnyUInt::U64(FatbinFileHeaderFlags::empty().bits()),
AnyUInt::U64(FatbinFileHeaderFlags::CompressedNew.bits()),
],
observed: AnyUInt::U64(fatbin_file.flags.bits()),
});
}
let compressed = fatbin_file
FatbinCompression::Zlib
} else if fatbin_file
.flags
.contains(FatbinFileHeaderFlags::CompressedNew);
.contains(FatbinFileHeaderFlags::CompressedNew)
{
FatbinCompression::Lz4
} else {
FatbinCompression::None
};
let data = (fatbin_file as *const _ as *const u8).add(fatbin_file.header_size as usize);
let padded_payload_size = fatbin_file.padded_payload_size as usize;
let payload_size = fatbin_file.payload_size as usize;
@ -743,7 +746,7 @@ impl FatbinFile {
Ok(Self {
data,
kind,
compressed,
compression,
padded_payload_size,
payload_size,
uncompressed_payload,
@ -753,28 +756,36 @@ impl FatbinFile {
// Returning static lifetime here because all known uses of this are related to fatbin files that
// are constants inside files
pub unsafe fn get_or_decompress(&self) -> Result<Cow<'static, [u8]>, Lz4DecompressionFailure> {
if self.compressed {
match self.decompress_kernel_module() {
Some(mut decompressed) => {
if self.kind == FatbinFileKind::Ptx {
decompressed.pop(); // remove trailing zero
pub unsafe fn get_or_decompress(&self) -> Result<Cow<'static, [u8]>, DecompressionFailure> {
match self.compression {
FatbinCompression::Lz4 => {
match self.decompress_kernel_module_lz4() {
Some(mut decompressed) => {
if self.kind == FatbinFileKind::Ptx {
decompressed.pop(); // remove trailing zero
}
Ok(Cow::Owned(decompressed))
}
Ok(Cow::Owned(decompressed))
None => Err(DecompressionFailure),
}
None => Err(Lz4DecompressionFailure),
}
} else {
Ok(Cow::Borrowed(slice::from_raw_parts(
FatbinCompression::Zlib => {
let compressed =
std::slice::from_raw_parts(self.data.cast(), self.padded_payload_size);
Ok(Cow::Owned(
cloudflare_zlib::inflate(compressed).map_err(|_| DecompressionFailure)?,
))
}
FatbinCompression::None => Ok(Cow::Borrowed(slice::from_raw_parts(
self.data,
self.padded_payload_size as usize,
)))
))),
}
}
const MAX_MODULE_DECOMPRESSION_BOUND: usize = 64 * 1024 * 1024;
unsafe fn decompress_kernel_module(&self) -> Option<Vec<u8>> {
unsafe fn decompress_kernel_module_lz4(&self) -> Option<Vec<u8>> {
let decompressed_size = usize::max(1024, self.uncompressed_payload as usize);
let mut decompressed_vec = vec![0u8; decompressed_size];
loop {
@ -801,7 +812,7 @@ impl FatbinFile {
}
#[derive(Debug)]
pub struct Lz4DecompressionFailure;
pub struct DecompressionFailure;
pub fn anti_zluda_hash<F: FnMut(u32) -> AntiZludaHashInputDevice>(
return_known_value: bool,

View file

@ -19,7 +19,7 @@ use std::path::PathBuf;
use std::str::Utf8Error;
use zluda_dark_api::AnyUInt;
use zluda_dark_api::FatbinFileKind;
use zluda_dark_api::Lz4DecompressionFailure;
use zluda_dark_api::DecompressionFailure;
use zluda_dark_api::UnexpectedFieldError;
const LOG_PREFIX: &[u8] = b"[ZLUDA_DUMP] ";
@ -447,7 +447,7 @@ impl Display for LogEntry {
file_name
)
}
LogEntry::Lz4DecompressionFailure => write!(f, "LZ4 decompression failure"),
LogEntry::Lz4DecompressionFailure => write!(f, "Decompression failure"),
LogEntry::UnknownExportTableFn => write!(f, "Unknown export table function"),
LogEntry::UnexpectedBinaryField {
field_name,
@ -591,8 +591,8 @@ impl From<io::Error> for LogEntry {
}
}
impl From<Lz4DecompressionFailure> for LogEntry {
fn from(_err: Lz4DecompressionFailure) -> Self {
impl From<DecompressionFailure> for LogEntry {
fn from(_err: DecompressionFailure) -> Self {
LogEntry::Lz4DecompressionFailure
}
}