mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-09-26 19:29:05 +00:00
Add mma
This commit is contained in:
parent
f96ea498bd
commit
00d7cd131b
8 changed files with 265 additions and 6 deletions
Binary file not shown.
|
@ -1,17 +1,21 @@
|
|||
// Every time this file changes it must te rebuilt, you need `rocm-llvm-dev` and `llvm-17`
|
||||
// `fdenormal-fp-math=dynamic` is required to make functions eligible for inlining
|
||||
// /opt/rocm/llvm/bin/clang -std=c++20 -Xclang -fdenormal-fp-math=dynamic -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1010 && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1010\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | sed 's/ nneg / /g' | sed 's/ disjoint / /g' | sed '/__hip_cuid/d' | sed 's/external protected/external hidden/g' | sed 's/trunc nuw/trunc/' | sed 's/trunc nsw/trunc/' | llvm-as-17 - -o zluda_ptx_impl.bc && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc
|
||||
// /opt/rocm/llvm/bin/clang -std=c++20 -Xclang -fdenormal-fp-math=dynamic -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1100 && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1100\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | sed 's/ nneg / /g' | sed 's/ disjoint / /g' | sed '/__hip_cuid/d' | sed 's/external protected/external hidden/g' | sed 's/trunc nuw/trunc/' | sed 's/trunc nsw/trunc/' | llvm-as-17 - -o zluda_ptx_impl.bc && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <bit>
|
||||
#include <cmath>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hip/amd_detail/amd_device_functions.h>
|
||||
#include <hip/hip_fp8.h>
|
||||
|
||||
#define SHARED_SPACE __attribute__((address_space(3)))
|
||||
#define CONSTANT_SPACE __attribute__((address_space(4)))
|
||||
|
||||
typedef _Float16 half16 __attribute__((ext_vector_type(16)));
|
||||
typedef float float8 __attribute__((ext_vector_type(8)));
|
||||
|
||||
#define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_##NAME
|
||||
#define ATTR(NAME) __ZLUDA_PTX_IMPL_ATTRIBUTE_##NAME
|
||||
#define DECLARE_ATTR(TYPE, NAME) \
|
||||
|
@ -624,4 +628,156 @@ typedef uint32_t ShflSyncResult __attribute__((ext_vector_type(2)));
|
|||
uint32_t x3 = load_single_matrix_trans(address, 24);
|
||||
return uint4::Native_vec_{x0, x1, x2, x3};
|
||||
}
|
||||
|
||||
static inline __device__ _Float16 top16_as_fp16(uint32_t value) {
|
||||
uint16_t half_bits = static_cast<uint16_t>((value >> 16) & 0xFFFF);
|
||||
return *reinterpret_cast<_Float16*>(&half_bits);
|
||||
}
|
||||
static inline __device__ _Float16 bottom16_as_fp16(uint32_t value) {
|
||||
uint16_t half_bits = static_cast<uint16_t>(value & 0xFFFF);
|
||||
return *reinterpret_cast<_Float16*>(&half_bits);
|
||||
}
|
||||
|
||||
static inline __device__ float bpermute_lane(int lane, float x) {
|
||||
return __hip_ds_bpermutef(4 * lane, x);
|
||||
}
|
||||
static inline __device__ uint32_t bpermute_lane(int lane, uint32_t x) {
|
||||
return __hip_ds_bpermute(4 * lane, x);
|
||||
}
|
||||
|
||||
static __device__ half16 shuffle_a(uint4::Native_vec_ a_reg) {
|
||||
const unsigned lIdx = threadIdx.x;
|
||||
const int lane = lIdx % 16; // Lanes 0-15 (the other 16 lanes are a duplicate in w32 mode)
|
||||
half16 aFrag;
|
||||
|
||||
for (int vGPR = 0; vGPR < 8; ++vGPR) {
|
||||
int cudaChunk = (vGPR / 4) * 2; // will be 0 or 2
|
||||
int cudaTID = (vGPR % 4 + lane * 4) % 32;
|
||||
uint32_t reg0, reg1;
|
||||
// Select the two consecutive elements from a_reg:
|
||||
if (cudaChunk == 0) {
|
||||
reg0 = a_reg.x;
|
||||
reg1 = a_reg.y;
|
||||
} else { // cudaChunk==2
|
||||
reg0 = a_reg.z;
|
||||
reg1 = a_reg.w;
|
||||
}
|
||||
uint32_t a_tmp0 = bpermute_lane(cudaTID, reg0);
|
||||
uint32_t a_tmp1 = bpermute_lane(cudaTID, reg1);
|
||||
uint32_t a_Frag_reg = (lane < 8) ? a_tmp0 : a_tmp1;
|
||||
aFrag[2 * vGPR] = bottom16_as_fp16(a_Frag_reg);
|
||||
aFrag[2 * vGPR + 1] = top16_as_fp16(a_Frag_reg);
|
||||
}
|
||||
return aFrag;
|
||||
}
|
||||
|
||||
static __device__ half16 shuffle_b(uint2::Native_vec_ b_reg) {
|
||||
const unsigned lIdx = threadIdx.x;
|
||||
const int lane = lIdx % 16;
|
||||
half16 bFrag;
|
||||
|
||||
for (int vGPR = 0; vGPR < 8; ++vGPR) {
|
||||
int cudaChunk = vGPR / 4; // will be 0 or 1
|
||||
int cudaTID = vGPR % 4 + (lane * 4) % 64;
|
||||
uint32_t reg = (cudaChunk == 0) ? b_reg.x : b_reg.y;
|
||||
uint32_t b_Frag_reg = bpermute_lane(cudaTID, reg);
|
||||
if (lane < 8) {
|
||||
bFrag[2 * vGPR] = bottom16_as_fp16(b_Frag_reg);
|
||||
bFrag[2 * vGPR + 1] = top16_as_fp16(b_Frag_reg);
|
||||
} else {
|
||||
bFrag[2 * vGPR] = 0.0f;
|
||||
bFrag[2 * vGPR + 1] = 0.0f;
|
||||
}
|
||||
}
|
||||
return bFrag;
|
||||
}
|
||||
|
||||
static __device__ float8 shuffle_c(float4::Native_vec_ c_reg) {
|
||||
const int lIdx = (int)threadIdx.x;
|
||||
float8 cFrag;
|
||||
|
||||
// Loop over the eight vector GPRs.
|
||||
for (int vGPR = 0; vGPR < 8; ++vGPR) {
|
||||
int cudaChunk = (vGPR / 4) * 2; // will be 0 or 2: selects which pair of components to use.
|
||||
int lIdx8 = (lIdx < 8) ? lIdx : lIdx - 8;
|
||||
int cudaTID = (vGPR % 4) * 8 + lIdx8 / 2;
|
||||
float ctmp0, ctmp1;
|
||||
|
||||
if (cudaChunk == 0) {
|
||||
ctmp0 = bpermute_lane(cudaTID, c_reg.x);
|
||||
ctmp1 = bpermute_lane(cudaTID, c_reg.y);
|
||||
} else { // cudaChunk == 2
|
||||
ctmp0 = bpermute_lane(cudaTID, c_reg.z);
|
||||
ctmp1 = bpermute_lane(cudaTID, c_reg.w);
|
||||
}
|
||||
|
||||
// Select one of the two values based on the thread index's LSB.
|
||||
cFrag[vGPR] = (lIdx & 1) ? ctmp1 : ctmp0;
|
||||
|
||||
// Zero out for specific thread indices.
|
||||
if ((lIdx > 7 && lIdx < 16) || (lIdx > 23 && lIdx < 32))
|
||||
cFrag[vGPR] = 0.0f;
|
||||
}
|
||||
return cFrag;
|
||||
}
|
||||
|
||||
static inline __device__ float4::Native_vec_ shuffle_d(float8 dFrag) {
|
||||
const int lIdx = (int)threadIdx.x;
|
||||
float4::Native_vec_ d_out;
|
||||
|
||||
for (int cChunk = 0; cChunk < 4; ++cChunk) {
|
||||
int r_vGPR = (cChunk / 2) * 4;
|
||||
int add8 = (lIdx & 0x4) ? 8 : 0;
|
||||
int r_lIdx = (cChunk % 2) + (lIdx % 8) * 2 + add8;
|
||||
float d_tmp0 = bpermute_lane(r_lIdx, dFrag[r_vGPR]);
|
||||
float d_tmp1 = bpermute_lane(r_lIdx, dFrag[r_vGPR + 1]);
|
||||
float d_tmp2 = bpermute_lane(r_lIdx, dFrag[r_vGPR + 2]);
|
||||
float d_tmp3 = bpermute_lane(r_lIdx, dFrag[r_vGPR + 3]);
|
||||
float val;
|
||||
if (lIdx < 8) {
|
||||
val = d_tmp0;
|
||||
} else if (lIdx < 16) {
|
||||
val = d_tmp1;
|
||||
} else if (lIdx < 24) {
|
||||
val = d_tmp2;
|
||||
} else {
|
||||
val = d_tmp3;
|
||||
}
|
||||
if (cChunk == 0) d_out.x = val;
|
||||
else if (cChunk == 1) d_out.y = val;
|
||||
else if (cChunk == 2) d_out.z = val;
|
||||
else d_out.w = val;
|
||||
}
|
||||
return d_out;
|
||||
}
|
||||
|
||||
float4::Native_vec_ FUNC(mma_sync_aligned_m16n8k16_row_col_f32_f16_f16_f32)(uint4::Native_vec_ a_reg, uint2::Native_vec_ b_reg, float4::Native_vec_ c_reg) {
|
||||
// Reshuffle from Nvidia-like register layout to AMD layout:
|
||||
half16 aFrag = shuffle_a(a_reg);
|
||||
half16 bFrag = shuffle_b(b_reg);
|
||||
float8 cFrag = shuffle_c(c_reg);
|
||||
|
||||
// Call the (built‐in) 16x16 MMA instruction. It returns a float8.
|
||||
float8 dFrag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(aFrag, bFrag, cFrag);
|
||||
|
||||
// Unshuffle back into Nvidia expected float4 result
|
||||
float4::Native_vec_ d_out = shuffle_d(dFrag);
|
||||
|
||||
return d_out;
|
||||
}
|
||||
|
||||
float4::Native_vec_ FUNC(mma_sync_aligned_m16n8k16_row_col_f32_bf16_bf16_f32)(uint4::Native_vec_ a_reg, uint2::Native_vec_ b_reg, float4::Native_vec_ c_reg) {
|
||||
// Reshuffle from Nvidia-like register layout to AMD layout:
|
||||
half16 aFrag = shuffle_a(a_reg);
|
||||
half16 bFrag = shuffle_b(b_reg);
|
||||
float8 cFrag = shuffle_c(c_reg);
|
||||
|
||||
// Call the (built‐in) 16x16 MMA instruction. It returns a float8.
|
||||
float8 dFrag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(aFrag, bFrag, cFrag);
|
||||
|
||||
// Unshuffle back into Nvidia expected float4 result
|
||||
float4::Native_vec_ d_out = shuffle_d(dFrag);
|
||||
|
||||
return d_out;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -198,7 +198,8 @@ fn run_instruction<'input>(
|
|||
| ast::Instruction::Vote { .. }
|
||||
| ast::Instruction::ReduxSync { .. }
|
||||
| ast::Instruction::GridDepControl { .. }
|
||||
| ast::Instruction::LdMatrix { .. } => result.push(Statement::Instruction(instruction)),
|
||||
| ast::Instruction::LdMatrix { .. }
|
||||
| ast::Instruction::Mma { .. } => result.push(Statement::Instruction(instruction)),
|
||||
ast::Instruction::Add {
|
||||
data:
|
||||
ast::ArithDetails::Float(ast::ArithFloat {
|
||||
|
|
|
@ -1856,7 +1856,8 @@ fn get_modes<T: ast::Operand>(inst: &ast::Instruction<T>) -> InstructionModes {
|
|||
| ast::Instruction::Vote { .. }
|
||||
| ast::Instruction::ReduxSync { .. }
|
||||
| ast::Instruction::GridDepControl { .. }
|
||||
| ast::Instruction::LdMatrix { .. } => InstructionModes::none(),
|
||||
| ast::Instruction::LdMatrix { .. }
|
||||
| ast::Instruction::Mma { .. } => InstructionModes::none(),
|
||||
ast::Instruction::Add {
|
||||
data: ast::ArithDetails::Integer(_),
|
||||
..
|
||||
|
|
|
@ -533,7 +533,8 @@ impl<'a> MethodEmitContext<'a> {
|
|||
| ast::Instruction::Vote { .. }
|
||||
| ast::Instruction::Nanosleep { .. }
|
||||
| ast::Instruction::ReduxSync { .. }
|
||||
| ast::Instruction::LdMatrix { .. } => return Err(error_unreachable()),
|
||||
| ast::Instruction::LdMatrix { .. }
|
||||
| ast::Instruction::Mma { .. } => return Err(error_unreachable()),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -351,6 +351,35 @@ fn run_instruction<'input>(
|
|||
let name = "sqrt_rn_ftz_f32";
|
||||
to_call(resolver, fn_declarations, name.into(), i)?
|
||||
}
|
||||
i @ ptx_parser::Instruction::Mma {
|
||||
data:
|
||||
ast::MmaDetails {
|
||||
alayout,
|
||||
blayout,
|
||||
dtype_scalar,
|
||||
atype_scalar,
|
||||
btype_scalar,
|
||||
ctype_scalar,
|
||||
},
|
||||
..
|
||||
} => {
|
||||
let name = format!(
|
||||
"mma_sync_aligned_m16n8k16_{}_{}_{}_{}_{}_{}",
|
||||
match alayout {
|
||||
ast::MatrixLayout::Row => "row",
|
||||
ast::MatrixLayout::Col => "col",
|
||||
},
|
||||
match blayout {
|
||||
ast::MatrixLayout::Row => "row",
|
||||
ast::MatrixLayout::Col => "col",
|
||||
},
|
||||
scalar_to_ptx_name(dtype_scalar),
|
||||
scalar_to_ptx_name(atype_scalar),
|
||||
scalar_to_ptx_name(btype_scalar),
|
||||
scalar_to_ptx_name(ctype_scalar),
|
||||
);
|
||||
to_call(resolver, fn_declarations, name.into(), i)?
|
||||
}
|
||||
i @ ptx_parser::Instruction::Sqrt {
|
||||
data:
|
||||
ast::RcpData {
|
||||
|
|
|
@ -3,8 +3,8 @@ use super::{
|
|||
StateSpace, VectorPrefix,
|
||||
};
|
||||
use crate::{
|
||||
FunnelShiftMode, MatrixNumber, MatrixShape, Mul24Control, PtxError, PtxParserState, Reduction,
|
||||
ShiftDirection, ShuffleMode, VoteMode,
|
||||
FunnelShiftMode, MatrixLayout, MatrixNumber, MatrixShape, Mul24Control, PtxError,
|
||||
PtxParserState, Reduction, ShiftDirection, ShuffleMode, VoteMode,
|
||||
};
|
||||
use bitflags::bitflags;
|
||||
use derive_more::Display;
|
||||
|
@ -724,6 +724,27 @@ ptx_parser_macros::generate_instruction_type!(
|
|||
},
|
||||
GridDepControl {
|
||||
data: crate::GridDepControlAction,
|
||||
},
|
||||
Mma {
|
||||
data: MmaDetails,
|
||||
arguments<T>: {
|
||||
dst: {
|
||||
repr: T,
|
||||
type: { data.dtype() },
|
||||
},
|
||||
src1: {
|
||||
repr: T,
|
||||
type: { data.atype() },
|
||||
},
|
||||
src2: {
|
||||
repr: T,
|
||||
type: { data.btype() },
|
||||
},
|
||||
src3: {
|
||||
repr: T,
|
||||
type: { data.ctype() },
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
);
|
||||
|
@ -2381,3 +2402,27 @@ pub struct ReduxSyncData {
|
|||
pub type_: ScalarType,
|
||||
pub reduction: Reduction,
|
||||
}
|
||||
|
||||
pub struct MmaDetails {
|
||||
pub alayout: MatrixLayout,
|
||||
pub blayout: MatrixLayout,
|
||||
pub dtype_scalar: ScalarType,
|
||||
pub atype_scalar: ScalarType,
|
||||
pub btype_scalar: ScalarType,
|
||||
pub ctype_scalar: ScalarType,
|
||||
}
|
||||
|
||||
impl MmaDetails {
|
||||
pub fn dtype(&self) -> Type {
|
||||
Type::Vector(4, ScalarType::F32)
|
||||
}
|
||||
pub fn atype(&self) -> Type {
|
||||
Type::Vector(4, ScalarType::U32)
|
||||
}
|
||||
pub fn btype(&self) -> Type {
|
||||
Type::Vector(2, ScalarType::U32)
|
||||
}
|
||||
pub fn ctype(&self) -> Type {
|
||||
Type::Vector(4, ScalarType::F32)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1862,6 +1862,9 @@ derive_parser!(
|
|||
#[derive(Copy, Clone, Display, PartialEq, Eq, Hash)]
|
||||
pub enum MatrixNumber { }
|
||||
|
||||
#[derive(Copy, Clone, Display, PartialEq, Eq, Hash)]
|
||||
pub enum MatrixLayout { }
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov
|
||||
mov{.vec}.type d, a => {
|
||||
Instruction::Mov {
|
||||
|
@ -3905,6 +3908,29 @@ derive_parser!(
|
|||
}
|
||||
}
|
||||
.action: GridDepControlAction = { .launch_dependents, .wait };
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-mma
|
||||
mma.sync.aligned.m16n8k16.alayout.blayout.dtype.bf16.bf16.ctype d, a, b, c => {
|
||||
if dtype != ScalarType::F32 || ctype != ScalarType::F32 {
|
||||
state.errors.push(PtxError::Todo);
|
||||
}
|
||||
Instruction::Mma {
|
||||
data: MmaDetails {
|
||||
alayout,
|
||||
blayout,
|
||||
dtype_scalar: dtype,
|
||||
atype_scalar: ScalarType::BF16,
|
||||
btype_scalar: ScalarType::BF16,
|
||||
ctype_scalar: ctype,
|
||||
},
|
||||
arguments: MmaArgs { dst: d, src1: a, src2: b, src3: c }
|
||||
}
|
||||
}
|
||||
|
||||
.alayout: MatrixLayout = {.row};
|
||||
.blayout: MatrixLayout = {.col};
|
||||
.ctype: ScalarType = {.f16, .f32};
|
||||
.dtype: ScalarType = {.f16, .f32};
|
||||
);
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue