mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-08-26 04:08:12 +00:00
Implement setp
This commit is contained in:
parent
cb64b04f41
commit
c08e6a6772
5 changed files with 388 additions and 55 deletions
|
@ -1,6 +1,5 @@
|
|||
use std::intrinsics::unreachable;
|
||||
|
||||
use super::{MemScope, ScalarType, StateSpace, VectorPrefix};
|
||||
use super::{MemScope, RawSetpCompareOp, ScalarType, SetpBoolPostOp, StateSpace, VectorPrefix};
|
||||
use crate::{PtxError, PtxParserState};
|
||||
use bitflags::bitflags;
|
||||
|
||||
pub enum Statement<P: Operand> {
|
||||
|
@ -11,7 +10,7 @@ pub enum Statement<P: Operand> {
|
|||
}
|
||||
|
||||
gen::generate_instruction_type!(
|
||||
pub enum Instruction<T> {
|
||||
pub enum Instruction<T: Operand> {
|
||||
Mov {
|
||||
type: { &data.typ },
|
||||
data: MovDetails,
|
||||
|
@ -63,6 +62,52 @@ gen::generate_instruction_type!(
|
|||
src2: T,
|
||||
}
|
||||
},
|
||||
Setp {
|
||||
data: SetpData,
|
||||
arguments<T>: {
|
||||
dst1: {
|
||||
repr: T,
|
||||
type: ScalarType::Pred.into()
|
||||
},
|
||||
dst2: {
|
||||
repr: Option<T>,
|
||||
type: ScalarType::Pred.into()
|
||||
},
|
||||
src1: {
|
||||
repr: T,
|
||||
type: data.type_.into(),
|
||||
},
|
||||
src2: {
|
||||
repr: T,
|
||||
type: data.type_.into(),
|
||||
}
|
||||
}
|
||||
},
|
||||
SetpBool {
|
||||
data: SetpBoolData,
|
||||
arguments<T>: {
|
||||
dst1: {
|
||||
repr: T,
|
||||
type: ScalarType::Pred.into()
|
||||
},
|
||||
dst2: {
|
||||
repr: Option<T>,
|
||||
type: ScalarType::Pred.into()
|
||||
},
|
||||
src1: {
|
||||
repr: T,
|
||||
type: data.base.type_.into(),
|
||||
},
|
||||
src2: {
|
||||
repr: T,
|
||||
type: data.base.type_.into(),
|
||||
},
|
||||
src3: {
|
||||
repr: T,
|
||||
type: ScalarType::Pred.into()
|
||||
}
|
||||
}
|
||||
},
|
||||
Ret {
|
||||
data: RetData
|
||||
},
|
||||
|
@ -70,6 +115,66 @@ gen::generate_instruction_type!(
|
|||
}
|
||||
);
|
||||
|
||||
pub trait Visitor<T> {
|
||||
fn visit(&mut self, args: &T, type_: &Type, space: StateSpace, is_dst: bool);
|
||||
}
|
||||
|
||||
pub trait VisitorMut<T> {
|
||||
fn visit(&mut self, args: &mut T, type_: &Type, space: StateSpace, is_dst: bool);
|
||||
}
|
||||
|
||||
pub trait VisitorMap<From, To> {
|
||||
fn visit(&mut self, args: From, type_: &Type, space: StateSpace, is_dst: bool) -> To;
|
||||
}
|
||||
|
||||
trait VisitOperand {
|
||||
type Operand;
|
||||
fn visit(&self, fn_: impl FnOnce(&Self::Operand));
|
||||
fn visit_mut(&mut self, fn_: impl FnOnce(&mut Self::Operand));
|
||||
}
|
||||
|
||||
impl<T: Operand> VisitOperand for T {
|
||||
type Operand = Self;
|
||||
fn visit(&self, fn_: impl FnOnce(&Self::Operand)) {
|
||||
fn_(self)
|
||||
}
|
||||
fn visit_mut(&mut self, fn_: impl FnOnce(&mut Self::Operand)) {
|
||||
fn_(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Operand> VisitOperand for Option<T> {
|
||||
type Operand = T;
|
||||
fn visit(&self, fn_: impl FnOnce(&Self::Operand)) {
|
||||
self.as_ref().map(fn_);
|
||||
}
|
||||
fn visit_mut(&mut self, fn_: impl FnOnce(&mut Self::Operand)) {
|
||||
self.as_mut().map(fn_);
|
||||
}
|
||||
}
|
||||
|
||||
trait MapOperand: Sized {
|
||||
type Input;
|
||||
type Output<U>;
|
||||
fn map<U>(self, fn_: impl FnOnce(Self::Input) -> U) -> Self::Output<U>;
|
||||
}
|
||||
|
||||
impl<T: Operand> MapOperand for T {
|
||||
type Input = Self;
|
||||
type Output<U> = U;
|
||||
fn map<U>(self, fn_: impl FnOnce(T) -> U) -> U {
|
||||
fn_(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Operand> MapOperand for Option<T> {
|
||||
type Input = T;
|
||||
type Output<U> = Option<U>;
|
||||
fn map<U>(self, fn_: impl FnOnce(T) -> U) -> Option<U> {
|
||||
self.map(|x| fn_(x))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MultiVariable<ID> {
|
||||
pub var: Variable<ID>,
|
||||
pub count: Option<u32>,
|
||||
|
@ -89,18 +194,6 @@ pub struct PredAt<ID> {
|
|||
pub label: ID,
|
||||
}
|
||||
|
||||
pub trait Visitor<T> {
|
||||
fn visit(&mut self, args: &T, type_: &Type, space: StateSpace, is_dst: bool);
|
||||
}
|
||||
|
||||
pub trait VisitorMut<T> {
|
||||
fn visit(&mut self, args: &mut T, type_: &Type, space: StateSpace, is_dst: bool);
|
||||
}
|
||||
|
||||
pub trait VisitorMap<From, To> {
|
||||
fn visit(&mut self, args: From, type_: &Type, space: StateSpace, is_dst: bool) -> To;
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Clone, Hash)]
|
||||
pub enum Type {
|
||||
// .param.b32 foo;
|
||||
|
@ -121,6 +214,43 @@ impl Type {
|
|||
}
|
||||
}
|
||||
|
||||
impl ScalarType {
|
||||
pub fn kind(self) -> ScalarKind {
|
||||
match self {
|
||||
ScalarType::U8 => ScalarKind::Unsigned,
|
||||
ScalarType::U16 => ScalarKind::Unsigned,
|
||||
ScalarType::U16x2 => ScalarKind::Unsigned,
|
||||
ScalarType::U32 => ScalarKind::Unsigned,
|
||||
ScalarType::U64 => ScalarKind::Unsigned,
|
||||
ScalarType::S8 => ScalarKind::Signed,
|
||||
ScalarType::S16 => ScalarKind::Signed,
|
||||
ScalarType::S16x2 => ScalarKind::Signed,
|
||||
ScalarType::S32 => ScalarKind::Signed,
|
||||
ScalarType::S64 => ScalarKind::Signed,
|
||||
ScalarType::B8 => ScalarKind::Bit,
|
||||
ScalarType::B16 => ScalarKind::Bit,
|
||||
ScalarType::B32 => ScalarKind::Bit,
|
||||
ScalarType::B64 => ScalarKind::Bit,
|
||||
ScalarType::B128 => ScalarKind::Bit,
|
||||
ScalarType::F16 => ScalarKind::Float,
|
||||
ScalarType::F16x2 => ScalarKind::Float,
|
||||
ScalarType::F32 => ScalarKind::Float,
|
||||
ScalarType::F64 => ScalarKind::Float,
|
||||
ScalarType::BF16 => ScalarKind::Float,
|
||||
ScalarType::BF16x2 => ScalarKind::Float,
|
||||
ScalarType::Pred => ScalarKind::Pred,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ScalarKind {
|
||||
Bit,
|
||||
Unsigned,
|
||||
Signed,
|
||||
Float,
|
||||
Pred,
|
||||
}
|
||||
impl From<ScalarType> for Type {
|
||||
fn from(value: ScalarType) -> Self {
|
||||
Type::Scalar(value)
|
||||
|
@ -347,3 +477,135 @@ pub enum MulIntControl {
|
|||
High,
|
||||
Wide,
|
||||
}
|
||||
|
||||
pub struct SetpData {
|
||||
pub type_: ScalarType,
|
||||
pub flush_to_zero: Option<bool>,
|
||||
pub cmp_op: SetpCompareOp,
|
||||
}
|
||||
|
||||
impl SetpData {
|
||||
pub(crate) fn try_parse(
|
||||
errors: &mut PtxParserState,
|
||||
cmp_op: super::RawSetpCompareOp,
|
||||
ftz: bool,
|
||||
type_: ScalarType,
|
||||
) -> Self {
|
||||
let flush_to_zero = match (ftz, type_) {
|
||||
(_, ScalarType::F32) => Some(ftz),
|
||||
_ => {
|
||||
errors.push(PtxError::NonF32Ftz);
|
||||
None
|
||||
}
|
||||
};
|
||||
let type_kind = type_.kind();
|
||||
let cmp_op = if type_kind == ScalarKind::Float {
|
||||
SetpCompareOp::Float(SetpCompareFloat::from(cmp_op))
|
||||
} else {
|
||||
match SetpCompareInt::try_from(cmp_op) {
|
||||
Ok(op) => SetpCompareOp::Integer(op),
|
||||
Err(err) => {
|
||||
errors.push(err);
|
||||
SetpCompareOp::Integer(SetpCompareInt::Eq)
|
||||
}
|
||||
}
|
||||
};
|
||||
Self {
|
||||
type_,
|
||||
flush_to_zero,
|
||||
cmp_op,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SetpBoolData {
|
||||
pub base: SetpData,
|
||||
pub bool_op: SetpBoolPostOp,
|
||||
pub negate_src3: bool
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Copy, Clone)]
|
||||
pub enum SetpCompareOp {
|
||||
Integer(SetpCompareInt),
|
||||
Float(SetpCompareFloat),
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Copy, Clone)]
|
||||
pub enum SetpCompareInt {
|
||||
Eq,
|
||||
NotEq,
|
||||
Less,
|
||||
LessOrEq,
|
||||
Greater,
|
||||
GreaterOrEq,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Copy, Clone)]
|
||||
pub enum SetpCompareFloat {
|
||||
Eq,
|
||||
NotEq,
|
||||
Less,
|
||||
LessOrEq,
|
||||
Greater,
|
||||
GreaterOrEq,
|
||||
NanEq,
|
||||
NanNotEq,
|
||||
NanLess,
|
||||
NanLessOrEq,
|
||||
NanGreater,
|
||||
NanGreaterOrEq,
|
||||
IsNotNan,
|
||||
IsAnyNan,
|
||||
}
|
||||
|
||||
impl TryFrom<RawSetpCompareOp> for SetpCompareInt {
|
||||
type Error = PtxError;
|
||||
|
||||
fn try_from(value: RawSetpCompareOp) -> Result<Self, PtxError> {
|
||||
match value {
|
||||
RawSetpCompareOp::Eq => Ok(SetpCompareInt::Eq),
|
||||
RawSetpCompareOp::Ne => Ok(SetpCompareInt::NotEq),
|
||||
RawSetpCompareOp::Lt => Ok(SetpCompareInt::Less),
|
||||
RawSetpCompareOp::Le => Ok(SetpCompareInt::LessOrEq),
|
||||
RawSetpCompareOp::Gt => Ok(SetpCompareInt::Greater),
|
||||
RawSetpCompareOp::Ge => Ok(SetpCompareInt::GreaterOrEq),
|
||||
RawSetpCompareOp::Lo => Ok(SetpCompareInt::Less),
|
||||
RawSetpCompareOp::Ls => Ok(SetpCompareInt::LessOrEq),
|
||||
RawSetpCompareOp::Hi => Ok(SetpCompareInt::Greater),
|
||||
RawSetpCompareOp::Hs => Ok(SetpCompareInt::GreaterOrEq),
|
||||
RawSetpCompareOp::Equ => Err(PtxError::WrongType),
|
||||
RawSetpCompareOp::Neu => Err(PtxError::WrongType),
|
||||
RawSetpCompareOp::Ltu => Err(PtxError::WrongType),
|
||||
RawSetpCompareOp::Leu => Err(PtxError::WrongType),
|
||||
RawSetpCompareOp::Gtu => Err(PtxError::WrongType),
|
||||
RawSetpCompareOp::Geu => Err(PtxError::WrongType),
|
||||
RawSetpCompareOp::Num => Err(PtxError::WrongType),
|
||||
RawSetpCompareOp::Nan => Err(PtxError::WrongType),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RawSetpCompareOp> for SetpCompareFloat {
|
||||
fn from(value: RawSetpCompareOp) -> Self {
|
||||
match value {
|
||||
RawSetpCompareOp::Eq => SetpCompareFloat::Eq,
|
||||
RawSetpCompareOp::Ne => SetpCompareFloat::NotEq,
|
||||
RawSetpCompareOp::Lt => SetpCompareFloat::Less,
|
||||
RawSetpCompareOp::Le => SetpCompareFloat::LessOrEq,
|
||||
RawSetpCompareOp::Gt => SetpCompareFloat::Greater,
|
||||
RawSetpCompareOp::Ge => SetpCompareFloat::GreaterOrEq,
|
||||
RawSetpCompareOp::Lo => SetpCompareFloat::Less,
|
||||
RawSetpCompareOp::Ls => SetpCompareFloat::LessOrEq,
|
||||
RawSetpCompareOp::Hi => SetpCompareFloat::Greater,
|
||||
RawSetpCompareOp::Hs => SetpCompareFloat::GreaterOrEq,
|
||||
RawSetpCompareOp::Equ => SetpCompareFloat::NanEq,
|
||||
RawSetpCompareOp::Neu => SetpCompareFloat::NanNotEq,
|
||||
RawSetpCompareOp::Ltu => SetpCompareFloat::NanLess,
|
||||
RawSetpCompareOp::Leu => SetpCompareFloat::NanLessOrEq,
|
||||
RawSetpCompareOp::Gtu => SetpCompareFloat::NanGreater,
|
||||
RawSetpCompareOp::Geu => SetpCompareFloat::NanGreaterOrEq,
|
||||
RawSetpCompareOp::Num => SetpCompareFloat::IsNotNan,
|
||||
RawSetpCompareOp::Nan => SetpCompareFloat::IsAnyNan,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -769,6 +769,8 @@ pub enum PtxError {
|
|||
#[error("")]
|
||||
NonF32Ftz,
|
||||
#[error("")]
|
||||
WrongType,
|
||||
#[error("")]
|
||||
WrongArrayType,
|
||||
#[error("")]
|
||||
WrongVectorElement,
|
||||
|
@ -996,6 +998,9 @@ derive_parser!(
|
|||
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum ScalarType { }
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum SetpBoolPostOp { }
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov
|
||||
mov{.vec}.type d, a => {
|
||||
Instruction::Mov {
|
||||
|
@ -1424,6 +1429,38 @@ derive_parser!(
|
|||
.rnd: RawFloatRounding = { .rn };
|
||||
ScalarType = { .f16, .f16x2, .bf16, .bf16x2 };
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-setp
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-comparison-instructions-setp
|
||||
setp.CmpOp{.ftz}.type p[|q], a, b => {
|
||||
let data = ast::SetpData::try_parse(state, cmpop, ftz, type_);
|
||||
ast::Instruction::Setp {
|
||||
data,
|
||||
arguments: SetpArgs { dst1: p, dst2: q, src1: a, src2: b }
|
||||
}
|
||||
}
|
||||
setp.CmpOp.BoolOp{.ftz}.type p[|q], a, b, {!}c => {
|
||||
let (negate_src3, c) = c;
|
||||
let base = ast::SetpData::try_parse(state, cmpop, ftz, type_);
|
||||
let data = ast::SetpBoolData {
|
||||
base,
|
||||
bool_op: boolop,
|
||||
negate_src3
|
||||
};
|
||||
ast::Instruction::SetpBool {
|
||||
data,
|
||||
arguments: SetpBoolArgs { dst1: p, dst2: q, src1: a, src2: b, src3: c }
|
||||
}
|
||||
}
|
||||
.CmpOp: RawSetpCompareOp = { .eq, .ne, .lt, .le, .gt, .ge,
|
||||
.lo, .ls, .hi, .hs, // signed
|
||||
.equ, .neu, .ltu, .leu, .gtu, .geu, .num, .nan }; // float-only
|
||||
.BoolOp: SetpBoolPostOp = { .and, .or, .xor };
|
||||
.type: ScalarType = { .b16, .b32, .b64,
|
||||
.u16, .u32, .u64,
|
||||
.s16, .s32, .s64,
|
||||
.f32, .f64,
|
||||
.f16, .f16x2, .bf16, .bf16x2 };
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret
|
||||
ret{.uni} => {
|
||||
Instruction::Ret { data: RetData { uniform: uni } }
|
||||
|
@ -1432,8 +1469,6 @@ derive_parser!(
|
|||
);
|
||||
|
||||
fn main() {
|
||||
use winnow::combinator::*;
|
||||
use winnow::token::*;
|
||||
use winnow::Parser;
|
||||
|
||||
let lexer = Token::lexer(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue