diff --git a/Cargo.lock b/Cargo.lock index 0a07aad..a6ded6c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -138,12 +138,6 @@ dependencies = [ "syn 2.0.89", ] -[[package]] -name = "bit-vec" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" - [[package]] name = "bit-vec" version = "0.8.0" @@ -363,10 +357,9 @@ dependencies = [ name = "compiler" version = "0.0.0" dependencies = [ - "amd_comgr-sys", "bpaf", "comgr", - "hip_runtime-sys", + "libloading", "ptx", "ptx_parser", "thiserror 2.0.12", @@ -462,7 +455,7 @@ dependencies = [ name = "dark_api" version = "0.0.0" dependencies = [ - "bit-vec 0.8.0", + "bit-vec", "cglue", "cuda_types", "format", @@ -2586,7 +2579,7 @@ dependencies = [ name = "ptx" version = "0.0.0" dependencies = [ - "bit-vec 0.6.3", + "bit-vec", "bitflags 1.3.2", "comgr", "cuda_macros", @@ -3797,6 +3790,7 @@ name = "zluda_common" version = "0.1.0" dependencies = [ "cuda_types", + "dark_api", "hip_runtime-sys", "rocblas-sys", ] @@ -3889,6 +3883,7 @@ dependencies = [ "unwrap_or", "wchar", "winapi", + "zluda_common", "zluda_trace_common", "zstd-safe", ] diff --git a/compiler/Cargo.toml b/compiler/Cargo.toml index 16dca14..7b4c4df 100644 --- a/compiler/Cargo.toml +++ b/compiler/Cargo.toml @@ -10,12 +10,11 @@ name = "zoc" path = "src/main.rs" [dependencies] -amd_comgr-sys = { path = "../ext/amd_comgr-sys" } bpaf = { version = "0.9.19", features = ["derive"] } comgr = { path = "../comgr" } -hip_runtime-sys = { path = "../ext/hip_runtime-sys" } ptx = { path = "../ptx" } ptx_parser = { path = "../ptx_parser" } +libloading = "0.8" thiserror = "2.0.12" [package.metadata.zluda] diff --git a/compiler/src/error.rs b/compiler/src/error.rs index f5bfe11..9da1b7e 100644 --- a/compiler/src/error.rs +++ b/compiler/src/error.rs @@ -1,15 +1,15 @@ +use ptx::TranslateError; +use ptx_parser::PtxError; use std::ffi::FromBytesUntilNulError; use std::io; use std::str::Utf8Error; -use hip_runtime_sys::hipErrorCode_t; -use ptx::TranslateError; -use ptx_parser::PtxError; - #[derive(Debug, thiserror::Error)] pub enum CompilerError { #[error("HIP error code: {0:?}")] - HipError(hipErrorCode_t), + HipError(u32), + #[error(transparent)] + Libloading(#[from] libloading::Error), #[error(transparent)] ComgrError(#[from] comgr::Error), #[error(transparent)] @@ -26,12 +26,6 @@ pub enum CompilerError { }, } -impl From for CompilerError { - fn from(error_code: hipErrorCode_t) -> Self { - CompilerError::HipError(error_code) - } -} - impl From>> for CompilerError { fn from(causes: Vec) -> Self { let errors: Vec = causes diff --git a/compiler/src/main.rs b/compiler/src/main.rs index fb8feb0..5effaaf 100644 --- a/compiler/src/main.rs +++ b/compiler/src/main.rs @@ -1,3 +1,5 @@ +use bpaf::Bpaf; +use error::CompilerError; use std::ffi::CStr; use std::fs::{self, File}; use std::io::{self, Write}; @@ -6,11 +8,7 @@ use std::process::ExitCode; use std::str; use std::{env, mem}; -use bpaf::Bpaf; - mod error; -use error::CompilerError; -use hip_runtime_sys::{hipDeviceProp_tR0600, hipGetDevicePropertiesR0600, hipInit}; const DEFAULT_ARCH: &'static str = "gfx1100"; @@ -60,12 +58,17 @@ fn main_core() -> Result<(), CompilerError> { let arch: String = match opts.arch { Some(s) => s, None => { - unsafe { hipInit(0) }?; - let mut dev_props: hipDeviceProp_tR0600 = unsafe { mem::zeroed() }; - unsafe { hipGetDevicePropertiesR0600(&mut dev_props, 0) }?; + (|| { + let runtime = hip::Runtime::load()?; + runtime.init()?; + get_gpu_arch(&runtime) + })() + .unwrap_or_else(|_| DEFAULT_ARCH.to_owned()) + /* get_gpu_arch(&mut dev_props) .map(String::from) .unwrap_or(DEFAULT_ARCH.to_owned()) + */ } }; @@ -122,12 +125,13 @@ struct LLVMArtifacts { llvm_ir: Vec, } -fn get_gpu_arch<'a>(dev_props: &'a mut hipDeviceProp_tR0600) -> Result<&'a str, CompilerError> { - unsafe { hipGetDevicePropertiesR0600(dev_props, 0) }?; +fn get_gpu_arch(runtime: &hip::Runtime) -> Result { + let mut dev_props = unsafe { mem::zeroed() }; + runtime.device_get_properties(&mut dev_props, 0)?; let gcn_arch_name = &dev_props.gcnArchName; let gcn_arch_name = unsafe { CStr::from_ptr(gcn_arch_name.as_ptr()) }; - let gcn_arch_name = gcn_arch_name.to_str(); - gcn_arch_name.map_err(CompilerError::from) + let gcn_arch_name = gcn_arch_name.to_str()?; + Ok(gcn_arch_name.to_string()) } fn write_to_file(content: &[u8], path: &Path) -> io::Result<()> { @@ -137,3 +141,316 @@ fn write_to_file(content: &[u8], path: &Path) -> io::Result<()> { println!("Wrote to {}", path.to_str().unwrap()); Ok(()) } + +mod hip { + use crate::error::CompilerError; + + // We lazy load HIP runtime because we want to work on systems with no + // HIP driver installed + pub struct Runtime(libloading::Library); + + impl Runtime { + fn hip_check(err: u32) -> Result<(), CompilerError> { + match err { + 0 => Ok(()), + err_code => Err(CompilerError::HipError(err_code)), + } + } + + pub fn load() -> Result { + #[cfg(windows)] + let lib_name = "amdhip64_6.dll\0"; + #[cfg(unix)] + let lib_name = "libamdhip64.so.6\0"; + let library = unsafe { libloading::Library::new(lib_name)? }; + Ok(Self(library)) + } + + pub fn init(&self) -> Result<(), CompilerError> { + unsafe { + let hip_init: libloading::Symbol u32> = + self.0.get(b"hipInit\0")?; + Self::hip_check(hip_init(0)) + } + } + + pub fn device_get_properties( + &self, + prop: &mut hipDeviceProp_tR0600, + device: i32, + ) -> Result<(), CompilerError> { + unsafe { + let hip_get_device_properties: libloading::Symbol< + unsafe extern "C" fn(*mut hipDeviceProp_tR0600, i32) -> u32, + > = self.0.get(b"hipGetDevicePropertiesR0600\0")?; + Self::hip_check(hip_get_device_properties(prop, device)) + } + } + } + + #[allow(non_snake_case, non_camel_case_types)] + #[repr(C)] + #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] + pub struct hipDeviceProp_tR0600 { + ///< Device name. + pub name: [::core::ffi::c_char; 256usize], + ///< UUID of a device + pub uuid: hipUUID, + ///< 8-byte unique identifier. Only valid on windows + pub luid: [::core::ffi::c_char; 8usize], + ///< LUID node mask + pub luidDeviceNodeMask: ::core::ffi::c_uint, + ///< Size of global memory region (in bytes). + pub totalGlobalMem: usize, + ///< Size of shared memory per block (in bytes). + pub sharedMemPerBlock: usize, + ///< Registers per block. + pub regsPerBlock: ::core::ffi::c_int, + ///< Warp size. + pub warpSize: ::core::ffi::c_int, + /**< Maximum pitch in bytes allowed by memory copies + < pitched memory*/ + pub memPitch: usize, + ///< Max work items per work group or workgroup max size. + pub maxThreadsPerBlock: ::core::ffi::c_int, + ///< Max number of threads in each dimension (XYZ) of a block. + pub maxThreadsDim: [::core::ffi::c_int; 3usize], + ///< Max grid dimensions (XYZ). + pub maxGridSize: [::core::ffi::c_int; 3usize], + ///< Max clock frequency of the multiProcessors in khz. + pub clockRate: ::core::ffi::c_int, + /**< Size of shared constant memory region on the device + < (in bytes).*/ + pub totalConstMem: usize, + /**< Major compute capability. On HCC, this is an approximation and features may + < differ from CUDA CC. See the arch feature flags for portable ways to query + < feature caps.*/ + pub major: ::core::ffi::c_int, + /**< Minor compute capability. On HCC, this is an approximation and features may + < differ from CUDA CC. See the arch feature flags for portable ways to query + < feature caps.*/ + pub minor: ::core::ffi::c_int, + ///< Alignment requirement for textures + pub textureAlignment: usize, + ///< Pitch alignment requirement for texture references bound to + pub texturePitchAlignment: usize, + ///< Deprecated. Use asyncEngineCount instead + pub deviceOverlap: ::core::ffi::c_int, + ///< Number of multi-processors (compute units). + pub multiProcessorCount: ::core::ffi::c_int, + ///< Run time limit for kernels executed on the device + pub kernelExecTimeoutEnabled: ::core::ffi::c_int, + ///< APU vs dGPU + pub integrated: ::core::ffi::c_int, + ///< Check whether HIP can map host memory + pub canMapHostMemory: ::core::ffi::c_int, + ///< Compute mode. + pub computeMode: ::core::ffi::c_int, + ///< Maximum number of elements in 1D images + pub maxTexture1D: ::core::ffi::c_int, + ///< Maximum 1D mipmap texture size + pub maxTexture1DMipmap: ::core::ffi::c_int, + ///< Maximum size for 1D textures bound to linear memory + pub maxTexture1DLinear: ::core::ffi::c_int, + ///< Maximum dimensions (width, height) of 2D images, in image elements + pub maxTexture2D: [::core::ffi::c_int; 2usize], + ///< Maximum number of elements in 2D array mipmap of images + pub maxTexture2DMipmap: [::core::ffi::c_int; 2usize], + ///< Maximum 2D tex dimensions if tex are bound to pitched memory + pub maxTexture2DLinear: [::core::ffi::c_int; 3usize], + ///< Maximum 2D tex dimensions if gather has to be performed + pub maxTexture2DGather: [::core::ffi::c_int; 2usize], + /**< Maximum dimensions (width, height, depth) of 3D images, in image + < elements*/ + pub maxTexture3D: [::core::ffi::c_int; 3usize], + ///< Maximum alternate 3D texture dims + pub maxTexture3DAlt: [::core::ffi::c_int; 3usize], + ///< Maximum cubemap texture dims + pub maxTextureCubemap: ::core::ffi::c_int, + ///< Maximum number of elements in 1D array images + pub maxTexture1DLayered: [::core::ffi::c_int; 2usize], + ///< Maximum number of elements in 2D array images + pub maxTexture2DLayered: [::core::ffi::c_int; 3usize], + ///< Maximum cubemaps layered texture dims + pub maxTextureCubemapLayered: [::core::ffi::c_int; 2usize], + ///< Maximum 1D surface size + pub maxSurface1D: ::core::ffi::c_int, + ///< Maximum 2D surface size + pub maxSurface2D: [::core::ffi::c_int; 2usize], + ///< Maximum 3D surface size + pub maxSurface3D: [::core::ffi::c_int; 3usize], + ///< Maximum 1D layered surface size + pub maxSurface1DLayered: [::core::ffi::c_int; 2usize], + ///< Maximum 2D layared surface size + pub maxSurface2DLayered: [::core::ffi::c_int; 3usize], + ///< Maximum cubemap surface size + pub maxSurfaceCubemap: ::core::ffi::c_int, + ///< Maximum cubemap layered surface size + pub maxSurfaceCubemapLayered: [::core::ffi::c_int; 2usize], + ///< Alignment requirement for surface + pub surfaceAlignment: usize, + ///< Device can possibly execute multiple kernels concurrently. + pub concurrentKernels: ::core::ffi::c_int, + ///< Device has ECC support enabled + pub ECCEnabled: ::core::ffi::c_int, + ///< PCI Bus ID. + pub pciBusID: ::core::ffi::c_int, + ///< PCI Device ID. + pub pciDeviceID: ::core::ffi::c_int, + ///< PCI Domain ID + pub pciDomainID: ::core::ffi::c_int, + ///< 1:If device is Tesla device using TCC driver, else 0 + pub tccDriver: ::core::ffi::c_int, + ///< Number of async engines + pub asyncEngineCount: ::core::ffi::c_int, + ///< Does device and host share unified address space + pub unifiedAddressing: ::core::ffi::c_int, + ///< Max global memory clock frequency in khz. + pub memoryClockRate: ::core::ffi::c_int, + ///< Global memory bus width in bits. + pub memoryBusWidth: ::core::ffi::c_int, + ///< L2 cache size. + pub l2CacheSize: ::core::ffi::c_int, + ///< Device's max L2 persisting lines in bytes + pub persistingL2CacheMaxSize: ::core::ffi::c_int, + ///< Maximum resident threads per multi-processor. + pub maxThreadsPerMultiProcessor: ::core::ffi::c_int, + ///< Device supports stream priority + pub streamPrioritiesSupported: ::core::ffi::c_int, + ///< Indicates globals are cached in L1 + pub globalL1CacheSupported: ::core::ffi::c_int, + ///< Locals are cahced in L1 + pub localL1CacheSupported: ::core::ffi::c_int, + ///< Amount of shared memory available per multiprocessor. + pub sharedMemPerMultiprocessor: usize, + ///< registers available per multiprocessor + pub regsPerMultiprocessor: ::core::ffi::c_int, + ///< Device supports allocating managed memory on this system + pub managedMemory: ::core::ffi::c_int, + ///< 1 if device is on a multi-GPU board, 0 if not. + pub isMultiGpuBoard: ::core::ffi::c_int, + ///< Unique identifier for a group of devices on same multiboard GPU + pub multiGpuBoardGroupID: ::core::ffi::c_int, + ///< Link between host and device supports native atomics + pub hostNativeAtomicSupported: ::core::ffi::c_int, + ///< Deprecated. CUDA only. + pub singleToDoublePrecisionPerfRatio: ::core::ffi::c_int, + /**< Device supports coherently accessing pageable memory + < without calling hipHostRegister on it*/ + pub pageableMemoryAccess: ::core::ffi::c_int, + /**< Device can coherently access managed memory concurrently with + < the CPU*/ + pub concurrentManagedAccess: ::core::ffi::c_int, + ///< Is compute preemption supported on the device + pub computePreemptionSupported: ::core::ffi::c_int, + /**< Device can access host registered memory with same + < address as the host*/ + pub canUseHostPointerForRegisteredMem: ::core::ffi::c_int, + ///< HIP device supports cooperative launch + pub cooperativeLaunch: ::core::ffi::c_int, + /**< HIP device supports cooperative launch on multiple + < devices*/ + pub cooperativeMultiDeviceLaunch: ::core::ffi::c_int, + ///< Per device m ax shared mem per block usable by special opt in + pub sharedMemPerBlockOptin: usize, + /**< Device accesses pageable memory via the host's + < page tables*/ + pub pageableMemoryAccessUsesHostPageTables: ::core::ffi::c_int, + /**< Host can directly access managed memory on the device + < without migration*/ + pub directManagedMemAccessFromHost: ::core::ffi::c_int, + ///< Max number of blocks on CU + pub maxBlocksPerMultiProcessor: ::core::ffi::c_int, + ///< Max value of access policy window + pub accessPolicyMaxWindowSize: ::core::ffi::c_int, + ///< Shared memory reserved by driver per block + pub reservedSharedMemPerBlock: usize, + ///< Device supports hipHostRegister + pub hostRegisterSupported: ::core::ffi::c_int, + ///< Indicates if device supports sparse hip arrays + pub sparseHipArraySupported: ::core::ffi::c_int, + /**< Device supports using the hipHostRegisterReadOnly flag + < with hipHostRegistger*/ + pub hostRegisterReadOnlySupported: ::core::ffi::c_int, + ///< Indicates external timeline semaphore support + pub timelineSemaphoreInteropSupported: ::core::ffi::c_int, + ///< Indicates if device supports hipMallocAsync and hipMemPool APIs + pub memoryPoolsSupported: ::core::ffi::c_int, + ///< Indicates device support of RDMA APIs + pub gpuDirectRDMASupported: ::core::ffi::c_int, + /**< Bitmask to be interpreted according to + < hipFlushGPUDirectRDMAWritesOptions*/ + pub gpuDirectRDMAFlushWritesOptions: ::core::ffi::c_uint, + ///< value of hipGPUDirectRDMAWritesOrdering + pub gpuDirectRDMAWritesOrdering: ::core::ffi::c_int, + ///< Bitmask of handle types support with mempool based IPC + pub memoryPoolSupportedHandleTypes: ::core::ffi::c_uint, + /**< Device supports deferred mapping HIP arrays and HIP + < mipmapped arrays*/ + pub deferredMappingHipArraySupported: ::core::ffi::c_int, + ///< Device supports IPC events + pub ipcEventSupported: ::core::ffi::c_int, + ///< Device supports cluster launch + pub clusterLaunch: ::core::ffi::c_int, + ///< Indicates device supports unified function pointers + pub unifiedFunctionPointers: ::core::ffi::c_int, + ///< CUDA Reserved. + pub reserved: [::core::ffi::c_int; 63usize], + ///< Reserved for adding new entries for HIP/CUDA. + pub hipReserved: [::core::ffi::c_int; 32usize], + ///< AMD GCN Arch Name. HIP Only. + pub gcnArchName: [::core::ffi::c_char; 256usize], + ///< Maximum Shared Memory Per CU. HIP Only. + pub maxSharedMemoryPerMultiProcessor: usize, + /**< Frequency in khz of the timer used by the device-side "clock*" + < instructions. New for HIP.*/ + pub clockInstructionRate: ::core::ffi::c_int, + ///< Architectural feature flags. New for HIP. + pub arch: hipDeviceArch_t, + ///< Addres of HDP_MEM_COHERENCY_FLUSH_CNTL register + pub hdpMemFlushCntl: *mut ::core::ffi::c_uint, + ///< Addres of HDP_REG_COHERENCY_FLUSH_CNTL register + pub hdpRegFlushCntl: *mut ::core::ffi::c_uint, + /**< HIP device supports cooperative launch on + < multiple*/ + pub cooperativeMultiDeviceUnmatchedFunc: ::core::ffi::c_int, + /**< HIP device supports cooperative launch on + < multiple*/ + pub cooperativeMultiDeviceUnmatchedGridDim: ::core::ffi::c_int, + /**< HIP device supports cooperative launch on + < multiple*/ + pub cooperativeMultiDeviceUnmatchedBlockDim: ::core::ffi::c_int, + /**< HIP device supports cooperative launch on + < multiple*/ + pub cooperativeMultiDeviceUnmatchedSharedMem: ::core::ffi::c_int, + ///< 1: if it is a large PCI bar device, else 0 + pub isLargeBar: ::core::ffi::c_int, + ///< Revision of the GPU in this device + pub asicRevision: ::core::ffi::c_int, + } + + #[allow(non_snake_case, non_camel_case_types)] + #[repr(C)] + #[repr(align(4))] + #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] + pub struct hipDeviceArch_t { + pub _bitfield_align_1: [u8; 0], + pub _bitfield_1: __BindgenBitfieldUnit<[u8; 3usize]>, + pub __bindgen_padding_0: u8, + } + + #[repr(C)] + #[derive(Copy, Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] + pub struct __BindgenBitfieldUnit { + storage: Storage, + } + + #[allow(non_camel_case_types)] + #[repr(C)] + #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] + pub struct hipUUID_t { + pub bytes: [::core::ffi::c_char; 16usize], + } + #[allow(non_camel_case_types)] + pub type hipUUID = hipUUID_t; +} diff --git a/cuda_types/src/dark_api.rs b/cuda_types/src/dark_api.rs index 442c0b6..bb7b2cf 100644 --- a/cuda_types/src/dark_api.rs +++ b/cuda_types/src/dark_api.rs @@ -45,13 +45,14 @@ pub struct FatbinHeader { } #[repr(C)] +#[derive(Debug)] pub struct FatbinFileHeader { pub kind: c_ushort, pub version: c_ushort, pub header_size: c_uint, - pub padded_payload_size: c_uint, - pub unknown0: c_uint, // check if it's written into separately pub payload_size: c_uint, + pub unknown0: c_uint, // check if it's written into separately + pub compressed_size: c_uint, pub unknown1: c_uint, pub unknown2: c_uint, pub sm_version: c_uint, @@ -63,6 +64,7 @@ pub struct FatbinFileHeader { } bitflags! { + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct FatbinFileHeaderFlags: u64 { const Is64Bit = 0x0000000000000001; const Debug = 0x0000000000000002; @@ -77,13 +79,13 @@ bitflags! { } impl FatbincWrapper { - pub const MAGIC: c_uint = 0x466243B1; + pub const MAGIC: [u8; 4] = 0x466243B1u32.to_le_bytes(); pub const VERSION_V1: c_uint = 0x1; pub const VERSION_V2: c_uint = 0x2; } impl FatbinHeader { - pub const MAGIC: c_uint = 0xBA55ED50; + pub const MAGIC: [u8; 4] = 0xBA55ED50u32.to_le_bytes(); pub const VERSION: c_ushort = 0x01; } diff --git a/dark_api/src/fatbin.rs b/dark_api/src/fatbin.rs index 8d7868d..9488499 100644 --- a/dark_api/src/fatbin.rs +++ b/dark_api/src/fatbin.rs @@ -1,6 +1,6 @@ // This file contains a higher-level interface for parsing fatbins -use std::ptr; +use std::{borrow::Cow, ptr}; use cuda_types::dark_api::*; @@ -43,7 +43,11 @@ pub fn parse_fatbinc_wrapper(ptr: &*const T) -> Result<&FatbincWrapper unsafe { ptr.cast::().as_ref() } .ok_or(ParseError::NullPointer("FatbincWrapper")) .and_then(|ptr| { - ParseError::check_fields("FATBINC_MAGIC", ptr.magic, [FatbincWrapper::MAGIC])?; + ParseError::check_fields( + "FATBINC_MAGIC", + ptr.magic, + [u32::from_ne_bytes(FatbincWrapper::MAGIC)], + )?; ParseError::check_fields( "FATBINC_VERSION", ptr.version, @@ -57,12 +61,17 @@ fn parse_fatbin_header(ptr: &*const T) -> Result<&FatbinHeader, ParseE unsafe { ptr.cast::().as_ref() } .ok_or(ParseError::NullPointer("FatbinHeader")) .and_then(|ptr| { - ParseError::check_fields("FATBIN_MAGIC", ptr.magic, [FatbinHeader::MAGIC])?; + ParseError::check_fields( + "FATBIN_MAGIC", + ptr.magic, + [u32::from_ne_bytes(FatbinHeader::MAGIC)], + )?; ParseError::check_fields("FATBIN_VERSION", ptr.version, [FatbinHeader::VERSION])?; Ok(ptr) }) } +#[derive(Clone, Copy)] pub struct Fatbin<'a> { pub wrapper: &'a FatbincWrapper, } @@ -75,7 +84,7 @@ impl<'a> Fatbin<'a> { Ok(Fatbin { wrapper }) } - pub fn get_submodules(&self) -> Result, FatbinError> { + pub fn get_submodules(self) -> Result, FatbinError> { match self.wrapper.version { FatbincWrapper::VERSION_V2 => Ok(FatbinIter::V2(FatbinSubmoduleIterator { fatbins: self.wrapper.filename_or_fatbins as *const *const std::ffi::c_void, @@ -97,6 +106,10 @@ impl<'a> Fatbin<'a> { } } +unsafe impl Send for Fatbin<'static> {} +unsafe impl Sync for Fatbin<'static> {} + +#[derive(Clone, Copy)] pub struct FatbinSubmodule<'a> { pub header: &'a FatbinHeader, // TODO: maybe make private } @@ -117,9 +130,13 @@ pub enum FatbinIter<'a> { } impl<'a> FatbinIter<'a> { - pub fn next(&mut self) -> Result>, ParseError> { + pub fn multi_module(&self) -> bool { + matches!(self, FatbinIter::V2(_)) + } + + pub fn next(&mut self) -> Option, ParseError>> { match self { - FatbinIter::V1(opt) => Ok(opt.take()), + FatbinIter::V1(opt) => Ok(opt.take()).transpose(), FatbinIter::V2(iter) => unsafe { iter.next() }, } } @@ -131,19 +148,23 @@ pub struct FatbinSubmoduleIterator<'a> { } impl<'a> FatbinSubmoduleIterator<'a> { - pub unsafe fn next(&mut self) -> Result>, ParseError> { + pub unsafe fn next(&mut self) -> Option, ParseError>> { if *self.fatbins != ptr::null() { let header = *self.fatbins as *const FatbinHeader; self.fatbins = self.fatbins.add(1); - Ok(Some(FatbinSubmodule::new(header.as_ref().ok_or( - ParseError::NullPointer("FatbinSubmoduleIterator"), - )?))) + Some( + header + .as_ref() + .ok_or(ParseError::NullPointer("FatbinSubmoduleIterator")) + .map(FatbinSubmodule::new), + ) } else { - Ok(None) + None } } } +#[derive(Clone, Copy)] pub struct FatbinFile<'a> { pub header: &'a FatbinFileHeader, } @@ -153,35 +174,60 @@ impl<'a> FatbinFile<'a> { Self { header } } - pub unsafe fn get_payload(&'a self) -> &'a [u8] { + pub fn kind(&self) -> &'static str { + match self.header.kind { + FatbinFileHeader::HEADER_KIND_PTX => "ptx", + FatbinFileHeader::HEADER_KIND_ELF => "elf", + _ => "bin", + } + } + + pub unsafe fn get_non_compressed_payload(self) -> &'a [u8] { let start = std::ptr::from_ref(self.header) .cast::() .add(self.header.header_size as usize); std::slice::from_raw_parts(start, self.header.payload_size as usize) } - pub unsafe fn decompress(&'a self) -> Result, FatbinError> { + pub unsafe fn get_compressed_payload(self) -> &'a [u8] { + let start = std::ptr::from_ref(self.header) + .cast::() + .add(self.header.header_size as usize); + std::slice::from_raw_parts(start, self.header.compressed_size as usize) + } + + pub unsafe fn get_or_decompress_content(self) -> Result, FatbinError> { let mut payload = if self .header .flags .contains(FatbinFileHeaderFlags::CompressedLz4) { - unsafe { decompress_lz4(self) }? + Cow::Owned(unsafe { decompress_lz4(self) }?) } else if self .header .flags .contains(FatbinFileHeaderFlags::CompressedZstd) { - unsafe { decompress_zstd(self) }? + Cow::Owned(unsafe { decompress_zstd(self) }?) } else { - unsafe { self.get_payload().to_vec() } + Cow::Borrowed(unsafe { self.get_non_compressed_payload() }) }; - while payload.last() == Some(&0) { - // remove trailing zeros - payload.pop(); + // Remove trailing zeros + if self.header.kind == FatbinFileHeader::HEADER_KIND_PTX { + match payload { + Cow::Borrowed(ref mut slice) => { + while slice.last() == Some(&0) { + *slice = &slice[..slice.len() - 1]; + } + } + Cow::Owned(ref mut vec) => { + while vec.last() == Some(&0) { + vec.pop(); + } + } + } } - Ok(payload) } } @@ -200,35 +246,40 @@ impl<'a> FatbinFileIterator<'a> { Self { file_buffer } } - pub unsafe fn next(&mut self) -> Result>, ParseError> { + pub unsafe fn next(&mut self) -> Option, ParseError>> { if self.file_buffer.len() < std::mem::size_of::() { - return Ok(None); + return None; } let this = &*self.file_buffer.as_ptr().cast::(); let next_element = self .file_buffer - .split_at_checked(this.header_size as usize + this.padded_payload_size as usize) + .split_at_checked( + this.header_size as usize + + u32::max(this.payload_size, this.compressed_size) as usize, + ) .map(|(_, next)| next); self.file_buffer = next_element.unwrap_or(&[]); - ParseError::check_fields( - "FATBIN_FILE_HEADER_VERSION_CURRENT", - this.version, - [FatbinFileHeader::HEADER_VERSION_CURRENT], - )?; - Ok(Some(FatbinFile::new(this))) + Some( + ParseError::check_fields( + "FATBIN_FILE_HEADER_VERSION_CURRENT", + this.version, + [FatbinFileHeader::HEADER_VERSION_CURRENT], + ) + .map(|_| FatbinFile::new(this)), + ) } } const MAX_MODULE_DECOMPRESSION_BOUND: usize = 64 * 1024 * 1024; -pub unsafe fn decompress_lz4(file: &FatbinFile) -> Result, FatbinError> { +pub unsafe fn decompress_lz4(file: FatbinFile) -> Result, FatbinError> { let decompressed_size = usize::max(1024, file.header.uncompressed_payload as usize); let mut decompressed_vec = vec![0u8; decompressed_size]; loop { match lz4_sys::LZ4_decompress_safe( - file.get_payload().as_ptr() as *const _, + file.get_compressed_payload().as_ptr() as *const _, decompressed_vec.as_mut_ptr() as *mut _, - file.header.payload_size as _, + file.header.compressed_size as _, decompressed_vec.len() as _, ) { error if error < 0 => { @@ -246,9 +297,9 @@ pub unsafe fn decompress_lz4(file: &FatbinFile) -> Result, FatbinError> } } -pub unsafe fn decompress_zstd(file: &FatbinFile) -> Result, FatbinError> { +pub unsafe fn decompress_zstd(file: FatbinFile) -> Result, FatbinError> { let mut result = Vec::with_capacity(file.header.uncompressed_payload as usize); - let payload = file.get_payload(); + let payload = file.get_compressed_payload(); match zstd_safe::decompress(&mut result, payload) { Ok(actual_size) => { result.truncate(actual_size); diff --git a/ptx/Cargo.toml b/ptx/Cargo.toml index 2f9b174..c9a5a6b 100644 --- a/ptx/Cargo.toml +++ b/ptx/Cargo.toml @@ -11,7 +11,7 @@ ptx_parser = { path = "../ptx_parser" } llvm_zluda = { path = "../llvm_zluda" } quick-error = "1.2" thiserror = "1.0" -bit-vec = "0.6" +bit-vec = "0.8" half ="1.6" bitflags = "1.2" rustc-hash = "2.0.0" diff --git a/ptx_parser/Cargo.toml b/ptx_parser/Cargo.toml index 3b96ac0..6f67e93 100644 --- a/ptx_parser/Cargo.toml +++ b/ptx_parser/Cargo.toml @@ -15,4 +15,4 @@ rustc-hash = "2.0.0" strum = { version = "0.27.1", features = ["derive"] } thiserror = "1.0" winnow = { version = "0.6.18" } -#winnow = { version = "0.6.18", features = ["debug"] } +# winnow = { version = "0.6.18", features = ["debug"] } diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 38d4aed..0570eb4 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -1519,6 +1519,17 @@ pub enum Directive<'input, O: Operand> { pub struct Module<'input> { pub version: (u8, u8), pub directives: Vec>>, + pub invalid_directives: usize, +} + +impl Module<'_> { + pub fn empty() -> Self { + Module { + version: (1, 0), + directives: Vec::new(), + invalid_directives: usize::MAX, + } + } } #[derive(Copy, Clone)] diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index b482fa5..29b348a 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -3,8 +3,8 @@ use logos::Logos; use ptx_parser_macros::derive_parser; use rustc_hash::FxHashMap; use std::fmt::Debug; -use std::iter; use std::num::{NonZeroU8, ParseFloatError, ParseIntError}; +use std::{iter, usize}; use winnow::ascii::dec_uint; use winnow::combinator::*; use winnow::error::{ErrMode, ErrorKind}; @@ -414,7 +414,7 @@ pub fn parse_module_checked<'input>( .map_err(|err| PtxError::Parser(err.into_inner())) }; match parse_result { - Ok(result) if errors.is_empty() => Ok(result), + Ok(result) if errors.is_empty() && result.invalid_directives == 0 => Ok(result), Ok(_) => Err(errors), Err(err) => { errors.push(err); @@ -423,6 +423,39 @@ pub fn parse_module_checked<'input>( } } +pub fn parse_module_unchecked<'input>(text: &'input str) -> ast::Module<'input> { + let mut lexer = Token::lexer(text); + let mut errors = Vec::new(); + let mut tokens = Vec::new(); + loop { + let maybe_token = match lexer.next() { + Some(maybe_token) => maybe_token, + None => break, + }; + match maybe_token { + Ok(token) => tokens.push((token, lexer.span())), + Err(mut err) => { + err.0 = lexer.span(); + errors.push(PtxError::from(err)) + } + } + } + if !errors.is_empty() { + return ast::Module::empty(); + } + let parse_result = { + let state = PtxParserState::new(text, &mut errors); + let parser = PtxParser { + state, + input: &tokens[..], + }; + module + .parse(parser) + .map_err(|err| PtxError::Parser(err.into_inner())) + }; + parse_result.unwrap_or(ast::Module::empty()) +} + fn module<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult> { trace( "module", @@ -430,13 +463,16 @@ fn module<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult(stream: &mut &str) -> PResult<(u32, Option)> { fn directive<'a, 'input>( stream: &mut PtxParser<'a, 'input>, ) -> PResult>>> { - trace( + let errors = stream.state.errors.len(); + let directive = trace( "directive", with_recovery( alt(( @@ -501,7 +538,11 @@ fn directive<'a, 'input>( ) .map(Option::flatten), ) - .parse_next(stream) + .parse_next(stream)?; + if errors != stream.state.errors.len() { + return Ok(None); + } + Ok(directive) } fn module_variable<'a, 'input>( @@ -1230,6 +1271,25 @@ fn repeat_without_none>( ) } +fn repeat_without_none_and_count>( + parser: impl Parser, Error>, +) -> impl Parser, usize), Error> { + trace( + "repeat_without_none_and_count", + repeat(0.., parser).fold( + || (Vec::new(), 0), + |(mut accumulator, mut nones): (Vec<_>, usize), item| { + if let Some(item) = item { + accumulator.push(item); + } else { + nones += 1; + } + (accumulator, nones) + }, + ), + ) +} + fn ident_literal< 'a, 'input, @@ -3789,6 +3849,7 @@ derive_parser!( #[cfg(test)] mod tests { use crate::first_optional; + use crate::module; use crate::parse_module_checked; use crate::section; use crate::PtxError; @@ -4086,4 +4147,38 @@ mod tests { assert!(section.parse(stream).is_ok()); assert_eq!(errors.len(), 0); } + + #[test] + fn report_unknown_directives() { + let text = " + .version 6.5 + .target sm_30 + .address_size 64 + + .global .b32 global[4] = { unknown (1), 2, 3, 4}; + + .visible .entry func1() + { + st.u64 [out_addr], temp2; + ret; + } + + .visible .entry func1() + { + broken_instruction; + ret; + }"; + let tokens = Token::lexer(text) + .map(|t| t.map(|t| (t, Span::default()))) + .collect::, _>>() + .unwrap(); + let mut errors = Vec::new(); + let stream = super::PtxParser { + input: &tokens[..], + state: PtxParserState::new(text, &mut errors), + }; + let module = module.parse(stream).unwrap(); + assert_eq!(module.directives.len(), 1); + assert_eq!(module.invalid_directives, 2); + } } diff --git a/zluda/build.rs b/zluda/build.rs index 1f4e654..82508db 100644 --- a/zluda/build.rs +++ b/zluda/build.rs @@ -1,7 +1,7 @@ use vergen_gix::{Emitter, GixBuilder}; fn main() { - let git = GixBuilder::all_git().unwrap(); + let git = GixBuilder::default().sha(false).build().unwrap(); Emitter::default() .add_instructions(&git) .unwrap() diff --git a/zluda/src/impl/context.rs b/zluda/src/impl/context.rs index 4862a45..8933116 100644 --- a/zluda/src/impl/context.rs +++ b/zluda/src/impl/context.rs @@ -102,7 +102,11 @@ pub(crate) fn get_current_device() -> Result { stack .try_borrow() .map_err(|_| CUerror::UNKNOWN) - .and_then(|s| s.last().ok_or(CUerror::UNKNOWN).map(|(_, dev)| *dev)) + .and_then(|s| { + s.last() + .ok_or(CUerror::INVALID_CONTEXT) + .map(|(_, dev)| *dev) + }) }) } diff --git a/zluda/src/impl/device.rs b/zluda/src/impl/device.rs index 074fcb4..03b2dcc 100644 --- a/zluda/src/impl/device.rs +++ b/zluda/src/impl/device.rs @@ -517,3 +517,23 @@ pub(crate) unsafe fn primary_context_get_state( *active_out = active; Ok(()) } + +#[cfg(test)] +mod tests { + use crate::tests::CudaApi; + use cuda_macros::test_cuda; + use std::{mem, ptr}; + + #[test_cuda] + unsafe fn primary_ctx_retain_does_not_make_it_active(api: impl CudaApi) { + api.cuInit(0); + let mut current_ctx = mem::zeroed(); + api.cuCtxGetCurrent(&mut current_ctx); + assert_eq!(current_ctx.0, ptr::null_mut()); + let mut primary_ctx = mem::zeroed(); + api.cuDevicePrimaryCtxRetain(&mut primary_ctx, 0); + assert_ne!(primary_ctx.0, ptr::null_mut()); + api.cuCtxGetCurrent(&mut current_ctx); + assert_eq!(current_ctx.0, ptr::null_mut()); + } +} diff --git a/zluda/src/impl/driver.rs b/zluda/src/impl/driver.rs index 3a54ac7..dbddb1a 100644 --- a/zluda/src/impl/driver.rs +++ b/zluda/src/impl/driver.rs @@ -1,4 +1,4 @@ -use crate::r#impl::{context, device, function}; +use crate::r#impl::{self, context, device, function}; use comgr::Comgr; use cuda_types::cuda::*; use hip_runtime_sys::*; @@ -154,7 +154,7 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi { _module: *mut cuda_types::cuda::CUmodule, _fatbinc_wrapper: *const cuda_types::dark_api::FatbincWrapper, ) -> cuda_types::cuda::CUresult { - todo!() + Err(r#impl::unimplemented()) } unsafe extern "system" fn cudart_interface_fn2( @@ -176,11 +176,11 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi { _arg4: *mut std::ffi::c_void, _arg5: u32, ) -> cuda_types::cuda::CUresult { - todo!() + Err(r#impl::unimplemented()) } unsafe extern "system" fn cudart_interface_fn7(_arg1: usize) -> cuda_types::cuda::CUresult { - todo!() + Err(r#impl::unimplemented()) } unsafe extern "system" fn get_module_from_cubin_ext2( @@ -190,7 +190,7 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi { _arg4: *mut std::ffi::c_void, _arg5: u32, ) -> cuda_types::cuda::CUresult { - todo!() + Err(r#impl::unimplemented()) } unsafe extern "system" fn get_unknown_buffer1( @@ -276,7 +276,7 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi { _flags: ::std::os::raw::c_uint, _dev: cuda_types::cuda::CUdevice, ) -> cuda_types::cuda::CUresult { - todo!() + Err(r#impl::unimplemented()) } unsafe extern "system" fn heap_alloc( @@ -284,14 +284,14 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi { _arg2: usize, _arg3: usize, ) -> cuda_types::cuda::CUresult { - todo!() + Err(r#impl::unimplemented()) } unsafe extern "system" fn heap_free( _heap_alloc_record_ptr: *const std::ffi::c_void, _arg2: *mut usize, ) -> cuda_types::cuda::CUresult { - todo!() + Err(r#impl::unimplemented()) } unsafe extern "system" fn device_get_attribute_ext( @@ -300,14 +300,14 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi { _unknown: std::ffi::c_int, _result: *mut [usize; 2], ) -> cuda_types::cuda::CUresult { - todo!() + Err(r#impl::unimplemented()) } unsafe extern "system" fn device_get_something( _result: *mut std::ffi::c_uchar, _dev: cuda_types::cuda::CUdevice, ) -> cuda_types::cuda::CUresult { - todo!() + Err(r#impl::unimplemented()) } unsafe extern "system" fn integrity_check( diff --git a/zluda/src/impl/library.rs b/zluda/src/impl/library.rs index 6841cfc..6ca3745 100644 --- a/zluda/src/impl/library.rs +++ b/zluda/src/impl/library.rs @@ -1,10 +1,38 @@ -use super::module; +use crate::r#impl::{context, driver, module}; use cuda_types::cuda::*; use hip_runtime_sys::*; -use zluda_common::ZludaObject; +use std::{ffi::c_void, sync::OnceLock}; +use zluda_common::{CodeLibraryRef, ZludaObject}; pub(crate) struct Library { - base: hipModule_t, + data: LibraryData, + modules: Vec>>, +} + +impl Library { + pub(crate) fn get_module_for_device(&self, device: usize) -> Result { + let module_lock = self.modules.get(device).ok_or(CUerror::INVALID_DEVICE)?; + *module_lock.get_or_init(|| match self.data { + LibraryData::Lazy(lib) => module::load_hip_module(lib), + LibraryData::Eager(()) => Err(CUerror::NOT_SUPPORTED), + }) + } +} + +enum LibraryData { + Lazy(CodeLibraryRef<'static>), + Eager(()), +} + +impl LibraryData { + unsafe fn new(ptr: *mut c_void, static_lifetime: bool) -> Result { + if static_lifetime { + let lib = CodeLibraryRef::try_load(ptr).map_err(|_| CUerror::INVALID_VALUE)?; + Ok(LibraryData::Lazy(lib)) + } else { + Ok(LibraryData::Eager(())) + } + } } impl ZludaObject for Library { @@ -14,34 +42,89 @@ impl ZludaObject for Library { type CudaHandle = CUlibrary; fn drop_checked(&mut self) -> CUresult { - // TODO: we will want to test that we handle `cuModuleUnload` on a module that came from a library correctly, without calling `hipModuleUnload` twice. - unsafe { hipModuleUnload(self.base) }?; + // TODO: implement unloading + // TODO: we will want to test that we handle `cuModuleUnload` on a module that came from a library correctly, without calling `hipModuleUnload` twice Ok(()) } } -/// This implementation simply loads the code as a HIP module for now. The various JIT and library options are ignored. -pub(crate) fn load_data( - library: &mut CUlibrary, +pub(crate) unsafe fn load_data( + result: &mut CUlibrary, code: *const ::core::ffi::c_void, - _jit_options: &mut CUjit_option, - _jit_options_values: &mut *mut ::core::ffi::c_void, + _jit_options: Option<&mut CUjit_option>, + _jit_options_values: Option<&mut *mut ::core::ffi::c_void>, _num_jit_options: ::core::ffi::c_uint, - _library_options: &mut CUlibraryOption, - _library_option_values: &mut *mut ::core::ffi::c_void, - _num_library_options: ::core::ffi::c_uint, + library_options: Option<&mut CUlibraryOption>, + library_option_values: Option<&mut *mut ::core::ffi::c_void>, + num_library_options: ::core::ffi::c_uint, ) -> CUresult { - let hip_module = module::load_hip_module(code)?; - *library = Library { base: hip_module }.wrap(); + let global_state = driver::global_state()?; + let options = + LibraryOptions::load(library_options, library_option_values, num_library_options)?; + let library = Library { + data: LibraryData::new(code as *mut c_void, options.preserve_binary)?, + modules: vec![OnceLock::new(); global_state.devices.len()], + }; + *result = library.wrap(); Ok(()) } +struct LibraryOptions { + preserve_binary: bool, +} + +impl LibraryOptions { + unsafe fn load( + library_options: Option<&mut CUlibraryOption>, + library_option_values: Option<&mut *mut ::core::ffi::c_void>, + num_library_options: ::core::ffi::c_uint, + ) -> Result { + if num_library_options == 0 { + return Ok(LibraryOptions { + preserve_binary: false, + }); + } + let (library_options, library_option_values) = + match (library_options, library_option_values) { + (Some(library_options), Some(library_option_values)) => { + let library_options = + std::slice::from_raw_parts(library_options, num_library_options as usize); + let library_option_values = std::slice::from_raw_parts( + library_option_values, + num_library_options as usize, + ); + (library_options, library_option_values) + } + _ => return Err(CUerror::INVALID_VALUE), + }; + let mut preserve_binary = false; + for (option, value) in library_options + .iter() + .copied() + .zip(library_option_values.iter()) + { + match option { + CUlibraryOption::CU_LIBRARY_BINARY_IS_PRESERVED => { + preserve_binary = *(value.cast::()); + } + _ => return Err(CUerror::NOT_SUPPORTED), + } + } + Ok(LibraryOptions { preserve_binary }) + } +} + pub(crate) unsafe fn unload(library: CUlibrary) -> CUresult { zluda_common::drop_checked::(library) } pub(crate) unsafe fn get_module(out: &mut CUmodule, library: &Library) -> CUresult { - *out = module::Module { base: library.base }.wrap(); + let device = context::get_current_device()?; + // TODO: lifetime is very wrong here + let library = module::Module { + base: library.get_module_for_device(device as usize)?, + }; + *out = library.wrap(); Ok(()) } @@ -49,8 +132,11 @@ pub(crate) unsafe fn get_kernel( kernel: *mut hipFunction_t, library: &Library, name: *const ::core::ffi::c_char, -) -> hipError_t { - hipModuleGetFunction(kernel, library.base, name) +) -> CUresult { + let device = context::get_current_device()?; + let module = library.get_module_for_device(device as usize)?; + hipModuleGetFunction(kernel, module, name)?; + Ok(()) } pub(crate) unsafe fn get_global( @@ -58,6 +144,59 @@ pub(crate) unsafe fn get_global( bytes: *mut usize, library: &Library, name: *const ::core::ffi::c_char, -) -> hipError_t { - hipModuleGetGlobal(dptr, bytes, library.base, name) +) -> CUresult { + let device = context::get_current_device()?; + let module = library.get_module_for_device(device as usize)?; + hipModuleGetGlobal(dptr, bytes, module, name)?; + Ok(()) +} + +#[cfg(test)] +mod tests { + use crate::tests::CudaApi; + use cuda_macros::test_cuda; + use cuda_types::cuda::{CUlibraryOption, CUresult, CUresultConsts}; + use std::{ + ffi::{c_void, CStr}, + mem, ptr, + }; + + #[test_cuda] + unsafe fn library_loads_without_context(api: impl CudaApi) { + static PTX: &'static CStr = c" + .version 7.0 + .target sm_70 + .address_size 64 + + .visible .entry foobar() { + ret; + } + "; + api.cuInit(0); + let mut device = mem::zeroed(); + assert_eq!( + CUresult::ERROR_INVALID_CONTEXT, + api.cuCtxGetDevice_unchecked(&mut device) + ); + let mut module = mem::zeroed(); + assert_eq!( + CUresult::ERROR_INVALID_CONTEXT, + api.cuModuleLoadData_unchecked(&mut module, PTX.as_ptr().cast()) + ); + let mut library = mem::zeroed(); + api.cuLibraryLoadData( + &mut library, + PTX.as_ptr().cast(), + ptr::null_mut(), + ptr::null_mut(), + 0, + [CUlibraryOption::CU_LIBRARY_BINARY_IS_PRESERVED].as_mut_ptr(), + [(&true as *const bool) as *mut c_void].as_mut_ptr(), + 1, + ); + assert_eq!( + CUresult::ERROR_INVALID_CONTEXT, + api.cuLibraryGetModule_unchecked(&mut mem::zeroed(), library) + ); + } } diff --git a/zluda/src/impl/mod.rs b/zluda/src/impl/mod.rs index 32d3455..f73a972 100644 --- a/zluda/src/impl/mod.rs +++ b/zluda/src/impl/mod.rs @@ -15,13 +15,13 @@ pub(super) mod pointer; pub(super) mod stream; #[cfg(debug_assertions)] -pub(crate) fn unimplemented() -> CUresult { +pub(crate) fn unimplemented() -> CUerror { unimplemented!() } #[cfg(not(debug_assertions))] -pub(crate) fn unimplemented() -> CUresult { - CUresult::ERROR_NOT_SUPPORTED +pub(crate) fn unimplemented() -> CUerror { + CUerror::NOT_SUPPORTED } from_cuda_object!(module::Module, context::Context, library::Library); diff --git a/zluda/src/impl/module.rs b/zluda/src/impl/module.rs index f8db917..ae9dcf1 100644 --- a/zluda/src/impl/module.rs +++ b/zluda/src/impl/module.rs @@ -1,12 +1,8 @@ use super::driver; -use cuda_types::{ - cuda::*, - dark_api::{FatbinFileHeader, FatbincWrapper}, -}; -use dark_api::fatbin::Fatbin; +use cuda_types::{cuda::*, dark_api::FatbinFileHeader}; use hip_runtime_sys::*; -use std::{ffi::CStr, mem}; -use zluda_common::ZludaObject; +use std::{borrow::Cow, ffi::CStr, mem}; +use zluda_common::{CodeLibraryRef, CodeModuleRef, ZludaObject}; pub(crate) struct Module { pub(crate) base: hipModule_t, @@ -24,47 +20,45 @@ impl ZludaObject for Module { } } -fn get_ptx_from_wrapped_fatbin(image: *const ::core::ffi::c_void) -> Result, CUerror> { - let fatbin = Fatbin::new(&image).map_err(|_| CUerror::UNKNOWN)?; - let mut submodules = fatbin.get_submodules().map_err(|_| CUerror::UNKNOWN)?; - - while let Some(current) = submodules.next().map_err(|_| CUerror::UNKNOWN)? { - let mut files = current.get_files(); - while let Some(file) = unsafe { files.next().map_err(|_| CUerror::UNKNOWN)? } { - if file.header.kind == FatbinFileHeader::HEADER_KIND_PTX { - let decompressed = unsafe { file.decompress() }.map_err(|_| CUerror::UNKNOWN)?; - return Ok(decompressed); +// get_ptx takes an `image` that can be anything we support and returns a +// String containing a ptx extracted from `image`. +fn get_ptx<'a>(image: CodeLibraryRef<'a>) -> Result, CUerror> { + let mut ptx_modules = Vec::new(); + unsafe { + CodeLibraryRef::iterate_modules(image, |_, module| match module { + Ok(CodeModuleRef::Text(ptx)) => { + ptx_modules.push(Cow::<'a, _>::Borrowed(ptx)); } - } - } - - Err(CUerror::NO_BINARY_FOR_GPU) -} - -/// get_ptx takes an `image` that can be either a fatbin or a NULL-terminated ptx, and returns a String containing a ptx extracted from `image`. -fn get_ptx(image: *const ::core::ffi::c_void) -> Result { - if image.is_null() { - return Err(CUerror::INVALID_VALUE); - } - - let ptx = if unsafe { *(image as *const u32) } == FatbincWrapper::MAGIC { - let ptx_bytes = get_ptx_from_wrapped_fatbin(image)?; - std::str::from_utf8(&ptx_bytes) - .map_err(|_| CUerror::UNKNOWN)? - .to_owned() - } else { - unsafe { CStr::from_ptr(image.cast()) } - .to_str() - .map_err(|_| CUerror::INVALID_VALUE)? - .to_owned() + Ok(CodeModuleRef::<'a>::File(file)) => { + if file.header.kind != FatbinFileHeader::HEADER_KIND_PTX { + return; + } + if let Ok(text) = file.get_or_decompress_content() { + if let Some(text) = cow_bytes_to_str(text) { + ptx_modules.push(text); + } + } + } + _ => {} + }) }; - - Ok(ptx) + // TODO: instead of getting first PTX module, try and get the best match + ptx_modules + .into_iter() + .next() + .ok_or(CUerror::NO_BINARY_FOR_GPU) } -pub(crate) fn load_hip_module(image: *const std::ffi::c_void) -> Result { +fn cow_bytes_to_str<'a>(data: Cow<'a, [u8]>) -> Option> { + match data { + Cow::Borrowed(bytes) => std::str::from_utf8(bytes).map(Cow::Borrowed).ok(), + Cow::Owned(bytes) => String::from_utf8(bytes).map(Cow::Owned).ok(), + } +} + +pub(crate) fn load_hip_module(library: CodeLibraryRef) -> Result { let global_state = driver::global_state()?; - let text = get_ptx(image)?; + let text = get_ptx(library)?; let hip_properties = get_hip_properties()?; let gcn_arch = get_gcn_arch(&hip_properties)?; let attributes = ptx::Attributes { @@ -139,7 +133,11 @@ fn compile_from_ptx_and_cache( text: &str, cache_with_key: &mut Option<(zluda_cache::ModuleCache, zluda_cache::ModuleKey)>, ) -> Result, CUerror> { - let ast = ptx_parser::parse_module_checked(text).map_err(|_| CUerror::NO_BINARY_FOR_GPU)?; + let ast = if cfg!(debug_assertions) { + ptx_parser::parse_module_checked(text).map_err(|_| CUerror::NO_BINARY_FOR_GPU)? + } else { + ptx_parser::parse_module_unchecked(text) + }; let llvm_module = ptx::to_llvm_module(ast, attributes).map_err(|_| CUerror::UNKNOWN)?; let elf_module = comgr::compile_bitcode( comgr, @@ -158,7 +156,9 @@ fn compile_from_ptx_and_cache( } pub(crate) fn load_data(module: &mut CUmodule, image: &std::ffi::c_void) -> CUresult { - let hip_module = load_hip_module(image)?; + let library = + unsafe { CodeLibraryRef::try_load(image) }.map_err(|_| CUerror::NO_BINARY_FOR_GPU)?; + let hip_module = load_hip_module(library)?; *module = Module { base: hip_module }.wrap(); Ok(()) } @@ -185,6 +185,6 @@ pub(crate) fn get_global_v2( } pub(crate) fn get_loading_mode(mode: &mut cuda_types::cuda::CUmoduleLoadingMode) -> CUresult { - *mode = cuda_types::cuda::CUmoduleLoadingMode::CU_MODULE_EAGER_LOADING; + *mode = cuda_types::cuda::CUmoduleLoadingMode::CU_MODULE_LAZY_LOADING; Ok(()) } diff --git a/zluda/src/lib.rs b/zluda/src/lib.rs index e9559e8..8d06f10 100644 --- a/zluda/src/lib.rs +++ b/zluda/src/lib.rs @@ -23,7 +23,7 @@ macro_rules! unimplemented { #[allow(improper_ctypes)] #[allow(improper_ctypes_definitions)] pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type { - crate::r#impl::unimplemented() + Err(r#impl::unimplemented()) } )* }; diff --git a/zluda_common/Cargo.toml b/zluda_common/Cargo.toml index 4c528e5..ca70ab8 100644 --- a/zluda_common/Cargo.toml +++ b/zluda_common/Cargo.toml @@ -8,3 +8,4 @@ edition = "2021" cuda_types = { path = "../cuda_types" } hip_runtime-sys = { path = "../ext/hip_runtime-sys" } rocblas-sys = { path = "../ext/rocblas-sys" } +dark_api = { path = "../dark_api" } diff --git a/zluda_common/src/lib.rs b/zluda_common/src/lib.rs index 94c795b..c857f6b 100644 --- a/zluda_common/src/lib.rs +++ b/zluda_common/src/lib.rs @@ -1,10 +1,18 @@ -use cuda_types::{cublas::*, cublaslt::cublasLtHandle_t, cuda::*, nvml::*}; +use cuda_types::{ + cublas::*, + cublaslt::cublasLtHandle_t, + cuda::*, + dark_api::{FatbinHeader, FatbincWrapper}, + nvml::*, +}; +use dark_api::fatbin::{Fatbin, FatbinError, FatbinFile, FatbinSubmodule}; use hip_runtime_sys::*; use rocblas_sys::*; use std::{ - ffi::CStr, + ffi::{c_void, CStr}, mem::{self, ManuallyDrop, MaybeUninit}, ptr, + str::Utf8Error, }; pub trait CudaErrorType { @@ -412,3 +420,156 @@ pub fn drop_checked(handle: T::CudaHandle) -> Result<(), T::Erro unsafe { ManuallyDrop::drop(&mut wrapped_object) }; underlying_error } + +/* +pub struct CodeModuleRef<'a> { + pub kind: CodeModuleKind, + pub data: &'a [u8], +} + +impl<'a> CodeModule<'a> { + /// Interprets `data` as a code module of some kind. + /// + /// This does not validate the contents of `data`, it only looks at the headers to determine + /// what kind of data it is. + pub fn parse(data: *mut c_void) -> Result { + if data.len() >= 4 { + let kind = match &data[0..4] { + FatbincWrapper::MAGIC => CodeModuleKind::FatbincWrapper, + FatbinHeader::MAGIC => CodeModuleKind::FatbinHeader, + elf64::header::ELFMAG => CodeModuleKind::Elf, + _ => { + if data.ends_with(&[0]) && data.iter().all(|&c| c != 0) { + CodeModuleKind::Ptx + } else { + CodeModuleKind::ForeignElf + } + } + }; + Ok(CodeModule { kind, data }) + } else { + Err(CUerror::INVALID_VALUE) + } + } +} + */ + +// We receive module as an opaque pointer. We want to handle three different +// lifetime-related scenarios: +// * The module has a 'static lifetime, but we don't want to use it just yet +// (`cuLibraryLoadData` with CU_LIBRARY_BINARY_IS_PRESERVED = 1), we might +// never use it. +// In this case we just keep the void pointer, we can pass it to +// the consuming function later +// * The module has a non-'static lifetime, and we will use it in the future +// (`cuLibraryLoadData` with CU_LIBRARY_BINARY_IS_PRESERVED = 0) +// In this case we need to copy the data into its own buffers +// * The module lifetime is scoped to the current function. E.g. zluda_trace +// might to parse a module to inspect and save it or it's cuModuleLoadData +// In this case we need to return either the compatible ELF or the +// iterator over `Cow` with decompressed PTX strings +// Even here there are two cases: +// * The consumer is cuModuleLoadData, if it's our ELF then it wants +// to load it directly from the pointer +// * The consumer is zluda_trace, it wants to compute the length of +// the ELF and save it to a file +#[derive(Clone, Copy)] +pub enum CodeLibraryRef<'a> { + FatbincWrapper(Fatbin<'a>), + FatbinHeader(FatbinSubmodule<'a>), + Text(&'a str), + Elf(&'a c_void), + Archive(&'a c_void), +} + +impl<'a> CodeLibraryRef<'a> { + const ELFMAG: [u8; 4] = *b"\x7FELF"; + const AR_MAGIC: [u8; 8] = *b"!\x0A"; + + pub unsafe fn try_load(ptr: *const c_void) -> Result { + Ok(match *ptr.cast::<[u8; 4]>() { + FatbincWrapper::MAGIC => Self::FatbincWrapper(Fatbin { + wrapper: &*(ptr.cast()), + }), + FatbinHeader::MAGIC => Self::FatbinHeader(FatbinSubmodule { + header: &*(ptr.cast()), + }), + Self::ELFMAG => Self::Elf(&*ptr), + _ => match *ptr.cast::<[u8; 8]>() { + Self::AR_MAGIC => Self::Archive(&*ptr), + _ => CStr::from_ptr(ptr.cast()).to_str().map(Self::Text)?, + }, + }) + } + + pub unsafe fn iterate_modules( + self, + mut fn_: impl FnMut(Option<(usize, Option)>, Result, FatbinError>), + ) { + match self { + CodeLibraryRef::FatbincWrapper(fatbin) => { + let module_iter = fatbin.get_submodules(); + match module_iter { + Ok(mut iter) => { + let mut module_index = if iter.multi_module() { + None + } else { + Some(0usize) + }; + while let Some(maybe_submodule) = iter.next() { + match maybe_submodule { + Ok(submodule) => iterate_modules_fatbin_header( + |subindex, module| { + let index = match module_index { + Some(index) => (index, Some(subindex)), + None => (subindex, None), + }; + fn_(Some(index), module) + }, + submodule, + ), + Err(err) => fn_( + module_index.map(|module_index| (module_index, None)), + Err(FatbinError::ParseFailure(err)), + ), + } + module_index = module_index.map(|index| index + 1); + } + } + Err(err) => fn_(None, Err(err)), + } + } + CodeLibraryRef::FatbinHeader(submodule) => iterate_modules_fatbin_header( + |index, module| fn_(Some((index, None)), module), + submodule, + ), + CodeLibraryRef::Text(text) => fn_(None, Ok(CodeModuleRef::Text(text))), + CodeLibraryRef::Elf(elf) => fn_(None, Ok(CodeModuleRef::Elf(elf))), + CodeLibraryRef::Archive(ar) => fn_(None, Ok(CodeModuleRef::Archive(ar))), + } + } +} + +unsafe fn iterate_modules_fatbin_header<'x>( + mut fn_: impl FnMut(usize, Result, FatbinError>), + submodule: FatbinSubmodule<'x>, +) { + let mut iter = submodule.get_files(); + let mut index = 0; + while let Some(file) = iter.next() { + fn_( + index, + file.map(CodeModuleRef::File) + .map_err(FatbinError::ParseFailure), + ); + index += 1; + } +} + +#[derive(Clone, Copy)] +pub enum CodeModuleRef<'a> { + File(FatbinFile<'a>), + Text(&'a str), + Elf(*const c_void), + Archive(*const c_void), +} diff --git a/zluda_trace/Cargo.toml b/zluda_trace/Cargo.toml index e709f17..a6c4120 100644 --- a/zluda_trace/Cargo.toml +++ b/zluda_trace/Cargo.toml @@ -12,6 +12,7 @@ crate-type = ["cdylib"] ptx = { path = "../ptx" } ptx_parser = { path = "../ptx_parser" } zluda_trace_common = { path = "../zluda_trace_common" } +zluda_common = { path = "../zluda_common" } format = { path = "../format" } dark_api = { path = "../dark_api" } regex = "1.4" diff --git a/zluda_trace/src/lib.rs b/zluda_trace/src/lib.rs index 50603f1..f61fed6 100644 --- a/zluda_trace/src/lib.rs +++ b/zluda_trace/src/lib.rs @@ -1,7 +1,5 @@ -use ::dark_api::fatbin::FatbinFileIterator; use ::dark_api::FnFfi; use cuda_types::cuda::*; -use cuda_types::dark_api::FatbinHeader; use dark_api::DarkApiState2; use log::{CudaFunctionName, ErrorEntry}; use parking_lot::ReentrantMutex; @@ -289,9 +287,7 @@ impl DarkApiTrace { fn_logger: &mut FnCallLog, _result: CUresult, ) { - fn_logger.try_(|fn_logger| unsafe { - trace::record_submodules_from_wrapped_fatbin(*module, fatbinc_wrapper, fn_logger, state) - }); + state.record_new_library(unsafe { *module }, fatbinc_wrapper.cast(), fn_logger) } fn get_module_from_cubin_ext1_post( @@ -325,9 +321,7 @@ impl DarkApiTrace { observed: UInt::U32(arg5), }); } - fn_logger.try_(|fn_logger| unsafe { - trace::record_submodules_from_wrapped_fatbin(*module, fatbinc_wrapper, fn_logger, state) - }); + state.record_new_library(unsafe { *module }, fatbinc_wrapper.cast(), fn_logger) } fn get_module_from_cubin_ext2_post( @@ -361,18 +355,7 @@ impl DarkApiTrace { observed: UInt::U32(arg5), }); } - fn_logger.try_(|fn_logger| unsafe { - trace::record_submodules( - *module, - fn_logger, - state, - FatbinFileIterator::new( - fatbin_header - .as_ref() - .ok_or(ErrorEntry::NullPointer("FatbinHeader"))?, - ), - ) - }); + state.record_new_library(unsafe { *module }, fatbin_header.cast(), fn_logger) } } @@ -1324,7 +1307,7 @@ pub(crate) fn cuModuleLoadData_Post( fn_logger: &mut FnCallLog, _result: CUresult, ) { - state.record_new_module(unsafe { *module }, raw_image, fn_logger) + state.record_new_library(unsafe { *module }, raw_image, fn_logger) } #[allow(non_snake_case)] @@ -1402,19 +1385,7 @@ pub(crate) fn cuModuleLoadFatBinary_Post( fn_logger: &mut FnCallLog, _result: CUresult, ) { - fn_logger.try_(|fn_logger| unsafe { - trace::record_submodules( - *module, - fn_logger, - state, - FatbinFileIterator::new( - fatbin_header - .cast::() - .as_ref() - .ok_or(ErrorEntry::NullPointer("FatbinHeader"))?, - ), - ) - }); + state.record_new_library(unsafe { *module }, fatbin_header.cast(), fn_logger) } #[allow(non_snake_case)] @@ -1427,16 +1398,7 @@ pub(crate) fn cuLibraryGetModule_Post( ) { match state.libraries.get(&library).copied() { None => fn_logger.log(log::ErrorEntry::UnknownLibrary(library)), - Some(code) => { - fn_logger.try_(|fn_logger| unsafe { - trace::record_submodules_from_wrapped_fatbin( - *module, - code.0.cast(), - fn_logger, - state, - ) - }); - } + Some(code) => state.record_new_library(unsafe { *module }, code.0, fn_logger), } } @@ -1459,5 +1421,5 @@ pub(crate) fn cuLibraryLoadData_Post( .insert(unsafe { *library }, trace::CodePointer(code)); // TODO: this is not correct, but it's enough for now, we just want to // save the binary to disk - state.record_new_module(unsafe { CUmodule((*library).0.cast()) }, code, fn_logger); + state.record_new_library(unsafe { CUmodule((*library).0.cast()) }, code, fn_logger); } diff --git a/zluda_trace/src/trace.rs b/zluda_trace/src/trace.rs index 1242df6..e71aacd 100644 --- a/zluda_trace/src/trace.rs +++ b/zluda_trace/src/trace.rs @@ -1,18 +1,11 @@ use crate::{ - log::{self, UInt}, - trace, ErrorEntry, FnCallLog, Settings, -}; -use cuda_types::{ - cuda::*, - dark_api::{FatbinFileHeader, FatbinFileHeaderFlags, FatbincWrapper}, -}; -use dark_api::fatbin::{ - decompress_lz4, decompress_zstd, Fatbin, FatbinFileIterator, FatbinSubmodule, + log::{self}, + ErrorEntry, FnCallLog, Settings, }; +use cuda_types::cuda::*; use goblin::{elf, elf32, elf64}; use rustc_hash::{FxHashMap, FxHashSet}; use std::{ - borrow::Cow, ffi::{c_void, CStr, CString}, fs::{self, File}, io::{self, Read, Write}, @@ -29,8 +22,7 @@ pub(crate) struct StateTracker { writer: DumpWriter, pub(crate) libraries: FxHashMap, saved_modules: FxHashSet, - module_counter: usize, - submodule_counter: usize, + library_counter: usize, pub(crate) override_cc: Option<(u32, u32)>, } @@ -46,8 +38,7 @@ impl StateTracker { writer: DumpWriter::new(settings.dump_dir.clone()), libraries: FxHashMap::default(), saved_modules: FxHashSet::default(), - module_counter: 0, - submodule_counter: 0, + library_counter: 0, override_cc: settings.override_cc, } } @@ -78,25 +69,68 @@ impl StateTracker { let mut module_file = fs::File::open(file_name)?; let mut read_buff = Vec::new(); module_file.read_to_end(&mut read_buff)?; - self.record_new_module(module, read_buff.as_ptr() as *const _, fn_logger); + self.record_new_library(module, read_buff.as_ptr() as *const _, fn_logger); Ok(()) } + pub(crate) fn record_new_library( + &mut self, + cu_module: CUmodule, + raw_image: *const c_void, + fn_logger: &mut FnCallLog, + ) { + self.saved_modules.insert(cu_module); + self.library_counter += 1; + let code_ref = fn_logger.try_return(|| { + unsafe { zluda_common::CodeLibraryRef::try_load(raw_image) } + .map_err(ErrorEntry::NonUtf8ModuleText) + }); + let code_ref = unwrap_some_or!(code_ref, return); + unsafe { + code_ref.iterate_modules(|index, module| match module { + Err(err) => fn_logger.log(ErrorEntry::from(err)), + Ok(zluda_common::CodeModuleRef::Elf(elf)) => match get_elf_size(elf) { + Some(len) => { + let elf_image = std::slice::from_raw_parts(elf.cast::(), len); + self.record_new_submodule(index, elf_image, fn_logger, "elf"); + } + None => fn_logger.log(log::ErrorEntry::UnsupportedModule { + module: cu_module, + raw_image: elf, + kind: "ELF", + }), + }, + Ok(zluda_common::CodeModuleRef::Archive(archive)) => { + fn_logger.log(log::ErrorEntry::UnsupportedModule { + module: cu_module, + raw_image: archive, + kind: "archive", + }) + } + Ok(zluda_common::CodeModuleRef::File(file)) => { + if let Some(buffer) = fn_logger + .try_(|_| file.get_or_decompress_content().map_err(ErrorEntry::from)) + { + self.record_new_submodule(index, &*buffer, fn_logger, file.kind()); + } + } + Ok(zluda_common::CodeModuleRef::Text(ptx)) => { + self.record_new_submodule(index, ptx.as_bytes(), fn_logger, "ptx"); + } + }); + }; + } + pub(crate) fn record_new_submodule( &mut self, - module: CUmodule, + index: Option<(usize, Option)>, submodule: &[u8], fn_logger: &mut FnCallLog, type_: &'static str, ) { - if self.saved_modules.insert(module) { - self.module_counter += 1; - self.submodule_counter = 0; - } - self.submodule_counter += 1; fn_logger.log_io_error(self.writer.save_module( - self.module_counter, - Some(self.submodule_counter), + self.library_counter, + index, submodule, type_, )); @@ -107,8 +141,8 @@ impl StateTracker { Err(e) => fn_logger.log(log::ErrorEntry::NonUtf8ModuleText(e)), Ok(submodule_text) => self.try_parse_and_record_kernels( fn_logger, - self.module_counter, - Some(self.submodule_counter), + self.library_counter, + index, submodule_text, ), }, @@ -116,80 +150,11 @@ impl StateTracker { } } - pub(crate) fn record_new_module( - &mut self, - module: CUmodule, - raw_image: *const c_void, - fn_logger: &mut FnCallLog, - ) { - self.module_counter += 1; - if unsafe { *(raw_image as *const [u8; 4]) } == *elf64::header::ELFMAG { - self.saved_modules.insert(module); - match unsafe { get_elf_size(raw_image) } { - Some(len) => { - let elf_image = - unsafe { std::slice::from_raw_parts(raw_image.cast::(), len) }; - self.record_new_submodule(module, elf_image, fn_logger, "elf"); - } - None => fn_logger.log(log::ErrorEntry::UnsupportedModule { - module, - raw_image, - kind: "ELF", - }), - } - } else if unsafe { *(raw_image as *const [u8; 8]) } == *goblin::archive::MAGIC { - self.saved_modules.insert(module); - // TODO: Figure out how to get size of archive module and write it to disk - fn_logger.log(log::ErrorEntry::UnsupportedModule { - module, - raw_image, - kind: "archive", - }) - } else if unsafe { *(raw_image as *const u32) } == FatbincWrapper::MAGIC { - unsafe { - fn_logger.try_(|fn_logger| { - trace::record_submodules_from_wrapped_fatbin( - module, - raw_image as *const FatbincWrapper, - fn_logger, - self, - ) - }); - } - } else { - self.record_module_ptx(module, raw_image, fn_logger) - } - } - - fn record_module_ptx( - &mut self, - module: CUmodule, - raw_image: *const c_void, - fn_logger: &mut FnCallLog, - ) { - self.saved_modules.insert(module); - let module_text = unsafe { CStr::from_ptr(raw_image as *const _) }.to_str(); - let module_text = match module_text { - Ok(m) => m, - Err(utf8_err) => { - fn_logger.log(log::ErrorEntry::NonUtf8ModuleText(utf8_err)); - return; - } - }; - fn_logger.log_io_error(self.writer.save_module( - self.module_counter, - None, - module_text.as_bytes(), - "ptx", - )); - self.try_parse_and_record_kernels(fn_logger, self.module_counter, None, module_text); - } - fn try_parse_and_record_kernels( &mut self, fn_logger: &mut FnCallLog, module_index: usize, - submodule_index: Option, + submodule_index: Option<(usize, Option)>, module_text: &str, ) { let errors = ptx_parser::parse_for_errors(module_text); @@ -359,7 +324,7 @@ impl DumpWriter { fn save_module( &self, module_index: usize, - submodule_index: Option, + submodule_index: Option<(usize, Option)>, buffer: &[u8], kind: &'static str, ) -> io::Result<()> { @@ -368,7 +333,7 @@ impl DumpWriter { Some(d) => d.clone(), }; dump_file.push(Self::get_file_name(module_index, submodule_index, kind)); - let mut file = File::create(dump_file)?; + let mut file = File::create_new(dump_file)?; file.write_all(buffer)?; Ok(()) } @@ -376,7 +341,7 @@ impl DumpWriter { fn save_module_error_log<'input>( &self, module_index: usize, - submodule_index: Option, + submodule_index: Option<(usize, Option)>, errors: &[ptx_parser::PtxError<'input>], ) -> io::Result<()> { let mut log_file = match &self.dump_dir { @@ -391,92 +356,27 @@ impl DumpWriter { Ok(()) } - fn get_file_name(module_index: usize, submodule_index: Option, kind: &str) -> String { + fn get_file_name( + module_index: usize, + submodule_index: Option<(usize, Option)>, + kind: &str, + ) -> String { match submodule_index { None => { format!("module_{:04}.{:02}", module_index, kind) } - Some(submodule_index) => { - format!("module_{:04}_{:02}.{}", module_index, submodule_index, kind) + Some((sub_index, None)) => { + format!("module_{:04}_{:02}.{}", module_index, sub_index + 1, kind) + } + Some((sub_index, Some(subsub_index))) => { + format!( + "module_{:04}_{:02}_{:02}.{}", + module_index, + sub_index + 1, + subsub_index + 1, + kind + ) } } } } - -pub(crate) unsafe fn record_submodules_from_wrapped_fatbin( - module: CUmodule, - fatbinc_wrapper: *const FatbincWrapper, - fn_logger: &mut FnCallLog, - state: &mut StateTracker, -) -> Result<(), ErrorEntry> { - let fatbin = Fatbin::new(&fatbinc_wrapper).map_err(ErrorEntry::from)?; - let mut submodules = fatbin.get_submodules()?; - while let Some(current) = submodules.next()? { - record_submodules_from_fatbin(module, current, fn_logger, state)?; - } - Ok(()) -} - -pub(crate) unsafe fn record_submodules_from_fatbin( - module: CUmodule, - submodule: FatbinSubmodule, - logger: &mut FnCallLog, - state: &mut StateTracker, -) -> Result<(), ErrorEntry> { - record_submodules(module, logger, state, submodule.get_files())?; - Ok(()) -} - -pub(crate) unsafe fn record_submodules( - module: CUmodule, - fn_logger: &mut FnCallLog, - state: &mut StateTracker, - mut files: FatbinFileIterator, -) -> Result<(), ErrorEntry> { - while let Some(file) = files.next()? { - let mut payload = if file - .header - .flags - .contains(FatbinFileHeaderFlags::CompressedLz4) - { - Cow::Owned(unwrap_some_or!( - fn_logger.try_return(|| decompress_lz4(&file).map_err(|e| e.into())), - continue - )) - } else if file - .header - .flags - .contains(FatbinFileHeaderFlags::CompressedZstd) - { - Cow::Owned(unwrap_some_or!( - fn_logger.try_return(|| decompress_zstd(&file).map_err(|e| e.into())), - continue - )) - } else { - Cow::Borrowed(file.get_payload()) - }; - match file.header.kind { - FatbinFileHeader::HEADER_KIND_PTX => { - while payload.last() == Some(&0) { - // remove trailing zeros - payload.to_mut().pop(); - } - state.record_new_submodule(module, &*payload, fn_logger, "ptx") - } - FatbinFileHeader::HEADER_KIND_ELF => { - state.record_new_submodule(module, &*payload, fn_logger, "elf") - } - _ => { - fn_logger.log(log::ErrorEntry::UnexpectedBinaryField { - field_name: "FATBIN_FILE_HEADER_KIND", - expected: vec![ - UInt::U16(FatbinFileHeader::HEADER_KIND_PTX), - UInt::U16(FatbinFileHeader::HEADER_KIND_ELF), - ], - observed: UInt::U16(file.header.kind), - }); - } - } - } - Ok(()) -}