From b7f3a647d7add50bb618d3b6bce7b0669dfb0fcc Mon Sep 17 00:00:00 2001 From: Violet Date: Thu, 4 Sep 2025 17:29:20 -0700 Subject: [PATCH] Implement `fma.rn.fn.bf16x2` (#496) * Add fma bf16x2 test * Implement fma.rn.fn.bf16x2 * cargo fmt --- ptx/src/pass/llvm/emit.rs | 2 +- ptx/src/pass/llvm/mod.rs | 2 +- ptx/src/test/ll/fma_bf16x2.ll | 51 +++++++++++++++++++++++++++ ptx/src/test/spirv_run/fma_bf16x2.ptx | 25 +++++++++++++ ptx/src/test/spirv_run/mod.rs | 5 +++ ptx_parser/src/lib.rs | 22 ++++++++++-- 6 files changed, 102 insertions(+), 5 deletions(-) create mode 100644 ptx/src/test/ll/fma_bf16x2.ll create mode 100644 ptx/src/test/spirv_run/fma_bf16x2.ptx diff --git a/ptx/src/pass/llvm/emit.rs b/ptx/src/pass/llvm/emit.rs index 2403d90..dc9befa 100644 --- a/ptx/src/pass/llvm/emit.rs +++ b/ptx/src/pass/llvm/emit.rs @@ -3133,7 +3133,7 @@ impl std::fmt::Display for LLVMTypeDisplay { ast::ScalarType::F64 => write!(f, "f64"), ptx_parser::ScalarType::S16x2 | ptx_parser::ScalarType::U16x2 => write!(f, "v2i16"), ast::ScalarType::F16x2 => write!(f, "v2f16"), - ptx_parser::ScalarType::BF16x2 => write!(f, "v2bfloat"), + ptx_parser::ScalarType::BF16x2 => write!(f, "v2bf16"), } } } diff --git a/ptx/src/pass/llvm/mod.rs b/ptx/src/pass/llvm/mod.rs index 24f790e..40781fc 100644 --- a/ptx/src/pass/llvm/mod.rs +++ b/ptx/src/pass/llvm/mod.rs @@ -172,7 +172,7 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR ast::ScalarType::U16x2 => todo!(), ast::ScalarType::S16x2 => todo!(), ast::ScalarType::F16x2 => unsafe { LLVMVectorType(LLVMHalfTypeInContext(context), 2) }, - ast::ScalarType::BF16x2 => todo!(), + ast::ScalarType::BF16x2 => unsafe { LLVMVectorType(LLVMBFloatTypeInContext(context), 2) }, } } diff --git a/ptx/src/test/ll/fma_bf16x2.ll b/ptx/src/test/ll/fma_bf16x2.ll new file mode 100644 index 0000000..ff7a638 --- /dev/null +++ b/ptx/src/test/ll/fma_bf16x2.ll @@ -0,0 +1,51 @@ +define amdgpu_kernel void @fma_bf16x2(ptr addrspace(4) byref(i64) %"39", ptr addrspace(4) byref(i64) %"40") #0 { + %"41" = alloca i64, align 8, addrspace(5) + %"42" = alloca i64, align 8, addrspace(5) + %"43" = alloca i32, align 4, addrspace(5) + %"44" = alloca i32, align 4, addrspace(5) + %"45" = alloca i32, align 4, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"38" + +"38": ; preds = %1 + %"46" = load i64, ptr addrspace(4) %"39", align 8 + store i64 %"46", ptr addrspace(5) %"41", align 8 + %"47" = load i64, ptr addrspace(4) %"40", align 8 + store i64 %"47", ptr addrspace(5) %"42", align 8 + %"49" = load i64, ptr addrspace(5) %"41", align 8 + %"60" = inttoptr i64 %"49" to ptr + %"48" = load i32, ptr %"60", align 4 + store i32 %"48", ptr addrspace(5) %"43", align 4 + %"50" = load i64, ptr addrspace(5) %"41", align 8 + %"61" = inttoptr i64 %"50" to ptr + %"35" = getelementptr inbounds i8, ptr %"61", i64 4 + %"51" = load i32, ptr %"35", align 4 + store i32 %"51", ptr addrspace(5) %"44", align 4 + %"52" = load i64, ptr addrspace(5) %"41", align 8 + %"62" = inttoptr i64 %"52" to ptr + %"37" = getelementptr inbounds i8, ptr %"62", i64 8 + %"53" = load i32, ptr %"37", align 4 + store i32 %"53", ptr addrspace(5) %"45", align 4 + %"55" = load i32, ptr addrspace(5) %"43", align 4 + %"56" = load i32, ptr addrspace(5) %"44", align 4 + %"57" = load i32, ptr addrspace(5) %"45", align 4 + %"64" = bitcast i32 %"55" to <2 x bfloat> + %"65" = bitcast i32 %"56" to <2 x bfloat> + %"66" = bitcast i32 %"57" to <2 x bfloat> + %"63" = call <2 x bfloat> @llvm.fma.v2bf16(<2 x bfloat> %"64", <2 x bfloat> %"65", <2 x bfloat> %"66") + %"54" = bitcast <2 x bfloat> %"63" to i32 + store i32 %"54", ptr addrspace(5) %"43", align 4 + %"58" = load i64, ptr addrspace(5) %"42", align 8 + %"59" = load i32, ptr addrspace(5) %"43", align 4 + %"67" = inttoptr i64 %"58" to ptr + store i32 %"59", ptr %"67", align 4 + ret void +} + +; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare <2 x bfloat> @llvm.fma.v2bf16(<2 x bfloat>, <2 x bfloat>, <2 x bfloat>) #1 + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) } diff --git a/ptx/src/test/spirv_run/fma_bf16x2.ptx b/ptx/src/test/spirv_run/fma_bf16x2.ptx new file mode 100644 index 0000000..ac112bd --- /dev/null +++ b/ptx/src/test/spirv_run/fma_bf16x2.ptx @@ -0,0 +1,25 @@ +.version 7.0 +.target sm_80 +.address_size 64 + +.visible .entry fma_bf16x2( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .b32 temp1; + .reg .b32 temp2; + .reg .b32 temp3; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.b32 temp1, [in_addr]; + ld.b32 temp2, [in_addr+4]; + ld.b32 temp3, [in_addr+8]; + fma.rn.bf16x2 temp1, temp1, temp2, temp3; + st.b32 [out_addr], temp1; + ret; +} diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 39e58a5..a0758ff 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -166,6 +166,11 @@ test_ptx!(and, [6u32, 3u32], [2u32]); test_ptx!(selp, [100u16, 200u16], [200u16]); test_ptx!(selp_true, [100u16, 200u16], [100u16]); test_ptx!(fma, [2f32, 3f32, 5f32], [11f32]); +test_ptx!( + fma_bf16x2, + [0x40004040, 0x40404080, 0x40A04040], + [0x41304170] +); test_ptx!(shared_variable, [513u64], [513u64]); test_ptx!(shared_ptr_32, [513u64], [513u64]); test_ptx!(atom_cas, [91u32, 91u32], [91u32, 100u32]); diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index 6e4e167..261849e 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -2712,14 +2712,30 @@ derive_parser!( arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c } } } + .rnd: RawRoundingMode = { .rn }; + ScalarType = { .f16 }; //fma.rnd{.ftz}{.sat}.f16x2 d, a, b, c; //fma.rnd{.ftz}.relu.f16 d, a, b, c; //fma.rnd{.ftz}.relu.f16x2 d, a, b, c; //fma.rnd{.relu}.bf16 d, a, b, c; - //fma.rnd{.relu}.bf16x2 d, a, b, c; - //fma.rnd.oob.{relu}.type d, a, b, c; + fma.rnd{.relu}.bf16x2 d, a, b, c => { + if relu { + state.errors.push(PtxError::Todo); + } + ast::Instruction::Fma { + data: ast::ArithFloat { + type_: bf16x2, + rounding: rnd.into(), + flush_to_zero: None, + saturate: false, + is_fusable: false + }, + arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c } + } + } .rnd: RawRoundingMode = { .rn }; - ScalarType = { .f16 }; + ScalarType = { .bf16x2 }; + //fma.rnd.oob.{relu}.type d, a, b, c; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-sub // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sub