diff --git a/gen_impl/src/lib.rs b/gen_impl/src/lib.rs index 08911ec..4532964 100644 --- a/gen_impl/src/lib.rs +++ b/gen_impl/src/lib.rs @@ -847,7 +847,11 @@ mod tests { assert_eq!("ScalarType :: U32", to_string(variant.type_.unwrap())); assert_eq!("StateSpace :: Global", to_string(variant.space.unwrap())); assert_eq!("LdDetails", to_string(variant.data.unwrap())); - let arguments = variant.arguments.unwrap(); + let arguments = if let Some(Arguments::Def(a)) = variant.arguments { + a + } else { + panic!() + }; assert_eq!("P", to_string(arguments.generic)); let mut fields = arguments.fields.into_iter(); let dst = fields.next().unwrap(); diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 0dabd5d..daee9da 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -1,4 +1,9 @@ -use super::{MemScope, RawSetpCompareOp, ScalarType, SetpBoolPostOp, StateSpace, VectorPrefix}; +use std::cmp::Ordering; + +use super::{ + MemScope, RawRoundingMode, RawSetpCompareOp, ScalarType, SetpBoolPostOp, StateSpace, + VectorPrefix, +}; use crate::{PtxError, PtxParserState}; use bitflags::bitflags; @@ -147,6 +152,19 @@ gen::generate_instruction_type!( visit_mut: arguments.visit_mut(data, visitor), map: Instruction::Call{ arguments: arguments.map(&data, visitor), data } }, + Cvt { + data: CvtDetails, + arguments: { + dst: { + repr: T, + type: { Type::Scalar(data.to) }, + }, + src: { + repr: T, + type: { Type::Scalar(data.from) }, + }, + } + }, Ret { data: RetData }, @@ -284,6 +302,28 @@ impl Type { } impl ScalarType { + pub fn size_of(self) -> u8 { + match self { + ScalarType::U8 | ScalarType::S8 | ScalarType::B8 => 1, + ScalarType::U16 + | ScalarType::S16 + | ScalarType::B16 + | ScalarType::F16 + | ScalarType::BF16 => 2, + ScalarType::U32 + | ScalarType::S32 + | ScalarType::B32 + | ScalarType::F32 + | ScalarType::U16x2 + | ScalarType::S16x2 + | ScalarType::F16x2 + | ScalarType::BF16x2 => 4, + ScalarType::U64 | ScalarType::S64 | ScalarType::B64 | ScalarType::F64 => 8, + ScalarType::B128 => 16, + ScalarType::Pred => 1, + } + } + pub fn kind(self) -> ScalarKind { match self { ScalarType::U8 => ScalarKind::Unsigned, @@ -758,3 +798,148 @@ impl CallArgs { } } } + +pub struct CvtDetails { + from: ScalarType, + to: ScalarType, + mode: CvtMode, +} + +pub enum CvtMode { + // int from int + ZeroExtend, + SignExtend, + Truncate, + Bitcast, + // float from float + FPExtend { + flush_to_zero: Option, + }, + FPTruncate { + // float rounding + rounding: RoundingMode, + flush_to_zero: Option, + }, + FPRound { + integer_rounding: Option, + flush_to_zero: Option, + }, + // int from float + SignedFromFP { + rounding: RoundingMode, + flush_to_zero: Option, + }, // integer rounding + UnsignedFromFP { + rounding: RoundingMode, + flush_to_zero: Option, + }, // integer rounding + // float from int, ftz is allowed in the grammar, but clearly nonsensical + FPFromSigned(RoundingMode), // float rounding + FPFromUnsigned(RoundingMode), // float rounding +} + +impl CvtDetails { + pub(crate) fn new( + errors: &mut Vec, + rnd: Option, + ftz: bool, + saturate: bool, + dst: ScalarType, + src: ScalarType, + ) -> Self { + if saturate { + errors.push(PtxError::Todo); + } + // Modifier .ftz can only be specified when either .dtype or .atype is .f32 and applies only to single precision (.f32) inputs and results. + let flush_to_zero = match (dst, src) { + (ScalarType::F32, _) | (_, ScalarType::F32) => Some(ftz), + _ => { + if ftz { + errors.push(PtxError::NonF32Ftz); + } + None + } + }; + let rounding = rnd.map(Into::into); + let mut unwrap_rounding = || match rounding { + Some(rnd) => rnd, + None => { + errors.push(PtxError::SyntaxError); + RoundingMode::NearestEven + } + }; + let mode = match (dst.kind(), src.kind()) { + (ScalarKind::Float, ScalarKind::Float) => match dst.size_of().cmp(&src.size_of()) { + Ordering::Less => CvtMode::FPTruncate { + rounding: unwrap_rounding(), + flush_to_zero, + }, + Ordering::Equal => CvtMode::FPRound { + integer_rounding: rounding, + flush_to_zero, + }, + Ordering::Greater => { + if rounding.is_some() { + errors.push(PtxError::SyntaxError); + } + CvtMode::FPExtend { flush_to_zero } + } + }, + (ScalarKind::Unsigned, ScalarKind::Float) => CvtMode::UnsignedFromFP { + rounding: unwrap_rounding(), + flush_to_zero, + }, + (ScalarKind::Signed, ScalarKind::Float) => CvtMode::SignedFromFP { + rounding: unwrap_rounding(), + flush_to_zero, + }, + (ScalarKind::Float, ScalarKind::Unsigned) => CvtMode::FPFromUnsigned(unwrap_rounding()), + (ScalarKind::Float, ScalarKind::Signed) => CvtMode::FPFromSigned(unwrap_rounding()), + ( + ScalarKind::Unsigned | ScalarKind::Signed, + ScalarKind::Unsigned | ScalarKind::Signed, + ) => match dst.size_of().cmp(&src.size_of()) { + Ordering::Less => { + if dst.kind() != src.kind() { + errors.push(PtxError::Todo); + } + CvtMode::Truncate + } + Ordering::Equal => CvtMode::Bitcast, + Ordering::Greater => { + if dst.kind() != src.kind() { + errors.push(PtxError::Todo); + } + if src.kind() == ScalarKind::Signed { + CvtMode::SignExtend + } else { + CvtMode::ZeroExtend + } + } + }, + (_, _) => { + errors.push(PtxError::SyntaxError); + CvtMode::Bitcast + } + }; + CvtDetails { + mode, + to: dst, + from: src, + } + } +} + +pub struct CvtIntToIntDesc { + pub dst: ScalarType, + pub src: ScalarType, + pub saturate: bool, +} + +pub struct CvtDesc { + pub rounding: Option, + pub flush_to_zero: Option, + pub saturate: bool, + pub dst: ScalarType, + pub src: ScalarType, +} diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/main.rs index 2c602d5..68787db 100644 --- a/ptx_parser/src/main.rs +++ b/ptx_parser/src/main.rs @@ -60,13 +60,13 @@ impl From for ast::LdStQualifier { } } -impl From for ast::RoundingMode { - fn from(value: RawFloatRounding) -> Self { +impl From for ast::RoundingMode { + fn from(value: RawRoundingMode) -> Self { match value { - RawFloatRounding::Rn => ast::RoundingMode::NearestEven, - RawFloatRounding::Rz => ast::RoundingMode::Zero, - RawFloatRounding::Rm => ast::RoundingMode::NegativeInf, - RawFloatRounding::Rp => ast::RoundingMode::PositiveInf, + RawRoundingMode::Rn | RawRoundingMode::Rni => ast::RoundingMode::NearestEven, + RawRoundingMode::Rz | RawRoundingMode::Rzi => ast::RoundingMode::Zero, + RawRoundingMode::Rm | RawRoundingMode::Rmi => ast::RoundingMode::NegativeInf, + RawRoundingMode::Rp | RawRoundingMode::Rpi => ast::RoundingMode::PositiveInf, } } } @@ -1380,7 +1380,7 @@ derive_parser!( } } } - .rnd: RawFloatRounding = { .rn, .rz, .rm, .rp }; + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; ScalarType = { .f32, .f64 }; // https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-add @@ -1444,7 +1444,7 @@ derive_parser!( } } } - .rnd: RawFloatRounding = { .rn }; + .rnd: RawRoundingMode = { .rn }; ScalarType = { .f16, .f16x2, .bf16, .bf16x2 }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul @@ -1502,7 +1502,7 @@ derive_parser!( arguments: MulArgs { dst: d, src1: a, src2: b } } } - .rnd: RawFloatRounding = { .rn, .rz, .rm, .rp }; + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; ScalarType = { .f32, .f64 }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul @@ -1558,7 +1558,7 @@ derive_parser!( arguments: MulArgs { dst: d, src1: a, src2: b } } } - .rnd: RawFloatRounding = { .rn }; + .rnd: RawRoundingMode = { .rn }; ScalarType = { .f16, .f16x2, .bf16, .bf16x2 }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-setp @@ -1626,6 +1626,33 @@ derive_parser!( // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-call call <= { call(stream) } + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt + cvt{.ifrnd}{.ftz}{.sat}.dtype.atype d, a => { + let data = ast::CvtDetails::new(&mut state.errors, ifrnd, ftz, sat, dtype, atype); + let arguments = ast::CvtArgs { dst: d, src: a }; + ast::Instruction::Cvt { + data, arguments + } + } + // cvt.frnd2{.relu}{.satfinite}.f16.f32 d, a; + // cvt.frnd2{.relu}{.satfinite}.f16x2.f32 d, a, b; + // cvt.frnd2{.relu}{.satfinite}.bf16.f32 d, a; + // cvt.frnd2{.relu}{.satfinite}.bf16x2.f32 d, a, b; + // cvt.rna{.satfinite}.tf32.f32 d, a; + // cvt.frnd2{.relu}.tf32.f32 d, a; + // cvt.rn.satfinite{.relu}.f8x2type.f32 d, a, b; + // cvt.rn.satfinite{.relu}.f8x2type.f16x2 d, a; + // cvt.rn.{.relu}.f16x2.f8x2type d, a; + + .ifrnd: RawRoundingMode = { .rn, .rz, .rm, .rp, .rni, .rzi, .rmi, .rpi }; + .frnd2: RawRoundingMode = { .rn, .rz }; + .dtype: ScalarType = { .u8, .u16, .u32, .u64, + .s8, .s16, .s32, .s64, + .bf16, .f16, .f32, .f64 }; + .atype: ScalarType = { .u8, .u16, .u32, .u64, + .s8, .s16, .s32, .s64, + .bf16, .f16, .f32, .f64 }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret ret{.uni} => { Instruction::Ret { data: RetData { uniform: uni } }