Add abs, mad

This commit is contained in:
Andrzej Janik 2024-08-21 02:45:52 +02:00
parent 588d66b236
commit 6cd18bfdb8
3 changed files with 237 additions and 9 deletions

View file

@ -73,7 +73,7 @@ pub struct OpcodeDecl(pub Instruction, pub Arguments);
impl OpcodeDecl {
fn peek(input: syn::parse::ParseStream) -> bool {
Instruction::peek(input)
Instruction::peek(input) && !input.peek2(Token![=])
}
}
@ -106,7 +106,7 @@ impl Parse for CodeBlock {
} else {
return Err(lookahead.error());
};
Ok(Self{special, code})
Ok(Self { special, code })
}
}

View file

@ -200,6 +200,27 @@ gen::generate_instruction_type!(
src: T,
}
},
Abs {
data: AbsDetails,
type: { Type::Scalar(data.type_) },
arguments<T>: {
dst: T,
src: T,
}
},
Mad {
type: { Type::from(data.type_()) },
data: MadDetails,
arguments<T>: {
dst: {
repr: T,
type: { Type::from(data.dst_type()) },
},
src1: T,
src2: T,
src3: T,
}
},
Trap { }
}
);
@ -588,16 +609,14 @@ pub enum MulDetails {
}
impl MulDetails {
#[allow(unused)] // Used by generated code
fn type_(&self) -> ScalarType {
pub fn type_(&self) -> ScalarType {
match self {
MulDetails::Integer { type_, .. } => *type_,
MulDetails::Float(arith) => arith.type_,
}
}
#[allow(unused)] // Used by generated code
fn dst_type(&self) -> ScalarType {
pub fn dst_type(&self) -> ScalarType {
match self {
MulDetails::Integer {
type_,
@ -995,3 +1014,45 @@ pub enum CvtaDirection {
GenericToExplicit,
ExplicitToGeneric,
}
#[derive(Copy, Clone)]
pub struct AbsDetails {
pub flush_to_zero: Option<bool>,
pub type_: ScalarType,
}
#[derive(Copy, Clone)]
pub enum MadDetails {
Integer {
control: MulIntControl,
saturate: bool,
type_: ScalarType,
},
Float(ArithFloat),
}
impl MadDetails {
pub fn dst_type(&self) -> ScalarType {
match self {
MadDetails::Integer {
type_,
control: MulIntControl::Wide,
..
} => match type_ {
ScalarType::U16 => ScalarType::U32,
ScalarType::S16 => ScalarType::S32,
ScalarType::U32 => ScalarType::U64,
ScalarType::S32 => ScalarType::S64,
_ => unreachable!(),
},
_ => self.type_(),
}
}
fn type_(&self) -> ScalarType {
match self {
MadDetails::Integer { type_, .. } => *type_,
MadDetails::Float(arith) => arith.type_,
}
}
}

View file

@ -1450,6 +1450,8 @@ derive_parser!(
ScalarType = { .f16, .f16x2, .bf16, .bf16x2 };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mul
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul
mul.mode.type d, a, b => {
ast::Instruction::Mul {
data: ast::MulDetails::Integer {
@ -1476,8 +1478,6 @@ derive_parser!(
.s16, .s32 };
RawMulIntControl = { .wide };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mul
mul{.rnd}{.ftz}{.sat}.f32 d, a, b => {
ast::Instruction::Mul {
data: ast::MulDetails::Float (
@ -1507,7 +1507,6 @@ derive_parser!(
.rnd: RawRoundingMode = { .rn, .rz, .rm, .rp };
ScalarType = { .f32, .f64 };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul
mul{.rnd}{.ftz}{.sat}.f16 d, a, b => {
ast::Instruction::Mul {
data: ast::MulDetails::Float (
@ -1706,6 +1705,174 @@ derive_parser!(
.space: StateSpace = { .const, .global, .local, .shared{::cta, ::cluster}, .param{::entry} };
.size: ScalarType = { .u32, .u64 };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-abs
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-abs
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-abs
abs.type d, a => {
ast::Instruction::Abs {
data: ast::AbsDetails {
flush_to_zero: None,
type_
},
arguments: ast::AbsArgs {
dst: d, src: a
}
}
}
abs{.ftz}.f32 d, a => {
ast::Instruction::Abs {
data: ast::AbsDetails {
flush_to_zero: Some(ftz),
type_: f32
},
arguments: ast::AbsArgs {
dst: d, src: a
}
}
}
abs.f64 d, a => {
ast::Instruction::Abs {
data: ast::AbsDetails {
flush_to_zero: None,
type_: f64
},
arguments: ast::AbsArgs {
dst: d, src: a
}
}
}
abs{.ftz}.f16 d, a => {
ast::Instruction::Abs {
data: ast::AbsDetails {
flush_to_zero: Some(ftz),
type_: f16
},
arguments: ast::AbsArgs {
dst: d, src: a
}
}
}
abs{.ftz}.f16x2 d, a => {
ast::Instruction::Abs {
data: ast::AbsDetails {
flush_to_zero: Some(ftz),
type_: f16x2
},
arguments: ast::AbsArgs {
dst: d, src: a
}
}
}
abs.bf16 d, a => {
ast::Instruction::Abs {
data: ast::AbsDetails {
flush_to_zero: None,
type_: bf16
},
arguments: ast::AbsArgs {
dst: d, src: a
}
}
}
abs.bf16x2 d, a => {
ast::Instruction::Abs {
data: ast::AbsDetails {
flush_to_zero: None,
type_: bf16x2
},
arguments: ast::AbsArgs {
dst: d, src: a
}
}
}
.type: ScalarType = { .s16, .s32, .s64 };
ScalarType = { .f32, .f64, .f16, .f16x2, .bf16, .bf16x2 };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mad
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mad
mad.mode.type d, a, b, c => {
ast::Instruction::Mad {
data: ast::MadDetails::Integer {
type_,
control: mode.into(),
saturate: false
},
arguments: MadArgs { dst: d, src1: a, src2: b, src3: c }
}
}
.type: ScalarType = { .u16, .u32, .u64,
.s16, .s32, .s64 };
.mode: RawMulIntControl = { .hi, .lo };
// The .wide suffix is supported only for 16-bit and 32-bit integer types.
mad.wide.type d, a, b, c => {
ast::Instruction::Mad {
data: ast::MadDetails::Integer {
type_,
control: wide.into(),
saturate: false
},
arguments: MadArgs { dst: d, src1: a, src2: b, src3: c }
}
}
.type: ScalarType = { .u16, .u32,
.s16, .s32 };
RawMulIntControl = { .wide };
mad.hi.sat.s32 d, a, b, c => {
ast::Instruction::Mad {
data: ast::MadDetails::Integer {
type_: s32,
control: hi.into(),
saturate: true
},
arguments: MadArgs { dst: d, src1: a, src2: b, src3: c }
}
}
RawMulIntControl = { .hi };
ScalarType = { .s32 };
mad{.ftz}{.sat}.f32 d, a, b, c => {
ast::Instruction::Mad {
data: ast::MadDetails::Float(
ArithFloat {
type_: f32,
rounding: None,
flush_to_zero: Some(ftz),
saturate: sat
}
),
arguments: MadArgs { dst: d, src1: a, src2: b, src3: c }
}
}
mad.rnd{.ftz}{.sat}.f32 d, a, b, c => {
ast::Instruction::Mad {
data: ast::MadDetails::Float(
ArithFloat {
type_: f32,
rounding: Some(rnd.into()),
flush_to_zero: Some(ftz),
saturate: sat
}
),
arguments: MadArgs { dst: d, src1: a, src2: b, src3: c }
}
}
mad.rnd.f64 d, a, b, c => {
ast::Instruction::Mad {
data: ast::MadDetails::Float(
ArithFloat {
type_: f64,
rounding: Some(rnd.into()),
flush_to_zero: None,
saturate: false
}
),
arguments: MadArgs { dst: d, src1: a, src2: b, src3: c }
}}
.rnd: RawRoundingMode = { .rn, .rz, .rm, .rp };
ScalarType = { .f32, .f64 };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret
ret{.uni} => {
Instruction::Ret { data: RetData { uniform: uni } }