Implement ldmatrix (#503)
Some checks failed
ZLUDA / Build (Linux) (push) Has been cancelled
ZLUDA / Build (Windows) (push) Has been cancelled
ZLUDA / Build AMD GPU unit tests (push) Has been cancelled
ZLUDA / Run AMD GPU unit tests (push) Has been cancelled

This commit is contained in:
Violet 2025-09-09 19:31:56 -07:00 committed by GitHub
commit 7b5fdb30c4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 607 additions and 11 deletions

Binary file not shown.

View file

@ -9,6 +9,7 @@
#include <hip/amd_detail/amd_device_functions.h>
#include <hip/hip_fp8.h>
#define SHARED_SPACE __attribute__((address_space(3)))
#define CONSTANT_SPACE __attribute__((address_space(4)))
#define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_##NAME
@ -577,4 +578,50 @@ typedef uint32_t ShflSyncResult __attribute__((ext_vector_type(2)));
REDUX_SYNC_IMPL(add);
REDUX_SYNC_IMPL(min);
REDUX_SYNC_IMPL(max);
__device__ inline static uint32_t load_single_matrix(void SHARED_SPACE * lds_address, uint32_t warp_offset)
{
uint32_t laneid = __zluda_ptx_impl_sreg_laneid();
int32_t row_address = __builtin_amdgcn_ds_bpermute((int32_t)(warp_offset + (laneid / 4U)) << 2U, (int32_t)lds_address);
uint32_t matrix_cell_address = (uint32_t)row_address + ((laneid % 4) * 4);
return *((uint32_t SHARED_SPACE*)matrix_cell_address);
}
__device__ inline static uint32_t load_single_matrix_trans(void SHARED_SPACE * lds_address, uint32_t warp_offset)
{
uint32_t laneid = __zluda_ptx_impl_sreg_laneid();
int32_t row_address_lo = __builtin_amdgcn_ds_bpermute((int32_t)(warp_offset + ((laneid % 4U) * 2)) << 2U, (int32_t)lds_address);
uint32_t address_lo = (uint32_t)row_address_lo + ((laneid / 4) * 2);
uint16_t lo = *((uint16_t SHARED_SPACE*)address_lo);
int32_t row_address_hi = __builtin_amdgcn_ds_bpermute((int32_t)(warp_offset + ((laneid % 4U) * 2) + 1) << 2U, (int32_t)lds_address);
uint32_t address_hi = (uint32_t)row_address_hi + ((laneid / 4) * 2);
uint16_t hi = *((uint16_t SHARED_SPACE*)address_hi);
return std::bit_cast<uint32_t>(ushort2::Native_vec_ { lo, hi });
}
uint2::Native_vec_ FUNC(ldmatrix_m8n8_x2_b16)(void SHARED_SPACE * address)
{
uint32_t x0 = load_single_matrix(address, 0);
uint32_t x1 = load_single_matrix(address, 8);
return uint2::Native_vec_{x0, x1};
}
uint4::Native_vec_ FUNC(ldmatrix_m8n8_x4_b16)(void SHARED_SPACE * address)
{
uint32_t x0 = load_single_matrix(address, 0);
uint32_t x1 = load_single_matrix(address, 8);
uint32_t x2 = load_single_matrix(address, 16);
uint32_t x3 = load_single_matrix(address, 24);
return uint4::Native_vec_{x0, x1, x2, x3};
}
uint4::Native_vec_ FUNC(ldmatrix_m8n8_x4_trans_b16)(void SHARED_SPACE * address)
{
uint32_t x0 = load_single_matrix_trans(address, 0);
uint32_t x1 = load_single_matrix_trans(address, 8);
uint32_t x2 = load_single_matrix_trans(address, 16);
uint32_t x3 = load_single_matrix_trans(address, 24);
return uint4::Native_vec_{x0, x1, x2, x3};
}
}

View file

@ -196,7 +196,8 @@ fn run_instruction<'input>(
| ast::Instruction::Trap {}
| ast::Instruction::Xor { .. }
| ast::Instruction::Vote { .. }
| ast::Instruction::ReduxSync { .. } => result.push(Statement::Instruction(instruction)),
| ast::Instruction::ReduxSync { .. }
| ast::Instruction::LdMatrix { .. } => result.push(Statement::Instruction(instruction)),
ast::Instruction::Add {
data:
ast::ArithDetails::Float(ast::ArithFloat {

View file

@ -1854,7 +1854,8 @@ fn get_modes<T: ast::Operand>(inst: &ast::Instruction<T>) -> InstructionModes {
| ast::Instruction::Nanosleep { .. }
| ast::Instruction::AtomCas { .. }
| ast::Instruction::Vote { .. }
| ast::Instruction::ReduxSync { .. } => InstructionModes::none(),
| ast::Instruction::ReduxSync { .. }
| ast::Instruction::LdMatrix { .. } => InstructionModes::none(),
ast::Instruction::Add {
data: ast::ArithDetails::Integer(_),
..

View file

@ -336,7 +336,7 @@ fn get_input_argument_type(
state_space: ast::StateSpace,
) -> Result<LLVMTypeRef, TranslateError> {
match state_space {
ast::StateSpace::ParamEntry => {
ast::StateSpace::ParamEntry | ast::StateSpace::Shared => {
Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(state_space)?) })
}
ast::StateSpace::Reg => get_type(context, v_type),
@ -527,7 +527,8 @@ impl<'a> MethodEmitContext<'a> {
| ast::Instruction::ShflSync { .. }
| ast::Instruction::Vote { .. }
| ast::Instruction::Nanosleep { .. }
| ast::Instruction::ReduxSync { .. } => return Err(error_unreachable()),
| ast::Instruction::ReduxSync { .. }
| ast::Instruction::LdMatrix { .. } => return Err(error_unreachable()),
}
}
@ -693,6 +694,7 @@ impl<'a> MethodEmitContext<'a> {
}
}
(ast::Type::Vector(..), ast::Type::Scalar(..))
| (ast::Type::Scalar(..), ast::Type::Vector(..))
| (ast::Type::Scalar(..), ast::Type::Array(..))
| (ast::Type::Array(..), ast::Type::Scalar(..)) => {
let dst_type = get_type(self.context, to_type)?;
@ -780,11 +782,6 @@ impl<'a> MethodEmitContext<'a> {
panic!()
}
}
for (_, space) in data.input_arguments.iter() {
if *space != ast::StateSpace::Reg {
panic!()
}
}
}
let name = match &*arguments.return_arguments {
[dst] => self.resolver.get_or_add_raw(*dst),

View file

@ -466,6 +466,30 @@ fn run_instruction<'input>(
i,
)?
}
i @ ptx_parser::Instruction::LdMatrix { data, .. } => {
let shape = match data.shape {
ptx_parser::MatrixShape::M8n8 => "m8n8",
ptx_parser::MatrixShape::M16n16 => return Err(error_todo()),
};
let number = match data.number {
ptx_parser::MatrixNumber::X2 => "x2",
ptx_parser::MatrixNumber::X4 => "x4",
ptx_parser::MatrixNumber::X1 => return Err(error_todo()),
};
let trans = if data.transpose { "_trans" } else { "" };
let type_str = match data.type_ {
ptx_parser::ScalarType::B16 => "b16",
ptx_parser::ScalarType::B8 => return Err(error_todo()),
_ => return Err(error_unreachable()),
};
to_call(
resolver,
fn_declarations,
format!("ldmatrix_{}_{}{}_{}", shape, number, trans, type_str).into(),
i,
)?
}
i => i,
})
}

View file

@ -0,0 +1,99 @@
@values_g = addrspace(1) global [64 x i32] [i32 340, i32 122, i32 527, i32 693, i32 958, i32 394, i32 668, i32 432, i32 646, i32 354, i32 761, i32 449, i32 252, i32 778, i32 218, i32 800, i32 656, i32 493, i32 659, i32 787, i32 672, i32 203, i32 343, i32 845, i32 318, i32 286, i32 206, i32 253, i32 194, i32 489, i32 29, i32 323, i32 7, i32 619, i32 998, i32 930, i32 773, i32 749, i32 172, i32 465, i32 937, i32 96, i32 88, i32 621, i32 909, i32 298, i32 283, i32 286, i32 779, i32 290, i32 429, i32 930, i32 25, i32 687, i32 423, i32 200, i32 918, i32 10, i32 515, i32 248, i32 158, i32 911, i32 270, i32 459]
@values_s = external addrspace(3) global [64 x i32], align 16
declare hidden <2 x i32> @__zluda_ptx_impl_ldmatrix_m8n8_x2_b16(ptr addrspace(3)) #0
declare hidden i32 @__zluda_ptx_impl_sreg_tid(i8) #0
define amdgpu_kernel void @ldmatrix(ptr addrspace(4) byref(i64) %"55") #1 {
%"56" = alloca i64, align 8, addrspace(5)
%"57" = alloca i32, align 4, addrspace(5)
%"58" = alloca i64, align 8, addrspace(5)
%"59" = alloca i64, align 8, addrspace(5)
%"60" = alloca i32, align 4, addrspace(5)
%"61" = alloca i64, align 8, addrspace(5)
%"62" = alloca i32, align 4, addrspace(5)
%"63" = alloca i32, align 4, addrspace(5)
%"64" = alloca i32, align 4, addrspace(5)
br label %1
1: ; preds = %0
br label %"52"
"52": ; preds = %1
%"65" = load i64, ptr addrspace(4) %"55", align 8
store i64 %"65", ptr addrspace(5) %"56", align 8
%"40" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0)
br label %"53"
"53": ; preds = %"52"
store i32 %"40", ptr addrspace(5) %"57", align 4
%"68" = load i32, ptr addrspace(5) %"57", align 4
%"67" = zext i32 %"68" to i64
store i64 %"67", ptr addrspace(5) %"58", align 8
store i64 ptrtoint (ptr addrspace(1) @values_g to i64), ptr addrspace(5) %"59", align 8
%"71" = load i64, ptr addrspace(5) %"58", align 8
%"72" = load i64, ptr addrspace(5) %"59", align 8
%2 = mul i64 %"71", 4
%"70" = add i64 %2, %"72"
store i64 %"70", ptr addrspace(5) %"59", align 8
%"74" = load i64, ptr addrspace(5) %"59", align 8
%"106" = inttoptr i64 %"74" to ptr addrspace(1)
%"105" = load i32, ptr addrspace(1) %"106", align 4
store i32 %"105", ptr addrspace(5) %"62", align 4
store i32 ptrtoint (ptr addrspace(3) @values_s to i32), ptr addrspace(5) %"60", align 4
%"77" = load i32, ptr addrspace(5) %"57", align 4
%"78" = load i32, ptr addrspace(5) %"60", align 4
%3 = mul i32 %"77", 4
%"108" = add i32 %3, %"78"
store i32 %"108", ptr addrspace(5) %"60", align 4
%"79" = load i32, ptr addrspace(5) %"60", align 4
%"80" = load i32, ptr addrspace(5) %"62", align 4
%"110" = inttoptr i32 %"79" to ptr addrspace(3)
store i32 %"80", ptr addrspace(3) %"110", align 4
%"81" = load i64, ptr addrspace(5) %"59", align 8
%"112" = inttoptr i64 %"81" to ptr addrspace(1)
%"44" = getelementptr inbounds i8, ptr addrspace(1) %"112", i64 128
%"113" = load i32, ptr addrspace(1) %"44", align 4
store i32 %"113", ptr addrspace(5) %"62", align 4
%"83" = load i32, ptr addrspace(5) %"60", align 4
%"114" = inttoptr i32 %"83" to ptr addrspace(3)
%"46" = getelementptr inbounds i8, ptr addrspace(3) %"114", i64 128
%"84" = load i32, ptr addrspace(5) %"62", align 4
store i32 %"84", ptr addrspace(3) %"46", align 4
store i64 ptrtoint (ptr addrspace(3) @values_s to i64), ptr addrspace(5) %"61", align 8
%"87" = load i64, ptr addrspace(5) %"61", align 8
%4 = inttoptr i64 %"87" to ptr addrspace(3)
%"86" = addrspacecast ptr addrspace(3) %4 to ptr
store ptr %"86", ptr addrspace(5) %"61", align 8
%"89" = load i64, ptr addrspace(5) %"58", align 8
%"90" = load i64, ptr addrspace(5) %"61", align 8
%5 = mul i64 %"89", 16
%"117" = add i64 %5, %"90"
store i64 %"117", ptr addrspace(5) %"61", align 8
%"91" = load i64, ptr addrspace(5) %"61", align 8
%"119" = inttoptr i64 %"91" to ptr addrspace(3)
%"48" = call <2 x i32> @__zluda_ptx_impl_ldmatrix_m8n8_x2_b16(ptr addrspace(3) %"119")
%"120" = extractelement <2 x i32> %"48", i8 0
%"121" = extractelement <2 x i32> %"48", i8 1
store i32 %"120", ptr addrspace(5) %"63", align 4
store i32 %"121", ptr addrspace(5) %"64", align 4
%"95" = load i64, ptr addrspace(5) %"58", align 8
%"96" = load i64, ptr addrspace(5) %"56", align 8
%6 = mul i64 %"95", 8
%"94" = add i64 %6, %"96"
store i64 %"94", ptr addrspace(5) %"56", align 8
%"97" = load i64, ptr addrspace(5) %"56", align 8
%"98" = load i32, ptr addrspace(5) %"63", align 4
%"122" = inttoptr i64 %"97" to ptr
store i32 %"98", ptr %"122", align 4
%"99" = load i64, ptr addrspace(5) %"56", align 8
%"123" = inttoptr i64 %"99" to ptr
%"51" = getelementptr inbounds i8, ptr %"123", i64 4
%"100" = load i32, ptr addrspace(5) %"64", align 4
store i32 %"100", ptr %"51", align 4
ret void
}
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View file

@ -0,0 +1,164 @@
@values_g = addrspace(1) global [256 x i16] [i16 1340, i16 122, i16 527, i16 693, i16 958, i16 394, i16 668, i16 432, i16 646, i16 354, i16 761, i16 449, i16 252, i16 778, i16 218, i16 800, i16 656, i16 493, i16 659, i16 787, i16 672, i16 203, i16 343, i16 845, i16 318, i16 286, i16 206, i16 253, i16 194, i16 489, i16 29, i16 323, i16 7, i16 619, i16 998, i16 930, i16 773, i16 749, i16 172, i16 465, i16 937, i16 96, i16 88, i16 621, i16 909, i16 298, i16 283, i16 286, i16 779, i16 290, i16 429, i16 930, i16 25, i16 687, i16 423, i16 200, i16 918, i16 10, i16 515, i16 248, i16 158, i16 911, i16 270, i16 459, i16 5832, i16 3864, i16 7868, i16 6538, i16 3898, i16 8685, i16 356, i16 3655, i16 3398, i16 8529, i16 2866, i16 1432, i16 4078, i16 1674, i16 498, i16 1124, i16 1576, i16 6490, i16 9895, i16 2152, i16 9668, i16 7349, i16 1948, i16 6239, i16 7944, i16 7630, i16 9699, i16 1957, i16 3360, i16 2291, i16 3832, i16 7370, i16 2683, i16 7465, i16 3107, i16 9822, i16 2510, i16 1642, i16 3240, i16 8860, i16 4935, i16 1935, i16 9328, i16 5164, i16 2759, i16 4816, i16 1049, i16 725, i16 9774, i16 5110, i16 5071, i16 8047, i16 7267, i16 7716, i16 1622, i16 9645, i16 6382, i16 1210, i16 2742, i16 2248, i16 6789, i16 5282, i16 5653, i16 5407, i16 29007, i16 29415, i16 25313, i16 -21396, i16 -15994, i16 21119, i16 -9745, i16 -22804, i16 -1897, i16 13898, i16 -7216, i16 20222, i16 31469, i16 -30937, i16 -676, i16 -4865, i16 4232, i16 -9793, i16 -11737, i16 -21717, i16 14011, i16 12369, i16 -8916, i16 13717, i16 12500, i16 -6672, i16 -31251, i16 -8199, i16 20956, i16 4977, i16 -16240, i16 19215, i16 -18975, i16 -1326, i16 -20663, i16 -29785, i16 15886, i16 14343, i16 966, i16 3529, i16 6132, i16 -8396, i16 -5346, i16 10303, i16 -22494, i16 2064, i16 22282, i16 -3981, i16 25824, i16 31442, i16 -8521, i16 -14400, i16 -24621, i16 30984, i16 -7274, i16 13983, i16 -23474, i16 11128, i16 -18559, i16 4030, i16 -29438, i16 22884, i16 16603, i16 -5437, i16 23344, i16 23968, i16 6079, i16 19797, i16 19404, i16 -30128, i16 12579, i16 13888, i16 -25241, i16 -25296, i16 3729, i16 -22983, i16 24354, i16 14074, i16 -15135, i16 -11424, i16 -28936, i16 -17901, i16 7766, i16 20953, i16 -24581, i16 -18991, i16 3574, i16 -29309, i16 -24581, i16 3027, i16 -14649, i16 -21970, i16 414, i16 8664, i16 -3920, i16 21636, i16 18637, i16 -26803, i16 -23932, i16 -12453, i16 -7462, i16 -3651, i16 22010, i16 -3233, i16 -2100, i16 -20960, i16 5954, i16 30529, i16 -8346, i16 -10708, i16 -8246, i16 -26229, i16 635, i16 28677, i16 29798, i16 13493, i16 14433, i16 16122, i16 6113, i16 29240, i16 22212, i16 16841, i16 -30165, i16 29695, i16 2862, i16 26519, i16 -13825, i16 -26725]
@values_s = external addrspace(3) global [256 x i16], align 16
declare hidden <4 x i32> @__zluda_ptx_impl_ldmatrix_m8n8_x4_trans_b16(ptr addrspace(3)) #0
declare hidden i32 @__zluda_ptx_impl_sreg_tid(i8) #0
define amdgpu_kernel void @ldmatrix_trans(ptr addrspace(4) byref(i64) %"86") #1 {
%"87" = alloca i64, align 8, addrspace(5)
%"88" = alloca i32, align 4, addrspace(5)
%"89" = alloca i64, align 8, addrspace(5)
%"90" = alloca i64, align 8, addrspace(5)
%"91" = alloca i32, align 4, addrspace(5)
%"92" = alloca i64, align 8, addrspace(5)
%"93" = alloca i32, align 4, addrspace(5)
%"94" = alloca i64, align 8, addrspace(5)
%"95" = alloca i64, align 8, addrspace(5)
%"96" = alloca i32, align 4, addrspace(5)
%"97" = alloca i32, align 4, addrspace(5)
%"98" = alloca i32, align 4, addrspace(5)
%"99" = alloca i32, align 4, addrspace(5)
%"100" = alloca <2 x i16>, align 4, addrspace(5)
%"101" = alloca <2 x i16>, align 4, addrspace(5)
%"102" = alloca <2 x i16>, align 4, addrspace(5)
%"103" = alloca <2 x i16>, align 4, addrspace(5)
%"108" = alloca i1, align 1, addrspace(5)
br label %1
1: ; preds = %0
br label %"83"
"83": ; preds = %1
%"104" = load i64, ptr addrspace(4) %"86", align 8
store i64 %"104", ptr addrspace(5) %"87", align 8
%"52" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0)
br label %"84"
"84": ; preds = %"83"
store i32 %"52", ptr addrspace(5) %"88", align 4
%"107" = load i32, ptr addrspace(5) %"88", align 4
%"106" = zext i32 %"107" to i64
store i64 %"106", ptr addrspace(5) %"89", align 8
%"110" = load i32, ptr addrspace(5) %"88", align 4
%2 = icmp uge i32 %"110", 32
store i1 %2, ptr addrspace(5) %"108", align 1
%"111" = load i1, ptr addrspace(5) %"108", align 1
br i1 %"111", label %"12", label %"32"
"32": ; preds = %"84"
store i64 ptrtoint (ptr addrspace(1) @values_g to i64), ptr addrspace(5) %"90", align 8
%"114" = load i64, ptr addrspace(5) %"89", align 8
%"115" = load i64, ptr addrspace(5) %"90", align 8
%3 = mul i64 %"114", 16
%"113" = add i64 %3, %"115"
store i64 %"113", ptr addrspace(5) %"90", align 8
%"116" = load i64, ptr addrspace(5) %"90", align 8
%"166" = inttoptr i64 %"116" to ptr addrspace(1)
%"55" = load <2 x i64>, ptr addrspace(1) %"166", align 16
%"167" = extractelement <2 x i64> %"55", i8 0
%"168" = extractelement <2 x i64> %"55", i8 1
store i64 %"167", ptr addrspace(5) %"94", align 8
store i64 %"168", ptr addrspace(5) %"95", align 8
store i32 ptrtoint (ptr addrspace(3) @values_s to i32), ptr addrspace(5) %"91", align 4
%"121" = load i32, ptr addrspace(5) %"88", align 4
%"122" = load i32, ptr addrspace(5) %"91", align 4
%4 = mul i32 %"121", 16
%"170" = add i32 %4, %"122"
store i32 %"170", ptr addrspace(5) %"91", align 4
%"123" = load i64, ptr addrspace(5) %"94", align 8
%"124" = load i64, ptr addrspace(5) %"95", align 8
%5 = insertelement <2 x i64> undef, i64 %"123", i8 0
%"57" = insertelement <2 x i64> %5, i64 %"124", i8 1
%"125" = load i32, ptr addrspace(5) %"91", align 4
%"174" = inttoptr i32 %"125" to ptr addrspace(3)
store <2 x i64> %"57", ptr addrspace(3) %"174", align 16
store i32 ptrtoint (ptr addrspace(3) @values_s to i32), ptr addrspace(5) %"91", align 4
%"128" = load i32, ptr addrspace(5) %"88", align 4
%"129" = load i32, ptr addrspace(5) %"91", align 4
%6 = mul i32 %"128", 16
%"176" = add i32 %6, %"129"
store i32 %"176", ptr addrspace(5) %"91", align 4
%"130" = load i32, ptr addrspace(5) %"91", align 4
%"178" = inttoptr i32 %"130" to ptr addrspace(3)
%"59" = call <4 x i32> @__zluda_ptx_impl_ldmatrix_m8n8_x4_trans_b16(ptr addrspace(3) %"178")
%"131" = extractelement <4 x i32> %"59", i8 0
%"132" = extractelement <4 x i32> %"59", i8 1
%"133" = extractelement <4 x i32> %"59", i8 2
%"134" = extractelement <4 x i32> %"59", i8 3
store i32 %"131", ptr addrspace(5) %"96", align 4
store i32 %"132", ptr addrspace(5) %"97", align 4
store i32 %"133", ptr addrspace(5) %"98", align 4
store i32 %"134", ptr addrspace(5) %"99", align 4
%"136" = load i64, ptr addrspace(5) %"89", align 8
%"137" = load i64, ptr addrspace(5) %"87", align 8
%7 = mul i64 %"136", 32
%"135" = add i64 %7, %"137"
store i64 %"135", ptr addrspace(5) %"87", align 8
%"139" = load i32, ptr addrspace(5) %"96", align 4
%"138" = bitcast i32 %"139" to <2 x i16>
store <2 x i16> %"138", ptr addrspace(5) %"100", align 4
%"140" = load <2 x i16>, ptr addrspace(5) %"100", align 4
%"61" = extractelement <2 x i16> %"140", i8 0
%"141" = load i64, ptr addrspace(5) %"87", align 8
%"180" = inttoptr i64 %"141" to ptr
store i16 %"61", ptr %"180", align 2
%"142" = load i64, ptr addrspace(5) %"87", align 8
%"181" = inttoptr i64 %"142" to ptr
%"63" = getelementptr inbounds i8, ptr %"181", i64 4
%"143" = load <2 x i16>, ptr addrspace(5) %"100", align 4
%"64" = extractelement <2 x i16> %"143", i8 1
store i16 %"64", ptr %"63", align 2
%"145" = load i32, ptr addrspace(5) %"97", align 4
%"144" = bitcast i32 %"145" to <2 x i16>
store <2 x i16> %"144", ptr addrspace(5) %"101", align 4
%"146" = load i64, ptr addrspace(5) %"87", align 8
%"183" = inttoptr i64 %"146" to ptr
%"66" = getelementptr inbounds i8, ptr %"183", i64 8
%"147" = load <2 x i16>, ptr addrspace(5) %"101", align 4
%"67" = extractelement <2 x i16> %"147", i8 0
store i16 %"67", ptr %"66", align 2
%"148" = load i64, ptr addrspace(5) %"87", align 8
%"184" = inttoptr i64 %"148" to ptr
%"69" = getelementptr inbounds i8, ptr %"184", i64 12
%"149" = load <2 x i16>, ptr addrspace(5) %"101", align 4
%"70" = extractelement <2 x i16> %"149", i8 1
store i16 %"70", ptr %"69", align 2
%"151" = load i32, ptr addrspace(5) %"98", align 4
%"150" = bitcast i32 %"151" to <2 x i16>
store <2 x i16> %"150", ptr addrspace(5) %"102", align 4
%"152" = load i64, ptr addrspace(5) %"87", align 8
%"186" = inttoptr i64 %"152" to ptr
%"72" = getelementptr inbounds i8, ptr %"186", i64 16
%"153" = load <2 x i16>, ptr addrspace(5) %"102", align 4
%"73" = extractelement <2 x i16> %"153", i8 0
store i16 %"73", ptr %"72", align 2
%"154" = load i64, ptr addrspace(5) %"87", align 8
%"187" = inttoptr i64 %"154" to ptr
%"75" = getelementptr inbounds i8, ptr %"187", i64 20
%"155" = load <2 x i16>, ptr addrspace(5) %"102", align 4
%"76" = extractelement <2 x i16> %"155", i8 1
store i16 %"76", ptr %"75", align 2
%"157" = load i32, ptr addrspace(5) %"99", align 4
%"156" = bitcast i32 %"157" to <2 x i16>
store <2 x i16> %"156", ptr addrspace(5) %"103", align 4
%"158" = load i64, ptr addrspace(5) %"87", align 8
%"189" = inttoptr i64 %"158" to ptr
%"78" = getelementptr inbounds i8, ptr %"189", i64 24
%"159" = load <2 x i16>, ptr addrspace(5) %"103", align 4
%"79" = extractelement <2 x i16> %"159", i8 0
store i16 %"79", ptr %"78", align 2
%"160" = load i64, ptr addrspace(5) %"87", align 8
%"190" = inttoptr i64 %"160" to ptr
%"81" = getelementptr inbounds i8, ptr %"190", i64 28
%"161" = load <2 x i16>, ptr addrspace(5) %"103", align 4
%"82" = extractelement <2 x i16> %"161", i8 1
store i16 %"82", ptr %"81", align 2
br label %"12"
"12": ; preds = %"32", %"84"
ret void
}
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View file

@ -0,0 +1,50 @@
.version 6.5
.target sm_75
.address_size 64
.global .u32 values_g[64] = {
340, 122, 527, 693, 958, 394, 668, 432, 646, 354, 761, 449, 252, 778, 218, 800,
656, 493, 659, 787, 672, 203, 343, 845, 318, 286, 206, 253, 194, 489, 29, 323,
7, 619, 998, 930, 773, 749, 172, 465, 937, 96, 88, 621, 909, 298, 283, 286,
779, 290, 429, 930, 25, 687, 423, 200, 918, 10, 515, 248, 158, 911, 270, 459
};
.shared .align 16 .u32 values_s[64];
.visible .entry ldmatrix(
.param .u64 output
)
{
.reg .u64 out_addr;
.reg .u32 tid;
.reg .u64 tid_64;
.reg .u64 values_g_addr;
.reg .b32 values_s_addr;
.reg .b64 values_s_addr_64;
.reg .u32 temp;
.reg .u32 x<2>;
ld.param.u64 out_addr, [output];
mov.b32 tid, %tid.x;
cvt.u64.u32 tid_64, tid;
mov.b64 values_g_addr, values_g;
mad.lo.u64 values_g_addr, tid_64, 4, values_g_addr;
ld.global.b32 temp, [values_g_addr];
mov.b32 values_s_addr, values_s;
mad.lo.u32 values_s_addr, tid, 4, values_s_addr;
st.shared.b32 [values_s_addr], temp;
ld.global.b32 temp, [values_g_addr+128];
st.shared.b32 [values_s_addr+128], temp;
mov.b64 values_s_addr_64, values_s;
cvta.shared.u64 values_s_addr_64, values_s_addr_64;
mad.lo.u64 values_s_addr_64, tid_64, 16, values_s_addr_64;
ldmatrix.sync.aligned.m8n8.x2.b16 {x0, x1}, [values_s_addr_64];
mad.lo.u64 out_addr, tid_64, 8, out_addr;
st.u32 [out_addr], x0;
st.u32 [out_addr+4], x1;
ret;
}

View file

@ -0,0 +1,97 @@
.version 6.5
.target sm_75
.address_size 64
.global .u16 values_g[256] = {
// matrix 1
1340, 122, 527, 693, 958, 394, 668, 432,
646, 354, 761, 449, 252, 778, 218, 800,
656, 493, 659, 787, 672, 203, 343, 845,
318, 286, 206, 253, 194, 489, 29, 323,
7, 619, 998, 930, 773, 749, 172, 465,
937, 96, 88, 621, 909, 298, 283, 286,
779, 290, 429, 930, 25, 687, 423, 200,
918, 10, 515, 248, 158, 911, 270, 459,
// matrix 2
5832, 3864, 7868, 6538, 3898, 8685, 356, 3655,
3398, 8529, 2866, 1432, 4078, 1674, 498, 1124,
1576, 6490, 9895, 2152, 9668, 7349, 1948, 6239,
7944, 7630, 9699, 1957, 3360, 2291, 3832, 7370,
2683, 7465, 3107, 9822, 2510, 1642, 3240, 8860,
4935, 1935, 9328, 5164, 2759, 4816, 1049, 725,
9774, 5110, 5071, 8047, 7267, 7716, 1622, 9645,
6382, 1210, 2742, 2248, 6789, 5282, 5653, 5407,
// matrix 3
29007, 29415, 25313, 44140, 49542, 21119, 55791, 42732,
63639, 13898, 58320, 20222, 31469, 34599, 64860, 60671,
4232, 55743, 53799, 43819, 14011, 12369, 56620, 13717,
12500, 58864, 34285, 57337, 20956, 4977, 49296, 19215,
46561, 64210, 44873, 35751, 15886, 14343, 966, 3529,
6132, 57140, 60190, 10303, 43042, 2064, 22282, 61555,
25824, 31442, 57015, 51136, 40915, 30984, 58262, 13983,
42062, 11128, 46977, 4030, 36098, 22884, 16603, 60099,
// matrix 4
23344, 23968, 6079, 19797, 19404, 35408, 12579, 13888,
40295, 40240, 3729, 42553, 24354, 14074, 50401, 54112,
36600, 47635, 7766, 20953, 40955, 46545, 3574, 36227,
40955, 3027, 50887, 43566, 414, 8664, 61616, 21636,
18637, 38733, 41604, 53083, 58074, 61885, 22010, 62303,
63436, 44576, 5954, 30529, 57190, 54828, 57290, 39307,
635, 28677, 29798, 13493, 14433, 16122, 6113, 29240,
22212, 16841, 35371, 29695, 2862, 26519, 51711, 38811
};
.shared .align 16 .u16 values_s[256];
.visible .entry ldmatrix_trans(
.param .u64 output
)
{
.reg .u64 out_addr;
.reg .u32 tid;
.reg .u64 tid_64;
.reg .u64 values_g_addr;
.reg .b32 values_s_addr;
.reg .b64 values_s_addr_64;
.reg .u32 temp;
.reg .u64 temp_64_<2>;
.reg .b32 x<4>;
.reg .v2.b16 x16_<4>;
ld.param.u64 out_addr, [output];
mov.b32 tid, %tid.x;
cvt.u64.u32 tid_64, tid;
.reg .pred not_first_warp;
setp.ge.u32 not_first_warp, tid, 32;
@not_first_warp bra END;
// copy constants from global to shared
mov.b64 values_g_addr, values_g;
mad.lo.u64 values_g_addr, tid_64, 16, values_g_addr;
ld.global.v2.b64 {temp_64_0, temp_64_1}, [values_g_addr];
mov.b32 values_s_addr, values_s;
mad.lo.u32 values_s_addr, tid, 16, values_s_addr;
st.shared.v2.b64 [values_s_addr], {temp_64_0, temp_64_1};
mov.b32 values_s_addr, values_s;
mad.lo.u32 values_s_addr, tid, 16, values_s_addr;
ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {x0, x1, x2, x3}, [values_s_addr];
mad.lo.u64 out_addr, tid_64, 32, out_addr;
mov.b32 x16_0, x0;
st.b16 [out_addr], x16_0.x;
st.b16 [out_addr+4], x16_0.y;
mov.b32 x16_1, x1;
st.b16 [out_addr+8], x16_1.x;
st.b16 [out_addr+12], x16_1.y;
mov.b32 x16_2, x2;
st.b16 [out_addr+16], x16_2.x;
st.b16 [out_addr+20], x16_2.y;
mov.b32 x16_3, x3;
st.b16 [out_addr+24], x16_3.x;
st.b16 [out_addr+28], x16_3.y;
END:
ret;
}

