Progress compilation despite parsing errors (#495)

Previously if we ran into a broken instruction we'd fail whole compilation. This PR changes it so (only in Release mode) we try and progress at all cost. Meaning that if we had trouble parsing an instruction we just remove function form the output and continue.

For some workloads we can still compile a semi-broken, but meaningful subset of a module
This commit is contained in:
Andrzej Janik 2025-09-08 23:35:29 +02:00 committed by GitHub
commit 869d291099
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 1043 additions and 391 deletions

15
Cargo.lock generated
View file

@ -138,12 +138,6 @@ dependencies = [
"syn 2.0.89",
]
[[package]]
name = "bit-vec"
version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb"
[[package]]
name = "bit-vec"
version = "0.8.0"
@ -363,10 +357,9 @@ dependencies = [
name = "compiler"
version = "0.0.0"
dependencies = [
"amd_comgr-sys",
"bpaf",
"comgr",
"hip_runtime-sys",
"libloading",
"ptx",
"ptx_parser",
"thiserror 2.0.12",
@ -462,7 +455,7 @@ dependencies = [
name = "dark_api"
version = "0.0.0"
dependencies = [
"bit-vec 0.8.0",
"bit-vec",
"cglue",
"cuda_types",
"format",
@ -2586,7 +2579,7 @@ dependencies = [
name = "ptx"
version = "0.0.0"
dependencies = [
"bit-vec 0.6.3",
"bit-vec",
"bitflags 1.3.2",
"comgr",
"cuda_macros",
@ -3797,6 +3790,7 @@ name = "zluda_common"
version = "0.1.0"
dependencies = [
"cuda_types",
"dark_api",
"hip_runtime-sys",
"rocblas-sys",
]
@ -3889,6 +3883,7 @@ dependencies = [
"unwrap_or",
"wchar",
"winapi",
"zluda_common",
"zluda_trace_common",
"zstd-safe",
]

View file

@ -10,12 +10,11 @@ name = "zoc"
path = "src/main.rs"
[dependencies]
amd_comgr-sys = { path = "../ext/amd_comgr-sys" }
bpaf = { version = "0.9.19", features = ["derive"] }
comgr = { path = "../comgr" }
hip_runtime-sys = { path = "../ext/hip_runtime-sys" }
ptx = { path = "../ptx" }
ptx_parser = { path = "../ptx_parser" }
libloading = "0.8"
thiserror = "2.0.12"
[package.metadata.zluda]

View file

@ -1,15 +1,15 @@
use ptx::TranslateError;
use ptx_parser::PtxError;
use std::ffi::FromBytesUntilNulError;
use std::io;
use std::str::Utf8Error;
use hip_runtime_sys::hipErrorCode_t;
use ptx::TranslateError;
use ptx_parser::PtxError;
#[derive(Debug, thiserror::Error)]
pub enum CompilerError {
#[error("HIP error code: {0:?}")]
HipError(hipErrorCode_t),
HipError(u32),
#[error(transparent)]
Libloading(#[from] libloading::Error),
#[error(transparent)]
ComgrError(#[from] comgr::Error),
#[error(transparent)]
@ -26,12 +26,6 @@ pub enum CompilerError {
},
}
impl From<hipErrorCode_t> for CompilerError {
fn from(error_code: hipErrorCode_t) -> Self {
CompilerError::HipError(error_code)
}
}
impl From<Vec<PtxError<'_>>> for CompilerError {
fn from(causes: Vec<PtxError>) -> Self {
let errors: Vec<String> = causes

View file

@ -1,3 +1,5 @@
use bpaf::Bpaf;
use error::CompilerError;
use std::ffi::CStr;
use std::fs::{self, File};
use std::io::{self, Write};
@ -6,11 +8,7 @@ use std::process::ExitCode;
use std::str;
use std::{env, mem};
use bpaf::Bpaf;
mod error;
use error::CompilerError;
use hip_runtime_sys::{hipDeviceProp_tR0600, hipGetDevicePropertiesR0600, hipInit};
const DEFAULT_ARCH: &'static str = "gfx1100";
@ -60,12 +58,17 @@ fn main_core() -> Result<(), CompilerError> {
let arch: String = match opts.arch {
Some(s) => s,
None => {
unsafe { hipInit(0) }?;
let mut dev_props: hipDeviceProp_tR0600 = unsafe { mem::zeroed() };
unsafe { hipGetDevicePropertiesR0600(&mut dev_props, 0) }?;
(|| {
let runtime = hip::Runtime::load()?;
runtime.init()?;
get_gpu_arch(&runtime)
})()
.unwrap_or_else(|_| DEFAULT_ARCH.to_owned())
/*
get_gpu_arch(&mut dev_props)
.map(String::from)
.unwrap_or(DEFAULT_ARCH.to_owned())
*/
}
};
@ -122,12 +125,13 @@ struct LLVMArtifacts {
llvm_ir: Vec<u8>,
}
fn get_gpu_arch<'a>(dev_props: &'a mut hipDeviceProp_tR0600) -> Result<&'a str, CompilerError> {
unsafe { hipGetDevicePropertiesR0600(dev_props, 0) }?;
fn get_gpu_arch(runtime: &hip::Runtime) -> Result<String, CompilerError> {
let mut dev_props = unsafe { mem::zeroed() };
runtime.device_get_properties(&mut dev_props, 0)?;
let gcn_arch_name = &dev_props.gcnArchName;
let gcn_arch_name = unsafe { CStr::from_ptr(gcn_arch_name.as_ptr()) };
let gcn_arch_name = gcn_arch_name.to_str();
gcn_arch_name.map_err(CompilerError::from)
let gcn_arch_name = gcn_arch_name.to_str()?;
Ok(gcn_arch_name.to_string())
}
fn write_to_file(content: &[u8], path: &Path) -> io::Result<()> {
@ -137,3 +141,316 @@ fn write_to_file(content: &[u8], path: &Path) -> io::Result<()> {
println!("Wrote to {}", path.to_str().unwrap());
Ok(())
}
mod hip {
use crate::error::CompilerError;
// We lazy load HIP runtime because we want to work on systems with no
// HIP driver installed
pub struct Runtime(libloading::Library);
impl Runtime {
fn hip_check(err: u32) -> Result<(), CompilerError> {
match err {
0 => Ok(()),
err_code => Err(CompilerError::HipError(err_code)),
}
}
pub fn load() -> Result<Self, CompilerError> {
#[cfg(windows)]
let lib_name = "amdhip64_6.dll\0";
#[cfg(unix)]
let lib_name = "libamdhip64.so.6\0";
let library = unsafe { libloading::Library::new(lib_name)? };
Ok(Self(library))
}
pub fn init(&self) -> Result<(), CompilerError> {
unsafe {
let hip_init: libloading::Symbol<unsafe extern "C" fn(u32) -> u32> =
self.0.get(b"hipInit\0")?;
Self::hip_check(hip_init(0))
}
}
pub fn device_get_properties(
&self,
prop: &mut hipDeviceProp_tR0600,
device: i32,
) -> Result<(), CompilerError> {
unsafe {
let hip_get_device_properties: libloading::Symbol<
unsafe extern "C" fn(*mut hipDeviceProp_tR0600, i32) -> u32,
> = self.0.get(b"hipGetDevicePropertiesR0600\0")?;
Self::hip_check(hip_get_device_properties(prop, device))
}
}
}
#[allow(non_snake_case, non_camel_case_types)]
#[repr(C)]
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
pub struct hipDeviceProp_tR0600 {
///< Device name.
pub name: [::core::ffi::c_char; 256usize],
///< UUID of a device
pub uuid: hipUUID,
///< 8-byte unique identifier. Only valid on windows
pub luid: [::core::ffi::c_char; 8usize],
///< LUID node mask
pub luidDeviceNodeMask: ::core::ffi::c_uint,
///< Size of global memory region (in bytes).
pub totalGlobalMem: usize,
///< Size of shared memory per block (in bytes).
pub sharedMemPerBlock: usize,
///< Registers per block.
pub regsPerBlock: ::core::ffi::c_int,
///< Warp size.
pub warpSize: ::core::ffi::c_int,
/**< Maximum pitch in bytes allowed by memory copies
< pitched memory*/
pub memPitch: usize,
///< Max work items per work group or workgroup max size.
pub maxThreadsPerBlock: ::core::ffi::c_int,
///< Max number of threads in each dimension (XYZ) of a block.
pub maxThreadsDim: [::core::ffi::c_int; 3usize],
///< Max grid dimensions (XYZ).
pub maxGridSize: [::core::ffi::c_int; 3usize],
///< Max clock frequency of the multiProcessors in khz.
pub clockRate: ::core::ffi::c_int,
/**< Size of shared constant memory region on the device
< (in bytes).*/
pub totalConstMem: usize,
/**< Major compute capability. On HCC, this is an approximation and features may
< differ from CUDA CC. See the arch feature flags for portable ways to query
< feature caps.*/
pub major: ::core::ffi::c_int,
/**< Minor compute capability. On HCC, this is an approximation and features may
< differ from CUDA CC. See the arch feature flags for portable ways to query
< feature caps.*/
pub minor: ::core::ffi::c_int,
///< Alignment requirement for textures
pub textureAlignment: usize,
///< Pitch alignment requirement for texture references bound to
pub texturePitchAlignment: usize,
///< Deprecated. Use asyncEngineCount instead
pub deviceOverlap: ::core::ffi::c_int,
///< Number of multi-processors (compute units).
pub multiProcessorCount: ::core::ffi::c_int,
///< Run time limit for kernels executed on the device
pub kernelExecTimeoutEnabled: ::core::ffi::c_int,
///< APU vs dGPU
pub integrated: ::core::ffi::c_int,
///< Check whether HIP can map host memory
pub canMapHostMemory: ::core::ffi::c_int,
///< Compute mode.
pub computeMode: ::core::ffi::c_int,
///< Maximum number of elements in 1D images
pub maxTexture1D: ::core::ffi::c_int,
///< Maximum 1D mipmap texture size
pub maxTexture1DMipmap: ::core::ffi::c_int,
///< Maximum size for 1D textures bound to linear memory
pub maxTexture1DLinear: ::core::ffi::c_int,
///< Maximum dimensions (width, height) of 2D images, in image elements
pub maxTexture2D: [::core::ffi::c_int; 2usize],
///< Maximum number of elements in 2D array mipmap of images
pub maxTexture2DMipmap: [::core::ffi::c_int; 2usize],
///< Maximum 2D tex dimensions if tex are bound to pitched memory
pub maxTexture2DLinear: [::core::ffi::c_int; 3usize],
///< Maximum 2D tex dimensions if gather has to be performed
pub maxTexture2DGather: [::core::ffi::c_int; 2usize],
/**< Maximum dimensions (width, height, depth) of 3D images, in image
< elements*/
pub maxTexture3D: [::core::ffi::c_int; 3usize],
///< Maximum alternate 3D texture dims
pub maxTexture3DAlt: [::core::ffi::c_int; 3usize],
///< Maximum cubemap texture dims
pub maxTextureCubemap: ::core::ffi::c_int,
///< Maximum number of elements in 1D array images
pub maxTexture1DLayered: [::core::ffi::c_int; 2usize],
///< Maximum number of elements in 2D array images
pub maxTexture2DLayered: [::core::ffi::c_int; 3usize],
///< Maximum cubemaps layered texture dims
pub maxTextureCubemapLayered: [::core::ffi::c_int; 2usize],
///< Maximum 1D surface size
pub maxSurface1D: ::core::ffi::c_int,
///< Maximum 2D surface size
pub maxSurface2D: [::core::ffi::c_int; 2usize],
///< Maximum 3D surface size
pub maxSurface3D: [::core::ffi::c_int; 3usize],
///< Maximum 1D layered surface size
pub maxSurface1DLayered: [::core::ffi::c_int; 2usize],
///< Maximum 2D layared surface size
pub maxSurface2DLayered: [::core::ffi::c_int; 3usize],
///< Maximum cubemap surface size
pub maxSurfaceCubemap: ::core::ffi::c_int,
///< Maximum cubemap layered surface size
pub maxSurfaceCubemapLayered: [::core::ffi::c_int; 2usize],
///< Alignment requirement for surface
pub surfaceAlignment: usize,
///< Device can possibly execute multiple kernels concurrently.
pub concurrentKernels: ::core::ffi::c_int,
///< Device has ECC support enabled
pub ECCEnabled: ::core::ffi::c_int,
///< PCI Bus ID.
pub pciBusID: ::core::ffi::c_int,
///< PCI Device ID.
pub pciDeviceID: ::core::ffi::c_int,
///< PCI Domain ID
pub pciDomainID: ::core::ffi::c_int,
///< 1:If device is Tesla device using TCC driver, else 0
pub tccDriver: ::core::ffi::c_int,
///< Number of async engines
pub asyncEngineCount: ::core::ffi::c_int,
///< Does device and host share unified address space
pub unifiedAddressing: ::core::ffi::c_int,
///< Max global memory clock frequency in khz.
pub memoryClockRate: ::core::ffi::c_int,
///< Global memory bus width in bits.
pub memoryBusWidth: ::core::ffi::c_int,
///< L2 cache size.
pub l2CacheSize: ::core::ffi::c_int,
///< Device's max L2 persisting lines in bytes
pub persistingL2CacheMaxSize: ::core::ffi::c_int,
///< Maximum resident threads per multi-processor.
pub maxThreadsPerMultiProcessor: ::core::ffi::c_int,
///< Device supports stream priority
pub streamPrioritiesSupported: ::core::ffi::c_int,
///< Indicates globals are cached in L1
pub globalL1CacheSupported: ::core::ffi::c_int,
///< Locals are cahced in L1
pub localL1CacheSupported: ::core::ffi::c_int,
///< Amount of shared memory available per multiprocessor.
pub sharedMemPerMultiprocessor: usize,
///< registers available per multiprocessor
pub regsPerMultiprocessor: ::core::ffi::c_int,
///< Device supports allocating managed memory on this system
pub managedMemory: ::core::ffi::c_int,
///< 1 if device is on a multi-GPU board, 0 if not.
pub isMultiGpuBoard: ::core::ffi::c_int,
///< Unique identifier for a group of devices on same multiboard GPU
pub multiGpuBoardGroupID: ::core::ffi::c_int,
///< Link between host and device supports native atomics
pub hostNativeAtomicSupported: ::core::ffi::c_int,
///< Deprecated. CUDA only.
pub singleToDoublePrecisionPerfRatio: ::core::ffi::c_int,
/**< Device supports coherently accessing pageable memory
< without calling hipHostRegister on it*/
pub pageableMemoryAccess: ::core::ffi::c_int,
/**< Device can coherently access managed memory concurrently with
< the CPU*/
pub concurrentManagedAccess: ::core::ffi::c_int,
///< Is compute preemption supported on the device
pub computePreemptionSupported: ::core::ffi::c_int,
/**< Device can access host registered memory with same
< address as the host*/
pub canUseHostPointerForRegisteredMem: ::core::ffi::c_int,
///< HIP device supports cooperative launch
pub cooperativeLaunch: ::core::ffi::c_int,
/**< HIP device supports cooperative launch on multiple
< devices*/
pub cooperativeMultiDeviceLaunch: ::core::ffi::c_int,
///< Per device m ax shared mem per block usable by special opt in
pub sharedMemPerBlockOptin: usize,
/**< Device accesses pageable memory via the host's
< page tables*/
pub pageableMemoryAccessUsesHostPageTables: ::core::ffi::c_int,
/**< Host can directly access managed memory on the device
< without migration*/
pub directManagedMemAccessFromHost: ::core::ffi::c_int,
///< Max number of blocks on CU
pub maxBlocksPerMultiProcessor: ::core::ffi::c_int,
///< Max value of access policy window
pub accessPolicyMaxWindowSize: ::core::ffi::c_int,
///< Shared memory reserved by driver per block
pub reservedSharedMemPerBlock: usize,
///< Device supports hipHostRegister
pub hostRegisterSupported: ::core::ffi::c_int,
///< Indicates if device supports sparse hip arrays
pub sparseHipArraySupported: ::core::ffi::c_int,
/**< Device supports using the hipHostRegisterReadOnly flag
< with hipHostRegistger*/
pub hostRegisterReadOnlySupported: ::core::ffi::c_int,
///< Indicates external timeline semaphore support
pub timelineSemaphoreInteropSupported: ::core::ffi::c_int,
///< Indicates if device supports hipMallocAsync and hipMemPool APIs
pub memoryPoolsSupported: ::core::ffi::c_int,
///< Indicates device support of RDMA APIs
pub gpuDirectRDMASupported: ::core::ffi::c_int,
/**< Bitmask to be interpreted according to
< hipFlushGPUDirectRDMAWritesOptions*/
pub gpuDirectRDMAFlushWritesOptions: ::core::ffi::c_uint,
///< value of hipGPUDirectRDMAWritesOrdering
pub gpuDirectRDMAWritesOrdering: ::core::ffi::c_int,
///< Bitmask of handle types support with mempool based IPC
pub memoryPoolSupportedHandleTypes: ::core::ffi::c_uint,
/**< Device supports deferred mapping HIP arrays and HIP
< mipmapped arrays*/
pub deferredMappingHipArraySupported: ::core::ffi::c_int,
///< Device supports IPC events
pub ipcEventSupported: ::core::ffi::c_int,
///< Device supports cluster launch
pub clusterLaunch: ::core::ffi::c_int,
///< Indicates device supports unified function pointers
pub unifiedFunctionPointers: ::core::ffi::c_int,
///< CUDA Reserved.
pub reserved: [::core::ffi::c_int; 63usize],
///< Reserved for adding new entries for HIP/CUDA.
pub hipReserved: [::core::ffi::c_int; 32usize],
///< AMD GCN Arch Name. HIP Only.
pub gcnArchName: [::core::ffi::c_char; 256usize],
///< Maximum Shared Memory Per CU. HIP Only.
pub maxSharedMemoryPerMultiProcessor: usize,
/**< Frequency in khz of the timer used by the device-side "clock*"
< instructions. New for HIP.*/
pub clockInstructionRate: ::core::ffi::c_int,
///< Architectural feature flags. New for HIP.
pub arch: hipDeviceArch_t,
///< Addres of HDP_MEM_COHERENCY_FLUSH_CNTL register
pub hdpMemFlushCntl: *mut ::core::ffi::c_uint,
///< Addres of HDP_REG_COHERENCY_FLUSH_CNTL register
pub hdpRegFlushCntl: *mut ::core::ffi::c_uint,
/**< HIP device supports cooperative launch on
< multiple*/
pub cooperativeMultiDeviceUnmatchedFunc: ::core::ffi::c_int,
/**< HIP device supports cooperative launch on
< multiple*/
pub cooperativeMultiDeviceUnmatchedGridDim: ::core::ffi::c_int,
/**< HIP device supports cooperative launch on
< multiple*/
pub cooperativeMultiDeviceUnmatchedBlockDim: ::core::ffi::c_int,
/**< HIP device supports cooperative launch on
< multiple*/
pub cooperativeMultiDeviceUnmatchedSharedMem: ::core::ffi::c_int,
///< 1: if it is a large PCI bar device, else 0
pub isLargeBar: ::core::ffi::c_int,
///< Revision of the GPU in this device
pub asicRevision: ::core::ffi::c_int,
}
#[allow(non_snake_case, non_camel_case_types)]
#[repr(C)]
#[repr(align(4))]
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
pub struct hipDeviceArch_t {
pub _bitfield_align_1: [u8; 0],
pub _bitfield_1: __BindgenBitfieldUnit<[u8; 3usize]>,
pub __bindgen_padding_0: u8,
}
#[repr(C)]
#[derive(Copy, Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct __BindgenBitfieldUnit<Storage> {
storage: Storage,
}
#[allow(non_camel_case_types)]
#[repr(C)]
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
pub struct hipUUID_t {
pub bytes: [::core::ffi::c_char; 16usize],
}
#[allow(non_camel_case_types)]
pub type hipUUID = hipUUID_t;
}

View file

@ -45,13 +45,14 @@ pub struct FatbinHeader {
}
#[repr(C)]
#[derive(Debug)]
pub struct FatbinFileHeader {
pub kind: c_ushort,
pub version: c_ushort,
pub header_size: c_uint,
pub padded_payload_size: c_uint,
pub unknown0: c_uint, // check if it's written into separately
pub payload_size: c_uint,
pub unknown0: c_uint, // check if it's written into separately
pub compressed_size: c_uint,
pub unknown1: c_uint,
pub unknown2: c_uint,
pub sm_version: c_uint,
@ -63,6 +64,7 @@ pub struct FatbinFileHeader {
}
bitflags! {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct FatbinFileHeaderFlags: u64 {
const Is64Bit = 0x0000000000000001;
const Debug = 0x0000000000000002;
@ -77,13 +79,13 @@ bitflags! {
}
impl FatbincWrapper {
pub const MAGIC: c_uint = 0x466243B1;
pub const MAGIC: [u8; 4] = 0x466243B1u32.to_le_bytes();
pub const VERSION_V1: c_uint = 0x1;
pub const VERSION_V2: c_uint = 0x2;
}
impl FatbinHeader {
pub const MAGIC: c_uint = 0xBA55ED50;
pub const MAGIC: [u8; 4] = 0xBA55ED50u32.to_le_bytes();
pub const VERSION: c_ushort = 0x01;
}

View file

@ -1,6 +1,6 @@
// This file contains a higher-level interface for parsing fatbins
use std::ptr;
use std::{borrow::Cow, ptr};
use cuda_types::dark_api::*;
@ -43,7 +43,11 @@ pub fn parse_fatbinc_wrapper<T: Sized>(ptr: &*const T) -> Result<&FatbincWrapper
unsafe { ptr.cast::<FatbincWrapper>().as_ref() }
.ok_or(ParseError::NullPointer("FatbincWrapper"))
.and_then(|ptr| {
ParseError::check_fields("FATBINC_MAGIC", ptr.magic, [FatbincWrapper::MAGIC])?;
ParseError::check_fields(
"FATBINC_MAGIC",
ptr.magic,
[u32::from_ne_bytes(FatbincWrapper::MAGIC)],
)?;
ParseError::check_fields(
"FATBINC_VERSION",
ptr.version,
@ -57,12 +61,17 @@ fn parse_fatbin_header<T: Sized>(ptr: &*const T) -> Result<&FatbinHeader, ParseE
unsafe { ptr.cast::<FatbinHeader>().as_ref() }
.ok_or(ParseError::NullPointer("FatbinHeader"))
.and_then(|ptr| {
ParseError::check_fields("FATBIN_MAGIC", ptr.magic, [FatbinHeader::MAGIC])?;
ParseError::check_fields(
"FATBIN_MAGIC",
ptr.magic,
[u32::from_ne_bytes(FatbinHeader::MAGIC)],
)?;
ParseError::check_fields("FATBIN_VERSION", ptr.version, [FatbinHeader::VERSION])?;
Ok(ptr)
})
}
#[derive(Clone, Copy)]
pub struct Fatbin<'a> {
pub wrapper: &'a FatbincWrapper,
}
@ -75,7 +84,7 @@ impl<'a> Fatbin<'a> {
Ok(Fatbin { wrapper })
}
pub fn get_submodules(&self) -> Result<FatbinIter<'a>, FatbinError> {
pub fn get_submodules(self) -> Result<FatbinIter<'a>, FatbinError> {
match self.wrapper.version {
FatbincWrapper::VERSION_V2 => Ok(FatbinIter::V2(FatbinSubmoduleIterator {
fatbins: self.wrapper.filename_or_fatbins as *const *const std::ffi::c_void,
@ -97,6 +106,10 @@ impl<'a> Fatbin<'a> {
}
}
unsafe impl Send for Fatbin<'static> {}
unsafe impl Sync for Fatbin<'static> {}
#[derive(Clone, Copy)]
pub struct FatbinSubmodule<'a> {
pub header: &'a FatbinHeader, // TODO: maybe make private
}
@ -117,9 +130,13 @@ pub enum FatbinIter<'a> {
}
impl<'a> FatbinIter<'a> {
pub fn next(&mut self) -> Result<Option<FatbinSubmodule<'a>>, ParseError> {
pub fn multi_module(&self) -> bool {
matches!(self, FatbinIter::V2(_))
}
pub fn next(&mut self) -> Option<Result<FatbinSubmodule<'a>, ParseError>> {
match self {
FatbinIter::V1(opt) => Ok(opt.take()),
FatbinIter::V1(opt) => Ok(opt.take()).transpose(),
FatbinIter::V2(iter) => unsafe { iter.next() },
}
}
@ -131,19 +148,23 @@ pub struct FatbinSubmoduleIterator<'a> {
}
impl<'a> FatbinSubmoduleIterator<'a> {
pub unsafe fn next(&mut self) -> Result<Option<FatbinSubmodule<'a>>, ParseError> {
pub unsafe fn next(&mut self) -> Option<Result<FatbinSubmodule<'a>, ParseError>> {
if *self.fatbins != ptr::null() {
let header = *self.fatbins as *const FatbinHeader;
self.fatbins = self.fatbins.add(1);
Ok(Some(FatbinSubmodule::new(header.as_ref().ok_or(
ParseError::NullPointer("FatbinSubmoduleIterator"),
)?)))
Some(
header
.as_ref()
.ok_or(ParseError::NullPointer("FatbinSubmoduleIterator"))
.map(FatbinSubmodule::new),
)
} else {
Ok(None)
None
}
}
}
#[derive(Clone, Copy)]
pub struct FatbinFile<'a> {
pub header: &'a FatbinFileHeader,
}
@ -153,35 +174,60 @@ impl<'a> FatbinFile<'a> {
Self { header }
}
pub unsafe fn get_payload(&'a self) -> &'a [u8] {
pub fn kind(&self) -> &'static str {
match self.header.kind {
FatbinFileHeader::HEADER_KIND_PTX => "ptx",
FatbinFileHeader::HEADER_KIND_ELF => "elf",
_ => "bin",
}
}
pub unsafe fn get_non_compressed_payload(self) -> &'a [u8] {
let start = std::ptr::from_ref(self.header)
.cast::<u8>()
.add(self.header.header_size as usize);
std::slice::from_raw_parts(start, self.header.payload_size as usize)
}
pub unsafe fn decompress(&'a self) -> Result<Vec<u8>, FatbinError> {
pub unsafe fn get_compressed_payload(self) -> &'a [u8] {
let start = std::ptr::from_ref(self.header)
.cast::<u8>()
.add(self.header.header_size as usize);
std::slice::from_raw_parts(start, self.header.compressed_size as usize)
}
pub unsafe fn get_or_decompress_content(self) -> Result<Cow<'a, [u8]>, FatbinError> {
let mut payload = if self
.header
.flags
.contains(FatbinFileHeaderFlags::CompressedLz4)
{
unsafe { decompress_lz4(self) }?
Cow::Owned(unsafe { decompress_lz4(self) }?)
} else if self
.header
.flags
.contains(FatbinFileHeaderFlags::CompressedZstd)
{
unsafe { decompress_zstd(self) }?
Cow::Owned(unsafe { decompress_zstd(self) }?)
} else {
unsafe { self.get_payload().to_vec() }
Cow::Borrowed(unsafe { self.get_non_compressed_payload() })
};
while payload.last() == Some(&0) {
// remove trailing zeros
payload.pop();
// Remove trailing zeros
if self.header.kind == FatbinFileHeader::HEADER_KIND_PTX {
match payload {
Cow::Borrowed(ref mut slice) => {
while slice.last() == Some(&0) {
*slice = &slice[..slice.len() - 1];
}
}
Cow::Owned(ref mut vec) => {
while vec.last() == Some(&0) {
vec.pop();
}
}
}
}
Ok(payload)
}
}
@ -200,35 +246,40 @@ impl<'a> FatbinFileIterator<'a> {
Self { file_buffer }
}
pub unsafe fn next(&mut self) -> Result<Option<FatbinFile<'a>>, ParseError> {
pub unsafe fn next(&mut self) -> Option<Result<FatbinFile<'a>, ParseError>> {
if self.file_buffer.len() < std::mem::size_of::<FatbinFileHeader>() {
return Ok(None);
return None;
}
let this = &*self.file_buffer.as_ptr().cast::<FatbinFileHeader>();
let next_element = self
.file_buffer
.split_at_checked(this.header_size as usize + this.padded_payload_size as usize)
.split_at_checked(
this.header_size as usize
+ u32::max(this.payload_size, this.compressed_size) as usize,
)
.map(|(_, next)| next);
self.file_buffer = next_element.unwrap_or(&[]);
Some(
ParseError::check_fields(
"FATBIN_FILE_HEADER_VERSION_CURRENT",
this.version,
[FatbinFileHeader::HEADER_VERSION_CURRENT],
)?;
Ok(Some(FatbinFile::new(this)))
)
.map(|_| FatbinFile::new(this)),
)
}
}
const MAX_MODULE_DECOMPRESSION_BOUND: usize = 64 * 1024 * 1024;
pub unsafe fn decompress_lz4(file: &FatbinFile) -> Result<Vec<u8>, FatbinError> {
pub unsafe fn decompress_lz4(file: FatbinFile) -> Result<Vec<u8>, FatbinError> {
let decompressed_size = usize::max(1024, file.header.uncompressed_payload as usize);
let mut decompressed_vec = vec![0u8; decompressed_size];
loop {
match lz4_sys::LZ4_decompress_safe(
file.get_payload().as_ptr() as *const _,
file.get_compressed_payload().as_ptr() as *const _,
decompressed_vec.as_mut_ptr() as *mut _,
file.header.payload_size as _,
file.header.compressed_size as _,
decompressed_vec.len() as _,
) {
error if error < 0 => {
@ -246,9 +297,9 @@ pub unsafe fn decompress_lz4(file: &FatbinFile) -> Result<Vec<u8>, FatbinError>
}
}
pub unsafe fn decompress_zstd(file: &FatbinFile) -> Result<Vec<u8>, FatbinError> {
pub unsafe fn decompress_zstd(file: FatbinFile) -> Result<Vec<u8>, FatbinError> {
let mut result = Vec::with_capacity(file.header.uncompressed_payload as usize);
let payload = file.get_payload();
let payload = file.get_compressed_payload();
match zstd_safe::decompress(&mut result, payload) {
Ok(actual_size) => {
result.truncate(actual_size);

View file

@ -11,7 +11,7 @@ ptx_parser = { path = "../ptx_parser" }
llvm_zluda = { path = "../llvm_zluda" }
quick-error = "1.2"
thiserror = "1.0"
bit-vec = "0.6"
bit-vec = "0.8"
half ="1.6"
bitflags = "1.2"
rustc-hash = "2.0.0"

View file

@ -1519,6 +1519,17 @@ pub enum Directive<'input, O: Operand> {
pub struct Module<'input> {
pub version: (u8, u8),
pub directives: Vec<Directive<'input, ParsedOperand<&'input str>>>,
pub invalid_directives: usize,
}
impl Module<'_> {
pub fn empty() -> Self {
Module {
version: (1, 0),
directives: Vec::new(),
invalid_directives: usize::MAX,
}
}
}
#[derive(Copy, Clone)]

View file

@ -3,8 +3,8 @@ use logos::Logos;
use ptx_parser_macros::derive_parser;
use rustc_hash::FxHashMap;
use std::fmt::Debug;
use std::iter;
use std::num::{NonZeroU8, ParseFloatError, ParseIntError};
use std::{iter, usize};
use winnow::ascii::dec_uint;
use winnow::combinator::*;
use winnow::error::{ErrMode, ErrorKind};
@ -414,7 +414,7 @@ pub fn parse_module_checked<'input>(
.map_err(|err| PtxError::Parser(err.into_inner()))
};
match parse_result {
Ok(result) if errors.is_empty() => Ok(result),
Ok(result) if errors.is_empty() && result.invalid_directives == 0 => Ok(result),
Ok(_) => Err(errors),
Err(err) => {
errors.push(err);
@ -423,6 +423,39 @@ pub fn parse_module_checked<'input>(
}
}
pub fn parse_module_unchecked<'input>(text: &'input str) -> ast::Module<'input> {
let mut lexer = Token::lexer(text);
let mut errors = Vec::new();
let mut tokens = Vec::new();
loop {
let maybe_token = match lexer.next() {
Some(maybe_token) => maybe_token,
None => break,
};
match maybe_token {
Ok(token) => tokens.push((token, lexer.span())),
Err(mut err) => {
err.0 = lexer.span();
errors.push(PtxError::from(err))
}
}
}
if !errors.is_empty() {
return ast::Module::empty();
}
let parse_result = {
let state = PtxParserState::new(text, &mut errors);
let parser = PtxParser {
state,
input: &tokens[..],
};
module
.parse(parser)
.map_err(|err| PtxError::Parser(err.into_inner()))
};
parse_result.unwrap_or(ast::Module::empty())
}
fn module<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::Module<'input>> {
trace(
"module",
@ -430,13 +463,16 @@ fn module<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::Module
version,
target,
opt(address_size),
repeat_without_none(directive),
repeat_without_none_and_count(directive),
eof,
)
.map(|(version, _, _, directives, _)| ast::Module {
.map(
|(version, _, _, (directives, invalid_directives), _)| ast::Module {
version,
directives,
}),
invalid_directives,
},
),
)
.parse_next(stream)
}
@ -471,7 +507,8 @@ fn shader_model<'a>(stream: &mut &str) -> PResult<(u32, Option<char>)> {
fn directive<'a, 'input>(
stream: &mut PtxParser<'a, 'input>,
) -> PResult<Option<ast::Directive<'input, ast::ParsedOperand<&'input str>>>> {
trace(
let errors = stream.state.errors.len();
let directive = trace(
"directive",
with_recovery(
alt((
@ -501,7 +538,11 @@ fn directive<'a, 'input>(
)
.map(Option::flatten),
)
.parse_next(stream)
.parse_next(stream)?;
if errors != stream.state.errors.len() {
return Ok(None);
}
Ok(directive)
}
fn module_variable<'a, 'input>(
@ -1230,6 +1271,25 @@ fn repeat_without_none<Input: Stream, Output, Error: ParserError<Input>>(
)
}
fn repeat_without_none_and_count<Input: Stream, Output, Error: ParserError<Input>>(
parser: impl Parser<Input, Option<Output>, Error>,
) -> impl Parser<Input, (Vec<Output>, usize), Error> {
trace(
"repeat_without_none_and_count",
repeat(0.., parser).fold(
|| (Vec::new(), 0),
|(mut accumulator, mut nones): (Vec<_>, usize), item| {
if let Some(item) = item {
accumulator.push(item);
} else {
nones += 1;
}
(accumulator, nones)
},
),
)
}
fn ident_literal<
'a,
'input,
@ -3789,6 +3849,7 @@ derive_parser!(
#[cfg(test)]
mod tests {
use crate::first_optional;
use crate::module;
use crate::parse_module_checked;
use crate::section;
use crate::PtxError;
@ -4086,4 +4147,38 @@ mod tests {
assert!(section.parse(stream).is_ok());
assert_eq!(errors.len(), 0);
}
#[test]
fn report_unknown_directives() {
let text = "
.version 6.5
.target sm_30
.address_size 64
.global .b32 global[4] = { unknown (1), 2, 3, 4};
.visible .entry func1()
{
st.u64 [out_addr], temp2;
ret;
}
.visible .entry func1()
{
broken_instruction;
ret;
}";
let tokens = Token::lexer(text)
.map(|t| t.map(|t| (t, Span::default())))
.collect::<Result<Vec<_>, _>>()
.unwrap();
let mut errors = Vec::new();
let stream = super::PtxParser {
input: &tokens[..],
state: PtxParserState::new(text, &mut errors),
};
let module = module.parse(stream).unwrap();
assert_eq!(module.directives.len(), 1);
assert_eq!(module.invalid_directives, 2);
}
}

View file

@ -1,7 +1,7 @@
use vergen_gix::{Emitter, GixBuilder};
fn main() {
let git = GixBuilder::all_git().unwrap();
let git = GixBuilder::default().sha(false).build().unwrap();
Emitter::default()
.add_instructions(&git)
.unwrap()

View file

@ -102,7 +102,11 @@ pub(crate) fn get_current_device() -> Result<hipDevice_t, CUerror> {
stack
.try_borrow()
.map_err(|_| CUerror::UNKNOWN)
.and_then(|s| s.last().ok_or(CUerror::UNKNOWN).map(|(_, dev)| *dev))
.and_then(|s| {
s.last()
.ok_or(CUerror::INVALID_CONTEXT)
.map(|(_, dev)| *dev)
})
})
}

View file

@ -517,3 +517,23 @@ pub(crate) unsafe fn primary_context_get_state(
*active_out = active;
Ok(())
}
#[cfg(test)]
mod tests {
use crate::tests::CudaApi;
use cuda_macros::test_cuda;
use std::{mem, ptr};
#[test_cuda]
unsafe fn primary_ctx_retain_does_not_make_it_active(api: impl CudaApi) {
api.cuInit(0);
let mut current_ctx = mem::zeroed();
api.cuCtxGetCurrent(&mut current_ctx);
assert_eq!(current_ctx.0, ptr::null_mut());
let mut primary_ctx = mem::zeroed();
api.cuDevicePrimaryCtxRetain(&mut primary_ctx, 0);
assert_ne!(primary_ctx.0, ptr::null_mut());
api.cuCtxGetCurrent(&mut current_ctx);
assert_eq!(current_ctx.0, ptr::null_mut());
}
}

View file

@ -1,4 +1,4 @@
use crate::r#impl::{context, device, function};
use crate::r#impl::{self, context, device, function};
use comgr::Comgr;
use cuda_types::cuda::*;
use hip_runtime_sys::*;
@ -154,7 +154,7 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi {
_module: *mut cuda_types::cuda::CUmodule,
_fatbinc_wrapper: *const cuda_types::dark_api::FatbincWrapper,
) -> cuda_types::cuda::CUresult {
todo!()
Err(r#impl::unimplemented())
}
unsafe extern "system" fn cudart_interface_fn2(
@ -176,11 +176,11 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi {
_arg4: *mut std::ffi::c_void,
_arg5: u32,
) -> cuda_types::cuda::CUresult {
todo!()
Err(r#impl::unimplemented())
}
unsafe extern "system" fn cudart_interface_fn7(_arg1: usize) -> cuda_types::cuda::CUresult {
todo!()
Err(r#impl::unimplemented())
}
unsafe extern "system" fn get_module_from_cubin_ext2(
@ -190,7 +190,7 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi {
_arg4: *mut std::ffi::c_void,
_arg5: u32,
) -> cuda_types::cuda::CUresult {
todo!()
Err(r#impl::unimplemented())
}
unsafe extern "system" fn get_unknown_buffer1(
@ -276,7 +276,7 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi {
_flags: ::std::os::raw::c_uint,
_dev: cuda_types::cuda::CUdevice,
) -> cuda_types::cuda::CUresult {
todo!()
Err(r#impl::unimplemented())
}
unsafe extern "system" fn heap_alloc(
@ -284,14 +284,14 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi {
_arg2: usize,
_arg3: usize,
) -> cuda_types::cuda::CUresult {
todo!()
Err(r#impl::unimplemented())
}
unsafe extern "system" fn heap_free(
_heap_alloc_record_ptr: *const std::ffi::c_void,
_arg2: *mut usize,
) -> cuda_types::cuda::CUresult {
todo!()
Err(r#impl::unimplemented())
}
unsafe extern "system" fn device_get_attribute_ext(
@ -300,14 +300,14 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi {
_unknown: std::ffi::c_int,
_result: *mut [usize; 2],
) -> cuda_types::cuda::CUresult {
todo!()
Err(r#impl::unimplemented())
}
unsafe extern "system" fn device_get_something(
_result: *mut std::ffi::c_uchar,
_dev: cuda_types::cuda::CUdevice,
) -> cuda_types::cuda::CUresult {
todo!()
Err(r#impl::unimplemented())
}
unsafe extern "system" fn integrity_check(

View file

@ -1,10 +1,38 @@
use super::module;
use crate::r#impl::{context, driver, module};
use cuda_types::cuda::*;
use hip_runtime_sys::*;
use zluda_common::ZludaObject;
use std::{ffi::c_void, sync::OnceLock};
use zluda_common::{CodeLibraryRef, ZludaObject};
pub(crate) struct Library {
base: hipModule_t,
data: LibraryData,
modules: Vec<OnceLock<Result<hipModule_t, CUerror>>>,
}
impl Library {
pub(crate) fn get_module_for_device(&self, device: usize) -> Result<hipModule_t, CUerror> {
let module_lock = self.modules.get(device).ok_or(CUerror::INVALID_DEVICE)?;
*module_lock.get_or_init(|| match self.data {
LibraryData::Lazy(lib) => module::load_hip_module(lib),
LibraryData::Eager(()) => Err(CUerror::NOT_SUPPORTED),
})
}
}
enum LibraryData {
Lazy(CodeLibraryRef<'static>),
Eager(()),
}
impl LibraryData {
unsafe fn new(ptr: *mut c_void, static_lifetime: bool) -> Result<Self, CUerror> {
if static_lifetime {
let lib = CodeLibraryRef::try_load(ptr).map_err(|_| CUerror::INVALID_VALUE)?;
Ok(LibraryData::Lazy(lib))
} else {
Ok(LibraryData::Eager(()))
}
}
}
impl ZludaObject for Library {
@ -14,34 +42,89 @@ impl ZludaObject for Library {
type CudaHandle = CUlibrary;
fn drop_checked(&mut self) -> CUresult {
// TODO: we will want to test that we handle `cuModuleUnload` on a module that came from a library correctly, without calling `hipModuleUnload` twice.
unsafe { hipModuleUnload(self.base) }?;
// TODO: implement unloading
// TODO: we will want to test that we handle `cuModuleUnload` on a module that came from a library correctly, without calling `hipModuleUnload` twice
Ok(())
}
}
/// This implementation simply loads the code as a HIP module for now. The various JIT and library options are ignored.
pub(crate) fn load_data(
library: &mut CUlibrary,
pub(crate) unsafe fn load_data(
result: &mut CUlibrary,
code: *const ::core::ffi::c_void,
_jit_options: &mut CUjit_option,
_jit_options_values: &mut *mut ::core::ffi::c_void,
_jit_options: Option<&mut CUjit_option>,
_jit_options_values: Option<&mut *mut ::core::ffi::c_void>,
_num_jit_options: ::core::ffi::c_uint,
_library_options: &mut CUlibraryOption,
_library_option_values: &mut *mut ::core::ffi::c_void,
_num_library_options: ::core::ffi::c_uint,
library_options: Option<&mut CUlibraryOption>,
library_option_values: Option<&mut *mut ::core::ffi::c_void>,
num_library_options: ::core::ffi::c_uint,
) -> CUresult {
let hip_module = module::load_hip_module(code)?;
*library = Library { base: hip_module }.wrap();
let global_state = driver::global_state()?;
let options =
LibraryOptions::load(library_options, library_option_values, num_library_options)?;
let library = Library {
data: LibraryData::new(code as *mut c_void, options.preserve_binary)?,
modules: vec![OnceLock::new(); global_state.devices.len()],
};
*result = library.wrap();
Ok(())
}
struct LibraryOptions {
preserve_binary: bool,
}
impl LibraryOptions {
unsafe fn load(
library_options: Option<&mut CUlibraryOption>,
library_option_values: Option<&mut *mut ::core::ffi::c_void>,
num_library_options: ::core::ffi::c_uint,
) -> Result<Self, CUerror> {
if num_library_options == 0 {
return Ok(LibraryOptions {
preserve_binary: false,
});
}
let (library_options, library_option_values) =
match (library_options, library_option_values) {
(Some(library_options), Some(library_option_values)) => {
let library_options =
std::slice::from_raw_parts(library_options, num_library_options as usize);
let library_option_values = std::slice::from_raw_parts(
library_option_values,
num_library_options as usize,
);
(library_options, library_option_values)
}
_ => return Err(CUerror::INVALID_VALUE),
};
let mut preserve_binary = false;
for (option, value) in library_options
.iter()
.copied()
.zip(library_option_values.iter())
{
match option {
CUlibraryOption::CU_LIBRARY_BINARY_IS_PRESERVED => {
preserve_binary = *(value.cast::<bool>());
}
_ => return Err(CUerror::NOT_SUPPORTED),
}
}
Ok(LibraryOptions { preserve_binary })
}
}
pub(crate) unsafe fn unload(library: CUlibrary) -> CUresult {
zluda_common::drop_checked::<Library>(library)
}
pub(crate) unsafe fn get_module(out: &mut CUmodule, library: &Library) -> CUresult {
*out = module::Module { base: library.base }.wrap();
let device = context::get_current_device()?;
// TODO: lifetime is very wrong here
let library = module::Module {
base: library.get_module_for_device(device as usize)?,
};
*out = library.wrap();
Ok(())
}
@ -49,8 +132,11 @@ pub(crate) unsafe fn get_kernel(
kernel: *mut hipFunction_t,
library: &Library,
name: *const ::core::ffi::c_char,
) -> hipError_t {
hipModuleGetFunction(kernel, library.base, name)
) -> CUresult {
let device = context::get_current_device()?;
let module = library.get_module_for_device(device as usize)?;
hipModuleGetFunction(kernel, module, name)?;
Ok(())
}
pub(crate) unsafe fn get_global(
@ -58,6 +144,59 @@ pub(crate) unsafe fn get_global(
bytes: *mut usize,
library: &Library,
name: *const ::core::ffi::c_char,
) -> hipError_t {
hipModuleGetGlobal(dptr, bytes, library.base, name)
) -> CUresult {
let device = context::get_current_device()?;
let module = library.get_module_for_device(device as usize)?;
hipModuleGetGlobal(dptr, bytes, module, name)?;
Ok(())
}
#[cfg(test)]
mod tests {
use crate::tests::CudaApi;
use cuda_macros::test_cuda;
use cuda_types::cuda::{CUlibraryOption, CUresult, CUresultConsts};
use std::{
ffi::{c_void, CStr},
mem, ptr,
};
#[test_cuda]
unsafe fn library_loads_without_context(api: impl CudaApi) {
static PTX: &'static CStr = c"
.version 7.0
.target sm_70
.address_size 64
.visible .entry foobar() {
ret;
}
";
api.cuInit(0);
let mut device = mem::zeroed();
assert_eq!(
CUresult::ERROR_INVALID_CONTEXT,
api.cuCtxGetDevice_unchecked(&mut device)
);
let mut module = mem::zeroed();
assert_eq!(
CUresult::ERROR_INVALID_CONTEXT,
api.cuModuleLoadData_unchecked(&mut module, PTX.as_ptr().cast())
);
let mut library = mem::zeroed();
api.cuLibraryLoadData(
&mut library,
PTX.as_ptr().cast(),
ptr::null_mut(),
ptr::null_mut(),
0,
[CUlibraryOption::CU_LIBRARY_BINARY_IS_PRESERVED].as_mut_ptr(),
[(&true as *const bool) as *mut c_void].as_mut_ptr(),
1,
);
assert_eq!(
CUresult::ERROR_INVALID_CONTEXT,
api.cuLibraryGetModule_unchecked(&mut mem::zeroed(), library)
);
}
}

View file

@ -15,13 +15,13 @@ pub(super) mod pointer;
pub(super) mod stream;
#[cfg(debug_assertions)]
pub(crate) fn unimplemented() -> CUresult {
pub(crate) fn unimplemented() -> CUerror {
unimplemented!()
}
#[cfg(not(debug_assertions))]
pub(crate) fn unimplemented() -> CUresult {
CUresult::ERROR_NOT_SUPPORTED
pub(crate) fn unimplemented() -> CUerror {
CUerror::NOT_SUPPORTED
}
from_cuda_object!(module::Module, context::Context, library::Library);

View file

@ -1,12 +1,8 @@
use super::driver;
use cuda_types::{
cuda::*,
dark_api::{FatbinFileHeader, FatbincWrapper},
};
use dark_api::fatbin::Fatbin;
use cuda_types::{cuda::*, dark_api::FatbinFileHeader};
use hip_runtime_sys::*;
use std::{ffi::CStr, mem};
use zluda_common::ZludaObject;
use std::{borrow::Cow, ffi::CStr, mem};
use zluda_common::{CodeLibraryRef, CodeModuleRef, ZludaObject};
pub(crate) struct Module {
pub(crate) base: hipModule_t,
@ -24,47 +20,45 @@ impl ZludaObject for Module {
}
}
fn get_ptx_from_wrapped_fatbin(image: *const ::core::ffi::c_void) -> Result<Vec<u8>, CUerror> {
let fatbin = Fatbin::new(&image).map_err(|_| CUerror::UNKNOWN)?;
let mut submodules = fatbin.get_submodules().map_err(|_| CUerror::UNKNOWN)?;
while let Some(current) = submodules.next().map_err(|_| CUerror::UNKNOWN)? {
let mut files = current.get_files();
while let Some(file) = unsafe { files.next().map_err(|_| CUerror::UNKNOWN)? } {
if file.header.kind == FatbinFileHeader::HEADER_KIND_PTX {
let decompressed = unsafe { file.decompress() }.map_err(|_| CUerror::UNKNOWN)?;
return Ok(decompressed);
// get_ptx takes an `image` that can be anything we support and returns a
// String containing a ptx extracted from `image`.
fn get_ptx<'a>(image: CodeLibraryRef<'a>) -> Result<Cow<'a, str>, CUerror> {
let mut ptx_modules = Vec::new();
unsafe {
CodeLibraryRef::iterate_modules(image, |_, module| match module {
Ok(CodeModuleRef::Text(ptx)) => {
ptx_modules.push(Cow::<'a, _>::Borrowed(ptx));
}
Ok(CodeModuleRef::<'a>::File(file)) => {
if file.header.kind != FatbinFileHeader::HEADER_KIND_PTX {
return;
}
if let Ok(text) = file.get_or_decompress_content() {
if let Some(text) = cow_bytes_to_str(text) {
ptx_modules.push(text);
}
}
}
Err(CUerror::NO_BINARY_FOR_GPU)
}
/// get_ptx takes an `image` that can be either a fatbin or a NULL-terminated ptx, and returns a String containing a ptx extracted from `image`.
fn get_ptx(image: *const ::core::ffi::c_void) -> Result<String, CUerror> {
if image.is_null() {
return Err(CUerror::INVALID_VALUE);
}
let ptx = if unsafe { *(image as *const u32) } == FatbincWrapper::MAGIC {
let ptx_bytes = get_ptx_from_wrapped_fatbin(image)?;
std::str::from_utf8(&ptx_bytes)
.map_err(|_| CUerror::UNKNOWN)?
.to_owned()
} else {
unsafe { CStr::from_ptr(image.cast()) }
.to_str()
.map_err(|_| CUerror::INVALID_VALUE)?
.to_owned()
_ => {}
})
};
Ok(ptx)
// TODO: instead of getting first PTX module, try and get the best match
ptx_modules
.into_iter()
.next()
.ok_or(CUerror::NO_BINARY_FOR_GPU)
}
pub(crate) fn load_hip_module(image: *const std::ffi::c_void) -> Result<hipModule_t, CUerror> {
fn cow_bytes_to_str<'a>(data: Cow<'a, [u8]>) -> Option<Cow<'a, str>> {
match data {
Cow::Borrowed(bytes) => std::str::from_utf8(bytes).map(Cow::Borrowed).ok(),
Cow::Owned(bytes) => String::from_utf8(bytes).map(Cow::Owned).ok(),
}
}
pub(crate) fn load_hip_module(library: CodeLibraryRef) -> Result<hipModule_t, CUerror> {
let global_state = driver::global_state()?;
let text = get_ptx(image)?;
let text = get_ptx(library)?;
let hip_properties = get_hip_properties()?;
let gcn_arch = get_gcn_arch(&hip_properties)?;
let attributes = ptx::Attributes {
@ -139,7 +133,11 @@ fn compile_from_ptx_and_cache(
text: &str,
cache_with_key: &mut Option<(zluda_cache::ModuleCache, zluda_cache::ModuleKey)>,
) -> Result<Vec<u8>, CUerror> {
let ast = ptx_parser::parse_module_checked(text).map_err(|_| CUerror::NO_BINARY_FOR_GPU)?;
let ast = if cfg!(debug_assertions) {
ptx_parser::parse_module_checked(text).map_err(|_| CUerror::NO_BINARY_FOR_GPU)?
} else {
ptx_parser::parse_module_unchecked(text)
};
let llvm_module = ptx::to_llvm_module(ast, attributes).map_err(|_| CUerror::UNKNOWN)?;
let elf_module = comgr::compile_bitcode(
comgr,
@ -158,7 +156,9 @@ fn compile_from_ptx_and_cache(
}
pub(crate) fn load_data(module: &mut CUmodule, image: &std::ffi::c_void) -> CUresult {
let hip_module = load_hip_module(image)?;
let library =
unsafe { CodeLibraryRef::try_load(image) }.map_err(|_| CUerror::NO_BINARY_FOR_GPU)?;
let hip_module = load_hip_module(library)?;
*module = Module { base: hip_module }.wrap();
Ok(())
}
@ -185,6 +185,6 @@ pub(crate) fn get_global_v2(
}
pub(crate) fn get_loading_mode(mode: &mut cuda_types::cuda::CUmoduleLoadingMode) -> CUresult {
*mode = cuda_types::cuda::CUmoduleLoadingMode::CU_MODULE_EAGER_LOADING;
*mode = cuda_types::cuda::CUmoduleLoadingMode::CU_MODULE_LAZY_LOADING;
Ok(())
}

View file

@ -23,7 +23,7 @@ macro_rules! unimplemented {
#[allow(improper_ctypes)]
#[allow(improper_ctypes_definitions)]
pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
crate::r#impl::unimplemented()
Err(r#impl::unimplemented())
}
)*
};

View file

@ -8,3 +8,4 @@ edition = "2021"
cuda_types = { path = "../cuda_types" }
hip_runtime-sys = { path = "../ext/hip_runtime-sys" }
rocblas-sys = { path = "../ext/rocblas-sys" }
dark_api = { path = "../dark_api" }

View file

@ -1,10 +1,18 @@
use cuda_types::{cublas::*, cublaslt::cublasLtHandle_t, cuda::*, nvml::*};
use cuda_types::{
cublas::*,
cublaslt::cublasLtHandle_t,
cuda::*,
dark_api::{FatbinHeader, FatbincWrapper},
nvml::*,
};
use dark_api::fatbin::{Fatbin, FatbinError, FatbinFile, FatbinSubmodule};
use hip_runtime_sys::*;
use rocblas_sys::*;
use std::{
ffi::CStr,
ffi::{c_void, CStr},
mem::{self, ManuallyDrop, MaybeUninit},
ptr,
str::Utf8Error,
};
pub trait CudaErrorType {
@ -412,3 +420,156 @@ pub fn drop_checked<T: ZludaObject>(handle: T::CudaHandle) -> Result<(), T::Erro
unsafe { ManuallyDrop::drop(&mut wrapped_object) };
underlying_error
}
/*
pub struct CodeModuleRef<'a> {
pub kind: CodeModuleKind,
pub data: &'a [u8],
}
impl<'a> CodeModule<'a> {
/// Interprets `data` as a code module of some kind.
///
/// This does not validate the contents of `data`, it only looks at the headers to determine
/// what kind of data it is.
pub fn parse(data: *mut c_void) -> Result<Self, CUerror> {
if data.len() >= 4 {
let kind = match &data[0..4] {
FatbincWrapper::MAGIC => CodeModuleKind::FatbincWrapper,
FatbinHeader::MAGIC => CodeModuleKind::FatbinHeader,
elf64::header::ELFMAG => CodeModuleKind::Elf,
_ => {
if data.ends_with(&[0]) && data.iter().all(|&c| c != 0) {
CodeModuleKind::Ptx
} else {
CodeModuleKind::ForeignElf
}
}
};
Ok(CodeModule { kind, data })
} else {
Err(CUerror::INVALID_VALUE)
}
}
}
*/
// We receive module as an opaque pointer. We want to handle three different
// lifetime-related scenarios:
// * The module has a 'static lifetime, but we don't want to use it just yet
// (`cuLibraryLoadData` with CU_LIBRARY_BINARY_IS_PRESERVED = 1), we might
// never use it.
// In this case we just keep the void pointer, we can pass it to
// the consuming function later
// * The module has a non-'static lifetime, and we will use it in the future
// (`cuLibraryLoadData` with CU_LIBRARY_BINARY_IS_PRESERVED = 0)
// In this case we need to copy the data into its own buffers
// * The module lifetime is scoped to the current function. E.g. zluda_trace
// might to parse a module to inspect and save it or it's cuModuleLoadData
// In this case we need to return either the compatible ELF or the
// iterator over `Cow` with decompressed PTX strings
// Even here there are two cases:
// * The consumer is cuModuleLoadData, if it's our ELF then it wants
// to load it directly from the pointer
// * The consumer is zluda_trace, it wants to compute the length of
// the ELF and save it to a file
#[derive(Clone, Copy)]
pub enum CodeLibraryRef<'a> {
FatbincWrapper(Fatbin<'a>),
FatbinHeader(FatbinSubmodule<'a>),
Text(&'a str),
Elf(&'a c_void),
Archive(&'a c_void),
}
impl<'a> CodeLibraryRef<'a> {
const ELFMAG: [u8; 4] = *b"\x7FELF";
const AR_MAGIC: [u8; 8] = *b"!<arch>\x0A";
pub unsafe fn try_load(ptr: *const c_void) -> Result<Self, Utf8Error> {
Ok(match *ptr.cast::<[u8; 4]>() {
FatbincWrapper::MAGIC => Self::FatbincWrapper(Fatbin {
wrapper: &*(ptr.cast()),
}),
FatbinHeader::MAGIC => Self::FatbinHeader(FatbinSubmodule {
header: &*(ptr.cast()),
}),
Self::ELFMAG => Self::Elf(&*ptr),
_ => match *ptr.cast::<[u8; 8]>() {
Self::AR_MAGIC => Self::Archive(&*ptr),
_ => CStr::from_ptr(ptr.cast()).to_str().map(Self::Text)?,
},
})
}
pub unsafe fn iterate_modules(
self,
mut fn_: impl FnMut(Option<(usize, Option<usize>)>, Result<CodeModuleRef<'a>, FatbinError>),
) {
match self {
CodeLibraryRef::FatbincWrapper(fatbin) => {
let module_iter = fatbin.get_submodules();
match module_iter {
Ok(mut iter) => {
let mut module_index = if iter.multi_module() {
None
} else {
Some(0usize)
};
while let Some(maybe_submodule) = iter.next() {
match maybe_submodule {
Ok(submodule) => iterate_modules_fatbin_header(
|subindex, module| {
let index = match module_index {
Some(index) => (index, Some(subindex)),
None => (subindex, None),
};
fn_(Some(index), module)
},
submodule,
),
Err(err) => fn_(
module_index.map(|module_index| (module_index, None)),
Err(FatbinError::ParseFailure(err)),
),
}
module_index = module_index.map(|index| index + 1);
}
}
Err(err) => fn_(None, Err(err)),
}
}
CodeLibraryRef::FatbinHeader(submodule) => iterate_modules_fatbin_header(
|index, module| fn_(Some((index, None)), module),
submodule,
),
CodeLibraryRef::Text(text) => fn_(None, Ok(CodeModuleRef::Text(text))),
CodeLibraryRef::Elf(elf) => fn_(None, Ok(CodeModuleRef::Elf(elf))),
CodeLibraryRef::Archive(ar) => fn_(None, Ok(CodeModuleRef::Archive(ar))),
}
}
}
unsafe fn iterate_modules_fatbin_header<'x>(
mut fn_: impl FnMut(usize, Result<CodeModuleRef<'x>, FatbinError>),
submodule: FatbinSubmodule<'x>,
) {
let mut iter = submodule.get_files();
let mut index = 0;
while let Some(file) = iter.next() {
fn_(
index,
file.map(CodeModuleRef::File)
.map_err(FatbinError::ParseFailure),
);
index += 1;
}
}
#[derive(Clone, Copy)]
pub enum CodeModuleRef<'a> {
File(FatbinFile<'a>),
Text(&'a str),
Elf(*const c_void),
Archive(*const c_void),
}

View file

@ -12,6 +12,7 @@ crate-type = ["cdylib"]
ptx = { path = "../ptx" }
ptx_parser = { path = "../ptx_parser" }
zluda_trace_common = { path = "../zluda_trace_common" }
zluda_common = { path = "../zluda_common" }
format = { path = "../format" }
dark_api = { path = "../dark_api" }
regex = "1.4"

View file

@ -1,7 +1,5 @@
use ::dark_api::fatbin::FatbinFileIterator;
use ::dark_api::FnFfi;
use cuda_types::cuda::*;
use cuda_types::dark_api::FatbinHeader;
use dark_api::DarkApiState2;
use log::{CudaFunctionName, ErrorEntry};
use parking_lot::ReentrantMutex;
@ -289,9 +287,7 @@ impl DarkApiTrace {
fn_logger: &mut FnCallLog,
_result: CUresult,
) {
fn_logger.try_(|fn_logger| unsafe {
trace::record_submodules_from_wrapped_fatbin(*module, fatbinc_wrapper, fn_logger, state)
});
state.record_new_library(unsafe { *module }, fatbinc_wrapper.cast(), fn_logger)
}
fn get_module_from_cubin_ext1_post(
@ -325,9 +321,7 @@ impl DarkApiTrace {
observed: UInt::U32(arg5),
});
}
fn_logger.try_(|fn_logger| unsafe {
trace::record_submodules_from_wrapped_fatbin(*module, fatbinc_wrapper, fn_logger, state)
});
state.record_new_library(unsafe { *module }, fatbinc_wrapper.cast(), fn_logger)
}
fn get_module_from_cubin_ext2_post(
@ -361,18 +355,7 @@ impl DarkApiTrace {
observed: UInt::U32(arg5),
});
}
fn_logger.try_(|fn_logger| unsafe {
trace::record_submodules(
*module,
fn_logger,
state,
FatbinFileIterator::new(
fatbin_header
.as_ref()
.ok_or(ErrorEntry::NullPointer("FatbinHeader"))?,
),
)
});
state.record_new_library(unsafe { *module }, fatbin_header.cast(), fn_logger)
}
}
@ -1324,7 +1307,7 @@ pub(crate) fn cuModuleLoadData_Post(
fn_logger: &mut FnCallLog,
_result: CUresult,
) {
state.record_new_module(unsafe { *module }, raw_image, fn_logger)
state.record_new_library(unsafe { *module }, raw_image, fn_logger)
}
#[allow(non_snake_case)]
@ -1402,19 +1385,7 @@ pub(crate) fn cuModuleLoadFatBinary_Post(
fn_logger: &mut FnCallLog,
_result: CUresult,
) {
fn_logger.try_(|fn_logger| unsafe {
trace::record_submodules(
*module,
fn_logger,
state,
FatbinFileIterator::new(
fatbin_header
.cast::<FatbinHeader>()
.as_ref()
.ok_or(ErrorEntry::NullPointer("FatbinHeader"))?,
),
)
});
state.record_new_library(unsafe { *module }, fatbin_header.cast(), fn_logger)
}
#[allow(non_snake_case)]
@ -1427,16 +1398,7 @@ pub(crate) fn cuLibraryGetModule_Post(
) {
match state.libraries.get(&library).copied() {
None => fn_logger.log(log::ErrorEntry::UnknownLibrary(library)),
Some(code) => {
fn_logger.try_(|fn_logger| unsafe {
trace::record_submodules_from_wrapped_fatbin(
*module,
code.0.cast(),
fn_logger,
state,
)
});
}
Some(code) => state.record_new_library(unsafe { *module }, code.0, fn_logger),
}
}
@ -1459,5 +1421,5 @@ pub(crate) fn cuLibraryLoadData_Post(
.insert(unsafe { *library }, trace::CodePointer(code));
// TODO: this is not correct, but it's enough for now, we just want to
// save the binary to disk
state.record_new_module(unsafe { CUmodule((*library).0.cast()) }, code, fn_logger);
state.record_new_library(unsafe { CUmodule((*library).0.cast()) }, code, fn_logger);
}

View file

@ -1,18 +1,11 @@
use crate::{
log::{self, UInt},
trace, ErrorEntry, FnCallLog, Settings,
};
use cuda_types::{
cuda::*,
dark_api::{FatbinFileHeader, FatbinFileHeaderFlags, FatbincWrapper},
};
use dark_api::fatbin::{
decompress_lz4, decompress_zstd, Fatbin, FatbinFileIterator, FatbinSubmodule,
log::{self},
ErrorEntry, FnCallLog, Settings,
};
use cuda_types::cuda::*;
use goblin::{elf, elf32, elf64};
use rustc_hash::{FxHashMap, FxHashSet};
use std::{
borrow::Cow,
ffi::{c_void, CStr, CString},
fs::{self, File},
io::{self, Read, Write},
@ -29,8 +22,7 @@ pub(crate) struct StateTracker {
writer: DumpWriter,
pub(crate) libraries: FxHashMap<CUlibrary, CodePointer>,
saved_modules: FxHashSet<CUmodule>,
module_counter: usize,
submodule_counter: usize,
library_counter: usize,
pub(crate) override_cc: Option<(u32, u32)>,
}
@ -46,8 +38,7 @@ impl StateTracker {
writer: DumpWriter::new(settings.dump_dir.clone()),
libraries: FxHashMap::default(),
saved_modules: FxHashSet::default(),
module_counter: 0,
submodule_counter: 0,
library_counter: 0,
override_cc: settings.override_cc,
}
}
@ -78,25 +69,68 @@ impl StateTracker {
let mut module_file = fs::File::open(file_name)?;
let mut read_buff = Vec::new();
module_file.read_to_end(&mut read_buff)?;
self.record_new_module(module, read_buff.as_ptr() as *const _, fn_logger);
self.record_new_library(module, read_buff.as_ptr() as *const _, fn_logger);
Ok(())
}
pub(crate) fn record_new_library(
&mut self,
cu_module: CUmodule,
raw_image: *const c_void,
fn_logger: &mut FnCallLog,
) {
self.saved_modules.insert(cu_module);
self.library_counter += 1;
let code_ref = fn_logger.try_return(|| {
unsafe { zluda_common::CodeLibraryRef::try_load(raw_image) }
.map_err(ErrorEntry::NonUtf8ModuleText)
});
let code_ref = unwrap_some_or!(code_ref, return);
unsafe {
code_ref.iterate_modules(|index, module| match module {
Err(err) => fn_logger.log(ErrorEntry::from(err)),
Ok(zluda_common::CodeModuleRef::Elf(elf)) => match get_elf_size(elf) {
Some(len) => {
let elf_image = std::slice::from_raw_parts(elf.cast::<u8>(), len);
self.record_new_submodule(index, elf_image, fn_logger, "elf");
}
None => fn_logger.log(log::ErrorEntry::UnsupportedModule {
module: cu_module,
raw_image: elf,
kind: "ELF",
}),
},
Ok(zluda_common::CodeModuleRef::Archive(archive)) => {
fn_logger.log(log::ErrorEntry::UnsupportedModule {
module: cu_module,
raw_image: archive,
kind: "archive",
})
}
Ok(zluda_common::CodeModuleRef::File(file)) => {
if let Some(buffer) = fn_logger
.try_(|_| file.get_or_decompress_content().map_err(ErrorEntry::from))
{
self.record_new_submodule(index, &*buffer, fn_logger, file.kind());
}
}
Ok(zluda_common::CodeModuleRef::Text(ptx)) => {
self.record_new_submodule(index, ptx.as_bytes(), fn_logger, "ptx");
}
});
};
}
pub(crate) fn record_new_submodule(
&mut self,
module: CUmodule,
index: Option<(usize, Option<usize>)>,
submodule: &[u8],
fn_logger: &mut FnCallLog,
type_: &'static str,
) {
if self.saved_modules.insert(module) {
self.module_counter += 1;
self.submodule_counter = 0;
}
self.submodule_counter += 1;
fn_logger.log_io_error(self.writer.save_module(
self.module_counter,
Some(self.submodule_counter),
self.library_counter,
index,
submodule,
type_,
));
@ -107,8 +141,8 @@ impl StateTracker {
Err(e) => fn_logger.log(log::ErrorEntry::NonUtf8ModuleText(e)),
Ok(submodule_text) => self.try_parse_and_record_kernels(
fn_logger,
self.module_counter,
Some(self.submodule_counter),
self.library_counter,
index,
submodule_text,
),
},
@ -116,80 +150,11 @@ impl StateTracker {
}
}
pub(crate) fn record_new_module(
&mut self,
module: CUmodule,
raw_image: *const c_void,
fn_logger: &mut FnCallLog,
) {
self.module_counter += 1;
if unsafe { *(raw_image as *const [u8; 4]) } == *elf64::header::ELFMAG {
self.saved_modules.insert(module);
match unsafe { get_elf_size(raw_image) } {
Some(len) => {
let elf_image =
unsafe { std::slice::from_raw_parts(raw_image.cast::<u8>(), len) };
self.record_new_submodule(module, elf_image, fn_logger, "elf");
}
None => fn_logger.log(log::ErrorEntry::UnsupportedModule {
module,
raw_image,
kind: "ELF",
}),
}
} else if unsafe { *(raw_image as *const [u8; 8]) } == *goblin::archive::MAGIC {
self.saved_modules.insert(module);
// TODO: Figure out how to get size of archive module and write it to disk
fn_logger.log(log::ErrorEntry::UnsupportedModule {
module,
raw_image,
kind: "archive",
})
} else if unsafe { *(raw_image as *const u32) } == FatbincWrapper::MAGIC {
unsafe {
fn_logger.try_(|fn_logger| {
trace::record_submodules_from_wrapped_fatbin(
module,
raw_image as *const FatbincWrapper,
fn_logger,
self,
)
});
}
} else {
self.record_module_ptx(module, raw_image, fn_logger)
}
}
fn record_module_ptx(
&mut self,
module: CUmodule,
raw_image: *const c_void,
fn_logger: &mut FnCallLog,
) {
self.saved_modules.insert(module);
let module_text = unsafe { CStr::from_ptr(raw_image as *const _) }.to_str();
let module_text = match module_text {
Ok(m) => m,
Err(utf8_err) => {
fn_logger.log(log::ErrorEntry::NonUtf8ModuleText(utf8_err));
return;
}
};
fn_logger.log_io_error(self.writer.save_module(
self.module_counter,
None,
module_text.as_bytes(),
"ptx",
));
self.try_parse_and_record_kernels(fn_logger, self.module_counter, None, module_text);
}
fn try_parse_and_record_kernels(
&mut self,
fn_logger: &mut FnCallLog,
module_index: usize,
submodule_index: Option<usize>,
submodule_index: Option<(usize, Option<usize>)>,
module_text: &str,
) {
let errors = ptx_parser::parse_for_errors(module_text);
@ -359,7 +324,7 @@ impl DumpWriter {
fn save_module(
&self,
module_index: usize,
submodule_index: Option<usize>,
submodule_index: Option<(usize, Option<usize>)>,
buffer: &[u8],
kind: &'static str,
) -> io::Result<()> {
@ -368,7 +333,7 @@ impl DumpWriter {
Some(d) => d.clone(),
};
dump_file.push(Self::get_file_name(module_index, submodule_index, kind));
let mut file = File::create(dump_file)?;
let mut file = File::create_new(dump_file)?;
file.write_all(buffer)?;
Ok(())
}
@ -376,7 +341,7 @@ impl DumpWriter {
fn save_module_error_log<'input>(
&self,
module_index: usize,
submodule_index: Option<usize>,
submodule_index: Option<(usize, Option<usize>)>,
errors: &[ptx_parser::PtxError<'input>],
) -> io::Result<()> {
let mut log_file = match &self.dump_dir {
@ -391,92 +356,27 @@ impl DumpWriter {
Ok(())
}
fn get_file_name(module_index: usize, submodule_index: Option<usize>, kind: &str) -> String {
fn get_file_name(
module_index: usize,
submodule_index: Option<(usize, Option<usize>)>,
kind: &str,
) -> String {
match submodule_index {
None => {
format!("module_{:04}.{:02}", module_index, kind)
}
Some(submodule_index) => {
format!("module_{:04}_{:02}.{}", module_index, submodule_index, kind)
Some((sub_index, None)) => {
format!("module_{:04}_{:02}.{}", module_index, sub_index + 1, kind)
}
Some((sub_index, Some(subsub_index))) => {
format!(
"module_{:04}_{:02}_{:02}.{}",
module_index,
sub_index + 1,
subsub_index + 1,
kind
)
}
}
}
}
pub(crate) unsafe fn record_submodules_from_wrapped_fatbin(
module: CUmodule,
fatbinc_wrapper: *const FatbincWrapper,
fn_logger: &mut FnCallLog,
state: &mut StateTracker,
) -> Result<(), ErrorEntry> {
let fatbin = Fatbin::new(&fatbinc_wrapper).map_err(ErrorEntry::from)?;
let mut submodules = fatbin.get_submodules()?;
while let Some(current) = submodules.next()? {
record_submodules_from_fatbin(module, current, fn_logger, state)?;
}
Ok(())
}
pub(crate) unsafe fn record_submodules_from_fatbin(
module: CUmodule,
submodule: FatbinSubmodule,
logger: &mut FnCallLog,
state: &mut StateTracker,
) -> Result<(), ErrorEntry> {
record_submodules(module, logger, state, submodule.get_files())?;
Ok(())
}
pub(crate) unsafe fn record_submodules(
module: CUmodule,
fn_logger: &mut FnCallLog,
state: &mut StateTracker,
mut files: FatbinFileIterator,
) -> Result<(), ErrorEntry> {
while let Some(file) = files.next()? {
let mut payload = if file
.header
.flags
.contains(FatbinFileHeaderFlags::CompressedLz4)
{
Cow::Owned(unwrap_some_or!(
fn_logger.try_return(|| decompress_lz4(&file).map_err(|e| e.into())),
continue
))
} else if file
.header
.flags
.contains(FatbinFileHeaderFlags::CompressedZstd)
{
Cow::Owned(unwrap_some_or!(
fn_logger.try_return(|| decompress_zstd(&file).map_err(|e| e.into())),
continue
))
} else {
Cow::Borrowed(file.get_payload())
};
match file.header.kind {
FatbinFileHeader::HEADER_KIND_PTX => {
while payload.last() == Some(&0) {
// remove trailing zeros
payload.to_mut().pop();
}
state.record_new_submodule(module, &*payload, fn_logger, "ptx")
}
FatbinFileHeader::HEADER_KIND_ELF => {
state.record_new_submodule(module, &*payload, fn_logger, "elf")
}
_ => {
fn_logger.log(log::ErrorEntry::UnexpectedBinaryField {
field_name: "FATBIN_FILE_HEADER_KIND",
expected: vec![
UInt::U16(FatbinFileHeader::HEADER_KIND_PTX),
UInt::U16(FatbinFileHeader::HEADER_KIND_ELF),
],
observed: UInt::U16(file.header.kind),
});
}
}
}
Ok(())
}