This commit is contained in:
Andrzej Janik 2024-08-18 23:27:07 +02:00
parent 522541d5c5
commit cb64b04f41
2 changed files with 202 additions and 27 deletions

View file

@ -1,3 +1,5 @@
use std::intrinsics::unreachable;
use super::{MemScope, ScalarType, StateSpace, VectorPrefix};
use bitflags::bitflags;
@ -8,25 +10,6 @@ pub enum Statement<P: Operand> {
Block(Vec<Statement<P>>),
}
pub struct MultiVariable<ID> {
pub var: Variable<ID>,
pub count: Option<u32>,
}
#[derive(Clone)]
pub struct Variable<ID> {
pub align: Option<u32>,
pub v_type: Type,
pub state_space: StateSpace,
pub name: ID,
pub array_init: Vec<u8>,
}
pub struct PredAt<ID> {
pub not: bool,
pub label: ID,
}
gen::generate_instruction_type!(
pub enum Instruction<T> {
Mov {
@ -68,6 +51,18 @@ gen::generate_instruction_type!(
src2: T,
}
},
Mul {
type: { data.type_().into() },
data: MulDetails,
arguments<T>: {
dst: {
repr: T,
type: { data.dst_type().into() },
},
src1: T,
src2: T,
}
},
Ret {
data: RetData
},
@ -75,6 +70,25 @@ gen::generate_instruction_type!(
}
);
pub struct MultiVariable<ID> {
pub var: Variable<ID>,
pub count: Option<u32>,
}
#[derive(Clone)]
pub struct Variable<ID> {
pub align: Option<u32>,
pub v_type: Type,
pub state_space: StateSpace,
pub name: ID,
pub array_init: Vec<u8>,
}
pub struct PredAt<ID> {
pub not: bool,
pub label: ID,
}
pub trait Visitor<T> {
fn visit(&mut self, args: &T, type_: &Type, space: StateSpace, is_dst: bool);
}
@ -185,7 +199,7 @@ pub enum ArithDetails {
}
impl ArithDetails {
pub fn type_(&self) -> super::ScalarType {
pub fn type_(&self) -> ScalarType {
match self {
ArithDetails::Integer(t) => t.type_,
ArithDetails::Float(arith) => arith.type_,
@ -195,13 +209,13 @@ impl ArithDetails {
#[derive(Copy, Clone)]
pub struct ArithInteger {
pub type_: super::ScalarType,
pub type_: ScalarType,
pub saturate: bool,
}
#[derive(Copy, Clone)]
pub struct ArithFloat {
pub type_: super::ScalarType,
pub type_: ScalarType,
pub rounding: Option<RoundingMode>,
pub flush_to_zero: Option<bool>,
pub saturate: bool,
@ -292,3 +306,44 @@ pub struct Module<'input> {
pub version: (u8, u8),
pub directives: Vec<Directive<'input, ParsedOperand<&'input str>>>,
}
#[derive(Copy, Clone)]
pub enum MulDetails {
Integer {
type_: ScalarType,
control: MulIntControl,
},
Float(ArithFloat),
}
impl MulDetails {
fn type_(&self) -> ScalarType {
match self {
MulDetails::Integer { type_, .. } => *type_,
MulDetails::Float(arith) => arith.type_,
}
}
fn dst_type(&self) -> ScalarType {
match self {
MulDetails::Integer {
type_,
control: MulIntControl::Wide,
} => match type_ {
ScalarType::U16 => ScalarType::U32,
ScalarType::S16 => ScalarType::S32,
ScalarType::U32 => ScalarType::U64,
ScalarType::S32 => ScalarType::S64,
_ => unreachable!(),
},
_ => self.type_(),
}
}
}
#[derive(Copy, Clone, PartialEq, Eq)]
pub enum MulIntControl {
Low,
High,
Wide,
}

View file

@ -16,6 +16,16 @@ use winnow::{prelude::*, Stateful};
mod ast;
pub use ast::*;
impl From<RawMulIntControl> for ast::MulIntControl {
fn from(value: RawMulIntControl) -> Self {
match value {
RawMulIntControl::Lo => ast::MulIntControl::Low,
RawMulIntControl::Hi => ast::MulIntControl::High,
RawMulIntControl::Wide => ast::MulIntControl::Wide,
}
}
}
impl From<RawStCacheOperator> for ast::StCacheOperator {
fn from(value: RawStCacheOperator) -> Self {
match value {
@ -1066,7 +1076,6 @@ derive_parser!(
arguments: ast::StArgs { src1:a, src2:b }
}
}
.ss: StateSpace = { .global, .local, .param{::func}, .shared{::cta, ::cluster} };
.level::eviction_priority: EvictionPriority =
{ .L1::evict_normal, .L1::evict_unchanged, .L1::evict_first, .L1::evict_last, .L1::no_allocate };
@ -1156,7 +1165,6 @@ derive_parser!(
arguments: LdArgs { dst:d, src:a }
}
}
.ss: StateSpace = { .const, .global, .local, .param{::entry, ::func}, .shared{::cta, ::cluster} };
.cop: RawLdCacheOperator = { .ca, .cg, .cs, .lu, .cv };
.level::eviction_priority: EvictionPriority =
@ -1199,7 +1207,6 @@ derive_parser!(
}
}
}
.type: ScalarType = { .u16, .u32, .u64,
.s16, .s64,
.u16x2, .s16x2 };
@ -1236,7 +1243,6 @@ derive_parser!(
}
}
}
.rnd: RawFloatRounding = { .rn, .rz, .rm, .rp };
ScalarType = { .f32, .f64 };
@ -1301,10 +1307,124 @@ derive_parser!(
}
}
}
.rnd: RawFloatRounding = { .rn };
ScalarType = { .f16, .f16x2, .bf16, .bf16x2 };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul
mul.mode.type d, a, b => {
ast::Instruction::Mul {
data: ast::MulDetails::Integer {
type_,
control: mode.into()
},
arguments: MulArgs { dst: d, src1: a, src2: b }
}
}
.mode: RawMulIntControl = { .hi, .lo };
.type: ScalarType = { .u16, .u32, .u64,
.s16, .s32, .s64 };
// "The .wide suffix is supported only for 16- and 32-bit integer types"
mul.wide.type d, a, b => {
ast::Instruction::Mul {
data: ast::MulDetails::Integer {
type_,
control: wide.into()
},
arguments: MulArgs { dst: d, src1: a, src2: b }
}
}
.type: ScalarType = { .u16, .u32,
.s16, .s32 };
RawMulIntControl = { .wide };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mul
mul{.rnd}{.ftz}{.sat}.f32 d, a, b => {
ast::Instruction::Mul {
data: ast::MulDetails::Float (
ArithFloat {
type_: f32,
rounding: rnd.map(Into::into),
flush_to_zero: Some(ftz),
saturate: sat,
}
),
arguments: MulArgs { dst: d, src1: a, src2: b }
}
}
mul{.rnd}.f64 d, a, b => {
ast::Instruction::Mul {
data: ast::MulDetails::Float (
ArithFloat {
type_: f64,
rounding: rnd.map(Into::into),
flush_to_zero: None,
saturate: false,
}
),
arguments: MulArgs { dst: d, src1: a, src2: b }
}
}
.rnd: RawFloatRounding = { .rn, .rz, .rm, .rp };
ScalarType = { .f32, .f64 };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul
mul{.rnd}{.ftz}{.sat}.f16 d, a, b => {
ast::Instruction::Mul {
data: ast::MulDetails::Float (
ArithFloat {
type_: f16,
rounding: rnd.map(Into::into),
flush_to_zero: Some(ftz),
saturate: sat,
}
),
arguments: MulArgs { dst: d, src1: a, src2: b }
}
}
mul{.rnd}{.ftz}{.sat}.f16x2 d, a, b => {
ast::Instruction::Mul {
data: ast::MulDetails::Float (
ArithFloat {
type_: f16x2,
rounding: rnd.map(Into::into),
flush_to_zero: Some(ftz),
saturate: sat,
}
),
arguments: MulArgs { dst: d, src1: a, src2: b }
}
}
mul{.rnd}.bf16 d, a, b => {
ast::Instruction::Mul {
data: ast::MulDetails::Float (
ArithFloat {
type_: bf16,
rounding: rnd.map(Into::into),
flush_to_zero: None,
saturate: false,
}
),
arguments: MulArgs { dst: d, src1: a, src2: b }
}
}
mul{.rnd}.bf16x2 d, a, b => {
ast::Instruction::Mul {
data: ast::MulDetails::Float (
ArithFloat {
type_: bf16x2,
rounding: rnd.map(Into::into),
flush_to_zero: None,
saturate: false,
}
),
arguments: MulArgs { dst: d, src1: a, src2: b }
}
}
.rnd: RawFloatRounding = { .rn };
ScalarType = { .f16, .f16x2, .bf16, .bf16x2 };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret
ret{.uni} => {
Instruction::Ret { data: RetData { uniform: uni } }
}