mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-08-02 14:19:57 +00:00
Add mul
This commit is contained in:
parent
522541d5c5
commit
cb64b04f41
2 changed files with 202 additions and 27 deletions
|
@ -1,3 +1,5 @@
|
||||||
|
use std::intrinsics::unreachable;
|
||||||
|
|
||||||
use super::{MemScope, ScalarType, StateSpace, VectorPrefix};
|
use super::{MemScope, ScalarType, StateSpace, VectorPrefix};
|
||||||
use bitflags::bitflags;
|
use bitflags::bitflags;
|
||||||
|
|
||||||
|
@ -8,25 +10,6 @@ pub enum Statement<P: Operand> {
|
||||||
Block(Vec<Statement<P>>),
|
Block(Vec<Statement<P>>),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct MultiVariable<ID> {
|
|
||||||
pub var: Variable<ID>,
|
|
||||||
pub count: Option<u32>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct Variable<ID> {
|
|
||||||
pub align: Option<u32>,
|
|
||||||
pub v_type: Type,
|
|
||||||
pub state_space: StateSpace,
|
|
||||||
pub name: ID,
|
|
||||||
pub array_init: Vec<u8>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct PredAt<ID> {
|
|
||||||
pub not: bool,
|
|
||||||
pub label: ID,
|
|
||||||
}
|
|
||||||
|
|
||||||
gen::generate_instruction_type!(
|
gen::generate_instruction_type!(
|
||||||
pub enum Instruction<T> {
|
pub enum Instruction<T> {
|
||||||
Mov {
|
Mov {
|
||||||
|
@ -68,6 +51,18 @@ gen::generate_instruction_type!(
|
||||||
src2: T,
|
src2: T,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
Mul {
|
||||||
|
type: { data.type_().into() },
|
||||||
|
data: MulDetails,
|
||||||
|
arguments<T>: {
|
||||||
|
dst: {
|
||||||
|
repr: T,
|
||||||
|
type: { data.dst_type().into() },
|
||||||
|
},
|
||||||
|
src1: T,
|
||||||
|
src2: T,
|
||||||
|
}
|
||||||
|
},
|
||||||
Ret {
|
Ret {
|
||||||
data: RetData
|
data: RetData
|
||||||
},
|
},
|
||||||
|
@ -75,6 +70,25 @@ gen::generate_instruction_type!(
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
|
pub struct MultiVariable<ID> {
|
||||||
|
pub var: Variable<ID>,
|
||||||
|
pub count: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct Variable<ID> {
|
||||||
|
pub align: Option<u32>,
|
||||||
|
pub v_type: Type,
|
||||||
|
pub state_space: StateSpace,
|
||||||
|
pub name: ID,
|
||||||
|
pub array_init: Vec<u8>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct PredAt<ID> {
|
||||||
|
pub not: bool,
|
||||||
|
pub label: ID,
|
||||||
|
}
|
||||||
|
|
||||||
pub trait Visitor<T> {
|
pub trait Visitor<T> {
|
||||||
fn visit(&mut self, args: &T, type_: &Type, space: StateSpace, is_dst: bool);
|
fn visit(&mut self, args: &T, type_: &Type, space: StateSpace, is_dst: bool);
|
||||||
}
|
}
|
||||||
|
@ -185,7 +199,7 @@ pub enum ArithDetails {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ArithDetails {
|
impl ArithDetails {
|
||||||
pub fn type_(&self) -> super::ScalarType {
|
pub fn type_(&self) -> ScalarType {
|
||||||
match self {
|
match self {
|
||||||
ArithDetails::Integer(t) => t.type_,
|
ArithDetails::Integer(t) => t.type_,
|
||||||
ArithDetails::Float(arith) => arith.type_,
|
ArithDetails::Float(arith) => arith.type_,
|
||||||
|
@ -195,13 +209,13 @@ impl ArithDetails {
|
||||||
|
|
||||||
#[derive(Copy, Clone)]
|
#[derive(Copy, Clone)]
|
||||||
pub struct ArithInteger {
|
pub struct ArithInteger {
|
||||||
pub type_: super::ScalarType,
|
pub type_: ScalarType,
|
||||||
pub saturate: bool,
|
pub saturate: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Copy, Clone)]
|
#[derive(Copy, Clone)]
|
||||||
pub struct ArithFloat {
|
pub struct ArithFloat {
|
||||||
pub type_: super::ScalarType,
|
pub type_: ScalarType,
|
||||||
pub rounding: Option<RoundingMode>,
|
pub rounding: Option<RoundingMode>,
|
||||||
pub flush_to_zero: Option<bool>,
|
pub flush_to_zero: Option<bool>,
|
||||||
pub saturate: bool,
|
pub saturate: bool,
|
||||||
|
@ -292,3 +306,44 @@ pub struct Module<'input> {
|
||||||
pub version: (u8, u8),
|
pub version: (u8, u8),
|
||||||
pub directives: Vec<Directive<'input, ParsedOperand<&'input str>>>,
|
pub directives: Vec<Directive<'input, ParsedOperand<&'input str>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Copy, Clone)]
|
||||||
|
pub enum MulDetails {
|
||||||
|
Integer {
|
||||||
|
type_: ScalarType,
|
||||||
|
control: MulIntControl,
|
||||||
|
},
|
||||||
|
Float(ArithFloat),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MulDetails {
|
||||||
|
fn type_(&self) -> ScalarType {
|
||||||
|
match self {
|
||||||
|
MulDetails::Integer { type_, .. } => *type_,
|
||||||
|
MulDetails::Float(arith) => arith.type_,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dst_type(&self) -> ScalarType {
|
||||||
|
match self {
|
||||||
|
MulDetails::Integer {
|
||||||
|
type_,
|
||||||
|
control: MulIntControl::Wide,
|
||||||
|
} => match type_ {
|
||||||
|
ScalarType::U16 => ScalarType::U32,
|
||||||
|
ScalarType::S16 => ScalarType::S32,
|
||||||
|
ScalarType::U32 => ScalarType::U64,
|
||||||
|
ScalarType::S32 => ScalarType::S64,
|
||||||
|
_ => unreachable!(),
|
||||||
|
},
|
||||||
|
_ => self.type_(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Copy, Clone, PartialEq, Eq)]
|
||||||
|
pub enum MulIntControl {
|
||||||
|
Low,
|
||||||
|
High,
|
||||||
|
Wide,
|
||||||
|
}
|
||||||
|
|
|
@ -16,6 +16,16 @@ use winnow::{prelude::*, Stateful};
|
||||||
mod ast;
|
mod ast;
|
||||||
pub use ast::*;
|
pub use ast::*;
|
||||||
|
|
||||||
|
impl From<RawMulIntControl> for ast::MulIntControl {
|
||||||
|
fn from(value: RawMulIntControl) -> Self {
|
||||||
|
match value {
|
||||||
|
RawMulIntControl::Lo => ast::MulIntControl::Low,
|
||||||
|
RawMulIntControl::Hi => ast::MulIntControl::High,
|
||||||
|
RawMulIntControl::Wide => ast::MulIntControl::Wide,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl From<RawStCacheOperator> for ast::StCacheOperator {
|
impl From<RawStCacheOperator> for ast::StCacheOperator {
|
||||||
fn from(value: RawStCacheOperator) -> Self {
|
fn from(value: RawStCacheOperator) -> Self {
|
||||||
match value {
|
match value {
|
||||||
|
@ -1066,7 +1076,6 @@ derive_parser!(
|
||||||
arguments: ast::StArgs { src1:a, src2:b }
|
arguments: ast::StArgs { src1:a, src2:b }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
.ss: StateSpace = { .global, .local, .param{::func}, .shared{::cta, ::cluster} };
|
.ss: StateSpace = { .global, .local, .param{::func}, .shared{::cta, ::cluster} };
|
||||||
.level::eviction_priority: EvictionPriority =
|
.level::eviction_priority: EvictionPriority =
|
||||||
{ .L1::evict_normal, .L1::evict_unchanged, .L1::evict_first, .L1::evict_last, .L1::no_allocate };
|
{ .L1::evict_normal, .L1::evict_unchanged, .L1::evict_first, .L1::evict_last, .L1::no_allocate };
|
||||||
|
@ -1156,7 +1165,6 @@ derive_parser!(
|
||||||
arguments: LdArgs { dst:d, src:a }
|
arguments: LdArgs { dst:d, src:a }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
.ss: StateSpace = { .const, .global, .local, .param{::entry, ::func}, .shared{::cta, ::cluster} };
|
.ss: StateSpace = { .const, .global, .local, .param{::entry, ::func}, .shared{::cta, ::cluster} };
|
||||||
.cop: RawLdCacheOperator = { .ca, .cg, .cs, .lu, .cv };
|
.cop: RawLdCacheOperator = { .ca, .cg, .cs, .lu, .cv };
|
||||||
.level::eviction_priority: EvictionPriority =
|
.level::eviction_priority: EvictionPriority =
|
||||||
|
@ -1199,7 +1207,6 @@ derive_parser!(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
.type: ScalarType = { .u16, .u32, .u64,
|
.type: ScalarType = { .u16, .u32, .u64,
|
||||||
.s16, .s64,
|
.s16, .s64,
|
||||||
.u16x2, .s16x2 };
|
.u16x2, .s16x2 };
|
||||||
|
@ -1236,7 +1243,6 @@ derive_parser!(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
.rnd: RawFloatRounding = { .rn, .rz, .rm, .rp };
|
.rnd: RawFloatRounding = { .rn, .rz, .rm, .rp };
|
||||||
ScalarType = { .f32, .f64 };
|
ScalarType = { .f32, .f64 };
|
||||||
|
|
||||||
|
@ -1301,10 +1307,124 @@ derive_parser!(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
.rnd: RawFloatRounding = { .rn };
|
.rnd: RawFloatRounding = { .rn };
|
||||||
ScalarType = { .f16, .f16x2, .bf16, .bf16x2 };
|
ScalarType = { .f16, .f16x2, .bf16, .bf16x2 };
|
||||||
|
|
||||||
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul
|
||||||
|
mul.mode.type d, a, b => {
|
||||||
|
ast::Instruction::Mul {
|
||||||
|
data: ast::MulDetails::Integer {
|
||||||
|
type_,
|
||||||
|
control: mode.into()
|
||||||
|
},
|
||||||
|
arguments: MulArgs { dst: d, src1: a, src2: b }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.mode: RawMulIntControl = { .hi, .lo };
|
||||||
|
.type: ScalarType = { .u16, .u32, .u64,
|
||||||
|
.s16, .s32, .s64 };
|
||||||
|
// "The .wide suffix is supported only for 16- and 32-bit integer types"
|
||||||
|
mul.wide.type d, a, b => {
|
||||||
|
ast::Instruction::Mul {
|
||||||
|
data: ast::MulDetails::Integer {
|
||||||
|
type_,
|
||||||
|
control: wide.into()
|
||||||
|
},
|
||||||
|
arguments: MulArgs { dst: d, src1: a, src2: b }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.type: ScalarType = { .u16, .u32,
|
||||||
|
.s16, .s32 };
|
||||||
|
RawMulIntControl = { .wide };
|
||||||
|
|
||||||
|
|
||||||
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mul
|
||||||
|
mul{.rnd}{.ftz}{.sat}.f32 d, a, b => {
|
||||||
|
ast::Instruction::Mul {
|
||||||
|
data: ast::MulDetails::Float (
|
||||||
|
ArithFloat {
|
||||||
|
type_: f32,
|
||||||
|
rounding: rnd.map(Into::into),
|
||||||
|
flush_to_zero: Some(ftz),
|
||||||
|
saturate: sat,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
arguments: MulArgs { dst: d, src1: a, src2: b }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mul{.rnd}.f64 d, a, b => {
|
||||||
|
ast::Instruction::Mul {
|
||||||
|
data: ast::MulDetails::Float (
|
||||||
|
ArithFloat {
|
||||||
|
type_: f64,
|
||||||
|
rounding: rnd.map(Into::into),
|
||||||
|
flush_to_zero: None,
|
||||||
|
saturate: false,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
arguments: MulArgs { dst: d, src1: a, src2: b }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.rnd: RawFloatRounding = { .rn, .rz, .rm, .rp };
|
||||||
|
ScalarType = { .f32, .f64 };
|
||||||
|
|
||||||
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul
|
||||||
|
mul{.rnd}{.ftz}{.sat}.f16 d, a, b => {
|
||||||
|
ast::Instruction::Mul {
|
||||||
|
data: ast::MulDetails::Float (
|
||||||
|
ArithFloat {
|
||||||
|
type_: f16,
|
||||||
|
rounding: rnd.map(Into::into),
|
||||||
|
flush_to_zero: Some(ftz),
|
||||||
|
saturate: sat,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
arguments: MulArgs { dst: d, src1: a, src2: b }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mul{.rnd}{.ftz}{.sat}.f16x2 d, a, b => {
|
||||||
|
ast::Instruction::Mul {
|
||||||
|
data: ast::MulDetails::Float (
|
||||||
|
ArithFloat {
|
||||||
|
type_: f16x2,
|
||||||
|
rounding: rnd.map(Into::into),
|
||||||
|
flush_to_zero: Some(ftz),
|
||||||
|
saturate: sat,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
arguments: MulArgs { dst: d, src1: a, src2: b }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mul{.rnd}.bf16 d, a, b => {
|
||||||
|
ast::Instruction::Mul {
|
||||||
|
data: ast::MulDetails::Float (
|
||||||
|
ArithFloat {
|
||||||
|
type_: bf16,
|
||||||
|
rounding: rnd.map(Into::into),
|
||||||
|
flush_to_zero: None,
|
||||||
|
saturate: false,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
arguments: MulArgs { dst: d, src1: a, src2: b }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mul{.rnd}.bf16x2 d, a, b => {
|
||||||
|
ast::Instruction::Mul {
|
||||||
|
data: ast::MulDetails::Float (
|
||||||
|
ArithFloat {
|
||||||
|
type_: bf16x2,
|
||||||
|
rounding: rnd.map(Into::into),
|
||||||
|
flush_to_zero: None,
|
||||||
|
saturate: false,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
arguments: MulArgs { dst: d, src1: a, src2: b }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.rnd: RawFloatRounding = { .rn };
|
||||||
|
ScalarType = { .f16, .f16x2, .bf16, .bf16x2 };
|
||||||
|
|
||||||
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret
|
||||||
ret{.uni} => {
|
ret{.uni} => {
|
||||||
Instruction::Ret { data: RetData { uniform: uni } }
|
Instruction::Ret { data: RetData { uniform: uni } }
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue