mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
Add mul
This commit is contained in:
parent
522541d5c5
commit
cb64b04f41
2 changed files with 202 additions and 27 deletions
|
@ -1,3 +1,5 @@
|
|||
use std::intrinsics::unreachable;
|
||||
|
||||
use super::{MemScope, ScalarType, StateSpace, VectorPrefix};
|
||||
use bitflags::bitflags;
|
||||
|
||||
|
@ -8,25 +10,6 @@ pub enum Statement<P: Operand> {
|
|||
Block(Vec<Statement<P>>),
|
||||
}
|
||||
|
||||
pub struct MultiVariable<ID> {
|
||||
pub var: Variable<ID>,
|
||||
pub count: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Variable<ID> {
|
||||
pub align: Option<u32>,
|
||||
pub v_type: Type,
|
||||
pub state_space: StateSpace,
|
||||
pub name: ID,
|
||||
pub array_init: Vec<u8>,
|
||||
}
|
||||
|
||||
pub struct PredAt<ID> {
|
||||
pub not: bool,
|
||||
pub label: ID,
|
||||
}
|
||||
|
||||
gen::generate_instruction_type!(
|
||||
pub enum Instruction<T> {
|
||||
Mov {
|
||||
|
@ -68,6 +51,18 @@ gen::generate_instruction_type!(
|
|||
src2: T,
|
||||
}
|
||||
},
|
||||
Mul {
|
||||
type: { data.type_().into() },
|
||||
data: MulDetails,
|
||||
arguments<T>: {
|
||||
dst: {
|
||||
repr: T,
|
||||
type: { data.dst_type().into() },
|
||||
},
|
||||
src1: T,
|
||||
src2: T,
|
||||
}
|
||||
},
|
||||
Ret {
|
||||
data: RetData
|
||||
},
|
||||
|
@ -75,6 +70,25 @@ gen::generate_instruction_type!(
|
|||
}
|
||||
);
|
||||
|
||||
pub struct MultiVariable<ID> {
|
||||
pub var: Variable<ID>,
|
||||
pub count: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Variable<ID> {
|
||||
pub align: Option<u32>,
|
||||
pub v_type: Type,
|
||||
pub state_space: StateSpace,
|
||||
pub name: ID,
|
||||
pub array_init: Vec<u8>,
|
||||
}
|
||||
|
||||
pub struct PredAt<ID> {
|
||||
pub not: bool,
|
||||
pub label: ID,
|
||||
}
|
||||
|
||||
pub trait Visitor<T> {
|
||||
fn visit(&mut self, args: &T, type_: &Type, space: StateSpace, is_dst: bool);
|
||||
}
|
||||
|
@ -185,7 +199,7 @@ pub enum ArithDetails {
|
|||
}
|
||||
|
||||
impl ArithDetails {
|
||||
pub fn type_(&self) -> super::ScalarType {
|
||||
pub fn type_(&self) -> ScalarType {
|
||||
match self {
|
||||
ArithDetails::Integer(t) => t.type_,
|
||||
ArithDetails::Float(arith) => arith.type_,
|
||||
|
@ -195,13 +209,13 @@ impl ArithDetails {
|
|||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct ArithInteger {
|
||||
pub type_: super::ScalarType,
|
||||
pub type_: ScalarType,
|
||||
pub saturate: bool,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct ArithFloat {
|
||||
pub type_: super::ScalarType,
|
||||
pub type_: ScalarType,
|
||||
pub rounding: Option<RoundingMode>,
|
||||
pub flush_to_zero: Option<bool>,
|
||||
pub saturate: bool,
|
||||
|
@ -292,3 +306,44 @@ pub struct Module<'input> {
|
|||
pub version: (u8, u8),
|
||||
pub directives: Vec<Directive<'input, ParsedOperand<&'input str>>>,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub enum MulDetails {
|
||||
Integer {
|
||||
type_: ScalarType,
|
||||
control: MulIntControl,
|
||||
},
|
||||
Float(ArithFloat),
|
||||
}
|
||||
|
||||
impl MulDetails {
|
||||
fn type_(&self) -> ScalarType {
|
||||
match self {
|
||||
MulDetails::Integer { type_, .. } => *type_,
|
||||
MulDetails::Float(arith) => arith.type_,
|
||||
}
|
||||
}
|
||||
|
||||
fn dst_type(&self) -> ScalarType {
|
||||
match self {
|
||||
MulDetails::Integer {
|
||||
type_,
|
||||
control: MulIntControl::Wide,
|
||||
} => match type_ {
|
||||
ScalarType::U16 => ScalarType::U32,
|
||||
ScalarType::S16 => ScalarType::S32,
|
||||
ScalarType::U32 => ScalarType::U64,
|
||||
ScalarType::S32 => ScalarType::S64,
|
||||
_ => unreachable!(),
|
||||
},
|
||||
_ => self.type_(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq)]
|
||||
pub enum MulIntControl {
|
||||
Low,
|
||||
High,
|
||||
Wide,
|
||||
}
|
||||
|
|
|
@ -16,6 +16,16 @@ use winnow::{prelude::*, Stateful};
|
|||
mod ast;
|
||||
pub use ast::*;
|
||||
|
||||
impl From<RawMulIntControl> for ast::MulIntControl {
|
||||
fn from(value: RawMulIntControl) -> Self {
|
||||
match value {
|
||||
RawMulIntControl::Lo => ast::MulIntControl::Low,
|
||||
RawMulIntControl::Hi => ast::MulIntControl::High,
|
||||
RawMulIntControl::Wide => ast::MulIntControl::Wide,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RawStCacheOperator> for ast::StCacheOperator {
|
||||
fn from(value: RawStCacheOperator) -> Self {
|
||||
match value {
|
||||
|
@ -1066,7 +1076,6 @@ derive_parser!(
|
|||
arguments: ast::StArgs { src1:a, src2:b }
|
||||
}
|
||||
}
|
||||
|
||||
.ss: StateSpace = { .global, .local, .param{::func}, .shared{::cta, ::cluster} };
|
||||
.level::eviction_priority: EvictionPriority =
|
||||
{ .L1::evict_normal, .L1::evict_unchanged, .L1::evict_first, .L1::evict_last, .L1::no_allocate };
|
||||
|
@ -1156,7 +1165,6 @@ derive_parser!(
|
|||
arguments: LdArgs { dst:d, src:a }
|
||||
}
|
||||
}
|
||||
|
||||
.ss: StateSpace = { .const, .global, .local, .param{::entry, ::func}, .shared{::cta, ::cluster} };
|
||||
.cop: RawLdCacheOperator = { .ca, .cg, .cs, .lu, .cv };
|
||||
.level::eviction_priority: EvictionPriority =
|
||||
|
@ -1199,7 +1207,6 @@ derive_parser!(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
.type: ScalarType = { .u16, .u32, .u64,
|
||||
.s16, .s64,
|
||||
.u16x2, .s16x2 };
|
||||
|
@ -1236,7 +1243,6 @@ derive_parser!(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
.rnd: RawFloatRounding = { .rn, .rz, .rm, .rp };
|
||||
ScalarType = { .f32, .f64 };
|
||||
|
||||
|
@ -1301,10 +1307,124 @@ derive_parser!(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
.rnd: RawFloatRounding = { .rn };
|
||||
ScalarType = { .f16, .f16x2, .bf16, .bf16x2 };
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul
|
||||
mul.mode.type d, a, b => {
|
||||
ast::Instruction::Mul {
|
||||
data: ast::MulDetails::Integer {
|
||||
type_,
|
||||
control: mode.into()
|
||||
},
|
||||
arguments: MulArgs { dst: d, src1: a, src2: b }
|
||||
}
|
||||
}
|
||||
.mode: RawMulIntControl = { .hi, .lo };
|
||||
.type: ScalarType = { .u16, .u32, .u64,
|
||||
.s16, .s32, .s64 };
|
||||
// "The .wide suffix is supported only for 16- and 32-bit integer types"
|
||||
mul.wide.type d, a, b => {
|
||||
ast::Instruction::Mul {
|
||||
data: ast::MulDetails::Integer {
|
||||
type_,
|
||||
control: wide.into()
|
||||
},
|
||||
arguments: MulArgs { dst: d, src1: a, src2: b }
|
||||
}
|
||||
}
|
||||
.type: ScalarType = { .u16, .u32,
|
||||
.s16, .s32 };
|
||||
RawMulIntControl = { .wide };
|
||||
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mul
|
||||
mul{.rnd}{.ftz}{.sat}.f32 d, a, b => {
|
||||
ast::Instruction::Mul {
|
||||
data: ast::MulDetails::Float (
|
||||
ArithFloat {
|
||||
type_: f32,
|
||||
rounding: rnd.map(Into::into),
|
||||
flush_to_zero: Some(ftz),
|
||||
saturate: sat,
|
||||
}
|
||||
),
|
||||
arguments: MulArgs { dst: d, src1: a, src2: b }
|
||||
}
|
||||
}
|
||||
mul{.rnd}.f64 d, a, b => {
|
||||
ast::Instruction::Mul {
|
||||
data: ast::MulDetails::Float (
|
||||
ArithFloat {
|
||||
type_: f64,
|
||||
rounding: rnd.map(Into::into),
|
||||
flush_to_zero: None,
|
||||
saturate: false,
|
||||
}
|
||||
),
|
||||
arguments: MulArgs { dst: d, src1: a, src2: b }
|
||||
}
|
||||
}
|
||||
.rnd: RawFloatRounding = { .rn, .rz, .rm, .rp };
|
||||
ScalarType = { .f32, .f64 };
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul
|
||||
mul{.rnd}{.ftz}{.sat}.f16 d, a, b => {
|
||||
ast::Instruction::Mul {
|
||||
data: ast::MulDetails::Float (
|
||||
ArithFloat {
|
||||
type_: f16,
|
||||
rounding: rnd.map(Into::into),
|
||||
flush_to_zero: Some(ftz),
|
||||
saturate: sat,
|
||||
}
|
||||
),
|
||||
arguments: MulArgs { dst: d, src1: a, src2: b }
|
||||
}
|
||||
}
|
||||
mul{.rnd}{.ftz}{.sat}.f16x2 d, a, b => {
|
||||
ast::Instruction::Mul {
|
||||
data: ast::MulDetails::Float (
|
||||
ArithFloat {
|
||||
type_: f16x2,
|
||||
rounding: rnd.map(Into::into),
|
||||
flush_to_zero: Some(ftz),
|
||||
saturate: sat,
|
||||
}
|
||||
),
|
||||
arguments: MulArgs { dst: d, src1: a, src2: b }
|
||||
}
|
||||
}
|
||||
mul{.rnd}.bf16 d, a, b => {
|
||||
ast::Instruction::Mul {
|
||||
data: ast::MulDetails::Float (
|
||||
ArithFloat {
|
||||
type_: bf16,
|
||||
rounding: rnd.map(Into::into),
|
||||
flush_to_zero: None,
|
||||
saturate: false,
|
||||
}
|
||||
),
|
||||
arguments: MulArgs { dst: d, src1: a, src2: b }
|
||||
}
|
||||
}
|
||||
mul{.rnd}.bf16x2 d, a, b => {
|
||||
ast::Instruction::Mul {
|
||||
data: ast::MulDetails::Float (
|
||||
ArithFloat {
|
||||
type_: bf16x2,
|
||||
rounding: rnd.map(Into::into),
|
||||
flush_to_zero: None,
|
||||
saturate: false,
|
||||
}
|
||||
),
|
||||
arguments: MulArgs { dst: d, src1: a, src2: b }
|
||||
}
|
||||
}
|
||||
.rnd: RawFloatRounding = { .rn };
|
||||
ScalarType = { .f16, .f16x2, .bf16, .bf16x2 };
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret
|
||||
ret{.uni} => {
|
||||
Instruction::Ret { data: RetData { uniform: uni } }
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue