mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-09-21 16:59:04 +00:00
More runtime fixes, add mma instruction (#509)
This commit is contained in:
parent
150ce171cf
commit
b5f41c7cd0
27 changed files with 639 additions and 154 deletions
2
.github/ISSUE_TEMPLATE/zluda_dump.yml
vendored
2
.github/ISSUE_TEMPLATE/zluda_dump.yml
vendored
|
@ -45,7 +45,7 @@ body:
|
||||||
./train_gpt2fp32cu
|
./train_gpt2fp32cu
|
||||||
4. Build and run the tests:
|
4. Build and run the tests:
|
||||||
make test_gpt2fp32cu
|
make test_gpt2fp32cu
|
||||||
LD_LIBRARY_PATH=<ZLUDA_TRACE_DIR> ./test_gpt2fp32cu
|
LD_LIBRARY_PATH=<ZLUDA_LOG_DIR> ./test_gpt2fp32cu
|
||||||
validations:
|
validations:
|
||||||
required: true
|
required: true
|
||||||
- type: input
|
- type: input
|
||||||
|
|
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -2573,7 +2573,6 @@ dependencies = [
|
||||||
"ptx_parser",
|
"ptx_parser",
|
||||||
"quick-error",
|
"quick-error",
|
||||||
"rustc-hash 2.0.0",
|
"rustc-hash 2.0.0",
|
||||||
"serde",
|
|
||||||
"smallvec",
|
"smallvec",
|
||||||
"strum 0.26.3",
|
"strum 0.26.3",
|
||||||
"strum_macros 0.26.4",
|
"strum_macros 0.26.4",
|
||||||
|
|
|
@ -219,6 +219,10 @@ pub fn compile_bitcode(
|
||||||
compile_to_exec.set_isa_name(gcn_arch)?;
|
compile_to_exec.set_isa_name(gcn_arch)?;
|
||||||
compile_to_exec.set_language(Language::LlvmIr)?;
|
compile_to_exec.set_language(Language::LlvmIr)?;
|
||||||
let common_options = [
|
let common_options = [
|
||||||
|
c"-mllvm",
|
||||||
|
c"-ignore-tti-inline-compatible",
|
||||||
|
// c"-mllvm",
|
||||||
|
// c"-amdgpu-early-inline-all=true",
|
||||||
// This makes no sense, but it makes ockl linking work
|
// This makes no sense, but it makes ockl linking work
|
||||||
c"-Xclang",
|
c"-Xclang",
|
||||||
c"-mno-link-builtin-bitcode-postopt",
|
c"-mno-link-builtin-bitcode-postopt",
|
||||||
|
@ -237,8 +241,7 @@ pub fn compile_bitcode(
|
||||||
]
|
]
|
||||||
.into_iter();
|
.into_iter();
|
||||||
let opt_options = if cfg!(debug_assertions) {
|
let opt_options = if cfg!(debug_assertions) {
|
||||||
//[c"-g", c"-mllvm", c"-print-before-all", c"", c""]
|
[c"-g", c"-mamdgpu-precise-memory-op", c"", c"", c""]
|
||||||
[c"-g", c"", c"", c"", c""]
|
|
||||||
} else {
|
} else {
|
||||||
[
|
[
|
||||||
c"-g0",
|
c"-g0",
|
||||||
|
|
|
@ -21,9 +21,14 @@ pub struct Options {
|
||||||
output_dir: Option<PathBuf>,
|
output_dir: Option<PathBuf>,
|
||||||
|
|
||||||
#[bpaf(long("arch"))]
|
#[bpaf(long("arch"))]
|
||||||
/// Target architecture
|
/// Target GPU architecture
|
||||||
arch: Option<String>,
|
arch: Option<String>,
|
||||||
|
|
||||||
|
#[bpaf(long("ignore-errors"))]
|
||||||
|
/// Try to ignore errors. This will try and produce output even if there are
|
||||||
|
/// parsing errors (e.g. an unimplemented instruction)
|
||||||
|
ignore_errors: bool,
|
||||||
|
|
||||||
#[bpaf(positional("filename"))]
|
#[bpaf(positional("filename"))]
|
||||||
/// PTX file
|
/// PTX file
|
||||||
ptx_path: String,
|
ptx_path: String,
|
||||||
|
@ -48,7 +53,10 @@ fn main_core() -> Result<(), CompilerError> {
|
||||||
.unwrap_or("output");
|
.unwrap_or("output");
|
||||||
|
|
||||||
let mut output_path = match opts.output_dir {
|
let mut output_path = match opts.output_dir {
|
||||||
Some(value) => value,
|
Some(value) => {
|
||||||
|
std::fs::create_dir_all(&value)?;
|
||||||
|
value
|
||||||
|
}
|
||||||
None => match ptx_path.parent() {
|
None => match ptx_path.parent() {
|
||||||
Some(dir) => dir.to_path_buf(),
|
Some(dir) => dir.to_path_buf(),
|
||||||
None => env::current_dir()?,
|
None => env::current_dir()?,
|
||||||
|
@ -68,7 +76,7 @@ fn main_core() -> Result<(), CompilerError> {
|
||||||
|
|
||||||
let ptx = fs::read(&ptx_path).map_err(CompilerError::from)?;
|
let ptx = fs::read(&ptx_path).map_err(CompilerError::from)?;
|
||||||
let ptx = str::from_utf8(&ptx).map_err(CompilerError::from)?;
|
let ptx = str::from_utf8(&ptx).map_err(CompilerError::from)?;
|
||||||
let llvm = ptx_to_llvm(ptx).map_err(CompilerError::from)?;
|
let llvm = ptx_to_llvm(opts.ignore_errors, ptx).map_err(CompilerError::from)?;
|
||||||
|
|
||||||
write_to_file(&llvm.llvm_ir, output_path.with_extension("ll").as_path())?;
|
write_to_file(&llvm.llvm_ir, output_path.with_extension("ll").as_path())?;
|
||||||
|
|
||||||
|
@ -92,8 +100,12 @@ fn main_core() -> Result<(), CompilerError> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ptx_to_llvm(ptx: &str) -> Result<LLVMArtifacts, CompilerError> {
|
fn ptx_to_llvm(ignore_errors: bool, ptx: &str) -> Result<LLVMArtifacts, CompilerError> {
|
||||||
let ast = ptx_parser::parse_module_checked(ptx).map_err(CompilerError::from)?;
|
let ast = if ignore_errors {
|
||||||
|
ptx_parser::parse_module_unchecked(ptx)
|
||||||
|
} else {
|
||||||
|
ptx_parser::parse_module_checked(ptx).map_err(CompilerError::from)?
|
||||||
|
};
|
||||||
let mut start = Instant::now();
|
let mut start = Instant::now();
|
||||||
let module = ptx::to_llvm_module(
|
let module = ptx::to_llvm_module(
|
||||||
ast,
|
ast,
|
||||||
|
|
|
@ -116,7 +116,7 @@ in order to demonstrate all of zluda_trace's features.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
nvcc add.cu -o add -arch sm_80
|
nvcc add.cu -o add -arch sm_80
|
||||||
LD_LIBRARY_PATH=~/ZLUDA/target/release/trace/ ZLUDA_TRACE_DIR=/tmp/zluda ./add
|
LD_LIBRARY_PATH=~/ZLUDA/target/release/trace/ ZLUDA_LOG_DIR=/tmp/zluda ./add
|
||||||
```
|
```
|
||||||
|
|
||||||
The last few lines should look something like:
|
The last few lines should look something like:
|
||||||
|
|
|
@ -22,7 +22,6 @@ microlp = "0.2.11"
|
||||||
int-enum = "1.1"
|
int-enum = "1.1"
|
||||||
unwrap_or = "1.0.1"
|
unwrap_or = "1.0.1"
|
||||||
smallvec = "1.15.1"
|
smallvec = "1.15.1"
|
||||||
serde = { version = "1.0.219", features = ["derive"] }
|
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
hip_runtime-sys = { path = "../ext/hip_runtime-sys" }
|
hip_runtime-sys = { path = "../ext/hip_runtime-sys" }
|
||||||
|
|
Binary file not shown.
|
@ -1,17 +1,21 @@
|
||||||
// Every time this file changes it must te rebuilt, you need `rocm-llvm-dev` and `llvm-17`
|
// Every time this file changes it must te rebuilt, you need `rocm-llvm-dev` and `llvm-17`
|
||||||
// `fdenormal-fp-math=dynamic` is required to make functions eligible for inlining
|
// `fdenormal-fp-math=dynamic` is required to make functions eligible for inlining
|
||||||
// /opt/rocm/llvm/bin/clang -std=c++20 -Xclang -fdenormal-fp-math=dynamic -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1010 && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1010\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | sed 's/ nneg / /g' | sed 's/ disjoint / /g' | sed '/__hip_cuid/d' | sed 's/external protected/external hidden/g' | sed 's/trunc nuw/trunc/' | sed 's/trunc nsw/trunc/' | llvm-as-17 - -o zluda_ptx_impl.bc && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc
|
// /opt/rocm/llvm/bin/clang -std=c++20 -Xclang -fdenormal-fp-math=dynamic -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1100 && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1100\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | sed 's/ nneg / /g' | sed 's/ disjoint / /g' | sed '/__hip_cuid/d' | sed 's/external protected/external hidden/g' | sed 's/trunc nuw/trunc/' | sed 's/trunc nsw/trunc/' | llvm-as-17 - -o zluda_ptx_impl.bc && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <bit>
|
#include <bit>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include <hip/hip_runtime.h>
|
||||||
#include <hip/amd_detail/amd_device_functions.h>
|
#include <hip/amd_detail/amd_device_functions.h>
|
||||||
#include <hip/hip_fp8.h>
|
#include <hip/hip_fp8.h>
|
||||||
|
|
||||||
#define SHARED_SPACE __attribute__((address_space(3)))
|
#define SHARED_SPACE __attribute__((address_space(3)))
|
||||||
#define CONSTANT_SPACE __attribute__((address_space(4)))
|
#define CONSTANT_SPACE __attribute__((address_space(4)))
|
||||||
|
|
||||||
|
typedef _Float16 half16 __attribute__((ext_vector_type(16)));
|
||||||
|
typedef float float8 __attribute__((ext_vector_type(8)));
|
||||||
|
|
||||||
#define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_##NAME
|
#define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_##NAME
|
||||||
#define ATTR(NAME) __ZLUDA_PTX_IMPL_ATTRIBUTE_##NAME
|
#define ATTR(NAME) __ZLUDA_PTX_IMPL_ATTRIBUTE_##NAME
|
||||||
#define DECLARE_ATTR(TYPE, NAME) \
|
#define DECLARE_ATTR(TYPE, NAME) \
|
||||||
|
@ -624,4 +628,156 @@ typedef uint32_t ShflSyncResult __attribute__((ext_vector_type(2)));
|
||||||
uint32_t x3 = load_single_matrix_trans(address, 24);
|
uint32_t x3 = load_single_matrix_trans(address, 24);
|
||||||
return uint4::Native_vec_{x0, x1, x2, x3};
|
return uint4::Native_vec_{x0, x1, x2, x3};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline __device__ _Float16 top16_as_fp16(uint32_t value) {
|
||||||
|
uint16_t half_bits = static_cast<uint16_t>((value >> 16) & 0xFFFF);
|
||||||
|
return *reinterpret_cast<_Float16*>(&half_bits);
|
||||||
|
}
|
||||||
|
static inline __device__ _Float16 bottom16_as_fp16(uint32_t value) {
|
||||||
|
uint16_t half_bits = static_cast<uint16_t>(value & 0xFFFF);
|
||||||
|
return *reinterpret_cast<_Float16*>(&half_bits);
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline __device__ float bpermute_lane(int lane, float x) {
|
||||||
|
return __hip_ds_bpermutef(4 * lane, x);
|
||||||
|
}
|
||||||
|
static inline __device__ uint32_t bpermute_lane(int lane, uint32_t x) {
|
||||||
|
return __hip_ds_bpermute(4 * lane, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ half16 shuffle_a(uint4::Native_vec_ a_reg) {
|
||||||
|
const unsigned lIdx = threadIdx.x;
|
||||||
|
const int lane = lIdx % 16; // Lanes 0-15 (the other 16 lanes are a duplicate in w32 mode)
|
||||||
|
half16 aFrag;
|
||||||
|
|
||||||
|
for (int vGPR = 0; vGPR < 8; ++vGPR) {
|
||||||
|
int cudaChunk = (vGPR / 4) * 2; // will be 0 or 2
|
||||||
|
int cudaTID = (vGPR % 4 + lane * 4) % 32;
|
||||||
|
uint32_t reg0, reg1;
|
||||||
|
// Select the two consecutive elements from a_reg:
|
||||||
|
if (cudaChunk == 0) {
|
||||||
|
reg0 = a_reg.x;
|
||||||
|
reg1 = a_reg.y;
|
||||||
|
} else { // cudaChunk==2
|
||||||
|
reg0 = a_reg.z;
|
||||||
|
reg1 = a_reg.w;
|
||||||
|
}
|
||||||
|
uint32_t a_tmp0 = bpermute_lane(cudaTID, reg0);
|
||||||
|
uint32_t a_tmp1 = bpermute_lane(cudaTID, reg1);
|
||||||
|
uint32_t a_Frag_reg = (lane < 8) ? a_tmp0 : a_tmp1;
|
||||||
|
aFrag[2 * vGPR] = bottom16_as_fp16(a_Frag_reg);
|
||||||
|
aFrag[2 * vGPR + 1] = top16_as_fp16(a_Frag_reg);
|
||||||
|
}
|
||||||
|
return aFrag;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ half16 shuffle_b(uint2::Native_vec_ b_reg) {
|
||||||
|
const unsigned lIdx = threadIdx.x;
|
||||||
|
const int lane = lIdx % 16;
|
||||||
|
half16 bFrag;
|
||||||
|
|
||||||
|
for (int vGPR = 0; vGPR < 8; ++vGPR) {
|
||||||
|
int cudaChunk = vGPR / 4; // will be 0 or 1
|
||||||
|
int cudaTID = vGPR % 4 + (lane * 4) % 64;
|
||||||
|
uint32_t reg = (cudaChunk == 0) ? b_reg.x : b_reg.y;
|
||||||
|
uint32_t b_Frag_reg = bpermute_lane(cudaTID, reg);
|
||||||
|
if (lane < 8) {
|
||||||
|
bFrag[2 * vGPR] = bottom16_as_fp16(b_Frag_reg);
|
||||||
|
bFrag[2 * vGPR + 1] = top16_as_fp16(b_Frag_reg);
|
||||||
|
} else {
|
||||||
|
bFrag[2 * vGPR] = 0.0f;
|
||||||
|
bFrag[2 * vGPR + 1] = 0.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return bFrag;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ float8 shuffle_c(float4::Native_vec_ c_reg) {
|
||||||
|
const int lIdx = (int)threadIdx.x;
|
||||||
|
float8 cFrag;
|
||||||
|
|
||||||
|
// Loop over the eight vector GPRs.
|
||||||
|
for (int vGPR = 0; vGPR < 8; ++vGPR) {
|
||||||
|
int cudaChunk = (vGPR / 4) * 2; // will be 0 or 2: selects which pair of components to use.
|
||||||
|
int lIdx8 = (lIdx < 8) ? lIdx : lIdx - 8;
|
||||||
|
int cudaTID = (vGPR % 4) * 8 + lIdx8 / 2;
|
||||||
|
float ctmp0, ctmp1;
|
||||||
|
|
||||||
|
if (cudaChunk == 0) {
|
||||||
|
ctmp0 = bpermute_lane(cudaTID, c_reg.x);
|
||||||
|
ctmp1 = bpermute_lane(cudaTID, c_reg.y);
|
||||||
|
} else { // cudaChunk == 2
|
||||||
|
ctmp0 = bpermute_lane(cudaTID, c_reg.z);
|
||||||
|
ctmp1 = bpermute_lane(cudaTID, c_reg.w);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select one of the two values based on the thread index's LSB.
|
||||||
|
cFrag[vGPR] = (lIdx & 1) ? ctmp1 : ctmp0;
|
||||||
|
|
||||||
|
// Zero out for specific thread indices.
|
||||||
|
if ((lIdx > 7 && lIdx < 16) || (lIdx > 23 && lIdx < 32))
|
||||||
|
cFrag[vGPR] = 0.0f;
|
||||||
|
}
|
||||||
|
return cFrag;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline __device__ float4::Native_vec_ shuffle_d(float8 dFrag) {
|
||||||
|
const int lIdx = (int)threadIdx.x;
|
||||||
|
float4::Native_vec_ d_out;
|
||||||
|
|
||||||
|
for (int cChunk = 0; cChunk < 4; ++cChunk) {
|
||||||
|
int r_vGPR = (cChunk / 2) * 4;
|
||||||
|
int add8 = (lIdx & 0x4) ? 8 : 0;
|
||||||
|
int r_lIdx = (cChunk % 2) + (lIdx % 8) * 2 + add8;
|
||||||
|
float d_tmp0 = bpermute_lane(r_lIdx, dFrag[r_vGPR]);
|
||||||
|
float d_tmp1 = bpermute_lane(r_lIdx, dFrag[r_vGPR + 1]);
|
||||||
|
float d_tmp2 = bpermute_lane(r_lIdx, dFrag[r_vGPR + 2]);
|
||||||
|
float d_tmp3 = bpermute_lane(r_lIdx, dFrag[r_vGPR + 3]);
|
||||||
|
float val;
|
||||||
|
if (lIdx < 8) {
|
||||||
|
val = d_tmp0;
|
||||||
|
} else if (lIdx < 16) {
|
||||||
|
val = d_tmp1;
|
||||||
|
} else if (lIdx < 24) {
|
||||||
|
val = d_tmp2;
|
||||||
|
} else {
|
||||||
|
val = d_tmp3;
|
||||||
|
}
|
||||||
|
if (cChunk == 0) d_out.x = val;
|
||||||
|
else if (cChunk == 1) d_out.y = val;
|
||||||
|
else if (cChunk == 2) d_out.z = val;
|
||||||
|
else d_out.w = val;
|
||||||
|
}
|
||||||
|
return d_out;
|
||||||
|
}
|
||||||
|
|
||||||
|
float4::Native_vec_ FUNC(mma_sync_aligned_m16n8k16_row_col_f32_f16_f16_f32)(uint4::Native_vec_ a_reg, uint2::Native_vec_ b_reg, float4::Native_vec_ c_reg) {
|
||||||
|
// Reshuffle from Nvidia-like register layout to AMD layout:
|
||||||
|
half16 aFrag = shuffle_a(a_reg);
|
||||||
|
half16 bFrag = shuffle_b(b_reg);
|
||||||
|
float8 cFrag = shuffle_c(c_reg);
|
||||||
|
|
||||||
|
// Call the (built‐in) 16x16 MMA instruction. It returns a float8.
|
||||||
|
float8 dFrag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(aFrag, bFrag, cFrag);
|
||||||
|
|
||||||
|
// Unshuffle back into Nvidia expected float4 result
|
||||||
|
float4::Native_vec_ d_out = shuffle_d(dFrag);
|
||||||
|
|
||||||
|
return d_out;
|
||||||
|
}
|
||||||
|
|
||||||
|
float4::Native_vec_ FUNC(mma_sync_aligned_m16n8k16_row_col_f32_bf16_bf16_f32)(uint4::Native_vec_ a_reg, uint2::Native_vec_ b_reg, float4::Native_vec_ c_reg) {
|
||||||
|
// Reshuffle from Nvidia-like register layout to AMD layout:
|
||||||
|
half16 aFrag = shuffle_a(a_reg);
|
||||||
|
half16 bFrag = shuffle_b(b_reg);
|
||||||
|
float8 cFrag = shuffle_c(c_reg);
|
||||||
|
|
||||||
|
// Call the (built‐in) 16x16 MMA instruction. It returns a float8.
|
||||||
|
float8 dFrag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(aFrag, bFrag, cFrag);
|
||||||
|
|
||||||
|
// Unshuffle back into Nvidia expected float4 result
|
||||||
|
float4::Native_vec_ d_out = shuffle_d(dFrag);
|
||||||
|
|
||||||
|
return d_out;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -197,7 +197,9 @@ fn run_instruction<'input>(
|
||||||
| ast::Instruction::Xor { .. }
|
| ast::Instruction::Xor { .. }
|
||||||
| ast::Instruction::Vote { .. }
|
| ast::Instruction::Vote { .. }
|
||||||
| ast::Instruction::ReduxSync { .. }
|
| ast::Instruction::ReduxSync { .. }
|
||||||
| ast::Instruction::LdMatrix { .. } => result.push(Statement::Instruction(instruction)),
|
| ast::Instruction::GridDepControl { .. }
|
||||||
|
| ast::Instruction::LdMatrix { .. }
|
||||||
|
| ast::Instruction::Mma { .. } => result.push(Statement::Instruction(instruction)),
|
||||||
ast::Instruction::Add {
|
ast::Instruction::Add {
|
||||||
data:
|
data:
|
||||||
ast::ArithDetails::Float(ast::ArithFloat {
|
ast::ArithDetails::Float(ast::ArithFloat {
|
||||||
|
|
|
@ -1855,7 +1855,9 @@ fn get_modes<T: ast::Operand>(inst: &ast::Instruction<T>) -> InstructionModes {
|
||||||
| ast::Instruction::AtomCas { .. }
|
| ast::Instruction::AtomCas { .. }
|
||||||
| ast::Instruction::Vote { .. }
|
| ast::Instruction::Vote { .. }
|
||||||
| ast::Instruction::ReduxSync { .. }
|
| ast::Instruction::ReduxSync { .. }
|
||||||
| ast::Instruction::LdMatrix { .. } => InstructionModes::none(),
|
| ast::Instruction::GridDepControl { .. }
|
||||||
|
| ast::Instruction::LdMatrix { .. }
|
||||||
|
| ast::Instruction::Mma { .. } => InstructionModes::none(),
|
||||||
ast::Instruction::Add {
|
ast::Instruction::Add {
|
||||||
data: ast::ArithDetails::Integer(_),
|
data: ast::ArithDetails::Integer(_),
|
||||||
..
|
..
|
||||||
|
|
|
@ -153,6 +153,9 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
|
||||||
for (i, param) in method.input_arguments.iter().enumerate() {
|
for (i, param) in method.input_arguments.iter().enumerate() {
|
||||||
let value = unsafe { LLVMGetParam(fn_, i as u32) };
|
let value = unsafe { LLVMGetParam(fn_, i as u32) };
|
||||||
let name = self.resolver.get_or_add(param.name);
|
let name = self.resolver.get_or_add(param.name);
|
||||||
|
if let Some(align) = param.align {
|
||||||
|
unsafe { LLVMSetParamAlignment(value, align) };
|
||||||
|
}
|
||||||
unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) };
|
unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) };
|
||||||
self.resolver.register(param.name, value);
|
self.resolver.register(param.name, value);
|
||||||
if method.is_kernel {
|
if method.is_kernel {
|
||||||
|
@ -519,6 +522,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
ast::Instruction::CpAsyncCommitGroup {} => Ok(()), // nop
|
ast::Instruction::CpAsyncCommitGroup {} => Ok(()), // nop
|
||||||
ast::Instruction::CpAsyncWaitGroup { .. } => Ok(()), // nop
|
ast::Instruction::CpAsyncWaitGroup { .. } => Ok(()), // nop
|
||||||
ast::Instruction::CpAsyncWaitAll { .. } => Ok(()), // nop
|
ast::Instruction::CpAsyncWaitAll { .. } => Ok(()), // nop
|
||||||
|
ast::Instruction::GridDepControl { .. } => Ok(()), // nop
|
||||||
// replaced by a function call
|
// replaced by a function call
|
||||||
ast::Instruction::Bfe { .. }
|
ast::Instruction::Bfe { .. }
|
||||||
| ast::Instruction::Bar { .. }
|
| ast::Instruction::Bar { .. }
|
||||||
|
@ -529,7 +533,8 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
| ast::Instruction::Vote { .. }
|
| ast::Instruction::Vote { .. }
|
||||||
| ast::Instruction::Nanosleep { .. }
|
| ast::Instruction::Nanosleep { .. }
|
||||||
| ast::Instruction::ReduxSync { .. }
|
| ast::Instruction::ReduxSync { .. }
|
||||||
| ast::Instruction::LdMatrix { .. } => return Err(error_unreachable()),
|
| ast::Instruction::LdMatrix { .. }
|
||||||
|
| ast::Instruction::Mma { .. } => return Err(error_unreachable()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -51,7 +51,7 @@ quick_error! {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// GPU attributes needed at compile time.
|
/// GPU attributes needed at compile time.
|
||||||
#[derive(serde::Serialize)]
|
#[derive(Copy, Clone)]
|
||||||
pub struct Attributes {
|
pub struct Attributes {
|
||||||
/// Clock frequency in kHz.
|
/// Clock frequency in kHz.
|
||||||
pub clock_rate: u32,
|
pub clock_rate: u32,
|
||||||
|
|
|
@ -351,6 +351,35 @@ fn run_instruction<'input>(
|
||||||
let name = "sqrt_rn_ftz_f32";
|
let name = "sqrt_rn_ftz_f32";
|
||||||
to_call(resolver, fn_declarations, name.into(), i)?
|
to_call(resolver, fn_declarations, name.into(), i)?
|
||||||
}
|
}
|
||||||
|
i @ ptx_parser::Instruction::Mma {
|
||||||
|
data:
|
||||||
|
ast::MmaDetails {
|
||||||
|
alayout,
|
||||||
|
blayout,
|
||||||
|
dtype_scalar,
|
||||||
|
atype_scalar,
|
||||||
|
btype_scalar,
|
||||||
|
ctype_scalar,
|
||||||
|
},
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
let name = format!(
|
||||||
|
"mma_sync_aligned_m16n8k16_{}_{}_{}_{}_{}_{}",
|
||||||
|
match alayout {
|
||||||
|
ast::MatrixLayout::Row => "row",
|
||||||
|
ast::MatrixLayout::Col => "col",
|
||||||
|
},
|
||||||
|
match blayout {
|
||||||
|
ast::MatrixLayout::Row => "row",
|
||||||
|
ast::MatrixLayout::Col => "col",
|
||||||
|
},
|
||||||
|
scalar_to_ptx_name(dtype_scalar),
|
||||||
|
scalar_to_ptx_name(atype_scalar),
|
||||||
|
scalar_to_ptx_name(btype_scalar),
|
||||||
|
scalar_to_ptx_name(ctype_scalar),
|
||||||
|
);
|
||||||
|
to_call(resolver, fn_declarations, name.into(), i)?
|
||||||
|
}
|
||||||
i @ ptx_parser::Instruction::Sqrt {
|
i @ ptx_parser::Instruction::Sqrt {
|
||||||
data:
|
data:
|
||||||
ast::RcpData {
|
ast::RcpData {
|
||||||
|
|
|
@ -3,8 +3,8 @@ use super::{
|
||||||
StateSpace, VectorPrefix,
|
StateSpace, VectorPrefix,
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
FunnelShiftMode, MatrixNumber, MatrixShape, Mul24Control, PtxError, PtxParserState, Reduction,
|
FunnelShiftMode, MatrixLayout, MatrixNumber, MatrixShape, Mul24Control, PtxError,
|
||||||
ShiftDirection, ShuffleMode, VoteMode,
|
PtxParserState, Reduction, ShiftDirection, ShuffleMode, VoteMode,
|
||||||
};
|
};
|
||||||
use bitflags::bitflags;
|
use bitflags::bitflags;
|
||||||
use derive_more::Display;
|
use derive_more::Display;
|
||||||
|
@ -721,6 +721,30 @@ ptx_parser_macros::generate_instruction_type!(
|
||||||
space: { data.state_space },
|
space: { data.state_space },
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
GridDepControl {
|
||||||
|
data: crate::GridDepControlAction,
|
||||||
|
},
|
||||||
|
Mma {
|
||||||
|
data: MmaDetails,
|
||||||
|
arguments<T>: {
|
||||||
|
dst: {
|
||||||
|
repr: T,
|
||||||
|
type: { data.dtype() },
|
||||||
|
},
|
||||||
|
src1: {
|
||||||
|
repr: T,
|
||||||
|
type: { data.atype() },
|
||||||
|
},
|
||||||
|
src2: {
|
||||||
|
repr: T,
|
||||||
|
type: { data.btype() },
|
||||||
|
},
|
||||||
|
src3: {
|
||||||
|
repr: T,
|
||||||
|
type: { data.ctype() },
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
@ -2378,3 +2402,27 @@ pub struct ReduxSyncData {
|
||||||
pub type_: ScalarType,
|
pub type_: ScalarType,
|
||||||
pub reduction: Reduction,
|
pub reduction: Reduction,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct MmaDetails {
|
||||||
|
pub alayout: MatrixLayout,
|
||||||
|
pub blayout: MatrixLayout,
|
||||||
|
pub dtype_scalar: ScalarType,
|
||||||
|
pub atype_scalar: ScalarType,
|
||||||
|
pub btype_scalar: ScalarType,
|
||||||
|
pub ctype_scalar: ScalarType,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MmaDetails {
|
||||||
|
pub fn dtype(&self) -> Type {
|
||||||
|
Type::Vector(4, ScalarType::F32)
|
||||||
|
}
|
||||||
|
pub fn atype(&self) -> Type {
|
||||||
|
Type::Vector(4, ScalarType::U32)
|
||||||
|
}
|
||||||
|
pub fn btype(&self) -> Type {
|
||||||
|
Type::Vector(2, ScalarType::U32)
|
||||||
|
}
|
||||||
|
pub fn ctype(&self) -> Type {
|
||||||
|
Type::Vector(4, ScalarType::F32)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1862,6 +1862,9 @@ derive_parser!(
|
||||||
#[derive(Copy, Clone, Display, PartialEq, Eq, Hash)]
|
#[derive(Copy, Clone, Display, PartialEq, Eq, Hash)]
|
||||||
pub enum MatrixNumber { }
|
pub enum MatrixNumber { }
|
||||||
|
|
||||||
|
#[derive(Copy, Clone, Display, PartialEq, Eq, Hash)]
|
||||||
|
pub enum MatrixLayout { }
|
||||||
|
|
||||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov
|
||||||
mov{.vec}.type d, a => {
|
mov{.vec}.type d, a => {
|
||||||
Instruction::Mov {
|
Instruction::Mov {
|
||||||
|
@ -3897,6 +3900,37 @@ derive_parser!(
|
||||||
.type: ScalarType = {.b16, .b8};
|
.type: ScalarType = {.b16, .b8};
|
||||||
// .dst_fmt = { .b8x16 };
|
// .dst_fmt = { .b8x16 };
|
||||||
// .src_fmt = { .b6x16_p32, .b4x16_p64 };
|
// .src_fmt = { .b6x16_p32, .b4x16_p64 };
|
||||||
|
|
||||||
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-griddepcontrol
|
||||||
|
griddepcontrol.action => {
|
||||||
|
Instruction::GridDepControl {
|
||||||
|
data: action
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.action: GridDepControlAction = { .launch_dependents, .wait };
|
||||||
|
|
||||||
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-mma
|
||||||
|
mma.sync.aligned.m16n8k16.alayout.blayout.dtype.bf16.bf16.ctype d, a, b, c => {
|
||||||
|
if dtype != ScalarType::F32 || ctype != ScalarType::F32 {
|
||||||
|
state.errors.push(PtxError::Todo);
|
||||||
|
}
|
||||||
|
Instruction::Mma {
|
||||||
|
data: MmaDetails {
|
||||||
|
alayout,
|
||||||
|
blayout,
|
||||||
|
dtype_scalar: dtype,
|
||||||
|
atype_scalar: ScalarType::BF16,
|
||||||
|
btype_scalar: ScalarType::BF16,
|
||||||
|
ctype_scalar: ctype,
|
||||||
|
},
|
||||||
|
arguments: MmaArgs { dst: d, src1: a, src2: b, src3: c }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
.alayout: MatrixLayout = {.row};
|
||||||
|
.blayout: MatrixLayout = {.col};
|
||||||
|
.ctype: ScalarType = {.f16, .f32};
|
||||||
|
.dtype: ScalarType = {.f16, .f32};
|
||||||
);
|
);
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
|
@ -1,6 +1,4 @@
|
||||||
use bpaf::{any, doc::Style, Bpaf, Parser};
|
use bpaf::{any, choice, doc::Style, literal, Bpaf, Parser};
|
||||||
use hip_runtime_sys::{hipDeviceProp_tR0600, hipGetDevicePropertiesR0600};
|
|
||||||
use std::{ffi::CStr, mem};
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Bpaf)]
|
#[derive(Debug, Clone, Bpaf)]
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
|
@ -12,6 +10,8 @@ pub struct Options {
|
||||||
#[bpaf(short, long)]
|
#[bpaf(short, long)]
|
||||||
verbose: bool,
|
verbose: bool,
|
||||||
#[bpaf(external)]
|
#[bpaf(external)]
|
||||||
|
lineinfo: bool,
|
||||||
|
#[bpaf(external)]
|
||||||
gpu_name: String,
|
gpu_name: String,
|
||||||
#[bpaf(long, short('O'), fallback(3))]
|
#[bpaf(long, short('O'), fallback(3))]
|
||||||
opt_level: usize,
|
opt_level: usize,
|
||||||
|
@ -19,48 +19,32 @@ pub struct Options {
|
||||||
input: String,
|
input: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn lineinfo() -> impl Parser<bool> {
|
||||||
|
choice(["-lineinfo", "--lineinfo"].into_iter().map(|s| {
|
||||||
|
literal(s)
|
||||||
|
.anywhere()
|
||||||
|
.optional()
|
||||||
|
.map(|_| true)
|
||||||
|
.fallback(false)
|
||||||
|
.boxed()
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
// #[bpaf(long, long("gpu_name"), fallback_with(default_arch))]
|
// #[bpaf(long, long("gpu_name"), fallback_with(default_arch))]
|
||||||
fn gpu_name() -> impl Parser<String> {
|
fn gpu_name() -> impl Parser<String> {
|
||||||
any("", move |s: String| {
|
any("", move |s: String| {
|
||||||
Some(s.strip_prefix("-arch=")?.to_owned())
|
Some(
|
||||||
|
s.strip_prefix("-arch=")
|
||||||
|
.or_else(|| s.strip_prefix("--gpu-name="))?
|
||||||
|
.to_owned(),
|
||||||
|
)
|
||||||
})
|
})
|
||||||
.metavar(&[("-arch=", Style::Literal), ("ARG", Style::Metavar)])
|
.metavar(&[("--gpu-name=", Style::Literal), ("SM", Style::Metavar)])
|
||||||
.anywhere()
|
.anywhere()
|
||||||
.fallback_with(|| Ok::<String, &'static str>("sm_52".to_string()))
|
.fallback_with(|| Ok::<String, &'static str>("sm_52".to_string()))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
let options = options().run();
|
let options = options().run();
|
||||||
let comgr = comgr::Comgr::new().unwrap();
|
std::fs::copy(&options.input, &options.output).unwrap();
|
||||||
unsafe { hip_runtime_sys::hipInit(0) }.unwrap();
|
|
||||||
let mut dev_props: hipDeviceProp_tR0600 = unsafe { mem::zeroed() };
|
|
||||||
let (gpu_arch, clock_rate) = get_gpu_arch_and_clock_rate(&mut dev_props);
|
|
||||||
let input = std::fs::read_to_string(options.input).unwrap();
|
|
||||||
let ast = ptx_parser::parse_module_checked(&input).unwrap();
|
|
||||||
let llvm = ptx::to_llvm_module(
|
|
||||||
ast,
|
|
||||||
ptx::Attributes {
|
|
||||||
clock_rate: clock_rate as u32,
|
|
||||||
},
|
|
||||||
|_| {},
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
let elf_binary = comgr::compile_bitcode(
|
|
||||||
&comgr,
|
|
||||||
gpu_arch,
|
|
||||||
&*llvm.llvm_ir.write_bitcode_to_memory(),
|
|
||||||
&*llvm.linked_bitcode(),
|
|
||||||
&*llvm.attributes_ir.write_bitcode_to_memory(),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
std::fs::write(options.output, elf_binary).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_gpu_arch_and_clock_rate<'a>(dev_props: &'a mut hipDeviceProp_tR0600) -> (&'a str, i32) {
|
|
||||||
unsafe { hipGetDevicePropertiesR0600(dev_props, 0) }.unwrap();
|
|
||||||
let gcn_arch_name = &dev_props.gcnArchName;
|
|
||||||
let gcn_arch_name = unsafe { CStr::from_ptr(gcn_arch_name.as_ptr()) };
|
|
||||||
let gcn_arch_name = gcn_arch_name.to_str();
|
|
||||||
(gcn_arch_name.unwrap(), dev_props.clockRate)
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -89,7 +89,15 @@ pub(crate) fn get_attribute(
|
||||||
*pi = 32;
|
*pi = 32;
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_TCC_DRIVER => {
|
// TODO: maintain a table, certain RDNAs are 1/16, some are 1/32
|
||||||
|
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_SINGLE_TO_DOUBLE_PRECISION_PERF_RATIO => {
|
||||||
|
*pi = 32;
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_TCC_DRIVER
|
||||||
|
| CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED
|
||||||
|
| CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MEMPOOL_SUPPORTED_HANDLE_TYPES
|
||||||
|
| CUdevice_attribute::CU_DEVICE_ATTRIBUTE_DMA_BUF_SUPPORTED => {
|
||||||
*pi = 0;
|
*pi = 0;
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
@ -211,9 +219,6 @@ pub(crate) fn get_attribute(
|
||||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_PERSISTING_L2_CACHE_SIZE => {
|
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_PERSISTING_L2_CACHE_SIZE => {
|
||||||
return get_device_prop(pi, dev_idx, |props| props.persistingL2CacheMaxSize)
|
return get_device_prop(pi, dev_idx, |props| props.persistingL2CacheMaxSize)
|
||||||
}
|
}
|
||||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_SINGLE_TO_DOUBLE_PRECISION_PERF_RATIO => {
|
|
||||||
return get_device_prop(pi, dev_idx, |props| props.singleToDoublePrecisionPerfRatio)
|
|
||||||
}
|
|
||||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_ACCESS_POLICY_WINDOW_SIZE => {
|
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_ACCESS_POLICY_WINDOW_SIZE => {
|
||||||
return get_device_prop(pi, dev_idx, |props| props.accessPolicyMaxWindowSize)
|
return get_device_prop(pi, dev_idx, |props| props.accessPolicyMaxWindowSize)
|
||||||
}
|
}
|
||||||
|
|
|
@ -487,9 +487,9 @@ pub(crate) unsafe fn occupancy_max_active_blocks_per_multiprocessor_with_flags(
|
||||||
dynamic_smem_size: usize,
|
dynamic_smem_size: usize,
|
||||||
flags: ::core::ffi::c_uint,
|
flags: ::core::ffi::c_uint,
|
||||||
) -> hipError_t {
|
) -> hipError_t {
|
||||||
hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags(
|
hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags(
|
||||||
num_blocks,
|
num_blocks,
|
||||||
func.0.cast(),
|
func,
|
||||||
block_size,
|
block_size,
|
||||||
dynamic_smem_size,
|
dynamic_smem_size,
|
||||||
flags,
|
flags,
|
||||||
|
|
12
zluda/src/impl/hipfix.rs
Normal file
12
zluda/src/impl/hipfix.rs
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
// There's a bug in hipDrvPointerGetAttributes where it returns
|
||||||
|
// HIP_ERROR_INVALID_VALUE if the pointer is null. It works correctly for any
|
||||||
|
// other invalid pointer
|
||||||
|
pub(crate) fn get_attributes(
|
||||||
|
ptr: hip_runtime_sys::hipDeviceptr_t,
|
||||||
|
) -> hip_runtime_sys::hipDeviceptr_t {
|
||||||
|
if ptr.0.is_null() {
|
||||||
|
hip_runtime_sys::hipDeviceptr_t(usize::MAX as _)
|
||||||
|
} else {
|
||||||
|
ptr
|
||||||
|
}
|
||||||
|
}
|
|
@ -7,6 +7,7 @@ pub(super) mod driver;
|
||||||
pub(super) mod event;
|
pub(super) mod event;
|
||||||
pub(super) mod function;
|
pub(super) mod function;
|
||||||
pub(super) mod graph;
|
pub(super) mod graph;
|
||||||
|
pub(super) mod hipfix;
|
||||||
pub(super) mod kernel;
|
pub(super) mod kernel;
|
||||||
pub(super) mod library;
|
pub(super) mod library;
|
||||||
pub(super) mod memory;
|
pub(super) mod memory;
|
||||||
|
|
|
@ -20,6 +20,10 @@ impl ZludaObject for Module {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static EMPTY_PTX: &str = ".version 6.5
|
||||||
|
.target sm_30
|
||||||
|
.address_size 64";
|
||||||
|
|
||||||
// get_ptx takes an `image` that can be anything we support and returns a
|
// get_ptx takes an `image` that can be anything we support and returns a
|
||||||
// String containing a ptx extracted from `image`.
|
// String containing a ptx extracted from `image`.
|
||||||
fn get_ptx<'a>(image: CodeLibraryRef<'a>) -> Result<Cow<'a, str>, CUerror> {
|
fn get_ptx<'a>(image: CodeLibraryRef<'a>) -> Result<Cow<'a, str>, CUerror> {
|
||||||
|
@ -58,11 +62,17 @@ fn cow_bytes_to_str<'a>(data: Cow<'a, [u8]>) -> Option<Cow<'a, str>> {
|
||||||
|
|
||||||
pub(crate) fn load_hip_module(library: CodeLibraryRef) -> Result<hipModule_t, CUerror> {
|
pub(crate) fn load_hip_module(library: CodeLibraryRef) -> Result<hipModule_t, CUerror> {
|
||||||
let global_state = driver::global_state()?;
|
let global_state = driver::global_state()?;
|
||||||
let text = get_ptx(library)?;
|
let maybe_ptx = get_ptx(library);
|
||||||
|
let text = if cfg!(debug_assertions) {
|
||||||
|
maybe_ptx?
|
||||||
|
} else {
|
||||||
|
maybe_ptx.unwrap_or_else(|_| Cow::Borrowed(EMPTY_PTX))
|
||||||
|
};
|
||||||
let hip_properties = get_hip_properties()?;
|
let hip_properties = get_hip_properties()?;
|
||||||
let gcn_arch = get_gcn_arch(&hip_properties)?;
|
let gcn_arch = get_gcn_arch(&hip_properties)?;
|
||||||
let attributes = ptx::Attributes {
|
let attributes = ExtraCacheAttributes {
|
||||||
clock_rate: hip_properties.clockRate as u32,
|
clock_rate: hip_properties.clockRate as u32,
|
||||||
|
is_debug: cfg!(debug_assertions),
|
||||||
};
|
};
|
||||||
let mut cache_with_key = global_state.cache_path.as_ref().and_then(|p| {
|
let mut cache_with_key = global_state.cache_path.as_ref().and_then(|p| {
|
||||||
let cache = zluda_cache::ModuleCache::open(p)?;
|
let cache = zluda_cache::ModuleCache::open(p)?;
|
||||||
|
@ -84,6 +94,12 @@ pub(crate) fn load_hip_module(library: CodeLibraryRef) -> Result<hipModule_t, CU
|
||||||
Ok(hip_module)
|
Ok(hip_module)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(serde::Serialize)]
|
||||||
|
struct ExtraCacheAttributes {
|
||||||
|
is_debug: bool,
|
||||||
|
clock_rate: u32,
|
||||||
|
}
|
||||||
|
|
||||||
fn get_hip_properties<'a>() -> Result<hipDeviceProp_tR0600, CUerror> {
|
fn get_hip_properties<'a>() -> Result<hipDeviceProp_tR0600, CUerror> {
|
||||||
let hip_dev = super::context::get_current_device()?;
|
let hip_dev = super::context::get_current_device()?;
|
||||||
let mut props = unsafe { mem::zeroed() };
|
let mut props = unsafe { mem::zeroed() };
|
||||||
|
@ -100,7 +116,7 @@ fn get_cache_key<'a, 'b>(
|
||||||
global_state: &'static driver::GlobalState,
|
global_state: &'static driver::GlobalState,
|
||||||
isa: &'a str,
|
isa: &'a str,
|
||||||
text: &str,
|
text: &str,
|
||||||
attributes: &ptx::Attributes,
|
attributes: &impl serde::Serialize,
|
||||||
) -> Option<zluda_cache::ModuleKey<'a>> {
|
) -> Option<zluda_cache::ModuleKey<'a>> {
|
||||||
// Serialization here is deterministic. When marking a type with
|
// Serialization here is deterministic. When marking a type with
|
||||||
// #[derive(serde::Serialize)] the derived implementation will just write
|
// #[derive(serde::Serialize)] the derived implementation will just write
|
||||||
|
@ -129,7 +145,7 @@ fn load_cached_binary(
|
||||||
fn compile_from_ptx_and_cache(
|
fn compile_from_ptx_and_cache(
|
||||||
comgr: &comgr::Comgr,
|
comgr: &comgr::Comgr,
|
||||||
gcn_arch: &str,
|
gcn_arch: &str,
|
||||||
attributes: ptx::Attributes,
|
attributes: ExtraCacheAttributes,
|
||||||
text: &str,
|
text: &str,
|
||||||
cache_with_key: &mut Option<(zluda_cache::ModuleCache, zluda_cache::ModuleKey)>,
|
cache_with_key: &mut Option<(zluda_cache::ModuleCache, zluda_cache::ModuleKey)>,
|
||||||
) -> Result<Vec<u8>, CUerror> {
|
) -> Result<Vec<u8>, CUerror> {
|
||||||
|
@ -138,7 +154,14 @@ fn compile_from_ptx_and_cache(
|
||||||
} else {
|
} else {
|
||||||
ptx_parser::parse_module_unchecked(text)
|
ptx_parser::parse_module_unchecked(text)
|
||||||
};
|
};
|
||||||
let llvm_module = ptx::to_llvm_module(ast, attributes, |_| {}).map_err(|_| CUerror::UNKNOWN)?;
|
let llvm_module = ptx::to_llvm_module(
|
||||||
|
ast,
|
||||||
|
ptx::Attributes {
|
||||||
|
clock_rate: attributes.clock_rate,
|
||||||
|
},
|
||||||
|
|_| {},
|
||||||
|
)
|
||||||
|
.map_err(|_| CUerror::UNKNOWN)?;
|
||||||
let elf_module = comgr::compile_bitcode(
|
let elf_module = comgr::compile_bitcode(
|
||||||
comgr,
|
comgr,
|
||||||
gcn_arch,
|
gcn_arch,
|
||||||
|
|
|
@ -2,7 +2,7 @@ use cuda_types::cuda::*;
|
||||||
use hip_runtime_sys::*;
|
use hip_runtime_sys::*;
|
||||||
use std::{ffi::c_void, ptr};
|
use std::{ffi::c_void, ptr};
|
||||||
|
|
||||||
use crate::r#impl::driver;
|
use crate::r#impl::{driver, hipfix};
|
||||||
|
|
||||||
// TODO: handlehipMemoryTypeUnregistered
|
// TODO: handlehipMemoryTypeUnregistered
|
||||||
fn to_cu_memory_type(cu: hipMemoryType) -> Result<CUmemorytype, hipErrorCode_t> {
|
fn to_cu_memory_type(cu: hipMemoryType) -> Result<CUmemorytype, hipErrorCode_t> {
|
||||||
|
@ -59,7 +59,12 @@ pub(crate) unsafe fn get_attributes(
|
||||||
data: &mut *mut ::core::ffi::c_void,
|
data: &mut *mut ::core::ffi::c_void,
|
||||||
ptr: hipDeviceptr_t,
|
ptr: hipDeviceptr_t,
|
||||||
) -> CUresult {
|
) -> CUresult {
|
||||||
hipDrvPointerGetAttributes(num_attributes, attributes, data, ptr)?;
|
hipDrvPointerGetAttributes(
|
||||||
|
num_attributes,
|
||||||
|
attributes,
|
||||||
|
data,
|
||||||
|
hipfix::get_attributes(ptr),
|
||||||
|
)?;
|
||||||
let attributes = std::slice::from_raw_parts_mut(attributes, num_attributes as usize);
|
let attributes = std::slice::from_raw_parts_mut(attributes, num_attributes as usize);
|
||||||
let data = std::slice::from_raw_parts_mut(data, num_attributes as usize);
|
let data = std::slice::from_raw_parts_mut(data, num_attributes as usize);
|
||||||
for (attr, data_ptr) in attributes.iter().copied().zip(data.iter().copied()) {
|
for (attr, data_ptr) in attributes.iter().copied().zip(data.iter().copied()) {
|
||||||
|
@ -88,7 +93,7 @@ mod tests {
|
||||||
use crate::tests::CudaApi;
|
use crate::tests::CudaApi;
|
||||||
use cuda_macros::test_cuda;
|
use cuda_macros::test_cuda;
|
||||||
use cuda_types::cuda::*;
|
use cuda_types::cuda::*;
|
||||||
use std::{ffi::c_void, mem, ptr};
|
use std::{ffi::c_void, i32, mem, ptr, usize};
|
||||||
|
|
||||||
#[test_cuda]
|
#[test_cuda]
|
||||||
pub unsafe fn unknown_ptr_attribute(api: impl CudaApi) {
|
pub unsafe fn unknown_ptr_attribute(api: impl CudaApi) {
|
||||||
|
@ -162,4 +167,47 @@ mod tests {
|
||||||
);
|
);
|
||||||
assert_eq!(context, CUcontext(ptr::null_mut()));
|
assert_eq!(context, CUcontext(ptr::null_mut()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test_cuda]
|
||||||
|
pub unsafe fn null_ptr_attributes_success(api: impl CudaApi) {
|
||||||
|
api.cuInit(0);
|
||||||
|
api.cuCtxCreate_v2(&mut mem::zeroed(), 0, 0);
|
||||||
|
let mut context = CUcontext(1 as _);
|
||||||
|
let mut mem_type = mem::transmute::<_, CUmemorytype>(u32::MAX);
|
||||||
|
let mut dev_ptr = mem::transmute::<_, *mut c_void>(usize::MAX);
|
||||||
|
let mut host_ptr = mem::transmute::<_, *mut c_void>(usize::MAX);
|
||||||
|
let mut is_managed = true;
|
||||||
|
let mut ordinal = i32::MAX;
|
||||||
|
let mut attrs = [
|
||||||
|
CUpointer_attribute::CU_POINTER_ATTRIBUTE_CONTEXT,
|
||||||
|
CUpointer_attribute::CU_POINTER_ATTRIBUTE_MEMORY_TYPE,
|
||||||
|
CUpointer_attribute::CU_POINTER_ATTRIBUTE_DEVICE_POINTER,
|
||||||
|
CUpointer_attribute::CU_POINTER_ATTRIBUTE_HOST_POINTER,
|
||||||
|
CUpointer_attribute::CU_POINTER_ATTRIBUTE_IS_MANAGED,
|
||||||
|
CUpointer_attribute::CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
|
||||||
|
];
|
||||||
|
let mut values = [
|
||||||
|
std::ptr::from_mut(&mut context).cast::<c_void>(),
|
||||||
|
std::ptr::from_mut(&mut mem_type).cast::<c_void>(),
|
||||||
|
std::ptr::from_mut(&mut dev_ptr).cast::<c_void>(),
|
||||||
|
std::ptr::from_mut(&mut host_ptr).cast::<c_void>(),
|
||||||
|
std::ptr::from_mut(&mut is_managed).cast::<c_void>(),
|
||||||
|
std::ptr::from_mut(&mut ordinal).cast::<c_void>(),
|
||||||
|
];
|
||||||
|
assert_eq!(
|
||||||
|
CUresult::SUCCESS,
|
||||||
|
api.cuPointerGetAttributes_unchecked(
|
||||||
|
attrs.len() as u32,
|
||||||
|
attrs.as_mut_ptr(),
|
||||||
|
values.as_mut_ptr(),
|
||||||
|
CUdeviceptr_v2(ptr::null_mut())
|
||||||
|
)
|
||||||
|
);
|
||||||
|
assert_eq!(context, CUcontext(ptr::null_mut()));
|
||||||
|
assert_eq!(mem_type, CUmemorytype(0));
|
||||||
|
assert_eq!(dev_ptr, ptr::null_mut());
|
||||||
|
assert_eq!(host_ptr, ptr::null_mut());
|
||||||
|
assert_eq!(is_managed, false);
|
||||||
|
assert_eq!(ordinal, -2);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,6 +43,86 @@ pub(crate) unsafe fn device_get_count_v2(device_count: &mut ::core::ffi::c_uint)
|
||||||
rsmi_num_monitor_devices(device_count)
|
rsmi_num_monitor_devices(device_count)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) unsafe fn device_get_handle_by_pci_bus_id_v2(
|
||||||
|
pci_bus_id: &std::ffi::CStr,
|
||||||
|
device: &mut cuda_types::nvml::nvmlDevice_t,
|
||||||
|
) -> nvmlReturn_t {
|
||||||
|
let pci = parse_pci_bus_id(pci_bus_id).ok_or(nvmlError_t::INVALID_ARGUMENT)?;
|
||||||
|
let bdfid = pci.to_bdfid();
|
||||||
|
let mut device_count = 0;
|
||||||
|
rsmi_num_monitor_devices(&mut device_count)?;
|
||||||
|
for dv_ind in 0..device_count {
|
||||||
|
let mut curr_bdfid = 0;
|
||||||
|
rsmi_dev_pci_id_get(dv_ind, &mut curr_bdfid)?;
|
||||||
|
if curr_bdfid == bdfid {
|
||||||
|
*device = Device { _index: dv_ind }.wrap();
|
||||||
|
return nvmlReturn_t::SUCCESS;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nvmlReturn_t::ERROR_NOT_FOUND
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
|
struct PciBusId {
|
||||||
|
domain: u16,
|
||||||
|
bus: u8,
|
||||||
|
device: u8,
|
||||||
|
function: u8,
|
||||||
|
}
|
||||||
|
impl PciBusId {
|
||||||
|
fn to_bdfid(self) -> u64 {
|
||||||
|
((self.domain as u64) << 32)
|
||||||
|
| ((self.bus as u64) << 8)
|
||||||
|
| ((self.device as u64) << 3)
|
||||||
|
| (self.function as u64)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_pci_bus_id(id: &std::ffi::CStr) -> Option<PciBusId> {
|
||||||
|
let s = id.to_str().ok()?.trim();
|
||||||
|
let mut domain: u16 = 0;
|
||||||
|
let mut rest = s;
|
||||||
|
if let Some(colon1) = s.find(':') {
|
||||||
|
if colon1 == 4 {
|
||||||
|
domain = hex_u16(&s[..4])?;
|
||||||
|
rest = &s[5..];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let mut parts = rest.split(':');
|
||||||
|
let bus_part = parts.next()?;
|
||||||
|
let tail = parts.next()?;
|
||||||
|
if parts.next().is_some() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let mut dev_func = tail.split('.');
|
||||||
|
let dev_part = dev_func.next()?;
|
||||||
|
let func_part = dev_func.next();
|
||||||
|
let function = match func_part {
|
||||||
|
Some(f) => hex_u8(f)?,
|
||||||
|
None => 0,
|
||||||
|
};
|
||||||
|
Some(PciBusId {
|
||||||
|
domain,
|
||||||
|
bus: hex_u8(bus_part)?,
|
||||||
|
device: hex_u8(dev_part)?,
|
||||||
|
function,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn hex_u16(s: &str) -> Option<u16> {
|
||||||
|
if s.len() > 4 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
u16::from_str_radix(s, 16).ok()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn hex_u8(s: &str) -> Option<u8> {
|
||||||
|
if s.len() > 2 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
u8::from_str_radix(s, 16).ok()
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) unsafe fn device_get_field_values(
|
pub(crate) unsafe fn device_get_field_values(
|
||||||
_device: &Device,
|
_device: &Device,
|
||||||
values_count: ::core::ffi::c_int,
|
values_count: ::core::ffi::c_int,
|
||||||
|
@ -75,3 +155,36 @@ pub(crate) fn device_get_handle_by_index_v2(
|
||||||
*device = Device { _index: index }.wrap();
|
*device = Device { _index: index }.wrap();
|
||||||
nvmlReturn_t::SUCCESS
|
nvmlReturn_t::SUCCESS
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
#[test]
|
||||||
|
fn parse_pci_bus_id_full() {
|
||||||
|
let id = std::ffi::CString::new("0100:65:a0.f").unwrap();
|
||||||
|
let parsed = super::parse_pci_bus_id(&id).unwrap();
|
||||||
|
assert_eq!(parsed.domain, 0x0100);
|
||||||
|
assert_eq!(parsed.bus, 0x65);
|
||||||
|
assert_eq!(parsed.device, 0xa0);
|
||||||
|
assert_eq!(parsed.function, 0xf);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_pci_bus_id_no_func() {
|
||||||
|
let id = std::ffi::CString::new("0100:65:a0").unwrap();
|
||||||
|
let parsed = super::parse_pci_bus_id(&id).unwrap();
|
||||||
|
assert_eq!(parsed.domain, 0x0100);
|
||||||
|
assert_eq!(parsed.bus, 0x65);
|
||||||
|
assert_eq!(parsed.device, 0xa0);
|
||||||
|
assert_eq!(parsed.function, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_pci_bus_id_no_domain() {
|
||||||
|
let id = std::ffi::CString::new("65:a0.f").unwrap();
|
||||||
|
let parsed = super::parse_pci_bus_id(&id).unwrap();
|
||||||
|
assert_eq!(parsed.domain, 0);
|
||||||
|
assert_eq!(parsed.bus, 0x65);
|
||||||
|
assert_eq!(parsed.device, 0xa0);
|
||||||
|
assert_eq!(parsed.function, 0xf);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -23,6 +23,13 @@ pub(crate) unsafe fn device_get_count_v2(_device_count: &mut ::core::ffi::c_uint
|
||||||
crate::impl_common::unimplemented()
|
crate::impl_common::unimplemented()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) unsafe fn device_get_handle_by_pci_bus_id_v2(
|
||||||
|
pci_bus_id: &std::ffi::CStr,
|
||||||
|
device: &mut cuda_types::nvml::nvmlDevice_t,
|
||||||
|
) -> nvmlReturn_t {
|
||||||
|
crate::impl_common::unimplemented()
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) unsafe fn device_get_field_values(
|
pub(crate) unsafe fn device_get_field_values(
|
||||||
_device: cuda_types::nvml::nvmlDevice_t,
|
_device: cuda_types::nvml::nvmlDevice_t,
|
||||||
_values_count: ::core::ffi::c_int,
|
_values_count: ::core::ffi::c_int,
|
||||||
|
@ -31,10 +38,6 @@ pub(crate) unsafe fn device_get_field_values(
|
||||||
crate::impl_common::unimplemented()
|
crate::impl_common::unimplemented()
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe fn get_field_value(_field: &mut nvmlFieldValue_st) -> Result<(), nvmlError_t> {
|
|
||||||
crate::impl_common::unimplemented()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) unsafe fn device_get_gpu_fabric_info(
|
pub(crate) unsafe fn device_get_gpu_fabric_info(
|
||||||
_device: cuda_types::nvml::nvmlDevice_t,
|
_device: cuda_types::nvml::nvmlDevice_t,
|
||||||
_gpu_fabric_info: &mut cuda_types::nvml::nvmlGpuFabricInfo_t,
|
_gpu_fabric_info: &mut cuda_types::nvml::nvmlGpuFabricInfo_t,
|
||||||
|
|
|
@ -48,6 +48,7 @@ cuda_macros::nvml_function_declarations!(
|
||||||
nvmlDeviceGetFieldValues,
|
nvmlDeviceGetFieldValues,
|
||||||
nvmlDeviceGetGpuFabricInfo,
|
nvmlDeviceGetGpuFabricInfo,
|
||||||
nvmlDeviceGetHandleByIndex_v2,
|
nvmlDeviceGetHandleByIndex_v2,
|
||||||
|
nvmlDeviceGetHandleByPciBusId_v2,
|
||||||
nvmlInit,
|
nvmlInit,
|
||||||
nvmlInitWithFlags,
|
nvmlInitWithFlags,
|
||||||
nvmlInit_v2,
|
nvmlInit_v2,
|
||||||
|
|
|
@ -303,6 +303,7 @@ pub(crate) enum ErrorEntry {
|
||||||
},
|
},
|
||||||
NullPointer(&'static str),
|
NullPointer(&'static str),
|
||||||
UnknownLibrary(CUlibrary),
|
UnknownLibrary(CUlibrary),
|
||||||
|
SavedModule(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl Send for ErrorEntry {}
|
unsafe impl Send for ErrorEntry {}
|
||||||
|
@ -344,93 +345,94 @@ impl Display for ErrorEntry {
|
||||||
match self {
|
match self {
|
||||||
ErrorEntry::IoError(e) => e.fmt(f),
|
ErrorEntry::IoError(e) => e.fmt(f),
|
||||||
ErrorEntry::CreatedDumpDirectory(dir) => {
|
ErrorEntry::CreatedDumpDirectory(dir) => {
|
||||||
write!(
|
write!(
|
||||||
f,
|
f,
|
||||||
"Created trace directory {} ",
|
"Created trace directory {} ",
|
||||||
dir.as_os_str().to_string_lossy()
|
dir.as_os_str().to_string_lossy()
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
ErrorEntry::ErrorBox(e) => e.fmt(f),
|
ErrorEntry::ErrorBox(e) => e.fmt(f),
|
||||||
ErrorEntry::UnsupportedModule {
|
ErrorEntry::UnsupportedModule {
|
||||||
module,
|
module,
|
||||||
raw_image,
|
raw_image,
|
||||||
kind,
|
kind,
|
||||||
} => {
|
} => {
|
||||||
write!(
|
write!(
|
||||||
f,
|
f,
|
||||||
"Unsupported {} module {:?} loaded from module image {:?}",
|
"Unsupported {} module {:?} loaded from module image {:?}",
|
||||||
kind, module, raw_image
|
kind, module, raw_image
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
ErrorEntry::MalformedModulePath(e) => e.fmt(f),
|
ErrorEntry::MalformedModulePath(e) => e.fmt(f),
|
||||||
ErrorEntry::NonUtf8ModuleText(e) => e.fmt(f),
|
ErrorEntry::NonUtf8ModuleText(e) => e.fmt(f),
|
||||||
ErrorEntry::ModuleParsingError(file_name) => {
|
ErrorEntry::ModuleParsingError(file_name) => {
|
||||||
write!(
|
write!(
|
||||||
f,
|
f,
|
||||||
"Error parsing module, log has been written to {}",
|
"Error parsing module, log has been written to {}",
|
||||||
file_name
|
file_name
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
ErrorEntry::NulInsideModuleText(e) => e.fmt(f),
|
ErrorEntry::NulInsideModuleText(e) => e.fmt(f),
|
||||||
ErrorEntry::Lz4DecompressionFailure => write!(f, "LZ4 decompression failure"),
|
ErrorEntry::Lz4DecompressionFailure => write!(f, "LZ4 decompression failure"),
|
||||||
ErrorEntry::ZstdDecompressionFailure(err_code) => write!(f, "Zstd decompression failure: {}", zstd_safe::get_error_name(*err_code)),
|
ErrorEntry::ZstdDecompressionFailure(err_code) => write!(f, "Zstd decompression failure: {}", zstd_safe::get_error_name(*err_code)),
|
||||||
ErrorEntry::UnexpectedBinaryField {
|
ErrorEntry::UnexpectedBinaryField {
|
||||||
field_name,
|
field_name,
|
||||||
expected,
|
expected,
|
||||||
observed,
|
observed,
|
||||||
} => write!(
|
} => write!(
|
||||||
f,
|
f,
|
||||||
"Unexpected field {}. Expected one of: [{}], observed: {}",
|
"Unexpected field {}. Expected one of: [{}], observed: {}",
|
||||||
field_name,
|
field_name,
|
||||||
expected
|
expected
|
||||||
.iter()
|
.iter()
|
||||||
.map(|x| x.to_string())
|
.map(|x| x.to_string())
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.join(", "),
|
.join(", "),
|
||||||
observed
|
observed
|
||||||
),
|
),
|
||||||
ErrorEntry::UnexpectedArgument {
|
ErrorEntry::UnexpectedArgument {
|
||||||
arg_name,
|
arg_name,
|
||||||
expected,
|
expected,
|
||||||
observed,
|
observed,
|
||||||
} => write!(
|
} => write!(
|
||||||
f,
|
f,
|
||||||
"Unexpected argument {}. Expected one of: {{{}}}, observed: {}",
|
"Unexpected argument {}. Expected one of: {{{}}}, observed: {}",
|
||||||
arg_name,
|
arg_name,
|
||||||
expected
|
expected
|
||||||
.iter()
|
.iter()
|
||||||
.map(|x| x.to_string())
|
.map(|x| x.to_string())
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.join(", "),
|
.join(", "),
|
||||||
observed
|
observed
|
||||||
),
|
),
|
||||||
ErrorEntry::InvalidEnvVar {
|
ErrorEntry::InvalidEnvVar {
|
||||||
var,
|
var,
|
||||||
pattern,
|
pattern,
|
||||||
value,
|
value,
|
||||||
} => write!(
|
} => write!(
|
||||||
f,
|
f,
|
||||||
"Unexpected value of environment variable {var}. Expected pattern: {pattern}, got value: {value}"
|
"Unexpected value of environment variable {var}. Expected pattern: {pattern}, got value: {value}"
|
||||||
),
|
),
|
||||||
ErrorEntry::FunctionNotFound(cuda_function_name) => write!(
|
ErrorEntry::FunctionNotFound(cuda_function_name) => write!(
|
||||||
f,
|
f,
|
||||||
"No function {cuda_function_name} in the underlying library"
|
"No function {cuda_function_name} in the underlying library"
|
||||||
),
|
),
|
||||||
ErrorEntry::UnexpectedExportTableSize { expected, computed } => {
|
ErrorEntry::UnexpectedExportTableSize { expected, computed } => {
|
||||||
write!(f, "Table length mismatch. Expected: {expected}, got: {computed}")
|
write!(f, "Table length mismatch. Expected: {expected}, got: {computed}")
|
||||||
}
|
}
|
||||||
ErrorEntry::IntegrityCheck { original, overriden } => {
|
ErrorEntry::IntegrityCheck { original, overriden } => {
|
||||||
write!(f, "Overriding integrity check hash. Original: {original:?}, overriden: {overriden:?}")
|
write!(f, "Overriding integrity check hash. Original: {original:?}, overriden: {overriden:?}")
|
||||||
}
|
}
|
||||||
ErrorEntry::NullPointer(type_) => {
|
ErrorEntry::NullPointer(type_) => {
|
||||||
write!(f, "Null pointer of type {type_} encountered")
|
write!(f, "Null pointer of type {type_} encountered")
|
||||||
}
|
}
|
||||||
ErrorEntry::UnknownLibrary(culibrary) => {
|
ErrorEntry::UnknownLibrary(culibrary) => {
|
||||||
write!(f, "Unknown library: ")?;
|
write!(f, "Unknown library: ")?;
|
||||||
let mut temp_buffer = Vec::new();
|
let mut temp_buffer = Vec::new();
|
||||||
CudaDisplay::write(culibrary, "", 0, &mut temp_buffer).ok();
|
CudaDisplay::write(culibrary, "", 0, &mut temp_buffer).ok();
|
||||||
f.write_str(&unsafe { String::from_utf8_unchecked(temp_buffer) })
|
f.write_str(&unsafe { String::from_utf8_unchecked(temp_buffer) })
|
||||||
}
|
}
|
||||||
|
ErrorEntry::SavedModule(file) => write!(f, "Saved module to {file}"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -128,12 +128,11 @@ impl StateTracker {
|
||||||
fn_logger: &mut FnCallLog,
|
fn_logger: &mut FnCallLog,
|
||||||
type_: &'static str,
|
type_: &'static str,
|
||||||
) {
|
) {
|
||||||
fn_logger.log_io_error(self.writer.save_module(
|
fn_logger.try_(|fn_logger| {
|
||||||
self.library_counter,
|
self.writer
|
||||||
index,
|
.save_module(fn_logger, self.library_counter, index, submodule, type_)
|
||||||
submodule,
|
.map_err(ErrorEntry::IoError)
|
||||||
type_,
|
});
|
||||||
));
|
|
||||||
if type_ == "ptx" {
|
if type_ == "ptx" {
|
||||||
match CString::new(submodule) {
|
match CString::new(submodule) {
|
||||||
Err(e) => fn_logger.log(log::ErrorEntry::NulInsideModuleText(e)),
|
Err(e) => fn_logger.log(log::ErrorEntry::NulInsideModuleText(e)),
|
||||||
|
@ -323,6 +322,7 @@ impl DumpWriter {
|
||||||
|
|
||||||
fn save_module(
|
fn save_module(
|
||||||
&self,
|
&self,
|
||||||
|
fn_logger: &mut FnCallLog,
|
||||||
module_index: usize,
|
module_index: usize,
|
||||||
submodule_index: Option<(usize, Option<usize>)>,
|
submodule_index: Option<(usize, Option<usize>)>,
|
||||||
buffer: &[u8],
|
buffer: &[u8],
|
||||||
|
@ -332,9 +332,13 @@ impl DumpWriter {
|
||||||
None => return Ok(()),
|
None => return Ok(()),
|
||||||
Some(d) => d.clone(),
|
Some(d) => d.clone(),
|
||||||
};
|
};
|
||||||
dump_file.push(Self::get_file_name(module_index, submodule_index, kind));
|
let file_name = Self::get_file_name(module_index, submodule_index, kind);
|
||||||
let mut file = File::create_new(dump_file)?;
|
dump_file.push(&file_name);
|
||||||
file.write_all(buffer)?;
|
{
|
||||||
|
let mut file = File::create_new(dump_file)?;
|
||||||
|
file.write_all(buffer)?;
|
||||||
|
}
|
||||||
|
fn_logger.log(ErrorEntry::SavedModule(file_name));
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -349,7 +353,7 @@ impl DumpWriter {
|
||||||
Some(d) => d.clone(),
|
Some(d) => d.clone(),
|
||||||
};
|
};
|
||||||
log_file.push(Self::get_file_name(module_index, submodule_index, "log"));
|
log_file.push(Self::get_file_name(module_index, submodule_index, "log"));
|
||||||
let mut file = File::create(log_file)?;
|
let mut file = File::create_new(log_file)?;
|
||||||
for error in errors {
|
for error in errors {
|
||||||
writeln!(file, "{}", error)?;
|
writeln!(file, "{}", error)?;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue