More runtime fixes, add mma instruction (#509)
Some checks are pending
ZLUDA / Build (Linux) (push) Waiting to run
ZLUDA / Build (Windows) (push) Waiting to run
ZLUDA / Build AMD GPU unit tests (push) Waiting to run
ZLUDA / Run AMD GPU unit tests (push) Blocked by required conditions

This commit is contained in:
Andrzej Janik 2025-09-18 20:15:22 +02:00 committed by GitHub
commit b5f41c7cd0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
27 changed files with 639 additions and 154 deletions

View file

@ -45,7 +45,7 @@ body:
./train_gpt2fp32cu ./train_gpt2fp32cu
4. Build and run the tests: 4. Build and run the tests:
make test_gpt2fp32cu make test_gpt2fp32cu
LD_LIBRARY_PATH=<ZLUDA_TRACE_DIR> ./test_gpt2fp32cu LD_LIBRARY_PATH=<ZLUDA_LOG_DIR> ./test_gpt2fp32cu
validations: validations:
required: true required: true
- type: input - type: input

1
Cargo.lock generated
View file

@ -2573,7 +2573,6 @@ dependencies = [
"ptx_parser", "ptx_parser",
"quick-error", "quick-error",
"rustc-hash 2.0.0", "rustc-hash 2.0.0",
"serde",
"smallvec", "smallvec",
"strum 0.26.3", "strum 0.26.3",
"strum_macros 0.26.4", "strum_macros 0.26.4",

View file

@ -219,6 +219,10 @@ pub fn compile_bitcode(
compile_to_exec.set_isa_name(gcn_arch)?; compile_to_exec.set_isa_name(gcn_arch)?;
compile_to_exec.set_language(Language::LlvmIr)?; compile_to_exec.set_language(Language::LlvmIr)?;
let common_options = [ let common_options = [
c"-mllvm",
c"-ignore-tti-inline-compatible",
// c"-mllvm",
// c"-amdgpu-early-inline-all=true",
// This makes no sense, but it makes ockl linking work // This makes no sense, but it makes ockl linking work
c"-Xclang", c"-Xclang",
c"-mno-link-builtin-bitcode-postopt", c"-mno-link-builtin-bitcode-postopt",
@ -237,8 +241,7 @@ pub fn compile_bitcode(
] ]
.into_iter(); .into_iter();
let opt_options = if cfg!(debug_assertions) { let opt_options = if cfg!(debug_assertions) {
//[c"-g", c"-mllvm", c"-print-before-all", c"", c""] [c"-g", c"-mamdgpu-precise-memory-op", c"", c"", c""]
[c"-g", c"", c"", c"", c""]
} else { } else {
[ [
c"-g0", c"-g0",

View file

@ -21,9 +21,14 @@ pub struct Options {
output_dir: Option<PathBuf>, output_dir: Option<PathBuf>,
#[bpaf(long("arch"))] #[bpaf(long("arch"))]
/// Target architecture /// Target GPU architecture
arch: Option<String>, arch: Option<String>,
#[bpaf(long("ignore-errors"))]
/// Try to ignore errors. This will try and produce output even if there are
/// parsing errors (e.g. an unimplemented instruction)
ignore_errors: bool,
#[bpaf(positional("filename"))] #[bpaf(positional("filename"))]
/// PTX file /// PTX file
ptx_path: String, ptx_path: String,
@ -48,7 +53,10 @@ fn main_core() -> Result<(), CompilerError> {
.unwrap_or("output"); .unwrap_or("output");
let mut output_path = match opts.output_dir { let mut output_path = match opts.output_dir {
Some(value) => value, Some(value) => {
std::fs::create_dir_all(&value)?;
value
}
None => match ptx_path.parent() { None => match ptx_path.parent() {
Some(dir) => dir.to_path_buf(), Some(dir) => dir.to_path_buf(),
None => env::current_dir()?, None => env::current_dir()?,
@ -68,7 +76,7 @@ fn main_core() -> Result<(), CompilerError> {
let ptx = fs::read(&ptx_path).map_err(CompilerError::from)?; let ptx = fs::read(&ptx_path).map_err(CompilerError::from)?;
let ptx = str::from_utf8(&ptx).map_err(CompilerError::from)?; let ptx = str::from_utf8(&ptx).map_err(CompilerError::from)?;
let llvm = ptx_to_llvm(ptx).map_err(CompilerError::from)?; let llvm = ptx_to_llvm(opts.ignore_errors, ptx).map_err(CompilerError::from)?;
write_to_file(&llvm.llvm_ir, output_path.with_extension("ll").as_path())?; write_to_file(&llvm.llvm_ir, output_path.with_extension("ll").as_path())?;
@ -92,8 +100,12 @@ fn main_core() -> Result<(), CompilerError> {
Ok(()) Ok(())
} }
fn ptx_to_llvm(ptx: &str) -> Result<LLVMArtifacts, CompilerError> { fn ptx_to_llvm(ignore_errors: bool, ptx: &str) -> Result<LLVMArtifacts, CompilerError> {
let ast = ptx_parser::parse_module_checked(ptx).map_err(CompilerError::from)?; let ast = if ignore_errors {
ptx_parser::parse_module_unchecked(ptx)
} else {
ptx_parser::parse_module_checked(ptx).map_err(CompilerError::from)?
};
let mut start = Instant::now(); let mut start = Instant::now();
let module = ptx::to_llvm_module( let module = ptx::to_llvm_module(
ast, ast,

View file

@ -116,7 +116,7 @@ in order to demonstrate all of zluda_trace's features.
```bash ```bash
nvcc add.cu -o add -arch sm_80 nvcc add.cu -o add -arch sm_80
LD_LIBRARY_PATH=~/ZLUDA/target/release/trace/ ZLUDA_TRACE_DIR=/tmp/zluda ./add LD_LIBRARY_PATH=~/ZLUDA/target/release/trace/ ZLUDA_LOG_DIR=/tmp/zluda ./add
``` ```
The last few lines should look something like: The last few lines should look something like:

View file

@ -22,7 +22,6 @@ microlp = "0.2.11"
int-enum = "1.1" int-enum = "1.1"
unwrap_or = "1.0.1" unwrap_or = "1.0.1"
smallvec = "1.15.1" smallvec = "1.15.1"
serde = { version = "1.0.219", features = ["derive"] }
[dev-dependencies] [dev-dependencies]
hip_runtime-sys = { path = "../ext/hip_runtime-sys" } hip_runtime-sys = { path = "../ext/hip_runtime-sys" }

Binary file not shown.

View file

@ -1,17 +1,21 @@
// Every time this file changes it must te rebuilt, you need `rocm-llvm-dev` and `llvm-17` // Every time this file changes it must te rebuilt, you need `rocm-llvm-dev` and `llvm-17`
// `fdenormal-fp-math=dynamic` is required to make functions eligible for inlining // `fdenormal-fp-math=dynamic` is required to make functions eligible for inlining
// /opt/rocm/llvm/bin/clang -std=c++20 -Xclang -fdenormal-fp-math=dynamic -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1010 && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1010\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | sed 's/ nneg / /g' | sed 's/ disjoint / /g' | sed '/__hip_cuid/d' | sed 's/external protected/external hidden/g' | sed 's/trunc nuw/trunc/' | sed 's/trunc nsw/trunc/' | llvm-as-17 - -o zluda_ptx_impl.bc && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc // /opt/rocm/llvm/bin/clang -std=c++20 -Xclang -fdenormal-fp-math=dynamic -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1100 && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1100\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | sed 's/ nneg / /g' | sed 's/ disjoint / /g' | sed '/__hip_cuid/d' | sed 's/external protected/external hidden/g' | sed 's/trunc nuw/trunc/' | sed 's/trunc nsw/trunc/' | llvm-as-17 - -o zluda_ptx_impl.bc && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <bit> #include <bit>
#include <cmath> #include <cmath>
#include <hip/hip_runtime.h>
#include <hip/amd_detail/amd_device_functions.h> #include <hip/amd_detail/amd_device_functions.h>
#include <hip/hip_fp8.h> #include <hip/hip_fp8.h>
#define SHARED_SPACE __attribute__((address_space(3))) #define SHARED_SPACE __attribute__((address_space(3)))
#define CONSTANT_SPACE __attribute__((address_space(4))) #define CONSTANT_SPACE __attribute__((address_space(4)))
typedef _Float16 half16 __attribute__((ext_vector_type(16)));
typedef float float8 __attribute__((ext_vector_type(8)));
#define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_##NAME #define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_##NAME
#define ATTR(NAME) __ZLUDA_PTX_IMPL_ATTRIBUTE_##NAME #define ATTR(NAME) __ZLUDA_PTX_IMPL_ATTRIBUTE_##NAME
#define DECLARE_ATTR(TYPE, NAME) \ #define DECLARE_ATTR(TYPE, NAME) \
@ -624,4 +628,156 @@ typedef uint32_t ShflSyncResult __attribute__((ext_vector_type(2)));
uint32_t x3 = load_single_matrix_trans(address, 24); uint32_t x3 = load_single_matrix_trans(address, 24);
return uint4::Native_vec_{x0, x1, x2, x3}; return uint4::Native_vec_{x0, x1, x2, x3};
} }
static inline __device__ _Float16 top16_as_fp16(uint32_t value) {
uint16_t half_bits = static_cast<uint16_t>((value >> 16) & 0xFFFF);
return *reinterpret_cast<_Float16*>(&half_bits);
}
static inline __device__ _Float16 bottom16_as_fp16(uint32_t value) {
uint16_t half_bits = static_cast<uint16_t>(value & 0xFFFF);
return *reinterpret_cast<_Float16*>(&half_bits);
}
static inline __device__ float bpermute_lane(int lane, float x) {
return __hip_ds_bpermutef(4 * lane, x);
}
static inline __device__ uint32_t bpermute_lane(int lane, uint32_t x) {
return __hip_ds_bpermute(4 * lane, x);
}
static __device__ half16 shuffle_a(uint4::Native_vec_ a_reg) {
const unsigned lIdx = threadIdx.x;
const int lane = lIdx % 16; // Lanes 0-15 (the other 16 lanes are a duplicate in w32 mode)
half16 aFrag;
for (int vGPR = 0; vGPR < 8; ++vGPR) {
int cudaChunk = (vGPR / 4) * 2; // will be 0 or 2
int cudaTID = (vGPR % 4 + lane * 4) % 32;
uint32_t reg0, reg1;
// Select the two consecutive elements from a_reg:
if (cudaChunk == 0) {
reg0 = a_reg.x;
reg1 = a_reg.y;
} else { // cudaChunk==2
reg0 = a_reg.z;
reg1 = a_reg.w;
}
uint32_t a_tmp0 = bpermute_lane(cudaTID, reg0);
uint32_t a_tmp1 = bpermute_lane(cudaTID, reg1);
uint32_t a_Frag_reg = (lane < 8) ? a_tmp0 : a_tmp1;
aFrag[2 * vGPR] = bottom16_as_fp16(a_Frag_reg);
aFrag[2 * vGPR + 1] = top16_as_fp16(a_Frag_reg);
}
return aFrag;
}
static __device__ half16 shuffle_b(uint2::Native_vec_ b_reg) {
const unsigned lIdx = threadIdx.x;
const int lane = lIdx % 16;
half16 bFrag;
for (int vGPR = 0; vGPR < 8; ++vGPR) {
int cudaChunk = vGPR / 4; // will be 0 or 1
int cudaTID = vGPR % 4 + (lane * 4) % 64;
uint32_t reg = (cudaChunk == 0) ? b_reg.x : b_reg.y;
uint32_t b_Frag_reg = bpermute_lane(cudaTID, reg);
if (lane < 8) {
bFrag[2 * vGPR] = bottom16_as_fp16(b_Frag_reg);
bFrag[2 * vGPR + 1] = top16_as_fp16(b_Frag_reg);
} else {
bFrag[2 * vGPR] = 0.0f;
bFrag[2 * vGPR + 1] = 0.0f;
}
}
return bFrag;
}
static __device__ float8 shuffle_c(float4::Native_vec_ c_reg) {
const int lIdx = (int)threadIdx.x;
float8 cFrag;
// Loop over the eight vector GPRs.
for (int vGPR = 0; vGPR < 8; ++vGPR) {
int cudaChunk = (vGPR / 4) * 2; // will be 0 or 2: selects which pair of components to use.
int lIdx8 = (lIdx < 8) ? lIdx : lIdx - 8;
int cudaTID = (vGPR % 4) * 8 + lIdx8 / 2;
float ctmp0, ctmp1;
if (cudaChunk == 0) {
ctmp0 = bpermute_lane(cudaTID, c_reg.x);
ctmp1 = bpermute_lane(cudaTID, c_reg.y);
} else { // cudaChunk == 2
ctmp0 = bpermute_lane(cudaTID, c_reg.z);
ctmp1 = bpermute_lane(cudaTID, c_reg.w);
}
// Select one of the two values based on the thread index's LSB.
cFrag[vGPR] = (lIdx & 1) ? ctmp1 : ctmp0;
// Zero out for specific thread indices.
if ((lIdx > 7 && lIdx < 16) || (lIdx > 23 && lIdx < 32))
cFrag[vGPR] = 0.0f;
}
return cFrag;
}
static inline __device__ float4::Native_vec_ shuffle_d(float8 dFrag) {
const int lIdx = (int)threadIdx.x;
float4::Native_vec_ d_out;
for (int cChunk = 0; cChunk < 4; ++cChunk) {
int r_vGPR = (cChunk / 2) * 4;
int add8 = (lIdx & 0x4) ? 8 : 0;
int r_lIdx = (cChunk % 2) + (lIdx % 8) * 2 + add8;
float d_tmp0 = bpermute_lane(r_lIdx, dFrag[r_vGPR]);
float d_tmp1 = bpermute_lane(r_lIdx, dFrag[r_vGPR + 1]);
float d_tmp2 = bpermute_lane(r_lIdx, dFrag[r_vGPR + 2]);
float d_tmp3 = bpermute_lane(r_lIdx, dFrag[r_vGPR + 3]);
float val;
if (lIdx < 8) {
val = d_tmp0;
} else if (lIdx < 16) {
val = d_tmp1;
} else if (lIdx < 24) {
val = d_tmp2;
} else {
val = d_tmp3;
}
if (cChunk == 0) d_out.x = val;
else if (cChunk == 1) d_out.y = val;
else if (cChunk == 2) d_out.z = val;
else d_out.w = val;
}
return d_out;
}
float4::Native_vec_ FUNC(mma_sync_aligned_m16n8k16_row_col_f32_f16_f16_f32)(uint4::Native_vec_ a_reg, uint2::Native_vec_ b_reg, float4::Native_vec_ c_reg) {
// Reshuffle from Nvidia-like register layout to AMD layout:
half16 aFrag = shuffle_a(a_reg);
half16 bFrag = shuffle_b(b_reg);
float8 cFrag = shuffle_c(c_reg);
// Call the (builtin) 16x16 MMA instruction. It returns a float8.
float8 dFrag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(aFrag, bFrag, cFrag);
// Unshuffle back into Nvidia expected float4 result
float4::Native_vec_ d_out = shuffle_d(dFrag);
return d_out;
}
float4::Native_vec_ FUNC(mma_sync_aligned_m16n8k16_row_col_f32_bf16_bf16_f32)(uint4::Native_vec_ a_reg, uint2::Native_vec_ b_reg, float4::Native_vec_ c_reg) {
// Reshuffle from Nvidia-like register layout to AMD layout:
half16 aFrag = shuffle_a(a_reg);
half16 bFrag = shuffle_b(b_reg);
float8 cFrag = shuffle_c(c_reg);
// Call the (builtin) 16x16 MMA instruction. It returns a float8.
float8 dFrag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(aFrag, bFrag, cFrag);
// Unshuffle back into Nvidia expected float4 result
float4::Native_vec_ d_out = shuffle_d(dFrag);
return d_out;
}
} }

View file

@ -197,7 +197,9 @@ fn run_instruction<'input>(
| ast::Instruction::Xor { .. } | ast::Instruction::Xor { .. }
| ast::Instruction::Vote { .. } | ast::Instruction::Vote { .. }
| ast::Instruction::ReduxSync { .. } | ast::Instruction::ReduxSync { .. }
| ast::Instruction::LdMatrix { .. } => result.push(Statement::Instruction(instruction)), | ast::Instruction::GridDepControl { .. }
| ast::Instruction::LdMatrix { .. }
| ast::Instruction::Mma { .. } => result.push(Statement::Instruction(instruction)),
ast::Instruction::Add { ast::Instruction::Add {
data: data:
ast::ArithDetails::Float(ast::ArithFloat { ast::ArithDetails::Float(ast::ArithFloat {

View file

@ -1855,7 +1855,9 @@ fn get_modes<T: ast::Operand>(inst: &ast::Instruction<T>) -> InstructionModes {
| ast::Instruction::AtomCas { .. } | ast::Instruction::AtomCas { .. }
| ast::Instruction::Vote { .. } | ast::Instruction::Vote { .. }
| ast::Instruction::ReduxSync { .. } | ast::Instruction::ReduxSync { .. }
| ast::Instruction::LdMatrix { .. } => InstructionModes::none(), | ast::Instruction::GridDepControl { .. }
| ast::Instruction::LdMatrix { .. }
| ast::Instruction::Mma { .. } => InstructionModes::none(),
ast::Instruction::Add { ast::Instruction::Add {
data: ast::ArithDetails::Integer(_), data: ast::ArithDetails::Integer(_),
.. ..

View file

@ -153,6 +153,9 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
for (i, param) in method.input_arguments.iter().enumerate() { for (i, param) in method.input_arguments.iter().enumerate() {
let value = unsafe { LLVMGetParam(fn_, i as u32) }; let value = unsafe { LLVMGetParam(fn_, i as u32) };
let name = self.resolver.get_or_add(param.name); let name = self.resolver.get_or_add(param.name);
if let Some(align) = param.align {
unsafe { LLVMSetParamAlignment(value, align) };
}
unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) }; unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) };
self.resolver.register(param.name, value); self.resolver.register(param.name, value);
if method.is_kernel { if method.is_kernel {
@ -519,6 +522,7 @@ impl<'a> MethodEmitContext<'a> {
ast::Instruction::CpAsyncCommitGroup {} => Ok(()), // nop ast::Instruction::CpAsyncCommitGroup {} => Ok(()), // nop
ast::Instruction::CpAsyncWaitGroup { .. } => Ok(()), // nop ast::Instruction::CpAsyncWaitGroup { .. } => Ok(()), // nop
ast::Instruction::CpAsyncWaitAll { .. } => Ok(()), // nop ast::Instruction::CpAsyncWaitAll { .. } => Ok(()), // nop
ast::Instruction::GridDepControl { .. } => Ok(()), // nop
// replaced by a function call // replaced by a function call
ast::Instruction::Bfe { .. } ast::Instruction::Bfe { .. }
| ast::Instruction::Bar { .. } | ast::Instruction::Bar { .. }
@ -529,7 +533,8 @@ impl<'a> MethodEmitContext<'a> {
| ast::Instruction::Vote { .. } | ast::Instruction::Vote { .. }
| ast::Instruction::Nanosleep { .. } | ast::Instruction::Nanosleep { .. }
| ast::Instruction::ReduxSync { .. } | ast::Instruction::ReduxSync { .. }
| ast::Instruction::LdMatrix { .. } => return Err(error_unreachable()), | ast::Instruction::LdMatrix { .. }
| ast::Instruction::Mma { .. } => return Err(error_unreachable()),
} }
} }

View file

@ -51,7 +51,7 @@ quick_error! {
} }
/// GPU attributes needed at compile time. /// GPU attributes needed at compile time.
#[derive(serde::Serialize)] #[derive(Copy, Clone)]
pub struct Attributes { pub struct Attributes {
/// Clock frequency in kHz. /// Clock frequency in kHz.
pub clock_rate: u32, pub clock_rate: u32,

View file

@ -351,6 +351,35 @@ fn run_instruction<'input>(
let name = "sqrt_rn_ftz_f32"; let name = "sqrt_rn_ftz_f32";
to_call(resolver, fn_declarations, name.into(), i)? to_call(resolver, fn_declarations, name.into(), i)?
} }
i @ ptx_parser::Instruction::Mma {
data:
ast::MmaDetails {
alayout,
blayout,
dtype_scalar,
atype_scalar,
btype_scalar,
ctype_scalar,
},
..
} => {
let name = format!(
"mma_sync_aligned_m16n8k16_{}_{}_{}_{}_{}_{}",
match alayout {
ast::MatrixLayout::Row => "row",
ast::MatrixLayout::Col => "col",
},
match blayout {
ast::MatrixLayout::Row => "row",
ast::MatrixLayout::Col => "col",
},
scalar_to_ptx_name(dtype_scalar),
scalar_to_ptx_name(atype_scalar),
scalar_to_ptx_name(btype_scalar),
scalar_to_ptx_name(ctype_scalar),
);
to_call(resolver, fn_declarations, name.into(), i)?
}
i @ ptx_parser::Instruction::Sqrt { i @ ptx_parser::Instruction::Sqrt {
data: data:
ast::RcpData { ast::RcpData {

View file

@ -3,8 +3,8 @@ use super::{
StateSpace, VectorPrefix, StateSpace, VectorPrefix,
}; };
use crate::{ use crate::{
FunnelShiftMode, MatrixNumber, MatrixShape, Mul24Control, PtxError, PtxParserState, Reduction, FunnelShiftMode, MatrixLayout, MatrixNumber, MatrixShape, Mul24Control, PtxError,
ShiftDirection, ShuffleMode, VoteMode, PtxParserState, Reduction, ShiftDirection, ShuffleMode, VoteMode,
}; };
use bitflags::bitflags; use bitflags::bitflags;
use derive_more::Display; use derive_more::Display;
@ -721,6 +721,30 @@ ptx_parser_macros::generate_instruction_type!(
space: { data.state_space }, space: { data.state_space },
} }
} }
},
GridDepControl {
data: crate::GridDepControlAction,
},
Mma {
data: MmaDetails,
arguments<T>: {
dst: {
repr: T,
type: { data.dtype() },
},
src1: {
repr: T,
type: { data.atype() },
},
src2: {
repr: T,
type: { data.btype() },
},
src3: {
repr: T,
type: { data.ctype() },
}
}
} }
} }
); );
@ -2378,3 +2402,27 @@ pub struct ReduxSyncData {
pub type_: ScalarType, pub type_: ScalarType,
pub reduction: Reduction, pub reduction: Reduction,
} }
pub struct MmaDetails {
pub alayout: MatrixLayout,
pub blayout: MatrixLayout,
pub dtype_scalar: ScalarType,
pub atype_scalar: ScalarType,
pub btype_scalar: ScalarType,
pub ctype_scalar: ScalarType,
}
impl MmaDetails {
pub fn dtype(&self) -> Type {
Type::Vector(4, ScalarType::F32)
}
pub fn atype(&self) -> Type {
Type::Vector(4, ScalarType::U32)
}
pub fn btype(&self) -> Type {
Type::Vector(2, ScalarType::U32)
}
pub fn ctype(&self) -> Type {
Type::Vector(4, ScalarType::F32)
}
}

View file

@ -1862,6 +1862,9 @@ derive_parser!(
#[derive(Copy, Clone, Display, PartialEq, Eq, Hash)] #[derive(Copy, Clone, Display, PartialEq, Eq, Hash)]
pub enum MatrixNumber { } pub enum MatrixNumber { }
#[derive(Copy, Clone, Display, PartialEq, Eq, Hash)]
pub enum MatrixLayout { }
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov
mov{.vec}.type d, a => { mov{.vec}.type d, a => {
Instruction::Mov { Instruction::Mov {
@ -3897,6 +3900,37 @@ derive_parser!(
.type: ScalarType = {.b16, .b8}; .type: ScalarType = {.b16, .b8};
// .dst_fmt = { .b8x16 }; // .dst_fmt = { .b8x16 };
// .src_fmt = { .b6x16_p32, .b4x16_p64 }; // .src_fmt = { .b6x16_p32, .b4x16_p64 };
// https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-griddepcontrol
griddepcontrol.action => {
Instruction::GridDepControl {
data: action
}
}
.action: GridDepControlAction = { .launch_dependents, .wait };
// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-mma
mma.sync.aligned.m16n8k16.alayout.blayout.dtype.bf16.bf16.ctype d, a, b, c => {
if dtype != ScalarType::F32 || ctype != ScalarType::F32 {
state.errors.push(PtxError::Todo);
}
Instruction::Mma {
data: MmaDetails {
alayout,
blayout,
dtype_scalar: dtype,
atype_scalar: ScalarType::BF16,
btype_scalar: ScalarType::BF16,
ctype_scalar: ctype,
},
arguments: MmaArgs { dst: d, src1: a, src2: b, src3: c }
}
}
.alayout: MatrixLayout = {.row};
.blayout: MatrixLayout = {.col};
.ctype: ScalarType = {.f16, .f32};
.dtype: ScalarType = {.f16, .f32};
); );
#[cfg(test)] #[cfg(test)]

View file

@ -1,6 +1,4 @@
use bpaf::{any, doc::Style, Bpaf, Parser}; use bpaf::{any, choice, doc::Style, literal, Bpaf, Parser};
use hip_runtime_sys::{hipDeviceProp_tR0600, hipGetDevicePropertiesR0600};
use std::{ffi::CStr, mem};
#[derive(Debug, Clone, Bpaf)] #[derive(Debug, Clone, Bpaf)]
#[allow(dead_code)] #[allow(dead_code)]
@ -12,6 +10,8 @@ pub struct Options {
#[bpaf(short, long)] #[bpaf(short, long)]
verbose: bool, verbose: bool,
#[bpaf(external)] #[bpaf(external)]
lineinfo: bool,
#[bpaf(external)]
gpu_name: String, gpu_name: String,
#[bpaf(long, short('O'), fallback(3))] #[bpaf(long, short('O'), fallback(3))]
opt_level: usize, opt_level: usize,
@ -19,48 +19,32 @@ pub struct Options {
input: String, input: String,
} }
fn lineinfo() -> impl Parser<bool> {
choice(["-lineinfo", "--lineinfo"].into_iter().map(|s| {
literal(s)
.anywhere()
.optional()
.map(|_| true)
.fallback(false)
.boxed()
}))
}
// #[bpaf(long, long("gpu_name"), fallback_with(default_arch))] // #[bpaf(long, long("gpu_name"), fallback_with(default_arch))]
fn gpu_name() -> impl Parser<String> { fn gpu_name() -> impl Parser<String> {
any("", move |s: String| { any("", move |s: String| {
Some(s.strip_prefix("-arch=")?.to_owned()) Some(
s.strip_prefix("-arch=")
.or_else(|| s.strip_prefix("--gpu-name="))?
.to_owned(),
)
}) })
.metavar(&[("-arch=", Style::Literal), ("ARG", Style::Metavar)]) .metavar(&[("--gpu-name=", Style::Literal), ("SM", Style::Metavar)])
.anywhere() .anywhere()
.fallback_with(|| Ok::<String, &'static str>("sm_52".to_string())) .fallback_with(|| Ok::<String, &'static str>("sm_52".to_string()))
} }
fn main() { fn main() {
let options = options().run(); let options = options().run();
let comgr = comgr::Comgr::new().unwrap(); std::fs::copy(&options.input, &options.output).unwrap();
unsafe { hip_runtime_sys::hipInit(0) }.unwrap();
let mut dev_props: hipDeviceProp_tR0600 = unsafe { mem::zeroed() };
let (gpu_arch, clock_rate) = get_gpu_arch_and_clock_rate(&mut dev_props);
let input = std::fs::read_to_string(options.input).unwrap();
let ast = ptx_parser::parse_module_checked(&input).unwrap();
let llvm = ptx::to_llvm_module(
ast,
ptx::Attributes {
clock_rate: clock_rate as u32,
},
|_| {},
)
.unwrap();
let elf_binary = comgr::compile_bitcode(
&comgr,
gpu_arch,
&*llvm.llvm_ir.write_bitcode_to_memory(),
&*llvm.linked_bitcode(),
&*llvm.attributes_ir.write_bitcode_to_memory(),
None,
)
.unwrap();
std::fs::write(options.output, elf_binary).unwrap();
}
fn get_gpu_arch_and_clock_rate<'a>(dev_props: &'a mut hipDeviceProp_tR0600) -> (&'a str, i32) {
unsafe { hipGetDevicePropertiesR0600(dev_props, 0) }.unwrap();
let gcn_arch_name = &dev_props.gcnArchName;
let gcn_arch_name = unsafe { CStr::from_ptr(gcn_arch_name.as_ptr()) };
let gcn_arch_name = gcn_arch_name.to_str();
(gcn_arch_name.unwrap(), dev_props.clockRate)
} }

View file

@ -89,7 +89,15 @@ pub(crate) fn get_attribute(
*pi = 32; *pi = 32;
return Ok(()); return Ok(());
} }
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_TCC_DRIVER => { // TODO: maintain a table, certain RDNAs are 1/16, some are 1/32
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_SINGLE_TO_DOUBLE_PRECISION_PERF_RATIO => {
*pi = 32;
return Ok(());
}
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_TCC_DRIVER
| CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED
| CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MEMPOOL_SUPPORTED_HANDLE_TYPES
| CUdevice_attribute::CU_DEVICE_ATTRIBUTE_DMA_BUF_SUPPORTED => {
*pi = 0; *pi = 0;
return Ok(()); return Ok(());
} }
@ -211,9 +219,6 @@ pub(crate) fn get_attribute(
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_PERSISTING_L2_CACHE_SIZE => { CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_PERSISTING_L2_CACHE_SIZE => {
return get_device_prop(pi, dev_idx, |props| props.persistingL2CacheMaxSize) return get_device_prop(pi, dev_idx, |props| props.persistingL2CacheMaxSize)
} }
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_SINGLE_TO_DOUBLE_PRECISION_PERF_RATIO => {
return get_device_prop(pi, dev_idx, |props| props.singleToDoublePrecisionPerfRatio)
}
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_ACCESS_POLICY_WINDOW_SIZE => { CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_ACCESS_POLICY_WINDOW_SIZE => {
return get_device_prop(pi, dev_idx, |props| props.accessPolicyMaxWindowSize) return get_device_prop(pi, dev_idx, |props| props.accessPolicyMaxWindowSize)
} }

View file

@ -487,9 +487,9 @@ pub(crate) unsafe fn occupancy_max_active_blocks_per_multiprocessor_with_flags(
dynamic_smem_size: usize, dynamic_smem_size: usize,
flags: ::core::ffi::c_uint, flags: ::core::ffi::c_uint,
) -> hipError_t { ) -> hipError_t {
hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags( hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags(
num_blocks, num_blocks,
func.0.cast(), func,
block_size, block_size,
dynamic_smem_size, dynamic_smem_size,
flags, flags,

12
zluda/src/impl/hipfix.rs Normal file
View file

@ -0,0 +1,12 @@
// There's a bug in hipDrvPointerGetAttributes where it returns
// HIP_ERROR_INVALID_VALUE if the pointer is null. It works correctly for any
// other invalid pointer
pub(crate) fn get_attributes(
ptr: hip_runtime_sys::hipDeviceptr_t,
) -> hip_runtime_sys::hipDeviceptr_t {
if ptr.0.is_null() {
hip_runtime_sys::hipDeviceptr_t(usize::MAX as _)
} else {
ptr
}
}

View file

@ -7,6 +7,7 @@ pub(super) mod driver;
pub(super) mod event; pub(super) mod event;
pub(super) mod function; pub(super) mod function;
pub(super) mod graph; pub(super) mod graph;
pub(super) mod hipfix;
pub(super) mod kernel; pub(super) mod kernel;
pub(super) mod library; pub(super) mod library;
pub(super) mod memory; pub(super) mod memory;

View file

@ -20,6 +20,10 @@ impl ZludaObject for Module {
} }
} }
static EMPTY_PTX: &str = ".version 6.5
.target sm_30
.address_size 64";
// get_ptx takes an `image` that can be anything we support and returns a // get_ptx takes an `image` that can be anything we support and returns a
// String containing a ptx extracted from `image`. // String containing a ptx extracted from `image`.
fn get_ptx<'a>(image: CodeLibraryRef<'a>) -> Result<Cow<'a, str>, CUerror> { fn get_ptx<'a>(image: CodeLibraryRef<'a>) -> Result<Cow<'a, str>, CUerror> {
@ -58,11 +62,17 @@ fn cow_bytes_to_str<'a>(data: Cow<'a, [u8]>) -> Option<Cow<'a, str>> {
pub(crate) fn load_hip_module(library: CodeLibraryRef) -> Result<hipModule_t, CUerror> { pub(crate) fn load_hip_module(library: CodeLibraryRef) -> Result<hipModule_t, CUerror> {
let global_state = driver::global_state()?; let global_state = driver::global_state()?;
let text = get_ptx(library)?; let maybe_ptx = get_ptx(library);
let text = if cfg!(debug_assertions) {
maybe_ptx?
} else {
maybe_ptx.unwrap_or_else(|_| Cow::Borrowed(EMPTY_PTX))
};
let hip_properties = get_hip_properties()?; let hip_properties = get_hip_properties()?;
let gcn_arch = get_gcn_arch(&hip_properties)?; let gcn_arch = get_gcn_arch(&hip_properties)?;
let attributes = ptx::Attributes { let attributes = ExtraCacheAttributes {
clock_rate: hip_properties.clockRate as u32, clock_rate: hip_properties.clockRate as u32,
is_debug: cfg!(debug_assertions),
}; };
let mut cache_with_key = global_state.cache_path.as_ref().and_then(|p| { let mut cache_with_key = global_state.cache_path.as_ref().and_then(|p| {
let cache = zluda_cache::ModuleCache::open(p)?; let cache = zluda_cache::ModuleCache::open(p)?;
@ -84,6 +94,12 @@ pub(crate) fn load_hip_module(library: CodeLibraryRef) -> Result<hipModule_t, CU
Ok(hip_module) Ok(hip_module)
} }
#[derive(serde::Serialize)]
struct ExtraCacheAttributes {
is_debug: bool,
clock_rate: u32,
}
fn get_hip_properties<'a>() -> Result<hipDeviceProp_tR0600, CUerror> { fn get_hip_properties<'a>() -> Result<hipDeviceProp_tR0600, CUerror> {
let hip_dev = super::context::get_current_device()?; let hip_dev = super::context::get_current_device()?;
let mut props = unsafe { mem::zeroed() }; let mut props = unsafe { mem::zeroed() };
@ -100,7 +116,7 @@ fn get_cache_key<'a, 'b>(
global_state: &'static driver::GlobalState, global_state: &'static driver::GlobalState,
isa: &'a str, isa: &'a str,
text: &str, text: &str,
attributes: &ptx::Attributes, attributes: &impl serde::Serialize,
) -> Option<zluda_cache::ModuleKey<'a>> { ) -> Option<zluda_cache::ModuleKey<'a>> {
// Serialization here is deterministic. When marking a type with // Serialization here is deterministic. When marking a type with
// #[derive(serde::Serialize)] the derived implementation will just write // #[derive(serde::Serialize)] the derived implementation will just write
@ -129,7 +145,7 @@ fn load_cached_binary(
fn compile_from_ptx_and_cache( fn compile_from_ptx_and_cache(
comgr: &comgr::Comgr, comgr: &comgr::Comgr,
gcn_arch: &str, gcn_arch: &str,
attributes: ptx::Attributes, attributes: ExtraCacheAttributes,
text: &str, text: &str,
cache_with_key: &mut Option<(zluda_cache::ModuleCache, zluda_cache::ModuleKey)>, cache_with_key: &mut Option<(zluda_cache::ModuleCache, zluda_cache::ModuleKey)>,
) -> Result<Vec<u8>, CUerror> { ) -> Result<Vec<u8>, CUerror> {
@ -138,7 +154,14 @@ fn compile_from_ptx_and_cache(
} else { } else {
ptx_parser::parse_module_unchecked(text) ptx_parser::parse_module_unchecked(text)
}; };
let llvm_module = ptx::to_llvm_module(ast, attributes, |_| {}).map_err(|_| CUerror::UNKNOWN)?; let llvm_module = ptx::to_llvm_module(
ast,
ptx::Attributes {
clock_rate: attributes.clock_rate,
},
|_| {},
)
.map_err(|_| CUerror::UNKNOWN)?;
let elf_module = comgr::compile_bitcode( let elf_module = comgr::compile_bitcode(
comgr, comgr,
gcn_arch, gcn_arch,

View file

@ -2,7 +2,7 @@ use cuda_types::cuda::*;
use hip_runtime_sys::*; use hip_runtime_sys::*;
use std::{ffi::c_void, ptr}; use std::{ffi::c_void, ptr};
use crate::r#impl::driver; use crate::r#impl::{driver, hipfix};
// TODO: handlehipMemoryTypeUnregistered // TODO: handlehipMemoryTypeUnregistered
fn to_cu_memory_type(cu: hipMemoryType) -> Result<CUmemorytype, hipErrorCode_t> { fn to_cu_memory_type(cu: hipMemoryType) -> Result<CUmemorytype, hipErrorCode_t> {
@ -59,7 +59,12 @@ pub(crate) unsafe fn get_attributes(
data: &mut *mut ::core::ffi::c_void, data: &mut *mut ::core::ffi::c_void,
ptr: hipDeviceptr_t, ptr: hipDeviceptr_t,
) -> CUresult { ) -> CUresult {
hipDrvPointerGetAttributes(num_attributes, attributes, data, ptr)?; hipDrvPointerGetAttributes(
num_attributes,
attributes,
data,
hipfix::get_attributes(ptr),
)?;
let attributes = std::slice::from_raw_parts_mut(attributes, num_attributes as usize); let attributes = std::slice::from_raw_parts_mut(attributes, num_attributes as usize);
let data = std::slice::from_raw_parts_mut(data, num_attributes as usize); let data = std::slice::from_raw_parts_mut(data, num_attributes as usize);
for (attr, data_ptr) in attributes.iter().copied().zip(data.iter().copied()) { for (attr, data_ptr) in attributes.iter().copied().zip(data.iter().copied()) {
@ -88,7 +93,7 @@ mod tests {
use crate::tests::CudaApi; use crate::tests::CudaApi;
use cuda_macros::test_cuda; use cuda_macros::test_cuda;
use cuda_types::cuda::*; use cuda_types::cuda::*;
use std::{ffi::c_void, mem, ptr}; use std::{ffi::c_void, i32, mem, ptr, usize};
#[test_cuda] #[test_cuda]
pub unsafe fn unknown_ptr_attribute(api: impl CudaApi) { pub unsafe fn unknown_ptr_attribute(api: impl CudaApi) {
@ -162,4 +167,47 @@ mod tests {
); );
assert_eq!(context, CUcontext(ptr::null_mut())); assert_eq!(context, CUcontext(ptr::null_mut()));
} }
#[test_cuda]
pub unsafe fn null_ptr_attributes_success(api: impl CudaApi) {
api.cuInit(0);
api.cuCtxCreate_v2(&mut mem::zeroed(), 0, 0);
let mut context = CUcontext(1 as _);
let mut mem_type = mem::transmute::<_, CUmemorytype>(u32::MAX);
let mut dev_ptr = mem::transmute::<_, *mut c_void>(usize::MAX);
let mut host_ptr = mem::transmute::<_, *mut c_void>(usize::MAX);
let mut is_managed = true;
let mut ordinal = i32::MAX;
let mut attrs = [
CUpointer_attribute::CU_POINTER_ATTRIBUTE_CONTEXT,
CUpointer_attribute::CU_POINTER_ATTRIBUTE_MEMORY_TYPE,
CUpointer_attribute::CU_POINTER_ATTRIBUTE_DEVICE_POINTER,
CUpointer_attribute::CU_POINTER_ATTRIBUTE_HOST_POINTER,
CUpointer_attribute::CU_POINTER_ATTRIBUTE_IS_MANAGED,
CUpointer_attribute::CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
];
let mut values = [
std::ptr::from_mut(&mut context).cast::<c_void>(),
std::ptr::from_mut(&mut mem_type).cast::<c_void>(),
std::ptr::from_mut(&mut dev_ptr).cast::<c_void>(),
std::ptr::from_mut(&mut host_ptr).cast::<c_void>(),
std::ptr::from_mut(&mut is_managed).cast::<c_void>(),
std::ptr::from_mut(&mut ordinal).cast::<c_void>(),
];
assert_eq!(
CUresult::SUCCESS,
api.cuPointerGetAttributes_unchecked(
attrs.len() as u32,
attrs.as_mut_ptr(),
values.as_mut_ptr(),
CUdeviceptr_v2(ptr::null_mut())
)
);
assert_eq!(context, CUcontext(ptr::null_mut()));
assert_eq!(mem_type, CUmemorytype(0));
assert_eq!(dev_ptr, ptr::null_mut());
assert_eq!(host_ptr, ptr::null_mut());
assert_eq!(is_managed, false);
assert_eq!(ordinal, -2);
}
} }

View file

@ -43,6 +43,86 @@ pub(crate) unsafe fn device_get_count_v2(device_count: &mut ::core::ffi::c_uint)
rsmi_num_monitor_devices(device_count) rsmi_num_monitor_devices(device_count)
} }
pub(crate) unsafe fn device_get_handle_by_pci_bus_id_v2(
pci_bus_id: &std::ffi::CStr,
device: &mut cuda_types::nvml::nvmlDevice_t,
) -> nvmlReturn_t {
let pci = parse_pci_bus_id(pci_bus_id).ok_or(nvmlError_t::INVALID_ARGUMENT)?;
let bdfid = pci.to_bdfid();
let mut device_count = 0;
rsmi_num_monitor_devices(&mut device_count)?;
for dv_ind in 0..device_count {
let mut curr_bdfid = 0;
rsmi_dev_pci_id_get(dv_ind, &mut curr_bdfid)?;
if curr_bdfid == bdfid {
*device = Device { _index: dv_ind }.wrap();
return nvmlReturn_t::SUCCESS;
}
}
nvmlReturn_t::ERROR_NOT_FOUND
}
#[derive(Clone, Copy)]
struct PciBusId {
domain: u16,
bus: u8,
device: u8,
function: u8,
}
impl PciBusId {
fn to_bdfid(self) -> u64 {
((self.domain as u64) << 32)
| ((self.bus as u64) << 8)
| ((self.device as u64) << 3)
| (self.function as u64)
}
}
fn parse_pci_bus_id(id: &std::ffi::CStr) -> Option<PciBusId> {
let s = id.to_str().ok()?.trim();
let mut domain: u16 = 0;
let mut rest = s;
if let Some(colon1) = s.find(':') {
if colon1 == 4 {
domain = hex_u16(&s[..4])?;
rest = &s[5..];
}
}
let mut parts = rest.split(':');
let bus_part = parts.next()?;
let tail = parts.next()?;
if parts.next().is_some() {
return None;
}
let mut dev_func = tail.split('.');
let dev_part = dev_func.next()?;
let func_part = dev_func.next();
let function = match func_part {
Some(f) => hex_u8(f)?,
None => 0,
};
Some(PciBusId {
domain,
bus: hex_u8(bus_part)?,
device: hex_u8(dev_part)?,
function,
})
}
fn hex_u16(s: &str) -> Option<u16> {
if s.len() > 4 {
return None;
}
u16::from_str_radix(s, 16).ok()
}
fn hex_u8(s: &str) -> Option<u8> {
if s.len() > 2 {
return None;
}
u8::from_str_radix(s, 16).ok()
}
pub(crate) unsafe fn device_get_field_values( pub(crate) unsafe fn device_get_field_values(
_device: &Device, _device: &Device,
values_count: ::core::ffi::c_int, values_count: ::core::ffi::c_int,
@ -75,3 +155,36 @@ pub(crate) fn device_get_handle_by_index_v2(
*device = Device { _index: index }.wrap(); *device = Device { _index: index }.wrap();
nvmlReturn_t::SUCCESS nvmlReturn_t::SUCCESS
} }
#[cfg(test)]
mod tests {
#[test]
fn parse_pci_bus_id_full() {
let id = std::ffi::CString::new("0100:65:a0.f").unwrap();
let parsed = super::parse_pci_bus_id(&id).unwrap();
assert_eq!(parsed.domain, 0x0100);
assert_eq!(parsed.bus, 0x65);
assert_eq!(parsed.device, 0xa0);
assert_eq!(parsed.function, 0xf);
}
#[test]
fn parse_pci_bus_id_no_func() {
let id = std::ffi::CString::new("0100:65:a0").unwrap();
let parsed = super::parse_pci_bus_id(&id).unwrap();
assert_eq!(parsed.domain, 0x0100);
assert_eq!(parsed.bus, 0x65);
assert_eq!(parsed.device, 0xa0);
assert_eq!(parsed.function, 0);
}
#[test]
fn parse_pci_bus_id_no_domain() {
let id = std::ffi::CString::new("65:a0.f").unwrap();
let parsed = super::parse_pci_bus_id(&id).unwrap();
assert_eq!(parsed.domain, 0);
assert_eq!(parsed.bus, 0x65);
assert_eq!(parsed.device, 0xa0);
assert_eq!(parsed.function, 0xf);
}
}

View file

@ -23,6 +23,13 @@ pub(crate) unsafe fn device_get_count_v2(_device_count: &mut ::core::ffi::c_uint
crate::impl_common::unimplemented() crate::impl_common::unimplemented()
} }
pub(crate) unsafe fn device_get_handle_by_pci_bus_id_v2(
pci_bus_id: &std::ffi::CStr,
device: &mut cuda_types::nvml::nvmlDevice_t,
) -> nvmlReturn_t {
crate::impl_common::unimplemented()
}
pub(crate) unsafe fn device_get_field_values( pub(crate) unsafe fn device_get_field_values(
_device: cuda_types::nvml::nvmlDevice_t, _device: cuda_types::nvml::nvmlDevice_t,
_values_count: ::core::ffi::c_int, _values_count: ::core::ffi::c_int,
@ -31,10 +38,6 @@ pub(crate) unsafe fn device_get_field_values(
crate::impl_common::unimplemented() crate::impl_common::unimplemented()
} }
unsafe fn get_field_value(_field: &mut nvmlFieldValue_st) -> Result<(), nvmlError_t> {
crate::impl_common::unimplemented()
}
pub(crate) unsafe fn device_get_gpu_fabric_info( pub(crate) unsafe fn device_get_gpu_fabric_info(
_device: cuda_types::nvml::nvmlDevice_t, _device: cuda_types::nvml::nvmlDevice_t,
_gpu_fabric_info: &mut cuda_types::nvml::nvmlGpuFabricInfo_t, _gpu_fabric_info: &mut cuda_types::nvml::nvmlGpuFabricInfo_t,

View file

@ -48,6 +48,7 @@ cuda_macros::nvml_function_declarations!(
nvmlDeviceGetFieldValues, nvmlDeviceGetFieldValues,
nvmlDeviceGetGpuFabricInfo, nvmlDeviceGetGpuFabricInfo,
nvmlDeviceGetHandleByIndex_v2, nvmlDeviceGetHandleByIndex_v2,
nvmlDeviceGetHandleByPciBusId_v2,
nvmlInit, nvmlInit,
nvmlInitWithFlags, nvmlInitWithFlags,
nvmlInit_v2, nvmlInit_v2,

View file

@ -303,6 +303,7 @@ pub(crate) enum ErrorEntry {
}, },
NullPointer(&'static str), NullPointer(&'static str),
UnknownLibrary(CUlibrary), UnknownLibrary(CUlibrary),
SavedModule(String),
} }
unsafe impl Send for ErrorEntry {} unsafe impl Send for ErrorEntry {}
@ -344,93 +345,94 @@ impl Display for ErrorEntry {
match self { match self {
ErrorEntry::IoError(e) => e.fmt(f), ErrorEntry::IoError(e) => e.fmt(f),
ErrorEntry::CreatedDumpDirectory(dir) => { ErrorEntry::CreatedDumpDirectory(dir) => {
write!( write!(
f, f,
"Created trace directory {} ", "Created trace directory {} ",
dir.as_os_str().to_string_lossy() dir.as_os_str().to_string_lossy()
) )
} }
ErrorEntry::ErrorBox(e) => e.fmt(f), ErrorEntry::ErrorBox(e) => e.fmt(f),
ErrorEntry::UnsupportedModule { ErrorEntry::UnsupportedModule {
module, module,
raw_image, raw_image,
kind, kind,
} => { } => {
write!( write!(
f, f,
"Unsupported {} module {:?} loaded from module image {:?}", "Unsupported {} module {:?} loaded from module image {:?}",
kind, module, raw_image kind, module, raw_image
) )
} }
ErrorEntry::MalformedModulePath(e) => e.fmt(f), ErrorEntry::MalformedModulePath(e) => e.fmt(f),
ErrorEntry::NonUtf8ModuleText(e) => e.fmt(f), ErrorEntry::NonUtf8ModuleText(e) => e.fmt(f),
ErrorEntry::ModuleParsingError(file_name) => { ErrorEntry::ModuleParsingError(file_name) => {
write!( write!(
f, f,
"Error parsing module, log has been written to {}", "Error parsing module, log has been written to {}",
file_name file_name
) )
} }
ErrorEntry::NulInsideModuleText(e) => e.fmt(f), ErrorEntry::NulInsideModuleText(e) => e.fmt(f),
ErrorEntry::Lz4DecompressionFailure => write!(f, "LZ4 decompression failure"), ErrorEntry::Lz4DecompressionFailure => write!(f, "LZ4 decompression failure"),
ErrorEntry::ZstdDecompressionFailure(err_code) => write!(f, "Zstd decompression failure: {}", zstd_safe::get_error_name(*err_code)), ErrorEntry::ZstdDecompressionFailure(err_code) => write!(f, "Zstd decompression failure: {}", zstd_safe::get_error_name(*err_code)),
ErrorEntry::UnexpectedBinaryField { ErrorEntry::UnexpectedBinaryField {
field_name, field_name,
expected, expected,
observed, observed,
} => write!( } => write!(
f, f,
"Unexpected field {}. Expected one of: [{}], observed: {}", "Unexpected field {}. Expected one of: [{}], observed: {}",
field_name, field_name,
expected expected
.iter() .iter()
.map(|x| x.to_string()) .map(|x| x.to_string())
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(", "), .join(", "),
observed observed
), ),
ErrorEntry::UnexpectedArgument { ErrorEntry::UnexpectedArgument {
arg_name, arg_name,
expected, expected,
observed, observed,
} => write!( } => write!(
f, f,
"Unexpected argument {}. Expected one of: {{{}}}, observed: {}", "Unexpected argument {}. Expected one of: {{{}}}, observed: {}",
arg_name, arg_name,
expected expected
.iter() .iter()
.map(|x| x.to_string()) .map(|x| x.to_string())
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(", "), .join(", "),
observed observed
), ),
ErrorEntry::InvalidEnvVar { ErrorEntry::InvalidEnvVar {
var, var,
pattern, pattern,
value, value,
} => write!( } => write!(
f, f,
"Unexpected value of environment variable {var}. Expected pattern: {pattern}, got value: {value}" "Unexpected value of environment variable {var}. Expected pattern: {pattern}, got value: {value}"
), ),
ErrorEntry::FunctionNotFound(cuda_function_name) => write!( ErrorEntry::FunctionNotFound(cuda_function_name) => write!(
f, f,
"No function {cuda_function_name} in the underlying library" "No function {cuda_function_name} in the underlying library"
), ),
ErrorEntry::UnexpectedExportTableSize { expected, computed } => { ErrorEntry::UnexpectedExportTableSize { expected, computed } => {
write!(f, "Table length mismatch. Expected: {expected}, got: {computed}") write!(f, "Table length mismatch. Expected: {expected}, got: {computed}")
} }
ErrorEntry::IntegrityCheck { original, overriden } => { ErrorEntry::IntegrityCheck { original, overriden } => {
write!(f, "Overriding integrity check hash. Original: {original:?}, overriden: {overriden:?}") write!(f, "Overriding integrity check hash. Original: {original:?}, overriden: {overriden:?}")
} }
ErrorEntry::NullPointer(type_) => { ErrorEntry::NullPointer(type_) => {
write!(f, "Null pointer of type {type_} encountered") write!(f, "Null pointer of type {type_} encountered")
} }
ErrorEntry::UnknownLibrary(culibrary) => { ErrorEntry::UnknownLibrary(culibrary) => {
write!(f, "Unknown library: ")?; write!(f, "Unknown library: ")?;
let mut temp_buffer = Vec::new(); let mut temp_buffer = Vec::new();
CudaDisplay::write(culibrary, "", 0, &mut temp_buffer).ok(); CudaDisplay::write(culibrary, "", 0, &mut temp_buffer).ok();
f.write_str(&unsafe { String::from_utf8_unchecked(temp_buffer) }) f.write_str(&unsafe { String::from_utf8_unchecked(temp_buffer) })
} }
ErrorEntry::SavedModule(file) => write!(f, "Saved module to {file}"),
} }
} }
} }

View file

@ -128,12 +128,11 @@ impl StateTracker {
fn_logger: &mut FnCallLog, fn_logger: &mut FnCallLog,
type_: &'static str, type_: &'static str,
) { ) {
fn_logger.log_io_error(self.writer.save_module( fn_logger.try_(|fn_logger| {
self.library_counter, self.writer
index, .save_module(fn_logger, self.library_counter, index, submodule, type_)
submodule, .map_err(ErrorEntry::IoError)
type_, });
));
if type_ == "ptx" { if type_ == "ptx" {
match CString::new(submodule) { match CString::new(submodule) {
Err(e) => fn_logger.log(log::ErrorEntry::NulInsideModuleText(e)), Err(e) => fn_logger.log(log::ErrorEntry::NulInsideModuleText(e)),
@ -323,6 +322,7 @@ impl DumpWriter {
fn save_module( fn save_module(
&self, &self,
fn_logger: &mut FnCallLog,
module_index: usize, module_index: usize,
submodule_index: Option<(usize, Option<usize>)>, submodule_index: Option<(usize, Option<usize>)>,
buffer: &[u8], buffer: &[u8],
@ -332,9 +332,13 @@ impl DumpWriter {
None => return Ok(()), None => return Ok(()),
Some(d) => d.clone(), Some(d) => d.clone(),
}; };
dump_file.push(Self::get_file_name(module_index, submodule_index, kind)); let file_name = Self::get_file_name(module_index, submodule_index, kind);
let mut file = File::create_new(dump_file)?; dump_file.push(&file_name);
file.write_all(buffer)?; {
let mut file = File::create_new(dump_file)?;
file.write_all(buffer)?;
}
fn_logger.log(ErrorEntry::SavedModule(file_name));
Ok(()) Ok(())
} }
@ -349,7 +353,7 @@ impl DumpWriter {
Some(d) => d.clone(), Some(d) => d.clone(),
}; };
log_file.push(Self::get_file_name(module_index, submodule_index, "log")); log_file.push(Self::get_file_name(module_index, submodule_index, "log"));
let mut file = File::create(log_file)?; let mut file = File::create_new(log_file)?;
for error in errors { for error in errors {
writeln!(file, "{}", error)?; writeln!(file, "{}", error)?;
} }