diff --git a/dark_api/src/fatbin.rs b/dark_api/src/fatbin.rs index b14a79f..c9ff08f 100644 --- a/dark_api/src/fatbin.rs +++ b/dark_api/src/fatbin.rs @@ -75,21 +75,24 @@ impl<'a> Fatbin<'a> { Ok(Fatbin { wrapper }) } - pub fn get_first(&self) -> Result { - let header: &FatbinHeader = - parse_fatbin_header(&self.wrapper.data).map_err(|e| FatbinError::ParseFailure(e))?; - Ok(FatbinSubmodule::new(header)) - } - - pub fn get_submodules(&self) -> Option { - let is_version_2 = self.wrapper.version == FatbincWrapper::VERSION_V2; - if !is_version_2 { - return None; + pub fn get_submodules(&self) -> Result, FatbinError> { + match self.wrapper.version { + FatbincWrapper::VERSION_V2 => + Ok(FatbinIter::V2(FatbinSubmoduleIterator { + fatbins: self.wrapper.filename_or_fatbins as *const *const std::ffi::c_void, + _phantom: std::marker::PhantomData, + })), + FatbincWrapper::VERSION_V1 => { + let header = parse_fatbin_header(&self.wrapper.data) + .map_err(FatbinError::ParseFailure)?; + Ok(FatbinIter::V1(Some(FatbinSubmodule::new(header)))) + } + version => Err(FatbinError::ParseFailure(ParseError::UnexpectedBinaryField{ + field_name: "FATBINC_VERSION", + observed: version, + expected: [FatbincWrapper::VERSION_V1, FatbincWrapper::VERSION_V2].into(), + })), } - - Some(FatbinSubmoduleIterator { - fatbins: self.wrapper.filename_or_fatbins as *const *const std::ffi::c_void, - }) } } @@ -107,12 +110,27 @@ impl<'a> FatbinSubmodule<'a> { } } -pub struct FatbinSubmoduleIterator { - fatbins: *const *const std::ffi::c_void, +pub enum FatbinIter<'a> { + V1(Option>), + V2(FatbinSubmoduleIterator<'a>), } -impl FatbinSubmoduleIterator { - pub unsafe fn next(&mut self) -> Result, ParseError> { +impl<'a> FatbinIter<'a> { + pub fn next(&mut self) -> Result>, ParseError> { + match self { + FatbinIter::V1(opt) => Ok(opt.take()), + FatbinIter::V2(iter) => unsafe { iter.next() }, + } + } +} + +pub struct FatbinSubmoduleIterator<'a> { + fatbins: *const *const std::ffi::c_void, + _phantom: std::marker::PhantomData<&'a FatbinHeader>, +} + +impl<'a> FatbinSubmoduleIterator<'a> { + pub unsafe fn next(&mut self) -> Result>, ParseError> { if *self.fatbins != ptr::null() { let header = *self.fatbins as *const FatbinHeader; self.fatbins = self.fatbins.add(1); diff --git a/zluda/src/impl/module.rs b/zluda/src/impl/module.rs index c55bfa6..f7b9f22 100644 --- a/zluda/src/impl/module.rs +++ b/zluda/src/impl/module.rs @@ -24,14 +24,15 @@ impl ZludaObject for Module { fn get_ptx_from_wrapped_fatbin(image: *const ::core::ffi::c_void) -> Result, CUerror> { let fatbin = Fatbin::new(&image).map_err(|_| CUerror::UNKNOWN)?; - let first = fatbin.get_first().map_err(|_| CUerror::UNKNOWN)?; - let mut files = first.get_files(); + let mut submodules = fatbin.get_submodules().map_err(|_| CUerror::UNKNOWN)?; - while let Some(file) = unsafe { files.next().map_err(|_| CUerror::UNKNOWN)? } { - // Eventually we will want to get the PTX for the highest hardware version that we can get to compile. But for now we just get the first PTX we can find. - if file.header.kind == FatbinFileHeader::HEADER_KIND_PTX { - let decompressed = unsafe { file.decompress() }.map_err(|_| CUerror::UNKNOWN)?; - return Ok(decompressed); + while let Some(current) = unsafe { submodules.next().map_err(|_| CUerror::UNKNOWN)? } { + let mut files = current.get_files(); + while let Some(file) = unsafe { files.next().map_err(|_| CUerror::UNKNOWN)? } { + if file.header.kind == FatbinFileHeader::HEADER_KIND_PTX { + let decompressed = unsafe { file.decompress() }.map_err(|_| CUerror::UNKNOWN)?; + return Ok(decompressed); + } } } diff --git a/zluda_dump/src/trace.rs b/zluda_dump/src/trace.rs index 23665fc..13e2f4a 100644 --- a/zluda_dump/src/trace.rs +++ b/zluda_dump/src/trace.rs @@ -262,12 +262,9 @@ pub(crate) unsafe fn record_submodules_from_wrapped_fatbin( state: &mut StateTracker, ) -> Result<(), ErrorEntry> { let fatbin = Fatbin::new(&fatbinc_wrapper).map_err(ErrorEntry::from)?; - let first = fatbin.get_first().map_err(ErrorEntry::from)?; - record_submodules_from_fatbin(module, first, fn_logger, state)?; - if let Some(mut submodules) = fatbin.get_submodules() { - while let Some(current) = submodules.next()? { - record_submodules_from_fatbin(module, current, fn_logger, state)?; - } + let mut submodules = fatbin.get_submodules()?; + while let Some(current) = submodules.next()? { + record_submodules_from_fatbin(module, current, fn_logger, state)?; } Ok(()) }