Add support for cvt_rn_bf16x2_f32 (#501)
Some checks are pending
ZLUDA / Build (Linux) (push) Waiting to run
ZLUDA / Build (Windows) (push) Waiting to run
ZLUDA / Build AMD GPU unit tests (push) Waiting to run
ZLUDA / Run AMD GPU unit tests (push) Blocked by required conditions

This commit is contained in:
Violet 2025-09-08 17:41:24 -07:00 committed by GitHub
commit d81456a549
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 145 additions and 5 deletions

View file

@ -1646,9 +1646,39 @@ impl<'a> MethodEmitContext<'a> {
}
};
let src = self.resolver.value(arguments.src)?;
self.resolver.with_result(arguments.dst, |dst| unsafe {
llvm_fn(self.builder, src, dst_type, dst)
});
if let Some(src2) = arguments.src2 {
let packed_type = get_scalar_type(
self.context,
data.to
.packed_type()
.ok_or_else(|| error_mismatched_type())?,
);
let src2 = self.resolver.value(src2)?;
self.resolver.with_result(arguments.dst, |dst| {
let vec = unsafe {
LLVMBuildInsertElement(
self.builder,
LLVMGetPoison(dst_type),
llvm_fn(self.builder, src, packed_type, LLVM_UNNAMED.as_ptr()),
LLVMConstInt(LLVMInt32TypeInContext(self.context), 1, false as i32),
LLVM_UNNAMED.as_ptr(),
)
};
unsafe {
LLVMBuildInsertElement(
self.builder,
vec,
llvm_fn(self.builder, src2, packed_type, LLVM_UNNAMED.as_ptr()),
LLVMConstInt(LLVMInt32TypeInContext(self.context), 0, false as i32),
dst,
)
}
})
} else {
self.resolver.with_result(arguments.dst, |dst| unsafe {
llvm_fn(self.builder, src, dst_type, dst)
})
};
Ok(())
}

View file

@ -0,0 +1,41 @@
define amdgpu_kernel void @cvt_rn_bf16x2_f32(ptr addrspace(4) byref(i64) %"37", ptr addrspace(4) byref(i64) %"38") #0 {
%"39" = alloca i64, align 8, addrspace(5)
%"40" = alloca i64, align 8, addrspace(5)
%"41" = alloca float, align 4, addrspace(5)
%"42" = alloca float, align 4, addrspace(5)
%"43" = alloca i32, align 4, addrspace(5)
br label %1
1: ; preds = %0
br label %"36"
"36": ; preds = %1
%"44" = load i64, ptr addrspace(4) %"37", align 8
store i64 %"44", ptr addrspace(5) %"39", align 8
%"45" = load i64, ptr addrspace(4) %"38", align 8
store i64 %"45", ptr addrspace(5) %"40", align 8
%"47" = load i64, ptr addrspace(5) %"39", align 8
%"55" = inttoptr i64 %"47" to ptr
%"46" = load float, ptr %"55", align 4
store float %"46", ptr addrspace(5) %"41", align 4
%"48" = load i64, ptr addrspace(5) %"39", align 8
%"56" = inttoptr i64 %"48" to ptr
%"35" = getelementptr inbounds i8, ptr %"56", i64 4
%"49" = load float, ptr %"35", align 4
store float %"49", ptr addrspace(5) %"42", align 4
%"51" = load float, ptr addrspace(5) %"41", align 4
%"52" = load float, ptr addrspace(5) %"42", align 4
%2 = fptrunc float %"51" to bfloat
%3 = insertelement <2 x bfloat> poison, bfloat %2, i32 1
%4 = fptrunc float %"52" to bfloat
%"57" = insertelement <2 x bfloat> %3, bfloat %4, i32 0
%"50" = bitcast <2 x bfloat> %"57" to i32
store i32 %"50", ptr addrspace(5) %"43", align 4
%"53" = load i64, ptr addrspace(5) %"40", align 8
%"54" = load i32, ptr addrspace(5) %"43", align 4
%"58" = inttoptr i64 %"53" to ptr
store i32 %"54", ptr %"58", align 4
ret void
}
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee" "denormal-fp-math-f32"="ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View file

