From a05bee9ccba93fcd4b6e9d6adb864829ba8768c6 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Wed, 14 Aug 2024 11:38:54 +0200 Subject: [PATCH] Start rewriting PTX parser --- Cargo.toml | 5 + gen/Cargo.toml | 15 + gen/src/lib.rs | 860 +++++++++++++++++++++++++++++++++++++++++ gen_impl/Cargo.toml | 12 + gen_impl/src/lib.rs | 718 ++++++++++++++++++++++++++++++++++ gen_impl/src/parser.rs | 793 +++++++++++++++++++++++++++++++++++++ ptx_parser/Cargo.toml | 9 + ptx_parser/src/main.rs | 437 +++++++++++++++++++++ 8 files changed, 2849 insertions(+) create mode 100644 gen/Cargo.toml create mode 100644 gen/src/lib.rs create mode 100644 gen_impl/Cargo.toml create mode 100644 gen_impl/src/lib.rs create mode 100644 gen_impl/src/parser.rs create mode 100644 ptx_parser/Cargo.toml create mode 100644 ptx_parser/src/main.rs diff --git a/Cargo.toml b/Cargo.toml index 6371981..7f38976 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,7 @@ [workspace] +resolver = "2" + members = [ "cuda_base", "cuda_types", @@ -15,6 +17,9 @@ members = [ "zluda_redirect", "zluda_ml", "ptx", + "gen", + "gen_impl", + "ptx_parser" ] default-members = ["zluda_lib", "zluda_ml", "zluda_inject", "zluda_redirect"] diff --git a/gen/Cargo.toml b/gen/Cargo.toml new file mode 100644 index 0000000..e24be0f --- /dev/null +++ b/gen/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "gen" +version = "0.1.0" +edition = "2021" + +[lib] +proc-macro = true + +[dependencies] +gen_impl = { path = "../gen_impl" } +convert_case = "0.6.0" +rustc-hash = "2.0.0" +syn = "2.0.67" +quote = "1.0" +proc-macro2 = "1.0.86" diff --git a/gen/src/lib.rs b/gen/src/lib.rs new file mode 100644 index 0000000..f39150f --- /dev/null +++ b/gen/src/lib.rs @@ -0,0 +1,860 @@ +use gen_impl::parser; +use proc_macro2::{Span, TokenStream}; +use quote::{format_ident, quote, ToTokens}; +use rustc_hash::{FxHashMap, FxHashSet}; +use std::{collections::hash_map, hash::Hash, rc::Rc}; +use syn::{parse_macro_input, punctuated::Punctuated, Ident, ItemEnum, Token, TypePath, Variant}; + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#vectors +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#fundamental-types +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#alternate-floating-point-data-formats +#[rustfmt::skip] +static POSTFIX_MODIFIERS: &[&str] = &[ + ".v2", ".v4", + ".s8", ".s16", ".s32", ".s64", + ".u8", ".u16", ".u32", ".u64", + ".f16", ".f16x2", ".f32", ".f64", + ".b8", ".b16", ".b32", ".b64", ".b128", + ".pred", + ".bf16", ".e4m3", ".e5m2", ".tf32", +]; + +static POSTFIX_TYPES: &[&str] = &["ScalarType", "VectorPrefix"]; + +struct OpcodeDefinitions { + definitions: Vec, + block_selection: Vec, usize)>>, +} + +impl OpcodeDefinitions { + fn new(opcode: &Ident, definitions: Vec) -> Self { + let mut selections = vec![None; definitions.len()]; + let mut generation = 0usize; + loop { + let mut selected_something = false; + let unselected = selections + .iter() + .enumerate() + .filter_map(|(idx, s)| if s.is_none() { Some(idx) } else { None }) + .collect::>(); + match &*unselected { + [] => break, + [remaining] => { + selections[*remaining] = Some((None, generation)); + break; + } + _ => {} + } + 'check_definitions: for i in unselected.iter().copied() { + // just pick the first alternative and attempt every modifier + 'check_candidates: for candidate in definitions[i] + .unordered_modifiers + .iter() + .chain(definitions[i].ordered_modifiers.iter()) + { + let candidate = if let DotModifierRef::Direct { + optional: false, + value, + } = candidate + { + value + } else { + continue; + }; + // check all other unselected patterns + for j in unselected.iter().copied() { + if i == j { + continue; + } + if definitions[j].possible_modifiers.contains(candidate) { + continue 'check_candidates; + } + } + // it's unique + selections[i] = Some((Some(candidate), generation)); + selected_something = true; + continue 'check_definitions; + } + } + if !selected_something { + panic!( + "Failed to generate pattern selection for `{}`. State: {:?}", + opcode, + selections.into_iter().rev().collect::>() + ); + } + generation += 1; + } + let mut block_selection = Vec::new(); + for current_generation in 0usize.. { + let mut current_generation_definitions = Vec::new(); + for (idx, selection) in selections.iter_mut().enumerate() { + match selection { + Some((modifier, generation)) => { + if *generation == current_generation { + current_generation_definitions.push((modifier.cloned(), idx)); + *selection = None; + } + } + None => {} + } + } + if current_generation_definitions.is_empty() { + break; + } + block_selection.push(current_generation_definitions); + } + #[cfg(debug_assertions)] + { + let selected = block_selection + .iter() + .map(|x| x.len()) + .reduce(|x, y| x + y) + .unwrap(); + if selected != definitions.len() { + panic!( + "Internal error when generating pattern selection for `{}`: {:?}", + opcode, &block_selection + ); + } + } + Self { + definitions, + block_selection, + } + } + + fn get_enum_types( + parse_definitions: &[parser::OpcodeDefinition], + ) -> FxHashMap> { + let mut result = FxHashMap::default(); + for parser::OpcodeDefinition(_, rules) in parse_definitions.iter() { + for rule in rules { + let type_ = match rule.type_ { + Some(ref type_) => type_.clone(), + None => continue, + }; + let insert_values = |set: &mut FxHashSet<_>| { + for value in rule.alternatives.iter().cloned() { + set.insert(value); + } + }; + match result.entry(type_) { + hash_map::Entry::Occupied(mut entry) => insert_values(entry.get_mut()), + hash_map::Entry::Vacant(entry) => { + insert_values(entry.insert(FxHashSet::default())) + } + }; + } + } + result + } +} + +struct SingleOpcodeDefinition { + possible_modifiers: FxHashSet, + unordered_modifiers: Vec, + ordered_modifiers: Vec, + arguments: parser::Arguments, + code_block: parser::CodeBlock, +} + +impl SingleOpcodeDefinition { + fn function_arguments_declarations(&self) -> impl Iterator + '_ { + self.unordered_modifiers + .iter() + .chain(self.ordered_modifiers.iter()) + .filter_map(|modf| { + let type_ = modf.type_of(); + type_.map(|t| { + let name = modf.ident(); + quote! { #name : #t } + }) + }) + .chain(self.arguments.0.iter().map(|arg| { + let name = &arg.ident; + if arg.optional { + quote! { #name : Option<&str> } + } else { + quote! { #name : &str } + } + })) + } + + fn function_arguments(&self) -> impl Iterator + '_ { + self.unordered_modifiers + .iter() + .chain(self.ordered_modifiers.iter()) + .filter_map(|modf| { + let type_ = modf.type_of(); + type_.map(|_| { + let name = modf.ident(); + quote! { #name } + }) + }) + .chain(self.arguments.0.iter().map(|arg| { + let name = &arg.ident; + quote! { #name } + })) + } + + fn extract_and_insert( + output: &mut FxHashMap>, + parser::OpcodeDefinition(pattern_seq, rules): parser::OpcodeDefinition, + ) { + let mut rules = rules + .into_iter() + .map(|r| (r.modifier.clone(), Rc::new(r))) + .collect::>(); + let mut last_opcode = pattern_seq.0.last().unwrap().0 .0.name.clone(); + for (opcode_decl, code_block) in pattern_seq.0.into_iter().rev() { + let current_opcode = opcode_decl.0.name.clone(); + if last_opcode != current_opcode { + rules = FxHashMap::default(); + } + let mut possible_modifiers = FxHashSet::default(); + for (_, options) in rules.iter() { + possible_modifiers.extend(options.alternatives.iter().cloned()); + } + let parser::OpcodeDecl(instruction, arguments) = opcode_decl; + let mut unordered_modifiers = instruction + .modifiers + .into_iter() + .map( + |parser::MaybeDotModifier { optional, modifier }| match rules.get(&modifier) { + Some(alts) => { + if alts.alternatives.len() == 1 && alts.type_.is_none() { + DotModifierRef::Direct { + optional, + value: alts.alternatives[0].clone(), + } + } else { + DotModifierRef::Indirect { + optional, + value: alts.clone(), + } + } + } + None => { + possible_modifiers.insert(modifier.clone()); + DotModifierRef::Direct { + optional, + value: modifier, + } + } + }, + ) + .collect::>(); + let ordered_modifiers = Self::extract_ordered_modifiers(&mut unordered_modifiers); + let entry = Self { + possible_modifiers, + unordered_modifiers, + ordered_modifiers, + arguments, + code_block, + }; + multihash_extend(output, current_opcode.clone(), entry); + last_opcode = current_opcode; + } + } + + fn extract_ordered_modifiers( + unordered_modifiers: &mut Vec, + ) -> Vec { + let mut result = Vec::new(); + loop { + let is_ordered = match unordered_modifiers.last() { + Some(DotModifierRef::Direct { value, .. }) => { + let name = value.to_string(); + POSTFIX_MODIFIERS.contains(&&*name) + } + Some(DotModifierRef::Indirect { value, .. }) => { + let type_ = value.type_.to_token_stream().to_string(); + //panic!("{} {}", type_, POSTFIX_TYPES.contains(&&*type_)); + POSTFIX_TYPES.contains(&&*type_) + } + None => break, + }; + if is_ordered { + result.push(unordered_modifiers.pop().unwrap()); + } else { + break; + } + } + if unordered_modifiers.len() == 1 { + result.push(unordered_modifiers.pop().unwrap()); + } + result.reverse(); + result + } +} + +#[proc_macro] +pub fn derive_parser(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream { + let parse_definitions = parse_macro_input!(tokens as gen_impl::parser::ParseDefinitions); + let mut definitions = FxHashMap::default(); + let types = OpcodeDefinitions::get_enum_types(&parse_definitions.definitions); + let enum_types_tokens = emit_enum_types(types, parse_definitions.additional_enums); + for definition in parse_definitions.definitions.into_iter() { + SingleOpcodeDefinition::extract_and_insert(&mut definitions, definition); + } + let definitions = definitions + .into_iter() + .map(|(k, v)| { + let v = OpcodeDefinitions::new(&k, v); + (k, v) + }) + .collect::>(); + let mut token_enum = parse_definitions.token_type; + let (_, all_modifier) = write_definitions_into_tokens(&definitions, &mut token_enum.variants); + let token_impl = emit_parse_function(&token_enum.ident, &definitions, all_modifier); + let tokens = quote! { + #enum_types_tokens + + #token_enum + + #token_impl + }; + tokens.into() +} + +fn emit_enum_types( + types: FxHashMap>, + mut existing_enums: FxHashMap, +) -> TokenStream { + let token_types = types.into_iter().filter_map(|(type_, variants)| { + match type_ { + syn::Type::Path(TypePath { + qself: None, + ref path, + }) => { + if let Some(ident) = path.get_ident() { + if let Some(enum_) = existing_enums.get_mut(ident) { + enum_.variants.extend(variants.into_iter().map(|modifier| { + let ident = modifier.variant_capitalized(); + let variant: syn::Variant = syn::parse_quote! { + #ident + }; + variant + })); + return None; + } + } + } + _ => {} + } + let variants = variants.iter().map(|v| v.variant_capitalized()); + Some(quote! { + #[derive(Copy, Clone, PartialEq, Eq, Hash)] + enum #type_ { + #(#variants),* + } + }) + }); + let mut result = TokenStream::new(); + for tokens in token_types { + tokens.to_tokens(&mut result); + } + for (_, enum_) in existing_enums { + quote! { #enum_ }.to_tokens(&mut result); + } + result +} + +fn emit_parse_function( + type_name: &Ident, + defs: &FxHashMap, + all_modifier: FxHashSet<&parser::DotModifier>, +) -> TokenStream { + use std::fmt::Write; + let fns_ = defs + .iter() + .map(|(opcode, defs)| { + defs.definitions.iter().enumerate().map(|(idx, def)| { + let mut fn_name = opcode.to_string(); + write!(&mut fn_name, "_{}", idx).ok(); + let fn_name = Ident::new(&fn_name, Span::call_site()); + let code_block = &def.code_block.0; + let args = def.function_arguments_declarations(); + quote! { + fn #fn_name( #(#args),* ) -> Instruction #code_block + } + }) + }) + .flatten(); + let selectors = defs.iter().map(|(opcode, def)| { + let opcode_variant = Ident::new(&capitalize(&opcode.to_string()), opcode.span()); + let mut result = TokenStream::new(); + let mut selectors = TokenStream::new(); + quote! { + if false { + unsafe { std::hint::unreachable_unchecked() } + } + } + .to_tokens(&mut selectors); + let mut has_default_selector = false; + for selection_layer in def.block_selection.iter() { + for (selection_key, selected_definition) in selection_layer { + let def_parser = emit_definition_parser(type_name, (opcode,*selected_definition), &def.definitions[*selected_definition]); + match selection_key { + Some(selection_key) => { + let selection_key = + selection_key.dot_capitalized(); + quote! { + else if modifiers.contains(& #type_name :: #selection_key) { + #def_parser + } + } + .to_tokens(&mut selectors); + } + None => { + has_default_selector = true; + quote! { + else { + #def_parser + } + } + .to_tokens(&mut selectors); + } + } + } + } + if !has_default_selector { + quote! { + else { + return Err(winnow::error::ErrMode::from_error_kind(stream, winnow::error::ErrorKind::Token)) + } + } + .to_tokens(&mut selectors); + } + quote! { + #opcode_variant => { + let modifers_start = stream.checkpoint(); + let modifiers = take_while(0.., Token::modifier).parse_next(stream)?; + #selectors + } + } + .to_tokens(&mut result); + result + }); + let modifier_names = all_modifier.iter().map(|m| m.dot_capitalized()); + quote! { + impl<'input> #type_name<'input> { + fn modifier(self) -> bool { + match self { + #( + #type_name :: #modifier_names => true, + )* + _ => false + } + } + } + + #(#fns_)* + + fn parse_instruction<'input>(stream: &mut (impl winnow::stream::Stream, Slice = &'input [#type_name<'input>]> + winnow::stream::StreamIsPartial)) -> winnow::error::PResult> + { + use winnow::Parser; + use winnow::token::*; + use winnow::combinator::*; + let opcode = any.parse_next(stream)?; + let modifiers_start = stream.checkpoint(); + Ok(match opcode { + #( + #type_name :: #selectors + )* + _ => return Err(winnow::error::ErrMode::from_error_kind(stream, winnow::error::ErrorKind::Token)) + }) + } + } +} + +fn emit_definition_parser( + token_type: &Ident, + (opcode, fn_idx): (&Ident, usize), + definition: &SingleOpcodeDefinition, +) -> TokenStream { + let return_error_ref = quote! { + return Err(winnow::error::ErrMode::from_error_kind(&stream, winnow::error::ErrorKind::Token)) + }; + let return_error = quote! { + return Err(winnow::error::ErrMode::from_error_kind(stream, winnow::error::ErrorKind::Token)) + }; + let ordered_parse_declarations = definition.ordered_modifiers.iter().map(|modifier| { + modifier.type_of().map(|type_| { + let name = modifier.ident(); + quote! { + let #name : #type_; + } + }) + }); + let ordered_parse = definition.ordered_modifiers.iter().rev().map(|modifier| { + let arg_name = modifier.ident(); + let arg_type = modifier.type_of(); + match modifier { + DotModifierRef::Direct { optional, value } => { + let variant = value.dot_capitalized(); + if *optional { + quote! { + #arg_name = opt(any.verify(|t| *t == #token_type :: #variant)).parse_next(&mut stream)?.is_some(); + } + } else { + quote! { + any.verify(|t| *t == #token_type :: #variant).parse_next(&mut stream)?; + } + } + } + DotModifierRef::Indirect { optional, value } => { + let variants = value.alternatives.iter().map(|alt| { + let type_ = value.type_.as_ref().unwrap(); + let token_variant = alt.dot_capitalized(); + let parsed_variant = alt.variant_capitalized(); + quote! { + #token_type :: #token_variant => #type_ :: #parsed_variant, + } + }); + if *optional { + quote! { + #arg_name = opt(any.verify_map(|tok| { + Some(match tok { + #(#variants)* + _ => return None + }) + })).parse_next(&mut stream)?; + } + } else { + quote! { + #arg_name = any.verify_map(|tok| { + Some(match tok { + #(#variants)* + _ => return None + }) + }).parse_next(&mut stream)?; + } + } + } + } + }); + let unordered_parse_declarations = definition.unordered_modifiers.iter().map(|modifier| { + let name = modifier.ident(); + let type_ = modifier.type_of_check(); + quote! { + let mut #name : #type_ = std::default::Default::default(); + } + }); + let unordered_parse = definition + .unordered_modifiers + .iter() + .map(|modifier| match modifier { + DotModifierRef::Direct { value, .. } => { + let name = value.ident(); + let token_variant = value.dot_capitalized(); + quote! { + #token_type :: #token_variant => { + if #name { + #return_error_ref; + } + #name = true; + } + } + } + DotModifierRef::Indirect { value, .. } => { + let variable = value.modifier.ident(); + let type_ = value.type_.as_ref().unwrap(); + let alternatives = value.alternatives.iter().map(|alt| { + let token_variant = alt.dot_capitalized(); + let enum_variant = alt.variant_capitalized(); + quote! { + #token_type :: #token_variant => { + if #variable.is_some() { + #return_error_ref; + } + #variable = Some(#type_ :: #enum_variant); + } + } + }); + quote! { + #(#alternatives)* + } + } + }); + let unordered_parse_validations = + definition + .unordered_modifiers + .iter() + .map(|modifier| match modifier { + DotModifierRef::Direct { + optional: false, + value, + } => { + let variable = value.ident(); + quote! { + if !#variable { + #return_error; + } + } + } + DotModifierRef::Direct { optional: true, .. } => TokenStream::new(), + DotModifierRef::Indirect { + optional: false, + value, + } => { + let variable = value.modifier.ident(); + quote! { + let #variable = match #variable { + Some(x) => x, + None => #return_error + }; + } + } + DotModifierRef::Indirect { optional: true, .. } => TokenStream::new(), + }); + let arguments_parse = definition.arguments.0.iter().enumerate().map(|(idx, arg)| { + let comma = if idx == 0 { + quote! { empty } + } else { + quote! { any.verify(|t| *t == #token_type::Comma) } + }; + let pre_bracket = if arg.pre_bracket { + quote! { + any.verify(|t| *t == #token_type::LBracket).map(|_| ()) + } + } else { + quote! { + empty + } + }; + let pre_pipe = if arg.pre_pipe { + quote! { + any.verify(|t| *t == #token_type::Or).map(|_| ()) + } + } else { + quote! { + empty + } + }; + let can_be_negated = if arg.can_be_negated { + quote! { + opt(any.verify(|t| *t == #token_type::Not)).map(|o| o.is_some()) + } + } else { + quote! { + empty + } + }; + let ident = { + quote! { + any.verify_map(|t| match t { #token_type::Ident(s) => Some(s), _ => None }) + } + }; + let post_bracket = if arg.post_bracket { + quote! { + any.verify(|t| *t == #token_type::RBracket).map(|_| ()) + } + } else { + quote! { + empty + } + }; + let parser = quote! { + (#comma, #pre_bracket, #pre_pipe, #can_be_negated, #ident, #post_bracket) + }; + let arg_name = &arg.ident; + if arg.optional { + quote! { + let #arg_name = opt(#parser.map(|(_, _, _, _, name, _)| name)).parse_next(stream)?; + } + } else { + quote! { + let #arg_name = #parser.map(|(_, _, _, _, name, _)| name).parse_next(stream)?; + } + } + }); + let fn_args = definition.function_arguments(); + let fn_name = format_ident!("{}_{}", opcode, fn_idx); + let fn_call = quote! { + #fn_name( #(#fn_args),* ) + }; + quote! { + #(#unordered_parse_declarations)* + #(#ordered_parse_declarations)* + { + let mut stream = ReverseStream(modifiers); + #(#ordered_parse)* + let mut stream: &[#token_type] = stream.0; + for token in stream.iter().copied() { + match token { + #(#unordered_parse)* + _ => #return_error_ref + } + } + } + #(#unordered_parse_validations)* + #(#arguments_parse)* + #fn_call + } +} + +fn write_definitions_into_tokens<'a>( + defs: &'a FxHashMap, + variants: &mut Punctuated, +) -> (Vec<&'a Ident>, FxHashSet<&'a parser::DotModifier>) { + let mut all_opcodes = Vec::new(); + let mut all_modifiers = FxHashSet::default(); + for (opcode, definitions) in defs.iter() { + all_opcodes.push(opcode); + let opcode_as_string = opcode.to_string(); + let variant_name = Ident::new(&capitalize(&opcode_as_string), opcode.span()); + let arg: Variant = syn::parse_quote! { + #[token(#opcode_as_string)] + #variant_name + }; + variants.push(arg); + for definition in definitions.definitions.iter() { + for modifier in definition.possible_modifiers.iter() { + all_modifiers.insert(modifier); + } + } + } + for modifier in all_modifiers.iter() { + let modifier_as_string = modifier.to_string(); + let variant_name = modifier.dot_capitalized(); + let arg: Variant = syn::parse_quote! { + #[token(#modifier_as_string)] + #variant_name + }; + variants.push(arg); + } + (all_opcodes, all_modifiers) +} + +fn capitalize(s: &str) -> String { + let mut c = s.chars(); + match c.next() { + None => String::new(), + Some(f) => f.to_uppercase().collect::() + c.as_str(), + } +} + +fn multihash_extend(multimap: &mut FxHashMap>, k: K, v: V) { + match multimap.entry(k) { + hash_map::Entry::Occupied(mut entry) => entry.get_mut().push(v), + hash_map::Entry::Vacant(entry) => { + entry.insert(vec![v]); + } + } +} + +enum DotModifierRef { + Direct { + optional: bool, + value: parser::DotModifier, + }, + Indirect { + optional: bool, + value: Rc, + }, +} + +impl DotModifierRef { + fn ident(&self) -> Ident { + match self { + DotModifierRef::Direct { value, .. } => value.ident(), + DotModifierRef::Indirect { value, .. } => value.modifier.ident(), + } + } + + fn type_of(&self) -> Option { + Some(match self { + DotModifierRef::Direct { optional: true, .. } => syn::parse_quote! { bool }, + DotModifierRef::Direct { + optional: false, .. + } => return None, + DotModifierRef::Indirect { optional, value } => { + let type_ = value + .type_ + .as_ref() + .expect("Indirect modifer must have a type"); + if *optional { + syn::parse_quote! { Option<#type_> } + } else { + type_.clone() + } + } + }) + } + + fn type_of_check(&self) -> syn::Type { + match self { + DotModifierRef::Direct { .. } => syn::parse_quote! { bool }, + DotModifierRef::Indirect { value, .. } => { + let type_ = value + .type_ + .as_ref() + .expect("Indirect modifer must have a type"); + syn::parse_quote! { Option<#type_> } + } + } + } +} + +impl Hash for DotModifierRef { + fn hash(&self, state: &mut H) { + match self { + DotModifierRef::Direct { optional, value } => { + optional.hash(state); + value.hash(state); + } + DotModifierRef::Indirect { optional, value } => { + optional.hash(state); + (value.as_ref() as *const parser::Rule).hash(state); + } + } + } +} + +impl PartialEq for DotModifierRef { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + ( + Self::Direct { + optional: l_optional, + value: l_value, + }, + Self::Direct { + optional: r_optional, + value: r_value, + }, + ) => l_optional == r_optional && l_value == r_value, + ( + Self::Indirect { + optional: l_optional, + value: l_value, + }, + Self::Indirect { + optional: r_optional, + value: r_value, + }, + ) => { + l_optional == r_optional + && l_value.as_ref() as *const parser::Rule + == r_value.as_ref() as *const parser::Rule + } + _ => false, + } + } +} + +impl Eq for DotModifierRef {} + +#[proc_macro] +pub fn generate_instruction_type(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input = parse_macro_input!(tokens as gen_impl::GenerateInstructionType); + let mut result = proc_macro2::TokenStream::new(); + input.emit_arg_types(&mut result); + input.emit_instruction_type(&mut result); + input.emit_visit(&mut result); + input.emit_visit_mut(&mut result); + input.emit_visit_map(&mut result); + result.into() +} diff --git a/gen_impl/Cargo.toml b/gen_impl/Cargo.toml new file mode 100644 index 0000000..ff93f98 --- /dev/null +++ b/gen_impl/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "gen_impl" +version = "0.1.0" +edition = "2021" + +[lib] + +[dependencies] +syn = { version = "2.0.67", features = ["extra-traits", "full"] } +quote = "1.0" +proc-macro2 = "1.0.86" +rustc-hash = "2.0.0" diff --git a/gen_impl/src/lib.rs b/gen_impl/src/lib.rs new file mode 100644 index 0000000..4c7f2ab --- /dev/null +++ b/gen_impl/src/lib.rs @@ -0,0 +1,718 @@ +use proc_macro2::TokenStream; +use quote::{format_ident, quote, ToTokens}; +use syn::{ + braced, parse::Parse, punctuated::Punctuated, token, Expr, Ident, Token, Type, TypeParam, +}; + +pub mod parser; + +pub struct GenerateInstructionType { + pub name: Ident, + pub type_parameters: Punctuated, + pub short_parameters: Punctuated, + pub variants: Punctuated, +} + +impl GenerateInstructionType { + pub fn emit_arg_types(&self, tokens: &mut TokenStream) { + for v in self.variants.iter() { + v.emit_type(&self.type_parameters, tokens); + } + } + + pub fn emit_instruction_type(&self, tokens: &mut TokenStream) { + let type_name = &self.name; + let type_parameters = &self.type_parameters; + let variants = self.variants.iter().map(|v| v.emit_variant()); + quote! { + enum #type_name<#type_parameters> { + #(#variants),* + } + } + .to_tokens(tokens); + } + + pub fn emit_visit(&self, tokens: &mut TokenStream) { + self.emit_visit_impl(VisitKind::Ref, tokens, InstructionVariant::emit_visit) + } + + pub fn emit_visit_mut(&self, tokens: &mut TokenStream) { + self.emit_visit_impl( + VisitKind::RefMut, + tokens, + InstructionVariant::emit_visit_mut, + ) + } + + pub fn emit_visit_map(&self, tokens: &mut TokenStream) { + self.emit_visit_impl(VisitKind::Map, tokens, InstructionVariant::emit_visit_map) + } + + fn emit_visit_impl( + &self, + kind: VisitKind, + tokens: &mut TokenStream, + mut fn_: impl FnMut(&InstructionVariant, &Ident, &mut TokenStream), + ) { + let type_name = &self.name; + let type_parameters = &self.type_parameters; + let short_parameters = &self.short_parameters; + let mut inner_tokens = TokenStream::new(); + for v in self.variants.iter() { + fn_(v, type_name, &mut inner_tokens); + } + let visit_ref = kind.reference(); + let visitor_type = format_ident!("Visitor{}", kind.type_suffix()); + let visit_fn = format_ident!("visit{}", kind.fn_suffix()); + let visit_slice_fn = format_ident!("visit{}_slice", kind.fn_suffix()); + let (type_parameters, visitor_parameters, return_type) = if kind == VisitKind::Map { + ( + quote! { <#type_parameters, To: Operand> }, + quote! { <#short_parameters, To> }, + quote! { #type_name }, + ) + } else { + ( + quote! { <#type_parameters> }, + quote! { <#short_parameters> }, + quote! { () }, + ) + }; + quote! { + fn #visit_fn #type_parameters (i: #visit_ref #type_name<#short_parameters>, visitor: &mut impl #visitor_type #visitor_parameters ) -> #return_type { + match i { + #inner_tokens + } + } + }.to_tokens(tokens); + if kind == VisitKind::Map { + return; + } + quote! { + fn #visit_slice_fn #type_parameters (instructions: #visit_ref [#type_name<#short_parameters>], visitor: &mut impl #visitor_type #visitor_parameters) { + for i in instructions { + #visit_fn(i, visitor) + } + } + }.to_tokens(tokens); + } +} + +#[derive(Clone, Copy, PartialEq, Eq)] +enum VisitKind { + Ref, + RefMut, + Map, +} + +impl VisitKind { + fn fn_suffix(self) -> &'static str { + match self { + VisitKind::Ref => "", + VisitKind::RefMut => "_mut", + VisitKind::Map => "_map", + } + } + + fn type_suffix(self) -> &'static str { + match self { + VisitKind::Ref => "", + VisitKind::RefMut => "Mut", + VisitKind::Map => "Map", + } + } + + fn reference(self) -> Option { + match self { + VisitKind::Ref => Some(quote! { & }), + VisitKind::RefMut => Some(quote! { &mut }), + VisitKind::Map => None, + } + } +} + +impl Parse for GenerateInstructionType { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + input.parse::()?; + let name = input.parse::()?; + input.parse::()?; + let type_parameters = Punctuated::parse_separated_nonempty(input)?; + let short_parameters = type_parameters + .iter() + .map(|p: &TypeParam| p.ident.clone()) + .collect(); + input.parse::]>()?; + let variants_buffer; + braced!(variants_buffer in input); + let variants = variants_buffer.parse_terminated(InstructionVariant::parse, Token![,])?; + Ok(Self { + name, + type_parameters, + short_parameters, + variants, + }) + } +} + +pub struct InstructionVariant { + pub name: Ident, + pub type_: Option, + pub space: Option, + pub data: Option, + pub arguments: Option, +} + +impl InstructionVariant { + fn args_name(&self) -> Ident { + format_ident!("{}Args", self.name) + } + + fn emit_variant(&self) -> TokenStream { + let name = &self.name; + let data = match &self.data { + None => { + quote! {} + } + Some(data_type) => { + quote! { + data: #data_type, + } + } + }; + let arguments = match &self.arguments { + None => { + quote! {} + } + Some(args) => { + let args_name = self.args_name(); + match &args.generic { + None => { + quote! { + arguments: #args_name, + } + } + Some(generics) => { + quote! { + arguments: #args_name <#generics>, + } + } + } + } + }; + quote! { + #name { #data #arguments } + } + } + + fn emit_visit(&self, enum_: &Ident, tokens: &mut TokenStream) { + self.emit_visit_impl(enum_, tokens, InstructionArguments::emit_visit) + } + + fn emit_visit_mut(&self, enum_: &Ident, tokens: &mut TokenStream) { + self.emit_visit_impl(enum_, tokens, InstructionArguments::emit_visit_mut) + } + + fn emit_visit_impl( + &self, + enum_: &Ident, + tokens: &mut TokenStream, + mut fn_: impl FnMut(&InstructionArguments, &Option, &Option) -> TokenStream, + ) { + let name = &self.name; + let arguments = match &self.arguments { + None => { + quote! { + #enum_ :: #name { .. } => { } + } + .to_tokens(tokens); + return; + } + Some(args) => args, + }; + let arg_calls = fn_(arguments, &self.type_, &self.space); + quote! { + #enum_ :: #name { arguments, data } => { + #arg_calls + } + } + .to_tokens(tokens); + } + + fn emit_visit_map(&self, enum_: &Ident, tokens: &mut TokenStream) { + let name = &self.name; + let arguments = &self.arguments.as_ref().map(|_| quote! { arguments,}); + let data = &self.data.as_ref().map(|_| quote! { data,}); + let mut arg_calls = None; + let arguments_init = self.arguments.as_ref().map(|arguments| { + let arg_type = self.args_name(); + arg_calls = Some(arguments.emit_visit_map(&self.type_, &self.space)); + let arg_names = arguments.fields.iter().map(|arg| &arg.name); + quote! { + arguments: #arg_type { #(#arg_names),* } + } + }); + quote! { + #enum_ :: #name { #data #arguments } => { + #arg_calls + #enum_ :: #name { #data #arguments_init } + } + } + .to_tokens(tokens); + } + + fn emit_type( + &self, + type_parameters: &Punctuated, + tokens: &mut TokenStream, + ) { + let arguments = match self.arguments { + Some(ref a) => a, + None => return, + }; + let name = self.args_name(); + let type_parameters = if arguments.generic.is_some() { + Some(quote! { <#type_parameters> }) + } else { + None + }; + let fields = arguments.fields.iter().map(ArgumentField::emit_field); + quote! { + struct #name #type_parameters { + #(#fields),* + } + } + .to_tokens(tokens); + } +} + +impl Parse for InstructionVariant { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let name = input.parse::()?; + let properties_buffer; + braced!(properties_buffer in input); + let properties = properties_buffer.parse_terminated(VariantProperty::parse, Token![,])?; + let mut type_ = None; + let mut space = None; + let mut data = None; + let mut arguments = None; + for property in properties { + match property { + VariantProperty::Type(t) => type_ = Some(t), + VariantProperty::Space(s) => space = Some(s), + VariantProperty::Data(d) => data = Some(d), + VariantProperty::Arguments(a) => arguments = Some(a), + } + } + Ok(Self { + name, + type_, + space, + data, + arguments, + }) + } +} + +enum VariantProperty { + Type(Expr), + Space(Expr), + Data(Type), + Arguments(InstructionArguments), +} + +impl VariantProperty { + pub fn parse(input: syn::parse::ParseStream) -> syn::Result { + let lookahead = input.lookahead1(); + Ok(if lookahead.peek(Token![type]) { + input.parse::()?; + input.parse::()?; + VariantProperty::Type(input.parse::()?) + } else if lookahead.peek(Ident) { + let key = input.parse::()?; + match &*key.to_string() { + "data" => { + input.parse::()?; + VariantProperty::Data(input.parse::()?) + } + "space" => { + input.parse::()?; + VariantProperty::Space(input.parse::()?) + } + "arguments" => { + let generics = if input.peek(Token![<]) { + input.parse::()?; + let gen_params = + Punctuated::::parse_separated_nonempty(input)?; + input.parse::]>()?; + Some(gen_params) + } else { + None + }; + input.parse::()?; + let fields; + braced!(fields in input); + VariantProperty::Arguments(InstructionArguments::parse(generics, &fields)?) + } + x => { + return Err(syn::Error::new( + key.span(), + format!( + "Unexpected key `{}`. Expected `type`, `data` or `arguments`.", + x + ), + )) + } + } + } else { + return Err(lookahead.error()); + }) + } +} + +pub struct InstructionArguments { + pub generic: Option>, + pub fields: Punctuated, +} + +impl InstructionArguments { + pub fn parse( + generic: Option>, + input: syn::parse::ParseStream, + ) -> syn::Result { + let fields = Punctuated::::parse_terminated_with( + input, + ArgumentField::parse, + )?; + Ok(Self { generic, fields }) + } + + fn emit_visit(&self, parent_type: &Option, parent_space: &Option) -> TokenStream { + self.emit_visit_impl(parent_type, parent_space, ArgumentField::emit_visit) + } + + fn emit_visit_mut( + &self, + parent_type: &Option, + parent_space: &Option, + ) -> TokenStream { + self.emit_visit_impl(parent_type, parent_space, ArgumentField::emit_visit_mut) + } + + fn emit_visit_map( + &self, + parent_type: &Option, + parent_space: &Option, + ) -> TokenStream { + self.emit_visit_impl(parent_type, parent_space, ArgumentField::emit_visit_map) + } + + fn emit_visit_impl( + &self, + parent_type: &Option, + parent_space: &Option, + mut fn_: impl FnMut(&ArgumentField, &Option, &Option) -> TokenStream, + ) -> TokenStream { + let field_calls = self + .fields + .iter() + .map(|f| fn_(f, parent_type, parent_space)); + quote! { + #(#field_calls)* + } + } +} + +pub struct ArgumentField { + pub name: Ident, + pub is_dst: bool, + pub repr: Type, + pub space: Option, + pub type_: Option, +} + +impl ArgumentField { + fn parse_block( + input: syn::parse::ParseStream, + ) -> syn::Result<(Type, Option, Option)> { + let content; + braced!(content in input); + let all_fields = + Punctuated::::parse_terminated_with(&content, |content| { + let lookahead = content.lookahead1(); + Ok(if lookahead.peek(Token![type]) { + content.parse::()?; + content.parse::()?; + ExprOrPath::Type(content.parse::()?) + } else if lookahead.peek(Ident) { + let name_ident = content.parse::()?; + content.parse::()?; + match &*name_ident.to_string() { + "repr" => ExprOrPath::Repr(content.parse::()?), + "space" => ExprOrPath::Space(content.parse::()?), + name => { + return Err(syn::Error::new( + name_ident.span(), + format!("Unexpected key `{}`, expected `repr` or `space", name), + )) + } + } + } else { + return Err(lookahead.error()); + }) + })?; + let mut repr = None; + let mut type_ = None; + let mut space = None; + for exp_or_path in all_fields { + match exp_or_path { + ExprOrPath::Repr(r) => repr = Some(r), + ExprOrPath::Type(t) => type_ = Some(t), + ExprOrPath::Space(s) => space = Some(s), + } + } + Ok((repr.unwrap(), type_, space)) + } + + fn parse_basic(input: &syn::parse::ParseBuffer) -> syn::Result { + input.parse::() + } + + fn emit_visit(&self, parent_type: &Option, parent_space: &Option) -> TokenStream { + self.emit_visit_impl(parent_type, parent_space, false) + } + + fn emit_visit_mut( + &self, + parent_type: &Option, + parent_space: &Option, + ) -> TokenStream { + self.emit_visit_impl(parent_type, parent_space, true) + } + + fn emit_visit_impl( + &self, + parent_type: &Option, + parent_space: &Option, + is_mut: bool, + ) -> TokenStream { + let type_ = self.type_.as_ref().or(parent_type.as_ref()).unwrap(); + let space = self + .space + .as_ref() + .or(parent_space.as_ref()) + .map(|space| quote! { #space }) + .unwrap_or_else(|| quote! { StateSpace::Reg }); + let is_dst = self.is_dst; + let name = &self.name; + let arguments_name = if is_mut { + quote! { + &mut arguments.#name + } + } else { + quote! { + & arguments.#name + } + }; + quote! {{ + let type_ = #type_; + let space = #space; + visitor.visit(#arguments_name, &type_, space, #is_dst); + }} + } + + fn emit_visit_map( + &self, + parent_type: &Option, + parent_space: &Option, + ) -> TokenStream { + let type_ = self.type_.as_ref().or(parent_type.as_ref()).unwrap(); + let space = self + .space + .as_ref() + .or(parent_space.as_ref()) + .map(|space| quote! { #space }) + .unwrap_or_else(|| quote! { StateSpace::Reg }); + let is_dst = self.is_dst; + let name = &self.name; + quote! { + let #name = { + let type_ = #type_; + let space = #space; + visitor.visit(arguments.#name, &type_, space, #is_dst) + }; + } + } + + fn is_dst(name: &Ident) -> syn::Result { + if name.to_string().starts_with("dst") { + Ok(true) + } else if name.to_string().starts_with("src") { + Ok(false) + } else { + return Err(syn::Error::new( + name.span(), + format!( + "Could not guess if `{}` is a read or write argument. Name should start with `dst` or `src`", + name + ), + )); + } + } + + fn emit_field(&self) -> TokenStream { + let name = &self.name; + let type_ = &self.repr; + quote! { + #name: #type_ + } + } +} + +impl Parse for ArgumentField { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let name = input.parse::()?; + let is_dst = Self::is_dst(&name)?; + input.parse::()?; + let lookahead = input.lookahead1(); + let (repr, type_, space) = if lookahead.peek(token::Brace) { + Self::parse_block(input)? + } else if lookahead.peek(syn::Ident) { + (Self::parse_basic(input)?, None, None) + } else { + return Err(lookahead.error()); + }; + Ok(Self { + name, + is_dst, + repr, + type_, + space, + }) + } +} + +enum ExprOrPath { + Repr(Type), + Type(Expr), + Space(Expr), +} + +#[cfg(test)] +mod tests { + use super::*; + use proc_macro2::Span; + use quote::{quote, ToTokens}; + + fn to_string(x: impl ToTokens) -> String { + quote! { #x }.to_string() + } + + #[test] + fn parse_argument_field_basic() { + let input = quote! { + dst: P::Operand + }; + let arg = syn::parse2::(input).unwrap(); + assert_eq!("dst", arg.name.to_string()); + assert_eq!("P :: Operand", to_string(arg.repr)); + assert!(matches!(arg.type_, None)); + } + + #[test] + fn parse_argument_field_block() { + let input = quote! { + dst: { + type: ScalarType::U32, + space: StateSpace::Global, + repr: P::Operand, + } + }; + let arg = syn::parse2::(input).unwrap(); + assert_eq!("dst", arg.name.to_string()); + assert_eq!("ScalarType :: U32", to_string(arg.type_.unwrap())); + assert_eq!("StateSpace :: Global", to_string(arg.space.unwrap())); + assert_eq!("P :: Operand", to_string(arg.repr)); + } + + #[test] + fn parse_argument_field_block_untyped() { + let input = quote! { + dst: { + repr: P::Operand, + } + }; + let arg = syn::parse2::(input).unwrap(); + assert_eq!("dst", arg.name.to_string()); + assert_eq!("P :: Operand", to_string(arg.repr)); + assert!(matches!(arg.type_, None)); + } + + #[test] + fn parse_variant_complex() { + let input = quote! { + Ld { + type: ScalarType::U32, + space: StateSpace::Global, + data: LdDetails, + arguments

: { + dst: { + repr: P::Operand, + type: ScalarType::U32, + space: StateSpace::Shared, + }, + src: P::Operand, + }, + } + }; + let variant = syn::parse2::(input).unwrap(); + assert_eq!("Ld", variant.name.to_string()); + assert_eq!("ScalarType :: U32", to_string(variant.type_.unwrap())); + assert_eq!("StateSpace :: Global", to_string(variant.space.unwrap())); + assert_eq!("LdDetails", to_string(variant.data.unwrap())); + let arguments = variant.arguments.unwrap(); + assert_eq!("P", to_string(arguments.generic)); + let mut fields = arguments.fields.into_iter(); + let dst = fields.next().unwrap(); + assert_eq!("P :: Operand", to_string(dst.repr)); + assert_eq!("ScalarType :: U32", to_string(dst.type_)); + assert_eq!("StateSpace :: Shared", to_string(dst.space)); + let src = fields.next().unwrap(); + assert_eq!("P :: Operand", to_string(src.repr)); + assert!(matches!(src.type_, None)); + assert!(matches!(src.space, None)); + } + + #[test] + fn visit_variant() { + let input = quote! { + Ld { + type: ScalarType::U32, + data: LdDetails, + arguments

: { + dst: { + repr: P::Operand, + type: ScalarType::U32 + }, + src: P::Operand, + }, + } + }; + let variant = syn::parse2::(input).unwrap(); + let mut output = TokenStream::new(); + variant.emit_visit(&Ident::new("Instruction", Span::call_site()), &mut output); + assert_eq!(output.to_string(), "Instruction :: Ld { arguments , data } => { { let type_ = ScalarType :: U32 ; let space = StateSpace :: Reg ; visitor . visit (& arguments . dst , & type_ , space , true) ; } { let type_ = ScalarType :: U32 ; let space = StateSpace :: Reg ; visitor . visit (& arguments . src , & type_ , space , false) ; } }"); + } + + #[test] + fn visit_variant_empty() { + let input = quote! { + Ret { + data: RetData + } + }; + let variant = syn::parse2::(input).unwrap(); + let mut output = TokenStream::new(); + variant.emit_visit(&Ident::new("Instruction", Span::call_site()), &mut output); + assert_eq!(output.to_string(), "Instruction :: Ret { .. } => { }"); + } +} diff --git a/gen_impl/src/parser.rs b/gen_impl/src/parser.rs new file mode 100644 index 0000000..c8da61d --- /dev/null +++ b/gen_impl/src/parser.rs @@ -0,0 +1,793 @@ +use proc_macro2::Span; +use proc_macro2::TokenStream; +use quote::quote; +use quote::ToTokens; +use rustc_hash::FxHashMap; +use std::fmt::Write; +use syn::bracketed; +use syn::parse::Peek; +use syn::punctuated::Punctuated; +use syn::LitInt; +use syn::Type; +use syn::{braced, parse::Parse, token, Ident, ItemEnum, Token}; + +pub struct ParseDefinitions { + pub token_type: ItemEnum, + pub additional_enums: FxHashMap, + pub definitions: Vec, +} + +impl Parse for ParseDefinitions { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let token_type = input.parse::()?; + let mut additional_enums = FxHashMap::default(); + while input.peek(Token![#]) { + let enum_ = input.parse::()?; + additional_enums.insert(enum_.ident.clone(), enum_); + } + let mut definitions = Vec::new(); + while !input.is_empty() { + definitions.push(input.parse::()?); + } + Ok(Self { + token_type, + additional_enums, + definitions, + }) + } +} + +pub struct OpcodeDefinition(pub Patterns, pub Vec); + +impl Parse for OpcodeDefinition { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let patterns = input.parse::()?; + let mut rules = Vec::new(); + while Rule::peek(input) { + rules.push(input.parse::()?); + input.parse::()?; + } + Ok(Self(patterns, rules)) + } +} + +pub struct Patterns(pub Vec<(OpcodeDecl, CodeBlock)>); + +impl Parse for Patterns { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let mut result = Vec::new(); + loop { + if !OpcodeDecl::peek(input) { + break; + } + let decl = input.parse::()?; + let code_block = input.parse::()?; + result.push((decl, code_block)) + } + Ok(Self(result)) + } +} + +pub struct OpcodeDecl(pub Instruction, pub Arguments); + +impl OpcodeDecl { + fn peek(input: syn::parse::ParseStream) -> bool { + Instruction::peek(input) + } +} + +impl Parse for OpcodeDecl { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + Ok(Self( + input.parse::()?, + input.parse::()?, + )) + } +} + +pub struct CodeBlock(pub proc_macro2::Group); + +impl Parse for CodeBlock { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + input.parse::]>()?; + let group = input.parse::()?; + Ok(Self(group)) + } +} + +pub struct Rule { + pub modifier: DotModifier, + pub type_: Option, + pub alternatives: Vec, +} + +impl Rule { + fn peek(input: syn::parse::ParseStream) -> bool { + DotModifier::peek(input) + } + + fn parse_alternatives(input: syn::parse::ParseStream) -> syn::Result> { + let mut result = Vec::new(); + Self::parse_with_alternative(input, &mut result)?; + loop { + if !input.peek(Token![,]) { + break; + } + input.parse::()?; + Self::parse_with_alternative(input, &mut result)?; + } + Ok(result) + } + + fn parse_with_alternative( + input: &syn::parse::ParseBuffer, + result: &mut Vec, + ) -> Result<(), syn::Error> { + input.parse::()?; + let part1 = input.parse::()?; + if input.peek(token::Brace) { + result.push(DotModifier { + part1: part1.clone(), + part2: None, + }); + let suffix_content; + braced!(suffix_content in input); + let suffixes = Punctuated::::parse_separated_nonempty( + &suffix_content, + )?; + for part2 in suffixes { + result.push(DotModifier { + part1: part1.clone(), + part2: Some(part2), + }); + } + } else if IdentOrTypeSuffix::peek(input) { + let part2 = Some(IdentOrTypeSuffix::parse(input)?); + result.push(DotModifier { part1, part2 }); + } else { + result.push(DotModifier { part1, part2: None }); + } + Ok(()) + } +} + +#[derive(PartialEq, Eq, Hash, Clone)] +struct IdentOrTypeSuffix(IdentLike); + +impl IdentOrTypeSuffix { + fn peek(input: syn::parse::ParseStream) -> bool { + input.peek(Token![::]) + } +} + +impl ToTokens for IdentOrTypeSuffix { + fn to_tokens(&self, tokens: &mut TokenStream) { + let ident = &self.0; + quote! { :: #ident }.to_tokens(tokens) + } +} + +impl Parse for IdentOrTypeSuffix { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + input.parse::()?; + Ok(Self(input.parse::()?)) + } +} + +impl Parse for Rule { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let modifier = input.parse::()?; + let type_ = if input.peek(Token![:]) { + input.parse::()?; + Some(input.parse::()?) + } else { + None + }; + input.parse::()?; + let content; + braced!(content in input); + let alternatives = Self::parse_alternatives(&content)?; + Ok(Self { + modifier, + type_, + alternatives, + }) + } +} + +pub struct Instruction { + pub name: Ident, + pub modifiers: Vec, +} +impl Instruction { + fn peek(input: syn::parse::ParseStream) -> bool { + input.peek(Ident) + } +} + +impl Parse for Instruction { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let instruction = input.parse::()?; + let mut modifiers = Vec::new(); + loop { + if !MaybeDotModifier::peek(input) { + break; + } + modifiers.push(MaybeDotModifier::parse(input)?); + } + Ok(Self { + name: instruction, + modifiers, + }) + } +} + +pub struct MaybeDotModifier { + pub optional: bool, + pub modifier: DotModifier, +} + +impl MaybeDotModifier { + fn peek(input: syn::parse::ParseStream) -> bool { + input.peek(token::Brace) || DotModifier::peek(input) + } +} + +impl Parse for MaybeDotModifier { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + Ok(if input.peek(token::Brace) { + let content; + braced!(content in input); + let modifier = DotModifier::parse(&content)?; + Self { + modifier, + optional: true, + } + } else { + let modifier = DotModifier::parse(input)?; + Self { + modifier, + optional: false, + } + }) + } +} + +#[derive(PartialEq, Eq, Hash, Clone)] +pub struct DotModifier { + part1: IdentLike, + part2: Option, +} + +impl std::fmt::Display for DotModifier { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, ".")?; + self.part1.fmt(f)?; + if let Some(ref part2) = self.part2 { + write!(f, "::")?; + part2.0.fmt(f)?; + } + Ok(()) + } +} + +impl std::fmt::Debug for DotModifier { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + std::fmt::Display::fmt(&self, f) + } +} + +impl DotModifier { + pub fn ident(&self) -> Ident { + let mut result = String::new(); + write!(&mut result, "{}", self.part1).unwrap(); + if let Some(ref part2) = self.part2 { + write!(&mut result, "_{}", part2.0).unwrap(); + } else { + match self.part1 { + IdentLike::Type | IdentLike::Const => result.push('_'), + IdentLike::Ident(_) | IdentLike::Integer(_) => {} + } + } + Ident::new(&result.to_ascii_lowercase(), Span::call_site()) + } + + pub fn variant_capitalized(&self) -> Ident { + self.capitalized_impl(String::new()) + } + + pub fn dot_capitalized(&self) -> Ident { + self.capitalized_impl("Dot".to_string()) + } + + fn capitalized_impl(&self, prefix: String) -> Ident { + let mut temp = String::new(); + write!(&mut temp, "{}", &self.part1).unwrap(); + if let Some(IdentOrTypeSuffix(ref part2)) = self.part2 { + write!(&mut temp, "_{}", part2).unwrap(); + } + let mut result = prefix; + let mut capitalize = true; + for c in temp.chars() { + if c == '_' { + capitalize = true; + continue; + } + let c = if capitalize { + capitalize = false; + c.to_ascii_uppercase() + } else { + c + }; + result.push(c); + } + Ident::new(&result, Span::call_site()) + } + + pub fn tokens(&self) -> TokenStream { + let part1 = &self.part1; + let part2 = &self.part2; + match self.part2 { + None => quote! { . #part1 }, + Some(_) => quote! { . #part1 #part2 }, + } + } + + fn peek(input: syn::parse::ParseStream) -> bool { + input.peek(Token![.]) + } +} + +impl Parse for DotModifier { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + input.parse::()?; + let part1 = input.parse::()?; + if IdentOrTypeSuffix::peek(input) { + let part2 = Some(IdentOrTypeSuffix::parse(input)?); + Ok(Self { part1, part2 }) + } else { + Ok(Self { part1, part2: None }) + } + } +} + +#[derive(PartialEq, Eq, Hash, Clone)] +enum IdentLike { + Type, + Const, + Ident(Ident), + Integer(LitInt), +} + +impl std::fmt::Display for IdentLike { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + IdentLike::Type => f.write_str("type"), + IdentLike::Const => f.write_str("const"), + IdentLike::Ident(ident) => write!(f, "{}", ident), + IdentLike::Integer(integer) => write!(f, "{}", integer), + } + } +} + +impl ToTokens for IdentLike { + fn to_tokens(&self, tokens: &mut TokenStream) { + match self { + IdentLike::Type => quote! { type }.to_tokens(tokens), + IdentLike::Const => quote! { const }.to_tokens(tokens), + IdentLike::Ident(ident) => quote! { #ident }.to_tokens(tokens), + IdentLike::Integer(int) => quote! { #int }.to_tokens(tokens), + } + } +} + +impl Parse for IdentLike { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let lookahead = input.lookahead1(); + Ok(if lookahead.peek(Token![const]) { + input.parse::()?; + IdentLike::Const + } else if lookahead.peek(Token![type]) { + input.parse::()?; + IdentLike::Type + } else if lookahead.peek(Ident) { + IdentLike::Ident(input.parse::()?) + } else if lookahead.peek(LitInt) { + IdentLike::Integer(input.parse::()?) + } else { + return Err(lookahead.error()); + }) + } +} + +// Arguments decalaration can loook like this: +// a{, b} +// That's why we don't parse Arguments as Punctuated +#[derive(PartialEq, Eq)] +pub struct Arguments(pub Vec); + +impl Parse for Arguments { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let mut result = Vec::new(); + loop { + if input.peek(Token![,]) { + input.parse::()?; + } + let mut optional = false; + let mut can_be_negated = false; + let mut pre_pipe = false; + let ident; + let lookahead = input.lookahead1(); + if lookahead.peek(token::Brace) { + let content; + braced!(content in input); + let lookahead = content.lookahead1(); + if lookahead.peek(Token![!]) { + content.parse::()?; + can_be_negated = true; + ident = input.parse::()?; + } else if lookahead.peek(Token![,]) { + optional = true; + content.parse::()?; + ident = content.parse::()?; + } else { + return Err(lookahead.error()); + } + } else if lookahead.peek(token::Bracket) { + let bracketed; + bracketed!(bracketed in input); + if bracketed.peek(Token![|]) { + optional = true; + bracketed.parse::()?; + pre_pipe = true; + ident = bracketed.parse::()?; + } else { + let mut sub_args = Self::parse(&bracketed)?; + sub_args.0.first_mut().unwrap().pre_bracket = true; + sub_args.0.last_mut().unwrap().post_bracket = true; + if peek_brace_token(input, Token![.]) { + let optional_suffix; + braced!(optional_suffix in input); + optional_suffix.parse::()?; + let unified_ident = optional_suffix.parse::()?; + if unified_ident.to_string() != "unified" { + return Err(syn::Error::new( + unified_ident.span(), + format!("Exptected `unified`, got `{}`", unified_ident), + )); + } + for a in sub_args.0.iter_mut() { + a.unified = true; + } + } + result.extend(sub_args.0); + continue; + } + } else if lookahead.peek(Ident) { + ident = input.parse::()?; + } else if lookahead.peek(Token![|]) { + input.parse::()?; + pre_pipe = true; + ident = input.parse::()?; + } else { + break; + } + result.push(Argument { + optional, + pre_pipe, + can_be_negated, + pre_bracket: false, + ident, + post_bracket: false, + unified: false, + }); + } + Ok(Self(result)) + } +} + +// This is effectively input.peek(token::Brace) && input.peek2(Token![.]) +// input.peek2 is supposed to skip over next token, but it skips over whole +// braced token group. Not sure if it's a bug +fn peek_brace_token(input: syn::parse::ParseStream, _t: T) -> bool { + use syn::token::Token; + let cursor = input.cursor(); + cursor + .group(proc_macro2::Delimiter::Brace) + .map_or(false, |(content, ..)| T::Token::peek(content)) +} + +#[derive(PartialEq, Eq)] +pub struct Argument { + pub optional: bool, + pub pre_bracket: bool, + pub pre_pipe: bool, + pub can_be_negated: bool, + pub ident: Ident, + pub post_bracket: bool, + pub unified: bool, +} + +#[cfg(test)] +mod tests { + use super::{Arguments, DotModifier, MaybeDotModifier}; + use quote::{quote, ToTokens}; + + #[test] + fn parse_modifier_complex() { + let input = quote! { + .level::eviction_priority + }; + let modifier = syn::parse2::(input).unwrap(); + assert_eq!( + ". level :: eviction_priority", + modifier.tokens().to_string() + ); + } + + #[test] + fn parse_modifier_optional() { + let input = quote! { + { .level::eviction_priority } + }; + let maybe_modifider = syn::parse2::(input).unwrap(); + assert_eq!( + ". level :: eviction_priority", + maybe_modifider.modifier.tokens().to_string() + ); + assert!(maybe_modifider.optional); + } + + #[test] + fn parse_type_token() { + let input = quote! { + . type + }; + let maybe_modifier = syn::parse2::(input).unwrap(); + assert_eq!(". type", maybe_modifier.modifier.tokens().to_string()); + assert!(!maybe_modifier.optional); + } + + #[test] + fn arguments_memory() { + let input = quote! { + [a], b + }; + let arguments = syn::parse2::(input).unwrap(); + let a = &arguments.0[0]; + assert!(!a.optional); + assert_eq!("a", a.ident.to_string()); + assert!(a.pre_bracket); + assert!(!a.pre_pipe); + assert!(a.post_bracket); + assert!(!a.can_be_negated); + let b = &arguments.0[1]; + assert!(!b.optional); + assert_eq!("b", b.ident.to_string()); + assert!(!b.pre_bracket); + assert!(!b.pre_pipe); + assert!(!b.post_bracket); + assert!(!b.can_be_negated); + } + + #[test] + fn arguments_optional() { + let input = quote! { + b{, cache_policy} + }; + let arguments = syn::parse2::(input).unwrap(); + let b = &arguments.0[0]; + assert!(!b.optional); + assert_eq!("b", b.ident.to_string()); + assert!(!b.pre_bracket); + assert!(!b.pre_pipe); + assert!(!b.post_bracket); + assert!(!b.can_be_negated); + let cache_policy = &arguments.0[1]; + assert!(cache_policy.optional); + assert_eq!("cache_policy", cache_policy.ident.to_string()); + assert!(!cache_policy.pre_bracket); + assert!(!cache_policy.pre_pipe); + assert!(!cache_policy.post_bracket); + assert!(!cache_policy.can_be_negated); + } + + #[test] + fn arguments_optional_pred() { + let input = quote! { + p[|q], a + }; + let arguments = syn::parse2::(input).unwrap(); + assert_eq!(arguments.0.len(), 3); + let p = &arguments.0[0]; + assert!(!p.optional); + assert_eq!("p", p.ident.to_string()); + assert!(!p.pre_bracket); + assert!(!p.pre_pipe); + assert!(!p.post_bracket); + assert!(!p.can_be_negated); + let q = &arguments.0[1]; + assert!(q.optional); + assert_eq!("q", q.ident.to_string()); + assert!(!q.pre_bracket); + assert!(q.pre_pipe); + assert!(!q.post_bracket); + assert!(!q.can_be_negated); + let a = &arguments.0[2]; + assert!(!a.optional); + assert_eq!("a", a.ident.to_string()); + assert!(!a.pre_bracket); + assert!(!a.pre_pipe); + assert!(!a.post_bracket); + assert!(!a.can_be_negated); + } + + #[test] + fn arguments_optional_with_negate() { + let input = quote! { + b, {!}c + }; + let arguments = syn::parse2::(input).unwrap(); + assert_eq!(arguments.0.len(), 2); + let b = &arguments.0[0]; + assert!(!b.optional); + assert_eq!("b", b.ident.to_string()); + assert!(!b.pre_bracket); + assert!(!b.pre_pipe); + assert!(!b.post_bracket); + assert!(!b.can_be_negated); + let c = &arguments.0[1]; + assert!(!c.optional); + assert_eq!("c", c.ident.to_string()); + assert!(!c.pre_bracket); + assert!(!c.pre_pipe); + assert!(!c.post_bracket); + assert!(c.can_be_negated); + } + + #[test] + fn arguments_tex() { + let input = quote! { + d[|p], [a{, b}, c], dpdx, dpdy {, e} + }; + let arguments = syn::parse2::(input).unwrap(); + assert_eq!(arguments.0.len(), 8); + { + let d = &arguments.0[0]; + assert!(!d.optional); + assert_eq!("d", d.ident.to_string()); + assert!(!d.pre_bracket); + assert!(!d.pre_pipe); + assert!(!d.post_bracket); + assert!(!d.can_be_negated); + } + { + let p = &arguments.0[1]; + assert!(p.optional); + assert_eq!("p", p.ident.to_string()); + assert!(!p.pre_bracket); + assert!(p.pre_pipe); + assert!(!p.post_bracket); + assert!(!p.can_be_negated); + } + { + let a = &arguments.0[2]; + assert!(!a.optional); + assert_eq!("a", a.ident.to_string()); + assert!(a.pre_bracket); + assert!(!a.pre_pipe); + assert!(!a.post_bracket); + assert!(!a.can_be_negated); + } + { + let b = &arguments.0[3]; + assert!(b.optional); + assert_eq!("b", b.ident.to_string()); + assert!(!b.pre_bracket); + assert!(!b.pre_pipe); + assert!(!b.post_bracket); + assert!(!b.can_be_negated); + } + { + let c = &arguments.0[4]; + assert!(!c.optional); + assert_eq!("c", c.ident.to_string()); + assert!(!c.pre_bracket); + assert!(!c.pre_pipe); + assert!(c.post_bracket); + assert!(!c.can_be_negated); + } + { + let dpdx = &arguments.0[5]; + assert!(!dpdx.optional); + assert_eq!("dpdx", dpdx.ident.to_string()); + assert!(!dpdx.pre_bracket); + assert!(!dpdx.pre_pipe); + assert!(!dpdx.post_bracket); + assert!(!dpdx.can_be_negated); + } + { + let dpdy = &arguments.0[6]; + assert!(!dpdy.optional); + assert_eq!("dpdy", dpdy.ident.to_string()); + assert!(!dpdy.pre_bracket); + assert!(!dpdy.pre_pipe); + assert!(!dpdy.post_bracket); + assert!(!dpdy.can_be_negated); + } + { + let e = &arguments.0[7]; + assert!(e.optional); + assert_eq!("e", e.ident.to_string()); + assert!(!e.pre_bracket); + assert!(!e.pre_pipe); + assert!(!e.post_bracket); + assert!(!e.can_be_negated); + } + } + + #[test] + fn rule_multi() { + let input = quote! { + .ss: StateSpace = { .global, .local, .param{::func}, .shared{::cta, ::cluster} } + }; + let rule = syn::parse2::(input).unwrap(); + assert_eq!(". ss", rule.modifier.tokens().to_string()); + assert_eq!( + "StateSpace", + rule.type_.unwrap().to_token_stream().to_string() + ); + let alts = rule + .alternatives + .iter() + .map(|m| m.tokens().to_string()) + .collect::>(); + assert_eq!( + vec![ + ". global", + ". local", + ". param", + ". param :: func", + ". shared", + ". shared :: cta", + ". shared :: cluster" + ], + alts + ); + } + + #[test] + fn rule_multi2() { + let input = quote! { + .cop: StCacheOperator = { .wb, .cg, .cs, .wt } + }; + let rule = syn::parse2::(input).unwrap(); + assert_eq!(". cop", rule.modifier.tokens().to_string()); + assert_eq!( + "StCacheOperator", + rule.type_.unwrap().to_token_stream().to_string() + ); + let alts = rule + .alternatives + .iter() + .map(|m| m.tokens().to_string()) + .collect::>(); + assert_eq!(vec![". wb", ". cg", ". cs", ". wt",], alts); + } + + #[test] + fn args_unified() { + let input = quote! { + d, [a]{.unified}{, cache_policy} + }; + let args = syn::parse2::(input).unwrap(); + let a = &args.0[1]; + assert!(!a.optional); + assert_eq!("a", a.ident.to_string()); + assert!(a.pre_bracket); + assert!(!a.pre_pipe); + assert!(a.post_bracket); + assert!(!a.can_be_negated); + assert!(a.unified); + } +} diff --git a/ptx_parser/Cargo.toml b/ptx_parser/Cargo.toml new file mode 100644 index 0000000..d5e3d5d --- /dev/null +++ b/ptx_parser/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "ptx_parser" +version = "0.1.0" +edition = "2021" + +[dependencies] +logos = "0.14" +winnow = { version = "0.6.18", features = ["debug"] } +gen = { path = "../gen" } diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/main.rs new file mode 100644 index 0000000..8af8ede --- /dev/null +++ b/ptx_parser/src/main.rs @@ -0,0 +1,437 @@ +use gen::derive_parser; +use logos::Logos; +use std::mem; +use winnow::{ + error::{ContextError, ParserError}, + stream::{Offset, Stream, StreamIsPartial}, +}; + +pub trait Operand {} + +pub trait Visitor { + fn visit(&mut self, args: &T, type_: &Type, space: StateSpace, is_dst: bool); +} + +pub trait VisitorMut { + fn visit(&mut self, args: &mut T, type_: &Type, space: StateSpace, is_dst: bool); +} + +pub trait VisitorMap { + fn visit(&mut self, args: From, type_: &Type, space: StateSpace, is_dst: bool) -> To; +} + +gen::generate_instruction_type!( + enum Instruction { + Ld { + type: { &data.typ }, + data: LdDetails, + arguments: { + dst: T, + src: { + repr: T, + space: { data.state_space }, + } + } + }, + Add { + type: { data.type_().into() }, + data: ArithDetails, + arguments: { + dst: T, + src1: T, + src2: T, + } + }, + St { + type: { &data.typ }, + data: StData, + arguments: { + src1: { + repr: T, + space: { data.state_space }, + }, + src2: T, + } + }, + Ret { + data: RetData + }, + Trap { } + } +); + +pub struct LdDetails { + pub qualifier: LdStQualifier, + pub state_space: StateSpace, + pub caching: LdCacheOperator, + pub typ: Type, + pub non_coherent: bool, +} + +#[derive(Copy, Clone)] +pub enum ArithDetails { + Unsigned(ScalarType), + Signed(ArithSInt), + Float(ArithFloat), +} + +impl ArithDetails { + fn type_(&self) -> ScalarType { + match self { + ArithDetails::Unsigned(t) => *t, + ArithDetails::Signed(arith) => arith.typ, + ArithDetails::Float(arith) => arith.typ, + } + } +} + +#[derive(Copy, Clone)] +pub struct ArithSInt { + pub typ: ScalarType, + pub saturate: bool, +} + +#[derive(Copy, Clone)] +pub struct ArithFloat { + pub typ: ScalarType, + pub rounding: Option, + pub flush_to_zero: Option, + pub saturate: bool, +} + +#[derive(PartialEq, Eq, Copy, Clone)] +pub enum RoundingMode { + NearestEven, + Zero, + NegativeInf, + PositiveInf, +} + +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum LdCacheOperator { + Cached, + L2Only, + Streaming, + LastUse, + Uncached, +} + +#[derive(PartialEq, Eq, Clone, Hash)] +pub enum Type { + // .param.b32 foo; + Scalar(ScalarType), + // .param.v2.b32 foo; + Vector(ScalarType, u8), + // .param.b32 foo[4]; + Array(ScalarType, Vec), +} + +impl From for Type { + fn from(value: ScalarType) -> Self { + Type::Scalar(value) + } +} + +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum LdStQualifier { + Weak, + Volatile, + Relaxed(MemScope), + Acquire(MemScope), + Release(MemScope), +} + +pub struct StData { + pub qualifier: LdStQualifier, + pub state_space: StateSpace, + pub caching: StCacheOperator, + pub typ: Type, +} + +#[derive(PartialEq, Eq)] +pub enum StCacheOperator { + Writeback, + L2Only, + Streaming, + Writethrough, +} + +#[derive(Copy, Clone)] +pub struct RetData { + pub uniform: bool, +} + +pub struct ParsedOperand {} + +impl Operand for ParsedOperand {} + +#[derive(Debug)] +struct ReverseStream<'a, T>(pub &'a [T]); + +impl<'i, T> Stream for ReverseStream<'i, T> +where + T: Clone + ::std::fmt::Debug, +{ + type Token = T; + type Slice = &'i [T]; + + type IterOffsets = + std::iter::Enumerate>>>; + + type Checkpoint = &'i [T]; + + #[inline(always)] + fn iter_offsets(&self) -> Self::IterOffsets { + self.0.iter().rev().cloned().enumerate() + } + + #[inline(always)] + fn eof_offset(&self) -> usize { + self.0.len() + } + + #[inline(always)] + fn next_token(&mut self) -> Option { + let (token, next) = self.0.split_last()?; + self.0 = next; + Some(token.clone()) + } + + #[inline(always)] + fn offset_for

(&self, predicate: P) -> Option + where + P: Fn(Self::Token) -> bool, + { + self.0.iter().rev().position(|b| predicate(b.clone())) + } + + #[inline(always)] + fn offset_at(&self, tokens: usize) -> Result { + if let Some(needed) = tokens + .checked_sub(self.0.len()) + .and_then(std::num::NonZeroUsize::new) + { + Err(winnow::error::Needed::Size(needed)) + } else { + Ok(tokens) + } + } + + #[inline(always)] + fn next_slice(&mut self, offset: usize) -> Self::Slice { + let offset = self.0.len() - offset; + let (next, slice) = self.0.split_at(offset); + self.0 = next; + slice + } + + #[inline(always)] + fn checkpoint(&self) -> Self::Checkpoint { + self.0 + } + + #[inline(always)] + fn reset(&mut self, checkpoint: &Self::Checkpoint) { + self.0 = checkpoint; + } + + #[inline(always)] + fn raw(&self) -> &dyn std::fmt::Debug { + self + } +} + +impl<'a, T> Offset<&'a [T]> for ReverseStream<'a, T> { + #[inline] + fn offset_from(&self, start: &&'a [T]) -> usize { + let fst = start.as_ptr(); + let snd = self.0.as_ptr(); + + debug_assert!( + snd <= fst, + "`Offset::offset_from({snd:?}, {fst:?})` only accepts slices of `self`" + ); + (fst as usize - snd as usize) / std::mem::size_of::() + } +} + +impl<'a, T> StreamIsPartial for ReverseStream<'a, T> { + type PartialState = (); + + fn complete(&mut self) -> Self::PartialState {} + + fn restore_partial(&mut self, _state: Self::PartialState) {} + + fn is_partial_supported() -> bool { + false + } +} + +// Modifiers are turned into arguments to the blocks, with type: +// * If it is an alternative: +// * If it is mandatory then its type is Foo (as defined by the relevant rule) +// * If it is optional then its type is Option +// * Otherwise: +// * If it is mandatory then it is skipped +// * If it is optional then its type is `bool` + +derive_parser!( + #[derive(Logos, PartialEq, Eq, Debug, Clone, Copy)] + #[logos(skip r"\s+")] + enum Token<'input> { + #[token(",")] + Comma, + #[regex(r"[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+", |lex| lex.slice(), priority = 0)] + Ident(&'input str), + #[token("|")] + Or, + #[token("!")] + Not, + #[token(";")] + Semicolon, + #[token("[")] + LBracket, + #[token("]")] + RBracket, + #[regex(r"[0-9]+U?")] + Decimal + } + + #[derive(Copy, Clone, PartialEq, Eq, Hash)] + pub enum StateSpace { + Reg + } + + #[derive(Copy, Clone, PartialEq, Eq, Hash)] + pub enum MemScope { } + + #[derive(Copy, Clone, PartialEq, Eq, Hash)] + pub enum ScalarType { } + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st + st{.weak}{.ss}{.cop}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => { + todo!() + } + st.volatile{.ss}{.vec}.type [a], b => { + todo!() + } + st.relaxed.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => { + todo!() + } + st.release.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => { + todo!() + } + st.mmio.relaxed.sys{.global}.type [a], b => { + todo!() + } + + .ss: StateSpace = { .global, .local, .param{::func}, .shared{::cta, ::cluster} }; + .level::eviction_priority: EvictionPriority = + { .L1::evict_normal, .L1::evict_unchanged, .L1::evict_first, .L1::evict_last, .L1::no_allocate }; + .level::cache_hint = { .L2::cache_hint }; + .cop: RawStCacheOperator = { .wb, .cg, .cs, .wt }; + .scope: MemScope = { .cta, .cluster, .gpu, .sys }; + .vec: VectorPrefix = { .v2, .v4 }; + .type: ScalarType = { .b8, .b16, .b32, .b64, .b128, + .u8, .u16, .u32, .u64, + .s8, .s16, .s32, .s64, + .f32, .f64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-ld + ld{.weak}{.ss}{.cop}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{.unified}{, cache_policy} => { + todo!() + } + ld.volatile{.ss}{.level::prefetch_size}{.vec}.type d, [a] => { + todo!() + } + ld.relaxed.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{, cache_policy} => { + todo!() + } + ld.acquire.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{, cache_policy} => { + todo!() + } + ld.mmio.relaxed.sys{.global}.type d, [a] => { + todo!() + } + + .ss: StateSpace = { .const, .global, .local, .param{::entry, ::func}, .shared{::cta, ::cluster} }; + .cop: RawCacheOp = { .ca, .cg, .cs, .lu, .cv }; + .level::eviction_priority: EvictionPriority = + { .L1::evict_normal, .L1::evict_unchanged, .L1::evict_first, .L1::evict_last, .L1::no_allocate }; + .level::cache_hint = { .L2::cache_hint }; + .level::prefetch_size: PrefetchSize = { .L2::64B, .L2::128B, .L2::256B }; + .scope: MemScope = { .cta, .cluster, .gpu, .sys }; + .vec: VectorPrefix = { .v2, .v4 }; + .type: ScalarType = { .b8, .b16, .b32, .b64, .b128, + .u8, .u16, .u32, .u64, + .s8, .s16, .s32, .s64, + .f32, .f64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-add + add.type d, a, b => { + todo!() + } + add{.sat}.s32 d, a, b => { + todo!() + } + + .type: ScalarType = { .u16, .u32, .u64, + .s16, .s64, + .u16x2, .s16x2 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-add + add{.rnd}{.ftz}{.sat}.f32 d, a, b => { + todo!() + } + add{.rnd}.f64 d, a, b => { + todo!() + } + + .rnd: RawFloatRounding = { .rn, .rz, .rm, .rp }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-add + add{.rnd}{.ftz}{.sat}.f16 d, a, b => { + todo!() + } + add{.rnd}{.ftz}{.sat}.f16x2 d, a, b => { + todo!() + } + add{.rnd}.bf16 d, a, b => { + todo!() + } + add{.rnd}.bf16x2 d, a, b => { + todo!() + } + + .rnd: RawFloatRounding = { .rn }; + + ret => { + todo!() + } + +); + +fn main() { + use winnow::combinator::*; + use winnow::token::*; + use winnow::Parser; + + let mut input: &[Token] = &[][..]; + let x = opt(any::<_, ContextError>.verify_map(|t| { println!("MAP");Some(true) })).parse_next(&mut input).unwrap(); + dbg!(x); + let lexer = Token::lexer( + " + ld.u64 temp, [in_addr]; + add.u64 temp2, temp, 1; + st.u64 [out_addr], temp2; + ret; + ", + ); + let tokens = lexer.map(|t| t.unwrap()).collect::>(); + println!("{:?}", &tokens); + let mut stream = &tokens[..]; + parse_instruction(&mut stream).unwrap(); + //parse_prefix(&mut lexer); + let mut parser = &*tokens; + println!("{}", mem::size_of::()); +}