More runtime fixes, add mma instruction (#509)
Some checks are pending
ZLUDA / Build (Linux) (push) Waiting to run
ZLUDA / Build (Windows) (push) Waiting to run
ZLUDA / Build AMD GPU unit tests (push) Waiting to run
ZLUDA / Run AMD GPU unit tests (push) Blocked by required conditions

This commit is contained in:
Andrzej Janik 2025-09-18 20:15:22 +02:00 committed by GitHub
commit b5f41c7cd0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
27 changed files with 639 additions and 154 deletions

View file

@ -45,7 +45,7 @@ body:
./train_gpt2fp32cu
4. Build and run the tests:
make test_gpt2fp32cu
LD_LIBRARY_PATH=<ZLUDA_TRACE_DIR> ./test_gpt2fp32cu
LD_LIBRARY_PATH=<ZLUDA_LOG_DIR> ./test_gpt2fp32cu
validations:
required: true
- type: input

1
Cargo.lock generated
View file

@ -2573,7 +2573,6 @@ dependencies = [
"ptx_parser",
"quick-error",
"rustc-hash 2.0.0",
"serde",
"smallvec",
"strum 0.26.3",
"strum_macros 0.26.4",

View file

@ -219,6 +219,10 @@ pub fn compile_bitcode(
compile_to_exec.set_isa_name(gcn_arch)?;
compile_to_exec.set_language(Language::LlvmIr)?;
let common_options = [
c"-mllvm",
c"-ignore-tti-inline-compatible",
// c"-mllvm",
// c"-amdgpu-early-inline-all=true",
// This makes no sense, but it makes ockl linking work
c"-Xclang",
c"-mno-link-builtin-bitcode-postopt",
@ -237,8 +241,7 @@ pub fn compile_bitcode(
]
.into_iter();
let opt_options = if cfg!(debug_assertions) {
//[c"-g", c"-mllvm", c"-print-before-all", c"", c""]
[c"-g", c"", c"", c"", c""]
[c"-g", c"-mamdgpu-precise-memory-op", c"", c"", c""]
} else {
[
c"-g0",

View file

@ -21,9 +21,14 @@ pub struct Options {
output_dir: Option<PathBuf>,
#[bpaf(long("arch"))]
/// Target architecture
/// Target GPU architecture
arch: Option<String>,
#[bpaf(long("ignore-errors"))]
/// Try to ignore errors. This will try and produce output even if there are
/// parsing errors (e.g. an unimplemented instruction)
ignore_errors: bool,
#[bpaf(positional("filename"))]
/// PTX file
ptx_path: String,
@ -48,7 +53,10 @@ fn main_core() -> Result<(), CompilerError> {
.unwrap_or("output");
let mut output_path = match opts.output_dir {
Some(value) => value,
Some(value) => {
std::fs::create_dir_all(&value)?;
value
}
None => match ptx_path.parent() {
Some(dir) => dir.to_path_buf(),
None => env::current_dir()?,
@ -68,7 +76,7 @@ fn main_core() -> Result<(), CompilerError> {
let ptx = fs::read(&ptx_path).map_err(CompilerError::from)?;
let ptx = str::from_utf8(&ptx).map_err(CompilerError::from)?;
let llvm = ptx_to_llvm(ptx).map_err(CompilerError::from)?;
let llvm = ptx_to_llvm(opts.ignore_errors, ptx).map_err(CompilerError::from)?;
write_to_file(&llvm.llvm_ir, output_path.with_extension("ll").as_path())?;
@ -92,8 +100,12 @@ fn main_core() -> Result<(), CompilerError> {
Ok(())
}
fn ptx_to_llvm(ptx: &str) -> Result<LLVMArtifacts, CompilerError> {
let ast = ptx_parser::parse_module_checked(ptx).map_err(CompilerError::from)?;
fn ptx_to_llvm(ignore_errors: bool, ptx: &str) -> Result<LLVMArtifacts, CompilerError> {
let ast = if ignore_errors {
ptx_parser::parse_module_unchecked(ptx)
} else {
ptx_parser::parse_module_checked(ptx).map_err(CompilerError::from)?
};
let mut start = Instant::now();
let module = ptx::to_llvm_module(
ast,

View file

@ -116,7 +116,7 @@ in order to demonstrate all of zluda_trace's features.
```bash
nvcc add.cu -o add -arch sm_80
LD_LIBRARY_PATH=~/ZLUDA/target/release/trace/ ZLUDA_TRACE_DIR=/tmp/zluda ./add
LD_LIBRARY_PATH=~/ZLUDA/target/release/trace/ ZLUDA_LOG_DIR=/tmp/zluda ./add
```
The last few lines should look something like:

View file

@ -22,7 +22,6 @@ microlp = "0.2.11"
int-enum = "1.1"
unwrap_or = "1.0.1"
smallvec = "1.15.1"
serde = { version = "1.0.219", features = ["derive"] }
[dev-dependencies]
hip_runtime-sys = { path = "../ext/hip_runtime-sys" }

Binary file not shown.

View file

@ -1,17 +1,21 @@
// Every time this file changes it must te rebuilt, you need `rocm-llvm-dev` and `llvm-17`
// `fdenormal-fp-math=dynamic` is required to make functions eligible for inlining
// /opt/rocm/llvm/bin/clang -std=c++20 -Xclang -fdenormal-fp-math=dynamic -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1010 && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1010\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | sed 's/ nneg / /g' | sed 's/ disjoint / /g' | sed '/__hip_cuid/d' | sed 's/external protected/external hidden/g' | sed 's/trunc nuw/trunc/' | sed 's/trunc nsw/trunc/' | llvm-as-17 - -o zluda_ptx_impl.bc && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc
// /opt/rocm/llvm/bin/clang -std=c++20 -Xclang -fdenormal-fp-math=dynamic -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1100 && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1100\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | sed 's/ nneg / /g' | sed 's/ disjoint / /g' | sed '/__hip_cuid/d' | sed 's/external protected/external hidden/g' | sed 's/trunc nuw/trunc/' | sed 's/trunc nsw/trunc/' | llvm-as-17 - -o zluda_ptx_impl.bc && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc
#include <cstddef>
#include <cstdint>
#include <bit>
#include <cmath>
#include <hip/hip_runtime.h>
#include <hip/amd_detail/amd_device_functions.h>
#include <hip/hip_fp8.h>
#define SHARED_SPACE __attribute__((address_space(3)))
#define CONSTANT_SPACE __attribute__((address_space(4)))
typedef _Float16 half16 __attribute__((ext_vector_type(16)));
typedef float float8 __attribute__((ext_vector_type(8)));
#define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_##NAME
#define ATTR(NAME) __ZLUDA_PTX_IMPL_ATTRIBUTE_##NAME
#define DECLARE_ATTR(TYPE, NAME) \
@ -624,4 +628,156 @@ typedef uint32_t ShflSyncResult __attribute__((ext_vector_type(2)));
uint32_t x3 = load_single_matrix_trans(address, 24);
return uint4::Native_vec_{x0, x1, x2, x3};
}
static inline __device__ _Float16 top16_as_fp16(uint32_t value) {
uint16_t half_bits = static_cast<uint16_t>((value >> 16) & 0xFFFF);
return *reinterpret_cast<_Float16*>(&half_bits);
}
static inline __device__ _Float16 bottom16_as_fp16(uint32_t value) {
uint16_t half_bits = static_cast<uint16_t>(value & 0xFFFF);
return *reinterpret_cast<_Float16*>(&half_bits);
}
static inline __device__ float bpermute_lane(int lane, float x) {
return __hip_ds_bpermutef(4 * lane, x);
}
static inline __device__ uint32_t bpermute_lane(int lane, uint32_t x) {
return __hip_ds_bpermute(4 * lane, x);
}
static __device__ half16 shuffle_a(uint4::Native_vec_ a_reg) {
const unsigned lIdx = threadIdx.x;
const int lane = lIdx % 16; // Lanes 0-15 (the other 16 lanes are a duplicate in w32 mode)
half16 aFrag;
for (int vGPR = 0; vGPR < 8; ++vGPR) {
int cudaChunk = (vGPR / 4) * 2; // will be 0 or 2
int cudaTID = (vGPR % 4 + lane * 4) % 32;
uint32_t reg0, reg1;
// Select the two consecutive elements from a_reg:
if (cudaChunk == 0) {
reg0 = a_reg.x;
reg1 = a_reg.y;
} else { // cudaChunk==2
reg0 = a_reg.z;
reg1 = a_reg.w;
}
uint32_t a_tmp0 = bpermute_lane(cudaTID, reg0);
uint32_t a_tmp1 = bpermute_lane(cudaTID, reg1);
uint32_t a_Frag_reg = (lane < 8) ? a_tmp0 : a_tmp1;
aFrag[2 * vGPR] = bottom16_as_fp16(a_Frag_reg);
aFrag[2 * vGPR + 1] = top16_as_fp16(a_Frag_reg);
}
return aFrag;
}
static __device__ half16 shuffle_b(uint2::Native_vec_ b_reg) {
const unsigned lIdx = threadIdx.x;
const int lane = lIdx % 16;
half16 bFrag;
for (int vGPR = 0; vGPR < 8; ++vGPR) {
int cudaChunk = vGPR / 4; // will be 0 or 1
int cudaTID = vGPR % 4 + (lane * 4) % 64;
uint32_t reg = (cudaChunk == 0) ? b_reg.x : b_reg.y;
uint32_t b_Frag_reg = bpermute_lane(cudaTID, reg);
if (lane < 8) {
bFrag[2 * vGPR] = bottom16_as_fp16(b_Frag_reg);
bFrag[2 * vGPR + 1] = top16_as_fp16(b_Frag_reg);
} else {
bFrag[2 * vGPR] = 0.0f;
bFrag[2 * vGPR + 1] = 0.0f;
}
}
return bFrag;
}
static __device__ float8 shuffle_c(float4::Native_vec_ c_reg) {
const int lIdx = (int)threadIdx.x;
float8 cFrag;
// Loop over the eight vector GPRs.
for (int vGPR = 0; vGPR < 8; ++vGPR) {
int cudaChunk = (vGPR / 4) * 2; // will be 0 or 2: selects which pair of components to use.
int lIdx8 = (lIdx < 8) ? lIdx : lIdx - 8;
int cudaTID = (vGPR % 4) * 8 + lIdx8 / 2;
float ctmp0, ctmp1;
if (cudaChunk == 0) {
ctmp0 = bpermute_lane(cudaTID, c_reg.x);
ctmp1 = bpermute_lane(cudaTID, c_reg.y);
} else { // cudaChunk == 2
ctmp0 = bpermute_lane(cudaTID, c_reg.z);
ctmp1 = bpermute_lane(cudaTID, c_reg.w);
}
// Select one of the two values based on the thread index's LSB.
cFrag[vGPR] = (lIdx & 1) ? ctmp1 : ctmp0;
// Zero out for specific thread indices.
if ((lIdx > 7 && lIdx < 16) || (lIdx > 23 && lIdx < 32))
cFrag[vGPR] = 0.0f;
}
return cFrag;
}
static inline __device__ float4::Native_vec_ shuffle_d(float8 dFrag) {
const int lIdx = (int)threadIdx.x;
float4::Native_vec_ d_out;
for (int cChunk = 0; cChunk < 4; ++cChunk) {
int r_vGPR = (cChunk / 2) * 4;
int add8 = (lIdx & 0x4) ? 8 : 0;
int r_lIdx = (cChunk % 2) + (lIdx % 8) * 2 + add8;
float d_tmp0 = bpermute_lane(r_lIdx, dFrag[r_vGPR]);
float d_tmp1 = bpermute_lane(r_lIdx, dFrag[r_vGPR + 1]);
float d_tmp2 = bpermute_lane(r_lIdx, dFrag[r_vGPR + 2]);
float d_tmp3 = bpermute_lane(r_lIdx, dFrag[r_vGPR + 3]);
float val;
if (lIdx < 8) {
val = d_tmp0;
} else if (lIdx < 16) {
val = d_tmp1;
} else if (lIdx < 24) {
val = d_tmp2;
} else {
val = d_tmp3;
}
if (cChunk == 0) d_out.x = val;
else if (cChunk == 1) d_out.y = val;
else if (cChunk == 2) d_out.z = val;
else d_out.w = val;
}
return d_out;
}
float4::Native_vec_ FUNC(mma_sync_aligned_m16n8k16_row_col_f32_f16_f16_f32)(uint4::Native_vec_ a_reg, uint2::Native_vec_ b_reg, float4::Native_vec_ c_reg) {
// Reshuffle from Nvidia-like register layout to AMD layout:
half16 aFrag = shuffle_a(a_reg);
half16 bFrag = shuffle_b(b_reg);
float8 cFrag = shuffle_c(c_reg);
// Call the (builtin) 16x16 MMA instruction. It returns a float8.
float8 dFrag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(aFrag, bFrag, cFrag);
// Unshuffle back into Nvidia expected float4 result
float4::Native_vec_ d_out = shuffle_d(dFrag);
return d_out;
}
float4::Native_vec_ FUNC(mma_sync_aligned_m16n8k16_row_col_f32_bf16_bf16_f32)(uint4::Native_vec_ a_reg, uint2::Native_vec_ b_reg, float4::Native_vec_ c_reg) {
// Reshuffle from Nvidia-like register layout to AMD layout:
half16 aFrag = shuffle_a(a_reg);
half16 bFrag = shuffle_b(b_reg);
float8 cFrag = shuffle_c(c_reg);
// Call the (builtin) 16x16 MMA instruction. It returns a float8.
float8 dFrag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(aFrag, bFrag, cFrag);
// Unshuffle back into Nvidia expected float4 result
float4::Native_vec_ d_out = shuffle_d(dFrag);
return d_out;
}
}

View file

@ -197,7 +197,9 @@ fn run_instruction<'input>(
| ast::Instruction::Xor { .. }
| ast::Instruction::Vote { .. }
| ast::Instruction::ReduxSync { .. }
| ast::Instruction::LdMatrix { .. } => result.push(Statement::Instruction(instruction)),
| ast::Instruction::GridDepControl { .. }
| ast::Instruction::LdMatrix { .. }
| ast::Instruction::Mma { .. } => result.push(Statement::Instruction(instruction)),
ast::Instruction::Add {
data:
ast::ArithDetails::Float(ast::ArithFloat {

View file

@ -1855,7 +1855,9 @@ fn get_modes<T: ast::Operand>(inst: &ast::Instruction<T>) -> InstructionModes {
| ast::Instruction::AtomCas { .. }
| ast::Instruction::Vote { .. }
| ast::Instruction::ReduxSync { .. }
| ast::Instruction::LdMatrix { .. } => InstructionModes::none(),
| ast::Instruction::GridDepControl { .. }
| ast::Instruction::LdMatrix { .. }
| ast::Instruction::Mma { .. } => InstructionModes::none(),
ast::Instruction::Add {
data: ast::ArithDetails::Integer(_),
..

View file

@ -153,6 +153,9 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
for (i, param) in method.input_arguments.iter().enumerate() {
let value = unsafe { LLVMGetParam(fn_, i as u32) };
let name = self.resolver.get_or_add(param.name);
if let Some(align) = param.align {
unsafe { LLVMSetParamAlignment(value, align) };
}
unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) };
self.resolver.register(param.name, value);
if method.is_kernel {
@ -519,6 +522,7 @@ impl<'a> MethodEmitContext<'a> {
ast::Instruction::CpAsyncCommitGroup {} => Ok(()), // nop
ast::Instruction::CpAsyncWaitGroup { .. } => Ok(()), // nop
ast::Instruction::CpAsyncWaitAll { .. } => Ok(()), // nop
ast::Instruction::GridDepControl { .. } => Ok(()), // nop
// replaced by a function call
ast::Instruction::Bfe { .. }
| ast::Instruction::Bar { .. }
@ -529,7 +533,8 @@ impl<'a> MethodEmitContext<'a> {
| ast::Instruction::Vote { .. }
| ast::Instruction::Nanosleep { .. }
| ast::Instruction::ReduxSync { .. }
| ast::Instruction::LdMatrix { .. } => return Err(error_unreachable()),
| ast::Instruction::LdMatrix { .. }
| ast::Instruction::Mma { .. } => return Err(error_unreachable()),
}
}

View file

@ -51,7 +51,7 @@ quick_error! {
}
/// GPU attributes needed at compile time.
#[derive(serde::Serialize)]
#[derive(Copy, Clone)]
pub struct Attributes {
/// Clock frequency in kHz.
pub clock_rate: u32,

View file

@ -351,6 +351,35 @@ fn run_instruction<'input>(
let name = "sqrt_rn_ftz_f32";
to_call(resolver, fn_declarations, name.into(), i)?
}
i @ ptx_parser::Instruction::Mma {
data:
ast::MmaDetails {
alayout,
blayout,
dtype_scalar,
atype_scalar,
btype_scalar,
ctype_scalar,
},
..
} => {
let name = format!(
"mma_sync_aligned_m16n8k16_{}_{}_{}_{}_{}_{}",
match alayout {
ast::MatrixLayout::Row => "row",
ast::MatrixLayout::Col => "col",
},
match blayout {
ast::MatrixLayout::Row => "row",
ast::MatrixLayout::Col => "col",
},
scalar_to_ptx_name(dtype_scalar),
scalar_to_ptx_name(atype_scalar),
scalar_to_ptx_name(btype_scalar),
scalar_to_ptx_name(ctype_scalar),
);
to_call(resolver, fn_declarations, name.into(), i)?
}
i @ ptx_parser::Instruction::Sqrt {
data:
ast::RcpData {

View file

@ -3,8 +3,8 @@ use super::{
StateSpace, VectorPrefix,
};
use crate::{
FunnelShiftMode, MatrixNumber, MatrixShape, Mul24Control, PtxError, PtxParserState, Reduction,
ShiftDirection, ShuffleMode, VoteMode,
FunnelShiftMode, MatrixLayout, MatrixNumber, MatrixShape, Mul24Control, PtxError,
PtxParserState, Reduction, ShiftDirection, ShuffleMode, VoteMode,
};
use bitflags::bitflags;
use derive_more::Display;
@ -721,6 +721,30 @@ ptx_parser_macros::generate_instruction_type!(
space: { data.state_space },
}
}
},
GridDepControl {
data: crate::GridDepControlAction,
},
Mma {
data: MmaDetails,
arguments<T>: {
dst: {
repr: T,
type: { data.dtype() },
},
src1: {
repr: T,
type: { data.atype() },
},
src2: {
repr: T,
type: { data.btype() },
},
src3: {
repr: T,
type: { data.ctype() },
}
}
}
}
);
@ -2378,3 +2402,27 @@ pub struct ReduxSyncData {
pub type_: ScalarType,
pub reduction: Reduction,
}
pub struct MmaDetails {
pub alayout: MatrixLayout,
pub blayout: MatrixLayout,
pub dtype_scalar: ScalarType,
pub atype_scalar: ScalarType,
pub btype_scalar: ScalarType,
pub ctype_scalar: ScalarType,
}
impl MmaDetails {
pub fn dtype(&self) -> Type {
Type::Vector(4, ScalarType::F32)
}
pub fn atype(&self) -> Type {
Type::Vector(4, ScalarType::U32)
}
pub fn btype(&self) -> Type {
Type::Vector(2, ScalarType::U32)
}
pub fn ctype(&self) -> Type {
Type::Vector(4, ScalarType::F32)
}
}

View file

@ -1862,6 +1862,9 @@ derive_parser!(
#[derive(Copy, Clone, Display, PartialEq, Eq, Hash)]
pub enum MatrixNumber { }
#[derive(Copy, Clone, Display, PartialEq, Eq, Hash)]
pub enum MatrixLayout { }
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov
mov{.vec}.type d, a => {
Instruction::Mov {
@ -3897,6 +3900,37 @@ derive_parser!(
.type: ScalarType = {.b16, .b8};
// .dst_fmt = { .b8x16 };
// .src_fmt = { .b6x16_p32, .b4x16_p64 };
// https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-griddepcontrol
griddepcontrol.action => {
Instruction::GridDepControl {
data: action
}
}
.action: GridDepControlAction = { .launch_dependents, .wait };
// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-mma
mma.sync.aligned.m16n8k16.alayout.blayout.dtype.bf16.bf16.ctype d, a, b, c => {
if dtype != ScalarType::F32 || ctype != ScalarType::F32 {
state.errors.push(PtxError::Todo);
}
Instruction::Mma {
data: MmaDetails {
alayout,
blayout,
dtype_scalar: dtype,
atype_scalar: ScalarType::BF16,
btype_scalar: ScalarType::BF16,
ctype_scalar: ctype,
},
arguments: MmaArgs { dst: d, src1: a, src2: b, src3: c }
}
}
.alayout: MatrixLayout = {.row};
.blayout: MatrixLayout = {.col};
.ctype: ScalarType = {.f16, .f32};
.dtype: ScalarType = {.f16, .f32};
);
#[cfg(test)]

View file

@ -1,6 +1,4 @@
use bpaf::{any, doc::Style, Bpaf, Parser};
use hip_runtime_sys::{hipDeviceProp_tR0600, hipGetDevicePropertiesR0600};
use std::{ffi::CStr, mem};
use bpaf::{any, choice, doc::Style, literal, Bpaf, Parser};
#[derive(Debug, Clone, Bpaf)]
#[allow(dead_code)]
@ -12,6 +10,8 @@ pub struct Options {
#[bpaf(short, long)]
verbose: bool,
#[bpaf(external)]
lineinfo: bool,
#[bpaf(external)]
gpu_name: String,
#[bpaf(long, short('O'), fallback(3))]
opt_level: usize,
@ -19,48 +19,32 @@ pub struct Options {
input: String,
}
fn lineinfo() -> impl Parser<bool> {
choice(["-lineinfo", "--lineinfo"].into_iter().map(|s| {
literal(s)
.anywhere()
.optional()
.map(|_| true)
.fallback(false)
.boxed()
}))
}
// #[bpaf(long, long("gpu_name"), fallback_with(default_arch))]
fn gpu_name() -> impl Parser<String> {
any("", move |s: String| {
Some(s.strip_prefix("-arch=")?.to_owned())
Some(
s.strip_prefix("-arch=")
.or_else(|| s.strip_prefix("--gpu-name="))?
.to_owned(),
)
})
.metavar(&[("-arch=", Style::Literal), ("ARG", Style::Metavar)])
.metavar(&[("--gpu-name=", Style::Literal), ("SM", Style::Metavar)])
.anywhere()
.fallback_with(|| Ok::<String, &'static str>("sm_52".to_string()))
}
fn main() {
let options = options().run();
let comgr = comgr::Comgr::new().unwrap();
unsafe { hip_runtime_sys::hipInit(0) }.unwrap();
let mut dev_props: hipDeviceProp_tR0600 = unsafe { mem::zeroed() };
let (gpu_arch, clock_rate) = get_gpu_arch_and_clock_rate(&mut dev_props);
let input = std::fs::read_to_string(options.input).unwrap();
let ast = ptx_parser::parse_module_checked(&input).unwrap();
let llvm = ptx::to_llvm_module(
ast,
ptx::Attributes {
clock_rate: clock_rate as u32,
},
|_| {},
)
.unwrap();
let elf_binary = comgr::compile_bitcode(
&comgr,
gpu_arch,
&*llvm.llvm_ir.write_bitcode_to_memory(),
&*llvm.linked_bitcode(),
&*llvm.attributes_ir.write_bitcode_to_memory(),
None,
)
.unwrap();
std::fs::write(options.output, elf_binary).unwrap();
}
fn get_gpu_arch_and_clock_rate<'a>(dev_props: &'a mut hipDeviceProp_tR0600) -> (&'a str, i32) {
unsafe { hipGetDevicePropertiesR0600(dev_props, 0) }.unwrap();
let gcn_arch_name = &dev_props.gcnArchName;
let gcn_arch_name = unsafe { CStr::from_ptr(gcn_arch_name.as_ptr()) };
let gcn_arch_name = gcn_arch_name.to_str();
(gcn_arch_name.unwrap(), dev_props.clockRate)
std::fs::copy(&options.input, &options.output).unwrap();
}

View file

@ -89,7 +89,15 @@ pub(crate) fn get_attribute(
*pi = 32;
return Ok(());
}
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_TCC_DRIVER => {
// TODO: maintain a table, certain RDNAs are 1/16, some are 1/32
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_SINGLE_TO_DOUBLE_PRECISION_PERF_RATIO => {
*pi = 32;
return Ok(());
}
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_TCC_DRIVER
| CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED
| CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MEMPOOL_SUPPORTED_HANDLE_TYPES
| CUdevice_attribute::CU_DEVICE_ATTRIBUTE_DMA_BUF_SUPPORTED => {
*pi = 0;
return Ok(());
}
@ -211,9 +219,6 @@ pub(crate) fn get_attribute(
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_PERSISTING_L2_CACHE_SIZE => {
return get_device_prop(pi, dev_idx, |props| props.persistingL2CacheMaxSize)
}
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_SINGLE_TO_DOUBLE_PRECISION_PERF_RATIO => {
return get_device_prop(pi, dev_idx, |props| props.singleToDoublePrecisionPerfRatio)
}
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_ACCESS_POLICY_WINDOW_SIZE => {
return get_device_prop(pi, dev_idx, |props| props.accessPolicyMaxWindowSize)
}

View file

@ -487,9 +487,9 @@ pub(crate) unsafe fn occupancy_max_active_blocks_per_multiprocessor_with_flags(
dynamic_smem_size: usize,
flags: ::core::ffi::c_uint,
) -> hipError_t {
hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags(
hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags(
num_blocks,
func.0.cast(),
func,
block_size,
dynamic_smem_size,
flags,

12
zluda/src/impl/hipfix.rs Normal file
View file

@ -0,0 +1,12 @@
// There's a bug in hipDrvPointerGetAttributes where it returns
// HIP_ERROR_INVALID_VALUE if the pointer is null. It works correctly for any
// other invalid pointer
pub(crate) fn get_attributes(
ptr: hip_runtime_sys::hipDeviceptr_t,
) -> hip_runtime_sys::hipDeviceptr_t {
if ptr.0.is_null() {
hip_runtime_sys::hipDeviceptr_t(usize::MAX as _)
} else {
ptr
}
}

View file

@ -7,6 +7,7 @@ pub(super) mod driver;
pub(super) mod event;
pub(super) mod function;
pub(super) mod graph;
pub(super) mod hipfix;
pub(super) mod kernel;
pub(super) mod library;
pub(super) mod memory;

View file

@ -20,6 +20,10 @@ impl ZludaObject for Module {
}
}
static EMPTY_PTX: &str = ".version 6.5
.target sm_30
.address_size 64";
// get_ptx takes an `image` that can be anything we support and returns a
// String containing a ptx extracted from `image`.
fn get_ptx<'a>(image: CodeLibraryRef<'a>) -> Result<Cow<'a, str>, CUerror> {
@ -58,11 +62,17 @@ fn cow_bytes_to_str<'a>(data: Cow<'a, [u8]>) -> Option<Cow<'a, str>> {
pub(crate) fn load_hip_module(library: CodeLibraryRef) -> Result<hipModule_t, CUerror> {
let global_state = driver::global_state()?;
let text = get_ptx(library)?;
let maybe_ptx = get_ptx(library);
let text = if cfg!(debug_assertions) {
maybe_ptx?
} else {
maybe_ptx.unwrap_or_else(|_| Cow::Borrowed(EMPTY_PTX))
};
let hip_properties = get_hip_properties()?;
let gcn_arch = get_gcn_arch(&hip_properties)?;
let attributes = ptx::Attributes {
let attributes = ExtraCacheAttributes {
clock_rate: hip_properties.clockRate as u32,
is_debug: cfg!(debug_assertions),
};
let mut cache_with_key = global_state.cache_path.as_ref().and_then(|p| {
let cache = zluda_cache::ModuleCache::open(p)?;
@ -84,6 +94,12 @@ pub(crate) fn load_hip_module(library: CodeLibraryRef) -> Result<hipModule_t, CU
Ok(hip_module)
}
#[derive(serde::Serialize)]
struct ExtraCacheAttributes {
is_debug: bool,
clock_rate: u32,
}
fn get_hip_properties<'a>() -> Result<hipDeviceProp_tR0600, CUerror> {
let hip_dev = super::context::get_current_device()?;
let mut props = unsafe { mem::zeroed() };
@ -100,7 +116,7 @@ fn get_cache_key<'a, 'b>(
global_state: &'static driver::GlobalState,
isa: &'a str,
text: &str,
attributes: &ptx::Attributes,
attributes: &impl serde::Serialize,
) -> Option<zluda_cache::ModuleKey<'a>> {
// Serialization here is deterministic. When marking a type with
// #[derive(serde::Serialize)] the derived implementation will just write
@ -129,7 +145,7 @@ fn load_cached_binary(
fn compile_from_ptx_and_cache(
comgr: &comgr::Comgr,
gcn_arch: &str,
attributes: ptx::Attributes,
attributes: ExtraCacheAttributes,
text: &str,
cache_with_key: &mut Option<(zluda_cache::ModuleCache, zluda_cache::ModuleKey)>,
) -> Result<Vec<u8>, CUerror> {
@ -138,7 +154,14 @@ fn compile_from_ptx_and_cache(
} else {
ptx_parser::parse_module_unchecked(text)
};
let llvm_module = ptx::to_llvm_module(ast, attributes, |_| {}).map_err(|_| CUerror::UNKNOWN)?;
let llvm_module = ptx::to_llvm_module(
ast,
ptx::Attributes {
clock_rate: attributes.clock_rate,
},
|_| {},
)
.map_err(|_| CUerror::UNKNOWN)?;
let elf_module = comgr::compile_bitcode(
comgr,
gcn_arch,

View file

@ -2,7 +2,7 @@ use cuda_types::cuda::*;
use hip_runtime_sys::*;
use std::{ffi::c_void, ptr};
use crate::r#impl::driver;
use crate::r#impl::{driver, hipfix};
// TODO: handlehipMemoryTypeUnregistered
fn to_cu_memory_type(cu: hipMemoryType) -> Result<CUmemorytype, hipErrorCode_t> {
@ -59,7 +59,12 @@ pub(crate) unsafe fn get_attributes(
data: &mut *mut ::core::ffi::c_void,
ptr: hipDeviceptr_t,
) -> CUresult {
hipDrvPointerGetAttributes(num_attributes, attributes, data, ptr)?;
hipDrvPointerGetAttributes(
num_attributes,
attributes,
data,
hipfix::get_attributes(ptr),
)?;
let attributes = std::slice::from_raw_parts_mut(attributes, num_attributes as usize);
let data = std::slice::from_raw_parts_mut(data, num_attributes as usize);
for (attr, data_ptr) in attributes.iter().copied().zip(data.iter().copied()) {
@ -88,7 +93,7 @@ mod tests {
use crate::tests::CudaApi;
use cuda_macros::test_cuda;
use cuda_types::cuda::*;
use std::{ffi::c_void, mem, ptr};
use std::{ffi::c_void, i32, mem, ptr, usize};
#[test_cuda]
pub unsafe fn unknown_ptr_attribute(api: impl CudaApi) {
@ -162,4 +167,47 @@ mod tests {
);
assert_eq!(context, CUcontext(ptr::null_mut()));
}
#[test_cuda]
pub unsafe fn null_ptr_attributes_success(api: impl CudaApi) {
api.cuInit(0);
api.cuCtxCreate_v2(&mut mem::zeroed(), 0, 0);
let mut context = CUcontext(1 as _);
let mut mem_type = mem::transmute::<_, CUmemorytype>(u32::MAX);
let mut dev_ptr = mem::transmute::<_, *mut c_void>(usize::MAX);
let mut host_ptr = mem::transmute::<_, *mut c_void>(usize::MAX);
let mut is_managed = true;
let mut ordinal = i32::MAX;
let mut attrs = [
CUpointer_attribute::CU_POINTER_ATTRIBUTE_CONTEXT,
CUpointer_attribute::CU_POINTER_ATTRIBUTE_MEMORY_TYPE,
CUpointer_attribute::CU_POINTER_ATTRIBUTE_DEVICE_POINTER,
CUpointer_attribute::CU_POINTER_ATTRIBUTE_HOST_POINTER,
CUpointer_attribute::CU_POINTER_ATTRIBUTE_IS_MANAGED,
CUpointer_attribute::CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
];
let mut values = [
std::ptr::from_mut(&mut context).cast::<c_void>(),
std::ptr::from_mut(&mut mem_type).cast::<c_void>(),
std::ptr::from_mut(&mut dev_ptr).cast::<c_void>(),
std::ptr::from_mut(&mut host_ptr).cast::<c_void>(),
std::ptr::from_mut(&mut is_managed).cast::<c_void>(),
std::ptr::from_mut(&mut ordinal).cast::<c_void>(),
];
assert_eq!(
CUresult::SUCCESS,
api.cuPointerGetAttributes_unchecked(
attrs.len() as u32,
attrs.as_mut_ptr(),
values.as_mut_ptr(),
CUdeviceptr_v2(ptr::null_mut())
)
);
assert_eq!(context, CUcontext(ptr::null_mut()));
assert_eq!(mem_type, CUmemorytype(0));
assert_eq!(dev_ptr, ptr::null_mut());
assert_eq!(host_ptr, ptr::null_mut());
assert_eq!(is_managed, false);
assert_eq!(ordinal, -2);
}
}

View file

@ -43,6 +43,86 @@ pub(crate) unsafe fn device_get_count_v2(device_count: &mut ::core::ffi::c_uint)
rsmi_num_monitor_devices(device_count)
}
pub(crate) unsafe fn device_get_handle_by_pci_bus_id_v2(
pci_bus_id: &std::ffi::CStr,
device: &mut cuda_types::nvml::nvmlDevice_t,
) -> nvmlReturn_t {
let pci = parse_pci_bus_id(pci_bus_id).ok_or(nvmlError_t::INVALID_ARGUMENT)?;
let bdfid = pci.to_bdfid();
let mut device_count = 0;
rsmi_num_monitor_devices(&mut device_count)?;
for dv_ind in 0..device_count {
let mut curr_bdfid = 0;
rsmi_dev_pci_id_get(dv_ind, &mut curr_bdfid)?;
if curr_bdfid == bdfid {
*device = Device { _index: dv_ind }.wrap();
return nvmlReturn_t::SUCCESS;
}
}
nvmlReturn_t::ERROR_NOT_FOUND
}
#[derive(Clone, Copy)]
struct PciBusId {
domain: u16,
bus: u8,
device: u8,
function: u8,
}
impl PciBusId {
fn to_bdfid(self) -> u64 {
((self.domain as u64) << 32)
| ((self.bus as u64) << 8)
| ((self.device as u64) << 3)
| (self.function as u64)
}
}
fn parse_pci_bus_id(id: &std::ffi::CStr) -> Option<PciBusId> {
let s = id.to_str().ok()?.trim();
let mut domain: u16 = 0;
let mut rest = s;
if let Some(colon1) = s.find(':') {
if colon1 == 4 {
domain = hex_u16(&s[..4])?;
rest = &s[5..];
}
}
let mut parts = rest.split(':');
let bus_part = parts.next()?;
let tail = parts.next()?;
if parts.next().is_some() {
return None;
}
let mut dev_func = tail.split('.');
let dev_part = dev_func.next()?;
let func_part = dev_func.next();
let function = match func_part {
Some(f) => hex_u8(f)?,
None => 0,
};
Some(PciBusId {
domain,
bus: hex_u8(bus_part)?,
device: hex_u8(dev_part)?,
function,
})
}
fn hex_u16(s: &str) -> Option<u16> {
if s.len() > 4 {
return None;
}
u16::from_str_radix(s, 16).ok()
}
fn hex_u8(s: &str) -> Option<u8> {
if s.len() > 2 {
return None;
}
u8::from_str_radix(s, 16).ok()
}
pub(crate) unsafe fn device_get_field_values(
_device: &Device,
values_count: ::core::ffi::c_int,
@ -75,3 +155,36 @@ pub(crate) fn device_get_handle_by_index_v2(
*device = Device { _index: index }.wrap();
nvmlReturn_t::SUCCESS
}
#[cfg(test)]
mod tests {
#[test]
fn parse_pci_bus_id_full() {
let id = std::ffi::CString::new("0100:65:a0.f").unwrap();
let parsed = super::parse_pci_bus_id(&id).unwrap();
assert_eq!(parsed.domain, 0x0100);
assert_eq!(parsed.bus, 0x65);
assert_eq!(parsed.device, 0xa0);
assert_eq!(parsed.function, 0xf);
}
#[test]
fn parse_pci_bus_id_no_func() {
let id = std::ffi::CString::new("0100:65:a0").unwrap();
let parsed = super::parse_pci_bus_id(&id).unwrap();
assert_eq!(parsed.domain, 0x0100);
assert_eq!(parsed.bus, 0x65);
assert_eq!(parsed.device, 0xa0);
assert_eq!(parsed.function, 0);
}
#[test]
fn parse_pci_bus_id_no_domain() {
let id = std::ffi::CString::new("65:a0.f").unwrap();
let parsed = super::parse_pci_bus_id(&id).unwrap();
assert_eq!(parsed.domain, 0);
assert_eq!(parsed.bus, 0x65);
assert_eq!(parsed.device, 0xa0);
assert_eq!(parsed.function, 0xf);
}
}

View file

@ -23,6 +23,13 @@ pub(crate) unsafe fn device_get_count_v2(_device_count: &mut ::core::ffi::c_uint
crate::impl_common::unimplemented()
}
pub(crate) unsafe fn device_get_handle_by_pci_bus_id_v2(
pci_bus_id: &std::ffi::CStr,
device: &mut cuda_types::nvml::nvmlDevice_t,
) -> nvmlReturn_t {
crate::impl_common::unimplemented()
}
pub(crate) unsafe fn device_get_field_values(
_device: cuda_types::nvml::nvmlDevice_t,
_values_count: ::core::ffi::c_int,
@ -31,10 +38,6 @@ pub(crate) unsafe fn device_get_field_values(
crate::impl_common::unimplemented()
}
unsafe fn get_field_value(_field: &mut nvmlFieldValue_st) -> Result<(), nvmlError_t> {
crate::impl_common::unimplemented()
}
pub(crate) unsafe fn device_get_gpu_fabric_info(
_device: cuda_types::nvml::nvmlDevice_t,
_gpu_fabric_info: &mut cuda_types::nvml::nvmlGpuFabricInfo_t,

View file

@ -48,6 +48,7 @@ cuda_macros::nvml_function_declarations!(
nvmlDeviceGetFieldValues,
nvmlDeviceGetGpuFabricInfo,
nvmlDeviceGetHandleByIndex_v2,
nvmlDeviceGetHandleByPciBusId_v2,
nvmlInit,
nvmlInitWithFlags,
nvmlInit_v2,

View file

@ -303,6 +303,7 @@ pub(crate) enum ErrorEntry {
},
NullPointer(&'static str),
UnknownLibrary(CUlibrary),
SavedModule(String),
}
unsafe impl Send for ErrorEntry {}
@ -344,93 +345,94 @@ impl Display for ErrorEntry {
match self {
ErrorEntry::IoError(e) => e.fmt(f),
ErrorEntry::CreatedDumpDirectory(dir) => {
write!(
f,
"Created trace directory {} ",
dir.as_os_str().to_string_lossy()
)
}
write!(
f,
"Created trace directory {} ",
dir.as_os_str().to_string_lossy()
)
}
ErrorEntry::ErrorBox(e) => e.fmt(f),
ErrorEntry::UnsupportedModule {
module,
raw_image,
kind,
} => {
write!(
f,
"Unsupported {} module {:?} loaded from module image {:?}",
kind, module, raw_image
)
}
module,
raw_image,
kind,
} => {
write!(
f,
"Unsupported {} module {:?} loaded from module image {:?}",
kind, module, raw_image
)
}
ErrorEntry::MalformedModulePath(e) => e.fmt(f),
ErrorEntry::NonUtf8ModuleText(e) => e.fmt(f),
ErrorEntry::ModuleParsingError(file_name) => {
write!(
f,
"Error parsing module, log has been written to {}",
file_name
)
}
write!(
f,
"Error parsing module, log has been written to {}",
file_name
)
}
ErrorEntry::NulInsideModuleText(e) => e.fmt(f),
ErrorEntry::Lz4DecompressionFailure => write!(f, "LZ4 decompression failure"),
ErrorEntry::ZstdDecompressionFailure(err_code) => write!(f, "Zstd decompression failure: {}", zstd_safe::get_error_name(*err_code)),
ErrorEntry::UnexpectedBinaryField {
field_name,
expected,
observed,
} => write!(
f,
"Unexpected field {}. Expected one of: [{}], observed: {}",
field_name,
expected
.iter()
.map(|x| x.to_string())
.collect::<Vec<_>>()
.join(", "),
observed
),
field_name,
expected,
observed,
} => write!(
f,
"Unexpected field {}. Expected one of: [{}], observed: {}",
field_name,
expected
.iter()
.map(|x| x.to_string())
.collect::<Vec<_>>()
.join(", "),
observed
),
ErrorEntry::UnexpectedArgument {
arg_name,
expected,
observed,
} => write!(
f,
"Unexpected argument {}. Expected one of: {{{}}}, observed: {}",
arg_name,
expected
.iter()
.map(|x| x.to_string())
.collect::<Vec<_>>()
.join(", "),
observed
),
arg_name,
expected,
observed,
} => write!(
f,
"Unexpected argument {}. Expected one of: {{{}}}, observed: {}",
arg_name,
expected
.iter()
.map(|x| x.to_string())
.collect::<Vec<_>>()
.join(", "),
observed
),
ErrorEntry::InvalidEnvVar {
var,
pattern,
value,
} => write!(
f,
"Unexpected value of environment variable {var}. Expected pattern: {pattern}, got value: {value}"
),
var,
pattern,
value,
} => write!(
f,
"Unexpected value of environment variable {var}. Expected pattern: {pattern}, got value: {value}"
),
ErrorEntry::FunctionNotFound(cuda_function_name) => write!(
f,
"No function {cuda_function_name} in the underlying library"
),
f,
"No function {cuda_function_name} in the underlying library"
),
ErrorEntry::UnexpectedExportTableSize { expected, computed } => {
write!(f, "Table length mismatch. Expected: {expected}, got: {computed}")
}
write!(f, "Table length mismatch. Expected: {expected}, got: {computed}")
}
ErrorEntry::IntegrityCheck { original, overriden } => {
write!(f, "Overriding integrity check hash. Original: {original:?}, overriden: {overriden:?}")
}
write!(f, "Overriding integrity check hash. Original: {original:?}, overriden: {overriden:?}")
}
ErrorEntry::NullPointer(type_) => {
write!(f, "Null pointer of type {type_} encountered")
}
write!(f, "Null pointer of type {type_} encountered")
}
ErrorEntry::UnknownLibrary(culibrary) => {
write!(f, "Unknown library: ")?;
let mut temp_buffer = Vec::new();
CudaDisplay::write(culibrary, "", 0, &mut temp_buffer).ok();
f.write_str(&unsafe { String::from_utf8_unchecked(temp_buffer) })
}
write!(f, "Unknown library: ")?;
let mut temp_buffer = Vec::new();
CudaDisplay::write(culibrary, "", 0, &mut temp_buffer).ok();
f.write_str(&unsafe { String::from_utf8_unchecked(temp_buffer) })
}
ErrorEntry::SavedModule(file) => write!(f, "Saved module to {file}"),
}
}
}

View file

@ -128,12 +128,11 @@ impl StateTracker {
fn_logger: &mut FnCallLog,
type_: &'static str,
) {
fn_logger.log_io_error(self.writer.save_module(
self.library_counter,
index,
submodule,
type_,
));
fn_logger.try_(|fn_logger| {
self.writer
.save_module(fn_logger, self.library_counter, index, submodule, type_)
.map_err(ErrorEntry::IoError)
});
if type_ == "ptx" {
match CString::new(submodule) {
Err(e) => fn_logger.log(log::ErrorEntry::NulInsideModuleText(e)),
@ -323,6 +322,7 @@ impl DumpWriter {
fn save_module(
&self,
fn_logger: &mut FnCallLog,
module_index: usize,
submodule_index: Option<(usize, Option<usize>)>,
buffer: &[u8],
@ -332,9 +332,13 @@ impl DumpWriter {
None => return Ok(()),
Some(d) => d.clone(),
};
dump_file.push(Self::get_file_name(module_index, submodule_index, kind));
let mut file = File::create_new(dump_file)?;
file.write_all(buffer)?;
let file_name = Self::get_file_name(module_index, submodule_index, kind);
dump_file.push(&file_name);
{
let mut file = File::create_new(dump_file)?;
file.write_all(buffer)?;
}
fn_logger.log(ErrorEntry::SavedModule(file_name));
Ok(())
}
@ -349,7 +353,7 @@ impl DumpWriter {
Some(d) => d.clone(),
};
log_file.push(Self::get_file_name(module_index, submodule_index, "log"));
let mut file = File::create(log_file)?;
let mut file = File::create_new(log_file)?;
for error in errors {
writeln!(file, "{}", error)?;
}