Implement cuLibraryLoadData (#388)

This commit is contained in:
Violet 2025-06-18 16:05:53 -07:00 committed by GitHub
commit 4da3978f94
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 447 additions and 189 deletions

44
Cargo.lock generated
View file

@ -333,8 +333,10 @@ dependencies = [
"cglue",
"cuda_types",
"format",
"lz4-sys",
"paste",
"uuid",
"zstd-safe",
]
[[package]]
@ -500,6 +502,18 @@ dependencies = [
"uuid",
]
[[package]]
name = "getrandom"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4"
dependencies = [
"cfg-if",
"libc",
"r-efi",
"wasi",
]
[[package]]
name = "glob"
version = "0.3.1"
@ -588,10 +602,11 @@ checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674"
[[package]]
name = "jobserver"
version = "0.1.32"
version = "0.1.33"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0"
checksum = "38f262f097c174adebe41eb73d66ae9c06b2844fb0da69969647bbddd9b0538a"
dependencies = [
"getrandom",
"libc",
]
@ -1124,6 +1139,12 @@ dependencies = [
"proc-macro2",
]
[[package]]
name = "r-efi"
version = "5.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f"
[[package]]
name = "rawpointer"
version = "0.2.1"
@ -1522,6 +1543,15 @@ version = "0.9.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
[[package]]
name = "wasi"
version = "0.14.2+wasi-0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3"
dependencies = [
"wit-bindgen-rt",
]
[[package]]
name = "wchar"
version = "0.6.1"
@ -1666,6 +1696,15 @@ dependencies = [
"memchr",
]
[[package]]
name = "wit-bindgen-rt"
version = "0.39.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1"
dependencies = [
"bitflags 2.9.1",
]
[[package]]
name = "xattr"
version = "1.5.0"
@ -1790,7 +1829,6 @@ dependencies = [
"format",
"goblin",
"libc",
"lz4-sys",
"parking_lot",
"paste",
"ptx",

View file

@ -199,7 +199,8 @@ impl VisitMut for FixFnSignatures {
}
const MODULES: &[&str] = &[
"context", "device", "driver", "function", "link", "memory", "module", "pointer", "stream",
"context", "device", "driver", "function", "library", "link", "memory", "module", "pointer",
"stream",
];
#[proc_macro]

View file

@ -78,122 +78,17 @@ bitflags! {
impl FatbincWrapper {
pub const MAGIC: c_uint = 0x466243B1;
const VERSION_V1: c_uint = 0x1;
pub const VERSION_V1: c_uint = 0x1;
pub const VERSION_V2: c_uint = 0x2;
pub fn new<'a, T: Sized>(ptr: &*const T) -> Result<&'a Self, ParseError> {
unsafe { ptr.cast::<Self>().as_ref() }
.ok_or(ParseError::NullPointer("FatbincWrapper"))
.and_then(|ptr| {
ParseError::check_fields("FATBINC_MAGIC", ptr.magic, [Self::MAGIC])?;
ParseError::check_fields(
"FATBINC_VERSION",
ptr.version,
[Self::VERSION_V1, Self::VERSION_V2],
)?;
Ok(ptr)
})
}
}
impl FatbinHeader {
const MAGIC: c_uint = 0xBA55ED50;
const VERSION: c_ushort = 0x01;
pub fn new<'a, T: Sized>(ptr: &'a *const T) -> Result<&'a Self, ParseError> {
unsafe { ptr.cast::<Self>().as_ref() }
.ok_or(ParseError::NullPointer("FatbinHeader"))
.and_then(|ptr| {
ParseError::check_fields("FATBIN_MAGIC", ptr.magic, [Self::MAGIC])?;
ParseError::check_fields("FATBIN_VERSION", ptr.version, [Self::VERSION])?;
Ok(ptr)
})
}
pub unsafe fn get_content<'a>(&'a self) -> &'a [u8] {
let start = std::ptr::from_ref(self)
.cast::<u8>()
.add(self.header_size as usize);
std::slice::from_raw_parts(start, self.files_size as usize)
}
pub const MAGIC: c_uint = 0xBA55ED50;
pub const VERSION: c_ushort = 0x01;
}
impl FatbinFileHeader {
pub const HEADER_KIND_PTX: c_ushort = 0x01;
pub const HEADER_KIND_ELF: c_ushort = 0x02;
const HEADER_VERSION_CURRENT: c_ushort = 0x101;
pub fn new_ptx<T: Sized>(ptr: *const T) -> Result<Option<&'static Self>, ParseError> {
unsafe { ptr.cast::<Self>().as_ref() }
.ok_or(ParseError::NullPointer("FatbinFileHeader"))
.and_then(|ptr| {
ParseError::check_fields(
"FATBIN_FILE_HEADER_VERSION_CURRENT",
ptr.version,
[Self::HEADER_VERSION_CURRENT],
)?;
match ptr.kind {
Self::HEADER_KIND_PTX => Ok(Some(ptr)),
Self::HEADER_KIND_ELF => Ok(None),
_ => Err(ParseError::UnexpectedBinaryField {
field_name: "FATBIN_FILE_HEADER_KIND",
observed: ptr.kind.into(),
expected: vec![Self::HEADER_KIND_PTX.into(), Self::HEADER_KIND_ELF.into()],
}),
}
})
}
pub unsafe fn next<'a>(slice: &'a mut &[u8]) -> Result<Option<&'a Self>, ParseError> {
if slice.len() < std::mem::size_of::<Self>() {
return Ok(None);
}
let this = &*slice.as_ptr().cast::<Self>();
let next_element = slice
.split_at_checked(this.header_size as usize + this.padded_payload_size as usize)
.map(|(_, next)| next);
*slice = next_element.unwrap_or(&[]);
ParseError::check_fields(
"FATBIN_FILE_HEADER_VERSION_CURRENT",
this.version,
[Self::HEADER_VERSION_CURRENT],
)?;
Ok(Some(this))
}
pub unsafe fn get_payload<'a>(&'a self) -> &'a [u8] {
let start = std::ptr::from_ref(self)
.cast::<u8>()
.add(self.header_size as usize);
std::slice::from_raw_parts(start, self.payload_size as usize)
}
}
pub enum ParseError {
NullPointer(&'static str),
UnexpectedBinaryField {
field_name: &'static str,
observed: u32,
expected: Vec<u32>,
},
}
impl ParseError {
pub(crate) fn check_fields<const N: usize, T: Into<u32> + Eq + Copy>(
name: &'static str,
observed: T,
expected: [T; N],
) -> Result<(), Self> {
if expected.contains(&observed) {
Ok(())
} else {
let observed = observed.into();
let expected = expected.into_iter().map(Into::into).collect();
Err(ParseError::UnexpectedBinaryField {
field_name: name,
expected,
observed,
})
}
}
pub const HEADER_VERSION_CURRENT: c_ushort = 0x101;
}

View file

@ -10,3 +10,5 @@ uuid = "1.16"
paste = "1.0"
bit-vec = "0.8.0"
cglue = "0.3.5"
lz4-sys = "1.9"
zstd-safe = { version = "7.2.4", features = ["std"] }

235
dark_api/src/fatbin.rs Normal file
View file

@ -0,0 +1,235 @@
// This file contains a higher-level interface for parsing fatbins
use std::ptr;
use cuda_types::dark_api::*;
pub enum ParseError {
NullPointer(&'static str),
UnexpectedBinaryField {
field_name: &'static str,
observed: u32,
expected: Vec<u32>,
},
}
impl ParseError {
pub(crate) fn check_fields<const N: usize, T: Into<u32> + Eq + Copy>(
name: &'static str,
observed: T,
expected: [T; N],
) -> Result<(), Self> {
if expected.contains(&observed) {
Ok(())
} else {
let observed = observed.into();
let expected = expected.into_iter().map(Into::into).collect();
Err(ParseError::UnexpectedBinaryField {
field_name: name,
expected,
observed,
})
}
}
}
pub enum FatbinError {
ParseFailure(ParseError),
Lz4DecompressionFailure,
ZstdDecompressionFailure(usize),
}
pub fn parse_fatbinc_wrapper<T: Sized>(ptr: &*const T) -> Result<&FatbincWrapper, ParseError> {
unsafe { ptr.cast::<FatbincWrapper>().as_ref() }
.ok_or(ParseError::NullPointer("FatbincWrapper"))
.and_then(|ptr| {
ParseError::check_fields("FATBINC_MAGIC", ptr.magic, [FatbincWrapper::MAGIC])?;
ParseError::check_fields(
"FATBINC_VERSION",
ptr.version,
[FatbincWrapper::VERSION_V1, FatbincWrapper::VERSION_V2],
)?;
Ok(ptr)
})
}
fn parse_fatbin_header<T: Sized>(ptr: &*const T) -> Result<&FatbinHeader, ParseError> {
unsafe { ptr.cast::<FatbinHeader>().as_ref() }
.ok_or(ParseError::NullPointer("FatbinHeader"))
.and_then(|ptr| {
ParseError::check_fields("FATBIN_MAGIC", ptr.magic, [FatbinHeader::MAGIC])?;
ParseError::check_fields("FATBIN_VERSION", ptr.version, [FatbinHeader::VERSION])?;
Ok(ptr)
})
}
pub struct Fatbin<'a> {
pub wrapper: &'a FatbincWrapper,
}
impl<'a> Fatbin<'a> {
pub fn new<T>(ptr: &'a *const T) -> Result<Self, FatbinError> {
let wrapper: &FatbincWrapper =
parse_fatbinc_wrapper(ptr).map_err(|e| FatbinError::ParseFailure(e))?;
Ok(Fatbin { wrapper })
}
pub fn get_first(&self) -> Result<FatbinSubmodule, FatbinError> {
let header: &FatbinHeader =
parse_fatbin_header(&self.wrapper.data).map_err(|e| FatbinError::ParseFailure(e))?;
Ok(FatbinSubmodule::new(header))
}
pub fn get_submodules(&self) -> Option<FatbinSubmoduleIterator> {
let is_version_2 = self.wrapper.version == FatbincWrapper::VERSION_V2;
if !is_version_2 {
return None;
}
Some(FatbinSubmoduleIterator {
fatbins: self.wrapper.filename_or_fatbins as *const *const std::ffi::c_void,
})
}
}
pub struct FatbinSubmodule<'a> {
pub header: &'a FatbinHeader, // TODO: maybe make private
}
impl<'a> FatbinSubmodule<'a> {
pub fn new(header: &'a FatbinHeader) -> Self {
FatbinSubmodule { header }
}
pub fn get_files(&self) -> FatbinFileIterator {
unsafe { FatbinFileIterator::new(self.header) }
}
}
pub struct FatbinSubmoduleIterator {
fatbins: *const *const std::ffi::c_void,
}
impl FatbinSubmoduleIterator {
pub unsafe fn next(&mut self) -> Result<Option<FatbinSubmodule>, ParseError> {
if *self.fatbins != ptr::null() {
let header = *self.fatbins as *const FatbinHeader;
self.fatbins = self.fatbins.add(1);
Ok(Some(FatbinSubmodule::new(header.as_ref().ok_or(
ParseError::NullPointer("FatbinSubmoduleIterator"),
)?)))
} else {
Ok(None)
}
}
}
pub struct FatbinFile<'a> {
pub header: &'a FatbinFileHeader,
}
impl<'a> FatbinFile<'a> {
pub fn new(header: &'a FatbinFileHeader) -> Self {
Self { header }
}
pub unsafe fn get_payload(&'a self) -> &'a [u8] {
let start = std::ptr::from_ref(self)
.cast::<u8>()
.add(self.header.header_size as usize);
std::slice::from_raw_parts(start, self.header.payload_size as usize)
}
pub unsafe fn decompress(&'a self) -> Result<Vec<u8>, FatbinError> {
let payload = if self
.header
.flags
.contains(FatbinFileHeaderFlags::CompressedLz4)
{
unsafe { decompress_lz4(self) }?
} else if self
.header
.flags
.contains(FatbinFileHeaderFlags::CompressedZstd)
{
unsafe { decompress_zstd(self) }?
} else {
unsafe { self.get_payload().to_vec() }
};
Ok(payload)
}
}
pub struct FatbinFileIterator<'a> {
file_buffer: &'a [u8],
}
impl<'a> FatbinFileIterator<'a> {
pub unsafe fn new(header: &'a FatbinHeader) -> Self {
let start = std::ptr::from_ref(header)
.cast::<u8>()
.add(header.header_size as usize);
let file_buffer = std::slice::from_raw_parts(start, header.files_size as usize);
Self { file_buffer }
}
pub unsafe fn next(&mut self) -> Result<Option<FatbinFile>, ParseError> {
if self.file_buffer.len() < std::mem::size_of::<FatbinFileHeader>() {
return Ok(None);
}
let this = &*self.file_buffer.as_ptr().cast::<FatbinFileHeader>();
let next_element = self
.file_buffer
.split_at_checked(this.header_size as usize + this.padded_payload_size as usize)
.map(|(_, next)| next);
self.file_buffer = next_element.unwrap_or(&[]);
ParseError::check_fields(
"FATBIN_FILE_HEADER_VERSION_CURRENT",
this.version,
[FatbinFileHeader::HEADER_VERSION_CURRENT],
)?;
Ok(Some(FatbinFile::new(this)))
}
}
const MAX_MODULE_DECOMPRESSION_BOUND: usize = 64 * 1024 * 1024;
pub unsafe fn decompress_lz4(file: &FatbinFile) -> Result<Vec<u8>, FatbinError> {
let decompressed_size = usize::max(1024, file.header.uncompressed_payload as usize);
let mut decompressed_vec = vec![0u8; decompressed_size];
loop {
match lz4_sys::LZ4_decompress_safe(
file.get_payload().as_ptr() as *const _,
decompressed_vec.as_mut_ptr() as *mut _,
file.header.payload_size as _,
decompressed_vec.len() as _,
) {
error if error < 0 => {
let new_size = decompressed_vec.len() * 2;
if new_size > MAX_MODULE_DECOMPRESSION_BOUND {
return Err(FatbinError::Lz4DecompressionFailure);
}
decompressed_vec.resize(decompressed_vec.len() * 2, 0);
}
real_decompressed_size => {
decompressed_vec.truncate(real_decompressed_size as usize);
return Ok(decompressed_vec);
}
}
}
}
pub unsafe fn decompress_zstd(file: &FatbinFile) -> Result<Vec<u8>, FatbinError> {
let mut result = Vec::with_capacity(file.header.uncompressed_payload as usize);
let payload = file.get_payload();
match zstd_safe::decompress(&mut result, payload) {
Ok(actual_size) => {
result.truncate(actual_size);
Ok(result)
}
Err(err) => Err(FatbinError::ZstdDecompressionFailure(err)),
}
}

View file

@ -2,6 +2,8 @@ use std::ffi::c_void;
use cuda_types::cuda::CUuuid;
pub mod fatbin;
macro_rules! dark_api_init {
(SIZE_OF, $table_len:literal, $type_:ty) => {
(std::mem::size_of::<usize>() * $table_len) as *const std::ffi::c_void

38
zluda/src/impl/library.rs Normal file
View file

@ -0,0 +1,38 @@
use super::module;
use super::ZludaObject;
use cuda_types::cuda::*;
use hip_runtime_sys::*;
pub(crate) struct Library {
base: hipModule_t,
}
impl ZludaObject for Library {
const COOKIE: usize = 0xb328a916cc234d7c;
type CudaHandle = CUlibrary;
fn drop_checked(&mut self) -> CUresult {
// TODO: we will want to test that we handle `cuModuleUnload` on a module that came from a library correctly, without calling `hipModuleUnload` twice.
unsafe { hipModuleUnload(self.base) }?;
Ok(())
}
}
/// This implementation simply loads the code as a HIP module for now. The various JIT and library options are ignored.
pub(crate) fn load_data(
library: &mut CUlibrary,
code: *const ::core::ffi::c_void,
_jit_options: &mut CUjit_option,
_jit_options_values: &mut *mut ::core::ffi::c_void,
_num_jit_options: ::core::ffi::c_uint,
_library_options: &mut CUlibraryOption,
_library_option_values: &mut *mut ::core::ffi::c_void,
_num_library_options: ::core::ffi::c_uint,
) -> CUresult {
let hip_module = module::load_hip_module(code)?;
*library = Library { base: hip_module }.wrap();
Ok(())
}

View file

@ -10,6 +10,7 @@ pub(super) mod context;
pub(super) mod device;
pub(super) mod driver;
pub(super) mod function;
pub(super) mod library;
pub(super) mod memory;
pub(super) mod module;
pub(super) mod pointer;
@ -135,6 +136,9 @@ from_cuda_nop!(
cuda_types::cuda::CUdevprop,
CUdevice_attribute,
CUdriverProcAddressQueryResult,
CUjit_option,
CUlibrary,
CUlibraryOption,
CUmoduleLoadingMode,
CUuuid
);
@ -169,6 +173,15 @@ impl<'a> FromCuda<'a, *const ::core::ffi::c_char> for &CStr {
}
}
impl<'a> FromCuda<'a, *const ::core::ffi::c_void> for &'a ::core::ffi::c_void {
fn from_cuda(x: &'a *const ::core::ffi::c_void) -> Result<Self, CUerror> {
match unsafe { x.as_ref() } {
Some(x) => Ok(x),
None => Err(CUerror::INVALID_VALUE),
}
}
}
pub(crate) trait ZludaObject: Sized + Send + Sync {
const COOKIE: usize;
const LIVENESS_FAIL: CUerror = cuda_types::cuda::CUerror::INVALID_VALUE;

View file

@ -1,5 +1,9 @@
use super::{driver, ZludaObject};
use cuda_types::cuda::*;
use cuda_types::{
cuda::*,
dark_api::{FatbinFileHeader, FatbincWrapper},
};
use dark_api::fatbin::Fatbin;
use hip_runtime_sys::*;
use std::{ffi::CStr, mem};
@ -18,12 +22,47 @@ impl ZludaObject for Module {
}
}
pub(crate) fn load_data(module: &mut CUmodule, image: *const std::ffi::c_void) -> CUresult {
fn get_ptx_from_wrapped_fatbin(image: *const ::core::ffi::c_void) -> Result<Vec<u8>, CUerror> {
let fatbin = Fatbin::new(&image).map_err(|_| CUerror::UNKNOWN)?;
let first = fatbin.get_first().map_err(|_| CUerror::UNKNOWN)?;
let mut files = first.get_files();
while let Some(file) = unsafe { files.next().map_err(|_| CUerror::UNKNOWN)? } {
// Eventually we will want to get the PTX for the highest hardware version that we can get to compile. But for now we just get the first PTX we can find.
if file.header.kind == FatbinFileHeader::HEADER_KIND_PTX {
let decompressed = unsafe { file.decompress() }.map_err(|_| CUerror::UNKNOWN)?;
return Ok(decompressed);
}
}
Err(CUerror::NO_BINARY_FOR_GPU)
}
/// get_ptx takes an `image` that can be either a fatbin or a NULL-terminated ptx, and returns a String containing a ptx extracted from `image`.
fn get_ptx(image: *const ::core::ffi::c_void) -> Result<String, CUerror> {
if image.is_null() {
return Err(CUerror::INVALID_VALUE);
}
let ptx = if unsafe { *(image as *const u32) } == FatbincWrapper::MAGIC {
let ptx_bytes = get_ptx_from_wrapped_fatbin(image)?;
str::from_utf8(&ptx_bytes)
.map_err(|_| CUerror::UNKNOWN)?
.to_owned()
} else {
unsafe { CStr::from_ptr(image.cast()) }
.to_str()
.map_err(|_| CUerror::INVALID_VALUE)?
.to_owned()
};
Ok(ptx)
}
pub(crate) fn load_hip_module(image: *const std::ffi::c_void) -> Result<hipModule_t, CUerror> {
let global_state = driver::global_state()?;
let text = unsafe { CStr::from_ptr(image.cast()) }
.to_str()
.map_err(|_| CUerror::INVALID_VALUE)?;
let ast = ptx_parser::parse_module_checked(text).map_err(|_| CUerror::NO_BINARY_FOR_GPU)?;
let text = get_ptx(image)?;
let ast = ptx_parser::parse_module_checked(&text).map_err(|_| CUerror::NO_BINARY_FOR_GPU)?;
let llvm_module = ptx::to_llvm_module(ast).map_err(|_| CUerror::UNKNOWN)?;
let mut dev = 0;
unsafe { hipCtxGetDevice(&mut dev) }?;
@ -38,6 +77,11 @@ pub(crate) fn load_data(module: &mut CUmodule, image: *const std::ffi::c_void) -
.map_err(|_| CUerror::UNKNOWN)?;
let mut hip_module = unsafe { mem::zeroed() };
unsafe { hipModuleLoadData(&mut hip_module, elf_module.as_ptr().cast()) }?;
Ok(hip_module)
}
pub(crate) fn load_data(module: &mut CUmodule, image: &std::ffi::c_void) -> CUresult {
let hip_module = load_hip_module(image)?;
*module = Module { base: hip_module }.wrap();
Ok(())
}

View file

@ -66,6 +66,7 @@ cuda_base::cuda_function_declarations!(
cuGetProcAddress,
cuGetProcAddress_v2,
cuInit,
cuLibraryLoadData,
cuMemAlloc_v2,
cuMemFree_v2,
cuMemGetAddressRange_v2,
@ -84,4 +85,4 @@ cuda_base::cuda_function_declarations!(
implemented_in_function <= [
cuLaunchKernel,
]
);
);

View file

@ -14,7 +14,6 @@ ptx_parser = { path = "../ptx_parser" }
zluda_dump_common = { path = "../zluda_dump_common" }
format = { path = "../format" }
dark_api = { path = "../dark_api" }
lz4-sys = "1.9"
regex = "1.4"
dynasm = "1.2"
dynasmrt = "1.2"

View file

@ -1,3 +1,4 @@
use ::dark_api::fatbin::FatbinFileIterator;
use ::dark_api::FnFfi;
use cuda_types::cuda::*;
use dark_api::DarkApiState2;
@ -360,7 +361,16 @@ impl DarkApiDump {
});
}
fn_logger.try_(|fn_logger| unsafe {
trace::record_submodules_from_fatbin(*module, fatbin_header, fn_logger, state)
trace::record_submodules(
*module,
fn_logger,
state,
FatbinFileIterator::new(
fatbin_header
.as_ref()
.ok_or(ErrorEntry::NullPointer("get_module_from_cubin_ext2_post"))?,
),
)
});
}
}

View file

@ -308,11 +308,11 @@ pub(crate) enum ErrorEntry {
unsafe impl Send for ErrorEntry {}
unsafe impl Sync for ErrorEntry {}
impl From<cuda_types::dark_api::ParseError> for ErrorEntry {
fn from(e: cuda_types::dark_api::ParseError) -> Self {
impl From<dark_api::fatbin::ParseError> for ErrorEntry {
fn from(e: dark_api::fatbin::ParseError) -> Self {
match e {
cuda_types::dark_api::ParseError::NullPointer(s) => ErrorEntry::NullPointer(s),
cuda_types::dark_api::ParseError::UnexpectedBinaryField {
dark_api::fatbin::ParseError::NullPointer(s) => ErrorEntry::NullPointer(s),
dark_api::fatbin::ParseError::UnexpectedBinaryField {
field_name,
observed,
expected,
@ -325,6 +325,20 @@ impl From<cuda_types::dark_api::ParseError> for ErrorEntry {
}
}
impl From<dark_api::fatbin::FatbinError> for ErrorEntry {
fn from(e: dark_api::fatbin::FatbinError) -> Self {
match e {
dark_api::fatbin::FatbinError::ParseFailure(parse_error) => parse_error.into(),
dark_api::fatbin::FatbinError::Lz4DecompressionFailure => {
ErrorEntry::Lz4DecompressionFailure
}
dark_api::fatbin::FatbinError::ZstdDecompressionFailure(c) => {
ErrorEntry::ZstdDecompressionFailure(c)
}
}
}
}
impl Display for ErrorEntry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {

View file

@ -4,7 +4,10 @@ use crate::{
};
use cuda_types::{
cuda::*,
dark_api::{FatbinFileHeader, FatbinFileHeaderFlags, FatbinHeader, FatbincWrapper},
dark_api::{FatbinFileHeader, FatbinFileHeaderFlags, FatbincWrapper},
};
use dark_api::fatbin::{
decompress_lz4, decompress_zstd, Fatbin, FatbinFileIterator, FatbinSubmodule,
};
use rustc_hash::{FxHashMap, FxHashSet};
use std::{
@ -13,7 +16,6 @@ use std::{
fs::{self, File},
io::{self, Read, Write},
path::PathBuf,
ptr,
};
use unwrap_or::unwrap_some_or;
@ -259,14 +261,12 @@ pub(crate) unsafe fn record_submodules_from_wrapped_fatbin(
fn_logger: &mut FnCallLog,
state: &mut StateTracker,
) -> Result<(), ErrorEntry> {
let fatbinc_wrapper = FatbincWrapper::new(&fatbinc_wrapper).map_err(ErrorEntry::from)?;
let is_version_2 = fatbinc_wrapper.version == FatbincWrapper::VERSION_V2;
record_submodules_from_fatbin(module, (*fatbinc_wrapper).data, fn_logger, state)?;
if is_version_2 {
let mut current = (*fatbinc_wrapper).filename_or_fatbins as *const *const c_void;
while *current != ptr::null() {
record_submodules_from_fatbin(module, *current as *const _, fn_logger, state)?;
current = current.add(1);
let fatbin = Fatbin::new(&fatbinc_wrapper).map_err(ErrorEntry::from)?;
let first = fatbin.get_first().map_err(ErrorEntry::from)?;
record_submodules_from_fatbin(module, first, fn_logger, state)?;
if let Some(mut submodules) = fatbin.get_submodules() {
while let Some(current) = submodules.next()? {
record_submodules_from_fatbin(module, current, fn_logger, state)?;
}
}
Ok(())
@ -274,37 +274,43 @@ pub(crate) unsafe fn record_submodules_from_wrapped_fatbin(
pub(crate) unsafe fn record_submodules_from_fatbin(
module: CUmodule,
fatbin_header: *const FatbinHeader,
submodule: FatbinSubmodule,
logger: &mut FnCallLog,
state: &mut StateTracker,
) -> Result<(), ErrorEntry> {
let header = FatbinHeader::new(&fatbin_header).map_err(ErrorEntry::from)?;
let file = header.get_content();
record_submodules(module, logger, state, file)?;
record_submodules(module, logger, state, submodule.get_files())?;
Ok(())
}
unsafe fn record_submodules(
pub(crate) unsafe fn record_submodules(
module: CUmodule,
fn_logger: &mut FnCallLog,
state: &mut StateTracker,
mut file_buffer: &[u8],
mut files: FatbinFileIterator,
) -> Result<(), ErrorEntry> {
while let Some(file) = FatbinFileHeader::next(&mut file_buffer)? {
let mut payload = if file.flags.contains(FatbinFileHeaderFlags::CompressedLz4) {
while let Some(file) = files.next()? {
let mut payload = if file
.header
.flags
.contains(FatbinFileHeaderFlags::CompressedLz4)
{
Cow::Owned(unwrap_some_or!(
fn_logger.try_return(|| decompress_lz4(file)),
fn_logger.try_return(|| decompress_lz4(&file).map_err(|e| e.into())),
continue
))
} else if file.flags.contains(FatbinFileHeaderFlags::CompressedZstd) {
} else if file
.header
.flags
.contains(FatbinFileHeaderFlags::CompressedZstd)
{
Cow::Owned(unwrap_some_or!(
fn_logger.try_return(|| decompress_zstd(file)),
fn_logger.try_return(|| decompress_zstd(&file).map_err(|e| e.into())),
continue
))
} else {
Cow::Borrowed(file.get_payload())
};
match file.kind {
match file.header.kind {
FatbinFileHeader::HEADER_KIND_PTX => {
while payload.last() == Some(&0) {
// remove trailing zeros
@ -322,50 +328,10 @@ unsafe fn record_submodules(
UInt::U16(FatbinFileHeader::HEADER_KIND_PTX),
UInt::U16(FatbinFileHeader::HEADER_KIND_ELF),
],
observed: UInt::U16(file.kind),
observed: UInt::U16(file.header.kind),
});
}
}
}
Ok(())
}
const MAX_MODULE_DECOMPRESSION_BOUND: usize = 64 * 1024 * 1024;
unsafe fn decompress_lz4(file: &FatbinFileHeader) -> Result<Vec<u8>, ErrorEntry> {
let decompressed_size = usize::max(1024, (*file).uncompressed_payload as usize);
let mut decompressed_vec = vec![0u8; decompressed_size];
loop {
match lz4_sys::LZ4_decompress_safe(
file.get_payload().as_ptr() as *const _,
decompressed_vec.as_mut_ptr() as *mut _,
(*file).payload_size as _,
decompressed_vec.len() as _,
) {
error if error < 0 => {
let new_size = decompressed_vec.len() * 2;
if new_size > MAX_MODULE_DECOMPRESSION_BOUND {
return Err(ErrorEntry::Lz4DecompressionFailure);
}
decompressed_vec.resize(decompressed_vec.len() * 2, 0);
}
real_decompressed_size => {
decompressed_vec.truncate(real_decompressed_size as usize);
return Ok(decompressed_vec);
}
}
}
}
unsafe fn decompress_zstd(file: &FatbinFileHeader) -> Result<Vec<u8>, ErrorEntry> {
let mut result = Vec::with_capacity(file.uncompressed_payload as usize);
let payload = file.get_payload();
dbg!((payload.len(), file.uncompressed_payload, file.payload_size));
match zstd_safe::decompress(&mut result, payload) {
Ok(actual_size) => {
result.truncate(actual_size);
Ok(result)
}
Err(err) => Err(ErrorEntry::ZstdDecompressionFailure(err)),
}
}