Implement setp

This commit is contained in:
Andrzej Janik 2024-08-19 02:23:26 +02:00
commit c08e6a6772
5 changed files with 388 additions and 55 deletions

View file

@ -1,6 +1,5 @@
use std::intrinsics::unreachable;
use super::{MemScope, ScalarType, StateSpace, VectorPrefix};
use super::{MemScope, RawSetpCompareOp, ScalarType, SetpBoolPostOp, StateSpace, VectorPrefix};
use crate::{PtxError, PtxParserState};
use bitflags::bitflags;
pub enum Statement<P: Operand> {
@ -11,7 +10,7 @@ pub enum Statement<P: Operand> {
}
gen::generate_instruction_type!(
pub enum Instruction<T> {
pub enum Instruction<T: Operand> {
Mov {
type: { &data.typ },
data: MovDetails,
@ -63,6 +62,52 @@ gen::generate_instruction_type!(
src2: T,
}
},
Setp {
data: SetpData,
arguments<T>: {
dst1: {
repr: T,
type: ScalarType::Pred.into()
},
dst2: {
repr: Option<T>,
type: ScalarType::Pred.into()
},
src1: {
repr: T,
type: data.type_.into(),
},
src2: {
repr: T,
type: data.type_.into(),
}
}
},
SetpBool {
data: SetpBoolData,
arguments<T>: {
dst1: {
repr: T,
type: ScalarType::Pred.into()
},
dst2: {
repr: Option<T>,
type: ScalarType::Pred.into()
},
src1: {
repr: T,
type: data.base.type_.into(),
},
src2: {
repr: T,
type: data.base.type_.into(),
},
src3: {
repr: T,
type: ScalarType::Pred.into()
}
}
},
Ret {
data: RetData
},
@ -70,6 +115,66 @@ gen::generate_instruction_type!(
}
);
pub trait Visitor<T> {
fn visit(&mut self, args: &T, type_: &Type, space: StateSpace, is_dst: bool);
}
pub trait VisitorMut<T> {
fn visit(&mut self, args: &mut T, type_: &Type, space: StateSpace, is_dst: bool);
}
pub trait VisitorMap<From, To> {
fn visit(&mut self, args: From, type_: &Type, space: StateSpace, is_dst: bool) -> To;
}
trait VisitOperand {
type Operand;
fn visit(&self, fn_: impl FnOnce(&Self::Operand));
fn visit_mut(&mut self, fn_: impl FnOnce(&mut Self::Operand));
}
impl<T: Operand> VisitOperand for T {
type Operand = Self;
fn visit(&self, fn_: impl FnOnce(&Self::Operand)) {
fn_(self)
}
fn visit_mut(&mut self, fn_: impl FnOnce(&mut Self::Operand)) {
fn_(self)
}
}
impl<T: Operand> VisitOperand for Option<T> {
type Operand = T;
fn visit(&self, fn_: impl FnOnce(&Self::Operand)) {
self.as_ref().map(fn_);
}
fn visit_mut(&mut self, fn_: impl FnOnce(&mut Self::Operand)) {
self.as_mut().map(fn_);
}
}
trait MapOperand: Sized {
type Input;
type Output<U>;
fn map<U>(self, fn_: impl FnOnce(Self::Input) -> U) -> Self::Output<U>;
}
impl<T: Operand> MapOperand for T {
type Input = Self;
type Output<U> = U;
fn map<U>(self, fn_: impl FnOnce(T) -> U) -> U {
fn_(self)
}
}
impl<T: Operand> MapOperand for Option<T> {
type Input = T;
type Output<U> = Option<U>;
fn map<U>(self, fn_: impl FnOnce(T) -> U) -> Option<U> {
self.map(|x| fn_(x))
}
}
pub struct MultiVariable<ID> {
pub var: Variable<ID>,
pub count: Option<u32>,
@ -89,18 +194,6 @@ pub struct PredAt<ID> {
pub label: ID,
}
pub trait Visitor<T> {
fn visit(&mut self, args: &T, type_: &Type, space: StateSpace, is_dst: bool);
}
pub trait VisitorMut<T> {
fn visit(&mut self, args: &mut T, type_: &Type, space: StateSpace, is_dst: bool);
}
pub trait VisitorMap<From, To> {
fn visit(&mut self, args: From, type_: &Type, space: StateSpace, is_dst: bool) -> To;
}
#[derive(PartialEq, Eq, Clone, Hash)]
pub enum Type {
// .param.b32 foo;
@ -121,6 +214,43 @@ impl Type {
}
}
impl ScalarType {
pub fn kind(self) -> ScalarKind {
match self {
ScalarType::U8 => ScalarKind::Unsigned,
ScalarType::U16 => ScalarKind::Unsigned,
ScalarType::U16x2 => ScalarKind::Unsigned,
ScalarType::U32 => ScalarKind::Unsigned,
ScalarType::U64 => ScalarKind::Unsigned,
ScalarType::S8 => ScalarKind::Signed,
ScalarType::S16 => ScalarKind::Signed,
ScalarType::S16x2 => ScalarKind::Signed,
ScalarType::S32 => ScalarKind::Signed,
ScalarType::S64 => ScalarKind::Signed,
ScalarType::B8 => ScalarKind::Bit,
ScalarType::B16 => ScalarKind::Bit,
ScalarType::B32 => ScalarKind::Bit,
ScalarType::B64 => ScalarKind::Bit,
ScalarType::B128 => ScalarKind::Bit,
ScalarType::F16 => ScalarKind::Float,
ScalarType::F16x2 => ScalarKind::Float,
ScalarType::F32 => ScalarKind::Float,
ScalarType::F64 => ScalarKind::Float,
ScalarType::BF16 => ScalarKind::Float,
ScalarType::BF16x2 => ScalarKind::Float,
ScalarType::Pred => ScalarKind::Pred,
}
}
}
#[derive(Clone, Copy, PartialEq, Eq)]
pub enum ScalarKind {
Bit,
Unsigned,
Signed,
Float,
Pred,
}
impl From<ScalarType> for Type {
fn from(value: ScalarType) -> Self {
Type::Scalar(value)
@ -347,3 +477,135 @@ pub enum MulIntControl {
High,
Wide,
}
pub struct SetpData {
pub type_: ScalarType,
pub flush_to_zero: Option<bool>,
pub cmp_op: SetpCompareOp,
}
impl SetpData {
pub(crate) fn try_parse(
errors: &mut PtxParserState,
cmp_op: super::RawSetpCompareOp,
ftz: bool,
type_: ScalarType,
) -> Self {
let flush_to_zero = match (ftz, type_) {
(_, ScalarType::F32) => Some(ftz),
_ => {
errors.push(PtxError::NonF32Ftz);
None
}
};
let type_kind = type_.kind();
let cmp_op = if type_kind == ScalarKind::Float {
SetpCompareOp::Float(SetpCompareFloat::from(cmp_op))
} else {
match SetpCompareInt::try_from(cmp_op) {
Ok(op) => SetpCompareOp::Integer(op),
Err(err) => {
errors.push(err);
SetpCompareOp::Integer(SetpCompareInt::Eq)
}
}
};
Self {
type_,
flush_to_zero,
cmp_op,
}
}
}
pub struct SetpBoolData {
pub base: SetpData,
pub bool_op: SetpBoolPostOp,
pub negate_src3: bool
}
#[derive(PartialEq, Eq, Copy, Clone)]
pub enum SetpCompareOp {
Integer(SetpCompareInt),
Float(SetpCompareFloat),
}
#[derive(PartialEq, Eq, Copy, Clone)]
pub enum SetpCompareInt {
Eq,
NotEq,
Less,
LessOrEq,
Greater,
GreaterOrEq,
}
#[derive(PartialEq, Eq, Copy, Clone)]
pub enum SetpCompareFloat {
Eq,
NotEq,
Less,
LessOrEq,
Greater,
GreaterOrEq,
NanEq,
NanNotEq,
NanLess,
NanLessOrEq,
NanGreater,
NanGreaterOrEq,
IsNotNan,
IsAnyNan,
}
impl TryFrom<RawSetpCompareOp> for SetpCompareInt {
type Error = PtxError;
fn try_from(value: RawSetpCompareOp) -> Result<Self, PtxError> {
match value {
RawSetpCompareOp::Eq => Ok(SetpCompareInt::Eq),
RawSetpCompareOp::Ne => Ok(SetpCompareInt::NotEq),
RawSetpCompareOp::Lt => Ok(SetpCompareInt::Less),
RawSetpCompareOp::Le => Ok(SetpCompareInt::LessOrEq),
RawSetpCompareOp::Gt => Ok(SetpCompareInt::Greater),
RawSetpCompareOp::Ge => Ok(SetpCompareInt::GreaterOrEq),
RawSetpCompareOp::Lo => Ok(SetpCompareInt::Less),
RawSetpCompareOp::Ls => Ok(SetpCompareInt::LessOrEq),
RawSetpCompareOp::Hi => Ok(SetpCompareInt::Greater),
RawSetpCompareOp::Hs => Ok(SetpCompareInt::GreaterOrEq),
RawSetpCompareOp::Equ => Err(PtxError::WrongType),
RawSetpCompareOp::Neu => Err(PtxError::WrongType),
RawSetpCompareOp::Ltu => Err(PtxError::WrongType),
RawSetpCompareOp::Leu => Err(PtxError::WrongType),
RawSetpCompareOp::Gtu => Err(PtxError::WrongType),
RawSetpCompareOp::Geu => Err(PtxError::WrongType),
RawSetpCompareOp::Num => Err(PtxError::WrongType),
RawSetpCompareOp::Nan => Err(PtxError::WrongType),
}
}
}
impl From<RawSetpCompareOp> for SetpCompareFloat {
fn from(value: RawSetpCompareOp) -> Self {
match value {
RawSetpCompareOp::Eq => SetpCompareFloat::Eq,
RawSetpCompareOp::Ne => SetpCompareFloat::NotEq,
RawSetpCompareOp::Lt => SetpCompareFloat::Less,
RawSetpCompareOp::Le => SetpCompareFloat::LessOrEq,
RawSetpCompareOp::Gt => SetpCompareFloat::Greater,
RawSetpCompareOp::Ge => SetpCompareFloat::GreaterOrEq,
RawSetpCompareOp::Lo => SetpCompareFloat::Less,
RawSetpCompareOp::Ls => SetpCompareFloat::LessOrEq,
RawSetpCompareOp::Hi => SetpCompareFloat::Greater,
RawSetpCompareOp::Hs => SetpCompareFloat::GreaterOrEq,
RawSetpCompareOp::Equ => SetpCompareFloat::NanEq,
RawSetpCompareOp::Neu => SetpCompareFloat::NanNotEq,
RawSetpCompareOp::Ltu => SetpCompareFloat::NanLess,
RawSetpCompareOp::Leu => SetpCompareFloat::NanLessOrEq,
RawSetpCompareOp::Gtu => SetpCompareFloat::NanGreater,
RawSetpCompareOp::Geu => SetpCompareFloat::NanGreaterOrEq,
RawSetpCompareOp::Num => SetpCompareFloat::IsNotNan,
RawSetpCompareOp::Nan => SetpCompareFloat::IsAnyNan,
}
}
}

View file

@ -769,6 +769,8 @@ pub enum PtxError {
#[error("")]
NonF32Ftz,
#[error("")]
WrongType,
#[error("")]
WrongArrayType,
#[error("")]
WrongVectorElement,
@ -996,6 +998,9 @@ derive_parser!(
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
pub enum ScalarType { }
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
pub enum SetpBoolPostOp { }
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov
mov{.vec}.type d, a => {
Instruction::Mov {
@ -1424,6 +1429,38 @@ derive_parser!(
.rnd: RawFloatRounding = { .rn };
ScalarType = { .f16, .f16x2, .bf16, .bf16x2 };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-setp
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-comparison-instructions-setp
setp.CmpOp{.ftz}.type p[|q], a, b => {
let data = ast::SetpData::try_parse(state, cmpop, ftz, type_);
ast::Instruction::Setp {
data,
arguments: SetpArgs { dst1: p, dst2: q, src1: a, src2: b }
}
}
setp.CmpOp.BoolOp{.ftz}.type p[|q], a, b, {!}c => {
let (negate_src3, c) = c;
let base = ast::SetpData::try_parse(state, cmpop, ftz, type_);
let data = ast::SetpBoolData {
base,
bool_op: boolop,
negate_src3
};
ast::Instruction::SetpBool {
data,
arguments: SetpBoolArgs { dst1: p, dst2: q, src1: a, src2: b, src3: c }
}
}
.CmpOp: RawSetpCompareOp = { .eq, .ne, .lt, .le, .gt, .ge,
.lo, .ls, .hi, .hs, // signed
.equ, .neu, .ltu, .leu, .gtu, .geu, .num, .nan }; // float-only
.BoolOp: SetpBoolPostOp = { .and, .or, .xor };
.type: ScalarType = { .b16, .b32, .b64,
.u16, .u32, .u64,
.s16, .s32, .s64,
.f32, .f64,
.f16, .f16x2, .bf16, .bf16x2 };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret
ret{.uni} => {
Instruction::Ret { data: RetData { uniform: uni } }
@ -1432,8 +1469,6 @@ derive_parser!(
);
fn main() {
use winnow::combinator::*;
use winnow::token::*;
use winnow::Parser;
let lexer = Token::lexer(