From 7cdab7abc28508f7e6840837eab2795c1fb00532 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=ABlle=20van=20Essen?= <39169351+JoelleJS@users.noreply.github.com> Date: Tue, 8 Apr 2025 12:27:19 +0200 Subject: [PATCH] Implement mul24 (#351) --- ptx/src/pass/emit_llvm.rs | 43 +- ptx/src/test/ll/mul24.ll | 34 - ptx/src/test/ll/mul24_hi_s32.ll | 46 ++ ptx/src/test/ll/mul24_hi_u32.ll | 42 ++ ptx/src/test/ll/mul24_lo_s32.ll | 39 + ptx/src/test/ll/mul24_lo_u32.ll | 35 + ptx/src/test/spirv_run/mod.rs | 21 +- ptx/src/test/spirv_run/mul24_hi_s32.ptx | 24 + .../spirv_run/{mul24.ptx => mul24_hi_u32.ptx} | 4 +- ptx/src/test/spirv_run/mul24_lo_s32.ptx | 24 + ptx/src/test/spirv_run/mul24_lo_u32.ptx | 22 + ptx_parser/src/ast.rs | 694 +++++++++--------- 12 files changed, 641 insertions(+), 387 deletions(-) delete mode 100644 ptx/src/test/ll/mul24.ll create mode 100644 ptx/src/test/ll/mul24_hi_s32.ll create mode 100644 ptx/src/test/ll/mul24_hi_u32.ll create mode 100644 ptx/src/test/ll/mul24_lo_s32.ll create mode 100644 ptx/src/test/ll/mul24_lo_u32.ll create mode 100644 ptx/src/test/spirv_run/mul24_hi_s32.ptx rename ptx/src/test/spirv_run/{mul24.ptx => mul24_hi_u32.ptx} (84%) create mode 100644 ptx/src/test/spirv_run/mul24_lo_s32.ptx create mode 100644 ptx/src/test/spirv_run/mul24_lo_u32.ptx diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 5a5dd80..0f432ca 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -36,6 +36,7 @@ use llvm_zluda::bit_writer::LLVMWriteBitcodeToMemoryBuffer; use llvm_zluda::{core::*, *}; use llvm_zluda::{prelude::*, LLVMZludaBuildAtomicRMW}; use llvm_zluda::{LLVMCallConv, LLVMZludaBuildAlloca}; +use ptx_parser::Mul24Control; const LLVM_UNNAMED: &CStr = c""; // https://llvm.org/docs/AMDGPUUsage.html#address-spaces @@ -2281,15 +2282,51 @@ impl<'a> MethodEmitContext<'a> { ) -> Result<(), TranslateError> { let src1 = self.resolver.value(arguments.src1)?; let src2 = self.resolver.value(arguments.src2)?; - self.emit_intrinsic( - c"llvm.amdgcn.mul.u24", - Some(arguments.dst), + let name_lo = match data.type_ { + ast::ScalarType::U32 => c"llvm.amdgcn.mul.u24", + ast::ScalarType::S32 => c"llvm.amdgcn.mul.i24", + _ => return Err(error_unreachable()), + }; + let res_lo = self.emit_intrinsic( + name_lo, + if data.control == Mul24Control::Lo { Some(arguments.dst) } else { None }, Some(&ast::Type::Scalar(data.type_)), vec![ (src1, get_scalar_type(self.context, data.type_)), (src2, get_scalar_type(self.context, data.type_)), ], )?; + if data.control == Mul24Control::Hi { + // There is an important difference between NVIDIA's mul24.hi and AMD's mulhi.[ui]24. + // NVIDIA: Returns bits 47..16 of the 64-bit result + // AMD: Returns bits 63..32 of the 64-bit result + // Hence we need to compute both hi and lo, shift the results and add them together to replicate NVIDIA's mul24 + let name_hi = match data.type_ { + ast::ScalarType::U32 => c"llvm.amdgcn.mulhi.u24", + ast::ScalarType::S32 => c"llvm.amdgcn.mulhi.i24", + _ => return Err(error_unreachable()), + }; + let res_hi = self.emit_intrinsic( + name_hi, + None, + Some(&ast::Type::Scalar(data.type_)), + vec![ + (src1, get_scalar_type(self.context, data.type_)), + (src2, get_scalar_type(self.context, data.type_)), + ], + )?; + let shift_number = unsafe { LLVMConstInt(LLVMInt32TypeInContext(self.context), 16, 0) }; + let res_lo_shr = unsafe { + LLVMBuildLShr(self.builder, res_lo, shift_number, LLVM_UNNAMED.as_ptr()) + }; + let res_hi_shl = + unsafe { LLVMBuildShl(self.builder, res_hi, shift_number, LLVM_UNNAMED.as_ptr()) }; + + self.resolver + .with_result(arguments.dst, |dst: *const i8| unsafe { + LLVMBuildOr(self.builder, res_lo_shr, res_hi_shl, dst) + }); + } Ok(()) } diff --git a/ptx/src/test/ll/mul24.ll b/ptx/src/test/ll/mul24.ll deleted file mode 100644 index f65aa94..0000000 --- a/ptx/src/test/ll/mul24.ll +++ /dev/null @@ -1,34 +0,0 @@ -define amdgpu_kernel void @mul24(ptr addrspace(4) byref(i64) %"32", ptr addrspace(4) byref(i64) %"33") #0 { - %"34" = alloca i64, align 8, addrspace(5) - %"35" = alloca i64, align 8, addrspace(5) - %"36" = alloca i32, align 4, addrspace(5) - %"37" = alloca i32, align 4, addrspace(5) - br label %1 - -1: ; preds = %0 - br label %"31" - -"31": ; preds = %1 - %"38" = load i64, ptr addrspace(4) %"32", align 4 - store i64 %"38", ptr addrspace(5) %"34", align 4 - %"39" = load i64, ptr addrspace(4) %"33", align 4 - store i64 %"39", ptr addrspace(5) %"35", align 4 - %"41" = load i64, ptr addrspace(5) %"34", align 4 - %"46" = inttoptr i64 %"41" to ptr - %"40" = load i32, ptr %"46", align 4 - store i32 %"40", ptr addrspace(5) %"36", align 4 - %"43" = load i32, ptr addrspace(5) %"36", align 4 - %"42" = call i32 @llvm.amdgcn.mul.u24(i32 %"43", i32 2) - store i32 %"42", ptr addrspace(5) %"37", align 4 - %"44" = load i64, ptr addrspace(5) %"35", align 4 - %"45" = load i32, ptr addrspace(5) %"37", align 4 - %"47" = inttoptr i64 %"44" to ptr - store i32 %"45", ptr %"47", align 4 - ret void -} - -; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) -declare i32 @llvm.amdgcn.mul.u24(i32, i32) #1 - -attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } -attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) } \ No newline at end of file diff --git a/ptx/src/test/ll/mul24_hi_s32.ll b/ptx/src/test/ll/mul24_hi_s32.ll new file mode 100644 index 0000000..20e32ed --- /dev/null +++ b/ptx/src/test/ll/mul24_hi_s32.ll @@ -0,0 +1,46 @@ +define amdgpu_kernel void @mul24_hi_s32(ptr addrspace(4) byref(i64) %"32", ptr addrspace(4) byref(i64) %"33") #0 { + %"34" = alloca i64, align 8, addrspace(5) + %"35" = alloca i64, align 8, addrspace(5) + %"36" = alloca i32, align 4, addrspace(5) + %"37" = alloca i32, align 4, addrspace(5) + %"38" = alloca i32, align 4, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"31" + +"31": ; preds = %1 + %"39" = load i64, ptr addrspace(4) %"32", align 4 + store i64 %"39", ptr addrspace(5) %"34", align 4 + %"40" = load i64, ptr addrspace(4) %"33", align 4 + store i64 %"40", ptr addrspace(5) %"35", align 4 + %"42" = load i64, ptr addrspace(5) %"34", align 4 + %"50" = inttoptr i64 %"42" to ptr + %"41" = load i32, ptr %"50", align 4 + store i32 %"41", ptr addrspace(5) %"36", align 4 + %"44" = load i32, ptr addrspace(5) %"36", align 4 + %"43" = sub i32 0, %"44" + store i32 %"43", ptr addrspace(5) %"37", align 4 + %"46" = load i32, ptr addrspace(5) %"37", align 4 + %"47" = load i32, ptr addrspace(5) %"36", align 4 + %2 = call i32 @llvm.amdgcn.mul.i24(i32 %"46", i32 %"47") + %3 = call i32 @llvm.amdgcn.mulhi.i24(i32 %"46", i32 %"47") + %4 = lshr i32 %2, 16 + %5 = shl i32 %3, 16 + %"45" = or i32 %4, %5 + store i32 %"45", ptr addrspace(5) %"38", align 4 + %"48" = load i64, ptr addrspace(5) %"35", align 4 + %"49" = load i32, ptr addrspace(5) %"38", align 4 + %"51" = inttoptr i64 %"48" to ptr + store i32 %"49", ptr %"51", align 4 + ret void +} + +; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare i32 @llvm.amdgcn.mul.i24(i32, i32) #1 + +; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare i32 @llvm.amdgcn.mulhi.i24(i32, i32) #1 + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) } \ No newline at end of file diff --git a/ptx/src/test/ll/mul24_hi_u32.ll b/ptx/src/test/ll/mul24_hi_u32.ll new file mode 100644 index 0000000..427adb6 --- /dev/null +++ b/ptx/src/test/ll/mul24_hi_u32.ll @@ -0,0 +1,42 @@ +define amdgpu_kernel void @mul24_hi_u32(ptr addrspace(4) byref(i64) %"31", ptr addrspace(4) byref(i64) %"32") #0 { + %"33" = alloca i64, align 8, addrspace(5) + %"34" = alloca i64, align 8, addrspace(5) + %"35" = alloca i32, align 4, addrspace(5) + %"36" = alloca i32, align 4, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"30" + +"30": ; preds = %1 + %"37" = load i64, ptr addrspace(4) %"31", align 4 + store i64 %"37", ptr addrspace(5) %"33", align 4 + %"38" = load i64, ptr addrspace(4) %"32", align 4 + store i64 %"38", ptr addrspace(5) %"34", align 4 + %"40" = load i64, ptr addrspace(5) %"33", align 4 + %"46" = inttoptr i64 %"40" to ptr + %"39" = load i32, ptr %"46", align 4 + store i32 %"39", ptr addrspace(5) %"35", align 4 + %"42" = load i32, ptr addrspace(5) %"35", align 4 + %"43" = load i32, ptr addrspace(5) %"35", align 4 + %2 = call i32 @llvm.amdgcn.mul.u24(i32 %"42", i32 %"43") + %3 = call i32 @llvm.amdgcn.mulhi.u24(i32 %"42", i32 %"43") + %4 = lshr i32 %2, 16 + %5 = shl i32 %3, 16 + %"41" = or i32 %4, %5 + store i32 %"41", ptr addrspace(5) %"36", align 4 + %"44" = load i64, ptr addrspace(5) %"34", align 4 + %"45" = load i32, ptr addrspace(5) %"36", align 4 + %"47" = inttoptr i64 %"44" to ptr + store i32 %"45", ptr %"47", align 4 + ret void +} + +; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare i32 @llvm.amdgcn.mul.u24(i32, i32) #1 + +; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare i32 @llvm.amdgcn.mulhi.u24(i32, i32) #1 + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) } \ No newline at end of file diff --git a/ptx/src/test/ll/mul24_lo_s32.ll b/ptx/src/test/ll/mul24_lo_s32.ll new file mode 100644 index 0000000..06a8b3b --- /dev/null +++ b/ptx/src/test/ll/mul24_lo_s32.ll @@ -0,0 +1,39 @@ +define amdgpu_kernel void @mul24_lo_s32(ptr addrspace(4) byref(i64) %"32", ptr addrspace(4) byref(i64) %"33") #0 { + %"34" = alloca i64, align 8, addrspace(5) + %"35" = alloca i64, align 8, addrspace(5) + %"36" = alloca i32, align 4, addrspace(5) + %"37" = alloca i32, align 4, addrspace(5) + %"38" = alloca i32, align 4, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"31" + +"31": ; preds = %1 + %"39" = load i64, ptr addrspace(4) %"32", align 4 + store i64 %"39", ptr addrspace(5) %"34", align 4 + %"40" = load i64, ptr addrspace(4) %"33", align 4 + store i64 %"40", ptr addrspace(5) %"35", align 4 + %"42" = load i64, ptr addrspace(5) %"34", align 4 + %"50" = inttoptr i64 %"42" to ptr + %"41" = load i32, ptr %"50", align 4 + store i32 %"41", ptr addrspace(5) %"36", align 4 + %"44" = load i32, ptr addrspace(5) %"36", align 4 + %"43" = sub i32 0, %"44" + store i32 %"43", ptr addrspace(5) %"37", align 4 + %"46" = load i32, ptr addrspace(5) %"37", align 4 + %"47" = load i32, ptr addrspace(5) %"36", align 4 + %"45" = call i32 @llvm.amdgcn.mul.i24(i32 %"46", i32 %"47") + store i32 %"45", ptr addrspace(5) %"38", align 4 + %"48" = load i64, ptr addrspace(5) %"35", align 4 + %"49" = load i32, ptr addrspace(5) %"38", align 4 + %"51" = inttoptr i64 %"48" to ptr + store i32 %"49", ptr %"51", align 4 + ret void +} + +; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare i32 @llvm.amdgcn.mul.i24(i32, i32) #1 + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) } \ No newline at end of file diff --git a/ptx/src/test/ll/mul24_lo_u32.ll b/ptx/src/test/ll/mul24_lo_u32.ll new file mode 100644 index 0000000..47c26c4 --- /dev/null +++ b/ptx/src/test/ll/mul24_lo_u32.ll @@ -0,0 +1,35 @@ +define amdgpu_kernel void @mul24_lo_u32(ptr addrspace(4) byref(i64) %"31", ptr addrspace(4) byref(i64) %"32") #0 { + %"33" = alloca i64, align 8, addrspace(5) + %"34" = alloca i64, align 8, addrspace(5) + %"35" = alloca i32, align 4, addrspace(5) + %"36" = alloca i32, align 4, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"30" + +"30": ; preds = %1 + %"37" = load i64, ptr addrspace(4) %"31", align 4 + store i64 %"37", ptr addrspace(5) %"33", align 4 + %"38" = load i64, ptr addrspace(4) %"32", align 4 + store i64 %"38", ptr addrspace(5) %"34", align 4 + %"40" = load i64, ptr addrspace(5) %"33", align 4 + %"46" = inttoptr i64 %"40" to ptr + %"39" = load i32, ptr %"46", align 4 + store i32 %"39", ptr addrspace(5) %"35", align 4 + %"42" = load i32, ptr addrspace(5) %"35", align 4 + %"43" = load i32, ptr addrspace(5) %"35", align 4 + %"41" = call i32 @llvm.amdgcn.mul.u24(i32 %"42", i32 %"43") + store i32 %"41", ptr addrspace(5) %"36", align 4 + %"44" = load i64, ptr addrspace(5) %"34", align 4 + %"45" = load i32, ptr addrspace(5) %"36", align 4 + %"47" = inttoptr i64 %"44" to ptr + store i32 %"45", ptr %"47", align 4 + ret void +} + +; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare i32 @llvm.amdgcn.mul.u24(i32, i32) #1 + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) } \ No newline at end of file diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index cafa480..27df227 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -53,7 +53,26 @@ test_ptx!(mov, [1u64], [1u64]); test_ptx!(mul_lo, [1u64], [2u64]); test_ptx!(mul_hi, [u64::max_value()], [1u64]); test_ptx!(add, [1u64], [2u64]); -test_ptx!(mul24, [10u32], [20u32]); +test_ptx!( + mul24_lo_u32, + [0b01110101_01010101_01010101u32], + [0b00011100_00100011_10001110_00111001u32] +); +test_ptx!( + mul24_hi_u32, + [0b01110101_01010101_01010101u32], + [0b00110101_11000111_00011100_00100011u32] +); +test_ptx!( + mul24_lo_s32, + [0b01110101_01010101_01010101i32], + [-0b0011100_00100011_10001110_00111001i32] +); +test_ptx!( + mul24_hi_s32, + [0b01110101_01010101_01010101i32], + [-0b0110101_11000111_00011100_00100100i32] +); test_ptx!(setp, [10u64, 11u64], [1u64, 0u64]); test_ptx!(setp_gt, [f32::NAN, 1f32], [1f32]); test_ptx!(setp_leu, [1f32, f32::NAN], [1f32]); diff --git a/ptx/src/test/spirv_run/mul24_hi_s32.ptx b/ptx/src/test/spirv_run/mul24_hi_s32.ptx new file mode 100644 index 0000000..7212214 --- /dev/null +++ b/ptx/src/test/spirv_run/mul24_hi_s32.ptx @@ -0,0 +1,24 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry mul24_hi_s32( + .param .s64 input, + .param .s64 output +) +{ + .reg .s64 in_addr; + .reg .s64 out_addr; + .reg .s32 temp; + .reg .s32 temp2; + .reg .s32 temp3; + + ld.param.s64 in_addr, [input]; + ld.param.s64 out_addr, [output]; + + ld.s32 temp, [in_addr]; + neg.s32 temp2, temp; + mul24.hi.s32 temp3, temp2, temp; + st.s32 [out_addr], temp3; + ret; +} diff --git a/ptx/src/test/spirv_run/mul24.ptx b/ptx/src/test/spirv_run/mul24_hi_u32.ptx similarity index 84% rename from ptx/src/test/spirv_run/mul24.ptx rename to ptx/src/test/spirv_run/mul24_hi_u32.ptx index 53c1224..958f0fe 100644 --- a/ptx/src/test/spirv_run/mul24.ptx +++ b/ptx/src/test/spirv_run/mul24_hi_u32.ptx @@ -2,7 +2,7 @@ .target sm_30 .address_size 64 -.visible .entry mul24( +.visible .entry mul24_hi_u32( .param .u64 input, .param .u64 output ) @@ -16,7 +16,7 @@ ld.param.u64 out_addr, [output]; ld.u32 temp, [in_addr]; - mul24.lo.u32 temp2, temp, 2; + mul24.hi.u32 temp2, temp, temp; st.u32 [out_addr], temp2; ret; } diff --git a/ptx/src/test/spirv_run/mul24_lo_s32.ptx b/ptx/src/test/spirv_run/mul24_lo_s32.ptx new file mode 100644 index 0000000..3be571f --- /dev/null +++ b/ptx/src/test/spirv_run/mul24_lo_s32.ptx @@ -0,0 +1,24 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry mul24_lo_s32( + .param .s64 input, + .param .s64 output +) +{ + .reg .s64 in_addr; + .reg .s64 out_addr; + .reg .s32 temp; + .reg .s32 temp2; + .reg .s32 temp3; + + ld.param.s64 in_addr, [input]; + ld.param.s64 out_addr, [output]; + + ld.s32 temp, [in_addr]; + neg.s32 temp2, temp; + mul24.lo.s32 temp3, temp2, temp; + st.s32 [out_addr], temp3; + ret; +} diff --git a/ptx/src/test/spirv_run/mul24_lo_u32.ptx b/ptx/src/test/spirv_run/mul24_lo_u32.ptx new file mode 100644 index 0000000..28c8902 --- /dev/null +++ b/ptx/src/test/spirv_run/mul24_lo_u32.ptx @@ -0,0 +1,22 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry mul24_lo_u32( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u32 temp; + .reg .u32 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.u32 temp, [in_addr]; + mul24.lo.u32 temp2, temp, temp; + st.u32 [out_addr], temp2; + ret; +} diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 55b950a..e4c3c87 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -30,14 +30,195 @@ pub enum Statement { // This information is then available to a visitor. ptx_parser_macros::generate_instruction_type!( pub enum Instruction { - Mov { - type: { &data.typ }, - data: MovDetails, + Abs { + data: TypeFtz, + type: { Type::Scalar(data.type_) }, + arguments: { + dst: T, + src: T, + } + }, + Activemask { + type: Type::Scalar(ScalarType::B32), + arguments: { + dst: T + } + }, + Add { + type: { Type::from(data.type_()) }, + data: ArithDetails, + arguments: { + dst: T, + src1: T, + src2: T, + } + }, + And { + data: ScalarType, + type: { Type::Scalar(data.clone()) }, + arguments: { + dst: T, + src1: T, + src2: T, + } + }, + Atom { + type: &data.type_, + data: AtomDetails, + arguments: { + dst: T, + src1: { + repr: T, + space: { data.space }, + }, + src2: T, + } + }, + AtomCas { + type: Type::Scalar(data.type_), + data: AtomCasDetails, + arguments: { + dst: T, + src1: { + repr: T, + space: { data.space }, + }, + src2: T, + src3: T, + } + }, + Bar { + type: Type::Scalar(ScalarType::U32), + data: BarData, + arguments: { + src1: T, + src2: Option, + } + }, + Bfe { + type: Type::Scalar(data.clone()), + data: ScalarType, + arguments: { + dst: T, + src1: T, + src2: { + repr: T, + type: Type::Scalar(ScalarType::U32) + }, + src3: { + repr: T, + type: Type::Scalar(ScalarType::U32) + }, + } + }, + Bfi { + type: Type::Scalar(data.clone()), + data: ScalarType, + arguments: { + dst: T, + src1: T, + src2: T, + src3: { + repr: T, + type: Type::Scalar(ScalarType::U32) + }, + src4: { + repr: T, + type: Type::Scalar(ScalarType::U32) + }, + } + }, + Bra { + type: !, + arguments: { + src: T + } + }, + Brev { + type: Type::Scalar(data.clone()), + data: ScalarType, arguments: { dst: T, src: T } }, + Call { + data: CallDetails, + arguments: CallArgs, + visit: arguments.visit(data, visitor)?, + visit_mut: arguments.visit_mut(data, visitor)?, + map: Instruction::Call{ arguments: arguments.map(&data, visitor)?, data } + }, + Clz { + type: Type::Scalar(data.clone()), + data: ScalarType, + arguments: { + dst: { + repr: T, + type: Type::Scalar(ScalarType::U32) + }, + src: T + } + }, + Cos { + type: Type::Scalar(ScalarType::F32), + data: FlushToZero, + arguments: { + dst: T, + src: T + } + }, + Cvt { + data: CvtDetails, + arguments: { + dst: { + repr: T, + type: { Type::Scalar(data.to) }, + // TODO: double check + relaxed_type_check: true, + }, + src: { + repr: T, + type: { Type::Scalar(data.from) }, + relaxed_type_check: true, + }, + } + }, + Cvta { + data: CvtaDetails, + type: { Type::Scalar(ScalarType::B64) }, + arguments: { + dst: T, + src: T, + } + }, + Div { + type: Type::Scalar(data.type_()), + data: DivDetails, + arguments: { + dst: T, + src1: T, + src2: T, + } + }, + Ex2 { + type: Type::Scalar(ScalarType::F32), + data: TypeFtz, + arguments: { + dst: T, + src: T + } + }, + Fma { + type: { Type::from(data.type_) }, + data: ArithFloat, + arguments: { + dst: T, + src1: T, + src2: T, + src3: T, + } + }, Ld { type: { &data.typ }, data: LdDetails, @@ -52,27 +233,54 @@ ptx_parser_macros::generate_instruction_type!( } } }, - Add { + Lg2 { + type: Type::Scalar(ScalarType::F32), + data: FlushToZero, + arguments: { + dst: T, + src: T + } + }, + Mad { type: { Type::from(data.type_()) }, - data: ArithDetails, + data: MadDetails, + arguments: { + dst: { + repr: T, + type: { Type::from(data.dst_type()) }, + }, + src1: T, + src2: T, + src3: T, + } + }, + Max { + type: { Type::from(data.type_()) }, + data: MinMaxDetails, arguments: { dst: T, src1: T, src2: T, } }, - St { - type: { &data.typ }, - data: StData, + Membar { + data: MemScope + }, + Min { + type: { Type::from(data.type_()) }, + data: MinMaxDetails, arguments: { - src1: { - repr: T, - space: { data.state_space }, - }, - src2: { - repr: T, - relaxed_type_check: true, - } + dst: T, + src1: T, + src2: T, + } + }, + Mov { + type: { &data.typ }, + data: MovDetails, + arguments: { + dst: T, + src: T } }, Mul { @@ -96,6 +304,101 @@ ptx_parser_macros::generate_instruction_type!( src2: T, } }, + Neg { + type: Type::Scalar(data.type_), + data: TypeFtz, + arguments: { + dst: T, + src: T + } + }, + Not { + data: ScalarType, + type: { Type::Scalar(data.clone()) }, + arguments: { + dst: T, + src: T, + } + }, + Or { + data: ScalarType, + type: { Type::Scalar(data.clone()) }, + arguments: { + dst: T, + src1: T, + src2: T, + } + }, + Popc { + type: Type::Scalar(data.clone()), + data: ScalarType, + arguments: { + dst: { + repr: T, + type: Type::Scalar(ScalarType::U32) + }, + src: T + } + }, + Prmt { + type: Type::Scalar(ScalarType::B32), + data: u16, + arguments: { + dst: T, + src1: T, + src2: T + } + }, + PrmtSlow { + type: Type::Scalar(ScalarType::U32), + arguments: { + dst: T, + src1: T, + src2: T, + src3: T + } + }, + Rcp { + type: { Type::from(data.type_) }, + data: RcpData, + arguments: { + dst: T, + src: T, + } + }, + Rem { + type: Type::Scalar(data.clone()), + data: ScalarType, + arguments: { + dst: T, + src1: T, + src2: T + } + }, + Ret { + data: RetData + }, + Rsqrt { + type: { Type::from(data.type_) }, + data: TypeFtz, + arguments: { + dst: T, + src: T, + } + }, + Selp { + type: { Type::Scalar(data.clone()) }, + data: ScalarType, + arguments: { + dst: T, + src1: T, + src2: T, + src3: { + repr: T, + type: Type::Scalar(ScalarType::Pred) + }, + } + }, Setp { data: SetpData, arguments: { @@ -142,58 +445,15 @@ ptx_parser_macros::generate_instruction_type!( } } }, - Not { - data: ScalarType, - type: { Type::Scalar(data.clone()) }, - arguments: { - dst: T, - src: T, - } - }, - Or { + Shl { data: ScalarType, type: { Type::Scalar(data.clone()) }, arguments: { dst: T, src1: T, - src2: T, - } - }, - And { - data: ScalarType, - type: { Type::Scalar(data.clone()) }, - arguments: { - dst: T, - src1: T, - src2: T, - } - }, - Bra { - type: !, - arguments: { - src: T - } - }, - Call { - data: CallDetails, - arguments: CallArgs, - visit: arguments.visit(data, visitor)?, - visit_mut: arguments.visit_mut(data, visitor)?, - map: Instruction::Call{ arguments: arguments.map(&data, visitor)?, data } - }, - Cvt { - data: CvtDetails, - arguments: { - dst: { + src2: { repr: T, - type: { Type::Scalar(data.to) }, - // TODO: double check - relaxed_type_check: true, - }, - src: { - repr: T, - type: { Type::Scalar(data.from) }, - relaxed_type_check: true, + type: { Type::Scalar(ScalarType::U32) }, }, } }, @@ -209,58 +469,34 @@ ptx_parser_macros::generate_instruction_type!( }, } }, - Shl { - data: ScalarType, - type: { Type::Scalar(data.clone()) }, + Sin { + type: Type::Scalar(ScalarType::F32), + data: FlushToZero, arguments: { dst: T, - src1: T, + src: T + } + }, + Sqrt { + type: { Type::from(data.type_) }, + data: RcpData, + arguments: { + dst: T, + src: T, + } + }, + St { + type: { &data.typ }, + data: StData, + arguments: { + src1: { + repr: T, + space: { data.state_space }, + }, src2: { repr: T, - type: { Type::Scalar(ScalarType::U32) }, - }, - } - }, - Ret { - data: RetData - }, - Cvta { - data: CvtaDetails, - type: { Type::Scalar(ScalarType::B64) }, - arguments: { - dst: T, - src: T, - } - }, - Abs { - data: TypeFtz, - type: { Type::Scalar(data.type_) }, - arguments: { - dst: T, - src: T, - } - }, - Mad { - type: { Type::from(data.type_()) }, - data: MadDetails, - arguments: { - dst: { - repr: T, - type: { Type::from(data.dst_type()) }, - }, - src1: T, - src2: T, - src3: T, - } - }, - Fma { - type: { Type::from(data.type_) }, - data: ArithFloat, - arguments: { - dst: T, - src1: T, - src2: T, - src3: T, + relaxed_type_check: true, + } } }, Sub { @@ -272,173 +508,7 @@ ptx_parser_macros::generate_instruction_type!( src2: T, } }, - Min { - type: { Type::from(data.type_()) }, - data: MinMaxDetails, - arguments: { - dst: T, - src1: T, - src2: T, - } - }, - Max { - type: { Type::from(data.type_()) }, - data: MinMaxDetails, - arguments: { - dst: T, - src1: T, - src2: T, - } - }, - Rcp { - type: { Type::from(data.type_) }, - data: RcpData, - arguments: { - dst: T, - src: T, - } - }, - Sqrt { - type: { Type::from(data.type_) }, - data: RcpData, - arguments: { - dst: T, - src: T, - } - }, - Rsqrt { - type: { Type::from(data.type_) }, - data: TypeFtz, - arguments: { - dst: T, - src: T, - } - }, - Selp { - type: { Type::Scalar(data.clone()) }, - data: ScalarType, - arguments: { - dst: T, - src1: T, - src2: T, - src3: { - repr: T, - type: Type::Scalar(ScalarType::Pred) - }, - } - }, - Bar { - type: Type::Scalar(ScalarType::U32), - data: BarData, - arguments: { - src1: T, - src2: Option, - } - }, - Atom { - type: &data.type_, - data: AtomDetails, - arguments: { - dst: T, - src1: { - repr: T, - space: { data.space }, - }, - src2: T, - } - }, - AtomCas { - type: Type::Scalar(data.type_), - data: AtomCasDetails, - arguments: { - dst: T, - src1: { - repr: T, - space: { data.space }, - }, - src2: T, - src3: T, - } - }, - Div { - type: Type::Scalar(data.type_()), - data: DivDetails, - arguments: { - dst: T, - src1: T, - src2: T, - } - }, - Neg { - type: Type::Scalar(data.type_), - data: TypeFtz, - arguments: { - dst: T, - src: T - } - }, - Sin { - type: Type::Scalar(ScalarType::F32), - data: FlushToZero, - arguments: { - dst: T, - src: T - } - }, - Cos { - type: Type::Scalar(ScalarType::F32), - data: FlushToZero, - arguments: { - dst: T, - src: T - } - }, - Lg2 { - type: Type::Scalar(ScalarType::F32), - data: FlushToZero, - arguments: { - dst: T, - src: T - } - }, - Ex2 { - type: Type::Scalar(ScalarType::F32), - data: TypeFtz, - arguments: { - dst: T, - src: T - } - }, - Clz { - type: Type::Scalar(data.clone()), - data: ScalarType, - arguments: { - dst: { - repr: T, - type: Type::Scalar(ScalarType::U32) - }, - src: T - } - }, - Brev { - type: Type::Scalar(data.clone()), - data: ScalarType, - arguments: { - dst: T, - src: T - } - }, - Popc { - type: Type::Scalar(data.clone()), - data: ScalarType, - arguments: { - dst: { - repr: T, - type: Type::Scalar(ScalarType::U32) - }, - src: T - } - }, + Trap { }, Xor { type: Type::Scalar(data.clone()), data: ScalarType, @@ -448,76 +518,6 @@ ptx_parser_macros::generate_instruction_type!( src2: T } }, - Rem { - type: Type::Scalar(data.clone()), - data: ScalarType, - arguments: { - dst: T, - src1: T, - src2: T - } - }, - Bfe { - type: Type::Scalar(data.clone()), - data: ScalarType, - arguments: { - dst: T, - src1: T, - src2: { - repr: T, - type: Type::Scalar(ScalarType::U32) - }, - src3: { - repr: T, - type: Type::Scalar(ScalarType::U32) - }, - } - }, - Bfi { - type: Type::Scalar(data.clone()), - data: ScalarType, - arguments: { - dst: T, - src1: T, - src2: T, - src3: { - repr: T, - type: Type::Scalar(ScalarType::U32) - }, - src4: { - repr: T, - type: Type::Scalar(ScalarType::U32) - }, - } - }, - PrmtSlow { - type: Type::Scalar(ScalarType::U32), - arguments: { - dst: T, - src1: T, - src2: T, - src3: T - } - }, - Prmt { - type: Type::Scalar(ScalarType::B32), - data: u16, - arguments: { - dst: T, - src1: T, - src2: T - } - }, - Activemask { - type: Type::Scalar(ScalarType::B32), - arguments: { - dst: T - } - }, - Membar { - data: MemScope - }, - Trap { } } );