This commit is contained in:
Andrzej Janik 2025-09-17 01:51:29 +00:00
commit 00d7cd131b
8 changed files with 265 additions and 6 deletions

Binary file not shown.

View file

@ -1,17 +1,21 @@
// Every time this file changes it must te rebuilt, you need `rocm-llvm-dev` and `llvm-17`
// `fdenormal-fp-math=dynamic` is required to make functions eligible for inlining
// /opt/rocm/llvm/bin/clang -std=c++20 -Xclang -fdenormal-fp-math=dynamic -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1010 && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1010\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | sed 's/ nneg / /g' | sed 's/ disjoint / /g' | sed '/__hip_cuid/d' | sed 's/external protected/external hidden/g' | sed 's/trunc nuw/trunc/' | sed 's/trunc nsw/trunc/' | llvm-as-17 - -o zluda_ptx_impl.bc && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc
// /opt/rocm/llvm/bin/clang -std=c++20 -Xclang -fdenormal-fp-math=dynamic -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1100 && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1100\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | sed 's/ nneg / /g' | sed 's/ disjoint / /g' | sed '/__hip_cuid/d' | sed 's/external protected/external hidden/g' | sed 's/trunc nuw/trunc/' | sed 's/trunc nsw/trunc/' | llvm-as-17 - -o zluda_ptx_impl.bc && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc
#include <cstddef>
#include <cstdint>
#include <bit>
#include <cmath>
#include <hip/hip_runtime.h>
#include <hip/amd_detail/amd_device_functions.h>
#include <hip/hip_fp8.h>
#define SHARED_SPACE __attribute__((address_space(3)))
#define CONSTANT_SPACE __attribute__((address_space(4)))
typedef _Float16 half16 __attribute__((ext_vector_type(16)));
typedef float float8 __attribute__((ext_vector_type(8)));
#define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_##NAME
#define ATTR(NAME) __ZLUDA_PTX_IMPL_ATTRIBUTE_##NAME
#define DECLARE_ATTR(TYPE, NAME) \
@ -624,4 +628,156 @@ typedef uint32_t ShflSyncResult __attribute__((ext_vector_type(2)));
uint32_t x3 = load_single_matrix_trans(address, 24);
return uint4::Native_vec_{x0, x1, x2, x3};
}
static inline __device__ _Float16 top16_as_fp16(uint32_t value) {
uint16_t half_bits = static_cast<uint16_t>((value >> 16) & 0xFFFF);
return *reinterpret_cast<_Float16*>(&half_bits);
}
static inline __device__ _Float16 bottom16_as_fp16(uint32_t value) {
uint16_t half_bits = static_cast<uint16_t>(value & 0xFFFF);
return *reinterpret_cast<_Float16*>(&half_bits);
}
static inline __device__ float bpermute_lane(int lane, float x) {
return __hip_ds_bpermutef(4 * lane, x);
}
static inline __device__ uint32_t bpermute_lane(int lane, uint32_t x) {
return __hip_ds_bpermute(4 * lane, x);
}
static __device__ half16 shuffle_a(uint4::Native_vec_ a_reg) {
const unsigned lIdx = threadIdx.x;
const int lane = lIdx % 16; // Lanes 0-15 (the other 16 lanes are a duplicate in w32 mode)
half16 aFrag;
for (int vGPR = 0; vGPR < 8; ++vGPR) {
int cudaChunk = (vGPR / 4) * 2; // will be 0 or 2
int cudaTID = (vGPR % 4 + lane * 4) % 32;
uint32_t reg0, reg1;
// Select the two consecutive elements from a_reg:
if (cudaChunk == 0) {
reg0 = a_reg.x;
reg1 = a_reg.y;
} else { // cudaChunk==2
reg0 = a_reg.z;
reg1 = a_reg.w;
}
uint32_t a_tmp0 = bpermute_lane(cudaTID, reg0);
uint32_t a_tmp1 = bpermute_lane(cudaTID, reg1);
uint32_t a_Frag_reg = (lane < 8) ? a_tmp0 : a_tmp1;
aFrag[2 * vGPR] = bottom16_as_fp16(a_Frag_reg);
aFrag[2 * vGPR + 1] = top16_as_fp16(a_Frag_reg);
}
return aFrag;
}
static __device__ half16 shuffle_b(uint2::Native_vec_ b_reg) {
const unsigned lIdx = threadIdx.x;
const int lane = lIdx % 16;
half16 bFrag;
for (int vGPR = 0; vGPR < 8; ++vGPR) {
int cudaChunk = vGPR / 4; // will be 0 or 1
int cudaTID = vGPR % 4 + (lane * 4) % 64;
uint32_t reg = (cudaChunk == 0) ? b_reg.x : b_reg.y;
uint32_t b_Frag_reg = bpermute_lane(cudaTID, reg);
if (lane < 8) {
bFrag[2 * vGPR] = bottom16_as_fp16(b_Frag_reg);
bFrag[2 * vGPR + 1] = top16_as_fp16(b_Frag_reg);
} else {
bFrag[2 * vGPR] = 0.0f;
bFrag[2 * vGPR + 1] = 0.0f;
}
}
return bFrag;
}
static __device__ float8 shuffle_c(float4::Native_vec_ c_reg) {
const int lIdx = (int)threadIdx.x;
float8 cFrag;
// Loop over the eight vector GPRs.
for (int vGPR = 0; vGPR < 8; ++vGPR) {
int cudaChunk = (vGPR / 4) * 2; // will be 0 or 2: selects which pair of components to use.
int lIdx8 = (lIdx < 8) ? lIdx : lIdx - 8;
int cudaTID = (vGPR % 4) * 8 + lIdx8 / 2;
float ctmp0, ctmp1;
if (cudaChunk == 0) {
ctmp0 = bpermute_lane(cudaTID, c_reg.x);
ctmp1 = bpermute_lane(cudaTID, c_reg.y);
} else { // cudaChunk == 2
ctmp0 = bpermute_lane(cudaTID, c_reg.z);
ctmp1 = bpermute_lane(cudaTID, c_reg.w);
}
// Select one of the two values based on the thread index's LSB.
cFrag[vGPR] = (lIdx & 1) ? ctmp1 : ctmp0;
// Zero out for specific thread indices.
if ((lIdx > 7 && lIdx < 16) || (lIdx > 23 && lIdx < 32))
cFrag[vGPR] = 0.0f;
}
return cFrag;
}
static inline __device__ float4::Native_vec_ shuffle_d(float8 dFrag) {
const int lIdx = (int)threadIdx.x;
float4::Native_vec_ d_out;
for (int cChunk = 0; cChunk < 4; ++cChunk) {
int r_vGPR = (cChunk / 2) * 4;
int add8 = (lIdx & 0x4) ? 8 : 0;
int r_lIdx = (cChunk % 2) + (lIdx % 8) * 2 + add8;
float d_tmp0 = bpermute_lane(r_lIdx, dFrag[r_vGPR]);
float d_tmp1 = bpermute_lane(r_lIdx, dFrag[r_vGPR + 1]);
float d_tmp2 = bpermute_lane(r_lIdx, dFrag[r_vGPR + 2]);
float d_tmp3 = bpermute_lane(r_lIdx, dFrag[r_vGPR + 3]);
float val;
if (lIdx < 8) {
val = d_tmp0;
} else if (lIdx < 16) {
val = d_tmp1;
} else if (lIdx < 24) {
val = d_tmp2;
} else {
val = d_tmp3;
}
if (cChunk == 0) d_out.x = val;
else if (cChunk == 1) d_out.y = val;
else if (cChunk == 2) d_out.z = val;
else d_out.w = val;
}
return d_out;
}
float4::Native_vec_ FUNC(mma_sync_aligned_m16n8k16_row_col_f32_f16_f16_f32)(uint4::Native_vec_ a_reg, uint2::Native_vec_ b_reg, float4::Native_vec_ c_reg) {
// Reshuffle from Nvidia-like register layout to AMD layout:
half16 aFrag = shuffle_a(a_reg);
half16 bFrag = shuffle_b(b_reg);
float8 cFrag = shuffle_c(c_reg);
// Call the (builtin) 16x16 MMA instruction. It returns a float8.
float8 dFrag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(aFrag, bFrag, cFrag);
// Unshuffle back into Nvidia expected float4 result
float4::Native_vec_ d_out = shuffle_d(dFrag);
return d_out;
}
float4::Native_vec_ FUNC(mma_sync_aligned_m16n8k16_row_col_f32_bf16_bf16_f32)(uint4::Native_vec_ a_reg, uint2::Native_vec_ b_reg, float4::Native_vec_ c_reg) {
// Reshuffle from Nvidia-like register layout to AMD layout:
half16 aFrag = shuffle_a(a_reg);
half16 bFrag = shuffle_b(b_reg);
float8 cFrag = shuffle_c(c_reg);
// Call the (builtin) 16x16 MMA instruction. It returns a float8.
float8 dFrag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(aFrag, bFrag, cFrag);
// Unshuffle back into Nvidia expected float4 result
float4::Native_vec_ d_out = shuffle_d(dFrag);
return d_out;
}
}

View file

@ -198,7 +198,8 @@ fn run_instruction<'input>(
| ast::Instruction::Vote { .. }
| ast::Instruction::ReduxSync { .. }
| ast::Instruction::GridDepControl { .. }
| ast::Instruction::LdMatrix { .. } => result.push(Statement::Instruction(instruction)),
| ast::Instruction::LdMatrix { .. }
| ast::Instruction::Mma { .. } => result.push(Statement::Instruction(instruction)),
ast::Instruction::Add {
data:
ast::ArithDetails::Float(ast::ArithFloat {

View file

@ -1856,7 +1856,8 @@ fn get_modes<T: ast::Operand>(inst: &ast::Instruction<T>) -> InstructionModes {
| ast::Instruction::Vote { .. }
| ast::Instruction::ReduxSync { .. }
| ast::Instruction::GridDepControl { .. }
| ast::Instruction::LdMatrix { .. } => InstructionModes::none(),
| ast::Instruction::LdMatrix { .. }
| ast::Instruction::Mma { .. } => InstructionModes::none(),
ast::Instruction::Add {
data: ast::ArithDetails::Integer(_),
..

View file

@ -533,7 +533,8 @@ impl<'a> MethodEmitContext<'a> {
| ast::Instruction::Vote { .. }
| ast::Instruction::Nanosleep { .. }
| ast::Instruction::ReduxSync { .. }
| ast::Instruction::LdMatrix { .. } => return Err(error_unreachable()),
| ast::Instruction::LdMatrix { .. }
| ast::Instruction::Mma { .. } => return Err(error_unreachable()),
}
}

View file

@ -351,6 +351,35 @@ fn run_instruction<'input>(
let name = "sqrt_rn_ftz_f32";
to_call(resolver, fn_declarations, name.into(), i)?
}
i @ ptx_parser::Instruction::Mma {
data:
ast::MmaDetails {
alayout,
blayout,
dtype_scalar,
atype_scalar,
btype_scalar,
ctype_scalar,
},
..
} => {
let name = format!(
"mma_sync_aligned_m16n8k16_{}_{}_{}_{}_{}_{}",
match alayout {
ast::MatrixLayout::Row => "row",
ast::MatrixLayout::Col => "col",
},
match blayout {
ast::MatrixLayout::Row => "row",
ast::MatrixLayout::Col => "col",
},
scalar_to_ptx_name(dtype_scalar),
scalar_to_ptx_name(atype_scalar),
scalar_to_ptx_name(btype_scalar),
scalar_to_ptx_name(ctype_scalar),
);
to_call(resolver, fn_declarations, name.into(), i)?
}
i @ ptx_parser::Instruction::Sqrt {
data:
ast::RcpData {

View file

@ -3,8 +3,8 @@ use super::{
StateSpace, VectorPrefix,
};
use crate::{
FunnelShiftMode, MatrixNumber, MatrixShape, Mul24Control, PtxError, PtxParserState, Reduction,
ShiftDirection, ShuffleMode, VoteMode,
FunnelShiftMode, MatrixLayout, MatrixNumber, MatrixShape, Mul24Control, PtxError,
PtxParserState, Reduction, ShiftDirection, ShuffleMode, VoteMode,
};
use bitflags::bitflags;
use derive_more::Display;
@ -724,6 +724,27 @@ ptx_parser_macros::generate_instruction_type!(
},
GridDepControl {
data: crate::GridDepControlAction,
},
Mma {
data: MmaDetails,
arguments<T>: {
dst: {
repr: T,
type: { data.dtype() },
},
src1: {
repr: T,
type: { data.atype() },
},
src2: {
repr: T,
type: { data.btype() },
},
src3: {
repr: T,
type: { data.ctype() },
}
}
}
}
);
@ -2381,3 +2402,27 @@ pub struct ReduxSyncData {
pub type_: ScalarType,
pub reduction: Reduction,
}
pub struct MmaDetails {
pub alayout: MatrixLayout,
pub blayout: MatrixLayout,
pub dtype_scalar: ScalarType,
pub atype_scalar: ScalarType,
pub btype_scalar: ScalarType,
pub ctype_scalar: ScalarType,
}
impl MmaDetails {
pub fn dtype(&self) -> Type {
Type::Vector(4, ScalarType::F32)
}
pub fn atype(&self) -> Type {
Type::Vector(4, ScalarType::U32)
}
pub fn btype(&self) -> Type {
Type::Vector(2, ScalarType::U32)
}
pub fn ctype(&self) -> Type {
Type::Vector(4, ScalarType::F32)
}
}

View file

@ -1862,6 +1862,9 @@ derive_parser!(
#[derive(Copy, Clone, Display, PartialEq, Eq, Hash)]
pub enum MatrixNumber { }
#[derive(Copy, Clone, Display, PartialEq, Eq, Hash)]
pub enum MatrixLayout { }
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov
mov{.vec}.type d, a => {
Instruction::Mov {
@ -3905,6 +3908,29 @@ derive_parser!(
}
}
.action: GridDepControlAction = { .launch_dependents, .wait };
// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-mma
mma.sync.aligned.m16n8k16.alayout.blayout.dtype.bf16.bf16.ctype d, a, b, c => {
if dtype != ScalarType::F32 || ctype != ScalarType::F32 {
state.errors.push(PtxError::Todo);
}
Instruction::Mma {
data: MmaDetails {
alayout,
blayout,
dtype_scalar: dtype,
atype_scalar: ScalarType::BF16,
btype_scalar: ScalarType::BF16,
ctype_scalar: ctype,
},
arguments: MmaArgs { dst: d, src1: a, src2: b, src3: c }
}
}
.alayout: MatrixLayout = {.row};
.blayout: MatrixLayout = {.col};
.ctype: ScalarType = {.f16, .f32};
.dtype: ScalarType = {.f16, .f32};
);
#[cfg(test)]