From 91dbbb372b04c40e0f0ad60cbeda621fb592ee01 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 16 Aug 2024 18:29:13 +0200 Subject: [PATCH] Move all types to a separate module --- gen/src/lib.rs | 10 +- gen_impl/src/lib.rs | 22 +++-- ptx_parser/src/ast.rs | 132 ++++++++++++++++++++++++++- ptx_parser/src/main.rs | 203 ++++------------------------------------- 4 files changed, 168 insertions(+), 199 deletions(-) diff --git a/gen/src/lib.rs b/gen/src/lib.rs index 93b31fe..67b276e 100644 --- a/gen/src/lib.rs +++ b/gen/src/lib.rs @@ -178,9 +178,9 @@ impl SingleOpcodeDefinition { .chain(self.arguments.0.iter().map(|arg| { let name = &arg.ident; let arg_type = if arg.unified { - quote! { (ParsedOperand<'input>, bool) } + quote! { (ParsedOperandStr<'input>, bool) } } else { - quote! { ParsedOperand<'input> } + quote! { ParsedOperandStr<'input> } }; if arg.optional { quote! { #name : Option<#arg_type> } @@ -415,7 +415,7 @@ fn emit_parse_function( let code_block = &def.code_block.0; let args = def.function_arguments_declarations(); quote! { - fn #fn_name<'input>(state: &mut PtxParserState, #(#args),* ) -> Instruction> #code_block + fn #fn_name<'input>(state: &mut PtxParserState, #(#args),* ) -> Instruction> #code_block } }) }) @@ -506,7 +506,7 @@ fn emit_parse_function( #(#fns_)* - fn parse_instruction<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> winnow::error::PResult>> + fn parse_instruction<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> winnow::error::PResult>> { use winnow::Parser; use winnow::token::*; @@ -747,7 +747,7 @@ fn emit_definition_parser( }; let operand = { quote! { - ParsedOperand::parse + ParsedOperandStr::parse } }; let post_bracket = if arg.post_bracket { diff --git a/gen_impl/src/lib.rs b/gen_impl/src/lib.rs index 6b606af..7160603 100644 --- a/gen_impl/src/lib.rs +++ b/gen_impl/src/lib.rs @@ -2,11 +2,13 @@ use proc_macro2::TokenStream; use quote::{format_ident, quote, ToTokens}; use syn::{ braced, parse::Parse, punctuated::Punctuated, token, Expr, Ident, Token, Type, TypeParam, + Visibility, }; pub mod parser; pub struct GenerateInstructionType { + pub visibility: Option, pub name: Ident, pub type_parameters: Punctuated, pub short_parameters: Punctuated, @@ -16,16 +18,17 @@ pub struct GenerateInstructionType { impl GenerateInstructionType { pub fn emit_arg_types(&self, tokens: &mut TokenStream) { for v in self.variants.iter() { - v.emit_type(&self.type_parameters, tokens); + v.emit_type(&self.visibility, &self.type_parameters, tokens); } } pub fn emit_instruction_type(&self, tokens: &mut TokenStream) { + let vis = &self.visibility; let type_name = &self.name; let type_parameters = &self.type_parameters; let variants = self.variants.iter().map(|v| v.emit_variant()); quote! { - enum #type_name<#type_parameters> { + #vis enum #type_name<#type_parameters> { #(#variants),* } } @@ -133,6 +136,11 @@ impl VisitKind { impl Parse for GenerateInstructionType { fn parse(input: syn::parse::ParseStream) -> syn::Result { + let visibility = if !input.peek(Token![enum]) { + Some(input.parse::()?) + } else { + None + }; input.parse::()?; let name = input.parse::()?; input.parse::()?; @@ -146,6 +154,7 @@ impl Parse for GenerateInstructionType { braced!(variants_buffer in input); let variants = variants_buffer.parse_terminated(InstructionVariant::parse, Token![,])?; Ok(Self { + visibility, name, type_parameters, short_parameters, @@ -262,6 +271,7 @@ impl InstructionVariant { fn emit_type( &self, + vis: &Option, type_parameters: &Punctuated, tokens: &mut TokenStream, ) { @@ -275,9 +285,9 @@ impl InstructionVariant { } else { None }; - let fields = arguments.fields.iter().map(ArgumentField::emit_field); + let fields = arguments.fields.iter().map(|f| f.emit_field(vis)); quote! { - struct #name #type_parameters { + #vis struct #name #type_parameters { #(#fields),* } } @@ -559,11 +569,11 @@ impl ArgumentField { } } - fn emit_field(&self) -> TokenStream { + fn emit_field(&self, vis: &Option) -> TokenStream { let name = &self.name; let type_ = &self.repr; quote! { - #name: #type_ + #vis #name: #type_ } } } diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index a471b8e..302aef7 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -1,4 +1,113 @@ -use super::MemScope; +use super::{MemScope, ScalarType, VectorPrefix, StateSpace}; + +gen::generate_instruction_type!( + pub enum Instruction { + Mov { + type: { &data.typ }, + data: MovDetails, + arguments: { + dst: T, + src: T + } + }, + Ld { + type: { &data.typ }, + data: LdDetails, + arguments: { + dst: T, + src: { + repr: T, + space: { data.state_space }, + } + } + }, + Add { + type: { data.type_().into() }, + data: ArithDetails, + arguments: { + dst: T, + src1: T, + src2: T, + } + }, + St { + type: { &data.typ }, + data: StData, + arguments: { + src1: { + repr: T, + space: { data.state_space }, + }, + src2: T, + } + }, + Ret { + data: RetData + }, + Trap { } + } +); + +pub trait Visitor { + fn visit(&mut self, args: &T, type_: &Type, space: StateSpace, is_dst: bool); +} + +pub trait VisitorMut { + fn visit(&mut self, args: &mut T, type_: &Type, space: StateSpace, is_dst: bool); +} + +pub trait VisitorMap { + fn visit(&mut self, args: From, type_: &Type, space: StateSpace, is_dst: bool) -> To; +} + +#[derive(PartialEq, Eq, Clone, Hash)] +pub enum Type { + // .param.b32 foo; + Scalar(ScalarType), + // .param.v2.b32 foo; + Vector(ScalarType, u8), + // .param.b32 foo[4]; + Array(ScalarType, Vec), +} + +impl Type { + pub(crate) fn maybe_vector(vector: Option, scalar: ScalarType) -> Self { + match vector { + Some(VectorPrefix::V2) => Type::Vector(scalar, 2), + Some(VectorPrefix::V4) => Type::Vector(scalar, 4), + None => Type::Scalar(scalar), + } + } +} + +impl From for Type { + fn from(value: ScalarType) -> Self { + Type::Scalar(value) + } +} + +#[derive(Clone)] +pub struct MovDetails { + pub typ: super::Type, + pub src_is_address: bool, + // two fields below are in use by member moves + pub dst_width: u8, + pub src_width: u8, + // This is in use by auto-generated movs + pub relaxed_src2_conv: bool, +} + +impl MovDetails { + pub(crate) fn new(vector: Option, scalar: ScalarType) -> Self { + MovDetails { + typ: Type::maybe_vector(vector, scalar), + src_is_address: false, + dst_width: 0, + src_width: 0, + relaxed_src2_conv: false, + } + } +} #[derive(Clone)] pub enum ParsedOperand { @@ -81,3 +190,24 @@ pub enum RoundingMode { NegativeInf, PositiveInf, } + +pub struct LdDetails { + pub qualifier: LdStQualifier, + pub state_space: StateSpace, + pub caching: LdCacheOperator, + pub typ: Type, + pub non_coherent: bool, +} + + +pub struct StData { + pub qualifier: LdStQualifier, + pub state_space: StateSpace, + pub caching: StCacheOperator, + pub typ: Type, +} + +#[derive(Copy, Clone)] +pub struct RetData { + pub uniform: bool, +} diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/main.rs index eb137a5..34c27da 100644 --- a/ptx_parser/src/main.rs +++ b/ptx_parser/src/main.rs @@ -12,176 +12,7 @@ use winnow::{ use winnow::{prelude::*, Stateful}; mod ast; - -pub trait Operand {} - -pub trait Visitor { - fn visit(&mut self, args: &T, type_: &Type, space: StateSpace, is_dst: bool); -} - -pub trait VisitorMut { - fn visit(&mut self, args: &mut T, type_: &Type, space: StateSpace, is_dst: bool); -} - -pub trait VisitorMap { - fn visit(&mut self, args: From, type_: &Type, space: StateSpace, is_dst: bool) -> To; -} - -#[derive(Clone)] -pub struct MovDetails { - pub typ: Type, - pub src_is_address: bool, - // two fields below are in use by member moves - pub dst_width: u8, - pub src_width: u8, - // This is in use by auto-generated movs - pub relaxed_src2_conv: bool, -} - -impl MovDetails { - fn new(vector: Option, scalar: ScalarType) -> Self { - MovDetails { - typ: Type::maybe_vector(vector, scalar), - src_is_address: false, - dst_width: 0, - src_width: 0, - relaxed_src2_conv: false, - } - } -} - -gen::generate_instruction_type!( - enum Instruction { - Mov { - type: { &data.typ }, - data: MovDetails, - arguments: { - dst: T, - src: T - } - }, - Ld { - type: { &data.typ }, - data: LdDetails, - arguments: { - dst: T, - src: { - repr: T, - space: { data.state_space }, - } - } - }, - Add { - type: { data.type_().into() }, - data: ast::ArithDetails, - arguments: { - dst: T, - src1: T, - src2: T, - } - }, - St { - type: { &data.typ }, - data: StData, - arguments: { - src1: { - repr: T, - space: { data.state_space }, - }, - src2: T, - } - }, - Ret { - data: RetData - }, - Trap { } - } -); - -pub struct LdDetails { - pub qualifier: ast::LdStQualifier, - pub state_space: StateSpace, - pub caching: ast::LdCacheOperator, - pub typ: Type, - pub non_coherent: bool, -} - -#[derive(Copy, Clone)] -pub enum ArithDetails { - Unsigned(ScalarType), - Signed(ArithSInt), - Float(ArithFloat), -} - -impl ArithDetails { - fn type_(&self) -> ScalarType { - match self { - ArithDetails::Unsigned(t) => *t, - ArithDetails::Signed(arith) => arith.typ, - ArithDetails::Float(arith) => arith.typ, - } - } -} - -#[derive(Copy, Clone)] -pub struct ArithSInt { - pub typ: ScalarType, - pub saturate: bool, -} - -#[derive(Copy, Clone)] -pub struct ArithFloat { - pub typ: ScalarType, - pub rounding: Option, - pub flush_to_zero: Option, - pub saturate: bool, -} - -#[derive(PartialEq, Eq, Copy, Clone)] -pub enum RoundingMode { - NearestEven, - Zero, - NegativeInf, - PositiveInf, -} - -#[derive(PartialEq, Eq, Clone, Hash)] -pub enum Type { - // .param.b32 foo; - Scalar(ScalarType), - // .param.v2.b32 foo; - Vector(ScalarType, u8), - // .param.b32 foo[4]; - Array(ScalarType, Vec), -} - -impl Type { - fn maybe_vector(vector: Option, scalar: ScalarType) -> Self { - match vector { - Some(VectorPrefix::V2) => Type::Vector(scalar, 2), - Some(VectorPrefix::V4) => Type::Vector(scalar, 4), - None => Type::Scalar(scalar), - } - } -} - -impl From for Type { - fn from(value: ScalarType) -> Self { - Type::Scalar(value) - } -} - -pub struct StData { - pub qualifier: ast::LdStQualifier, - pub state_space: StateSpace, - pub caching: ast::StCacheOperator, - pub typ: Type, -} - -#[derive(Copy, Clone)] -pub struct RetData { - pub uniform: bool, -} +pub use ast::*; impl From for ast::StCacheOperator { fn from(value: RawStCacheOperator) -> Self { @@ -350,7 +181,7 @@ fn immediate_value<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult( stream: &mut PtxParser<'a, 'input>, -) -> PResult>>> { +) -> PResult>>> { repeat(3.., terminated(parse_instruction, Token::Semicolon)).parse_next(stream) } @@ -550,7 +381,7 @@ impl<'input, I: Stream + StreamIsPartial, E: ParserError> Parse // * If it is mandatory then it is skipped // * If it is optional then its type is `bool` -type ParsedOperand<'input> = ast::ParsedOperand<&'input str>; +type ParsedOperandStr<'input> = ast::ParsedOperand<&'input str>; derive_parser!( #[derive(Logos, PartialEq, Eq, Debug, Clone, Copy)] @@ -601,7 +432,7 @@ derive_parser!( // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov mov{.vec}.type d, a => { Instruction::Mov { - data: MovDetails::new(vec, type_), + data: ast::MovDetails::new(vec, type_), arguments: MovArgs { dst: d, src: a }, } } @@ -622,7 +453,7 @@ derive_parser!( qualifier: weak.unwrap_or(RawLdStQualifier::Weak).into(), state_space: ss.unwrap_or(StateSpace::Generic), caching: cop.unwrap_or(RawStCacheOperator::Wb).into(), - typ: Type::maybe_vector(vec, type_) + typ: ast::Type::maybe_vector(vec, type_) }, arguments: StArgs { src1:a, src2:b } } @@ -633,7 +464,7 @@ derive_parser!( qualifier: volatile.into(), state_space: ss.unwrap_or(StateSpace::Generic), caching: ast::StCacheOperator::Writeback, - typ: Type::maybe_vector(vec, type_) + typ: ast::Type::maybe_vector(vec, type_) }, arguments: StArgs { src1:a, src2:b } } @@ -647,7 +478,7 @@ derive_parser!( qualifier: ast::LdStQualifier::Relaxed(scope), state_space: ss.unwrap_or(StateSpace::Generic), caching: ast::StCacheOperator::Writeback, - typ: Type::maybe_vector(vec, type_) + typ: ast::Type::maybe_vector(vec, type_) }, arguments: StArgs { src1:a, src2:b } } @@ -661,7 +492,7 @@ derive_parser!( qualifier: ast::LdStQualifier::Release(scope), state_space: ss.unwrap_or(StateSpace::Generic), caching: ast::StCacheOperator::Writeback, - typ: Type::maybe_vector(vec, type_) + typ: ast::Type::maybe_vector(vec, type_) }, arguments: StArgs { src1:a, src2:b } } @@ -669,13 +500,13 @@ derive_parser!( st.mmio.relaxed.sys{.global}.type [a], b => { state.push(PtxError::Todo); Instruction::St { - data: StData { + data: ast::StData { qualifier: ast::LdStQualifier::Relaxed(MemScope::Sys), state_space: global.unwrap_or(StateSpace::Generic), caching: ast::StCacheOperator::Writeback, typ: type_.into() }, - arguments: StArgs { src1:a, src2:b } + arguments: ast::StArgs { src1:a, src2:b } } } @@ -704,7 +535,7 @@ derive_parser!( qualifier: weak.unwrap_or(RawLdStQualifier::Weak).into(), state_space: ss.unwrap_or(StateSpace::Generic), caching: cop.unwrap_or(RawLdCacheOperator::Ca).into(), - typ: Type::maybe_vector(vec, type_), + typ: ast::Type::maybe_vector(vec, type_), non_coherent: false }, arguments: LdArgs { dst:d, src:a } @@ -719,7 +550,7 @@ derive_parser!( qualifier: volatile.into(), state_space: ss.unwrap_or(StateSpace::Generic), caching: ast::LdCacheOperator::Cached, - typ: Type::maybe_vector(vec, type_), + typ: ast::Type::maybe_vector(vec, type_), non_coherent: false }, arguments: LdArgs { dst:d, src:a } @@ -734,7 +565,7 @@ derive_parser!( qualifier: ast::LdStQualifier::Relaxed(scope), state_space: ss.unwrap_or(StateSpace::Generic), caching: ast::LdCacheOperator::Cached, - typ: Type::maybe_vector(vec, type_), + typ: ast::Type::maybe_vector(vec, type_), non_coherent: false }, arguments: LdArgs { dst:d, src:a } @@ -749,7 +580,7 @@ derive_parser!( qualifier: ast::LdStQualifier::Acquire(scope), state_space: ss.unwrap_or(StateSpace::Generic), caching: ast::LdCacheOperator::Cached, - typ: Type::maybe_vector(vec, type_), + typ: ast::Type::maybe_vector(vec, type_), non_coherent: false }, arguments: LdArgs { dst:d, src:a } @@ -931,7 +762,7 @@ fn main() { println!("{}", mem::size_of::()); let mut input: &[Token] = &[][..]; - let x = opt(any::<_, ContextError>.verify_map(|t| { + let x = opt(any::<_, ContextError>.verify_map(|_| { println!("MAP"); Some(true) })) @@ -948,13 +779,11 @@ fn main() { ); let tokens = lexer.map(|t| t.unwrap()).collect::>(); println!("{:?}", &tokens); - let mut stream = PtxParser { + let stream = PtxParser { input: &tokens[..], state: Vec::new(), }; let fn_body = fn_body.parse(stream).unwrap(); println!("{}", fn_body.len()); - //parse_prefix(&mut lexer); - let mut parser = &*tokens; println!("{}", mem::size_of::()); }