mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-09-11 03:56:37 +00:00
Implement fma.rn.fn.bf16x2
(#496)
* Add fma bf16x2 test * Implement fma.rn.fn.bf16x2 * cargo fmt
This commit is contained in:
parent
5309065cc1
commit
b7f3a647d7
6 changed files with 102 additions and 5 deletions
|
@ -3133,7 +3133,7 @@ impl std::fmt::Display for LLVMTypeDisplay {
|
|||
ast::ScalarType::F64 => write!(f, "f64"),
|
||||
ptx_parser::ScalarType::S16x2 | ptx_parser::ScalarType::U16x2 => write!(f, "v2i16"),
|
||||
ast::ScalarType::F16x2 => write!(f, "v2f16"),
|
||||
ptx_parser::ScalarType::BF16x2 => write!(f, "v2bfloat"),
|
||||
ptx_parser::ScalarType::BF16x2 => write!(f, "v2bf16"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -172,7 +172,7 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR
|
|||
ast::ScalarType::U16x2 => todo!(),
|
||||
ast::ScalarType::S16x2 => todo!(),
|
||||
ast::ScalarType::F16x2 => unsafe { LLVMVectorType(LLVMHalfTypeInContext(context), 2) },
|
||||
ast::ScalarType::BF16x2 => todo!(),
|
||||
ast::ScalarType::BF16x2 => unsafe { LLVMVectorType(LLVMBFloatTypeInContext(context), 2) },
|
||||
}
|
||||
}
|
||||
|
||||
|
|
51
ptx/src/test/ll/fma_bf16x2.ll
Normal file
51
ptx/src/test/ll/fma_bf16x2.ll
Normal file
|
@ -0,0 +1,51 @@
|
|||
define amdgpu_kernel void @fma_bf16x2(ptr addrspace(4) byref(i64) %"39", ptr addrspace(4) byref(i64) %"40") #0 {
|
||||
%"41" = alloca i64, align 8, addrspace(5)
|
||||
%"42" = alloca i64, align 8, addrspace(5)
|
||||
%"43" = alloca i32, align 4, addrspace(5)
|
||||
%"44" = alloca i32, align 4, addrspace(5)
|
||||
%"45" = alloca i32, align 4, addrspace(5)
|
||||
br label %1
|
||||
|
||||
1: ; preds = %0
|
||||
br label %"38"
|
||||
|
||||
"38": ; preds = %1
|
||||
%"46" = load i64, ptr addrspace(4) %"39", align 8
|
||||
store i64 %"46", ptr addrspace(5) %"41", align 8
|
||||
%"47" = load i64, ptr addrspace(4) %"40", align 8
|
||||
store i64 %"47", ptr addrspace(5) %"42", align 8
|
||||
%"49" = load i64, ptr addrspace(5) %"41", align 8
|
||||
%"60" = inttoptr i64 %"49" to ptr
|
||||
%"48" = load i32, ptr %"60", align 4
|
||||
store i32 %"48", ptr addrspace(5) %"43", align 4
|
||||
%"50" = load i64, ptr addrspace(5) %"41", align 8
|
||||
%"61" = inttoptr i64 %"50" to ptr
|
||||
%"35" = getelementptr inbounds i8, ptr %"61", i64 4
|
||||
%"51" = load i32, ptr %"35", align 4
|
||||
store i32 %"51", ptr addrspace(5) %"44", align 4
|
||||
%"52" = load i64, ptr addrspace(5) %"41", align 8
|
||||
%"62" = inttoptr i64 %"52" to ptr
|
||||
%"37" = getelementptr inbounds i8, ptr %"62", i64 8
|
||||
%"53" = load i32, ptr %"37", align 4
|
||||
store i32 %"53", ptr addrspace(5) %"45", align 4
|
||||
%"55" = load i32, ptr addrspace(5) %"43", align 4
|
||||
%"56" = load i32, ptr addrspace(5) %"44", align 4
|
||||
%"57" = load i32, ptr addrspace(5) %"45", align 4
|
||||
%"64" = bitcast i32 %"55" to <2 x bfloat>
|
||||
%"65" = bitcast i32 %"56" to <2 x bfloat>
|
||||
%"66" = bitcast i32 %"57" to <2 x bfloat>
|
||||
%"63" = call <2 x bfloat> @llvm.fma.v2bf16(<2 x bfloat> %"64", <2 x bfloat> %"65", <2 x bfloat> %"66")
|
||||
%"54" = bitcast <2 x bfloat> %"63" to i32
|
||||
store i32 %"54", ptr addrspace(5) %"43", align 4
|
||||
%"58" = load i64, ptr addrspace(5) %"42", align 8
|
||||
%"59" = load i32, ptr addrspace(5) %"43", align 4
|
||||
%"67" = inttoptr i64 %"58" to ptr
|
||||
store i32 %"59", ptr %"67", align 4
|
||||
ret void
|
||||
}
|
||||
|
||||
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
|
||||
declare <2 x bfloat> @llvm.fma.v2bf16(<2 x bfloat>, <2 x bfloat>, <2 x bfloat>) #1
|
||||
|
||||
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
||||
attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
|
25
ptx/src/test/spirv_run/fma_bf16x2.ptx
Normal file
25
ptx/src/test/spirv_run/fma_bf16x2.ptx
Normal file
|
@ -0,0 +1,25 @@
|
|||
.version 7.0
|
||||
.target sm_80
|
||||
.address_size 64
|
||||
|
||||
.visible .entry fma_bf16x2(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .b32 temp1;
|
||||
.reg .b32 temp2;
|
||||
.reg .b32 temp3;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
ld.b32 temp1, [in_addr];
|
||||
ld.b32 temp2, [in_addr+4];
|
||||
ld.b32 temp3, [in_addr+8];
|
||||
fma.rn.bf16x2 temp1, temp1, temp2, temp3;
|
||||
st.b32 [out_addr], temp1;
|
||||
ret;
|
||||
}
|
|
@ -166,6 +166,11 @@ test_ptx!(and, [6u32, 3u32], [2u32]);
|
|||
test_ptx!(selp, [100u16, 200u16], [200u16]);
|
||||
test_ptx!(selp_true, [100u16, 200u16], [100u16]);
|
||||
test_ptx!(fma, [2f32, 3f32, 5f32], [11f32]);
|
||||
test_ptx!(
|
||||
fma_bf16x2,
|
||||
[0x40004040, 0x40404080, 0x40A04040],
|
||||
[0x41304170]
|
||||
);
|
||||
test_ptx!(shared_variable, [513u64], [513u64]);
|
||||
test_ptx!(shared_ptr_32, [513u64], [513u64]);
|
||||
test_ptx!(atom_cas, [91u32, 91u32], [91u32, 100u32]);
|
||||
|
|
|
@ -2712,14 +2712,30 @@ derive_parser!(
|
|||
arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c }
|
||||
}
|
||||
}
|
||||
.rnd: RawRoundingMode = { .rn };
|
||||
ScalarType = { .f16 };
|
||||
//fma.rnd{.ftz}{.sat}.f16x2 d, a, b, c;
|
||||
//fma.rnd{.ftz}.relu.f16 d, a, b, c;
|
||||
//fma.rnd{.ftz}.relu.f16x2 d, a, b, c;
|
||||
//fma.rnd{.relu}.bf16 d, a, b, c;
|
||||
//fma.rnd{.relu}.bf16x2 d, a, b, c;
|
||||
//fma.rnd.oob.{relu}.type d, a, b, c;
|
||||
fma.rnd{.relu}.bf16x2 d, a, b, c => {
|
||||
if relu {
|
||||
state.errors.push(PtxError::Todo);
|
||||
}
|
||||
ast::Instruction::Fma {
|
||||
data: ast::ArithFloat {
|
||||
type_: bf16x2,
|
||||
rounding: rnd.into(),
|
||||
flush_to_zero: None,
|
||||
saturate: false,
|
||||
is_fusable: false
|
||||
},
|
||||
arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c }
|
||||
}
|
||||
}
|
||||
.rnd: RawRoundingMode = { .rn };
|
||||
ScalarType = { .f16 };
|
||||
ScalarType = { .bf16x2 };
|
||||
//fma.rnd.oob.{relu}.type d, a, b, c;
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-sub
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sub
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue