mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-09-21 08:49:04 +00:00
More runtime fixes, add mma instruction (#509)
This commit is contained in:
parent
150ce171cf
commit
b5f41c7cd0
27 changed files with 639 additions and 154 deletions
2
.github/ISSUE_TEMPLATE/zluda_dump.yml
vendored
2
.github/ISSUE_TEMPLATE/zluda_dump.yml
vendored
|
@ -45,7 +45,7 @@ body:
|
|||
./train_gpt2fp32cu
|
||||
4. Build and run the tests:
|
||||
make test_gpt2fp32cu
|
||||
LD_LIBRARY_PATH=<ZLUDA_TRACE_DIR> ./test_gpt2fp32cu
|
||||
LD_LIBRARY_PATH=<ZLUDA_LOG_DIR> ./test_gpt2fp32cu
|
||||
validations:
|
||||
required: true
|
||||
- type: input
|
||||
|
|
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -2573,7 +2573,6 @@ dependencies = [
|
|||
"ptx_parser",
|
||||
"quick-error",
|
||||
"rustc-hash 2.0.0",
|
||||
"serde",
|
||||
"smallvec",
|
||||
"strum 0.26.3",
|
||||
"strum_macros 0.26.4",
|
||||
|
|
|
@ -219,6 +219,10 @@ pub fn compile_bitcode(
|
|||
compile_to_exec.set_isa_name(gcn_arch)?;
|
||||
compile_to_exec.set_language(Language::LlvmIr)?;
|
||||
let common_options = [
|
||||
c"-mllvm",
|
||||
c"-ignore-tti-inline-compatible",
|
||||
// c"-mllvm",
|
||||
// c"-amdgpu-early-inline-all=true",
|
||||
// This makes no sense, but it makes ockl linking work
|
||||
c"-Xclang",
|
||||
c"-mno-link-builtin-bitcode-postopt",
|
||||
|
@ -237,8 +241,7 @@ pub fn compile_bitcode(
|
|||
]
|
||||
.into_iter();
|
||||
let opt_options = if cfg!(debug_assertions) {
|
||||
//[c"-g", c"-mllvm", c"-print-before-all", c"", c""]
|
||||
[c"-g", c"", c"", c"", c""]
|
||||
[c"-g", c"-mamdgpu-precise-memory-op", c"", c"", c""]
|
||||
} else {
|
||||
[
|
||||
c"-g0",
|
||||
|
|
|
@ -21,9 +21,14 @@ pub struct Options {
|
|||
output_dir: Option<PathBuf>,
|
||||
|
||||
#[bpaf(long("arch"))]
|
||||
/// Target architecture
|
||||
/// Target GPU architecture
|
||||
arch: Option<String>,
|
||||
|
||||
#[bpaf(long("ignore-errors"))]
|
||||
/// Try to ignore errors. This will try and produce output even if there are
|
||||
/// parsing errors (e.g. an unimplemented instruction)
|
||||
ignore_errors: bool,
|
||||
|
||||
#[bpaf(positional("filename"))]
|
||||
/// PTX file
|
||||
ptx_path: String,
|
||||
|
@ -48,7 +53,10 @@ fn main_core() -> Result<(), CompilerError> {
|
|||
.unwrap_or("output");
|
||||
|
||||
let mut output_path = match opts.output_dir {
|
||||
Some(value) => value,
|
||||
Some(value) => {
|
||||
std::fs::create_dir_all(&value)?;
|
||||
value
|
||||
}
|
||||
None => match ptx_path.parent() {
|
||||
Some(dir) => dir.to_path_buf(),
|
||||
None => env::current_dir()?,
|
||||
|
@ -68,7 +76,7 @@ fn main_core() -> Result<(), CompilerError> {
|
|||
|
||||
let ptx = fs::read(&ptx_path).map_err(CompilerError::from)?;
|
||||
let ptx = str::from_utf8(&ptx).map_err(CompilerError::from)?;
|
||||
let llvm = ptx_to_llvm(ptx).map_err(CompilerError::from)?;
|
||||
let llvm = ptx_to_llvm(opts.ignore_errors, ptx).map_err(CompilerError::from)?;
|
||||
|
||||
write_to_file(&llvm.llvm_ir, output_path.with_extension("ll").as_path())?;
|
||||
|
||||
|
@ -92,8 +100,12 @@ fn main_core() -> Result<(), CompilerError> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn ptx_to_llvm(ptx: &str) -> Result<LLVMArtifacts, CompilerError> {
|
||||
let ast = ptx_parser::parse_module_checked(ptx).map_err(CompilerError::from)?;
|
||||
fn ptx_to_llvm(ignore_errors: bool, ptx: &str) -> Result<LLVMArtifacts, CompilerError> {
|
||||
let ast = if ignore_errors {
|
||||
ptx_parser::parse_module_unchecked(ptx)
|
||||
} else {
|
||||
ptx_parser::parse_module_checked(ptx).map_err(CompilerError::from)?
|
||||
};
|
||||
let mut start = Instant::now();
|
||||
let module = ptx::to_llvm_module(
|
||||
ast,
|
||||
|
|
|
@ -116,7 +116,7 @@ in order to demonstrate all of zluda_trace's features.
|
|||
|
||||
```bash
|
||||
nvcc add.cu -o add -arch sm_80
|
||||
LD_LIBRARY_PATH=~/ZLUDA/target/release/trace/ ZLUDA_TRACE_DIR=/tmp/zluda ./add
|
||||
LD_LIBRARY_PATH=~/ZLUDA/target/release/trace/ ZLUDA_LOG_DIR=/tmp/zluda ./add
|
||||
```
|
||||
|
||||
The last few lines should look something like:
|
||||
|
|
|
@ -22,7 +22,6 @@ microlp = "0.2.11"
|
|||
int-enum = "1.1"
|
||||
unwrap_or = "1.0.1"
|
||||
smallvec = "1.15.1"
|
||||
serde = { version = "1.0.219", features = ["derive"] }
|
||||
|
||||
[dev-dependencies]
|
||||
hip_runtime-sys = { path = "../ext/hip_runtime-sys" }
|
||||
|
|
Binary file not shown.
|
@ -1,17 +1,21 @@
|
|||
// Every time this file changes it must te rebuilt, you need `rocm-llvm-dev` and `llvm-17`
|
||||
// `fdenormal-fp-math=dynamic` is required to make functions eligible for inlining
|
||||
// /opt/rocm/llvm/bin/clang -std=c++20 -Xclang -fdenormal-fp-math=dynamic -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1010 && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1010\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | sed 's/ nneg / /g' | sed 's/ disjoint / /g' | sed '/__hip_cuid/d' | sed 's/external protected/external hidden/g' | sed 's/trunc nuw/trunc/' | sed 's/trunc nsw/trunc/' | llvm-as-17 - -o zluda_ptx_impl.bc && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc
|
||||
// /opt/rocm/llvm/bin/clang -std=c++20 -Xclang -fdenormal-fp-math=dynamic -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1100 && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1100\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | sed 's/ nneg / /g' | sed 's/ disjoint / /g' | sed '/__hip_cuid/d' | sed 's/external protected/external hidden/g' | sed 's/trunc nuw/trunc/' | sed 's/trunc nsw/trunc/' | llvm-as-17 - -o zluda_ptx_impl.bc && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <bit>
|
||||
#include <cmath>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hip/amd_detail/amd_device_functions.h>
|
||||
#include <hip/hip_fp8.h>
|
||||
|
||||
#define SHARED_SPACE __attribute__((address_space(3)))
|
||||
#define CONSTANT_SPACE __attribute__((address_space(4)))
|
||||
|
||||
typedef _Float16 half16 __attribute__((ext_vector_type(16)));
|
||||
typedef float float8 __attribute__((ext_vector_type(8)));
|
||||
|
||||
#define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_##NAME
|
||||
#define ATTR(NAME) __ZLUDA_PTX_IMPL_ATTRIBUTE_##NAME
|
||||
#define DECLARE_ATTR(TYPE, NAME) \
|
||||
|
@ -624,4 +628,156 @@ typedef uint32_t ShflSyncResult __attribute__((ext_vector_type(2)));
|
|||
uint32_t x3 = load_single_matrix_trans(address, 24);
|
||||
return uint4::Native_vec_{x0, x1, x2, x3};
|
||||
}
|
||||
|
||||
static inline __device__ _Float16 top16_as_fp16(uint32_t value) {
|
||||
uint16_t half_bits = static_cast<uint16_t>((value >> 16) & 0xFFFF);
|
||||
return *reinterpret_cast<_Float16*>(&half_bits);
|
||||
}
|
||||
static inline __device__ _Float16 bottom16_as_fp16(uint32_t value) {
|
||||
uint16_t half_bits = static_cast<uint16_t>(value & 0xFFFF);
|
||||
return *reinterpret_cast<_Float16*>(&half_bits);
|
||||
}
|
||||
|
||||
static inline __device__ float bpermute_lane(int lane, float x) {
|
||||
return __hip_ds_bpermutef(4 * lane, x);
|
||||
}
|
||||
static inline __device__ uint32_t bpermute_lane(int lane, uint32_t x) {
|
||||
return __hip_ds_bpermute(4 * lane, x);
|
||||
}
|
||||
|
||||
static __device__ half16 shuffle_a(uint4::Native_vec_ a_reg) {
|
||||
const unsigned lIdx = threadIdx.x;
|
||||
const int lane = lIdx % 16; // Lanes 0-15 (the other 16 lanes are a duplicate in w32 mode)
|
||||
half16 aFrag;
|
||||
|
||||
for (int vGPR = 0; vGPR < 8; ++vGPR) {
|
||||
int cudaChunk = (vGPR / 4) * 2; // will be 0 or 2
|
||||
int cudaTID = (vGPR % 4 + lane * 4) % 32;
|
||||
uint32_t reg0, reg1;
|
||||
// Select the two consecutive elements from a_reg:
|
||||
if (cudaChunk == 0) {
|
||||
reg0 = a_reg.x;
|
||||
reg1 = a_reg.y;
|
||||
} else { // cudaChunk==2
|
||||
reg0 = a_reg.z;
|
||||
reg1 = a_reg.w;
|
||||
}
|
||||
uint32_t a_tmp0 = bpermute_lane(cudaTID, reg0);
|
||||
uint32_t a_tmp1 = bpermute_lane(cudaTID, reg1);
|
||||
uint32_t a_Frag_reg = (lane < 8) ? a_tmp0 : a_tmp1;
|
||||
aFrag[2 * vGPR] = bottom16_as_fp16(a_Frag_reg);
|
||||
aFrag[2 * vGPR + 1] = top16_as_fp16(a_Frag_reg);
|
||||
}
|
||||
return aFrag;
|
||||
}
|
||||
|
||||
static __device__ half16 shuffle_b(uint2::Native_vec_ b_reg) {
|
||||
const unsigned lIdx = threadIdx.x;
|
||||
const int lane = lIdx % 16;
|
||||
half16 bFrag;
|
||||
|
||||
for (int vGPR = 0; vGPR < 8; ++vGPR) {
|
||||
int cudaChunk = vGPR / 4; // will be 0 or 1
|
||||
int cudaTID = vGPR % 4 + (lane * 4) % 64;
|
||||
uint32_t reg = (cudaChunk == 0) ? b_reg.x : b_reg.y;
|
||||
uint32_t b_Frag_reg = bpermute_lane(cudaTID, reg);
|
||||
if (lane < 8) {
|
||||
bFrag[2 * vGPR] = bottom16_as_fp16(b_Frag_reg);
|
||||
bFrag[2 * vGPR + 1] = top16_as_fp16(b_Frag_reg);
|
||||
} else {
|
||||
bFrag[2 * vGPR] = 0.0f;
|
||||
bFrag[2 * vGPR + 1] = 0.0f;
|
||||
}
|
||||
}
|
||||
return bFrag;
|
||||
}
|
||||
|
||||
static __device__ float8 shuffle_c(float4::Native_vec_ c_reg) {
|
||||
const int lIdx = (int)threadIdx.x;
|
||||
float8 cFrag;
|
||||
|
||||
// Loop over the eight vector GPRs.
|
||||
for (int vGPR = 0; vGPR < 8; ++vGPR) {
|
||||
int cudaChunk = (vGPR / 4) * 2; // will be 0 or 2: selects which pair of components to use.
|
||||
int lIdx8 = (lIdx < 8) ? lIdx : lIdx - 8;
|
||||
int cudaTID = (vGPR % 4) * 8 + lIdx8 / 2;
|
||||
float ctmp0, ctmp1;
|
||||
|
||||
if (cudaChunk == 0) {
|
||||
ctmp0 = bpermute_lane(cudaTID, c_reg.x);
|
||||
ctmp1 = bpermute_lane(cudaTID, c_reg.y);
|
||||
} else { // cudaChunk == 2
|
||||
ctmp0 = bpermute_lane(cudaTID, c_reg.z);
|
||||
ctmp1 = bpermute_lane(cudaTID, c_reg.w);
|
||||
}
|
||||
|
||||
// Select one of the two values based on the thread index's LSB.
|
||||
cFrag[vGPR] = (lIdx & 1) ? ctmp1 : ctmp0;
|
||||
|
||||
// Zero out for specific thread indices.
|
||||
if ((lIdx > 7 && lIdx < 16) || (lIdx > 23 && lIdx < 32))
|
||||
cFrag[vGPR] = 0.0f;
|
||||
}
|
||||
return cFrag;
|
||||
}
|
||||
|
||||
static inline __device__ float4::Native_vec_ shuffle_d(float8 dFrag) {
|
||||
const int lIdx = (int)threadIdx.x;
|
||||
float4::Native_vec_ d_out;
|
||||
|
||||
for (int cChunk = 0; cChunk < 4; ++cChunk) {
|
||||
int r_vGPR = (cChunk / 2) * 4;
|
||||
int add8 = (lIdx & 0x4) ? 8 : 0;
|
||||
int r_lIdx = (cChunk % 2) + (lIdx % 8) * 2 + add8;
|
||||
float d_tmp0 = bpermute_lane(r_lIdx, dFrag[r_vGPR]);
|
||||
float d_tmp1 = bpermute_lane(r_lIdx, dFrag[r_vGPR + 1]);
|
||||
float d_tmp2 = bpermute_lane(r_lIdx, dFrag[r_vGPR + 2]);
|
||||
float d_tmp3 = bpermute_lane(r_lIdx, dFrag[r_vGPR + 3]);
|
||||
float val;
|
||||
if (lIdx < 8) {
|
||||
val = d_tmp0;
|
||||
} else if (lIdx < 16) {
|
||||
val = d_tmp1;
|
||||
} else if (lIdx < 24) {
|
||||
val = d_tmp2;
|
||||
} else {
|
||||
val = d_tmp3;
|
||||
}
|
||||
if (cChunk == 0) d_out.x = val;
|
||||
else if (cChunk == 1) d_out.y = val;
|
||||
else if (cChunk == 2) d_out.z = val;
|
||||
else d_out.w = val;
|
||||
}
|
||||
return d_out;
|
||||
}
|
||||
|
||||
float4::Native_vec_ FUNC(mma_sync_aligned_m16n8k16_row_col_f32_f16_f16_f32)(uint4::Native_vec_ a_reg, uint2::Native_vec_ b_reg, float4::Native_vec_ c_reg) {
|
||||
// Reshuffle from Nvidia-like register layout to AMD layout:
|
||||
half16 aFrag = shuffle_a(a_reg);
|
||||
half16 bFrag = shuffle_b(b_reg);
|
||||
float8 cFrag = shuffle_c(c_reg);
|
||||
|
||||
// Call the (built‐in) 16x16 MMA instruction. It returns a float8.
|
||||
float8 dFrag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(aFrag, bFrag, cFrag);
|
||||
|
||||
// Unshuffle back into Nvidia expected float4 result
|
||||
float4::Native_vec_ d_out = shuffle_d(dFrag);
|
||||
|
||||
return d_out;
|
||||
}
|
||||
|
||||
float4::Native_vec_ FUNC(mma_sync_aligned_m16n8k16_row_col_f32_bf16_bf16_f32)(uint4::Native_vec_ a_reg, uint2::Native_vec_ b_reg, float4::Native_vec_ c_reg) {
|
||||
// Reshuffle from Nvidia-like register layout to AMD layout:
|
||||
half16 aFrag = shuffle_a(a_reg);
|
||||
half16 bFrag = shuffle_b(b_reg);
|
||||
float8 cFrag = shuffle_c(c_reg);
|
||||
|
||||
// Call the (built‐in) 16x16 MMA instruction. It returns a float8.
|
||||
float8 dFrag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(aFrag, bFrag, cFrag);
|
||||
|
||||
// Unshuffle back into Nvidia expected float4 result
|
||||
float4::Native_vec_ d_out = shuffle_d(dFrag);
|
||||
|
||||
return d_out;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -197,7 +197,9 @@ fn run_instruction<'input>(
|
|||
| ast::Instruction::Xor { .. }
|
||||
| ast::Instruction::Vote { .. }
|
||||
| ast::Instruction::ReduxSync { .. }
|
||||
| ast::Instruction::LdMatrix { .. } => result.push(Statement::Instruction(instruction)),
|
||||
| ast::Instruction::GridDepControl { .. }
|
||||
| ast::Instruction::LdMatrix { .. }
|
||||
| ast::Instruction::Mma { .. } => result.push(Statement::Instruction(instruction)),
|
||||
ast::Instruction::Add {
|
||||
data:
|
||||
ast::ArithDetails::Float(ast::ArithFloat {
|
||||
|
|
|
@ -1855,7 +1855,9 @@ fn get_modes<T: ast::Operand>(inst: &ast::Instruction<T>) -> InstructionModes {
|
|||
| ast::Instruction::AtomCas { .. }
|
||||
| ast::Instruction::Vote { .. }
|
||||
| ast::Instruction::ReduxSync { .. }
|
||||
| ast::Instruction::LdMatrix { .. } => InstructionModes::none(),
|
||||
| ast::Instruction::GridDepControl { .. }
|
||||
| ast::Instruction::LdMatrix { .. }
|
||||
| ast::Instruction::Mma { .. } => InstructionModes::none(),
|
||||
ast::Instruction::Add {
|
||||
data: ast::ArithDetails::Integer(_),
|
||||
..
|
||||
|
|
|
@ -153,6 +153,9 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
|
|||
for (i, param) in method.input_arguments.iter().enumerate() {
|
||||
let value = unsafe { LLVMGetParam(fn_, i as u32) };
|
||||
let name = self.resolver.get_or_add(param.name);
|
||||
if let Some(align) = param.align {
|
||||
unsafe { LLVMSetParamAlignment(value, align) };
|
||||
}
|
||||
unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) };
|
||||
self.resolver.register(param.name, value);
|
||||
if method.is_kernel {
|
||||
|
@ -519,6 +522,7 @@ impl<'a> MethodEmitContext<'a> {
|
|||
ast::Instruction::CpAsyncCommitGroup {} => Ok(()), // nop
|
||||
ast::Instruction::CpAsyncWaitGroup { .. } => Ok(()), // nop
|
||||
ast::Instruction::CpAsyncWaitAll { .. } => Ok(()), // nop
|
||||
ast::Instruction::GridDepControl { .. } => Ok(()), // nop
|
||||
// replaced by a function call
|
||||
ast::Instruction::Bfe { .. }
|
||||
| ast::Instruction::Bar { .. }
|
||||
|
@ -529,7 +533,8 @@ impl<'a> MethodEmitContext<'a> {
|
|||
| ast::Instruction::Vote { .. }
|
||||
| ast::Instruction::Nanosleep { .. }
|
||||
| ast::Instruction::ReduxSync { .. }
|
||||
| ast::Instruction::LdMatrix { .. } => return Err(error_unreachable()),
|
||||
| ast::Instruction::LdMatrix { .. }
|
||||
| ast::Instruction::Mma { .. } => return Err(error_unreachable()),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ quick_error! {
|
|||
}
|
||||
|
||||
/// GPU attributes needed at compile time.
|
||||
#[derive(serde::Serialize)]
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct Attributes {
|
||||
/// Clock frequency in kHz.
|
||||
pub clock_rate: u32,
|
||||
|
|
|
@ -351,6 +351,35 @@ fn run_instruction<'input>(
|
|||
let name = "sqrt_rn_ftz_f32";
|
||||
to_call(resolver, fn_declarations, name.into(), i)?
|
||||
}
|
||||
i @ ptx_parser::Instruction::Mma {
|
||||
data:
|
||||
ast::MmaDetails {
|
||||
alayout,
|
||||
blayout,
|
||||
dtype_scalar,
|
||||
atype_scalar,
|
||||
btype_scalar,
|
||||
ctype_scalar,
|
||||
},
|
||||
..
|
||||
} => {
|
||||
let name = format!(
|
||||
"mma_sync_aligned_m16n8k16_{}_{}_{}_{}_{}_{}",
|
||||
match alayout {
|
||||
ast::MatrixLayout::Row => "row",
|
||||
ast::MatrixLayout::Col => "col",
|
||||
},
|
||||
match blayout {
|
||||
ast::MatrixLayout::Row => "row",
|
||||
ast::MatrixLayout::Col => "col",
|
||||
},
|
||||
scalar_to_ptx_name(dtype_scalar),
|
||||
scalar_to_ptx_name(atype_scalar),
|
||||
scalar_to_ptx_name(btype_scalar),
|
||||
scalar_to_ptx_name(ctype_scalar),
|
||||
);
|
||||
to_call(resolver, fn_declarations, name.into(), i)?
|
||||
}
|
||||
i @ ptx_parser::Instruction::Sqrt {
|
||||
data:
|
||||
ast::RcpData {
|
||||
|
|
|
@ -3,8 +3,8 @@ use super::{
|
|||
StateSpace, VectorPrefix,
|
||||
};
|
||||
use crate::{
|
||||
FunnelShiftMode, MatrixNumber, MatrixShape, Mul24Control, PtxError, PtxParserState, Reduction,
|
||||
ShiftDirection, ShuffleMode, VoteMode,
|
||||
FunnelShiftMode, MatrixLayout, MatrixNumber, MatrixShape, Mul24Control, PtxError,
|
||||
PtxParserState, Reduction, ShiftDirection, ShuffleMode, VoteMode,
|
||||
};
|
||||
use bitflags::bitflags;
|
||||
use derive_more::Display;
|
||||
|
@ -721,6 +721,30 @@ ptx_parser_macros::generate_instruction_type!(
|
|||
space: { data.state_space },
|
||||
}
|
||||
}
|
||||
},
|
||||
GridDepControl {
|
||||
data: crate::GridDepControlAction,
|
||||
},
|
||||
Mma {
|
||||
data: MmaDetails,
|
||||
arguments<T>: {
|
||||
dst: {
|
||||
repr: T,
|
||||
type: { data.dtype() },
|
||||
},
|
||||
src1: {
|
||||
repr: T,
|
||||
type: { data.atype() },
|
||||
},
|
||||
src2: {
|
||||
repr: T,
|
||||
type: { data.btype() },
|
||||
},
|
||||
src3: {
|
||||
repr: T,
|
||||
type: { data.ctype() },
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
);
|
||||
|
@ -2378,3 +2402,27 @@ pub struct ReduxSyncData {
|
|||
pub type_: ScalarType,
|
||||
pub reduction: Reduction,
|
||||
}
|
||||
|
||||
pub struct MmaDetails {
|
||||
pub alayout: MatrixLayout,
|
||||
pub blayout: MatrixLayout,
|
||||
pub dtype_scalar: ScalarType,
|
||||
pub atype_scalar: ScalarType,
|
||||
pub btype_scalar: ScalarType,
|
||||
pub ctype_scalar: ScalarType,
|
||||
}
|
||||
|
||||
impl MmaDetails {
|
||||
pub fn dtype(&self) -> Type {
|
||||
Type::Vector(4, ScalarType::F32)
|
||||
}
|
||||
pub fn atype(&self) -> Type {
|
||||
Type::Vector(4, ScalarType::U32)
|
||||
}
|
||||
pub fn btype(&self) -> Type {
|
||||
Type::Vector(2, ScalarType::U32)
|
||||
}
|
||||
pub fn ctype(&self) -> Type {
|
||||
Type::Vector(4, ScalarType::F32)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1862,6 +1862,9 @@ derive_parser!(
|
|||
#[derive(Copy, Clone, Display, PartialEq, Eq, Hash)]
|
||||
pub enum MatrixNumber { }
|
||||
|
||||
#[derive(Copy, Clone, Display, PartialEq, Eq, Hash)]
|
||||
pub enum MatrixLayout { }
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov
|
||||
mov{.vec}.type d, a => {
|
||||
Instruction::Mov {
|
||||
|
@ -3897,6 +3900,37 @@ derive_parser!(
|
|||
.type: ScalarType = {.b16, .b8};
|
||||
// .dst_fmt = { .b8x16 };
|
||||
// .src_fmt = { .b6x16_p32, .b4x16_p64 };
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-griddepcontrol
|
||||
griddepcontrol.action => {
|
||||
Instruction::GridDepControl {
|
||||
data: action
|
||||
}
|
||||
}
|
||||
.action: GridDepControlAction = { .launch_dependents, .wait };
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-mma
|
||||
mma.sync.aligned.m16n8k16.alayout.blayout.dtype.bf16.bf16.ctype d, a, b, c => {
|
||||
if dtype != ScalarType::F32 || ctype != ScalarType::F32 {
|
||||
state.errors.push(PtxError::Todo);
|
||||
}
|
||||
Instruction::Mma {
|
||||
data: MmaDetails {
|
||||
alayout,
|
||||
blayout,
|
||||
dtype_scalar: dtype,
|
||||
atype_scalar: ScalarType::BF16,
|
||||
btype_scalar: ScalarType::BF16,
|
||||
ctype_scalar: ctype,
|
||||
},
|
||||
arguments: MmaArgs { dst: d, src1: a, src2: b, src3: c }
|
||||
}
|
||||
}
|
||||
|
||||
.alayout: MatrixLayout = {.row};
|
||||
.blayout: MatrixLayout = {.col};
|
||||
.ctype: ScalarType = {.f16, .f32};
|
||||
.dtype: ScalarType = {.f16, .f32};
|
||||
);
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
@ -1,6 +1,4 @@
|
|||
use bpaf::{any, doc::Style, Bpaf, Parser};
|
||||
use hip_runtime_sys::{hipDeviceProp_tR0600, hipGetDevicePropertiesR0600};
|
||||
use std::{ffi::CStr, mem};
|
||||
use bpaf::{any, choice, doc::Style, literal, Bpaf, Parser};
|
||||
|
||||
#[derive(Debug, Clone, Bpaf)]
|
||||
#[allow(dead_code)]
|
||||
|
@ -12,6 +10,8 @@ pub struct Options {
|
|||
#[bpaf(short, long)]
|
||||
verbose: bool,
|
||||
#[bpaf(external)]
|
||||
lineinfo: bool,
|
||||
#[bpaf(external)]
|
||||
gpu_name: String,
|
||||
#[bpaf(long, short('O'), fallback(3))]
|
||||
opt_level: usize,
|
||||
|
@ -19,48 +19,32 @@ pub struct Options {
|
|||
input: String,
|
||||
}
|
||||
|
||||
fn lineinfo() -> impl Parser<bool> {
|
||||
choice(["-lineinfo", "--lineinfo"].into_iter().map(|s| {
|
||||
literal(s)
|
||||
.anywhere()
|
||||
.optional()
|
||||
.map(|_| true)
|
||||
.fallback(false)
|
||||
.boxed()
|
||||
}))
|
||||
}
|
||||
|
||||
// #[bpaf(long, long("gpu_name"), fallback_with(default_arch))]
|
||||
fn gpu_name() -> impl Parser<String> {
|
||||
any("", move |s: String| {
|
||||
Some(s.strip_prefix("-arch=")?.to_owned())
|
||||
Some(
|
||||
s.strip_prefix("-arch=")
|
||||
.or_else(|| s.strip_prefix("--gpu-name="))?
|
||||
.to_owned(),
|
||||
)
|
||||
})
|
||||
.metavar(&[("-arch=", Style::Literal), ("ARG", Style::Metavar)])
|
||||
.metavar(&[("--gpu-name=", Style::Literal), ("SM", Style::Metavar)])
|
||||
.anywhere()
|
||||
.fallback_with(|| Ok::<String, &'static str>("sm_52".to_string()))
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let options = options().run();
|
||||
let comgr = comgr::Comgr::new().unwrap();
|
||||
unsafe { hip_runtime_sys::hipInit(0) }.unwrap();
|
||||
let mut dev_props: hipDeviceProp_tR0600 = unsafe { mem::zeroed() };
|
||||
let (gpu_arch, clock_rate) = get_gpu_arch_and_clock_rate(&mut dev_props);
|
||||
let input = std::fs::read_to_string(options.input).unwrap();
|
||||
let ast = ptx_parser::parse_module_checked(&input).unwrap();
|
||||
let llvm = ptx::to_llvm_module(
|
||||
ast,
|
||||
ptx::Attributes {
|
||||
clock_rate: clock_rate as u32,
|
||||
},
|
||||
|_| {},
|
||||
)
|
||||
.unwrap();
|
||||
let elf_binary = comgr::compile_bitcode(
|
||||
&comgr,
|
||||
gpu_arch,
|
||||
&*llvm.llvm_ir.write_bitcode_to_memory(),
|
||||
&*llvm.linked_bitcode(),
|
||||
&*llvm.attributes_ir.write_bitcode_to_memory(),
|
||||
None,
|
||||
)
|
||||
.unwrap();
|
||||
std::fs::write(options.output, elf_binary).unwrap();
|
||||
}
|
||||
|
||||
fn get_gpu_arch_and_clock_rate<'a>(dev_props: &'a mut hipDeviceProp_tR0600) -> (&'a str, i32) {
|
||||
unsafe { hipGetDevicePropertiesR0600(dev_props, 0) }.unwrap();
|
||||
let gcn_arch_name = &dev_props.gcnArchName;
|
||||
let gcn_arch_name = unsafe { CStr::from_ptr(gcn_arch_name.as_ptr()) };
|
||||
let gcn_arch_name = gcn_arch_name.to_str();
|
||||
(gcn_arch_name.unwrap(), dev_props.clockRate)
|
||||
std::fs::copy(&options.input, &options.output).unwrap();
|
||||
}
|
||||
|
|
|
@ -89,7 +89,15 @@ pub(crate) fn get_attribute(
|
|||
*pi = 32;
|
||||
return Ok(());
|
||||
}
|
||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_TCC_DRIVER => {
|
||||
// TODO: maintain a table, certain RDNAs are 1/16, some are 1/32
|
||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_SINGLE_TO_DOUBLE_PRECISION_PERF_RATIO => {
|
||||
*pi = 32;
|
||||
return Ok(());
|
||||
}
|
||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_TCC_DRIVER
|
||||
| CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED
|
||||
| CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MEMPOOL_SUPPORTED_HANDLE_TYPES
|
||||
| CUdevice_attribute::CU_DEVICE_ATTRIBUTE_DMA_BUF_SUPPORTED => {
|
||||
*pi = 0;
|
||||
return Ok(());
|
||||
}
|
||||
|
@ -211,9 +219,6 @@ pub(crate) fn get_attribute(
|
|||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_PERSISTING_L2_CACHE_SIZE => {
|
||||
return get_device_prop(pi, dev_idx, |props| props.persistingL2CacheMaxSize)
|
||||
}
|
||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_SINGLE_TO_DOUBLE_PRECISION_PERF_RATIO => {
|
||||
return get_device_prop(pi, dev_idx, |props| props.singleToDoublePrecisionPerfRatio)
|
||||
}
|
||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_ACCESS_POLICY_WINDOW_SIZE => {
|
||||
return get_device_prop(pi, dev_idx, |props| props.accessPolicyMaxWindowSize)
|
||||
}
|
||||
|
|
|
@ -487,9 +487,9 @@ pub(crate) unsafe fn occupancy_max_active_blocks_per_multiprocessor_with_flags(
|
|||
dynamic_smem_size: usize,
|
||||
flags: ::core::ffi::c_uint,
|
||||
) -> hipError_t {
|
||||
hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags(
|
||||
hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags(
|
||||
num_blocks,
|
||||
func.0.cast(),
|
||||
func,
|
||||
block_size,
|
||||
dynamic_smem_size,
|
||||
flags,
|
||||
|
|
12
zluda/src/impl/hipfix.rs
Normal file
12
zluda/src/impl/hipfix.rs
Normal file
|
@ -0,0 +1,12 @@
|
|||
// There's a bug in hipDrvPointerGetAttributes where it returns
|
||||
// HIP_ERROR_INVALID_VALUE if the pointer is null. It works correctly for any
|
||||
// other invalid pointer
|
||||
pub(crate) fn get_attributes(
|
||||
ptr: hip_runtime_sys::hipDeviceptr_t,
|
||||
) -> hip_runtime_sys::hipDeviceptr_t {
|
||||
if ptr.0.is_null() {
|
||||
hip_runtime_sys::hipDeviceptr_t(usize::MAX as _)
|
||||
} else {
|
||||
ptr
|
||||
}
|
||||
}
|
|
@ -7,6 +7,7 @@ pub(super) mod driver;
|
|||
pub(super) mod event;
|
||||
pub(super) mod function;
|
||||
pub(super) mod graph;
|
||||
pub(super) mod hipfix;
|
||||
pub(super) mod kernel;
|
||||
pub(super) mod library;
|
||||
pub(super) mod memory;
|
||||
|
|
|
@ -20,6 +20,10 @@ impl ZludaObject for Module {
|
|||
}
|
||||
}
|
||||
|
||||
static EMPTY_PTX: &str = ".version 6.5
|
||||
.target sm_30
|
||||
.address_size 64";
|
||||
|
||||
// get_ptx takes an `image` that can be anything we support and returns a
|
||||
// String containing a ptx extracted from `image`.
|
||||
fn get_ptx<'a>(image: CodeLibraryRef<'a>) -> Result<Cow<'a, str>, CUerror> {
|
||||
|
@ -58,11 +62,17 @@ fn cow_bytes_to_str<'a>(data: Cow<'a, [u8]>) -> Option<Cow<'a, str>> {
|
|||
|
||||
pub(crate) fn load_hip_module(library: CodeLibraryRef) -> Result<hipModule_t, CUerror> {
|
||||
let global_state = driver::global_state()?;
|
||||
let text = get_ptx(library)?;
|
||||
let maybe_ptx = get_ptx(library);
|
||||
let text = if cfg!(debug_assertions) {
|
||||
maybe_ptx?
|
||||
} else {
|
||||
maybe_ptx.unwrap_or_else(|_| Cow::Borrowed(EMPTY_PTX))
|
||||
};
|
||||
let hip_properties = get_hip_properties()?;
|
||||
let gcn_arch = get_gcn_arch(&hip_properties)?;
|
||||
let attributes = ptx::Attributes {
|
||||
let attributes = ExtraCacheAttributes {
|
||||
clock_rate: hip_properties.clockRate as u32,
|
||||
is_debug: cfg!(debug_assertions),
|
||||
};
|
||||
let mut cache_with_key = global_state.cache_path.as_ref().and_then(|p| {
|
||||
let cache = zluda_cache::ModuleCache::open(p)?;
|
||||
|
@ -84,6 +94,12 @@ pub(crate) fn load_hip_module(library: CodeLibraryRef) -> Result<hipModule_t, CU
|
|||
Ok(hip_module)
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct ExtraCacheAttributes {
|
||||
is_debug: bool,
|
||||
clock_rate: u32,
|
||||
}
|
||||
|
||||
fn get_hip_properties<'a>() -> Result<hipDeviceProp_tR0600, CUerror> {
|
||||
let hip_dev = super::context::get_current_device()?;
|
||||
let mut props = unsafe { mem::zeroed() };
|
||||
|
@ -100,7 +116,7 @@ fn get_cache_key<'a, 'b>(
|
|||
global_state: &'static driver::GlobalState,
|
||||
isa: &'a str,
|
||||
text: &str,
|
||||
attributes: &ptx::Attributes,
|
||||
attributes: &impl serde::Serialize,
|
||||
) -> Option<zluda_cache::ModuleKey<'a>> {
|
||||
// Serialization here is deterministic. When marking a type with
|
||||
// #[derive(serde::Serialize)] the derived implementation will just write
|
||||
|
@ -129,7 +145,7 @@ fn load_cached_binary(
|
|||
fn compile_from_ptx_and_cache(
|
||||
comgr: &comgr::Comgr,
|
||||
gcn_arch: &str,
|
||||
attributes: ptx::Attributes,
|
||||
attributes: ExtraCacheAttributes,
|
||||
text: &str,
|
||||
cache_with_key: &mut Option<(zluda_cache::ModuleCache, zluda_cache::ModuleKey)>,
|
||||
) -> Result<Vec<u8>, CUerror> {
|
||||
|
@ -138,7 +154,14 @@ fn compile_from_ptx_and_cache(
|
|||
} else {
|
||||
ptx_parser::parse_module_unchecked(text)
|
||||
};
|
||||
let llvm_module = ptx::to_llvm_module(ast, attributes, |_| {}).map_err(|_| CUerror::UNKNOWN)?;
|
||||
let llvm_module = ptx::to_llvm_module(
|
||||
ast,
|
||||
ptx::Attributes {
|
||||
clock_rate: attributes.clock_rate,
|
||||
},
|
||||
|_| {},
|
||||
)
|
||||
.map_err(|_| CUerror::UNKNOWN)?;
|
||||
let elf_module = comgr::compile_bitcode(
|
||||
comgr,
|
||||
gcn_arch,
|
||||
|
|
|
@ -2,7 +2,7 @@ use cuda_types::cuda::*;
|
|||
use hip_runtime_sys::*;
|
||||
use std::{ffi::c_void, ptr};
|
||||
|
||||
use crate::r#impl::driver;
|
||||
use crate::r#impl::{driver, hipfix};
|
||||
|
||||
// TODO: handlehipMemoryTypeUnregistered
|
||||
fn to_cu_memory_type(cu: hipMemoryType) -> Result<CUmemorytype, hipErrorCode_t> {
|
||||
|
@ -59,7 +59,12 @@ pub(crate) unsafe fn get_attributes(
|
|||
data: &mut *mut ::core::ffi::c_void,
|
||||
ptr: hipDeviceptr_t,
|
||||
) -> CUresult {
|
||||
hipDrvPointerGetAttributes(num_attributes, attributes, data, ptr)?;
|
||||
hipDrvPointerGetAttributes(
|
||||
num_attributes,
|
||||
attributes,
|
||||
data,
|
||||
hipfix::get_attributes(ptr),
|
||||
)?;
|
||||
let attributes = std::slice::from_raw_parts_mut(attributes, num_attributes as usize);
|
||||
let data = std::slice::from_raw_parts_mut(data, num_attributes as usize);
|
||||
for (attr, data_ptr) in attributes.iter().copied().zip(data.iter().copied()) {
|
||||
|
@ -88,7 +93,7 @@ mod tests {
|
|||
use crate::tests::CudaApi;
|
||||
use cuda_macros::test_cuda;
|
||||
use cuda_types::cuda::*;
|
||||
use std::{ffi::c_void, mem, ptr};
|
||||
use std::{ffi::c_void, i32, mem, ptr, usize};
|
||||
|
||||
#[test_cuda]
|
||||
pub unsafe fn unknown_ptr_attribute(api: impl CudaApi) {
|
||||
|
@ -162,4 +167,47 @@ mod tests {
|
|||
);
|
||||
assert_eq!(context, CUcontext(ptr::null_mut()));
|
||||
}
|
||||
|
||||
#[test_cuda]
|
||||
pub unsafe fn null_ptr_attributes_success(api: impl CudaApi) {
|
||||
api.cuInit(0);
|
||||
api.cuCtxCreate_v2(&mut mem::zeroed(), 0, 0);
|
||||
let mut context = CUcontext(1 as _);
|
||||
let mut mem_type = mem::transmute::<_, CUmemorytype>(u32::MAX);
|
||||
let mut dev_ptr = mem::transmute::<_, *mut c_void>(usize::MAX);
|
||||
let mut host_ptr = mem::transmute::<_, *mut c_void>(usize::MAX);
|
||||
let mut is_managed = true;
|
||||
let mut ordinal = i32::MAX;
|
||||
let mut attrs = [
|
||||
CUpointer_attribute::CU_POINTER_ATTRIBUTE_CONTEXT,
|
||||
CUpointer_attribute::CU_POINTER_ATTRIBUTE_MEMORY_TYPE,
|
||||
CUpointer_attribute::CU_POINTER_ATTRIBUTE_DEVICE_POINTER,
|
||||
CUpointer_attribute::CU_POINTER_ATTRIBUTE_HOST_POINTER,
|
||||
CUpointer_attribute::CU_POINTER_ATTRIBUTE_IS_MANAGED,
|
||||
CUpointer_attribute::CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
|
||||
];
|
||||
let mut values = [
|
||||
std::ptr::from_mut(&mut context).cast::<c_void>(),
|
||||
std::ptr::from_mut(&mut mem_type).cast::<c_void>(),
|
||||
std::ptr::from_mut(&mut dev_ptr).cast::<c_void>(),
|
||||
std::ptr::from_mut(&mut host_ptr).cast::<c_void>(),
|
||||
std::ptr::from_mut(&mut is_managed).cast::<c_void>(),
|
||||
std::ptr::from_mut(&mut ordinal).cast::<c_void>(),
|
||||
];
|
||||
assert_eq!(
|
||||
CUresult::SUCCESS,
|
||||
api.cuPointerGetAttributes_unchecked(
|
||||
attrs.len() as u32,
|
||||
attrs.as_mut_ptr(),
|
||||
values.as_mut_ptr(),
|
||||
CUdeviceptr_v2(ptr::null_mut())
|
||||
)
|
||||
);
|
||||
assert_eq!(context, CUcontext(ptr::null_mut()));
|
||||
assert_eq!(mem_type, CUmemorytype(0));
|
||||
assert_eq!(dev_ptr, ptr::null_mut());
|
||||
assert_eq!(host_ptr, ptr::null_mut());
|
||||
assert_eq!(is_managed, false);
|
||||
assert_eq!(ordinal, -2);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -43,6 +43,86 @@ pub(crate) unsafe fn device_get_count_v2(device_count: &mut ::core::ffi::c_uint)
|
|||
rsmi_num_monitor_devices(device_count)
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn device_get_handle_by_pci_bus_id_v2(
|
||||
pci_bus_id: &std::ffi::CStr,
|
||||
device: &mut cuda_types::nvml::nvmlDevice_t,
|
||||
) -> nvmlReturn_t {
|
||||
let pci = parse_pci_bus_id(pci_bus_id).ok_or(nvmlError_t::INVALID_ARGUMENT)?;
|
||||
let bdfid = pci.to_bdfid();
|
||||
let mut device_count = 0;
|
||||
rsmi_num_monitor_devices(&mut device_count)?;
|
||||
for dv_ind in 0..device_count {
|
||||
let mut curr_bdfid = 0;
|
||||
rsmi_dev_pci_id_get(dv_ind, &mut curr_bdfid)?;
|
||||
if curr_bdfid == bdfid {
|
||||
*device = Device { _index: dv_ind }.wrap();
|
||||
return nvmlReturn_t::SUCCESS;
|
||||
}
|
||||
}
|
||||
nvmlReturn_t::ERROR_NOT_FOUND
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct PciBusId {
|
||||
domain: u16,
|
||||
bus: u8,
|
||||
device: u8,
|
||||
function: u8,
|
||||
}
|
||||
impl PciBusId {
|
||||
fn to_bdfid(self) -> u64 {
|
||||
((self.domain as u64) << 32)
|
||||
| ((self.bus as u64) << 8)
|
||||
| ((self.device as u64) << 3)
|
||||
| (self.function as u64)
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_pci_bus_id(id: &std::ffi::CStr) -> Option<PciBusId> {
|
||||
let s = id.to_str().ok()?.trim();
|
||||
let mut domain: u16 = 0;
|
||||
let mut rest = s;
|
||||
if let Some(colon1) = s.find(':') {
|
||||
if colon1 == 4 {
|
||||
domain = hex_u16(&s[..4])?;
|
||||
rest = &s[5..];
|
||||
}
|
||||
}
|
||||
let mut parts = rest.split(':');
|
||||
let bus_part = parts.next()?;
|
||||
let tail = parts.next()?;
|
||||
if parts.next().is_some() {
|
||||
return None;
|
||||
}
|
||||
let mut dev_func = tail.split('.');
|
||||
let dev_part = dev_func.next()?;
|
||||
let func_part = dev_func.next();
|
||||
let function = match func_part {
|
||||
Some(f) => hex_u8(f)?,
|
||||
None => 0,
|
||||
};
|
||||
Some(PciBusId {
|
||||
domain,
|
||||
bus: hex_u8(bus_part)?,
|
||||
device: hex_u8(dev_part)?,
|
||||
function,
|
||||
})
|
||||
}
|
||||
|
||||
fn hex_u16(s: &str) -> Option<u16> {
|
||||
if s.len() > 4 {
|
||||
return None;
|
||||
}
|
||||
u16::from_str_radix(s, 16).ok()
|
||||
}
|
||||
|
||||
fn hex_u8(s: &str) -> Option<u8> {
|
||||
if s.len() > 2 {
|
||||
return None;
|
||||
}
|
||||
u8::from_str_radix(s, 16).ok()
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn device_get_field_values(
|
||||
_device: &Device,
|
||||
values_count: ::core::ffi::c_int,
|
||||
|
@ -75,3 +155,36 @@ pub(crate) fn device_get_handle_by_index_v2(
|
|||
*device = Device { _index: index }.wrap();
|
||||
nvmlReturn_t::SUCCESS
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#[test]
|
||||
fn parse_pci_bus_id_full() {
|
||||
let id = std::ffi::CString::new("0100:65:a0.f").unwrap();
|
||||
let parsed = super::parse_pci_bus_id(&id).unwrap();
|
||||
assert_eq!(parsed.domain, 0x0100);
|
||||
assert_eq!(parsed.bus, 0x65);
|
||||
assert_eq!(parsed.device, 0xa0);
|
||||
assert_eq!(parsed.function, 0xf);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_pci_bus_id_no_func() {
|
||||
let id = std::ffi::CString::new("0100:65:a0").unwrap();
|
||||
let parsed = super::parse_pci_bus_id(&id).unwrap();
|
||||
assert_eq!(parsed.domain, 0x0100);
|
||||
assert_eq!(parsed.bus, 0x65);
|
||||
assert_eq!(parsed.device, 0xa0);
|
||||
assert_eq!(parsed.function, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_pci_bus_id_no_domain() {
|
||||
let id = std::ffi::CString::new("65:a0.f").unwrap();
|
||||
let parsed = super::parse_pci_bus_id(&id).unwrap();
|
||||
assert_eq!(parsed.domain, 0);
|
||||
assert_eq!(parsed.bus, 0x65);
|
||||
assert_eq!(parsed.device, 0xa0);
|
||||
assert_eq!(parsed.function, 0xf);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,6 +23,13 @@ pub(crate) unsafe fn device_get_count_v2(_device_count: &mut ::core::ffi::c_uint
|
|||
crate::impl_common::unimplemented()
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn device_get_handle_by_pci_bus_id_v2(
|
||||
pci_bus_id: &std::ffi::CStr,
|
||||
device: &mut cuda_types::nvml::nvmlDevice_t,
|
||||
) -> nvmlReturn_t {
|
||||
crate::impl_common::unimplemented()
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn device_get_field_values(
|
||||
_device: cuda_types::nvml::nvmlDevice_t,
|
||||
_values_count: ::core::ffi::c_int,
|
||||
|
@ -31,10 +38,6 @@ pub(crate) unsafe fn device_get_field_values(
|
|||
crate::impl_common::unimplemented()
|
||||
}
|
||||
|
||||
unsafe fn get_field_value(_field: &mut nvmlFieldValue_st) -> Result<(), nvmlError_t> {
|
||||
crate::impl_common::unimplemented()
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn device_get_gpu_fabric_info(
|
||||
_device: cuda_types::nvml::nvmlDevice_t,
|
||||
_gpu_fabric_info: &mut cuda_types::nvml::nvmlGpuFabricInfo_t,
|
||||
|
|
|
@ -48,6 +48,7 @@ cuda_macros::nvml_function_declarations!(
|
|||
nvmlDeviceGetFieldValues,
|
||||
nvmlDeviceGetGpuFabricInfo,
|
||||
nvmlDeviceGetHandleByIndex_v2,
|
||||
nvmlDeviceGetHandleByPciBusId_v2,
|
||||
nvmlInit,
|
||||
nvmlInitWithFlags,
|
||||
nvmlInit_v2,
|
||||
|
|
|
@ -303,6 +303,7 @@ pub(crate) enum ErrorEntry {
|
|||
},
|
||||
NullPointer(&'static str),
|
||||
UnknownLibrary(CUlibrary),
|
||||
SavedModule(String),
|
||||
}
|
||||
|
||||
unsafe impl Send for ErrorEntry {}
|
||||
|
@ -344,93 +345,94 @@ impl Display for ErrorEntry {
|
|||
match self {
|
||||
ErrorEntry::IoError(e) => e.fmt(f),
|
||||
ErrorEntry::CreatedDumpDirectory(dir) => {
|
||||
write!(
|
||||
f,
|
||||
"Created trace directory {} ",
|
||||
dir.as_os_str().to_string_lossy()
|
||||
)
|
||||
}
|
||||
write!(
|
||||
f,
|
||||
"Created trace directory {} ",
|
||||
dir.as_os_str().to_string_lossy()
|
||||
)
|
||||
}
|
||||
ErrorEntry::ErrorBox(e) => e.fmt(f),
|
||||
ErrorEntry::UnsupportedModule {
|
||||
module,
|
||||
raw_image,
|
||||
kind,
|
||||
} => {
|
||||
write!(
|
||||
f,
|
||||
"Unsupported {} module {:?} loaded from module image {:?}",
|
||||
kind, module, raw_image
|
||||
)
|
||||
}
|
||||
module,
|
||||
raw_image,
|
||||
kind,
|
||||
} => {
|
||||
write!(
|
||||
f,
|
||||
"Unsupported {} module {:?} loaded from module image {:?}",
|
||||
kind, module, raw_image
|
||||
)
|
||||
}
|
||||
ErrorEntry::MalformedModulePath(e) => e.fmt(f),
|
||||
ErrorEntry::NonUtf8ModuleText(e) => e.fmt(f),
|
||||
ErrorEntry::ModuleParsingError(file_name) => {
|
||||
write!(
|
||||
f,
|
||||
"Error parsing module, log has been written to {}",
|
||||
file_name
|
||||
)
|
||||
}
|
||||
write!(
|
||||
f,
|
||||
"Error parsing module, log has been written to {}",
|
||||
file_name
|
||||
)
|
||||
}
|
||||
ErrorEntry::NulInsideModuleText(e) => e.fmt(f),
|
||||
ErrorEntry::Lz4DecompressionFailure => write!(f, "LZ4 decompression failure"),
|
||||
ErrorEntry::ZstdDecompressionFailure(err_code) => write!(f, "Zstd decompression failure: {}", zstd_safe::get_error_name(*err_code)),
|
||||
ErrorEntry::UnexpectedBinaryField {
|
||||
field_name,
|
||||
expected,
|
||||
observed,
|
||||
} => write!(
|
||||
f,
|
||||
"Unexpected field {}. Expected one of: [{}], observed: {}",
|
||||
field_name,
|
||||
expected
|
||||
.iter()
|
||||
.map(|x| x.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", "),
|
||||
observed
|
||||
),
|
||||
field_name,
|
||||
expected,
|
||||
observed,
|
||||
} => write!(
|
||||
f,
|
||||
"Unexpected field {}. Expected one of: [{}], observed: {}",
|
||||
field_name,
|
||||
expected
|
||||
.iter()
|
||||
.map(|x| x.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", "),
|
||||
observed
|
||||
),
|
||||
ErrorEntry::UnexpectedArgument {
|
||||
arg_name,
|
||||
expected,
|
||||
observed,
|
||||
} => write!(
|
||||
f,
|
||||
"Unexpected argument {}. Expected one of: {{{}}}, observed: {}",
|
||||
arg_name,
|
||||
expected
|
||||
.iter()
|
||||
.map(|x| x.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", "),
|
||||
observed
|
||||
),
|
||||
arg_name,
|
||||
expected,
|
||||
observed,
|
||||
} => write!(
|
||||
f,
|
||||
"Unexpected argument {}. Expected one of: {{{}}}, observed: {}",
|
||||
arg_name,
|
||||
expected
|
||||
.iter()
|
||||
.map(|x| x.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", "),
|
||||
observed
|
||||
),
|
||||
ErrorEntry::InvalidEnvVar {
|
||||
var,
|
||||
pattern,
|
||||
value,
|
||||
} => write!(
|
||||
f,
|
||||
"Unexpected value of environment variable {var}. Expected pattern: {pattern}, got value: {value}"
|
||||
),
|
||||
var,
|
||||
pattern,
|
||||
value,
|
||||
} => write!(
|
||||
f,
|
||||
"Unexpected value of environment variable {var}. Expected pattern: {pattern}, got value: {value}"
|
||||
),
|
||||
ErrorEntry::FunctionNotFound(cuda_function_name) => write!(
|
||||
f,
|
||||
"No function {cuda_function_name} in the underlying library"
|
||||
),
|
||||
f,
|
||||
"No function {cuda_function_name} in the underlying library"
|
||||
),
|
||||
ErrorEntry::UnexpectedExportTableSize { expected, computed } => {
|
||||
write!(f, "Table length mismatch. Expected: {expected}, got: {computed}")
|
||||
}
|
||||
write!(f, "Table length mismatch. Expected: {expected}, got: {computed}")
|
||||
}
|
||||
ErrorEntry::IntegrityCheck { original, overriden } => {
|
||||
write!(f, "Overriding integrity check hash. Original: {original:?}, overriden: {overriden:?}")
|
||||
}
|
||||
write!(f, "Overriding integrity check hash. Original: {original:?}, overriden: {overriden:?}")
|
||||
}
|
||||
ErrorEntry::NullPointer(type_) => {
|
||||
write!(f, "Null pointer of type {type_} encountered")
|
||||
}
|
||||
write!(f, "Null pointer of type {type_} encountered")
|
||||
}
|
||||
ErrorEntry::UnknownLibrary(culibrary) => {
|
||||
write!(f, "Unknown library: ")?;
|
||||
let mut temp_buffer = Vec::new();
|
||||
CudaDisplay::write(culibrary, "", 0, &mut temp_buffer).ok();
|
||||
f.write_str(&unsafe { String::from_utf8_unchecked(temp_buffer) })
|
||||
}
|
||||
write!(f, "Unknown library: ")?;
|
||||
let mut temp_buffer = Vec::new();
|
||||
CudaDisplay::write(culibrary, "", 0, &mut temp_buffer).ok();
|
||||
f.write_str(&unsafe { String::from_utf8_unchecked(temp_buffer) })
|
||||
}
|
||||
ErrorEntry::SavedModule(file) => write!(f, "Saved module to {file}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -128,12 +128,11 @@ impl StateTracker {
|
|||
fn_logger: &mut FnCallLog,
|
||||
type_: &'static str,
|
||||
) {
|
||||
fn_logger.log_io_error(self.writer.save_module(
|
||||
self.library_counter,
|
||||
index,
|
||||
submodule,
|
||||
type_,
|
||||
));
|
||||
fn_logger.try_(|fn_logger| {
|
||||
self.writer
|
||||
.save_module(fn_logger, self.library_counter, index, submodule, type_)
|
||||
.map_err(ErrorEntry::IoError)
|
||||
});
|
||||
if type_ == "ptx" {
|
||||
match CString::new(submodule) {
|
||||
Err(e) => fn_logger.log(log::ErrorEntry::NulInsideModuleText(e)),
|
||||
|
@ -323,6 +322,7 @@ impl DumpWriter {
|
|||
|
||||
fn save_module(
|
||||
&self,
|
||||
fn_logger: &mut FnCallLog,
|
||||
module_index: usize,
|
||||
submodule_index: Option<(usize, Option<usize>)>,
|
||||
buffer: &[u8],
|
||||
|
@ -332,9 +332,13 @@ impl DumpWriter {
|
|||
None => return Ok(()),
|
||||
Some(d) => d.clone(),
|
||||
};
|
||||
dump_file.push(Self::get_file_name(module_index, submodule_index, kind));
|
||||
let mut file = File::create_new(dump_file)?;
|
||||
file.write_all(buffer)?;
|
||||
let file_name = Self::get_file_name(module_index, submodule_index, kind);
|
||||
dump_file.push(&file_name);
|
||||
{
|
||||
let mut file = File::create_new(dump_file)?;
|
||||
file.write_all(buffer)?;
|
||||
}
|
||||
fn_logger.log(ErrorEntry::SavedModule(file_name));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -349,7 +353,7 @@ impl DumpWriter {
|
|||
Some(d) => d.clone(),
|
||||
};
|
||||
log_file.push(Self::get_file_name(module_index, submodule_index, "log"));
|
||||
let mut file = File::create(log_file)?;
|
||||
let mut file = File::create_new(log_file)?;
|
||||
for error in errors {
|
||||
writeln!(file, "{}", error)?;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue