diff --git a/dark_api/src/fatbin.rs b/dark_api/src/fatbin.rs index 86cff8e..c1772e0 100644 --- a/dark_api/src/fatbin.rs +++ b/dark_api/src/fatbin.rs @@ -1,6 +1,6 @@ // This file contains a higher-level interface for parsing fatbins -use std::ptr; +use std::{borrow::Cow, ptr}; use cuda_types::dark_api::*; @@ -125,9 +125,9 @@ pub enum FatbinIter<'a> { } impl<'a> FatbinIter<'a> { - pub fn next(&mut self) -> Result>, ParseError> { + pub fn next(&mut self) -> Option, ParseError>> { match self { - FatbinIter::V1(opt) => Ok(opt.take()), + FatbinIter::V1(opt) => Ok(opt.take()).transpose(), FatbinIter::V2(iter) => unsafe { iter.next() }, } } @@ -139,15 +139,18 @@ pub struct FatbinSubmoduleIterator<'a> { } impl<'a> FatbinSubmoduleIterator<'a> { - pub unsafe fn next(&mut self) -> Result>, ParseError> { + pub unsafe fn next(&mut self) -> Option, ParseError>> { if *self.fatbins != ptr::null() { let header = *self.fatbins as *const FatbinHeader; self.fatbins = self.fatbins.add(1); - Ok(Some(FatbinSubmodule::new(header.as_ref().ok_or( - ParseError::NullPointer("FatbinSubmoduleIterator"), - )?))) + Some( + header + .as_ref() + .ok_or(ParseError::NullPointer("FatbinSubmoduleIterator")) + .map(FatbinSubmodule::new), + ) } else { - Ok(None) + None } } } @@ -161,6 +164,14 @@ impl<'a> FatbinFile<'a> { Self { header } } + pub fn kind(&self) -> &'static str { + match self.header.kind { + FatbinFileHeader::HEADER_KIND_PTX => "ptx", + FatbinFileHeader::HEADER_KIND_ELF => "elf", + _ => "bin", + } + } + pub unsafe fn get_payload(&'a self) -> &'a [u8] { let start = std::ptr::from_ref(self.header) .cast::() @@ -168,28 +179,38 @@ impl<'a> FatbinFile<'a> { std::slice::from_raw_parts(start, self.header.payload_size as usize) } - pub unsafe fn decompress(&'a self) -> Result, FatbinError> { + pub unsafe fn get_or_decompress_content(&'a self) -> Result, FatbinError> { let mut payload = if self .header .flags .contains(FatbinFileHeaderFlags::CompressedLz4) { - unsafe { decompress_lz4(self) }? + Cow::Owned(unsafe { decompress_lz4(self) }?) } else if self .header .flags .contains(FatbinFileHeaderFlags::CompressedZstd) { - unsafe { decompress_zstd(self) }? + Cow::Owned(unsafe { decompress_zstd(self) }?) } else { - unsafe { self.get_payload().to_vec() } + Cow::Borrowed(unsafe { self.get_payload() }) }; - while payload.last() == Some(&0) { - // remove trailing zeros - payload.pop(); + // Remove trailing zeros + if self.header.kind == FatbinFileHeader::HEADER_KIND_PTX { + match payload { + Cow::Borrowed(ref mut slice) => { + while slice.last() == Some(&0) { + *slice = &slice[..slice.len() - 1]; + } + } + Cow::Owned(ref mut vec) => { + while vec.last() == Some(&0) { + vec.pop(); + } + } + } } - Ok(payload) } } @@ -208,9 +229,9 @@ impl<'a> FatbinFileIterator<'a> { Self { file_buffer } } - pub unsafe fn next(&mut self) -> Result>, ParseError> { + pub unsafe fn next(&mut self) -> Option, ParseError>> { if self.file_buffer.len() < std::mem::size_of::() { - return Ok(None); + return None; } let this = &*self.file_buffer.as_ptr().cast::(); let next_element = self @@ -218,12 +239,14 @@ impl<'a> FatbinFileIterator<'a> { .split_at_checked(this.header_size as usize + this.padded_payload_size as usize) .map(|(_, next)| next); self.file_buffer = next_element.unwrap_or(&[]); - ParseError::check_fields( - "FATBIN_FILE_HEADER_VERSION_CURRENT", - this.version, - [FatbinFileHeader::HEADER_VERSION_CURRENT], - )?; - Ok(Some(FatbinFile::new(this))) + Some( + ParseError::check_fields( + "FATBIN_FILE_HEADER_VERSION_CURRENT", + this.version, + [FatbinFileHeader::HEADER_VERSION_CURRENT], + ) + .map(|_| FatbinFile::new(this)), + ) } } diff --git a/zluda_common/src/lib.rs b/zluda_common/src/lib.rs index c611f9f..339d861 100644 --- a/zluda_common/src/lib.rs +++ b/zluda_common/src/lib.rs @@ -2,20 +2,17 @@ use cuda_types::{ cublas::*, cublaslt::cublasLtHandle_t, cuda::*, - dark_api::{FatbinFileHeaderFlags, FatbinHeader, FatbincWrapper}, + dark_api::{FatbinHeader, FatbincWrapper}, nvml::*, }; -use dark_api::fatbin::{ - Fatbin, FatbinError, FatbinFile, FatbinFileIterator, FatbinIter, FatbinSubmodule, ParseError, -}; +use dark_api::fatbin::{Fatbin, FatbinError, FatbinFile, FatbinSubmodule}; use hip_runtime_sys::*; use rocblas_sys::*; use std::{ - borrow::Cow, ffi::{c_void, CStr}, mem::{self, ManuallyDrop, MaybeUninit}, - ops::ControlFlow, ptr, + str::Utf8Error, }; pub trait CudaErrorType { @@ -480,14 +477,16 @@ pub enum CodeLibaryRef<'a> { FatbincWrapper(Fatbin<'a>), FatbinHeader(FatbinSubmodule<'a>), Text(&'a str), - Elf(*mut c_void), + Elf(*const c_void), + Archive(*const c_void), } impl<'a> CodeLibaryRef<'a> { const ELFMAG: [u8; 4] = *b"\x7FELF"; + const AR_MAGIC: [u8; 8] = *b"!\x0A"; - unsafe fn try_load(ptr: *mut c_void) -> Option { - Some(match *ptr.cast::<[u8; 4]>() { + pub unsafe fn try_load(ptr: *const c_void) -> Result { + Ok(match *ptr.cast::<[u8; 4]>() { FatbincWrapper::MAGIC => Self::FatbincWrapper(Fatbin { wrapper: &*(ptr.cast()), }), @@ -495,13 +494,16 @@ impl<'a> CodeLibaryRef<'a> { header: &*(ptr.cast()), }), Self::ELFMAG => Self::Elf(ptr), - _ => CStr::from_ptr(ptr.cast()).to_str().ok().map(Self::Text)?, + _ => match *ptr.cast::<[u8; 8]>() { + Self::AR_MAGIC => Self::Archive(ptr), + _ => CStr::from_ptr(ptr.cast()).to_str().map(Self::Text)?, + }, }) } - unsafe fn iterate_modules( + pub unsafe fn iterate_modules( &self, - fn_: &mut impl FnMut((usize, usize), Result), + mut fn_: impl FnMut(Option<(usize, Option)>, Result), ) { match self { CodeLibaryRef::FatbincWrapper(fatbin) => { @@ -511,124 +513,53 @@ impl<'a> CodeLibaryRef<'a> { let mut module_index = 0; while let Some(maybe_submodule) = iter.next() { match maybe_submodule { - Ok(submodule) => Self::iterate_modules( - &CodeLibaryRef::FatbinHeader(submodule), - &mut |(_, subindex), module| { - fn_((module_index, subindex), module) + Ok(submodule) => iterate_modules_fatbin_header( + &mut |subindex, module| { + let (subindex, _) = subindex.unwrap(); + fn_(Some((module_index, Some(subindex))), module) }, + &submodule, + ), + Err(err) => fn_( + Some((module_index, None)), + Err(FatbinError::ParseFailure(err)), ), - Err(err) => { - fn_((module_index, 0), Err(FatbinError::ParseFailure(err))) - } } module_index += 1; } } - Err(err) => fn_((0, 0), Err(err)), + Err(err) => fn_(None, Err(err)), } } CodeLibaryRef::FatbinHeader(submodule) => { - let mut iter = submodule.get_files(); - let mut index = 0; - while let Some(file) = iter.next() { - fn_( - (0, index), - file.map(CodeModule::File) - .map_err(FatbinError::ParseFailure), - ); - index += 1; - } + iterate_modules_fatbin_header(&mut fn_, submodule); } - CodeLibaryRef::Text(text) => fn_((0, 0), Ok(CodeModule::Text(*text))), - CodeLibaryRef::Elf(elf) => fn_((0, 0), Ok(CodeModule::Elf(*elf))), + CodeLibaryRef::Text(text) => fn_(None, Ok(CodeModule::Text(*text))), + CodeLibaryRef::Elf(elf) => fn_(None, Ok(CodeModule::Elf(*elf))), + CodeLibaryRef::Archive(ar) => fn_(None, Ok(CodeModule::Archive(*ar))), } } } -enum CodeModule<'a> { +unsafe fn iterate_modules_fatbin_header( + fn_: &mut impl FnMut(Option<(usize, Option)>, Result), + submodule: &FatbinSubmodule<'_>, +) { + let mut iter = submodule.get_files(); + let mut index = 0; + while let Some(file) = iter.next() { + fn_( + Some((index, None)), + file.map(CodeModule::File) + .map_err(FatbinError::ParseFailure), + ); + index += 1; + } +} + +pub enum CodeModule<'a> { File(FatbinFile<'a>), Text(&'a str), - Elf(*mut c_void), + Elf(*const c_void), + Archive(*const c_void), } - -/* -enum PtxIterator<'a> { - FatbincWrapper(FatbinIter<'a>), - FatbinHeader(FatbinFileIterator<'a>), - Text(std::iter::Once<&'a str>), -} - -impl<'a> PtxIterator<'a> { - fn next(&mut self) -> Option> { - match self { - PtxIterator::FatbincWrapper(iter) => { - while let Ok(Some(submodule)) = iter.next() { - let mut files = submodule.get_files(); - while let Some(file) = unsafe { files.next().ok()? } { - if file.header.kind == FatbinFileHeader::HEADER_KIND_PTX { - return Some( - unsafe { file.decompress().ok()? } - .as_slice() - .strip_suffix(&[0])? - .as_ref() - .and_then(|s| std::str::from_utf8(s).ok()), - ); - } - } - } - None - } - PtxIterator::FatbinHeader(iter) => unsafe { - iter.next().map(|file| file.map(|file| file.decompress())) - }, - PtxIterator::Text(iter) => iter.next(), - } - } -} - -fn decompress_payload<'a>(file: &'a FatbinFile) -> Result, FatbinError> { - let mut payload = if file - .header - .flags - .contains(FatbinFileHeaderFlags::CompressedLz4) - { - Cow::Owned(unwrap_some_or!( - fn_logger.try_return(|| decompress_lz4(&file).map_err(|e| e.into())), - continue - )) - } else if file - .header - .flags - .contains(FatbinFileHeaderFlags::CompressedZstd) - { - Cow::Owned(unwrap_some_or!( - fn_logger.try_return(|| decompress_zstd(&file).map_err(|e| e.into())), - continue - )) - } else { - Cow::Borrowed(file.get_payload()) - }; - match file.header.kind { - FatbinFileHeader::HEADER_KIND_PTX => { - while payload.last() == Some(&0) { - // remove trailing zeros - payload.to_mut().pop(); - } - state.record_new_submodule(module, &*payload, fn_logger, "ptx") - } - FatbinFileHeader::HEADER_KIND_ELF => { - state.record_new_submodule(module, &*payload, fn_logger, "elf") - } - _ => { - fn_logger.log(log::ErrorEntry::UnexpectedBinaryField { - field_name: "FATBIN_FILE_HEADER_KIND", - expected: vec![ - UInt::U16(FatbinFileHeader::HEADER_KIND_PTX), - UInt::U16(FatbinFileHeader::HEADER_KIND_ELF), - ], - observed: UInt::U16(file.header.kind), - }); - } - } -} -*/ diff --git a/zluda_trace/Cargo.toml b/zluda_trace/Cargo.toml index e709f17..a6c4120 100644 --- a/zluda_trace/Cargo.toml +++ b/zluda_trace/Cargo.toml @@ -12,6 +12,7 @@ crate-type = ["cdylib"] ptx = { path = "../ptx" } ptx_parser = { path = "../ptx_parser" } zluda_trace_common = { path = "../zluda_trace_common" } +zluda_common = { path = "../zluda_common" } format = { path = "../format" } dark_api = { path = "../dark_api" } regex = "1.4" diff --git a/zluda_trace/src/lib.rs b/zluda_trace/src/lib.rs index 50603f1..f61fed6 100644 --- a/zluda_trace/src/lib.rs +++ b/zluda_trace/src/lib.rs @@ -1,7 +1,5 @@ -use ::dark_api::fatbin::FatbinFileIterator; use ::dark_api::FnFfi; use cuda_types::cuda::*; -use cuda_types::dark_api::FatbinHeader; use dark_api::DarkApiState2; use log::{CudaFunctionName, ErrorEntry}; use parking_lot::ReentrantMutex; @@ -289,9 +287,7 @@ impl DarkApiTrace { fn_logger: &mut FnCallLog, _result: CUresult, ) { - fn_logger.try_(|fn_logger| unsafe { - trace::record_submodules_from_wrapped_fatbin(*module, fatbinc_wrapper, fn_logger, state) - }); + state.record_new_library(unsafe { *module }, fatbinc_wrapper.cast(), fn_logger) } fn get_module_from_cubin_ext1_post( @@ -325,9 +321,7 @@ impl DarkApiTrace { observed: UInt::U32(arg5), }); } - fn_logger.try_(|fn_logger| unsafe { - trace::record_submodules_from_wrapped_fatbin(*module, fatbinc_wrapper, fn_logger, state) - }); + state.record_new_library(unsafe { *module }, fatbinc_wrapper.cast(), fn_logger) } fn get_module_from_cubin_ext2_post( @@ -361,18 +355,7 @@ impl DarkApiTrace { observed: UInt::U32(arg5), }); } - fn_logger.try_(|fn_logger| unsafe { - trace::record_submodules( - *module, - fn_logger, - state, - FatbinFileIterator::new( - fatbin_header - .as_ref() - .ok_or(ErrorEntry::NullPointer("FatbinHeader"))?, - ), - ) - }); + state.record_new_library(unsafe { *module }, fatbin_header.cast(), fn_logger) } } @@ -1324,7 +1307,7 @@ pub(crate) fn cuModuleLoadData_Post( fn_logger: &mut FnCallLog, _result: CUresult, ) { - state.record_new_module(unsafe { *module }, raw_image, fn_logger) + state.record_new_library(unsafe { *module }, raw_image, fn_logger) } #[allow(non_snake_case)] @@ -1402,19 +1385,7 @@ pub(crate) fn cuModuleLoadFatBinary_Post( fn_logger: &mut FnCallLog, _result: CUresult, ) { - fn_logger.try_(|fn_logger| unsafe { - trace::record_submodules( - *module, - fn_logger, - state, - FatbinFileIterator::new( - fatbin_header - .cast::() - .as_ref() - .ok_or(ErrorEntry::NullPointer("FatbinHeader"))?, - ), - ) - }); + state.record_new_library(unsafe { *module }, fatbin_header.cast(), fn_logger) } #[allow(non_snake_case)] @@ -1427,16 +1398,7 @@ pub(crate) fn cuLibraryGetModule_Post( ) { match state.libraries.get(&library).copied() { None => fn_logger.log(log::ErrorEntry::UnknownLibrary(library)), - Some(code) => { - fn_logger.try_(|fn_logger| unsafe { - trace::record_submodules_from_wrapped_fatbin( - *module, - code.0.cast(), - fn_logger, - state, - ) - }); - } + Some(code) => state.record_new_library(unsafe { *module }, code.0, fn_logger), } } @@ -1459,5 +1421,5 @@ pub(crate) fn cuLibraryLoadData_Post( .insert(unsafe { *library }, trace::CodePointer(code)); // TODO: this is not correct, but it's enough for now, we just want to // save the binary to disk - state.record_new_module(unsafe { CUmodule((*library).0.cast()) }, code, fn_logger); + state.record_new_library(unsafe { CUmodule((*library).0.cast()) }, code, fn_logger); } diff --git a/zluda_trace/src/trace.rs b/zluda_trace/src/trace.rs index ae0662a..ec27fa0 100644 --- a/zluda_trace/src/trace.rs +++ b/zluda_trace/src/trace.rs @@ -1,18 +1,11 @@ use crate::{ - log::{self, UInt}, - trace, ErrorEntry, FnCallLog, Settings, -}; -use cuda_types::{ - cuda::*, - dark_api::{FatbinFileHeader, FatbinFileHeaderFlags, FatbincWrapper}, -}; -use dark_api::fatbin::{ - decompress_lz4, decompress_zstd, Fatbin, FatbinFileIterator, FatbinSubmodule, + log::{self}, + ErrorEntry, FnCallLog, Settings, }; +use cuda_types::cuda::*; use goblin::{elf, elf32, elf64}; use rustc_hash::{FxHashMap, FxHashSet}; use std::{ - borrow::Cow, ffi::{c_void, CStr, CString}, fs::{self, File}, io::{self, Read, Write}, @@ -29,8 +22,7 @@ pub(crate) struct StateTracker { writer: DumpWriter, pub(crate) libraries: FxHashMap, saved_modules: FxHashSet, - module_counter: usize, - submodule_counter: usize, + library_counter: usize, pub(crate) override_cc: Option<(u32, u32)>, } @@ -46,8 +38,7 @@ impl StateTracker { writer: DumpWriter::new(settings.dump_dir.clone()), libraries: FxHashMap::default(), saved_modules: FxHashSet::default(), - module_counter: 0, - submodule_counter: 0, + library_counter: 0, override_cc: settings.override_cc, } } @@ -78,25 +69,68 @@ impl StateTracker { let mut module_file = fs::File::open(file_name)?; let mut read_buff = Vec::new(); module_file.read_to_end(&mut read_buff)?; - self.record_new_module(module, read_buff.as_ptr() as *const _, fn_logger); + self.record_new_library(module, read_buff.as_ptr() as *const _, fn_logger); Ok(()) } + pub(crate) fn record_new_library( + &mut self, + cu_module: CUmodule, + raw_image: *const c_void, + fn_logger: &mut FnCallLog, + ) { + self.saved_modules.insert(cu_module); + self.library_counter += 1; + let code_ref = fn_logger.try_return(|| { + unsafe { zluda_common::CodeLibaryRef::try_load(raw_image) } + .map_err(ErrorEntry::NonUtf8ModuleText) + }); + let code_ref = unwrap_some_or!(code_ref, return); + unsafe { + code_ref.iterate_modules(|index, module| match module { + Err(err) => fn_logger.log(ErrorEntry::from(err)), + Ok(zluda_common::CodeModule::Elf(elf)) => match get_elf_size(elf) { + Some(len) => { + let elf_image = std::slice::from_raw_parts(elf.cast::(), len); + self.record_new_submodule(index, elf_image, fn_logger, "elf"); + } + None => fn_logger.log(log::ErrorEntry::UnsupportedModule { + module: cu_module, + raw_image: elf, + kind: "ELF", + }), + }, + Ok(zluda_common::CodeModule::Archive(archive)) => { + fn_logger.log(log::ErrorEntry::UnsupportedModule { + module: cu_module, + raw_image: archive, + kind: "archive", + }) + } + Ok(zluda_common::CodeModule::File(file)) => { + if let Some(buffer) = fn_logger + .try_(|_| file.get_or_decompress_content().map_err(ErrorEntry::from)) + { + self.record_new_submodule(index, &*buffer, fn_logger, file.kind()); + } + } + Ok(zluda_common::CodeModule::Text(ptx)) => { + self.record_new_submodule(index, ptx.as_bytes(), fn_logger, "ptx"); + } + }); + }; + } + pub(crate) fn record_new_submodule( &mut self, - module: CUmodule, + index: Option<(usize, Option)>, submodule: &[u8], fn_logger: &mut FnCallLog, type_: &'static str, ) { - if self.saved_modules.insert(module) { - self.module_counter += 1; - self.submodule_counter = 0; - } - self.submodule_counter += 1; fn_logger.log_io_error(self.writer.save_module( - self.module_counter, - Some(self.submodule_counter), + self.library_counter, + index, submodule, type_, )); @@ -107,8 +141,8 @@ impl StateTracker { Err(e) => fn_logger.log(log::ErrorEntry::NonUtf8ModuleText(e)), Ok(submodule_text) => self.try_parse_and_record_kernels( fn_logger, - self.module_counter, - Some(self.submodule_counter), + self.library_counter, + index, submodule_text, ), }, @@ -116,80 +150,11 @@ impl StateTracker { } } - pub(crate) fn record_new_module( - &mut self, - module: CUmodule, - raw_image: *const c_void, - fn_logger: &mut FnCallLog, - ) { - self.module_counter += 1; - if unsafe { *(raw_image as *const [u8; 4]) } == *elf64::header::ELFMAG { - self.saved_modules.insert(module); - match unsafe { get_elf_size(raw_image) } { - Some(len) => { - let elf_image = - unsafe { std::slice::from_raw_parts(raw_image.cast::(), len) }; - self.record_new_submodule(module, elf_image, fn_logger, "elf"); - } - None => fn_logger.log(log::ErrorEntry::UnsupportedModule { - module, - raw_image, - kind: "ELF", - }), - } - } else if unsafe { *(raw_image as *const [u8; 8]) } == *goblin::archive::MAGIC { - self.saved_modules.insert(module); - // TODO: Figure out how to get size of archive module and write it to disk - fn_logger.log(log::ErrorEntry::UnsupportedModule { - module, - raw_image, - kind: "archive", - }) - } else if unsafe { *(raw_image as *const [u8; 4]) } == FatbincWrapper::MAGIC { - unsafe { - fn_logger.try_(|fn_logger| { - trace::record_submodules_from_wrapped_fatbin( - module, - raw_image as *const FatbincWrapper, - fn_logger, - self, - ) - }); - } - } else { - self.record_module_ptx(module, raw_image, fn_logger) - } - } - - fn record_module_ptx( - &mut self, - module: CUmodule, - raw_image: *const c_void, - fn_logger: &mut FnCallLog, - ) { - self.saved_modules.insert(module); - let module_text = unsafe { CStr::from_ptr(raw_image as *const _) }.to_str(); - let module_text = match module_text { - Ok(m) => m, - Err(utf8_err) => { - fn_logger.log(log::ErrorEntry::NonUtf8ModuleText(utf8_err)); - return; - } - }; - fn_logger.log_io_error(self.writer.save_module( - self.module_counter, - None, - module_text.as_bytes(), - "ptx", - )); - self.try_parse_and_record_kernels(fn_logger, self.module_counter, None, module_text); - } - fn try_parse_and_record_kernels( &mut self, fn_logger: &mut FnCallLog, module_index: usize, - submodule_index: Option, + submodule_index: Option<(usize, Option)>, module_text: &str, ) { let errors = ptx_parser::parse_for_errors(module_text); @@ -359,7 +324,7 @@ impl DumpWriter { fn save_module( &self, module_index: usize, - submodule_index: Option, + submodule_index: Option<(usize, Option)>, buffer: &[u8], kind: &'static str, ) -> io::Result<()> { @@ -368,7 +333,7 @@ impl DumpWriter { Some(d) => d.clone(), }; dump_file.push(Self::get_file_name(module_index, submodule_index, kind)); - let mut file = File::create(dump_file)?; + let mut file = File::create_new(dump_file)?; file.write_all(buffer)?; Ok(()) } @@ -376,7 +341,7 @@ impl DumpWriter { fn save_module_error_log<'input>( &self, module_index: usize, - submodule_index: Option, + submodule_index: Option<(usize, Option)>, errors: &[ptx_parser::PtxError<'input>], ) -> io::Result<()> { let mut log_file = match &self.dump_dir { @@ -391,92 +356,24 @@ impl DumpWriter { Ok(()) } - fn get_file_name(module_index: usize, submodule_index: Option, kind: &str) -> String { + fn get_file_name( + module_index: usize, + submodule_index: Option<(usize, Option)>, + kind: &str, + ) -> String { match submodule_index { None => { format!("module_{:04}.{:02}", module_index, kind) } - Some(submodule_index) => { - format!("module_{:04}_{:02}.{}", module_index, submodule_index, kind) + Some((sub_index, None)) => { + format!("module_{:04}_{:02}.{}", module_index, sub_index, kind) + } + Some((sub_index, Some(subsub_index))) => { + format!( + "module_{:04}_{:02}_{:02}.{}", + module_index, sub_index, subsub_index, kind + ) } } } } - -pub(crate) unsafe fn record_submodules_from_wrapped_fatbin( - module: CUmodule, - fatbinc_wrapper: *const FatbincWrapper, - fn_logger: &mut FnCallLog, - state: &mut StateTracker, -) -> Result<(), ErrorEntry> { - let fatbin = Fatbin::new(&fatbinc_wrapper).map_err(ErrorEntry::from)?; - let mut submodules = fatbin.get_submodules()?; - while let Some(current) = submodules.next()? { - record_submodules_from_fatbin(module, current, fn_logger, state)?; - } - Ok(()) -} - -pub(crate) unsafe fn record_submodules_from_fatbin( - module: CUmodule, - submodule: FatbinSubmodule, - logger: &mut FnCallLog, - state: &mut StateTracker, -) -> Result<(), ErrorEntry> { - record_submodules(module, logger, state, submodule.get_files())?; - Ok(()) -} - -pub(crate) unsafe fn record_submodules( - module: CUmodule, - fn_logger: &mut FnCallLog, - state: &mut StateTracker, - mut files: FatbinFileIterator, -) -> Result<(), ErrorEntry> { - while let Some(file) = files.next()? { - let mut payload = if file - .header - .flags - .contains(FatbinFileHeaderFlags::CompressedLz4) - { - Cow::Owned(unwrap_some_or!( - fn_logger.try_return(|| decompress_lz4(&file).map_err(|e| e.into())), - continue - )) - } else if file - .header - .flags - .contains(FatbinFileHeaderFlags::CompressedZstd) - { - Cow::Owned(unwrap_some_or!( - fn_logger.try_return(|| decompress_zstd(&file).map_err(|e| e.into())), - continue - )) - } else { - Cow::Borrowed(file.get_payload()) - }; - match file.header.kind { - FatbinFileHeader::HEADER_KIND_PTX => { - while payload.last() == Some(&0) { - // remove trailing zeros - payload.to_mut().pop(); - } - state.record_new_submodule(module, &*payload, fn_logger, "ptx") - } - FatbinFileHeader::HEADER_KIND_ELF => { - state.record_new_submodule(module, &*payload, fn_logger, "elf") - } - _ => { - fn_logger.log(log::ErrorEntry::UnexpectedBinaryField { - field_name: "FATBIN_FILE_HEADER_KIND", - expected: vec![ - UInt::U16(FatbinFileHeader::HEADER_KIND_PTX), - UInt::U16(FatbinFileHeader::HEADER_KIND_ELF), - ], - observed: UInt::U16(file.header.kind), - }); - } - } - } - Ok(()) -}