@ -0,0 +1,25 @@
.version 7.8
.target sm_90
.address_size 64
.visible .entry cvt_rn_bf16x2_f32(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .f32 in_a;
.reg .f32 in_b;
.reg .b32 result;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.f32 in_a, [in_addr];
ld.f32 in_b, [in_addr + 4];
cvt.rn.bf16x2.f32 result, in_a, in_b;
st.b32 [out_addr], result;
ret;
}

View file

@ -200,6 +200,7 @@ test_ptx!(
);
test_ptx!(cvt_rn_f16x2_e4m3x2, [0x2D55u16], [0x36804a80u32]);
test_ptx!(cvt_rn_f16x2_e5m2x2, [0x36EDu16], [0x3600ED00u32]);
test_ptx!(cvt_rn_bf16x2_f32, [0.40625, 12.9f32], [0x3ED0414Eu32]);
test_ptx!(clz, [0b00000101_00101101_00010011_10101011u32], [5u32]);
test_ptx!(popc, [0b10111100_10010010_01001001_10001010u32], [14u32]);
test_ptx!(

View file

@ -1174,6 +1174,35 @@ impl ScalarType {
ScalarType::Pred => ScalarKind::Pred,
}
}
pub fn packed_type(&self) -> Option<ScalarType> {
match self {
ScalarType::E4m3x2 => Some(ScalarType::B8),
ScalarType::E5m2x2 => Some(ScalarType::B8),
ScalarType::F16x2 => Some(ScalarType::F16),
ScalarType::BF16x2 => Some(ScalarType::BF16),
ScalarType::U16x2 => Some(ScalarType::U16),
ScalarType::S16x2 => Some(ScalarType::S16),
ScalarType::S16
| ScalarType::BF16
| ScalarType::U32
| ScalarType::S8
| ScalarType::S32
| ScalarType::Pred
| ScalarType::B8
| ScalarType::U64
| ScalarType::B16
| ScalarType::S64
| ScalarType::B32
| ScalarType::U8
| ScalarType::F32
| ScalarType::B64
| ScalarType::B128
| ScalarType::U16
| ScalarType::F64
| ScalarType::F16 => None,
}
}
}
#[derive(Clone, Copy, PartialEq, Eq)]
@ -1945,8 +1974,13 @@ impl CvtDetails {
(RoundingMode::NearestEven, false)
}
};
let dst_size = if dst.packed_type().is_some() {
dst.size_of() / 2
} else {
dst.size_of()
};
let mode = match (dst.kind(), src.kind()) {
(ScalarKind::Float, ScalarKind::Float) => match dst.size_of().cmp(&src.size_of()) {
(ScalarKind::Float, ScalarKind::Float) => match dst_size.cmp(&src.size_of()) {
Ordering::Less => {
let (rounding, is_integer_rounding) = unwrap_rounding();
CvtMode::FPTruncate {

View file

@ -2442,7 +2442,16 @@ derive_parser!(
// cvt.frnd2{.relu}{.satfinite}.f16.f32 d, a;
// cvt.frnd2{.relu}{.satfinite}.f16x2.f32 d, a, b;
// cvt.frnd2{.relu}{.satfinite}.bf16.f32 d, a;
// cvt.frnd2{.relu}{.satfinite}.bf16x2.f32 d, a, b;
cvt.frnd2{.relu}{.satfinite}.bf16x2.f32 d, a, b => {
if relu || satfinite {
state.errors.push(PtxError::Todo);
}
let data = ast::CvtDetails::new(&mut state.errors, Some(frnd2), false, false, ScalarType::BF16x2, ScalarType::F32);
ast::Instruction::Cvt {
data,
arguments: ast::CvtArgs { dst: d, src: a, src2: Some(b) }
}
}
// cvt.rna{.satfinite}.tf32.f32 d, a;
// cvt.frnd2{.relu}.tf32.f32 d, a;
cvt.rn.satfinite{.relu}.f8x2type.f32 d, a, b => {