This commit is contained in:
Andrzej Janik 2024-08-20 17:59:39 +02:00
parent c21c55dfc2
commit bc1074ed67
3 changed files with 228 additions and 12 deletions

View file

@ -847,7 +847,11 @@ mod tests {
assert_eq!("ScalarType :: U32", to_string(variant.type_.unwrap()));
assert_eq!("StateSpace :: Global", to_string(variant.space.unwrap()));
assert_eq!("LdDetails", to_string(variant.data.unwrap()));
let arguments = variant.arguments.unwrap();
let arguments = if let Some(Arguments::Def(a)) = variant.arguments {
a
} else {
panic!()
};
assert_eq!("P", to_string(arguments.generic));
let mut fields = arguments.fields.into_iter();
let dst = fields.next().unwrap();

View file

@ -1,4 +1,9 @@
use super::{MemScope, RawSetpCompareOp, ScalarType, SetpBoolPostOp, StateSpace, VectorPrefix};
use std::cmp::Ordering;
use super::{
MemScope, RawRoundingMode, RawSetpCompareOp, ScalarType, SetpBoolPostOp, StateSpace,
VectorPrefix,
};
use crate::{PtxError, PtxParserState};
use bitflags::bitflags;
@ -147,6 +152,19 @@ gen::generate_instruction_type!(
visit_mut: arguments.visit_mut(data, visitor),
map: Instruction::Call{ arguments: arguments.map(&data, visitor), data }
},
Cvt {
data: CvtDetails,
arguments<T>: {
dst: {
repr: T,
type: { Type::Scalar(data.to) },
},
src: {
repr: T,
type: { Type::Scalar(data.from) },
},
}
},
Ret {
data: RetData
},
@ -284,6 +302,28 @@ impl Type {
}
impl ScalarType {
pub fn size_of(self) -> u8 {
match self {
ScalarType::U8 | ScalarType::S8 | ScalarType::B8 => 1,
ScalarType::U16
| ScalarType::S16
| ScalarType::B16
| ScalarType::F16
| ScalarType::BF16 => 2,
ScalarType::U32
| ScalarType::S32
| ScalarType::B32
| ScalarType::F32
| ScalarType::U16x2
| ScalarType::S16x2
| ScalarType::F16x2
| ScalarType::BF16x2 => 4,
ScalarType::U64 | ScalarType::S64 | ScalarType::B64 | ScalarType::F64 => 8,
ScalarType::B128 => 16,
ScalarType::Pred => 1,
}
}
pub fn kind(self) -> ScalarKind {
match self {
ScalarType::U8 => ScalarKind::Unsigned,
@ -758,3 +798,148 @@ impl<T: Operand> CallArgs<T> {
}
}
}
pub struct CvtDetails {
from: ScalarType,
to: ScalarType,
mode: CvtMode,
}
pub enum CvtMode {
// int from int
ZeroExtend,
SignExtend,
Truncate,
Bitcast,
// float from float
FPExtend {
flush_to_zero: Option<bool>,
},
FPTruncate {
// float rounding
rounding: RoundingMode,
flush_to_zero: Option<bool>,
},
FPRound {
integer_rounding: Option<RoundingMode>,
flush_to_zero: Option<bool>,
},
// int from float
SignedFromFP {
rounding: RoundingMode,
flush_to_zero: Option<bool>,
}, // integer rounding
UnsignedFromFP {
rounding: RoundingMode,
flush_to_zero: Option<bool>,
}, // integer rounding
// float from int, ftz is allowed in the grammar, but clearly nonsensical
FPFromSigned(RoundingMode), // float rounding
FPFromUnsigned(RoundingMode), // float rounding
}
impl CvtDetails {
pub(crate) fn new(
errors: &mut Vec<PtxError>,
rnd: Option<RawRoundingMode>,
ftz: bool,
saturate: bool,
dst: ScalarType,
src: ScalarType,
) -> Self {
if saturate {
errors.push(PtxError::Todo);
}
// Modifier .ftz can only be specified when either .dtype or .atype is .f32 and applies only to single precision (.f32) inputs and results.
let flush_to_zero = match (dst, src) {
(ScalarType::F32, _) | (_, ScalarType::F32) => Some(ftz),
_ => {
if ftz {
errors.push(PtxError::NonF32Ftz);
}
None
}
};
let rounding = rnd.map(Into::into);
let mut unwrap_rounding = || match rounding {
Some(rnd) => rnd,
None => {
errors.push(PtxError::SyntaxError);
RoundingMode::NearestEven
}
};
let mode = match (dst.kind(), src.kind()) {
(ScalarKind::Float, ScalarKind::Float) => match dst.size_of().cmp(&src.size_of()) {
Ordering::Less => CvtMode::FPTruncate {
rounding: unwrap_rounding(),
flush_to_zero,
},
Ordering::Equal => CvtMode::FPRound {
integer_rounding: rounding,
flush_to_zero,
},
Ordering::Greater => {
if rounding.is_some() {
errors.push(PtxError::SyntaxError);
}
CvtMode::FPExtend { flush_to_zero }
}
},
(ScalarKind::Unsigned, ScalarKind::Float) => CvtMode::UnsignedFromFP {
rounding: unwrap_rounding(),
flush_to_zero,
},
(ScalarKind::Signed, ScalarKind::Float) => CvtMode::SignedFromFP {
rounding: unwrap_rounding(),
flush_to_zero,
},
(ScalarKind::Float, ScalarKind::Unsigned) => CvtMode::FPFromUnsigned(unwrap_rounding()),
(ScalarKind::Float, ScalarKind::Signed) => CvtMode::FPFromSigned(unwrap_rounding()),
(
ScalarKind::Unsigned | ScalarKind::Signed,
ScalarKind::Unsigned | ScalarKind::Signed,
) => match dst.size_of().cmp(&src.size_of()) {
Ordering::Less => {
if dst.kind() != src.kind() {
errors.push(PtxError::Todo);
}
CvtMode::Truncate
}
Ordering::Equal => CvtMode::Bitcast,
Ordering::Greater => {
if dst.kind() != src.kind() {
errors.push(PtxError::Todo);
}
if src.kind() == ScalarKind::Signed {
CvtMode::SignExtend
} else {
CvtMode::ZeroExtend
}
}
},
(_, _) => {
errors.push(PtxError::SyntaxError);
CvtMode::Bitcast
}
};
CvtDetails {
mode,
to: dst,
from: src,
}
}
}
pub struct CvtIntToIntDesc {
pub dst: ScalarType,
pub src: ScalarType,
pub saturate: bool,
}
pub struct CvtDesc {
pub rounding: Option<RoundingMode>,
pub flush_to_zero: Option<bool>,
pub saturate: bool,
pub dst: ScalarType,
pub src: ScalarType,
}

View file

@ -60,13 +60,13 @@ impl From<RawLdStQualifier> for ast::LdStQualifier {
}
}
impl From<RawFloatRounding> for ast::RoundingMode {
fn from(value: RawFloatRounding) -> Self {
impl From<RawRoundingMode> for ast::RoundingMode {
fn from(value: RawRoundingMode) -> Self {
match value {
RawFloatRounding::Rn => ast::RoundingMode::NearestEven,
RawFloatRounding::Rz => ast::RoundingMode::Zero,
RawFloatRounding::Rm => ast::RoundingMode::NegativeInf,
RawFloatRounding::Rp => ast::RoundingMode::PositiveInf,
RawRoundingMode::Rn | RawRoundingMode::Rni => ast::RoundingMode::NearestEven,
RawRoundingMode::Rz | RawRoundingMode::Rzi => ast::RoundingMode::Zero,
RawRoundingMode::Rm | RawRoundingMode::Rmi => ast::RoundingMode::NegativeInf,
RawRoundingMode::Rp | RawRoundingMode::Rpi => ast::RoundingMode::PositiveInf,
}
}
}
@ -1380,7 +1380,7 @@ derive_parser!(
}
}
}
.rnd: RawFloatRounding = { .rn, .rz, .rm, .rp };
.rnd: RawRoundingMode = { .rn, .rz, .rm, .rp };
ScalarType = { .f32, .f64 };
// https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-add
@ -1444,7 +1444,7 @@ derive_parser!(
}
}
}
.rnd: RawFloatRounding = { .rn };
.rnd: RawRoundingMode = { .rn };
ScalarType = { .f16, .f16x2, .bf16, .bf16x2 };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul
@ -1502,7 +1502,7 @@ derive_parser!(
arguments: MulArgs { dst: d, src1: a, src2: b }
}
}
.rnd: RawFloatRounding = { .rn, .rz, .rm, .rp };
.rnd: RawRoundingMode = { .rn, .rz, .rm, .rp };
ScalarType = { .f32, .f64 };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul
@ -1558,7 +1558,7 @@ derive_parser!(
arguments: MulArgs { dst: d, src1: a, src2: b }
}
}
.rnd: RawFloatRounding = { .rn };
.rnd: RawRoundingMode = { .rn };
ScalarType = { .f16, .f16x2, .bf16, .bf16x2 };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-setp
@ -1626,6 +1626,33 @@ derive_parser!(
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-call
call <= { call(stream) }
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt
cvt{.ifrnd}{.ftz}{.sat}.dtype.atype d, a => {
let data = ast::CvtDetails::new(&mut state.errors, ifrnd, ftz, sat, dtype, atype);
let arguments = ast::CvtArgs { dst: d, src: a };
ast::Instruction::Cvt {
data, arguments
}
}
// cvt.frnd2{.relu}{.satfinite}.f16.f32 d, a;
// cvt.frnd2{.relu}{.satfinite}.f16x2.f32 d, a, b;
// cvt.frnd2{.relu}{.satfinite}.bf16.f32 d, a;
// cvt.frnd2{.relu}{.satfinite}.bf16x2.f32 d, a, b;
// cvt.rna{.satfinite}.tf32.f32 d, a;
// cvt.frnd2{.relu}.tf32.f32 d, a;
// cvt.rn.satfinite{.relu}.f8x2type.f32 d, a, b;
// cvt.rn.satfinite{.relu}.f8x2type.f16x2 d, a;
// cvt.rn.{.relu}.f16x2.f8x2type d, a;
.ifrnd: RawRoundingMode = { .rn, .rz, .rm, .rp, .rni, .rzi, .rmi, .rpi };
.frnd2: RawRoundingMode = { .rn, .rz };
.dtype: ScalarType = { .u8, .u16, .u32, .u64,
.s8, .s16, .s32, .s64,
.bf16, .f16, .f32, .f64 };
.atype: ScalarType = { .u8, .u16, .u32, .u64,
.s8, .s16, .s32, .s64,
.bf16, .f16, .f32, .f64 };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret
ret{.uni} => {
Instruction::Ret { data: RetData { uniform: uni } }