mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
Implement mad.hi.cc (#196)
This commit is contained in:
parent
b695f44c18
commit
76bae5f91b
7 changed files with 153 additions and 33 deletions
|
@ -380,6 +380,7 @@ pub enum Instruction<P: ArgParams> {
|
|||
},
|
||||
MadCC {
|
||||
type_: ScalarType,
|
||||
is_hi: bool,
|
||||
arg: Arg4<P>,
|
||||
},
|
||||
Fma(ArithFloat, Arg4<P>),
|
||||
|
|
|
@ -621,8 +621,8 @@ fn emit_statement(
|
|||
crate::translate::Statement::MadC(MadCDetails { type_, is_hi, arg }) => {
|
||||
emit_inst_madc(ctx, type_, is_hi, &arg)?
|
||||
}
|
||||
crate::translate::Statement::MadCC(MadCCDetails { type_, arg }) => {
|
||||
emit_inst_madcc(ctx, type_, &arg)?
|
||||
crate::translate::Statement::MadCC(MadCCDetails { type_, is_hi, arg }) => {
|
||||
emit_inst_madcc(ctx, type_, is_hi, &arg)?
|
||||
}
|
||||
crate::translate::Statement::AddC(type_, arg) => emit_inst_add_c(ctx, type_, &arg)?,
|
||||
crate::translate::Statement::AddCC(type_, arg) => {
|
||||
|
@ -2079,16 +2079,17 @@ fn emit_inst_mad_lo(
|
|||
)
|
||||
}
|
||||
|
||||
// TODO: support mad.hi.cc
|
||||
fn emit_inst_madcc(
|
||||
ctx: &mut EmitContext,
|
||||
type_: ast::ScalarType,
|
||||
is_hi: bool,
|
||||
arg: &Arg4CarryOut<ExpandedArgParams>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let builder = ctx.builder.get();
|
||||
let src1 = ctx.names.value(arg.src1)?;
|
||||
let src2 = ctx.names.value(arg.src2)?;
|
||||
let mul_result = unsafe { LLVMBuildMul(builder, src1, src2, LLVM_UNNAMED) };
|
||||
let mul_result = if is_hi {
|
||||
emit_inst_mul_hi_impl(ctx, type_, None, arg.src1, arg.src2)?
|
||||
} else {
|
||||
emit_inst_mul_low_impl(ctx, None, arg.src1, arg.src2, LLVMBuildMul)?
|
||||
};
|
||||
emit_inst_addsub_cc_impl(
|
||||
ctx,
|
||||
"add",
|
||||
|
@ -2176,29 +2177,6 @@ fn emit_inst_madc(
|
|||
mul_result,
|
||||
args.src3,
|
||||
)
|
||||
/*
|
||||
let src3 = ctx.names.value(args.src3)?;
|
||||
let add_no_carry = unsafe { LLVMBuildAdd(builder, mul_result, src3, LLVM_UNNAMED) };
|
||||
let carry_flag = ctx.names.value(args.carry_in)?;
|
||||
let llvm_type = get_llvm_type(ctx, &ast::Type::Scalar(type_))?;
|
||||
let carry_flag = unsafe { LLVMBuildZExt(builder, carry_flag, llvm_type, LLVM_UNNAMED) };
|
||||
if let Some(carry_out) = args.carry_out {
|
||||
emit_inst_addsub_cc_impl(
|
||||
ctx,
|
||||
"add",
|
||||
type_,
|
||||
args.dst,
|
||||
carry_out,
|
||||
add_no_carry,
|
||||
carry_flag,
|
||||
)?;
|
||||
} else {
|
||||
ctx.names.register_result(args.dst, |dst| unsafe {
|
||||
LLVMBuildAdd(builder, add_no_carry, carry_flag, dst)
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
*/
|
||||
}
|
||||
|
||||
fn emit_inst_add_c(
|
||||
|
|
|
@ -1516,7 +1516,12 @@ InstMad: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#extended-precision-arithmetic-instructions-mad-cc
|
||||
InstMadCC: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
||||
"mad" ".lo" ".cc" <type_:IntType3264> <arg:Arg4> => ast::Instruction::MadCC{<>},
|
||||
"mad" ".lo" ".cc" <type_:IntType3264> <arg:Arg4> => {
|
||||
ast::Instruction::MadCC { type_, arg, is_hi: false }
|
||||
},
|
||||
"mad" ".hi" ".cc" <type_:IntType3264> <arg:Arg4> => {
|
||||
ast::Instruction::MadCC { type_, arg, is_hi: true }
|
||||
},
|
||||
};
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#extended-precision-arithmetic-instructions-madc
|
||||
|
|
90
ptx/src/test/spirv_run/mad_hi_cc.ll
Normal file
90
ptx/src/test/spirv_run/mad_hi_cc.ll
Normal file
|
@ -0,0 +1,90 @@
|
|||
target datalayout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7"
|
||||
target triple = "amdgcn-amd-amdhsa"
|
||||
|
||||
define protected amdgpu_kernel void @mad_hi_cc(ptr addrspace(4) byref(i64) %"61", ptr addrspace(4) byref(i64) %"62") #0 {
|
||||
"78":
|
||||
%"14" = alloca i1, align 1, addrspace(5)
|
||||
store i1 false, ptr addrspace(5) %"14", align 1
|
||||
%"15" = alloca i1, align 1, addrspace(5)
|
||||
store i1 false, ptr addrspace(5) %"15", align 1
|
||||
%"4" = alloca i64, align 8, addrspace(5)
|
||||
%"5" = alloca i64, align 8, addrspace(5)
|
||||
%"6" = alloca i32, align 4, addrspace(5)
|
||||
%"7" = alloca i32, align 4, addrspace(5)
|
||||
%"8" = alloca i32, align 4, addrspace(5)
|
||||
%"9" = alloca i32, align 4, addrspace(5)
|
||||
%"10" = alloca i32, align 4, addrspace(5)
|
||||
%"11" = alloca i32, align 4, addrspace(5)
|
||||
%"12" = alloca i32, align 4, addrspace(5)
|
||||
%"13" = alloca i32, align 4, addrspace(5)
|
||||
%"16" = load i64, ptr addrspace(4) %"61", align 8
|
||||
store i64 %"16", ptr addrspace(5) %"4", align 8
|
||||
%"17" = load i64, ptr addrspace(4) %"62", align 8
|
||||
store i64 %"17", ptr addrspace(5) %"5", align 8
|
||||
%"19" = load i64, ptr addrspace(5) %"4", align 8
|
||||
%"64" = inttoptr i64 %"19" to ptr
|
||||
%"63" = load i32, ptr %"64", align 4
|
||||
store i32 %"63", ptr addrspace(5) %"8", align 4
|
||||
%"21" = load i64, ptr addrspace(5) %"4", align 8
|
||||
%"65" = inttoptr i64 %"21" to ptr
|
||||
%"80" = getelementptr inbounds i8, ptr %"65", i64 4
|
||||
%"66" = load i32, ptr %"80", align 4
|
||||
store i32 %"66", ptr addrspace(5) %"9", align 4
|
||||
%"23" = load i64, ptr addrspace(5) %"4", align 8
|
||||
%"67" = inttoptr i64 %"23" to ptr
|
||||
%"82" = getelementptr inbounds i8, ptr %"67", i64 8
|
||||
%"22" = load i32, ptr %"82", align 4
|
||||
store i32 %"22", ptr addrspace(5) %"10", align 4
|
||||
%"26" = load i32, ptr addrspace(5) %"8", align 4
|
||||
%"27" = load i32, ptr addrspace(5) %"9", align 4
|
||||
%"28" = load i32, ptr addrspace(5) %"10", align 4
|
||||
%0 = sext i32 %"26" to i64
|
||||
%1 = sext i32 %"27" to i64
|
||||
%2 = mul nsw i64 %0, %1
|
||||
%3 = lshr i64 %2, 32
|
||||
%4 = trunc i64 %3 to i32
|
||||
%5 = call { i32, i1 } @llvm.uadd.with.overflow.i32(i32 %4, i32 %"28")
|
||||
%"24" = extractvalue { i32, i1 } %5, 0
|
||||
%"25" = extractvalue { i32, i1 } %5, 1
|
||||
store i32 %"24", ptr addrspace(5) %"7", align 4
|
||||
store i1 %"25", ptr addrspace(5) %"14", align 1
|
||||
%6 = call { i32, i1 } @llvm.uadd.with.overflow.i32(i32 1, i32 -2)
|
||||
%"29" = extractvalue { i32, i1 } %6, 0
|
||||
%"30" = extractvalue { i32, i1 } %6, 1
|
||||
store i32 %"29", ptr addrspace(5) %"6", align 4
|
||||
store i1 %"30", ptr addrspace(5) %"14", align 1
|
||||
%"32" = load i1, ptr addrspace(5) %"14", align 1
|
||||
%7 = zext i1 %"32" to i32
|
||||
%"71" = add i32 0, %7
|
||||
store i32 %"71", ptr addrspace(5) %"12", align 4
|
||||
%8 = call { i32, i1 } @llvm.uadd.with.overflow.i32(i32 1, i32 -1)
|
||||
%"33" = extractvalue { i32, i1 } %8, 0
|
||||
%"34" = extractvalue { i32, i1 } %8, 1
|
||||
store i32 %"33", ptr addrspace(5) %"6", align 4
|
||||
store i1 %"34", ptr addrspace(5) %"14", align 1
|
||||
%"36" = load i1, ptr addrspace(5) %"14", align 1
|
||||
%9 = zext i1 %"36" to i32
|
||||
%"72" = add i32 0, %9
|
||||
store i32 %"72", ptr addrspace(5) %"13", align 4
|
||||
%"37" = load i64, ptr addrspace(5) %"5", align 8
|
||||
%"38" = load i32, ptr addrspace(5) %"7", align 4
|
||||
%"73" = inttoptr i64 %"37" to ptr
|
||||
store i32 %"38", ptr %"73", align 4
|
||||
%"39" = load i64, ptr addrspace(5) %"5", align 8
|
||||
%"40" = load i32, ptr addrspace(5) %"12", align 4
|
||||
%"74" = inttoptr i64 %"39" to ptr
|
||||
%"84" = getelementptr inbounds i8, ptr %"74", i64 4
|
||||
store i32 %"40", ptr %"84", align 4
|
||||
%"41" = load i64, ptr addrspace(5) %"5", align 8
|
||||
%"42" = load i32, ptr addrspace(5) %"13", align 4
|
||||
%"76" = inttoptr i64 %"41" to ptr
|
||||
%"86" = getelementptr inbounds i8, ptr %"76", i64 8
|
||||
store i32 %"42", ptr %"86", align 4
|
||||
ret void
|
||||
}
|
||||
|
||||
; Function Attrs: nocallback nofree nosync nounwind readnone speculatable willreturn
|
||||
declare { i32, i1 } @llvm.uadd.with.overflow.i32(i32, i32) #1
|
||||
|
||||
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee,ieee" "denormal-fp-math-f32"="ieee,ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
||||
attributes #1 = { nocallback nofree nosync nounwind readnone speculatable willreturn }
|
41
ptx/src/test/spirv_run/mad_hi_cc.ptx
Normal file
41
ptx/src/test/spirv_run/mad_hi_cc.ptx
Normal file
|
@ -0,0 +1,41 @@
|
|||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.visible .entry mad_hi_cc(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .u32 unused;
|
||||
|
||||
.reg .s32 dst1;
|
||||
.reg .b32 src1;
|
||||
.reg .b32 src2;
|
||||
.reg .b32 src3;
|
||||
|
||||
.reg .b32 result_1;
|
||||
.reg .b32 carry_out_1;
|
||||
.reg .b32 carry_out_2;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
// test valid computational results
|
||||
ld.s32 src1, [in_addr];
|
||||
ld.s32 src2, [in_addr+4];
|
||||
ld.b32 src3, [in_addr+8];
|
||||
mad.hi.cc.s32 dst1, src1, src2, src3;
|
||||
|
||||
mad.hi.cc.u32 unused, 65536, 65536, 4294967294; // non-overflowing
|
||||
addc.u32 carry_out_1, 0, 0; // carry_out_1 should be 0
|
||||
mad.hi.cc.u32 unused, 65536, 65536, 4294967295; // overflowing
|
||||
addc.u32 carry_out_2, 0, 0; // carry_out_2 should be 1
|
||||
|
||||
st.s32 [out_addr], dst1;
|
||||
st.s32 [out_addr+4], carry_out_1;
|
||||
st.s32 [out_addr+8], carry_out_2;
|
||||
ret;
|
||||
}
|
|
@ -290,6 +290,7 @@ test_ptx!(
|
|||
[2147487519u32, 4294934539]
|
||||
);
|
||||
test_ptx!(madc_cc2, [0xDEADu32], [0u32, 1, 1, 2]);
|
||||
test_ptx!(mad_hi_cc, [0x26223377u32, 0x70777766u32, 0x60666633u32], [0x71272866u32, 0u32, 1u32]); // Multi-tap :)
|
||||
test_ptx!(mov_vector_cast, [0x200000001u64], [2u32, 1u32]);
|
||||
test_ptx!(
|
||||
cvt_clamp,
|
||||
|
|
|
@ -1999,9 +1999,10 @@ fn insert_hardware_registers_impl<'input>(
|
|||
is_hi,
|
||||
arg: Arg4CarryIn::new(arg, carry_out, TypedOperand::Reg(overflow_flag)),
|
||||
})),
|
||||
Statement::Instruction(ast::Instruction::MadCC { type_, arg }) => {
|
||||
Statement::Instruction(ast::Instruction::MadCC { type_, is_hi, arg }) => {
|
||||
result.push(Statement::MadCC(MadCCDetails {
|
||||
type_,
|
||||
is_hi,
|
||||
arg: Arg4CarryOut::new(arg, TypedOperand::Reg(overflow_flag)),
|
||||
}))
|
||||
}
|
||||
|
@ -5568,6 +5569,7 @@ impl<T: ArgParamsEx<Id = Id>, U: ArgParamsEx<Id = Id>> Visitable<T, U> for MadCD
|
|||
|
||||
pub(crate) struct MadCCDetails<P: ast::ArgParams> {
|
||||
pub(crate) type_: ast::ScalarType,
|
||||
pub(crate) is_hi: bool,
|
||||
pub(crate) arg: Arg4CarryOut<P>,
|
||||
}
|
||||
|
||||
|
@ -5578,6 +5580,7 @@ impl<T: ArgParamsEx<Id = Id>, U: ArgParamsEx<Id = Id>> Visitable<T, U> for MadCC
|
|||
) -> Result<Statement<ast::Instruction<U>, U>, TranslateError> {
|
||||
Ok(Statement::MadCC(MadCCDetails {
|
||||
type_: self.type_,
|
||||
is_hi: self.is_hi,
|
||||
arg: self.arg.map(visitor, self.type_)?,
|
||||
}))
|
||||
}
|
||||
|
@ -6486,8 +6489,9 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
|
|||
carry_out,
|
||||
arg: arg.map(visitor, &ast::Type::Scalar(type_), false)?,
|
||||
},
|
||||
ast::Instruction::MadCC { type_, arg } => ast::Instruction::MadCC {
|
||||
ast::Instruction::MadCC { type_, arg, is_hi } => ast::Instruction::MadCC {
|
||||
type_,
|
||||
is_hi,
|
||||
arg: arg.map(visitor, &ast::Type::Scalar(type_), false)?,
|
||||
},
|
||||
ast::Instruction::Tex(details, arg) => {
|
||||
|
|
Loading…
Add table
Reference in a new issue