View file

@ -487,6 +487,40 @@ test_ptx_warp!(
752u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32
]
);
test_ptx_warp!(
ldmatrix,
[
340u32, 7u32, 122u32, 619u32, 527u32, 998u32, 693u32, 930u32, 958u32, 773u32, 394u32,
749u32, 668u32, 172u32, 432u32, 465u32, 646u32, 937u32, 354u32, 96u32, 761u32, 88u32,
449u32, 621u32, 252u32, 909u32, 778u32, 298u32, 218u32, 283u32, 800u32, 286u32, 656u32,
779u32, 493u32, 290u32, 659u32, 429u32, 787u32, 930u32, 672u32, 25u32, 203u32, 687u32,
343u32, 423u32, 845u32, 200u32, 318u32, 918u32, 286u32, 10u32, 206u32, 515u32, 253u32,
248u32, 194u32, 158u32, 489u32, 911u32, 29u32, 270u32, 323u32, 459u32
]
);
test_ptx_warp!(
ldmatrix_trans,
[
1340, 646, 5832, 3398, 29007, 63639, 23344, 40295, 656, 318, 1576, 7944, 4232, 12500,
36600, 40955, 7, 937, 2683, 4935, 46561, 6132, 18637, 63436, 779, 918, 9774, 6382, 25824,
42062, 635, 22212, 122, 354, 3864, 8529, 29415, 13898, 23968, 40240, 493, 286, 6490, 7630,
55743, 58864, 47635, 3027, 619, 96, 7465, 1935, 64210, 57140, 38733, 44576, 290, 10, 5110,
1210, 31442, 11128, 28677, 16841, 527, 761, 7868, 2866, 25313, 58320, 6079, 3729, 659, 206,
9895, 9699, 53799, 34285, 7766, 50887, 998, 88, 3107, 9328, 44873, 60190, 41604, 5954, 429,
515, 5071, 2742, 57015, 46977, 29798, 35371, 693, 449, 6538, 1432, 44140, 20222, 19797,
42553, 787, 253, 2152, 1957, 43819, 57337, 20953, 43566, 930, 621, 9822, 5164, 35751,
10303, 53083, 30529, 930, 248, 8047, 2248, 51136, 4030, 13493, 29695, 958, 252, 3898, 4078,
49542, 31469, 19404, 24354, 672, 194, 9668, 3360, 14011, 20956, 40955, 414, 773, 909, 2510,
2759, 15886, 43042, 58074, 57190, 25, 158, 7267, 6789, 40915, 36098, 14433, 2862, 394, 778,
8685, 1674, 21119, 34599, 35408, 14074, 203, 489, 7349, 2291, 12369, 4977, 46545, 8664,
749, 298, 1642, 4816, 14343, 2064, 61885, 54828, 687, 911, 7716, 5282, 30984, 22884, 16122,
26519, 668, 218, 356, 498, 55791, 64860, 12579, 50401, 343, 29, 1948, 3832, 56620, 49296,
3574, 61616, 172, 283, 3240, 1049, 966, 22282, 22010, 57290, 423, 270, 1622, 5653, 58262,
16603, 6113, 51711, 432, 800, 3655, 1124, 42732, 60671, 13888, 54112, 845, 323, 6239, 7370,
13717, 19215, 36227, 21636, 465, 286, 8860, 725, 3529, 61555, 62303, 39307, 200, 459, 9645,
5407, 13983, 60099, 29240, 38811
]
);
struct DisplayError<T: Debug> {
err: T,

View file

@ -3,8 +3,8 @@ use super::{
StateSpace, VectorPrefix,
};
use crate::{
FunnelShiftMode, Mul24Control, PtxError, PtxParserState, Reduction, ShiftDirection,
ShuffleMode, VoteMode,
FunnelShiftMode, MatrixNumber, MatrixShape, Mul24Control, PtxError, PtxParserState, Reduction,
ShiftDirection, ShuffleMode, VoteMode,
};
use bitflags::bitflags;
use derive_more::Display;
@ -707,6 +707,20 @@ ptx_parser_macros::generate_instruction_type!(
type: { Type::Scalar(ScalarType::U32) },
}
}
},
LdMatrix {
type: data.get_loaded_type(),
data: LdMatrixDetails,
arguments<T>: {
dst: {
repr: T,
relaxed_type_check: true,
},
src: {
repr: T,
space: { data.state_space },
}
}
}
}
);
@ -1460,6 +1474,47 @@ pub struct LdDetails {
pub non_coherent: bool,
}
impl MatrixNumber {
fn get(&self) -> u8 {
match self {
MatrixNumber::X1 => 1,
MatrixNumber::X2 => 2,
MatrixNumber::X4 => 4,
}
}
}
#[derive(Copy, Clone)]
pub struct LdMatrixDetails {
pub shape: MatrixShape,
pub number: MatrixNumber,
pub transpose: bool,
pub state_space: StateSpace,
pub type_: ScalarType,
}
impl LdMatrixDetails {
pub fn new(
shape: MatrixShape,
number: MatrixNumber,
transpose: bool,
ss: Option<StateSpace>,
type_: ScalarType,
) -> Self {
Self {
shape,
number,
transpose,
state_space: ss.unwrap_or(StateSpace::Shared),
type_,
}
}
pub fn get_loaded_type(&self) -> Type {
let count = self.number.get();
Type::Vector(count, ScalarType::B32)
}
}
pub struct StData {
pub qualifier: LdStQualifier,
pub state_space: StateSpace,

View file

@ -1856,6 +1856,12 @@ derive_parser!(
Ballot
}
#[derive(Copy, Clone, Display, PartialEq, Eq, Hash)]
pub enum MatrixShape { }
#[derive(Copy, Clone, Display, PartialEq, Eq, Hash)]
pub enum MatrixNumber { }
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov
mov{.vec}.type d, a => {
Instruction::Mov {
@ -3870,6 +3876,27 @@ derive_parser!(
// redux.sync.op{.abs.}{.NaN}.f32 dst, src, membermask;
// .op = { .min, .max }
// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-ldmatrix
ldmatrix.sync.aligned.shape.num{.trans}{.ss}.type r, [p] => {
let data = LdMatrixDetails::new(shape, num, trans, ss, type_);
Instruction::LdMatrix {
data,
arguments: LdMatrixArgs {
dst: r,
src: p
}
}
}
// ldmatrix.sync.aligned.m8n16.num{.ss}.dst_fmt.src_fmt r, [p];
// ldmatrix.sync.aligned.m16n16.num.trans{.ss}.dst_fmt.src_fmt r, [p];
.shape: MatrixShape = {.m8n8, .m16n16};
.num: MatrixNumber = {.x1, .x2, .x4};
.ss: StateSpace = {.shared{::cta}};
.type: ScalarType = {.b16, .b8};
// .dst_fmt = { .b8x16 };
// .src_fmt = { .b6x16_p32, .b4x16_p64 };
);
#[cfg(test)]