diff --git a/ptx/lib/zluda_ptx_impl.bc b/ptx/lib/zluda_ptx_impl.bc index 2b62aeb..bc375c3 100644 Binary files a/ptx/lib/zluda_ptx_impl.bc and b/ptx/lib/zluda_ptx_impl.bc differ diff --git a/ptx/lib/zluda_ptx_impl.cpp b/ptx/lib/zluda_ptx_impl.cpp index cc1d973..6174ec1 100644 --- a/ptx/lib/zluda_ptx_impl.cpp +++ b/ptx/lib/zluda_ptx_impl.cpp @@ -1,17 +1,21 @@ // Every time this file changes it must te rebuilt, you need `rocm-llvm-dev` and `llvm-17` // `fdenormal-fp-math=dynamic` is required to make functions eligible for inlining -// /opt/rocm/llvm/bin/clang -std=c++20 -Xclang -fdenormal-fp-math=dynamic -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1010 && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1010\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | sed 's/ nneg / /g' | sed 's/ disjoint / /g' | sed '/__hip_cuid/d' | sed 's/external protected/external hidden/g' | sed 's/trunc nuw/trunc/' | sed 's/trunc nsw/trunc/' | llvm-as-17 - -o zluda_ptx_impl.bc && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc +// /opt/rocm/llvm/bin/clang -std=c++20 -Xclang -fdenormal-fp-math=dynamic -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1100 && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1100\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | sed 's/ nneg / /g' | sed 's/ disjoint / /g' | sed '/__hip_cuid/d' | sed 's/external protected/external hidden/g' | sed 's/trunc nuw/trunc/' | sed 's/trunc nsw/trunc/' | llvm-as-17 - -o zluda_ptx_impl.bc && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc #include #include #include #include +#include #include #include #define SHARED_SPACE __attribute__((address_space(3))) #define CONSTANT_SPACE __attribute__((address_space(4))) +typedef _Float16 half16 __attribute__((ext_vector_type(16))); +typedef float float8 __attribute__((ext_vector_type(8))); + #define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_##NAME #define ATTR(NAME) __ZLUDA_PTX_IMPL_ATTRIBUTE_##NAME #define DECLARE_ATTR(TYPE, NAME) \ @@ -624,4 +628,156 @@ typedef uint32_t ShflSyncResult __attribute__((ext_vector_type(2))); uint32_t x3 = load_single_matrix_trans(address, 24); return uint4::Native_vec_{x0, x1, x2, x3}; } + + static inline __device__ _Float16 top16_as_fp16(uint32_t value) { + uint16_t half_bits = static_cast((value >> 16) & 0xFFFF); + return *reinterpret_cast<_Float16*>(&half_bits); + } + static inline __device__ _Float16 bottom16_as_fp16(uint32_t value) { + uint16_t half_bits = static_cast(value & 0xFFFF); + return *reinterpret_cast<_Float16*>(&half_bits); + } + + static inline __device__ float bpermute_lane(int lane, float x) { + return __hip_ds_bpermutef(4 * lane, x); + } + static inline __device__ uint32_t bpermute_lane(int lane, uint32_t x) { + return __hip_ds_bpermute(4 * lane, x); + } + + static __device__ half16 shuffle_a(uint4::Native_vec_ a_reg) { + const unsigned lIdx = threadIdx.x; + const int lane = lIdx % 16; // Lanes 0-15 (the other 16 lanes are a duplicate in w32 mode) + half16 aFrag; + + for (int vGPR = 0; vGPR < 8; ++vGPR) { + int cudaChunk = (vGPR / 4) * 2; // will be 0 or 2 + int cudaTID = (vGPR % 4 + lane * 4) % 32; + uint32_t reg0, reg1; + // Select the two consecutive elements from a_reg: + if (cudaChunk == 0) { + reg0 = a_reg.x; + reg1 = a_reg.y; + } else { // cudaChunk==2 + reg0 = a_reg.z; + reg1 = a_reg.w; + } + uint32_t a_tmp0 = bpermute_lane(cudaTID, reg0); + uint32_t a_tmp1 = bpermute_lane(cudaTID, reg1); + uint32_t a_Frag_reg = (lane < 8) ? a_tmp0 : a_tmp1; + aFrag[2 * vGPR] = bottom16_as_fp16(a_Frag_reg); + aFrag[2 * vGPR + 1] = top16_as_fp16(a_Frag_reg); + } + return aFrag; + } + + static __device__ half16 shuffle_b(uint2::Native_vec_ b_reg) { + const unsigned lIdx = threadIdx.x; + const int lane = lIdx % 16; + half16 bFrag; + + for (int vGPR = 0; vGPR < 8; ++vGPR) { + int cudaChunk = vGPR / 4; // will be 0 or 1 + int cudaTID = vGPR % 4 + (lane * 4) % 64; + uint32_t reg = (cudaChunk == 0) ? b_reg.x : b_reg.y; + uint32_t b_Frag_reg = bpermute_lane(cudaTID, reg); + if (lane < 8) { + bFrag[2 * vGPR] = bottom16_as_fp16(b_Frag_reg); + bFrag[2 * vGPR + 1] = top16_as_fp16(b_Frag_reg); + } else { + bFrag[2 * vGPR] = 0.0f; + bFrag[2 * vGPR + 1] = 0.0f; + } + } + return bFrag; + } + + static __device__ float8 shuffle_c(float4::Native_vec_ c_reg) { + const int lIdx = (int)threadIdx.x; + float8 cFrag; + + // Loop over the eight vector GPRs. + for (int vGPR = 0; vGPR < 8; ++vGPR) { + int cudaChunk = (vGPR / 4) * 2; // will be 0 or 2: selects which pair of components to use. + int lIdx8 = (lIdx < 8) ? lIdx : lIdx - 8; + int cudaTID = (vGPR % 4) * 8 + lIdx8 / 2; + float ctmp0, ctmp1; + + if (cudaChunk == 0) { + ctmp0 = bpermute_lane(cudaTID, c_reg.x); + ctmp1 = bpermute_lane(cudaTID, c_reg.y); + } else { // cudaChunk == 2 + ctmp0 = bpermute_lane(cudaTID, c_reg.z); + ctmp1 = bpermute_lane(cudaTID, c_reg.w); + } + + // Select one of the two values based on the thread index's LSB. + cFrag[vGPR] = (lIdx & 1) ? ctmp1 : ctmp0; + + // Zero out for specific thread indices. + if ((lIdx > 7 && lIdx < 16) || (lIdx > 23 && lIdx < 32)) + cFrag[vGPR] = 0.0f; + } + return cFrag; + } + + static inline __device__ float4::Native_vec_ shuffle_d(float8 dFrag) { + const int lIdx = (int)threadIdx.x; + float4::Native_vec_ d_out; + + for (int cChunk = 0; cChunk < 4; ++cChunk) { + int r_vGPR = (cChunk / 2) * 4; + int add8 = (lIdx & 0x4) ? 8 : 0; + int r_lIdx = (cChunk % 2) + (lIdx % 8) * 2 + add8; + float d_tmp0 = bpermute_lane(r_lIdx, dFrag[r_vGPR]); + float d_tmp1 = bpermute_lane(r_lIdx, dFrag[r_vGPR + 1]); + float d_tmp2 = bpermute_lane(r_lIdx, dFrag[r_vGPR + 2]); + float d_tmp3 = bpermute_lane(r_lIdx, dFrag[r_vGPR + 3]); + float val; + if (lIdx < 8) { + val = d_tmp0; + } else if (lIdx < 16) { + val = d_tmp1; + } else if (lIdx < 24) { + val = d_tmp2; + } else { + val = d_tmp3; + } + if (cChunk == 0) d_out.x = val; + else if (cChunk == 1) d_out.y = val; + else if (cChunk == 2) d_out.z = val; + else d_out.w = val; + } + return d_out; + } + + float4::Native_vec_ FUNC(mma_sync_aligned_m16n8k16_row_col_f32_f16_f16_f32)(uint4::Native_vec_ a_reg, uint2::Native_vec_ b_reg, float4::Native_vec_ c_reg) { + // Reshuffle from Nvidia-like register layout to AMD layout: + half16 aFrag = shuffle_a(a_reg); + half16 bFrag = shuffle_b(b_reg); + float8 cFrag = shuffle_c(c_reg); + + // Call the (built‐in) 16x16 MMA instruction. It returns a float8. + float8 dFrag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(aFrag, bFrag, cFrag); + + // Unshuffle back into Nvidia expected float4 result + float4::Native_vec_ d_out = shuffle_d(dFrag); + + return d_out; + } + + float4::Native_vec_ FUNC(mma_sync_aligned_m16n8k16_row_col_f32_bf16_bf16_f32)(uint4::Native_vec_ a_reg, uint2::Native_vec_ b_reg, float4::Native_vec_ c_reg) { + // Reshuffle from Nvidia-like register layout to AMD layout: + half16 aFrag = shuffle_a(a_reg); + half16 bFrag = shuffle_b(b_reg); + float8 cFrag = shuffle_c(c_reg); + + // Call the (built‐in) 16x16 MMA instruction. It returns a float8. + float8 dFrag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(aFrag, bFrag, cFrag); + + // Unshuffle back into Nvidia expected float4 result + float4::Native_vec_ d_out = shuffle_d(dFrag); + + return d_out; + } } diff --git a/ptx/src/pass/insert_post_saturation.rs b/ptx/src/pass/insert_post_saturation.rs index d3e0b7b..525ae15 100644 --- a/ptx/src/pass/insert_post_saturation.rs +++ b/ptx/src/pass/insert_post_saturation.rs @@ -198,7 +198,8 @@ fn run_instruction<'input>( | ast::Instruction::Vote { .. } | ast::Instruction::ReduxSync { .. } | ast::Instruction::GridDepControl { .. } - | ast::Instruction::LdMatrix { .. } => result.push(Statement::Instruction(instruction)), + | ast::Instruction::LdMatrix { .. } + | ast::Instruction::Mma { .. } => result.push(Statement::Instruction(instruction)), ast::Instruction::Add { data: ast::ArithDetails::Float(ast::ArithFloat { diff --git a/ptx/src/pass/instruction_mode_to_global_mode/mod.rs b/ptx/src/pass/instruction_mode_to_global_mode/mod.rs index 229e179..d365e29 100644 --- a/ptx/src/pass/instruction_mode_to_global_mode/mod.rs +++ b/ptx/src/pass/instruction_mode_to_global_mode/mod.rs @@ -1856,7 +1856,8 @@ fn get_modes(inst: &ast::Instruction) -> InstructionModes { | ast::Instruction::Vote { .. } | ast::Instruction::ReduxSync { .. } | ast::Instruction::GridDepControl { .. } - | ast::Instruction::LdMatrix { .. } => InstructionModes::none(), + | ast::Instruction::LdMatrix { .. } + | ast::Instruction::Mma { .. } => InstructionModes::none(), ast::Instruction::Add { data: ast::ArithDetails::Integer(_), .. diff --git a/ptx/src/pass/llvm/emit.rs b/ptx/src/pass/llvm/emit.rs index 0677345..144f5e6 100644 --- a/ptx/src/pass/llvm/emit.rs +++ b/ptx/src/pass/llvm/emit.rs @@ -533,7 +533,8 @@ impl<'a> MethodEmitContext<'a> { | ast::Instruction::Vote { .. } | ast::Instruction::Nanosleep { .. } | ast::Instruction::ReduxSync { .. } - | ast::Instruction::LdMatrix { .. } => return Err(error_unreachable()), + | ast::Instruction::LdMatrix { .. } + | ast::Instruction::Mma { .. } => return Err(error_unreachable()), } } diff --git a/ptx/src/pass/replace_instructions_with_functions.rs b/ptx/src/pass/replace_instructions_with_functions.rs index 19e16e7..f7c976e 100644 --- a/ptx/src/pass/replace_instructions_with_functions.rs +++ b/ptx/src/pass/replace_instructions_with_functions.rs @@ -351,6 +351,35 @@ fn run_instruction<'input>( let name = "sqrt_rn_ftz_f32"; to_call(resolver, fn_declarations, name.into(), i)? } + i @ ptx_parser::Instruction::Mma { + data: + ast::MmaDetails { + alayout, + blayout, + dtype_scalar, + atype_scalar, + btype_scalar, + ctype_scalar, + }, + .. + } => { + let name = format!( + "mma_sync_aligned_m16n8k16_{}_{}_{}_{}_{}_{}", + match alayout { + ast::MatrixLayout::Row => "row", + ast::MatrixLayout::Col => "col", + }, + match blayout { + ast::MatrixLayout::Row => "row", + ast::MatrixLayout::Col => "col", + }, + scalar_to_ptx_name(dtype_scalar), + scalar_to_ptx_name(atype_scalar), + scalar_to_ptx_name(btype_scalar), + scalar_to_ptx_name(ctype_scalar), + ); + to_call(resolver, fn_declarations, name.into(), i)? + } i @ ptx_parser::Instruction::Sqrt { data: ast::RcpData { diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 04e1b6a..1bc622c 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -3,8 +3,8 @@ use super::{ StateSpace, VectorPrefix, }; use crate::{ - FunnelShiftMode, MatrixNumber, MatrixShape, Mul24Control, PtxError, PtxParserState, Reduction, - ShiftDirection, ShuffleMode, VoteMode, + FunnelShiftMode, MatrixLayout, MatrixNumber, MatrixShape, Mul24Control, PtxError, + PtxParserState, Reduction, ShiftDirection, ShuffleMode, VoteMode, }; use bitflags::bitflags; use derive_more::Display; @@ -724,6 +724,27 @@ ptx_parser_macros::generate_instruction_type!( }, GridDepControl { data: crate::GridDepControlAction, + }, + Mma { + data: MmaDetails, + arguments: { + dst: { + repr: T, + type: { data.dtype() }, + }, + src1: { + repr: T, + type: { data.atype() }, + }, + src2: { + repr: T, + type: { data.btype() }, + }, + src3: { + repr: T, + type: { data.ctype() }, + } + } } } ); @@ -2381,3 +2402,27 @@ pub struct ReduxSyncData { pub type_: ScalarType, pub reduction: Reduction, } + +pub struct MmaDetails { + pub alayout: MatrixLayout, + pub blayout: MatrixLayout, + pub dtype_scalar: ScalarType, + pub atype_scalar: ScalarType, + pub btype_scalar: ScalarType, + pub ctype_scalar: ScalarType, +} + +impl MmaDetails { + pub fn dtype(&self) -> Type { + Type::Vector(4, ScalarType::F32) + } + pub fn atype(&self) -> Type { + Type::Vector(4, ScalarType::U32) + } + pub fn btype(&self) -> Type { + Type::Vector(2, ScalarType::U32) + } + pub fn ctype(&self) -> Type { + Type::Vector(4, ScalarType::F32) + } +} diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index 5389118..a4f9080 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -1862,6 +1862,9 @@ derive_parser!( #[derive(Copy, Clone, Display, PartialEq, Eq, Hash)] pub enum MatrixNumber { } + #[derive(Copy, Clone, Display, PartialEq, Eq, Hash)] + pub enum MatrixLayout { } + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov mov{.vec}.type d, a => { Instruction::Mov { @@ -3905,6 +3908,29 @@ derive_parser!( } } .action: GridDepControlAction = { .launch_dependents, .wait }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-mma + mma.sync.aligned.m16n8k16.alayout.blayout.dtype.bf16.bf16.ctype d, a, b, c => { + if dtype != ScalarType::F32 || ctype != ScalarType::F32 { + state.errors.push(PtxError::Todo); + } + Instruction::Mma { + data: MmaDetails { + alayout, + blayout, + dtype_scalar: dtype, + atype_scalar: ScalarType::BF16, + btype_scalar: ScalarType::BF16, + ctype_scalar: ctype, + }, + arguments: MmaArgs { dst: d, src1: a, src2: b, src3: c } + } + } + + .alayout: MatrixLayout = {.row}; + .blayout: MatrixLayout = {.col}; + .ctype: ScalarType = {.f16, .f32}; + .dtype: ScalarType = {.f16, .f32}; ); #[cfg(test)]