mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
Add cvt
This commit is contained in:
parent
c21c55dfc2
commit
bc1074ed67
3 changed files with 228 additions and 12 deletions
|
@ -847,7 +847,11 @@ mod tests {
|
|||
assert_eq!("ScalarType :: U32", to_string(variant.type_.unwrap()));
|
||||
assert_eq!("StateSpace :: Global", to_string(variant.space.unwrap()));
|
||||
assert_eq!("LdDetails", to_string(variant.data.unwrap()));
|
||||
let arguments = variant.arguments.unwrap();
|
||||
let arguments = if let Some(Arguments::Def(a)) = variant.arguments {
|
||||
a
|
||||
} else {
|
||||
panic!()
|
||||
};
|
||||
assert_eq!("P", to_string(arguments.generic));
|
||||
let mut fields = arguments.fields.into_iter();
|
||||
let dst = fields.next().unwrap();
|
||||
|
|
|
@ -1,4 +1,9 @@
|
|||
use super::{MemScope, RawSetpCompareOp, ScalarType, SetpBoolPostOp, StateSpace, VectorPrefix};
|
||||
use std::cmp::Ordering;
|
||||
|
||||
use super::{
|
||||
MemScope, RawRoundingMode, RawSetpCompareOp, ScalarType, SetpBoolPostOp, StateSpace,
|
||||
VectorPrefix,
|
||||
};
|
||||
use crate::{PtxError, PtxParserState};
|
||||
use bitflags::bitflags;
|
||||
|
||||
|
@ -147,6 +152,19 @@ gen::generate_instruction_type!(
|
|||
visit_mut: arguments.visit_mut(data, visitor),
|
||||
map: Instruction::Call{ arguments: arguments.map(&data, visitor), data }
|
||||
},
|
||||
Cvt {
|
||||
data: CvtDetails,
|
||||
arguments<T>: {
|
||||
dst: {
|
||||
repr: T,
|
||||
type: { Type::Scalar(data.to) },
|
||||
},
|
||||
src: {
|
||||
repr: T,
|
||||
type: { Type::Scalar(data.from) },
|
||||
},
|
||||
}
|
||||
},
|
||||
Ret {
|
||||
data: RetData
|
||||
},
|
||||
|
@ -284,6 +302,28 @@ impl Type {
|
|||
}
|
||||
|
||||
impl ScalarType {
|
||||
pub fn size_of(self) -> u8 {
|
||||
match self {
|
||||
ScalarType::U8 | ScalarType::S8 | ScalarType::B8 => 1,
|
||||
ScalarType::U16
|
||||
| ScalarType::S16
|
||||
| ScalarType::B16
|
||||
| ScalarType::F16
|
||||
| ScalarType::BF16 => 2,
|
||||
ScalarType::U32
|
||||
| ScalarType::S32
|
||||
| ScalarType::B32
|
||||
| ScalarType::F32
|
||||
| ScalarType::U16x2
|
||||
| ScalarType::S16x2
|
||||
| ScalarType::F16x2
|
||||
| ScalarType::BF16x2 => 4,
|
||||
ScalarType::U64 | ScalarType::S64 | ScalarType::B64 | ScalarType::F64 => 8,
|
||||
ScalarType::B128 => 16,
|
||||
ScalarType::Pred => 1,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn kind(self) -> ScalarKind {
|
||||
match self {
|
||||
ScalarType::U8 => ScalarKind::Unsigned,
|
||||
|
@ -758,3 +798,148 @@ impl<T: Operand> CallArgs<T> {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CvtDetails {
|
||||
from: ScalarType,
|
||||
to: ScalarType,
|
||||
mode: CvtMode,
|
||||
}
|
||||
|
||||
pub enum CvtMode {
|
||||
// int from int
|
||||
ZeroExtend,
|
||||
SignExtend,
|
||||
Truncate,
|
||||
Bitcast,
|
||||
// float from float
|
||||
FPExtend {
|
||||
flush_to_zero: Option<bool>,
|
||||
},
|
||||
FPTruncate {
|
||||
// float rounding
|
||||
rounding: RoundingMode,
|
||||
flush_to_zero: Option<bool>,
|
||||
},
|
||||
FPRound {
|
||||
integer_rounding: Option<RoundingMode>,
|
||||
flush_to_zero: Option<bool>,
|
||||
},
|
||||
// int from float
|
||||
SignedFromFP {
|
||||
rounding: RoundingMode,
|
||||
flush_to_zero: Option<bool>,
|
||||
}, // integer rounding
|
||||
UnsignedFromFP {
|
||||
rounding: RoundingMode,
|
||||
flush_to_zero: Option<bool>,
|
||||
}, // integer rounding
|
||||
// float from int, ftz is allowed in the grammar, but clearly nonsensical
|
||||
FPFromSigned(RoundingMode), // float rounding
|
||||
FPFromUnsigned(RoundingMode), // float rounding
|
||||
}
|
||||
|
||||
impl CvtDetails {
|
||||
pub(crate) fn new(
|
||||
errors: &mut Vec<PtxError>,
|
||||
rnd: Option<RawRoundingMode>,
|
||||
ftz: bool,
|
||||
saturate: bool,
|
||||
dst: ScalarType,
|
||||
src: ScalarType,
|
||||
) -> Self {
|
||||
if saturate {
|
||||
errors.push(PtxError::Todo);
|
||||
}
|
||||
// Modifier .ftz can only be specified when either .dtype or .atype is .f32 and applies only to single precision (.f32) inputs and results.
|
||||
let flush_to_zero = match (dst, src) {
|
||||
(ScalarType::F32, _) | (_, ScalarType::F32) => Some(ftz),
|
||||
_ => {
|
||||
if ftz {
|
||||
errors.push(PtxError::NonF32Ftz);
|
||||
}
|
||||
None
|
||||
}
|
||||
};
|
||||
let rounding = rnd.map(Into::into);
|
||||
let mut unwrap_rounding = || match rounding {
|
||||
Some(rnd) => rnd,
|
||||
None => {
|
||||
errors.push(PtxError::SyntaxError);
|
||||
RoundingMode::NearestEven
|
||||
}
|
||||
};
|
||||
let mode = match (dst.kind(), src.kind()) {
|
||||
(ScalarKind::Float, ScalarKind::Float) => match dst.size_of().cmp(&src.size_of()) {
|
||||
Ordering::Less => CvtMode::FPTruncate {
|
||||
rounding: unwrap_rounding(),
|
||||
flush_to_zero,
|
||||
},
|
||||
Ordering::Equal => CvtMode::FPRound {
|
||||
integer_rounding: rounding,
|
||||
flush_to_zero,
|
||||
},
|
||||
Ordering::Greater => {
|
||||
if rounding.is_some() {
|
||||
errors.push(PtxError::SyntaxError);
|
||||
}
|
||||
CvtMode::FPExtend { flush_to_zero }
|
||||
}
|
||||
},
|
||||
(ScalarKind::Unsigned, ScalarKind::Float) => CvtMode::UnsignedFromFP {
|
||||
rounding: unwrap_rounding(),
|
||||
flush_to_zero,
|
||||
},
|
||||
(ScalarKind::Signed, ScalarKind::Float) => CvtMode::SignedFromFP {
|
||||
rounding: unwrap_rounding(),
|
||||
flush_to_zero,
|
||||
},
|
||||
(ScalarKind::Float, ScalarKind::Unsigned) => CvtMode::FPFromUnsigned(unwrap_rounding()),
|
||||
(ScalarKind::Float, ScalarKind::Signed) => CvtMode::FPFromSigned(unwrap_rounding()),
|
||||
(
|
||||
ScalarKind::Unsigned | ScalarKind::Signed,
|
||||
ScalarKind::Unsigned | ScalarKind::Signed,
|
||||
) => match dst.size_of().cmp(&src.size_of()) {
|
||||
Ordering::Less => {
|
||||
if dst.kind() != src.kind() {
|
||||
errors.push(PtxError::Todo);
|
||||
}
|
||||
CvtMode::Truncate
|
||||
}
|
||||
Ordering::Equal => CvtMode::Bitcast,
|
||||
Ordering::Greater => {
|
||||
if dst.kind() != src.kind() {
|
||||
errors.push(PtxError::Todo);
|
||||
}
|
||||
if src.kind() == ScalarKind::Signed {
|
||||
CvtMode::SignExtend
|
||||
} else {
|
||||
CvtMode::ZeroExtend
|
||||
}
|
||||
}
|
||||
},
|
||||
(_, _) => {
|
||||
errors.push(PtxError::SyntaxError);
|
||||
CvtMode::Bitcast
|
||||
}
|
||||
};
|
||||
CvtDetails {
|
||||
mode,
|
||||
to: dst,
|
||||
from: src,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CvtIntToIntDesc {
|
||||
pub dst: ScalarType,
|
||||
pub src: ScalarType,
|
||||
pub saturate: bool,
|
||||
}
|
||||
|
||||
pub struct CvtDesc {
|
||||
pub rounding: Option<RoundingMode>,
|
||||
pub flush_to_zero: Option<bool>,
|
||||
pub saturate: bool,
|
||||
pub dst: ScalarType,
|
||||
pub src: ScalarType,
|
||||
}
|
||||
|
|
|
@ -60,13 +60,13 @@ impl From<RawLdStQualifier> for ast::LdStQualifier {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<RawFloatRounding> for ast::RoundingMode {
|
||||
fn from(value: RawFloatRounding) -> Self {
|
||||
impl From<RawRoundingMode> for ast::RoundingMode {
|
||||
fn from(value: RawRoundingMode) -> Self {
|
||||
match value {
|
||||
RawFloatRounding::Rn => ast::RoundingMode::NearestEven,
|
||||
RawFloatRounding::Rz => ast::RoundingMode::Zero,
|
||||
RawFloatRounding::Rm => ast::RoundingMode::NegativeInf,
|
||||
RawFloatRounding::Rp => ast::RoundingMode::PositiveInf,
|
||||
RawRoundingMode::Rn | RawRoundingMode::Rni => ast::RoundingMode::NearestEven,
|
||||
RawRoundingMode::Rz | RawRoundingMode::Rzi => ast::RoundingMode::Zero,
|
||||
RawRoundingMode::Rm | RawRoundingMode::Rmi => ast::RoundingMode::NegativeInf,
|
||||
RawRoundingMode::Rp | RawRoundingMode::Rpi => ast::RoundingMode::PositiveInf,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1380,7 +1380,7 @@ derive_parser!(
|
|||
}
|
||||
}
|
||||
}
|
||||
.rnd: RawFloatRounding = { .rn, .rz, .rm, .rp };
|
||||
.rnd: RawRoundingMode = { .rn, .rz, .rm, .rp };
|
||||
ScalarType = { .f32, .f64 };
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-add
|
||||
|
@ -1444,7 +1444,7 @@ derive_parser!(
|
|||
}
|
||||
}
|
||||
}
|
||||
.rnd: RawFloatRounding = { .rn };
|
||||
.rnd: RawRoundingMode = { .rn };
|
||||
ScalarType = { .f16, .f16x2, .bf16, .bf16x2 };
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul
|
||||
|
@ -1502,7 +1502,7 @@ derive_parser!(
|
|||
arguments: MulArgs { dst: d, src1: a, src2: b }
|
||||
}
|
||||
}
|
||||
.rnd: RawFloatRounding = { .rn, .rz, .rm, .rp };
|
||||
.rnd: RawRoundingMode = { .rn, .rz, .rm, .rp };
|
||||
ScalarType = { .f32, .f64 };
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul
|
||||
|
@ -1558,7 +1558,7 @@ derive_parser!(
|
|||
arguments: MulArgs { dst: d, src1: a, src2: b }
|
||||
}
|
||||
}
|
||||
.rnd: RawFloatRounding = { .rn };
|
||||
.rnd: RawRoundingMode = { .rn };
|
||||
ScalarType = { .f16, .f16x2, .bf16, .bf16x2 };
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-setp
|
||||
|
@ -1626,6 +1626,33 @@ derive_parser!(
|
|||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-call
|
||||
call <= { call(stream) }
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt
|
||||
cvt{.ifrnd}{.ftz}{.sat}.dtype.atype d, a => {
|
||||
let data = ast::CvtDetails::new(&mut state.errors, ifrnd, ftz, sat, dtype, atype);
|
||||
let arguments = ast::CvtArgs { dst: d, src: a };
|
||||
ast::Instruction::Cvt {
|
||||
data, arguments
|
||||
}
|
||||
}
|
||||
// cvt.frnd2{.relu}{.satfinite}.f16.f32 d, a;
|
||||
// cvt.frnd2{.relu}{.satfinite}.f16x2.f32 d, a, b;
|
||||
// cvt.frnd2{.relu}{.satfinite}.bf16.f32 d, a;
|
||||
// cvt.frnd2{.relu}{.satfinite}.bf16x2.f32 d, a, b;
|
||||
// cvt.rna{.satfinite}.tf32.f32 d, a;
|
||||
// cvt.frnd2{.relu}.tf32.f32 d, a;
|
||||
// cvt.rn.satfinite{.relu}.f8x2type.f32 d, a, b;
|
||||
// cvt.rn.satfinite{.relu}.f8x2type.f16x2 d, a;
|
||||
// cvt.rn.{.relu}.f16x2.f8x2type d, a;
|
||||
|
||||
.ifrnd: RawRoundingMode = { .rn, .rz, .rm, .rp, .rni, .rzi, .rmi, .rpi };
|
||||
.frnd2: RawRoundingMode = { .rn, .rz };
|
||||
.dtype: ScalarType = { .u8, .u16, .u32, .u64,
|
||||
.s8, .s16, .s32, .s64,
|
||||
.bf16, .f16, .f32, .f64 };
|
||||
.atype: ScalarType = { .u8, .u16, .u32, .u64,
|
||||
.s8, .s16, .s32, .s64,
|
||||
.bf16, .f16, .f32, .f64 };
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret
|
||||
ret{.uni} => {
|
||||
Instruction::Ret { data: RetData { uniform: uni } }
|
||||
|
|
Loading…
Add table
Reference in a new issue