Implement mul24

This commit is contained in:
Joëlle van Essen 2025-03-30 14:33:34 +02:00
parent d704e92c97
commit 4901aba163
No known key found for this signature in database
GPG key ID: 28D3B5CDD4B43882
11 changed files with 293 additions and 39 deletions

View file

@ -36,6 +36,7 @@ use llvm_zluda::bit_writer::LLVMWriteBitcodeToMemoryBuffer;
use llvm_zluda::{core::*, *};
use llvm_zluda::{prelude::*, LLVMZludaBuildAtomicRMW};
use llvm_zluda::{LLVMCallConv, LLVMZludaBuildAlloca};
use ptx_parser::Mul24Control;
const LLVM_UNNAMED: &CStr = c"";
// https://llvm.org/docs/AMDGPUUsage.html#address-spaces
@ -2281,8 +2282,13 @@ impl<'a> MethodEmitContext<'a> {
) -> Result<(), TranslateError> {
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
self.emit_intrinsic(
c"llvm.amdgcn.mul.u24",
let name_lo = match data.type_ {
ast::ScalarType::U32 => c"llvm.amdgcn.mul.u24",
ast::ScalarType::S32 => c"llvm.amdgcn.mul.i24",
_ => return Err(error_unreachable()),
};
let res_lo = self.emit_intrinsic(
name_lo,
Some(arguments.dst),
Some(&ast::Type::Scalar(data.type_)),
vec![
@ -2290,6 +2296,37 @@ impl<'a> MethodEmitContext<'a> {
(src2, get_scalar_type(self.context, data.type_)),
],
)?;
if data.control == Mul24Control::Hi {
// There is an important difference between NVIDIA's mul24 and AMD's mulhi.[ui]24.
// NVIDIA: Returns bits 47..16 of the 64-bit result
// AMD: Returns bits 63..32 of the 64-bit result
// Hence we need to compute both hi and lo, shift the results and add them together to replicate NVIDIA's mul24
let name_hi = match data.type_ {
ast::ScalarType::U32 => c"llvm.amdgcn.mulhi.u24",
ast::ScalarType::S32 => c"llvm.amdgcn.mulhi.i24",
_ => return Err(error_unreachable()),
};
let res_hi = self.emit_intrinsic(
name_hi,
None,
Some(&ast::Type::Scalar(data.type_)),
vec![
(src1, get_scalar_type(self.context, data.type_)),
(src2, get_scalar_type(self.context, data.type_)),
],
)?;
let shift_number = unsafe { LLVMConstInt(LLVMInt32TypeInContext(self.context), 16, 0) };
let res_lo_shr = unsafe {
LLVMBuildLShr(self.builder, res_lo, shift_number, c"res_lo_shr".as_ptr())
};
let res_hi_shl =
unsafe { LLVMBuildShl(self.builder, res_hi, shift_number, c"res_hi_shl".as_ptr()) };
self.resolver
.with_result(arguments.dst, |dst: *const i8| unsafe {
LLVMBuildAdd(self.builder, res_lo_shr, res_hi_shl, dst)
});
}
Ok(())
}

View file

@ -1,34 +0,0 @@
define amdgpu_kernel void @mul24(ptr addrspace(4) byref(i64) %"32", ptr addrspace(4) byref(i64) %"33") #0 {
%"34" = alloca i64, align 8, addrspace(5)
%"35" = alloca i64, align 8, addrspace(5)
%"36" = alloca i32, align 4, addrspace(5)
%"37" = alloca i32, align 4, addrspace(5)
br label %1
1: ; preds = %0
br label %"31"
"31": ; preds = %1
%"38" = load i64, ptr addrspace(4) %"32", align 4
store i64 %"38", ptr addrspace(5) %"34", align 4
%"39" = load i64, ptr addrspace(4) %"33", align 4
store i64 %"39", ptr addrspace(5) %"35", align 4
%"41" = load i64, ptr addrspace(5) %"34", align 4
%"46" = inttoptr i64 %"41" to ptr
%"40" = load i32, ptr %"46", align 4
store i32 %"40", ptr addrspace(5) %"36", align 4
%"43" = load i32, ptr addrspace(5) %"36", align 4
%"42" = call i32 @llvm.amdgcn.mul.u24(i32 %"43", i32 2)
store i32 %"42", ptr addrspace(5) %"37", align 4
%"44" = load i64, ptr addrspace(5) %"35", align 4
%"45" = load i32, ptr addrspace(5) %"37", align 4
%"47" = inttoptr i64 %"44" to ptr
store i32 %"45", ptr %"47", align 4
ret void
}
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare i32 @llvm.amdgcn.mul.u24(i32, i32) #1
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }

View file

@ -0,0 +1,46 @@
define amdgpu_kernel void @mul24_hi_s32(ptr addrspace(4) byref(i64) %"32", ptr addrspace(4) byref(i64) %"33") #0 {
%"34" = alloca i64, align 8, addrspace(5)
%"35" = alloca i64, align 8, addrspace(5)
%"36" = alloca i32, align 4, addrspace(5)
%"37" = alloca i32, align 4, addrspace(5)
%"38" = alloca i32, align 4, addrspace(5)
br label %1
1: ; preds = %0
br label %"31"
"31": ; preds = %1
%"39" = load i64, ptr addrspace(4) %"32", align 4
store i64 %"39", ptr addrspace(5) %"34", align 4
%"40" = load i64, ptr addrspace(4) %"33", align 4
store i64 %"40", ptr addrspace(5) %"35", align 4
%"42" = load i64, ptr addrspace(5) %"34", align 4
%"50" = inttoptr i64 %"42" to ptr
%"41" = load i32, ptr %"50", align 4
store i32 %"41", ptr addrspace(5) %"36", align 4
%"44" = load i32, ptr addrspace(5) %"36", align 4
%"43" = sub i32 0, %"44"
store i32 %"43", ptr addrspace(5) %"37", align 4
%"46" = load i32, ptr addrspace(5) %"37", align 4
%"47" = load i32, ptr addrspace(5) %"36", align 4
%"45" = call i32 @llvm.amdgcn.mul.i24(i32 %"46", i32 %"47")
%2 = call i32 @llvm.amdgcn.mulhi.i24(i32 %"46", i32 %"47")
%res_lo_shr = lshr i32 %"45", 16
%res_hi_shl = shl i32 %2, 16
%"451" = add i32 %res_lo_shr, %res_hi_shl
store i32 %"451", ptr addrspace(5) %"38", align 4
%"48" = load i64, ptr addrspace(5) %"35", align 4
%"49" = load i32, ptr addrspace(5) %"38", align 4
%"51" = inttoptr i64 %"48" to ptr
store i32 %"49", ptr %"51", align 4
ret void
}
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare i32 @llvm.amdgcn.mul.i24(i32, i32) #1
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare i32 @llvm.amdgcn.mulhi.i24(i32, i32) #1
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }

View file

@ -0,0 +1,42 @@
define amdgpu_kernel void @mul24_hi_u32(ptr addrspace(4) byref(i64) %"31", ptr addrspace(4) byref(i64) %"32") #0 {
%"33" = alloca i64, align 8, addrspace(5)
%"34" = alloca i64, align 8, addrspace(5)
%"35" = alloca i32, align 4, addrspace(5)
%"36" = alloca i32, align 4, addrspace(5)
br label %1
1: ; preds = %0
br label %"30"
"30": ; preds = %1
%"37" = load i64, ptr addrspace(4) %"31", align 4
store i64 %"37", ptr addrspace(5) %"33", align 4
%"38" = load i64, ptr addrspace(4) %"32", align 4
store i64 %"38", ptr addrspace(5) %"34", align 4
%"40" = load i64, ptr addrspace(5) %"33", align 4
%"46" = inttoptr i64 %"40" to ptr
%"39" = load i32, ptr %"46", align 4
store i32 %"39", ptr addrspace(5) %"35", align 4
%"42" = load i32, ptr addrspace(5) %"35", align 4
%"43" = load i32, ptr addrspace(5) %"35", align 4
%"41" = call i32 @llvm.amdgcn.mul.u24(i32 %"42", i32 %"43")
%2 = call i32 @llvm.amdgcn.mulhi.u24(i32 %"42", i32 %"43")
%res_lo_shr = lshr i32 %"41", 16
%res_hi_shl = shl i32 %2, 16
%"411" = add i32 %res_lo_shr, %res_hi_shl
store i32 %"411", ptr addrspace(5) %"36", align 4
%"44" = load i64, ptr addrspace(5) %"34", align 4
%"45" = load i32, ptr addrspace(5) %"36", align 4
%"47" = inttoptr i64 %"44" to ptr
store i32 %"45", ptr %"47", align 4
ret void
}
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare i32 @llvm.amdgcn.mul.u24(i32, i32) #1
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare i32 @llvm.amdgcn.mulhi.u24(i32, i32) #1
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }

View file

@ -0,0 +1,39 @@
define amdgpu_kernel void @mul24_lo_s32(ptr addrspace(4) byref(i64) %"32", ptr addrspace(4) byref(i64) %"33") #0 {
%"34" = alloca i64, align 8, addrspace(5)
%"35" = alloca i64, align 8, addrspace(5)
%"36" = alloca i32, align 4, addrspace(5)
%"37" = alloca i32, align 4, addrspace(5)
%"38" = alloca i32, align 4, addrspace(5)
br label %1
1: ; preds = %0
br label %"31"
"31": ; preds = %1
%"39" = load i64, ptr addrspace(4) %"32", align 4
store i64 %"39", ptr addrspace(5) %"34", align 4
%"40" = load i64, ptr addrspace(4) %"33", align 4
store i64 %"40", ptr addrspace(5) %"35", align 4
%"42" = load i64, ptr addrspace(5) %"34", align 4
%"50" = inttoptr i64 %"42" to ptr
%"41" = load i32, ptr %"50", align 4
store i32 %"41", ptr addrspace(5) %"36", align 4
%"44" = load i32, ptr addrspace(5) %"36", align 4
%"43" = sub i32 0, %"44"
store i32 %"43", ptr addrspace(5) %"37", align 4
%"46" = load i32, ptr addrspace(5) %"37", align 4
%"47" = load i32, ptr addrspace(5) %"36", align 4
%"45" = call i32 @llvm.amdgcn.mul.i24(i32 %"46", i32 %"47")
store i32 %"45", ptr addrspace(5) %"38", align 4
%"48" = load i64, ptr addrspace(5) %"35", align 4
%"49" = load i32, ptr addrspace(5) %"38", align 4
%"51" = inttoptr i64 %"48" to ptr
store i32 %"49", ptr %"51", align 4
ret void
}
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare i32 @llvm.amdgcn.mul.i24(i32, i32) #1
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }

View file

@ -0,0 +1,35 @@
define amdgpu_kernel void @mul24_lo_u32(ptr addrspace(4) byref(i64) %"31", ptr addrspace(4) byref(i64) %"32") #0 {
%"33" = alloca i64, align 8, addrspace(5)
%"34" = alloca i64, align 8, addrspace(5)
%"35" = alloca i32, align 4, addrspace(5)
%"36" = alloca i32, align 4, addrspace(5)
br label %1
1: ; preds = %0
br label %"30"
"30": ; preds = %1
%"37" = load i64, ptr addrspace(4) %"31", align 4
store i64 %"37", ptr addrspace(5) %"33", align 4
%"38" = load i64, ptr addrspace(4) %"32", align 4
store i64 %"38", ptr addrspace(5) %"34", align 4
%"40" = load i64, ptr addrspace(5) %"33", align 4
%"46" = inttoptr i64 %"40" to ptr
%"39" = load i32, ptr %"46", align 4
store i32 %"39", ptr addrspace(5) %"35", align 4
%"42" = load i32, ptr addrspace(5) %"35", align 4
%"43" = load i32, ptr addrspace(5) %"35", align 4
%"41" = call i32 @llvm.amdgcn.mul.u24(i32 %"42", i32 %"43")
store i32 %"41", ptr addrspace(5) %"36", align 4
%"44" = load i64, ptr addrspace(5) %"34", align 4
%"45" = load i32, ptr addrspace(5) %"36", align 4
%"47" = inttoptr i64 %"44" to ptr
store i32 %"45", ptr %"47", align 4
ret void
}
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare i32 @llvm.amdgcn.mul.u24(i32, i32) #1
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }

View file

@ -53,7 +53,26 @@ test_ptx!(mov, [1u64], [1u64]);
test_ptx!(mul_lo, [1u64], [2u64]);
test_ptx!(mul_hi, [u64::max_value()], [1u64]);
test_ptx!(add, [1u64], [2u64]);
test_ptx!(mul24, [10u32], [20u32]);
test_ptx!(
mul24_lo_u32,
[0b01110101_01010101_01010101u32],
[0b00011100_00100011_10001110_00111001u32]
);
test_ptx!(
mul24_hi_u32,
[0b01110101_01010101_01010101u32],
[0b00110101_11000111_00011100_00100011u32]
);
test_ptx!(
mul24_lo_s32,
[0b01110101_01010101_01010101i32],
[-0b0011100_00100011_10001110_00111001i32]
);
test_ptx!(
mul24_hi_s32,
[0b01110101_01010101_01010101i32],
[-0b0110101_11000111_00011100_00100100i32]
);
test_ptx!(setp, [10u64, 11u64], [1u64, 0u64]);
test_ptx!(setp_gt, [f32::NAN, 1f32], [1f32]);
test_ptx!(setp_leu, [1f32, f32::NAN], [1f32]);

View file

@ -0,0 +1,24 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry mul24_hi_s32(
.param .s64 input,
.param .s64 output
)
{
.reg .s64 in_addr;
.reg .s64 out_addr;
.reg .s32 temp;
.reg .s32 temp2;
.reg .s32 temp3;
ld.param.s64 in_addr, [input];
ld.param.s64 out_addr, [output];
ld.s32 temp, [in_addr];
neg.s32 temp2, temp;
mul24.hi.s32 temp3, temp2, temp;
st.s32 [out_addr], temp3;
ret;
}

View file

@ -2,7 +2,7 @@
.target sm_30
.address_size 64
.visible .entry mul24(
.visible .entry mul24_hi_u32(
.param .u64 input,
.param .u64 output
)
@ -16,7 +16,7 @@
ld.param.u64 out_addr, [output];
ld.u32 temp, [in_addr];
mul24.lo.u32 temp2, temp, 2;
mul24.hi.u32 temp2, temp, temp;
st.u32 [out_addr], temp2;
ret;
}

View file

@ -0,0 +1,24 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry mul24_lo_s32(
.param .s64 input,
.param .s64 output
)
{
.reg .s64 in_addr;
.reg .s64 out_addr;
.reg .s32 temp;
.reg .s32 temp2;
.reg .s32 temp3;
ld.param.s64 in_addr, [input];
ld.param.s64 out_addr, [output];
ld.s32 temp, [in_addr];
neg.s32 temp2, temp;
mul24.lo.s32 temp3, temp2, temp;
st.s32 [out_addr], temp3;
ret;
}

View file

@ -0,0 +1,22 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry mul24_lo_u32(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .u32 temp;
.reg .u32 temp2;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.u32 temp, [in_addr];
mul24.lo.u32 temp2, temp, temp;
st.u32 [out_addr], temp2;
ret;
}