diff --git a/.github/ISSUE_TEMPLATE/zluda_dump.yml b/.github/ISSUE_TEMPLATE/zluda_dump.yml index ee2738a..a199cf4 100644 --- a/.github/ISSUE_TEMPLATE/zluda_dump.yml +++ b/.github/ISSUE_TEMPLATE/zluda_dump.yml @@ -45,7 +45,7 @@ body: ./train_gpt2fp32cu 4. Build and run the tests: make test_gpt2fp32cu - LD_LIBRARY_PATH= ./test_gpt2fp32cu + LD_LIBRARY_PATH= ./test_gpt2fp32cu validations: required: true - type: input diff --git a/Cargo.lock b/Cargo.lock index baddd3a..cfe4cff 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2573,7 +2573,6 @@ dependencies = [ "ptx_parser", "quick-error", "rustc-hash 2.0.0", - "serde", "smallvec", "strum 0.26.3", "strum_macros 0.26.4", diff --git a/comgr/src/lib.rs b/comgr/src/lib.rs index 9c5671b..8546203 100644 --- a/comgr/src/lib.rs +++ b/comgr/src/lib.rs @@ -219,6 +219,10 @@ pub fn compile_bitcode( compile_to_exec.set_isa_name(gcn_arch)?; compile_to_exec.set_language(Language::LlvmIr)?; let common_options = [ + c"-mllvm", + c"-ignore-tti-inline-compatible", + // c"-mllvm", + // c"-amdgpu-early-inline-all=true", // This makes no sense, but it makes ockl linking work c"-Xclang", c"-mno-link-builtin-bitcode-postopt", @@ -237,8 +241,7 @@ pub fn compile_bitcode( ] .into_iter(); let opt_options = if cfg!(debug_assertions) { - //[c"-g", c"-mllvm", c"-print-before-all", c"", c""] - [c"-g", c"", c"", c"", c""] + [c"-g", c"-mamdgpu-precise-memory-op", c"", c"", c""] } else { [ c"-g0", diff --git a/compiler/src/main.rs b/compiler/src/main.rs index 9d1a5d1..a58ad98 100644 --- a/compiler/src/main.rs +++ b/compiler/src/main.rs @@ -21,9 +21,14 @@ pub struct Options { output_dir: Option, #[bpaf(long("arch"))] - /// Target architecture + /// Target GPU architecture arch: Option, + #[bpaf(long("ignore-errors"))] + /// Try to ignore errors. This will try and produce output even if there are + /// parsing errors (e.g. an unimplemented instruction) + ignore_errors: bool, + #[bpaf(positional("filename"))] /// PTX file ptx_path: String, @@ -48,7 +53,10 @@ fn main_core() -> Result<(), CompilerError> { .unwrap_or("output"); let mut output_path = match opts.output_dir { - Some(value) => value, + Some(value) => { + std::fs::create_dir_all(&value)?; + value + } None => match ptx_path.parent() { Some(dir) => dir.to_path_buf(), None => env::current_dir()?, @@ -68,7 +76,7 @@ fn main_core() -> Result<(), CompilerError> { let ptx = fs::read(&ptx_path).map_err(CompilerError::from)?; let ptx = str::from_utf8(&ptx).map_err(CompilerError::from)?; - let llvm = ptx_to_llvm(ptx).map_err(CompilerError::from)?; + let llvm = ptx_to_llvm(opts.ignore_errors, ptx).map_err(CompilerError::from)?; write_to_file(&llvm.llvm_ir, output_path.with_extension("ll").as_path())?; @@ -92,8 +100,12 @@ fn main_core() -> Result<(), CompilerError> { Ok(()) } -fn ptx_to_llvm(ptx: &str) -> Result { - let ast = ptx_parser::parse_module_checked(ptx).map_err(CompilerError::from)?; +fn ptx_to_llvm(ignore_errors: bool, ptx: &str) -> Result { + let ast = if ignore_errors { + ptx_parser::parse_module_unchecked(ptx) + } else { + ptx_parser::parse_module_checked(ptx).map_err(CompilerError::from)? + }; let mut start = Instant::now(); let module = ptx::to_llvm_module( ast, diff --git a/docs/src/troubleshooting.md b/docs/src/troubleshooting.md index ce1189b..cc75399 100644 --- a/docs/src/troubleshooting.md +++ b/docs/src/troubleshooting.md @@ -116,7 +116,7 @@ in order to demonstrate all of zluda_trace's features. ```bash nvcc add.cu -o add -arch sm_80 -LD_LIBRARY_PATH=~/ZLUDA/target/release/trace/ ZLUDA_TRACE_DIR=/tmp/zluda ./add +LD_LIBRARY_PATH=~/ZLUDA/target/release/trace/ ZLUDA_LOG_DIR=/tmp/zluda ./add ``` The last few lines should look something like: diff --git a/ptx/Cargo.toml b/ptx/Cargo.toml index c9a5a6b..7ee6e43 100644 --- a/ptx/Cargo.toml +++ b/ptx/Cargo.toml @@ -22,7 +22,6 @@ microlp = "0.2.11" int-enum = "1.1" unwrap_or = "1.0.1" smallvec = "1.15.1" -serde = { version = "1.0.219", features = ["derive"] } [dev-dependencies] hip_runtime-sys = { path = "../ext/hip_runtime-sys" } diff --git a/ptx/lib/zluda_ptx_impl.bc b/ptx/lib/zluda_ptx_impl.bc index 2b62aeb..bc375c3 100644 Binary files a/ptx/lib/zluda_ptx_impl.bc and b/ptx/lib/zluda_ptx_impl.bc differ diff --git a/ptx/lib/zluda_ptx_impl.cpp b/ptx/lib/zluda_ptx_impl.cpp index cc1d973..6174ec1 100644 --- a/ptx/lib/zluda_ptx_impl.cpp +++ b/ptx/lib/zluda_ptx_impl.cpp @@ -1,17 +1,21 @@ // Every time this file changes it must te rebuilt, you need `rocm-llvm-dev` and `llvm-17` // `fdenormal-fp-math=dynamic` is required to make functions eligible for inlining -// /opt/rocm/llvm/bin/clang -std=c++20 -Xclang -fdenormal-fp-math=dynamic -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1010 && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1010\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | sed 's/ nneg / /g' | sed 's/ disjoint / /g' | sed '/__hip_cuid/d' | sed 's/external protected/external hidden/g' | sed 's/trunc nuw/trunc/' | sed 's/trunc nsw/trunc/' | llvm-as-17 - -o zluda_ptx_impl.bc && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc +// /opt/rocm/llvm/bin/clang -std=c++20 -Xclang -fdenormal-fp-math=dynamic -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1100 && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1100\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | sed 's/ nneg / /g' | sed 's/ disjoint / /g' | sed '/__hip_cuid/d' | sed 's/external protected/external hidden/g' | sed 's/trunc nuw/trunc/' | sed 's/trunc nsw/trunc/' | llvm-as-17 - -o zluda_ptx_impl.bc && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc #include #include #include #include +#include #include #include #define SHARED_SPACE __attribute__((address_space(3))) #define CONSTANT_SPACE __attribute__((address_space(4))) +typedef _Float16 half16 __attribute__((ext_vector_type(16))); +typedef float float8 __attribute__((ext_vector_type(8))); + #define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_##NAME #define ATTR(NAME) __ZLUDA_PTX_IMPL_ATTRIBUTE_##NAME #define DECLARE_ATTR(TYPE, NAME) \ @@ -624,4 +628,156 @@ typedef uint32_t ShflSyncResult __attribute__((ext_vector_type(2))); uint32_t x3 = load_single_matrix_trans(address, 24); return uint4::Native_vec_{x0, x1, x2, x3}; } + + static inline __device__ _Float16 top16_as_fp16(uint32_t value) { + uint16_t half_bits = static_cast((value >> 16) & 0xFFFF); + return *reinterpret_cast<_Float16*>(&half_bits); + } + static inline __device__ _Float16 bottom16_as_fp16(uint32_t value) { + uint16_t half_bits = static_cast(value & 0xFFFF); + return *reinterpret_cast<_Float16*>(&half_bits); + } + + static inline __device__ float bpermute_lane(int lane, float x) { + return __hip_ds_bpermutef(4 * lane, x); + } + static inline __device__ uint32_t bpermute_lane(int lane, uint32_t x) { + return __hip_ds_bpermute(4 * lane, x); + } + + static __device__ half16 shuffle_a(uint4::Native_vec_ a_reg) { + const unsigned lIdx = threadIdx.x; + const int lane = lIdx % 16; // Lanes 0-15 (the other 16 lanes are a duplicate in w32 mode) + half16 aFrag; + + for (int vGPR = 0; vGPR < 8; ++vGPR) { + int cudaChunk = (vGPR / 4) * 2; // will be 0 or 2 + int cudaTID = (vGPR % 4 + lane * 4) % 32; + uint32_t reg0, reg1; + // Select the two consecutive elements from a_reg: + if (cudaChunk == 0) { + reg0 = a_reg.x; + reg1 = a_reg.y; + } else { // cudaChunk==2 + reg0 = a_reg.z; + reg1 = a_reg.w; + } + uint32_t a_tmp0 = bpermute_lane(cudaTID, reg0); + uint32_t a_tmp1 = bpermute_lane(cudaTID, reg1); + uint32_t a_Frag_reg = (lane < 8) ? a_tmp0 : a_tmp1; + aFrag[2 * vGPR] = bottom16_as_fp16(a_Frag_reg); + aFrag[2 * vGPR + 1] = top16_as_fp16(a_Frag_reg); + } + return aFrag; + } + + static __device__ half16 shuffle_b(uint2::Native_vec_ b_reg) { + const unsigned lIdx = threadIdx.x; + const int lane = lIdx % 16; + half16 bFrag; + + for (int vGPR = 0; vGPR < 8; ++vGPR) { + int cudaChunk = vGPR / 4; // will be 0 or 1 + int cudaTID = vGPR % 4 + (lane * 4) % 64; + uint32_t reg = (cudaChunk == 0) ? b_reg.x : b_reg.y; + uint32_t b_Frag_reg = bpermute_lane(cudaTID, reg); + if (lane < 8) { + bFrag[2 * vGPR] = bottom16_as_fp16(b_Frag_reg); + bFrag[2 * vGPR + 1] = top16_as_fp16(b_Frag_reg); + } else { + bFrag[2 * vGPR] = 0.0f; + bFrag[2 * vGPR + 1] = 0.0f; + } + } + return bFrag; + } + + static __device__ float8 shuffle_c(float4::Native_vec_ c_reg) { + const int lIdx = (int)threadIdx.x; + float8 cFrag; + + // Loop over the eight vector GPRs. + for (int vGPR = 0; vGPR < 8; ++vGPR) { + int cudaChunk = (vGPR / 4) * 2; // will be 0 or 2: selects which pair of components to use. + int lIdx8 = (lIdx < 8) ? lIdx : lIdx - 8; + int cudaTID = (vGPR % 4) * 8 + lIdx8 / 2; + float ctmp0, ctmp1; + + if (cudaChunk == 0) { + ctmp0 = bpermute_lane(cudaTID, c_reg.x); + ctmp1 = bpermute_lane(cudaTID, c_reg.y); + } else { // cudaChunk == 2 + ctmp0 = bpermute_lane(cudaTID, c_reg.z); + ctmp1 = bpermute_lane(cudaTID, c_reg.w); + } + + // Select one of the two values based on the thread index's LSB. + cFrag[vGPR] = (lIdx & 1) ? ctmp1 : ctmp0; + + // Zero out for specific thread indices. + if ((lIdx > 7 && lIdx < 16) || (lIdx > 23 && lIdx < 32)) + cFrag[vGPR] = 0.0f; + } + return cFrag; + } + + static inline __device__ float4::Native_vec_ shuffle_d(float8 dFrag) { + const int lIdx = (int)threadIdx.x; + float4::Native_vec_ d_out; + + for (int cChunk = 0; cChunk < 4; ++cChunk) { + int r_vGPR = (cChunk / 2) * 4; + int add8 = (lIdx & 0x4) ? 8 : 0; + int r_lIdx = (cChunk % 2) + (lIdx % 8) * 2 + add8; + float d_tmp0 = bpermute_lane(r_lIdx, dFrag[r_vGPR]); + float d_tmp1 = bpermute_lane(r_lIdx, dFrag[r_vGPR + 1]); + float d_tmp2 = bpermute_lane(r_lIdx, dFrag[r_vGPR + 2]); + float d_tmp3 = bpermute_lane(r_lIdx, dFrag[r_vGPR + 3]); + float val; + if (lIdx < 8) { + val = d_tmp0; + } else if (lIdx < 16) { + val = d_tmp1; + } else if (lIdx < 24) { + val = d_tmp2; + } else { + val = d_tmp3; + } + if (cChunk == 0) d_out.x = val; + else if (cChunk == 1) d_out.y = val; + else if (cChunk == 2) d_out.z = val; + else d_out.w = val; + } + return d_out; + } + + float4::Native_vec_ FUNC(mma_sync_aligned_m16n8k16_row_col_f32_f16_f16_f32)(uint4::Native_vec_ a_reg, uint2::Native_vec_ b_reg, float4::Native_vec_ c_reg) { + // Reshuffle from Nvidia-like register layout to AMD layout: + half16 aFrag = shuffle_a(a_reg); + half16 bFrag = shuffle_b(b_reg); + float8 cFrag = shuffle_c(c_reg); + + // Call the (built‐in) 16x16 MMA instruction. It returns a float8. + float8 dFrag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(aFrag, bFrag, cFrag); + + // Unshuffle back into Nvidia expected float4 result + float4::Native_vec_ d_out = shuffle_d(dFrag); + + return d_out; + } + + float4::Native_vec_ FUNC(mma_sync_aligned_m16n8k16_row_col_f32_bf16_bf16_f32)(uint4::Native_vec_ a_reg, uint2::Native_vec_ b_reg, float4::Native_vec_ c_reg) { + // Reshuffle from Nvidia-like register layout to AMD layout: + half16 aFrag = shuffle_a(a_reg); + half16 bFrag = shuffle_b(b_reg); + float8 cFrag = shuffle_c(c_reg); + + // Call the (built‐in) 16x16 MMA instruction. It returns a float8. + float8 dFrag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(aFrag, bFrag, cFrag); + + // Unshuffle back into Nvidia expected float4 result + float4::Native_vec_ d_out = shuffle_d(dFrag); + + return d_out; + } } diff --git a/ptx/src/pass/insert_post_saturation.rs b/ptx/src/pass/insert_post_saturation.rs index 904bf37..525ae15 100644 --- a/ptx/src/pass/insert_post_saturation.rs +++ b/ptx/src/pass/insert_post_saturation.rs @@ -197,7 +197,9 @@ fn run_instruction<'input>( | ast::Instruction::Xor { .. } | ast::Instruction::Vote { .. } | ast::Instruction::ReduxSync { .. } - | ast::Instruction::LdMatrix { .. } => result.push(Statement::Instruction(instruction)), + | ast::Instruction::GridDepControl { .. } + | ast::Instruction::LdMatrix { .. } + | ast::Instruction::Mma { .. } => result.push(Statement::Instruction(instruction)), ast::Instruction::Add { data: ast::ArithDetails::Float(ast::ArithFloat { diff --git a/ptx/src/pass/instruction_mode_to_global_mode/mod.rs b/ptx/src/pass/instruction_mode_to_global_mode/mod.rs index a4c2dc4..d365e29 100644 --- a/ptx/src/pass/instruction_mode_to_global_mode/mod.rs +++ b/ptx/src/pass/instruction_mode_to_global_mode/mod.rs @@ -1855,7 +1855,9 @@ fn get_modes(inst: &ast::Instruction) -> InstructionModes { | ast::Instruction::AtomCas { .. } | ast::Instruction::Vote { .. } | ast::Instruction::ReduxSync { .. } - | ast::Instruction::LdMatrix { .. } => InstructionModes::none(), + | ast::Instruction::GridDepControl { .. } + | ast::Instruction::LdMatrix { .. } + | ast::Instruction::Mma { .. } => InstructionModes::none(), ast::Instruction::Add { data: ast::ArithDetails::Integer(_), .. diff --git a/ptx/src/pass/llvm/emit.rs b/ptx/src/pass/llvm/emit.rs index c811a53..144f5e6 100644 --- a/ptx/src/pass/llvm/emit.rs +++ b/ptx/src/pass/llvm/emit.rs @@ -153,6 +153,9 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { for (i, param) in method.input_arguments.iter().enumerate() { let value = unsafe { LLVMGetParam(fn_, i as u32) }; let name = self.resolver.get_or_add(param.name); + if let Some(align) = param.align { + unsafe { LLVMSetParamAlignment(value, align) }; + } unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) }; self.resolver.register(param.name, value); if method.is_kernel { @@ -519,6 +522,7 @@ impl<'a> MethodEmitContext<'a> { ast::Instruction::CpAsyncCommitGroup {} => Ok(()), // nop ast::Instruction::CpAsyncWaitGroup { .. } => Ok(()), // nop ast::Instruction::CpAsyncWaitAll { .. } => Ok(()), // nop + ast::Instruction::GridDepControl { .. } => Ok(()), // nop // replaced by a function call ast::Instruction::Bfe { .. } | ast::Instruction::Bar { .. } @@ -529,7 +533,8 @@ impl<'a> MethodEmitContext<'a> { | ast::Instruction::Vote { .. } | ast::Instruction::Nanosleep { .. } | ast::Instruction::ReduxSync { .. } - | ast::Instruction::LdMatrix { .. } => return Err(error_unreachable()), + | ast::Instruction::LdMatrix { .. } + | ast::Instruction::Mma { .. } => return Err(error_unreachable()), } } diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 0b9ef79..4f87dc3 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -51,7 +51,7 @@ quick_error! { } /// GPU attributes needed at compile time. -#[derive(serde::Serialize)] +#[derive(Copy, Clone)] pub struct Attributes { /// Clock frequency in kHz. pub clock_rate: u32, diff --git a/ptx/src/pass/replace_instructions_with_functions.rs b/ptx/src/pass/replace_instructions_with_functions.rs index 19e16e7..f7c976e 100644 --- a/ptx/src/pass/replace_instructions_with_functions.rs +++ b/ptx/src/pass/replace_instructions_with_functions.rs @@ -351,6 +351,35 @@ fn run_instruction<'input>( let name = "sqrt_rn_ftz_f32"; to_call(resolver, fn_declarations, name.into(), i)? } + i @ ptx_parser::Instruction::Mma { + data: + ast::MmaDetails { + alayout, + blayout, + dtype_scalar, + atype_scalar, + btype_scalar, + ctype_scalar, + }, + .. + } => { + let name = format!( + "mma_sync_aligned_m16n8k16_{}_{}_{}_{}_{}_{}", + match alayout { + ast::MatrixLayout::Row => "row", + ast::MatrixLayout::Col => "col", + }, + match blayout { + ast::MatrixLayout::Row => "row", + ast::MatrixLayout::Col => "col", + }, + scalar_to_ptx_name(dtype_scalar), + scalar_to_ptx_name(atype_scalar), + scalar_to_ptx_name(btype_scalar), + scalar_to_ptx_name(ctype_scalar), + ); + to_call(resolver, fn_declarations, name.into(), i)? + } i @ ptx_parser::Instruction::Sqrt { data: ast::RcpData { diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 9fecba3..1bc622c 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -3,8 +3,8 @@ use super::{ StateSpace, VectorPrefix, }; use crate::{ - FunnelShiftMode, MatrixNumber, MatrixShape, Mul24Control, PtxError, PtxParserState, Reduction, - ShiftDirection, ShuffleMode, VoteMode, + FunnelShiftMode, MatrixLayout, MatrixNumber, MatrixShape, Mul24Control, PtxError, + PtxParserState, Reduction, ShiftDirection, ShuffleMode, VoteMode, }; use bitflags::bitflags; use derive_more::Display; @@ -721,6 +721,30 @@ ptx_parser_macros::generate_instruction_type!( space: { data.state_space }, } } + }, + GridDepControl { + data: crate::GridDepControlAction, + }, + Mma { + data: MmaDetails, + arguments: { + dst: { + repr: T, + type: { data.dtype() }, + }, + src1: { + repr: T, + type: { data.atype() }, + }, + src2: { + repr: T, + type: { data.btype() }, + }, + src3: { + repr: T, + type: { data.ctype() }, + } + } } } ); @@ -2378,3 +2402,27 @@ pub struct ReduxSyncData { pub type_: ScalarType, pub reduction: Reduction, } + +pub struct MmaDetails { + pub alayout: MatrixLayout, + pub blayout: MatrixLayout, + pub dtype_scalar: ScalarType, + pub atype_scalar: ScalarType, + pub btype_scalar: ScalarType, + pub ctype_scalar: ScalarType, +} + +impl MmaDetails { + pub fn dtype(&self) -> Type { + Type::Vector(4, ScalarType::F32) + } + pub fn atype(&self) -> Type { + Type::Vector(4, ScalarType::U32) + } + pub fn btype(&self) -> Type { + Type::Vector(2, ScalarType::U32) + } + pub fn ctype(&self) -> Type { + Type::Vector(4, ScalarType::F32) + } +} diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index 4253ae6..a4f9080 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -1862,6 +1862,9 @@ derive_parser!( #[derive(Copy, Clone, Display, PartialEq, Eq, Hash)] pub enum MatrixNumber { } + #[derive(Copy, Clone, Display, PartialEq, Eq, Hash)] + pub enum MatrixLayout { } + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov mov{.vec}.type d, a => { Instruction::Mov { @@ -3897,6 +3900,37 @@ derive_parser!( .type: ScalarType = {.b16, .b8}; // .dst_fmt = { .b8x16 }; // .src_fmt = { .b6x16_p32, .b4x16_p64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-griddepcontrol + griddepcontrol.action => { + Instruction::GridDepControl { + data: action + } + } + .action: GridDepControlAction = { .launch_dependents, .wait }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-mma + mma.sync.aligned.m16n8k16.alayout.blayout.dtype.bf16.bf16.ctype d, a, b, c => { + if dtype != ScalarType::F32 || ctype != ScalarType::F32 { + state.errors.push(PtxError::Todo); + } + Instruction::Mma { + data: MmaDetails { + alayout, + blayout, + dtype_scalar: dtype, + atype_scalar: ScalarType::BF16, + btype_scalar: ScalarType::BF16, + ctype_scalar: ctype, + }, + arguments: MmaArgs { dst: d, src1: a, src2: b, src3: c } + } + } + + .alayout: MatrixLayout = {.row}; + .blayout: MatrixLayout = {.col}; + .ctype: ScalarType = {.f16, .f32}; + .dtype: ScalarType = {.f16, .f32}; ); #[cfg(test)] diff --git a/ptxas/src/main.rs b/ptxas/src/main.rs index 0ffe841..f8a52aa 100644 --- a/ptxas/src/main.rs +++ b/ptxas/src/main.rs @@ -1,6 +1,4 @@ -use bpaf::{any, doc::Style, Bpaf, Parser}; -use hip_runtime_sys::{hipDeviceProp_tR0600, hipGetDevicePropertiesR0600}; -use std::{ffi::CStr, mem}; +use bpaf::{any, choice, doc::Style, literal, Bpaf, Parser}; #[derive(Debug, Clone, Bpaf)] #[allow(dead_code)] @@ -12,6 +10,8 @@ pub struct Options { #[bpaf(short, long)] verbose: bool, #[bpaf(external)] + lineinfo: bool, + #[bpaf(external)] gpu_name: String, #[bpaf(long, short('O'), fallback(3))] opt_level: usize, @@ -19,48 +19,32 @@ pub struct Options { input: String, } +fn lineinfo() -> impl Parser { + choice(["-lineinfo", "--lineinfo"].into_iter().map(|s| { + literal(s) + .anywhere() + .optional() + .map(|_| true) + .fallback(false) + .boxed() + })) +} + // #[bpaf(long, long("gpu_name"), fallback_with(default_arch))] fn gpu_name() -> impl Parser { any("", move |s: String| { - Some(s.strip_prefix("-arch=")?.to_owned()) + Some( + s.strip_prefix("-arch=") + .or_else(|| s.strip_prefix("--gpu-name="))? + .to_owned(), + ) }) - .metavar(&[("-arch=", Style::Literal), ("ARG", Style::Metavar)]) + .metavar(&[("--gpu-name=", Style::Literal), ("SM", Style::Metavar)]) .anywhere() .fallback_with(|| Ok::("sm_52".to_string())) } fn main() { let options = options().run(); - let comgr = comgr::Comgr::new().unwrap(); - unsafe { hip_runtime_sys::hipInit(0) }.unwrap(); - let mut dev_props: hipDeviceProp_tR0600 = unsafe { mem::zeroed() }; - let (gpu_arch, clock_rate) = get_gpu_arch_and_clock_rate(&mut dev_props); - let input = std::fs::read_to_string(options.input).unwrap(); - let ast = ptx_parser::parse_module_checked(&input).unwrap(); - let llvm = ptx::to_llvm_module( - ast, - ptx::Attributes { - clock_rate: clock_rate as u32, - }, - |_| {}, - ) - .unwrap(); - let elf_binary = comgr::compile_bitcode( - &comgr, - gpu_arch, - &*llvm.llvm_ir.write_bitcode_to_memory(), - &*llvm.linked_bitcode(), - &*llvm.attributes_ir.write_bitcode_to_memory(), - None, - ) - .unwrap(); - std::fs::write(options.output, elf_binary).unwrap(); -} - -fn get_gpu_arch_and_clock_rate<'a>(dev_props: &'a mut hipDeviceProp_tR0600) -> (&'a str, i32) { - unsafe { hipGetDevicePropertiesR0600(dev_props, 0) }.unwrap(); - let gcn_arch_name = &dev_props.gcnArchName; - let gcn_arch_name = unsafe { CStr::from_ptr(gcn_arch_name.as_ptr()) }; - let gcn_arch_name = gcn_arch_name.to_str(); - (gcn_arch_name.unwrap(), dev_props.clockRate) + std::fs::copy(&options.input, &options.output).unwrap(); } diff --git a/zluda/src/impl/device.rs b/zluda/src/impl/device.rs index 6816994..ed8bb8c 100644 --- a/zluda/src/impl/device.rs +++ b/zluda/src/impl/device.rs @@ -89,7 +89,15 @@ pub(crate) fn get_attribute( *pi = 32; return Ok(()); } - CUdevice_attribute::CU_DEVICE_ATTRIBUTE_TCC_DRIVER => { + // TODO: maintain a table, certain RDNAs are 1/16, some are 1/32 + CUdevice_attribute::CU_DEVICE_ATTRIBUTE_SINGLE_TO_DOUBLE_PRECISION_PERF_RATIO => { + *pi = 32; + return Ok(()); + } + CUdevice_attribute::CU_DEVICE_ATTRIBUTE_TCC_DRIVER + | CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED + | CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MEMPOOL_SUPPORTED_HANDLE_TYPES + | CUdevice_attribute::CU_DEVICE_ATTRIBUTE_DMA_BUF_SUPPORTED => { *pi = 0; return Ok(()); } @@ -211,9 +219,6 @@ pub(crate) fn get_attribute( CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_PERSISTING_L2_CACHE_SIZE => { return get_device_prop(pi, dev_idx, |props| props.persistingL2CacheMaxSize) } - CUdevice_attribute::CU_DEVICE_ATTRIBUTE_SINGLE_TO_DOUBLE_PRECISION_PERF_RATIO => { - return get_device_prop(pi, dev_idx, |props| props.singleToDoublePrecisionPerfRatio) - } CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_ACCESS_POLICY_WINDOW_SIZE => { return get_device_prop(pi, dev_idx, |props| props.accessPolicyMaxWindowSize) } diff --git a/zluda/src/impl/driver.rs b/zluda/src/impl/driver.rs index 737f5c3..ad8310e 100644 --- a/zluda/src/impl/driver.rs +++ b/zluda/src/impl/driver.rs @@ -487,9 +487,9 @@ pub(crate) unsafe fn occupancy_max_active_blocks_per_multiprocessor_with_flags( dynamic_smem_size: usize, flags: ::core::ffi::c_uint, ) -> hipError_t { - hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags( + hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags( num_blocks, - func.0.cast(), + func, block_size, dynamic_smem_size, flags, diff --git a/zluda/src/impl/hipfix.rs b/zluda/src/impl/hipfix.rs new file mode 100644 index 0000000..f957849 --- /dev/null +++ b/zluda/src/impl/hipfix.rs @@ -0,0 +1,12 @@ +// There's a bug in hipDrvPointerGetAttributes where it returns +// HIP_ERROR_INVALID_VALUE if the pointer is null. It works correctly for any +// other invalid pointer +pub(crate) fn get_attributes( + ptr: hip_runtime_sys::hipDeviceptr_t, +) -> hip_runtime_sys::hipDeviceptr_t { + if ptr.0.is_null() { + hip_runtime_sys::hipDeviceptr_t(usize::MAX as _) + } else { + ptr + } +} diff --git a/zluda/src/impl/mod.rs b/zluda/src/impl/mod.rs index f73a972..60ecb80 100644 --- a/zluda/src/impl/mod.rs +++ b/zluda/src/impl/mod.rs @@ -7,6 +7,7 @@ pub(super) mod driver; pub(super) mod event; pub(super) mod function; pub(super) mod graph; +pub(super) mod hipfix; pub(super) mod kernel; pub(super) mod library; pub(super) mod memory; diff --git a/zluda/src/impl/module.rs b/zluda/src/impl/module.rs index 506f824..da7c145 100644 --- a/zluda/src/impl/module.rs +++ b/zluda/src/impl/module.rs @@ -20,6 +20,10 @@ impl ZludaObject for Module { } } +static EMPTY_PTX: &str = ".version 6.5 +.target sm_30 +.address_size 64"; + // get_ptx takes an `image` that can be anything we support and returns a // String containing a ptx extracted from `image`. fn get_ptx<'a>(image: CodeLibraryRef<'a>) -> Result, CUerror> { @@ -58,11 +62,17 @@ fn cow_bytes_to_str<'a>(data: Cow<'a, [u8]>) -> Option> { pub(crate) fn load_hip_module(library: CodeLibraryRef) -> Result { let global_state = driver::global_state()?; - let text = get_ptx(library)?; + let maybe_ptx = get_ptx(library); + let text = if cfg!(debug_assertions) { + maybe_ptx? + } else { + maybe_ptx.unwrap_or_else(|_| Cow::Borrowed(EMPTY_PTX)) + }; let hip_properties = get_hip_properties()?; let gcn_arch = get_gcn_arch(&hip_properties)?; - let attributes = ptx::Attributes { + let attributes = ExtraCacheAttributes { clock_rate: hip_properties.clockRate as u32, + is_debug: cfg!(debug_assertions), }; let mut cache_with_key = global_state.cache_path.as_ref().and_then(|p| { let cache = zluda_cache::ModuleCache::open(p)?; @@ -84,6 +94,12 @@ pub(crate) fn load_hip_module(library: CodeLibraryRef) -> Result() -> Result { let hip_dev = super::context::get_current_device()?; let mut props = unsafe { mem::zeroed() }; @@ -100,7 +116,7 @@ fn get_cache_key<'a, 'b>( global_state: &'static driver::GlobalState, isa: &'a str, text: &str, - attributes: &ptx::Attributes, + attributes: &impl serde::Serialize, ) -> Option> { // Serialization here is deterministic. When marking a type with // #[derive(serde::Serialize)] the derived implementation will just write @@ -129,7 +145,7 @@ fn load_cached_binary( fn compile_from_ptx_and_cache( comgr: &comgr::Comgr, gcn_arch: &str, - attributes: ptx::Attributes, + attributes: ExtraCacheAttributes, text: &str, cache_with_key: &mut Option<(zluda_cache::ModuleCache, zluda_cache::ModuleKey)>, ) -> Result, CUerror> { @@ -138,7 +154,14 @@ fn compile_from_ptx_and_cache( } else { ptx_parser::parse_module_unchecked(text) }; - let llvm_module = ptx::to_llvm_module(ast, attributes, |_| {}).map_err(|_| CUerror::UNKNOWN)?; + let llvm_module = ptx::to_llvm_module( + ast, + ptx::Attributes { + clock_rate: attributes.clock_rate, + }, + |_| {}, + ) + .map_err(|_| CUerror::UNKNOWN)?; let elf_module = comgr::compile_bitcode( comgr, gcn_arch, diff --git a/zluda/src/impl/pointer.rs b/zluda/src/impl/pointer.rs index 8eda15e..6541fce 100644 --- a/zluda/src/impl/pointer.rs +++ b/zluda/src/impl/pointer.rs @@ -2,7 +2,7 @@ use cuda_types::cuda::*; use hip_runtime_sys::*; use std::{ffi::c_void, ptr}; -use crate::r#impl::driver; +use crate::r#impl::{driver, hipfix}; // TODO: handlehipMemoryTypeUnregistered fn to_cu_memory_type(cu: hipMemoryType) -> Result { @@ -59,7 +59,12 @@ pub(crate) unsafe fn get_attributes( data: &mut *mut ::core::ffi::c_void, ptr: hipDeviceptr_t, ) -> CUresult { - hipDrvPointerGetAttributes(num_attributes, attributes, data, ptr)?; + hipDrvPointerGetAttributes( + num_attributes, + attributes, + data, + hipfix::get_attributes(ptr), + )?; let attributes = std::slice::from_raw_parts_mut(attributes, num_attributes as usize); let data = std::slice::from_raw_parts_mut(data, num_attributes as usize); for (attr, data_ptr) in attributes.iter().copied().zip(data.iter().copied()) { @@ -88,7 +93,7 @@ mod tests { use crate::tests::CudaApi; use cuda_macros::test_cuda; use cuda_types::cuda::*; - use std::{ffi::c_void, mem, ptr}; + use std::{ffi::c_void, i32, mem, ptr, usize}; #[test_cuda] pub unsafe fn unknown_ptr_attribute(api: impl CudaApi) { @@ -162,4 +167,47 @@ mod tests { ); assert_eq!(context, CUcontext(ptr::null_mut())); } + + #[test_cuda] + pub unsafe fn null_ptr_attributes_success(api: impl CudaApi) { + api.cuInit(0); + api.cuCtxCreate_v2(&mut mem::zeroed(), 0, 0); + let mut context = CUcontext(1 as _); + let mut mem_type = mem::transmute::<_, CUmemorytype>(u32::MAX); + let mut dev_ptr = mem::transmute::<_, *mut c_void>(usize::MAX); + let mut host_ptr = mem::transmute::<_, *mut c_void>(usize::MAX); + let mut is_managed = true; + let mut ordinal = i32::MAX; + let mut attrs = [ + CUpointer_attribute::CU_POINTER_ATTRIBUTE_CONTEXT, + CUpointer_attribute::CU_POINTER_ATTRIBUTE_MEMORY_TYPE, + CUpointer_attribute::CU_POINTER_ATTRIBUTE_DEVICE_POINTER, + CUpointer_attribute::CU_POINTER_ATTRIBUTE_HOST_POINTER, + CUpointer_attribute::CU_POINTER_ATTRIBUTE_IS_MANAGED, + CUpointer_attribute::CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL, + ]; + let mut values = [ + std::ptr::from_mut(&mut context).cast::(), + std::ptr::from_mut(&mut mem_type).cast::(), + std::ptr::from_mut(&mut dev_ptr).cast::(), + std::ptr::from_mut(&mut host_ptr).cast::(), + std::ptr::from_mut(&mut is_managed).cast::(), + std::ptr::from_mut(&mut ordinal).cast::(), + ]; + assert_eq!( + CUresult::SUCCESS, + api.cuPointerGetAttributes_unchecked( + attrs.len() as u32, + attrs.as_mut_ptr(), + values.as_mut_ptr(), + CUdeviceptr_v2(ptr::null_mut()) + ) + ); + assert_eq!(context, CUcontext(ptr::null_mut())); + assert_eq!(mem_type, CUmemorytype(0)); + assert_eq!(dev_ptr, ptr::null_mut()); + assert_eq!(host_ptr, ptr::null_mut()); + assert_eq!(is_managed, false); + assert_eq!(ordinal, -2); + } } diff --git a/zluda_ml/src/impl_unix.rs b/zluda_ml/src/impl_unix.rs index 93d04e3..55437a6 100644 --- a/zluda_ml/src/impl_unix.rs +++ b/zluda_ml/src/impl_unix.rs @@ -43,6 +43,86 @@ pub(crate) unsafe fn device_get_count_v2(device_count: &mut ::core::ffi::c_uint) rsmi_num_monitor_devices(device_count) } +pub(crate) unsafe fn device_get_handle_by_pci_bus_id_v2( + pci_bus_id: &std::ffi::CStr, + device: &mut cuda_types::nvml::nvmlDevice_t, +) -> nvmlReturn_t { + let pci = parse_pci_bus_id(pci_bus_id).ok_or(nvmlError_t::INVALID_ARGUMENT)?; + let bdfid = pci.to_bdfid(); + let mut device_count = 0; + rsmi_num_monitor_devices(&mut device_count)?; + for dv_ind in 0..device_count { + let mut curr_bdfid = 0; + rsmi_dev_pci_id_get(dv_ind, &mut curr_bdfid)?; + if curr_bdfid == bdfid { + *device = Device { _index: dv_ind }.wrap(); + return nvmlReturn_t::SUCCESS; + } + } + nvmlReturn_t::ERROR_NOT_FOUND +} + +#[derive(Clone, Copy)] +struct PciBusId { + domain: u16, + bus: u8, + device: u8, + function: u8, +} +impl PciBusId { + fn to_bdfid(self) -> u64 { + ((self.domain as u64) << 32) + | ((self.bus as u64) << 8) + | ((self.device as u64) << 3) + | (self.function as u64) + } +} + +fn parse_pci_bus_id(id: &std::ffi::CStr) -> Option { + let s = id.to_str().ok()?.trim(); + let mut domain: u16 = 0; + let mut rest = s; + if let Some(colon1) = s.find(':') { + if colon1 == 4 { + domain = hex_u16(&s[..4])?; + rest = &s[5..]; + } + } + let mut parts = rest.split(':'); + let bus_part = parts.next()?; + let tail = parts.next()?; + if parts.next().is_some() { + return None; + } + let mut dev_func = tail.split('.'); + let dev_part = dev_func.next()?; + let func_part = dev_func.next(); + let function = match func_part { + Some(f) => hex_u8(f)?, + None => 0, + }; + Some(PciBusId { + domain, + bus: hex_u8(bus_part)?, + device: hex_u8(dev_part)?, + function, + }) +} + +fn hex_u16(s: &str) -> Option { + if s.len() > 4 { + return None; + } + u16::from_str_radix(s, 16).ok() +} + +fn hex_u8(s: &str) -> Option { + if s.len() > 2 { + return None; + } + u8::from_str_radix(s, 16).ok() +} + pub(crate) unsafe fn device_get_field_values( _device: &Device, values_count: ::core::ffi::c_int, @@ -75,3 +155,36 @@ pub(crate) fn device_get_handle_by_index_v2( *device = Device { _index: index }.wrap(); nvmlReturn_t::SUCCESS } + +#[cfg(test)] +mod tests { + #[test] + fn parse_pci_bus_id_full() { + let id = std::ffi::CString::new("0100:65:a0.f").unwrap(); + let parsed = super::parse_pci_bus_id(&id).unwrap(); + assert_eq!(parsed.domain, 0x0100); + assert_eq!(parsed.bus, 0x65); + assert_eq!(parsed.device, 0xa0); + assert_eq!(parsed.function, 0xf); + } + + #[test] + fn parse_pci_bus_id_no_func() { + let id = std::ffi::CString::new("0100:65:a0").unwrap(); + let parsed = super::parse_pci_bus_id(&id).unwrap(); + assert_eq!(parsed.domain, 0x0100); + assert_eq!(parsed.bus, 0x65); + assert_eq!(parsed.device, 0xa0); + assert_eq!(parsed.function, 0); + } + + #[test] + fn parse_pci_bus_id_no_domain() { + let id = std::ffi::CString::new("65:a0.f").unwrap(); + let parsed = super::parse_pci_bus_id(&id).unwrap(); + assert_eq!(parsed.domain, 0); + assert_eq!(parsed.bus, 0x65); + assert_eq!(parsed.device, 0xa0); + assert_eq!(parsed.function, 0xf); + } +} diff --git a/zluda_ml/src/impl_win.rs b/zluda_ml/src/impl_win.rs index 35f0dfc..205e792 100644 --- a/zluda_ml/src/impl_win.rs +++ b/zluda_ml/src/impl_win.rs @@ -23,6 +23,13 @@ pub(crate) unsafe fn device_get_count_v2(_device_count: &mut ::core::ffi::c_uint crate::impl_common::unimplemented() } +pub(crate) unsafe fn device_get_handle_by_pci_bus_id_v2( + pci_bus_id: &std::ffi::CStr, + device: &mut cuda_types::nvml::nvmlDevice_t, +) -> nvmlReturn_t { + crate::impl_common::unimplemented() +} + pub(crate) unsafe fn device_get_field_values( _device: cuda_types::nvml::nvmlDevice_t, _values_count: ::core::ffi::c_int, @@ -31,10 +38,6 @@ pub(crate) unsafe fn device_get_field_values( crate::impl_common::unimplemented() } -unsafe fn get_field_value(_field: &mut nvmlFieldValue_st) -> Result<(), nvmlError_t> { - crate::impl_common::unimplemented() -} - pub(crate) unsafe fn device_get_gpu_fabric_info( _device: cuda_types::nvml::nvmlDevice_t, _gpu_fabric_info: &mut cuda_types::nvml::nvmlGpuFabricInfo_t, diff --git a/zluda_ml/src/lib.rs b/zluda_ml/src/lib.rs index fe8271c..40a7e30 100644 --- a/zluda_ml/src/lib.rs +++ b/zluda_ml/src/lib.rs @@ -48,6 +48,7 @@ cuda_macros::nvml_function_declarations!( nvmlDeviceGetFieldValues, nvmlDeviceGetGpuFabricInfo, nvmlDeviceGetHandleByIndex_v2, + nvmlDeviceGetHandleByPciBusId_v2, nvmlInit, nvmlInitWithFlags, nvmlInit_v2, diff --git a/zluda_trace/src/log.rs b/zluda_trace/src/log.rs index b3f9716..9cbb9cc 100644 --- a/zluda_trace/src/log.rs +++ b/zluda_trace/src/log.rs @@ -303,6 +303,7 @@ pub(crate) enum ErrorEntry { }, NullPointer(&'static str), UnknownLibrary(CUlibrary), + SavedModule(String), } unsafe impl Send for ErrorEntry {} @@ -344,93 +345,94 @@ impl Display for ErrorEntry { match self { ErrorEntry::IoError(e) => e.fmt(f), ErrorEntry::CreatedDumpDirectory(dir) => { - write!( - f, - "Created trace directory {} ", - dir.as_os_str().to_string_lossy() - ) - } + write!( + f, + "Created trace directory {} ", + dir.as_os_str().to_string_lossy() + ) + } ErrorEntry::ErrorBox(e) => e.fmt(f), ErrorEntry::UnsupportedModule { - module, - raw_image, - kind, - } => { - write!( - f, - "Unsupported {} module {:?} loaded from module image {:?}", - kind, module, raw_image - ) - } + module, + raw_image, + kind, + } => { + write!( + f, + "Unsupported {} module {:?} loaded from module image {:?}", + kind, module, raw_image + ) + } ErrorEntry::MalformedModulePath(e) => e.fmt(f), ErrorEntry::NonUtf8ModuleText(e) => e.fmt(f), ErrorEntry::ModuleParsingError(file_name) => { - write!( - f, - "Error parsing module, log has been written to {}", - file_name - ) - } + write!( + f, + "Error parsing module, log has been written to {}", + file_name + ) + } ErrorEntry::NulInsideModuleText(e) => e.fmt(f), ErrorEntry::Lz4DecompressionFailure => write!(f, "LZ4 decompression failure"), ErrorEntry::ZstdDecompressionFailure(err_code) => write!(f, "Zstd decompression failure: {}", zstd_safe::get_error_name(*err_code)), ErrorEntry::UnexpectedBinaryField { - field_name, - expected, - observed, - } => write!( - f, - "Unexpected field {}. Expected one of: [{}], observed: {}", - field_name, - expected - .iter() - .map(|x| x.to_string()) - .collect::>() - .join(", "), - observed - ), + field_name, + expected, + observed, + } => write!( + f, + "Unexpected field {}. Expected one of: [{}], observed: {}", + field_name, + expected + .iter() + .map(|x| x.to_string()) + .collect::>() + .join(", "), + observed + ), ErrorEntry::UnexpectedArgument { - arg_name, - expected, - observed, - } => write!( - f, - "Unexpected argument {}. Expected one of: {{{}}}, observed: {}", - arg_name, - expected - .iter() - .map(|x| x.to_string()) - .collect::>() - .join(", "), - observed - ), + arg_name, + expected, + observed, + } => write!( + f, + "Unexpected argument {}. Expected one of: {{{}}}, observed: {}", + arg_name, + expected + .iter() + .map(|x| x.to_string()) + .collect::>() + .join(", "), + observed + ), ErrorEntry::InvalidEnvVar { - var, - pattern, - value, - } => write!( - f, - "Unexpected value of environment variable {var}. Expected pattern: {pattern}, got value: {value}" - ), + var, + pattern, + value, + } => write!( + f, + "Unexpected value of environment variable {var}. Expected pattern: {pattern}, got value: {value}" + ), ErrorEntry::FunctionNotFound(cuda_function_name) => write!( - f, - "No function {cuda_function_name} in the underlying library" - ), + f, + "No function {cuda_function_name} in the underlying library" + ), ErrorEntry::UnexpectedExportTableSize { expected, computed } => { - write!(f, "Table length mismatch. Expected: {expected}, got: {computed}") - } + write!(f, "Table length mismatch. Expected: {expected}, got: {computed}") + } ErrorEntry::IntegrityCheck { original, overriden } => { - write!(f, "Overriding integrity check hash. Original: {original:?}, overriden: {overriden:?}") - } + write!(f, "Overriding integrity check hash. Original: {original:?}, overriden: {overriden:?}") + } ErrorEntry::NullPointer(type_) => { - write!(f, "Null pointer of type {type_} encountered") - } + write!(f, "Null pointer of type {type_} encountered") + } ErrorEntry::UnknownLibrary(culibrary) => { - write!(f, "Unknown library: ")?; - let mut temp_buffer = Vec::new(); - CudaDisplay::write(culibrary, "", 0, &mut temp_buffer).ok(); - f.write_str(&unsafe { String::from_utf8_unchecked(temp_buffer) }) - } + write!(f, "Unknown library: ")?; + let mut temp_buffer = Vec::new(); + CudaDisplay::write(culibrary, "", 0, &mut temp_buffer).ok(); + f.write_str(&unsafe { String::from_utf8_unchecked(temp_buffer) }) + } + ErrorEntry::SavedModule(file) => write!(f, "Saved module to {file}"), } } } diff --git a/zluda_trace/src/trace.rs b/zluda_trace/src/trace.rs index e71aacd..f397d34 100644 --- a/zluda_trace/src/trace.rs +++ b/zluda_trace/src/trace.rs @@ -128,12 +128,11 @@ impl StateTracker { fn_logger: &mut FnCallLog, type_: &'static str, ) { - fn_logger.log_io_error(self.writer.save_module( - self.library_counter, - index, - submodule, - type_, - )); + fn_logger.try_(|fn_logger| { + self.writer + .save_module(fn_logger, self.library_counter, index, submodule, type_) + .map_err(ErrorEntry::IoError) + }); if type_ == "ptx" { match CString::new(submodule) { Err(e) => fn_logger.log(log::ErrorEntry::NulInsideModuleText(e)), @@ -323,6 +322,7 @@ impl DumpWriter { fn save_module( &self, + fn_logger: &mut FnCallLog, module_index: usize, submodule_index: Option<(usize, Option)>, buffer: &[u8], @@ -332,9 +332,13 @@ impl DumpWriter { None => return Ok(()), Some(d) => d.clone(), }; - dump_file.push(Self::get_file_name(module_index, submodule_index, kind)); - let mut file = File::create_new(dump_file)?; - file.write_all(buffer)?; + let file_name = Self::get_file_name(module_index, submodule_index, kind); + dump_file.push(&file_name); + { + let mut file = File::create_new(dump_file)?; + file.write_all(buffer)?; + } + fn_logger.log(ErrorEntry::SavedModule(file_name)); Ok(()) } @@ -349,7 +353,7 @@ impl DumpWriter { Some(d) => d.clone(), }; log_file.push(Self::get_file_name(module_index, submodule_index, "log")); - let mut file = File::create(log_file)?; + let mut file = File::create_new(log_file)?; for error in errors { writeln!(file, "{}", error)?; }