mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-08-02 22:30:41 +00:00
Implement setp
This commit is contained in:
parent
cb64b04f41
commit
c08e6a6772
5 changed files with 388 additions and 55 deletions
|
@ -13,3 +13,4 @@ rustc-hash = "2.0.0"
|
||||||
syn = "2.0.67"
|
syn = "2.0.67"
|
||||||
quote = "1.0"
|
quote = "1.0"
|
||||||
proc-macro2 = "1.0.86"
|
proc-macro2 = "1.0.86"
|
||||||
|
either = "1.13.0"
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
use either::Either;
|
||||||
use gen_impl::parser;
|
use gen_impl::parser;
|
||||||
use proc_macro2::{Span, TokenStream};
|
use proc_macro2::{Span, TokenStream};
|
||||||
use quote::{format_ident, quote, ToTokens};
|
use quote::{format_ident, quote, ToTokens};
|
||||||
|
@ -28,7 +29,7 @@ static POSTFIX_TYPES: &[&str] = &["ScalarType", "VectorPrefix"];
|
||||||
|
|
||||||
struct OpcodeDefinitions {
|
struct OpcodeDefinitions {
|
||||||
definitions: Vec<SingleOpcodeDefinition>,
|
definitions: Vec<SingleOpcodeDefinition>,
|
||||||
block_selection: Vec<Vec<(Option<parser::DotModifier>, usize)>>,
|
block_selection: Vec<Vec<(Option<Vec<parser::DotModifier>>, usize)>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OpcodeDefinitions {
|
impl OpcodeDefinitions {
|
||||||
|
@ -51,33 +52,51 @@ impl OpcodeDefinitions {
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
'check_definitions: for i in unselected.iter().copied() {
|
'check_definitions: for i in unselected.iter().copied() {
|
||||||
// Attempt every modifier
|
let mut candidates = definitions[i]
|
||||||
'check_candidates: for candidate in definitions[i]
|
|
||||||
.unordered_modifiers
|
.unordered_modifiers
|
||||||
.iter()
|
.iter()
|
||||||
.chain(definitions[i].ordered_modifiers.iter())
|
.chain(definitions[i].ordered_modifiers.iter())
|
||||||
{
|
.filter(|modifier| match modifier {
|
||||||
let candidate = if let DotModifierRef::Direct {
|
DotModifierRef::Direct {
|
||||||
optional: false,
|
optional: false, ..
|
||||||
value,
|
}
|
||||||
..
|
| DotModifierRef::Indirect {
|
||||||
} = candidate
|
optional: false, ..
|
||||||
{
|
} => true,
|
||||||
value
|
_ => false,
|
||||||
} else {
|
})
|
||||||
continue;
|
.collect::<Vec<_>>();
|
||||||
};
|
candidates.sort_by_key(|modifier| match modifier {
|
||||||
|
DotModifierRef::Direct { .. } => 1,
|
||||||
|
DotModifierRef::Indirect { value, .. } => value.alternatives.len(),
|
||||||
|
});
|
||||||
|
// Attempt every modifier
|
||||||
|
'check_candidates: for candidate_modifier in candidates {
|
||||||
// check all other unselected patterns
|
// check all other unselected patterns
|
||||||
for j in unselected.iter().copied() {
|
for j in unselected.iter().copied() {
|
||||||
if i == j {
|
if i == j {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if definitions[j].possible_modifiers.contains(candidate) {
|
let candidate_set = match candidate_modifier {
|
||||||
|
DotModifierRef::Direct { value, .. } => Either::Left(iter::once(value)),
|
||||||
|
DotModifierRef::Indirect { value, .. } => {
|
||||||
|
Either::Right(value.alternatives.iter())
|
||||||
|
}
|
||||||
|
};
|
||||||
|
for candidate_value in candidate_set {
|
||||||
|
if definitions[j].possible_modifiers.contains(candidate_value) {
|
||||||
continue 'check_candidates;
|
continue 'check_candidates;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
// it's unique
|
// it's unique
|
||||||
selections[i] = Some((Some(candidate), generation));
|
let candidate_vec = match candidate_modifier {
|
||||||
|
DotModifierRef::Direct { value, .. } => vec![value.clone()],
|
||||||
|
DotModifierRef::Indirect { value, .. } => {
|
||||||
|
value.alternatives.iter().cloned().collect::<Vec<_>>()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
selections[i] = Some((Some(candidate_vec), generation));
|
||||||
selected_something = true;
|
selected_something = true;
|
||||||
continue 'check_definitions;
|
continue 'check_definitions;
|
||||||
}
|
}
|
||||||
|
@ -96,9 +115,9 @@ impl OpcodeDefinitions {
|
||||||
let mut current_generation_definitions = Vec::new();
|
let mut current_generation_definitions = Vec::new();
|
||||||
for (idx, selection) in selections.iter_mut().enumerate() {
|
for (idx, selection) in selections.iter_mut().enumerate() {
|
||||||
match selection {
|
match selection {
|
||||||
Some((modifier, generation)) => {
|
Some((modifier_set, generation)) => {
|
||||||
if *generation == current_generation {
|
if *generation == current_generation {
|
||||||
current_generation_definitions.push((modifier.cloned(), idx));
|
current_generation_definitions.push((modifier_set.clone(), idx));
|
||||||
*selection = None;
|
*selection = None;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -181,6 +200,8 @@ impl SingleOpcodeDefinition {
|
||||||
let name = &arg.ident;
|
let name = &arg.ident;
|
||||||
let arg_type = if arg.unified {
|
let arg_type = if arg.unified {
|
||||||
quote! { (ParsedOperandStr<'input>, bool) }
|
quote! { (ParsedOperandStr<'input>, bool) }
|
||||||
|
} else if arg.can_be_negated {
|
||||||
|
quote! { (bool, ParsedOperandStr<'input>) }
|
||||||
} else {
|
} else {
|
||||||
quote! { ParsedOperandStr<'input> }
|
quote! { ParsedOperandStr<'input> }
|
||||||
};
|
};
|
||||||
|
@ -222,9 +243,6 @@ impl SingleOpcodeDefinition {
|
||||||
unnamed_rules = FxHashMap::default();
|
unnamed_rules = FxHashMap::default();
|
||||||
}
|
}
|
||||||
let mut possible_modifiers = FxHashSet::default();
|
let mut possible_modifiers = FxHashSet::default();
|
||||||
for (_, options) in named_rules.iter() {
|
|
||||||
possible_modifiers.extend(options.alternatives.iter().cloned());
|
|
||||||
}
|
|
||||||
let parser::OpcodeDecl(instruction, arguments) = opcode_decl;
|
let parser::OpcodeDecl(instruction, arguments) = opcode_decl;
|
||||||
let mut unordered_modifiers = instruction
|
let mut unordered_modifiers = instruction
|
||||||
.modifiers
|
.modifiers
|
||||||
|
@ -232,6 +250,7 @@ impl SingleOpcodeDefinition {
|
||||||
.map(|parser::MaybeDotModifier { optional, modifier }| {
|
.map(|parser::MaybeDotModifier { optional, modifier }| {
|
||||||
match named_rules.get(&modifier) {
|
match named_rules.get(&modifier) {
|
||||||
Some(alts) => {
|
Some(alts) => {
|
||||||
|
possible_modifiers.extend(alts.alternatives.iter().cloned());
|
||||||
if alts.alternatives.len() == 1 && alts.type_.is_none() {
|
if alts.alternatives.len() == 1 && alts.type_.is_none() {
|
||||||
DotModifierRef::Direct {
|
DotModifierRef::Direct {
|
||||||
optional,
|
optional,
|
||||||
|
@ -437,11 +456,10 @@ fn emit_parse_function(
|
||||||
for (selection_key, selected_definition) in selection_layer {
|
for (selection_key, selected_definition) in selection_layer {
|
||||||
let def_parser = emit_definition_parser(type_name, (opcode,*selected_definition), &def.definitions[*selected_definition]);
|
let def_parser = emit_definition_parser(type_name, (opcode,*selected_definition), &def.definitions[*selected_definition]);
|
||||||
match selection_key {
|
match selection_key {
|
||||||
Some(selection_key) => {
|
Some(selection_keys) => {
|
||||||
let selection_key =
|
let selection_keys = selection_keys.iter().map(|k| k.dot_capitalized());
|
||||||
selection_key.dot_capitalized();
|
|
||||||
quote! {
|
quote! {
|
||||||
else if modifiers.contains(& #type_name :: #selection_key) {
|
else if false #(|| modifiers.contains(& #type_name :: #selection_keys))* {
|
||||||
#def_parser
|
#def_parser
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -715,7 +733,7 @@ fn emit_definition_parser(
|
||||||
| DotModifierRef::Indirect { optional: true, .. } => TokenStream::new(),
|
| DotModifierRef::Indirect { optional: true, .. } => TokenStream::new(),
|
||||||
});
|
});
|
||||||
let arguments_parse = definition.arguments.0.iter().enumerate().map(|(idx, arg)| {
|
let arguments_parse = definition.arguments.0.iter().enumerate().map(|(idx, arg)| {
|
||||||
let comma = if idx == 0 {
|
let comma = if idx == 0 || arg.pre_pipe {
|
||||||
quote! { empty }
|
quote! { empty }
|
||||||
} else {
|
} else {
|
||||||
quote! { any.verify(|t| *t == #token_type::Comma).void() }
|
quote! { any.verify(|t| *t == #token_type::Comma).void() }
|
||||||
|
@ -774,10 +792,17 @@ fn emit_definition_parser(
|
||||||
(#comma, #pre_bracket, #pre_pipe, #can_be_negated, #operand, #post_bracket, #unified)
|
(#comma, #pre_bracket, #pre_pipe, #can_be_negated, #operand, #post_bracket, #unified)
|
||||||
};
|
};
|
||||||
let arg_name = &arg.ident;
|
let arg_name = &arg.ident;
|
||||||
|
if arg.unified && arg.can_be_negated {
|
||||||
|
panic!("TODO: argument can't be both prefixed by `!` and suffixed by `.unified`")
|
||||||
|
}
|
||||||
let inner_parser = if arg.unified {
|
let inner_parser = if arg.unified {
|
||||||
quote! {
|
quote! {
|
||||||
#pattern.map(|(_, _, _, _, name, _, unified)| (name, unified))
|
#pattern.map(|(_, _, _, _, name, _, unified)| (name, unified))
|
||||||
}
|
}
|
||||||
|
} else if arg.can_be_negated {
|
||||||
|
quote! {
|
||||||
|
#pattern.map(|(_, _, _, negated, name, _, _)| (negated, name))
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
quote! {
|
quote! {
|
||||||
#pattern.map(|(_, _, _, _, name, _, _)| name)
|
#pattern.map(|(_, _, _, _, name, _, _)| name)
|
||||||
|
|
|
@ -70,7 +70,7 @@ impl GenerateInstructionType {
|
||||||
let visit_slice_fn = format_ident!("visit{}_slice", kind.fn_suffix());
|
let visit_slice_fn = format_ident!("visit{}_slice", kind.fn_suffix());
|
||||||
let (type_parameters, visitor_parameters, return_type) = if kind == VisitKind::Map {
|
let (type_parameters, visitor_parameters, return_type) = if kind == VisitKind::Map {
|
||||||
(
|
(
|
||||||
quote! { <#type_parameters, To> },
|
quote! { <#type_parameters, To: Operand> },
|
||||||
quote! { <#short_parameters, To> },
|
quote! { <#short_parameters, To> },
|
||||||
quote! { #type_name<To> },
|
quote! { #type_name<To> },
|
||||||
)
|
)
|
||||||
|
@ -514,19 +514,29 @@ impl ArgumentField {
|
||||||
.unwrap_or_else(|| quote! { StateSpace::Reg });
|
.unwrap_or_else(|| quote! { StateSpace::Reg });
|
||||||
let is_dst = self.is_dst;
|
let is_dst = self.is_dst;
|
||||||
let name = &self.name;
|
let name = &self.name;
|
||||||
let arguments_name = if is_mut {
|
let (operand_fn, arguments_name) = if is_mut {
|
||||||
|
(
|
||||||
|
quote! {
|
||||||
|
VisitOperand::visit_mut
|
||||||
|
},
|
||||||
quote! {
|
quote! {
|
||||||
&mut arguments.#name
|
&mut arguments.#name
|
||||||
}
|
},
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
|
(
|
||||||
|
quote! {
|
||||||
|
VisitOperand::visit
|
||||||
|
},
|
||||||
quote! {
|
quote! {
|
||||||
& arguments.#name
|
& arguments.#name
|
||||||
}
|
},
|
||||||
|
)
|
||||||
};
|
};
|
||||||
quote! {{
|
quote! {{
|
||||||
let type_ = #type_;
|
let type_ = #type_;
|
||||||
let space = #space;
|
let space = #space;
|
||||||
visitor.visit(#arguments_name, &type_, space, #is_dst);
|
#operand_fn(#arguments_name, |x| visitor.visit(x, &type_, space, #is_dst));
|
||||||
}}
|
}}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -548,7 +558,7 @@ impl ArgumentField {
|
||||||
let #name = {
|
let #name = {
|
||||||
let type_ = #type_;
|
let type_ = #type_;
|
||||||
let space = #space;
|
let space = #space;
|
||||||
visitor.visit(arguments.#name, &type_, space, #is_dst)
|
MapOperand::map(arguments.#name, |x| visitor.visit(x, &type_, space, #is_dst))
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
use std::intrinsics::unreachable;
|
use super::{MemScope, RawSetpCompareOp, ScalarType, SetpBoolPostOp, StateSpace, VectorPrefix};
|
||||||
|
use crate::{PtxError, PtxParserState};
|
||||||
use super::{MemScope, ScalarType, StateSpace, VectorPrefix};
|
|
||||||
use bitflags::bitflags;
|
use bitflags::bitflags;
|
||||||
|
|
||||||
pub enum Statement<P: Operand> {
|
pub enum Statement<P: Operand> {
|
||||||
|
@ -11,7 +10,7 @@ pub enum Statement<P: Operand> {
|
||||||
}
|
}
|
||||||
|
|
||||||
gen::generate_instruction_type!(
|
gen::generate_instruction_type!(
|
||||||
pub enum Instruction<T> {
|
pub enum Instruction<T: Operand> {
|
||||||
Mov {
|
Mov {
|
||||||
type: { &data.typ },
|
type: { &data.typ },
|
||||||
data: MovDetails,
|
data: MovDetails,
|
||||||
|
@ -63,6 +62,52 @@ gen::generate_instruction_type!(
|
||||||
src2: T,
|
src2: T,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
Setp {
|
||||||
|
data: SetpData,
|
||||||
|
arguments<T>: {
|
||||||
|
dst1: {
|
||||||
|
repr: T,
|
||||||
|
type: ScalarType::Pred.into()
|
||||||
|
},
|
||||||
|
dst2: {
|
||||||
|
repr: Option<T>,
|
||||||
|
type: ScalarType::Pred.into()
|
||||||
|
},
|
||||||
|
src1: {
|
||||||
|
repr: T,
|
||||||
|
type: data.type_.into(),
|
||||||
|
},
|
||||||
|
src2: {
|
||||||
|
repr: T,
|
||||||
|
type: data.type_.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
SetpBool {
|
||||||
|
data: SetpBoolData,
|
||||||
|
arguments<T>: {
|
||||||
|
dst1: {
|
||||||
|
repr: T,
|
||||||
|
type: ScalarType::Pred.into()
|
||||||
|
},
|
||||||
|
dst2: {
|
||||||
|
repr: Option<T>,
|
||||||
|
type: ScalarType::Pred.into()
|
||||||
|
},
|
||||||
|
src1: {
|
||||||
|
repr: T,
|
||||||
|
type: data.base.type_.into(),
|
||||||
|
},
|
||||||
|
src2: {
|
||||||
|
repr: T,
|
||||||
|
type: data.base.type_.into(),
|
||||||
|
},
|
||||||
|
src3: {
|
||||||
|
repr: T,
|
||||||
|
type: ScalarType::Pred.into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
Ret {
|
Ret {
|
||||||
data: RetData
|
data: RetData
|
||||||
},
|
},
|
||||||
|
@ -70,6 +115,66 @@ gen::generate_instruction_type!(
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
|
pub trait Visitor<T> {
|
||||||
|
fn visit(&mut self, args: &T, type_: &Type, space: StateSpace, is_dst: bool);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait VisitorMut<T> {
|
||||||
|
fn visit(&mut self, args: &mut T, type_: &Type, space: StateSpace, is_dst: bool);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait VisitorMap<From, To> {
|
||||||
|
fn visit(&mut self, args: From, type_: &Type, space: StateSpace, is_dst: bool) -> To;
|
||||||
|
}
|
||||||
|
|
||||||
|
trait VisitOperand {
|
||||||
|
type Operand;
|
||||||
|
fn visit(&self, fn_: impl FnOnce(&Self::Operand));
|
||||||
|
fn visit_mut(&mut self, fn_: impl FnOnce(&mut Self::Operand));
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Operand> VisitOperand for T {
|
||||||
|
type Operand = Self;
|
||||||
|
fn visit(&self, fn_: impl FnOnce(&Self::Operand)) {
|
||||||
|
fn_(self)
|
||||||
|
}
|
||||||
|
fn visit_mut(&mut self, fn_: impl FnOnce(&mut Self::Operand)) {
|
||||||
|
fn_(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Operand> VisitOperand for Option<T> {
|
||||||
|
type Operand = T;
|
||||||
|
fn visit(&self, fn_: impl FnOnce(&Self::Operand)) {
|
||||||
|
self.as_ref().map(fn_);
|
||||||
|
}
|
||||||
|
fn visit_mut(&mut self, fn_: impl FnOnce(&mut Self::Operand)) {
|
||||||
|
self.as_mut().map(fn_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
trait MapOperand: Sized {
|
||||||
|
type Input;
|
||||||
|
type Output<U>;
|
||||||
|
fn map<U>(self, fn_: impl FnOnce(Self::Input) -> U) -> Self::Output<U>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Operand> MapOperand for T {
|
||||||
|
type Input = Self;
|
||||||
|
type Output<U> = U;
|
||||||
|
fn map<U>(self, fn_: impl FnOnce(T) -> U) -> U {
|
||||||
|
fn_(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Operand> MapOperand for Option<T> {
|
||||||
|
type Input = T;
|
||||||
|
type Output<U> = Option<U>;
|
||||||
|
fn map<U>(self, fn_: impl FnOnce(T) -> U) -> Option<U> {
|
||||||
|
self.map(|x| fn_(x))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct MultiVariable<ID> {
|
pub struct MultiVariable<ID> {
|
||||||
pub var: Variable<ID>,
|
pub var: Variable<ID>,
|
||||||
pub count: Option<u32>,
|
pub count: Option<u32>,
|
||||||
|
@ -89,18 +194,6 @@ pub struct PredAt<ID> {
|
||||||
pub label: ID,
|
pub label: ID,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait Visitor<T> {
|
|
||||||
fn visit(&mut self, args: &T, type_: &Type, space: StateSpace, is_dst: bool);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait VisitorMut<T> {
|
|
||||||
fn visit(&mut self, args: &mut T, type_: &Type, space: StateSpace, is_dst: bool);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait VisitorMap<From, To> {
|
|
||||||
fn visit(&mut self, args: From, type_: &Type, space: StateSpace, is_dst: bool) -> To;
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Clone, Hash)]
|
#[derive(PartialEq, Eq, Clone, Hash)]
|
||||||
pub enum Type {
|
pub enum Type {
|
||||||
// .param.b32 foo;
|
// .param.b32 foo;
|
||||||
|
@ -121,6 +214,43 @@ impl Type {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl ScalarType {
|
||||||
|
pub fn kind(self) -> ScalarKind {
|
||||||
|
match self {
|
||||||
|
ScalarType::U8 => ScalarKind::Unsigned,
|
||||||
|
ScalarType::U16 => ScalarKind::Unsigned,
|
||||||
|
ScalarType::U16x2 => ScalarKind::Unsigned,
|
||||||
|
ScalarType::U32 => ScalarKind::Unsigned,
|
||||||
|
ScalarType::U64 => ScalarKind::Unsigned,
|
||||||
|
ScalarType::S8 => ScalarKind::Signed,
|
||||||
|
ScalarType::S16 => ScalarKind::Signed,
|
||||||
|
ScalarType::S16x2 => ScalarKind::Signed,
|
||||||
|
ScalarType::S32 => ScalarKind::Signed,
|
||||||
|
ScalarType::S64 => ScalarKind::Signed,
|
||||||
|
ScalarType::B8 => ScalarKind::Bit,
|
||||||
|
ScalarType::B16 => ScalarKind::Bit,
|
||||||
|
ScalarType::B32 => ScalarKind::Bit,
|
||||||
|
ScalarType::B64 => ScalarKind::Bit,
|
||||||
|
ScalarType::B128 => ScalarKind::Bit,
|
||||||
|
ScalarType::F16 => ScalarKind::Float,
|
||||||
|
ScalarType::F16x2 => ScalarKind::Float,
|
||||||
|
ScalarType::F32 => ScalarKind::Float,
|
||||||
|
ScalarType::F64 => ScalarKind::Float,
|
||||||
|
ScalarType::BF16 => ScalarKind::Float,
|
||||||
|
ScalarType::BF16x2 => ScalarKind::Float,
|
||||||
|
ScalarType::Pred => ScalarKind::Pred,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum ScalarKind {
|
||||||
|
Bit,
|
||||||
|
Unsigned,
|
||||||
|
Signed,
|
||||||
|
Float,
|
||||||
|
Pred,
|
||||||
|
}
|
||||||
impl From<ScalarType> for Type {
|
impl From<ScalarType> for Type {
|
||||||
fn from(value: ScalarType) -> Self {
|
fn from(value: ScalarType) -> Self {
|
||||||
Type::Scalar(value)
|
Type::Scalar(value)
|
||||||
|
@ -347,3 +477,135 @@ pub enum MulIntControl {
|
||||||
High,
|
High,
|
||||||
Wide,
|
Wide,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct SetpData {
|
||||||
|
pub type_: ScalarType,
|
||||||
|
pub flush_to_zero: Option<bool>,
|
||||||
|
pub cmp_op: SetpCompareOp,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SetpData {
|
||||||
|
pub(crate) fn try_parse(
|
||||||
|
errors: &mut PtxParserState,
|
||||||
|
cmp_op: super::RawSetpCompareOp,
|
||||||
|
ftz: bool,
|
||||||
|
type_: ScalarType,
|
||||||
|
) -> Self {
|
||||||
|
let flush_to_zero = match (ftz, type_) {
|
||||||
|
(_, ScalarType::F32) => Some(ftz),
|
||||||
|
_ => {
|
||||||
|
errors.push(PtxError::NonF32Ftz);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let type_kind = type_.kind();
|
||||||
|
let cmp_op = if type_kind == ScalarKind::Float {
|
||||||
|
SetpCompareOp::Float(SetpCompareFloat::from(cmp_op))
|
||||||
|
} else {
|
||||||
|
match SetpCompareInt::try_from(cmp_op) {
|
||||||
|
Ok(op) => SetpCompareOp::Integer(op),
|
||||||
|
Err(err) => {
|
||||||
|
errors.push(err);
|
||||||
|
SetpCompareOp::Integer(SetpCompareInt::Eq)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Self {
|
||||||
|
type_,
|
||||||
|
flush_to_zero,
|
||||||
|
cmp_op,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct SetpBoolData {
|
||||||
|
pub base: SetpData,
|
||||||
|
pub bool_op: SetpBoolPostOp,
|
||||||
|
pub negate_src3: bool
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(PartialEq, Eq, Copy, Clone)]
|
||||||
|
pub enum SetpCompareOp {
|
||||||
|
Integer(SetpCompareInt),
|
||||||
|
Float(SetpCompareFloat),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(PartialEq, Eq, Copy, Clone)]
|
||||||
|
pub enum SetpCompareInt {
|
||||||
|
Eq,
|
||||||
|
NotEq,
|
||||||
|
Less,
|
||||||
|
LessOrEq,
|
||||||
|
Greater,
|
||||||
|
GreaterOrEq,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(PartialEq, Eq, Copy, Clone)]
|
||||||
|
pub enum SetpCompareFloat {
|
||||||
|
Eq,
|
||||||
|
NotEq,
|
||||||
|
Less,
|
||||||
|
LessOrEq,
|
||||||
|
Greater,
|
||||||
|
GreaterOrEq,
|
||||||
|
NanEq,
|
||||||
|
NanNotEq,
|
||||||
|
NanLess,
|
||||||
|
NanLessOrEq,
|
||||||
|
NanGreater,
|
||||||
|
NanGreaterOrEq,
|
||||||
|
IsNotNan,
|
||||||
|
IsAnyNan,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<RawSetpCompareOp> for SetpCompareInt {
|
||||||
|
type Error = PtxError;
|
||||||
|
|
||||||
|
fn try_from(value: RawSetpCompareOp) -> Result<Self, PtxError> {
|
||||||
|
match value {
|
||||||
|
RawSetpCompareOp::Eq => Ok(SetpCompareInt::Eq),
|
||||||
|
RawSetpCompareOp::Ne => Ok(SetpCompareInt::NotEq),
|
||||||
|
RawSetpCompareOp::Lt => Ok(SetpCompareInt::Less),
|
||||||
|
RawSetpCompareOp::Le => Ok(SetpCompareInt::LessOrEq),
|
||||||
|
RawSetpCompareOp::Gt => Ok(SetpCompareInt::Greater),
|
||||||
|
RawSetpCompareOp::Ge => Ok(SetpCompareInt::GreaterOrEq),
|
||||||
|
RawSetpCompareOp::Lo => Ok(SetpCompareInt::Less),
|
||||||
|
RawSetpCompareOp::Ls => Ok(SetpCompareInt::LessOrEq),
|
||||||
|
RawSetpCompareOp::Hi => Ok(SetpCompareInt::Greater),
|
||||||
|
RawSetpCompareOp::Hs => Ok(SetpCompareInt::GreaterOrEq),
|
||||||
|
RawSetpCompareOp::Equ => Err(PtxError::WrongType),
|
||||||
|
RawSetpCompareOp::Neu => Err(PtxError::WrongType),
|
||||||
|
RawSetpCompareOp::Ltu => Err(PtxError::WrongType),
|
||||||
|
RawSetpCompareOp::Leu => Err(PtxError::WrongType),
|
||||||
|
RawSetpCompareOp::Gtu => Err(PtxError::WrongType),
|
||||||
|
RawSetpCompareOp::Geu => Err(PtxError::WrongType),
|
||||||
|
RawSetpCompareOp::Num => Err(PtxError::WrongType),
|
||||||
|
RawSetpCompareOp::Nan => Err(PtxError::WrongType),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<RawSetpCompareOp> for SetpCompareFloat {
|
||||||
|
fn from(value: RawSetpCompareOp) -> Self {
|
||||||
|
match value {
|
||||||
|
RawSetpCompareOp::Eq => SetpCompareFloat::Eq,
|
||||||
|
RawSetpCompareOp::Ne => SetpCompareFloat::NotEq,
|
||||||
|
RawSetpCompareOp::Lt => SetpCompareFloat::Less,
|
||||||
|
RawSetpCompareOp::Le => SetpCompareFloat::LessOrEq,
|
||||||
|
RawSetpCompareOp::Gt => SetpCompareFloat::Greater,
|
||||||
|
RawSetpCompareOp::Ge => SetpCompareFloat::GreaterOrEq,
|
||||||
|
RawSetpCompareOp::Lo => SetpCompareFloat::Less,
|
||||||
|
RawSetpCompareOp::Ls => SetpCompareFloat::LessOrEq,
|
||||||
|
RawSetpCompareOp::Hi => SetpCompareFloat::Greater,
|
||||||
|
RawSetpCompareOp::Hs => SetpCompareFloat::GreaterOrEq,
|
||||||
|
RawSetpCompareOp::Equ => SetpCompareFloat::NanEq,
|
||||||
|
RawSetpCompareOp::Neu => SetpCompareFloat::NanNotEq,
|
||||||
|
RawSetpCompareOp::Ltu => SetpCompareFloat::NanLess,
|
||||||
|
RawSetpCompareOp::Leu => SetpCompareFloat::NanLessOrEq,
|
||||||
|
RawSetpCompareOp::Gtu => SetpCompareFloat::NanGreater,
|
||||||
|
RawSetpCompareOp::Geu => SetpCompareFloat::NanGreaterOrEq,
|
||||||
|
RawSetpCompareOp::Num => SetpCompareFloat::IsNotNan,
|
||||||
|
RawSetpCompareOp::Nan => SetpCompareFloat::IsAnyNan,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -769,6 +769,8 @@ pub enum PtxError {
|
||||||
#[error("")]
|
#[error("")]
|
||||||
NonF32Ftz,
|
NonF32Ftz,
|
||||||
#[error("")]
|
#[error("")]
|
||||||
|
WrongType,
|
||||||
|
#[error("")]
|
||||||
WrongArrayType,
|
WrongArrayType,
|
||||||
#[error("")]
|
#[error("")]
|
||||||
WrongVectorElement,
|
WrongVectorElement,
|
||||||
|
@ -996,6 +998,9 @@ derive_parser!(
|
||||||
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
|
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
|
||||||
pub enum ScalarType { }
|
pub enum ScalarType { }
|
||||||
|
|
||||||
|
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
|
||||||
|
pub enum SetpBoolPostOp { }
|
||||||
|
|
||||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov
|
||||||
mov{.vec}.type d, a => {
|
mov{.vec}.type d, a => {
|
||||||
Instruction::Mov {
|
Instruction::Mov {
|
||||||
|
@ -1424,6 +1429,38 @@ derive_parser!(
|
||||||
.rnd: RawFloatRounding = { .rn };
|
.rnd: RawFloatRounding = { .rn };
|
||||||
ScalarType = { .f16, .f16x2, .bf16, .bf16x2 };
|
ScalarType = { .f16, .f16x2, .bf16, .bf16x2 };
|
||||||
|
|
||||||
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-setp
|
||||||
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-comparison-instructions-setp
|
||||||
|
setp.CmpOp{.ftz}.type p[|q], a, b => {
|
||||||
|
let data = ast::SetpData::try_parse(state, cmpop, ftz, type_);
|
||||||
|
ast::Instruction::Setp {
|
||||||
|
data,
|
||||||
|
arguments: SetpArgs { dst1: p, dst2: q, src1: a, src2: b }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
setp.CmpOp.BoolOp{.ftz}.type p[|q], a, b, {!}c => {
|
||||||
|
let (negate_src3, c) = c;
|
||||||
|
let base = ast::SetpData::try_parse(state, cmpop, ftz, type_);
|
||||||
|
let data = ast::SetpBoolData {
|
||||||
|
base,
|
||||||
|
bool_op: boolop,
|
||||||
|
negate_src3
|
||||||
|
};
|
||||||
|
ast::Instruction::SetpBool {
|
||||||
|
data,
|
||||||
|
arguments: SetpBoolArgs { dst1: p, dst2: q, src1: a, src2: b, src3: c }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.CmpOp: RawSetpCompareOp = { .eq, .ne, .lt, .le, .gt, .ge,
|
||||||
|
.lo, .ls, .hi, .hs, // signed
|
||||||
|
.equ, .neu, .ltu, .leu, .gtu, .geu, .num, .nan }; // float-only
|
||||||
|
.BoolOp: SetpBoolPostOp = { .and, .or, .xor };
|
||||||
|
.type: ScalarType = { .b16, .b32, .b64,
|
||||||
|
.u16, .u32, .u64,
|
||||||
|
.s16, .s32, .s64,
|
||||||
|
.f32, .f64,
|
||||||
|
.f16, .f16x2, .bf16, .bf16x2 };
|
||||||
|
|
||||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret
|
||||||
ret{.uni} => {
|
ret{.uni} => {
|
||||||
Instruction::Ret { data: RetData { uniform: uni } }
|
Instruction::Ret { data: RetData { uniform: uni } }
|
||||||
|
@ -1432,8 +1469,6 @@ derive_parser!(
|
||||||
);
|
);
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
use winnow::combinator::*;
|
|
||||||
use winnow::token::*;
|
|
||||||
use winnow::Parser;
|
use winnow::Parser;
|
||||||
|
|
||||||
let lexer = Token::lexer(
|
let lexer = Token::lexer(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue