diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index bd5c277..96e0815 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -545,6 +545,7 @@ impl<'a> MethodEmitContext<'a> { ast::Instruction::Add { data, arguments } => self.emit_add(data, arguments), ast::Instruction::St { data, arguments } => self.emit_st(data, arguments), ast::Instruction::Mul { data, arguments } => self.emit_mul(data, arguments), + ast::Instruction::Mul24 { data, arguments } => self.emit_mul24(data, arguments), ast::Instruction::Setp { data, arguments } => self.emit_setp(data, arguments), ast::Instruction::SetpBool { .. } => todo!(), ast::Instruction::Not { data, arguments } => self.emit_not(data, arguments), @@ -2224,6 +2225,20 @@ impl<'a> MethodEmitContext<'a> { Ok(()) } + fn emit_mul24( + &mut self, + data: ast::Mul24Details, + arguments: ast::Mul24Args, + ) -> Result<(), TranslateError> { + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + self.emit_intrinsic(c"llvm.amdgcn.mul.u24", Some(arguments.dst), &ast::Type::Scalar(data.type_), vec![ + (src1, get_scalar_type(self.context, data.type_)), + (src2, get_scalar_type(self.context, data.type_)), + ])?; + Ok(()) + } + /* // Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding` // Should be available in LLVM 19 diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index ce5452a..165835f 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -53,6 +53,7 @@ test_ptx!(mov, [1u64], [1u64]); test_ptx!(mul_lo, [1u64], [2u64]); test_ptx!(mul_hi, [u64::max_value()], [1u64]); test_ptx!(add, [1u64], [2u64]); +test_ptx!(mul24, [10u32], [20u32]); test_ptx!(setp, [10u64, 11u64], [1u64, 0u64]); test_ptx!(setp_gt, [f32::NAN, 1f32], [1f32]); test_ptx!(setp_leu, [1f32, f32::NAN], [1f32]); diff --git a/ptx/src/test/spirv_run/mul24.ptx b/ptx/src/test/spirv_run/mul24.ptx new file mode 100644 index 0000000..53c1224 --- /dev/null +++ b/ptx/src/test/spirv_run/mul24.ptx @@ -0,0 +1,22 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry mul24( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u32 temp; + .reg .u32 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.u32 temp, [in_addr]; + mul24.lo.u32 temp2, temp, 2; + st.u32 [out_addr], temp2; + ret; +} diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index c5e8e79..3c0db16 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -2,7 +2,7 @@ use super::{ AtomSemantics, MemScope, RawRoundingMode, RawSetpCompareOp, ScalarType, SetpBoolPostOp, StateSpace, VectorPrefix, }; -use crate::{PtxError, PtxParserState}; +use crate::{PtxError, PtxParserState, Mul24Control}; use bitflags::bitflags; use std::{alloc::Layout, cmp::Ordering, num::NonZeroU8}; @@ -87,6 +87,15 @@ ptx_parser_macros::generate_instruction_type!( src2: T, } }, + Mul24 { + type: { Type::from(data.type_) }, + data: Mul24Details, + arguments: { + dst: T, + src1: T, + src2: T, + } + }, Setp { data: SetpData, arguments: { @@ -1178,6 +1187,13 @@ pub enum MulIntControl { Wide, } + +#[derive(Copy, Clone)] +pub struct Mul24Details { + pub type_: ScalarType, + pub control: Mul24Control, +} + pub struct SetpData { pub type_: ScalarType, pub flush_to_zero: Option, diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index f2c376d..12f8a4b 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -1652,6 +1652,9 @@ derive_parser!( #[derive(Copy, Clone, PartialEq, Eq, Hash)] pub enum AtomSemantics { } + #[derive(Copy, Clone, PartialEq, Eq, Hash)] + pub enum Mul24Control { } + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov mov{.vec}.type d, a => { Instruction::Mov { @@ -3359,6 +3362,19 @@ derive_parser!( Instruction::Ret { data: RetData { uniform: uni } } } + mul24.mode.type d, a, b => { + ast::Instruction::Mul24 { + data: ast::Mul24Details { + control: mode, + type_ + }, + arguments: Mul24Args { dst: d, src1: a, src2: b } + } + } + + .mode: Mul24Control = { .hi, .lo }; + .type: ScalarType = { .u32, .s32 }; + ); #[cfg(test)]