From cb64b04f41b39a8b4740fe1c3e9450a05a90d950 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sun, 18 Aug 2024 23:27:07 +0200 Subject: [PATCH] Add mul --- ptx_parser/src/ast.rs | 99 ++++++++++++++++++++++++------- ptx_parser/src/main.rs | 130 +++++++++++++++++++++++++++++++++++++++-- 2 files changed, 202 insertions(+), 27 deletions(-) diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 2dabf3e..714c9b3 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -1,3 +1,5 @@ +use std::intrinsics::unreachable; + use super::{MemScope, ScalarType, StateSpace, VectorPrefix}; use bitflags::bitflags; @@ -8,25 +10,6 @@ pub enum Statement { Block(Vec>), } -pub struct MultiVariable { - pub var: Variable, - pub count: Option, -} - -#[derive(Clone)] -pub struct Variable { - pub align: Option, - pub v_type: Type, - pub state_space: StateSpace, - pub name: ID, - pub array_init: Vec, -} - -pub struct PredAt { - pub not: bool, - pub label: ID, -} - gen::generate_instruction_type!( pub enum Instruction { Mov { @@ -68,6 +51,18 @@ gen::generate_instruction_type!( src2: T, } }, + Mul { + type: { data.type_().into() }, + data: MulDetails, + arguments: { + dst: { + repr: T, + type: { data.dst_type().into() }, + }, + src1: T, + src2: T, + } + }, Ret { data: RetData }, @@ -75,6 +70,25 @@ gen::generate_instruction_type!( } ); +pub struct MultiVariable { + pub var: Variable, + pub count: Option, +} + +#[derive(Clone)] +pub struct Variable { + pub align: Option, + pub v_type: Type, + pub state_space: StateSpace, + pub name: ID, + pub array_init: Vec, +} + +pub struct PredAt { + pub not: bool, + pub label: ID, +} + pub trait Visitor { fn visit(&mut self, args: &T, type_: &Type, space: StateSpace, is_dst: bool); } @@ -185,7 +199,7 @@ pub enum ArithDetails { } impl ArithDetails { - pub fn type_(&self) -> super::ScalarType { + pub fn type_(&self) -> ScalarType { match self { ArithDetails::Integer(t) => t.type_, ArithDetails::Float(arith) => arith.type_, @@ -195,13 +209,13 @@ impl ArithDetails { #[derive(Copy, Clone)] pub struct ArithInteger { - pub type_: super::ScalarType, + pub type_: ScalarType, pub saturate: bool, } #[derive(Copy, Clone)] pub struct ArithFloat { - pub type_: super::ScalarType, + pub type_: ScalarType, pub rounding: Option, pub flush_to_zero: Option, pub saturate: bool, @@ -292,3 +306,44 @@ pub struct Module<'input> { pub version: (u8, u8), pub directives: Vec>>, } + +#[derive(Copy, Clone)] +pub enum MulDetails { + Integer { + type_: ScalarType, + control: MulIntControl, + }, + Float(ArithFloat), +} + +impl MulDetails { + fn type_(&self) -> ScalarType { + match self { + MulDetails::Integer { type_, .. } => *type_, + MulDetails::Float(arith) => arith.type_, + } + } + + fn dst_type(&self) -> ScalarType { + match self { + MulDetails::Integer { + type_, + control: MulIntControl::Wide, + } => match type_ { + ScalarType::U16 => ScalarType::U32, + ScalarType::S16 => ScalarType::S32, + ScalarType::U32 => ScalarType::U64, + ScalarType::S32 => ScalarType::S64, + _ => unreachable!(), + }, + _ => self.type_(), + } + } +} + +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum MulIntControl { + Low, + High, + Wide, +} diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/main.rs index 0ac1260..b087fb9 100644 --- a/ptx_parser/src/main.rs +++ b/ptx_parser/src/main.rs @@ -16,6 +16,16 @@ use winnow::{prelude::*, Stateful}; mod ast; pub use ast::*; +impl From for ast::MulIntControl { + fn from(value: RawMulIntControl) -> Self { + match value { + RawMulIntControl::Lo => ast::MulIntControl::Low, + RawMulIntControl::Hi => ast::MulIntControl::High, + RawMulIntControl::Wide => ast::MulIntControl::Wide, + } + } +} + impl From for ast::StCacheOperator { fn from(value: RawStCacheOperator) -> Self { match value { @@ -1066,7 +1076,6 @@ derive_parser!( arguments: ast::StArgs { src1:a, src2:b } } } - .ss: StateSpace = { .global, .local, .param{::func}, .shared{::cta, ::cluster} }; .level::eviction_priority: EvictionPriority = { .L1::evict_normal, .L1::evict_unchanged, .L1::evict_first, .L1::evict_last, .L1::no_allocate }; @@ -1156,7 +1165,6 @@ derive_parser!( arguments: LdArgs { dst:d, src:a } } } - .ss: StateSpace = { .const, .global, .local, .param{::entry, ::func}, .shared{::cta, ::cluster} }; .cop: RawLdCacheOperator = { .ca, .cg, .cs, .lu, .cv }; .level::eviction_priority: EvictionPriority = @@ -1199,7 +1207,6 @@ derive_parser!( } } } - .type: ScalarType = { .u16, .u32, .u64, .s16, .s64, .u16x2, .s16x2 }; @@ -1236,7 +1243,6 @@ derive_parser!( } } } - .rnd: RawFloatRounding = { .rn, .rz, .rm, .rp }; ScalarType = { .f32, .f64 }; @@ -1301,10 +1307,124 @@ derive_parser!( } } } - .rnd: RawFloatRounding = { .rn }; ScalarType = { .f16, .f16x2, .bf16, .bf16x2 }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul + mul.mode.type d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Integer { + type_, + control: mode.into() + }, + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + .mode: RawMulIntControl = { .hi, .lo }; + .type: ScalarType = { .u16, .u32, .u64, + .s16, .s32, .s64 }; + // "The .wide suffix is supported only for 16- and 32-bit integer types" + mul.wide.type d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Integer { + type_, + control: wide.into() + }, + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + .type: ScalarType = { .u16, .u32, + .s16, .s32 }; + RawMulIntControl = { .wide }; + + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mul + mul{.rnd}{.ftz}{.sat}.f32 d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Float ( + ArithFloat { + type_: f32, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat, + } + ), + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + mul{.rnd}.f64 d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Float ( + ArithFloat { + type_: f64, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false, + } + ), + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + .rnd: RawFloatRounding = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul + mul{.rnd}{.ftz}{.sat}.f16 d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Float ( + ArithFloat { + type_: f16, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat, + } + ), + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + mul{.rnd}{.ftz}{.sat}.f16x2 d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Float ( + ArithFloat { + type_: f16x2, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat, + } + ), + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + mul{.rnd}.bf16 d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Float ( + ArithFloat { + type_: bf16, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false, + } + ), + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + mul{.rnd}.bf16x2 d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Float ( + ArithFloat { + type_: bf16x2, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false, + } + ), + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + .rnd: RawFloatRounding = { .rn }; + ScalarType = { .f16, .f16x2, .bf16, .bf16x2 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret ret{.uni} => { Instruction::Ret { data: RetData { uniform: uni } } }