Implement setp

This commit is contained in:
Andrzej Janik 2024-08-19 02:23:26 +02:00
commit c08e6a6772
5 changed files with 388 additions and 55 deletions

View file

@ -13,3 +13,4 @@ rustc-hash = "2.0.0"
syn = "2.0.67"
quote = "1.0"
proc-macro2 = "1.0.86"
either = "1.13.0"

View file

@ -1,3 +1,4 @@
use either::Either;
use gen_impl::parser;
use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote, ToTokens};
@ -28,7 +29,7 @@ static POSTFIX_TYPES: &[&str] = &["ScalarType", "VectorPrefix"];
struct OpcodeDefinitions {
definitions: Vec<SingleOpcodeDefinition>,
block_selection: Vec<Vec<(Option<parser::DotModifier>, usize)>>,
block_selection: Vec<Vec<(Option<Vec<parser::DotModifier>>, usize)>>,
}
impl OpcodeDefinitions {
@ -51,33 +52,51 @@ impl OpcodeDefinitions {
_ => {}
}
'check_definitions: for i in unselected.iter().copied() {
// Attempt every modifier
'check_candidates: for candidate in definitions[i]
let mut candidates = definitions[i]
.unordered_modifiers
.iter()
.chain(definitions[i].ordered_modifiers.iter())
{
let candidate = if let DotModifierRef::Direct {
optional: false,
value,
..
} = candidate
{
value
} else {
continue;
};
.filter(|modifier| match modifier {
DotModifierRef::Direct {
optional: false, ..
}
| DotModifierRef::Indirect {
optional: false, ..
} => true,
_ => false,
})
.collect::<Vec<_>>();
candidates.sort_by_key(|modifier| match modifier {
DotModifierRef::Direct { .. } => 1,
DotModifierRef::Indirect { value, .. } => value.alternatives.len(),
});
// Attempt every modifier
'check_candidates: for candidate_modifier in candidates {
// check all other unselected patterns
for j in unselected.iter().copied() {
if i == j {
continue;
}
if definitions[j].possible_modifiers.contains(candidate) {
continue 'check_candidates;
let candidate_set = match candidate_modifier {
DotModifierRef::Direct { value, .. } => Either::Left(iter::once(value)),
DotModifierRef::Indirect { value, .. } => {
Either::Right(value.alternatives.iter())
}
};
for candidate_value in candidate_set {
if definitions[j].possible_modifiers.contains(candidate_value) {
continue 'check_candidates;
}
}
}
// it's unique
selections[i] = Some((Some(candidate), generation));
let candidate_vec = match candidate_modifier {
DotModifierRef::Direct { value, .. } => vec![value.clone()],
DotModifierRef::Indirect { value, .. } => {
value.alternatives.iter().cloned().collect::<Vec<_>>()
}
};
selections[i] = Some((Some(candidate_vec), generation));
selected_something = true;
continue 'check_definitions;
}
@ -96,9 +115,9 @@ impl OpcodeDefinitions {
let mut current_generation_definitions = Vec::new();
for (idx, selection) in selections.iter_mut().enumerate() {
match selection {
Some((modifier, generation)) => {
Some((modifier_set, generation)) => {
if *generation == current_generation {
current_generation_definitions.push((modifier.cloned(), idx));
current_generation_definitions.push((modifier_set.clone(), idx));
*selection = None;
}
}
@ -181,6 +200,8 @@ impl SingleOpcodeDefinition {
let name = &arg.ident;
let arg_type = if arg.unified {
quote! { (ParsedOperandStr<'input>, bool) }
} else if arg.can_be_negated {
quote! { (bool, ParsedOperandStr<'input>) }
} else {
quote! { ParsedOperandStr<'input> }
};
@ -222,9 +243,6 @@ impl SingleOpcodeDefinition {
unnamed_rules = FxHashMap::default();
}
let mut possible_modifiers = FxHashSet::default();
for (_, options) in named_rules.iter() {
possible_modifiers.extend(options.alternatives.iter().cloned());
}
let parser::OpcodeDecl(instruction, arguments) = opcode_decl;
let mut unordered_modifiers = instruction
.modifiers
@ -232,6 +250,7 @@ impl SingleOpcodeDefinition {
.map(|parser::MaybeDotModifier { optional, modifier }| {
match named_rules.get(&modifier) {
Some(alts) => {
possible_modifiers.extend(alts.alternatives.iter().cloned());
if alts.alternatives.len() == 1 && alts.type_.is_none() {
DotModifierRef::Direct {
optional,
@ -437,11 +456,10 @@ fn emit_parse_function(
for (selection_key, selected_definition) in selection_layer {
let def_parser = emit_definition_parser(type_name, (opcode,*selected_definition), &def.definitions[*selected_definition]);
match selection_key {
Some(selection_key) => {
let selection_key =
selection_key.dot_capitalized();
Some(selection_keys) => {
let selection_keys = selection_keys.iter().map(|k| k.dot_capitalized());
quote! {
else if modifiers.contains(& #type_name :: #selection_key) {
else if false #(|| modifiers.contains(& #type_name :: #selection_keys))* {
#def_parser
}
}
@ -715,7 +733,7 @@ fn emit_definition_parser(
| DotModifierRef::Indirect { optional: true, .. } => TokenStream::new(),
});
let arguments_parse = definition.arguments.0.iter().enumerate().map(|(idx, arg)| {
let comma = if idx == 0 {
let comma = if idx == 0 || arg.pre_pipe {
quote! { empty }
} else {
quote! { any.verify(|t| *t == #token_type::Comma).void() }
@ -774,10 +792,17 @@ fn emit_definition_parser(
(#comma, #pre_bracket, #pre_pipe, #can_be_negated, #operand, #post_bracket, #unified)
};
let arg_name = &arg.ident;
if arg.unified && arg.can_be_negated {
panic!("TODO: argument can't be both prefixed by `!` and suffixed by `.unified`")
}
let inner_parser = if arg.unified {
quote! {
#pattern.map(|(_, _, _, _, name, _, unified)| (name, unified))
}
} else if arg.can_be_negated {
quote! {
#pattern.map(|(_, _, _, negated, name, _, _)| (negated, name))
}
} else {
quote! {
#pattern.map(|(_, _, _, _, name, _, _)| name)

View file

@ -70,7 +70,7 @@ impl GenerateInstructionType {
let visit_slice_fn = format_ident!("visit{}_slice", kind.fn_suffix());
let (type_parameters, visitor_parameters, return_type) = if kind == VisitKind::Map {
(
quote! { <#type_parameters, To> },
quote! { <#type_parameters, To: Operand> },
quote! { <#short_parameters, To> },
quote! { #type_name<To> },
)
@ -514,19 +514,29 @@ impl ArgumentField {
.unwrap_or_else(|| quote! { StateSpace::Reg });
let is_dst = self.is_dst;
let name = &self.name;
let arguments_name = if is_mut {
quote! {
&mut arguments.#name
}
let (operand_fn, arguments_name) = if is_mut {
(
quote! {
VisitOperand::visit_mut
},
quote! {
&mut arguments.#name
},
)
} else {
quote! {
& arguments.#name
}
(
quote! {
VisitOperand::visit
},
quote! {
& arguments.#name
},
)
};
quote! {{
let type_ = #type_;
let space = #space;
visitor.visit(#arguments_name, &type_, space, #is_dst);
#operand_fn(#arguments_name, |x| visitor.visit(x, &type_, space, #is_dst));
}}
}
@ -548,7 +558,7 @@ impl ArgumentField {
let #name = {
let type_ = #type_;
let space = #space;
visitor.visit(arguments.#name, &type_, space, #is_dst)
MapOperand::map(arguments.#name, |x| visitor.visit(x, &type_, space, #is_dst))
};
}
}

View file

@ -1,6 +1,5 @@
use std::intrinsics::unreachable;
use super::{MemScope, ScalarType, StateSpace, VectorPrefix};
use super::{MemScope, RawSetpCompareOp, ScalarType, SetpBoolPostOp, StateSpace, VectorPrefix};
use crate::{PtxError, PtxParserState};
use bitflags::bitflags;
pub enum Statement<P: Operand> {
@ -11,7 +10,7 @@ pub enum Statement<P: Operand> {
}
gen::generate_instruction_type!(
pub enum Instruction<T> {
pub enum Instruction<T: Operand> {
Mov {
type: { &data.typ },
data: MovDetails,
@ -63,6 +62,52 @@ gen::generate_instruction_type!(
src2: T,
}
},
Setp {
data: SetpData,
arguments<T>: {
dst1: {
repr: T,
type: ScalarType::Pred.into()
},
dst2: {
repr: Option<T>,
type: ScalarType::Pred.into()
},
src1: {
repr: T,
type: data.type_.into(),
},
src2: {
repr: T,
type: data.type_.into(),
}
}
},
SetpBool {
data: SetpBoolData,
arguments<T>: {
dst1: {
repr: T,
type: ScalarType::Pred.into()
},
dst2: {
repr: Option<T>,
type: ScalarType::Pred.into()
},
src1: {
repr: T,
type: data.base.type_.into(),
},
src2: {
repr: T,
type: data.base.type_.into(),
},
src3: {
repr: T,
type: ScalarType::Pred.into()
}
}
},
Ret {
data: RetData
},
@ -70,6 +115,66 @@ gen::generate_instruction_type!(
}
);
pub trait Visitor<T> {
fn visit(&mut self, args: &T, type_: &Type, space: StateSpace, is_dst: bool);
}
pub trait VisitorMut<T> {
fn visit(&mut self, args: &mut T, type_: &Type, space: StateSpace, is_dst: bool);
}
pub trait VisitorMap<From, To> {
fn visit(&mut self, args: From, type_: &Type, space: StateSpace, is_dst: bool) -> To;
}
trait VisitOperand {
type Operand;
fn visit(&self, fn_: impl FnOnce(&Self::Operand));
fn visit_mut(&mut self, fn_: impl FnOnce(&mut Self::Operand));
}
impl<T: Operand> VisitOperand for T {
type Operand = Self;
fn visit(&self, fn_: impl FnOnce(&Self::Operand)) {
fn_(self)
}
fn visit_mut(&mut self, fn_: impl FnOnce(&mut Self::Operand)) {
fn_(self)
}
}
impl<T: Operand> VisitOperand for Option<T> {
type Operand = T;
fn visit(&self, fn_: impl FnOnce(&Self::Operand)) {
self.as_ref().map(fn_);
}
fn visit_mut(&mut self, fn_: impl FnOnce(&mut Self::Operand)) {
self.as_mut().map(fn_);
}
}
trait MapOperand: Sized {
type Input;
type Output<U>;
fn map<U>(self, fn_: impl FnOnce(Self::Input) -> U) -> Self::Output<U>;
}
impl<T: Operand> MapOperand for T {
type Input = Self;
type Output<U> = U;
fn map<U>(self, fn_: impl FnOnce(T) -> U) -> U {
fn_(self)
}
}
impl<T: Operand> MapOperand for Option<T> {
type Input = T;
type Output<U> = Option<U>;
fn map<U>(self, fn_: impl FnOnce(T) -> U) -> Option<U> {
self.map(|x| fn_(x))
}
}
pub struct MultiVariable<ID> {
pub var: Variable<ID>,
pub count: Option<u32>,
@ -89,18 +194,6 @@ pub struct PredAt<ID> {
pub label: ID,
}
pub trait Visitor<T> {
fn visit(&mut self, args: &T, type_: &Type, space: StateSpace, is_dst: bool);
}
pub trait VisitorMut<T> {
fn visit(&mut self, args: &mut T, type_: &Type, space: StateSpace, is_dst: bool);
}
pub trait VisitorMap<From, To> {
fn visit(&mut self, args: From, type_: &Type, space: StateSpace, is_dst: bool) -> To;
}
#[derive(PartialEq, Eq, Clone, Hash)]
pub enum Type {
// .param.b32 foo;
@ -121,6 +214,43 @@ impl Type {
}
}
impl ScalarType {
pub fn kind(self) -> ScalarKind {
match self {
ScalarType::U8 => ScalarKind::Unsigned,
ScalarType::U16 => ScalarKind::Unsigned,
ScalarType::U16x2 => ScalarKind::Unsigned,
ScalarType::U32 => ScalarKind::Unsigned,
ScalarType::U64 => ScalarKind::Unsigned,
ScalarType::S8 => ScalarKind::Signed,
ScalarType::S16 => ScalarKind::Signed,
ScalarType::S16x2 => ScalarKind::Signed,
ScalarType::S32 => ScalarKind::Signed,
ScalarType::S64 => ScalarKind::Signed,
ScalarType::B8 => ScalarKind::Bit,
ScalarType::B16 => ScalarKind::Bit,
ScalarType::B32 => ScalarKind::Bit,
ScalarType::B64 => ScalarKind::Bit,
ScalarType::B128 => ScalarKind::Bit,
ScalarType::F16 => ScalarKind::Float,
ScalarType::F16x2 => ScalarKind::Float,
ScalarType::F32 => ScalarKind::Float,
ScalarType::F64 => ScalarKind::Float,
ScalarType::BF16 => ScalarKind::Float,
ScalarType::BF16x2 => ScalarKind::Float,
ScalarType::Pred => ScalarKind::Pred,
}
}
}
#[derive(Clone, Copy, PartialEq, Eq)]
pub enum ScalarKind {
Bit,
Unsigned,
Signed,
Float,
Pred,
}
impl From<ScalarType> for Type {
fn from(value: ScalarType) -> Self {
Type::Scalar(value)
@ -347,3 +477,135 @@ pub enum MulIntControl {
High,
Wide,
}
pub struct SetpData {
pub type_: ScalarType,
pub flush_to_zero: Option<bool>,
pub cmp_op: SetpCompareOp,
}
impl SetpData {
pub(crate) fn try_parse(
errors: &mut PtxParserState,
cmp_op: super::RawSetpCompareOp,
ftz: bool,
type_: ScalarType,
) -> Self {
let flush_to_zero = match (ftz, type_) {
(_, ScalarType::F32) => Some(ftz),
_ => {
errors.push(PtxError::NonF32Ftz);
None
}
};
let type_kind = type_.kind();
let cmp_op = if type_kind == ScalarKind::Float {
SetpCompareOp::Float(SetpCompareFloat::from(cmp_op))
} else {
match SetpCompareInt::try_from(cmp_op) {
Ok(op) => SetpCompareOp::Integer(op),
Err(err) => {
errors.push(err);
SetpCompareOp::Integer(SetpCompareInt::Eq)
}
}
};
Self {
type_,
flush_to_zero,
cmp_op,
}
}
}
pub struct SetpBoolData {
pub base: SetpData,
pub bool_op: SetpBoolPostOp,
pub negate_src3: bool
}
#[derive(PartialEq, Eq, Copy, Clone)]
pub enum SetpCompareOp {
Integer(SetpCompareInt),
Float(SetpCompareFloat),
}
#[derive(PartialEq, Eq, Copy, Clone)]
pub enum SetpCompareInt {
Eq,
NotEq,
Less,
LessOrEq,
Greater,
GreaterOrEq,
}
#[derive(PartialEq, Eq, Copy, Clone)]
pub enum SetpCompareFloat {
Eq,
NotEq,
Less,
LessOrEq,
Greater,
GreaterOrEq,
NanEq,
NanNotEq,
NanLess,
NanLessOrEq,
NanGreater,
NanGreaterOrEq,
IsNotNan,
IsAnyNan,
}
impl TryFrom<RawSetpCompareOp> for SetpCompareInt {
type Error = PtxError;
fn try_from(value: RawSetpCompareOp) -> Result<Self, PtxError> {
match value {
RawSetpCompareOp::Eq => Ok(SetpCompareInt::Eq),
RawSetpCompareOp::Ne => Ok(SetpCompareInt::NotEq),
RawSetpCompareOp::Lt => Ok(SetpCompareInt::Less),
RawSetpCompareOp::Le => Ok(SetpCompareInt::LessOrEq),
RawSetpCompareOp::Gt => Ok(SetpCompareInt::Greater),
RawSetpCompareOp::Ge => Ok(SetpCompareInt::GreaterOrEq),
RawSetpCompareOp::Lo => Ok(SetpCompareInt::Less),
RawSetpCompareOp::Ls => Ok(SetpCompareInt::LessOrEq),
RawSetpCompareOp::Hi => Ok(SetpCompareInt::Greater),
RawSetpCompareOp::Hs => Ok(SetpCompareInt::GreaterOrEq),
RawSetpCompareOp::Equ => Err(PtxError::WrongType),
RawSetpCompareOp::Neu => Err(PtxError::WrongType),
RawSetpCompareOp::Ltu => Err(PtxError::WrongType),
RawSetpCompareOp::Leu => Err(PtxError::WrongType),
RawSetpCompareOp::Gtu => Err(PtxError::WrongType),
RawSetpCompareOp::Geu => Err(PtxError::WrongType),
RawSetpCompareOp::Num => Err(PtxError::WrongType),
RawSetpCompareOp::Nan => Err(PtxError::WrongType),
}
}
}
impl From<RawSetpCompareOp> for SetpCompareFloat {
fn from(value: RawSetpCompareOp) -> Self {
match value {
RawSetpCompareOp::Eq => SetpCompareFloat::Eq,
RawSetpCompareOp::Ne => SetpCompareFloat::NotEq,
RawSetpCompareOp::Lt => SetpCompareFloat::Less,
RawSetpCompareOp::Le => SetpCompareFloat::LessOrEq,
RawSetpCompareOp::Gt => SetpCompareFloat::Greater,
RawSetpCompareOp::Ge => SetpCompareFloat::GreaterOrEq,
RawSetpCompareOp::Lo => SetpCompareFloat::Less,
RawSetpCompareOp::Ls => SetpCompareFloat::LessOrEq,
RawSetpCompareOp::Hi => SetpCompareFloat::Greater,
RawSetpCompareOp::Hs => SetpCompareFloat::GreaterOrEq,
RawSetpCompareOp::Equ => SetpCompareFloat::NanEq,
RawSetpCompareOp::Neu => SetpCompareFloat::NanNotEq,
RawSetpCompareOp::Ltu => SetpCompareFloat::NanLess,
RawSetpCompareOp::Leu => SetpCompareFloat::NanLessOrEq,
RawSetpCompareOp::Gtu => SetpCompareFloat::NanGreater,
RawSetpCompareOp::Geu => SetpCompareFloat::NanGreaterOrEq,
RawSetpCompareOp::Num => SetpCompareFloat::IsNotNan,
RawSetpCompareOp::Nan => SetpCompareFloat::IsAnyNan,
}
}
}

View file

@ -769,6 +769,8 @@ pub enum PtxError {
#[error("")]
NonF32Ftz,
#[error("")]
WrongType,
#[error("")]
WrongArrayType,
#[error("")]
WrongVectorElement,
@ -996,6 +998,9 @@ derive_parser!(
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
pub enum ScalarType { }
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
pub enum SetpBoolPostOp { }
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov
mov{.vec}.type d, a => {
Instruction::Mov {
@ -1424,6 +1429,38 @@ derive_parser!(
.rnd: RawFloatRounding = { .rn };
ScalarType = { .f16, .f16x2, .bf16, .bf16x2 };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-setp
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-comparison-instructions-setp
setp.CmpOp{.ftz}.type p[|q], a, b => {
let data = ast::SetpData::try_parse(state, cmpop, ftz, type_);
ast::Instruction::Setp {
data,
arguments: SetpArgs { dst1: p, dst2: q, src1: a, src2: b }
}
}
setp.CmpOp.BoolOp{.ftz}.type p[|q], a, b, {!}c => {
let (negate_src3, c) = c;
let base = ast::SetpData::try_parse(state, cmpop, ftz, type_);
let data = ast::SetpBoolData {
base,
bool_op: boolop,
negate_src3
};
ast::Instruction::SetpBool {
data,
arguments: SetpBoolArgs { dst1: p, dst2: q, src1: a, src2: b, src3: c }
}
}
.CmpOp: RawSetpCompareOp = { .eq, .ne, .lt, .le, .gt, .ge,
.lo, .ls, .hi, .hs, // signed
.equ, .neu, .ltu, .leu, .gtu, .geu, .num, .nan }; // float-only
.BoolOp: SetpBoolPostOp = { .and, .or, .xor };
.type: ScalarType = { .b16, .b32, .b64,
.u16, .u32, .u64,
.s16, .s32, .s64,
.f32, .f64,
.f16, .f16x2, .bf16, .bf16x2 };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret
ret{.uni} => {
Instruction::Ret { data: RetData { uniform: uni } }
@ -1432,8 +1469,6 @@ derive_parser!(
);
fn main() {
use winnow::combinator::*;
use winnow::token::*;
use winnow::Parser;
let lexer = Token::lexer(