Add rcp, sqrt, rsqrt

This commit is contained in:
Andrzej Janik 2024-08-21 03:38:43 +02:00
parent fc713f2930
commit c16bae32b5
2 changed files with 151 additions and 0 deletions

View file

@ -258,6 +258,30 @@ gen::generate_instruction_type!(
src2: T,
}
},
Rcp {
type: { Type::from(data.type_) },
data: RcpData,
arguments<T>: {
dst: T,
src: T,
}
},
Sqrt {
type: { Type::from(data.type_) },
data: RcpData,
arguments<T>: {
dst: T,
src: T,
}
},
Rsqrt {
type: { Type::from(data.type_) },
data: RsqrtData,
arguments<T>: {
dst: T,
src: T,
}
},
Trap { }
}
);
@ -1117,3 +1141,29 @@ pub struct MinMaxFloat {
pub nan: bool,
pub type_: ScalarType,
}
#[derive(Copy, Clone, Eq, PartialEq)]
pub enum DivFloatKind {
Approx,
Full,
Rounding(RoundingMode),
}
#[derive(Copy, Clone)]
pub struct RcpData {
pub kind: RcpKind,
pub flush_to_zero: Option<bool>,
pub type_: ScalarType,
}
#[derive(Copy, Clone, Eq, PartialEq)]
pub enum RcpKind {
Approx,
Full(RoundingMode),
}
#[derive(Copy, Clone)]
pub struct RsqrtData {
pub flush_to_zero: Option<bool>,
pub type_: ScalarType,
}

View file

@ -2244,6 +2244,107 @@ derive_parser!(
}
ScalarType = { .f16, .f16x2, .bf16, .bf16x2 };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rcp
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rcp-approx-ftz-f64
rcp.approx{.ftz}.type d, a => {
ast::Instruction::Rcp {
data: ast::RcpData {
kind: ast::RcpKind::Approx,
flush_to_zero: Some(ftz),
type_
},
arguments: RcpArgs { dst: d, src: a }
}
}
rcp.rnd{.ftz}.f32 d, a => {
ast::Instruction::Rcp {
data: ast::RcpData {
kind: ast::RcpKind::Full(rnd.into()),
flush_to_zero: Some(ftz),
type_: f32
},
arguments: RcpArgs { dst: d, src: a }
}
}
rcp.rnd.f64 d, a => {
ast::Instruction::Rcp {
data: ast::RcpData {
kind: ast::RcpKind::Full(rnd.into()),
flush_to_zero: None,
type_: f64
},
arguments: RcpArgs { dst: d, src: a }
}
}
.type: ScalarType = { .f32, .f64 };
.rnd: RawRoundingMode = { .rn, .rz, .rm, .rp };
ScalarType = { .f32, .f64 };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sqrt
sqrt.approx{.ftz}.f32 d, a => {
ast::Instruction::Sqrt {
data: ast::RcpData {
kind: ast::RcpKind::Approx,
flush_to_zero: Some(ftz),
type_: f32
},
arguments: SqrtArgs { dst: d, src: a }
}
}
sqrt.rnd{.ftz}.f32 d, a => {
ast::Instruction::Sqrt {
data: ast::RcpData {
kind: ast::RcpKind::Full(rnd.into()),
flush_to_zero: Some(ftz),
type_: f32
},
arguments: SqrtArgs { dst: d, src: a }
}
}
sqrt.rnd.f64 d, a => {
ast::Instruction::Sqrt {
data: ast::RcpData {
kind: ast::RcpKind::Full(rnd.into()),
flush_to_zero: None,
type_: f64
},
arguments: SqrtArgs { dst: d, src: a }
}
}
.rnd: RawRoundingMode = { .rn, .rz, .rm, .rp };
ScalarType = { .f32, .f64 };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rsqrt
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rsqrt-approx-ftz-f64
rsqrt.approx{.ftz}.f32 d, a => {
ast::Instruction::Rsqrt {
data: ast::RsqrtData {
flush_to_zero: Some(ftz),
type_: f32
},
arguments: RsqrtArgs { dst: d, src: a }
}
}
rsqrt.approx.f64 d, a => {
ast::Instruction::Rsqrt {
data: ast::RsqrtData {
flush_to_zero: None,
type_: f64
},
arguments: RsqrtArgs { dst: d, src: a }
}
}
rsqrt.approx.ftz.f64 d, a => {
ast::Instruction::Rsqrt {
data: ast::RsqrtData {
flush_to_zero: None,
type_: f64
},
arguments: RsqrtArgs { dst: d, src: a }
}
}
ScalarType = { .f32, .f64 };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret
ret{.uni} => {
Instruction::Ret { data: RetData { uniform: uni } }