From 6cd18bfdb8926e374a7a060e4acc20bfadfacfd0 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Wed, 21 Aug 2024 02:45:52 +0200 Subject: [PATCH] Add abs, mad --- gen_impl/src/parser.rs | 4 +- ptx_parser/src/ast.rs | 69 +++++++++++++++- ptx_parser/src/main.rs | 173 ++++++++++++++++++++++++++++++++++++++++- 3 files changed, 237 insertions(+), 9 deletions(-) diff --git a/gen_impl/src/parser.rs b/gen_impl/src/parser.rs index ea5070d..f1cd738 100644 --- a/gen_impl/src/parser.rs +++ b/gen_impl/src/parser.rs @@ -73,7 +73,7 @@ pub struct OpcodeDecl(pub Instruction, pub Arguments); impl OpcodeDecl { fn peek(input: syn::parse::ParseStream) -> bool { - Instruction::peek(input) + Instruction::peek(input) && !input.peek2(Token![=]) } } @@ -106,7 +106,7 @@ impl Parse for CodeBlock { } else { return Err(lookahead.error()); }; - Ok(Self{special, code}) + Ok(Self { special, code }) } } diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 98583a8..248a6f3 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -200,6 +200,27 @@ gen::generate_instruction_type!( src: T, } }, + Abs { + data: AbsDetails, + type: { Type::Scalar(data.type_) }, + arguments: { + dst: T, + src: T, + } + }, + Mad { + type: { Type::from(data.type_()) }, + data: MadDetails, + arguments: { + dst: { + repr: T, + type: { Type::from(data.dst_type()) }, + }, + src1: T, + src2: T, + src3: T, + } + }, Trap { } } ); @@ -588,16 +609,14 @@ pub enum MulDetails { } impl MulDetails { - #[allow(unused)] // Used by generated code - fn type_(&self) -> ScalarType { + pub fn type_(&self) -> ScalarType { match self { MulDetails::Integer { type_, .. } => *type_, MulDetails::Float(arith) => arith.type_, } } - #[allow(unused)] // Used by generated code - fn dst_type(&self) -> ScalarType { + pub fn dst_type(&self) -> ScalarType { match self { MulDetails::Integer { type_, @@ -995,3 +1014,45 @@ pub enum CvtaDirection { GenericToExplicit, ExplicitToGeneric, } + +#[derive(Copy, Clone)] +pub struct AbsDetails { + pub flush_to_zero: Option, + pub type_: ScalarType, +} + +#[derive(Copy, Clone)] +pub enum MadDetails { + Integer { + control: MulIntControl, + saturate: bool, + type_: ScalarType, + }, + Float(ArithFloat), +} + +impl MadDetails { + pub fn dst_type(&self) -> ScalarType { + match self { + MadDetails::Integer { + type_, + control: MulIntControl::Wide, + .. + } => match type_ { + ScalarType::U16 => ScalarType::U32, + ScalarType::S16 => ScalarType::S32, + ScalarType::U32 => ScalarType::U64, + ScalarType::S32 => ScalarType::S64, + _ => unreachable!(), + }, + _ => self.type_(), + } + } + + fn type_(&self) -> ScalarType { + match self { + MadDetails::Integer { type_, .. } => *type_, + MadDetails::Float(arith) => arith.type_, + } + } +} diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/main.rs index 03360f3..ce1f56d 100644 --- a/ptx_parser/src/main.rs +++ b/ptx_parser/src/main.rs @@ -1450,6 +1450,8 @@ derive_parser!( ScalarType = { .f16, .f16x2, .bf16, .bf16x2 }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mul + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul mul.mode.type d, a, b => { ast::Instruction::Mul { data: ast::MulDetails::Integer { @@ -1476,8 +1478,6 @@ derive_parser!( .s16, .s32 }; RawMulIntControl = { .wide }; - - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mul mul{.rnd}{.ftz}{.sat}.f32 d, a, b => { ast::Instruction::Mul { data: ast::MulDetails::Float ( @@ -1507,7 +1507,6 @@ derive_parser!( .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; ScalarType = { .f32, .f64 }; - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul mul{.rnd}{.ftz}{.sat}.f16 d, a, b => { ast::Instruction::Mul { data: ast::MulDetails::Float ( @@ -1706,6 +1705,174 @@ derive_parser!( .space: StateSpace = { .const, .global, .local, .shared{::cta, ::cluster}, .param{::entry} }; .size: ScalarType = { .u32, .u64 }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-abs + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-abs + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-abs + abs.type d, a => { + ast::Instruction::Abs { + data: ast::AbsDetails { + flush_to_zero: None, + type_ + }, + arguments: ast::AbsArgs { + dst: d, src: a + } + } + } + abs{.ftz}.f32 d, a => { + ast::Instruction::Abs { + data: ast::AbsDetails { + flush_to_zero: Some(ftz), + type_: f32 + }, + arguments: ast::AbsArgs { + dst: d, src: a + } + } + } + abs.f64 d, a => { + ast::Instruction::Abs { + data: ast::AbsDetails { + flush_to_zero: None, + type_: f64 + }, + arguments: ast::AbsArgs { + dst: d, src: a + } + } + } + abs{.ftz}.f16 d, a => { + ast::Instruction::Abs { + data: ast::AbsDetails { + flush_to_zero: Some(ftz), + type_: f16 + }, + arguments: ast::AbsArgs { + dst: d, src: a + } + } + } + abs{.ftz}.f16x2 d, a => { + ast::Instruction::Abs { + data: ast::AbsDetails { + flush_to_zero: Some(ftz), + type_: f16x2 + }, + arguments: ast::AbsArgs { + dst: d, src: a + } + } + } + abs.bf16 d, a => { + ast::Instruction::Abs { + data: ast::AbsDetails { + flush_to_zero: None, + type_: bf16 + }, + arguments: ast::AbsArgs { + dst: d, src: a + } + } + } + abs.bf16x2 d, a => { + ast::Instruction::Abs { + data: ast::AbsDetails { + flush_to_zero: None, + type_: bf16x2 + }, + arguments: ast::AbsArgs { + dst: d, src: a + } + } + } + .type: ScalarType = { .s16, .s32, .s64 }; + ScalarType = { .f32, .f64, .f16, .f16x2, .bf16, .bf16x2 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mad + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mad + mad.mode.type d, a, b, c => { + ast::Instruction::Mad { + data: ast::MadDetails::Integer { + type_, + control: mode.into(), + saturate: false + }, + arguments: MadArgs { dst: d, src1: a, src2: b, src3: c } + } + } + .type: ScalarType = { .u16, .u32, .u64, + .s16, .s32, .s64 }; + .mode: RawMulIntControl = { .hi, .lo }; + + // The .wide suffix is supported only for 16-bit and 32-bit integer types. + mad.wide.type d, a, b, c => { + ast::Instruction::Mad { + data: ast::MadDetails::Integer { + type_, + control: wide.into(), + saturate: false + }, + arguments: MadArgs { dst: d, src1: a, src2: b, src3: c } + } + } + .type: ScalarType = { .u16, .u32, + .s16, .s32 }; + RawMulIntControl = { .wide }; + + mad.hi.sat.s32 d, a, b, c => { + ast::Instruction::Mad { + data: ast::MadDetails::Integer { + type_: s32, + control: hi.into(), + saturate: true + }, + arguments: MadArgs { dst: d, src1: a, src2: b, src3: c } + } + } + RawMulIntControl = { .hi }; + ScalarType = { .s32 }; + + mad{.ftz}{.sat}.f32 d, a, b, c => { + ast::Instruction::Mad { + data: ast::MadDetails::Float( + ArithFloat { + type_: f32, + rounding: None, + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: MadArgs { dst: d, src1: a, src2: b, src3: c } + } + } + mad.rnd{.ftz}{.sat}.f32 d, a, b, c => { + ast::Instruction::Mad { + data: ast::MadDetails::Float( + ArithFloat { + type_: f32, + rounding: Some(rnd.into()), + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: MadArgs { dst: d, src1: a, src2: b, src3: c } + } + } + mad.rnd.f64 d, a, b, c => { + ast::Instruction::Mad { + data: ast::MadDetails::Float( + ArithFloat { + type_: f64, + rounding: Some(rnd.into()), + flush_to_zero: None, + saturate: false + } + ), + arguments: MadArgs { dst: d, src1: a, src2: b, src3: c } + }} + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret ret{.uni} => { Instruction::Ret { data: RetData { uniform: uni } }