From 8d7c88c095a013261cca1c6e5cbfb1acaac05624 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Thu, 15 Aug 2024 03:26:38 +0200 Subject: [PATCH] Fully parse operands --- gen/src/lib.rs | 14 +- gen_impl/src/lib.rs | 2 +- ptx_parser/Cargo.toml | 1 + ptx_parser/src/ast.rs | 16 +++ ptx_parser/src/main.rs | 318 ++++++++++++++++++++++++++++++++++++++--- 5 files changed, 323 insertions(+), 28 deletions(-) create mode 100644 ptx_parser/src/ast.rs diff --git a/gen/src/lib.rs b/gen/src/lib.rs index f39150f..30e4595 100644 --- a/gen/src/lib.rs +++ b/gen/src/lib.rs @@ -174,9 +174,9 @@ impl SingleOpcodeDefinition { .chain(self.arguments.0.iter().map(|arg| { let name = &arg.ident; if arg.optional { - quote! { #name : Option<&str> } + quote! { #name : Option> } } else { - quote! { #name : &str } + quote! { #name : ParsedOperand<'input> } } })) } @@ -377,7 +377,7 @@ fn emit_parse_function( let code_block = &def.code_block.0; let args = def.function_arguments_declarations(); quote! { - fn #fn_name( #(#args),* ) -> Instruction #code_block + fn #fn_name<'input>( #(#args),* ) -> Instruction> #code_block } }) }) @@ -452,7 +452,7 @@ fn emit_parse_function( #(#fns_)* - fn parse_instruction<'input>(stream: &mut (impl winnow::stream::Stream, Slice = &'input [#type_name<'input>]> + winnow::stream::StreamIsPartial)) -> winnow::error::PResult> + fn parse_instruction<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> winnow::error::PResult>> { use winnow::Parser; use winnow::token::*; @@ -642,9 +642,9 @@ fn emit_definition_parser( empty } }; - let ident = { + let operand = { quote! { - any.verify_map(|t| match t { #token_type::Ident(s) => Some(s), _ => None }) + ParsedOperand::parse } }; let post_bracket = if arg.post_bracket { @@ -657,7 +657,7 @@ fn emit_definition_parser( } }; let parser = quote! { - (#comma, #pre_bracket, #pre_pipe, #can_be_negated, #ident, #post_bracket) + (#comma, #pre_bracket, #pre_pipe, #can_be_negated, #operand, #post_bracket) }; let arg_name = &arg.ident; if arg.optional { diff --git a/gen_impl/src/lib.rs b/gen_impl/src/lib.rs index 4c7f2ab..6b606af 100644 --- a/gen_impl/src/lib.rs +++ b/gen_impl/src/lib.rs @@ -67,7 +67,7 @@ impl GenerateInstructionType { let visit_slice_fn = format_ident!("visit{}_slice", kind.fn_suffix()); let (type_parameters, visitor_parameters, return_type) = if kind == VisitKind::Map { ( - quote! { <#type_parameters, To: Operand> }, + quote! { <#type_parameters, To> }, quote! { <#short_parameters, To> }, quote! { #type_name }, ) diff --git a/ptx_parser/Cargo.toml b/ptx_parser/Cargo.toml index d5e3d5d..951d508 100644 --- a/ptx_parser/Cargo.toml +++ b/ptx_parser/Cargo.toml @@ -7,3 +7,4 @@ edition = "2021" logos = "0.14" winnow = { version = "0.6.18", features = ["debug"] } gen = { path = "../gen" } +thiserror = "1.0" diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs new file mode 100644 index 0000000..ae4eaba --- /dev/null +++ b/ptx_parser/src/ast.rs @@ -0,0 +1,16 @@ +#[derive(Clone)] +pub enum ParsedOperand { + Reg(Ident), + RegOffset(Ident, i32), + Imm(ImmediateValue), + VecMember(Ident, u8), + VecPack(Vec), +} + +#[derive(Copy, Clone)] +pub enum ImmediateValue { + U64(u64), + S64(i64), + F32(f32), + F64(f64), +} diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/main.rs index 8af8ede..4f3ed41 100644 --- a/ptx_parser/src/main.rs +++ b/ptx_parser/src/main.rs @@ -1,27 +1,66 @@ use gen::derive_parser; use logos::Logos; use std::mem; +use std::num::{ParseFloatError, ParseIntError}; +use winnow::combinator::{alt, empty, fail, opt}; +use winnow::stream::SliceLen; +use winnow::token::{any, literal}; use winnow::{ error::{ContextError, ParserError}, stream::{Offset, Stream, StreamIsPartial}, + PResult, }; +use winnow::{prelude::*, Stateful}; + +mod ast; pub trait Operand {} -pub trait Visitor { +pub trait Visitor { fn visit(&mut self, args: &T, type_: &Type, space: StateSpace, is_dst: bool); } -pub trait VisitorMut { +pub trait VisitorMut { fn visit(&mut self, args: &mut T, type_: &Type, space: StateSpace, is_dst: bool); } -pub trait VisitorMap { +pub trait VisitorMap { fn visit(&mut self, args: From, type_: &Type, space: StateSpace, is_dst: bool) -> To; } +#[derive(Clone)] +pub struct MovDetails { + pub typ: Type, + pub src_is_address: bool, + // two fields below are in use by member moves + pub dst_width: u8, + pub src_width: u8, + // This is in use by auto-generated movs + pub relaxed_src2_conv: bool, +} + +impl MovDetails { + pub fn new(typ: Type) -> Self { + MovDetails { + typ, + src_is_address: false, + dst_width: 0, + src_width: 0, + relaxed_src2_conv: false, + } + } +} + gen::generate_instruction_type!( - enum Instruction { + enum Instruction { + Mov { + type: { &data.typ }, + data: MovDetails, + arguments: { + dst: T, + src: T + } + }, Ld { type: { &data.typ }, data: LdDetails, @@ -161,9 +200,212 @@ pub struct RetData { pub uniform: bool, } -pub struct ParsedOperand {} +type ParserState<'a, 'input> = Stateful<&'a [Token<'input>], Vec>; -impl Operand for ParsedOperand {} +fn ident<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> PResult<&'input str> { + any.verify_map(|t| { + if let Token::Ident(text) = t { + Some(text) + } else { + None + } + }) + .parse_next(stream) +} + +fn num<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> PResult<(&'input str, u32, bool)> { + any.verify_map(|t| { + Some(match t { + Token::Hex(s) => { + if s.ends_with('U') { + (&s[2..s.len() - 1], 16, true) + } else { + (&s[2..], 16, false) + } + } + Token::Decimal(s) => { + let radix = if s.starts_with('0') { 8 } else { 10 }; + if s.ends_with('U') { + (&s[..s.len() - 1], radix, true) + } else { + (s, radix, false) + } + } + _ => return None, + }) + }) + .parse_next(stream) +} + +fn take_error<'a, 'input: 'a, O, E>( + mut parser: impl Parser, Result, E>, +) -> impl Parser, O, E> { + move |input: &mut ParserState<'a, 'input>| { + Ok(match parser.parse_next(input)? { + Ok(x) => x, + Err((x, err)) => { + input.state.push(err); + x + } + }) + } +} + +fn int_immediate<'a, 'input>(input: &mut ParserState<'a, 'input>) -> PResult { + take_error((opt(Token::Minus), num).map(|(neg, x)| { + let (num, radix, is_unsigned) = x; + if neg.is_some() { + match i64::from_str_radix(num, radix) { + Ok(x) => Ok(ast::ImmediateValue::S64(-x)), + Err(err) => Err((ast::ImmediateValue::S64(0), PtxError::from(err))), + } + } else if is_unsigned { + match u64::from_str_radix(num, radix) { + Ok(x) => Ok(ast::ImmediateValue::U64(x)), + Err(err) => Err((ast::ImmediateValue::U64(0), PtxError::from(err))), + } + } else { + match i64::from_str_radix(num, radix) { + Ok(x) => Ok(ast::ImmediateValue::S64(x)), + Err(_) => match u64::from_str_radix(num, radix) { + Ok(x) => Ok(ast::ImmediateValue::U64(x)), + Err(err) => Err((ast::ImmediateValue::U64(0), PtxError::from(err))), + }, + } + } + })) + .parse_next(input) +} + +fn f32<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> PResult { + take_error(any.verify_map(|t| match t { + Token::F32(f) => Some(match u32::from_str_radix(&f[2..], 16) { + Ok(x) => Ok(f32::from_bits(x)), + Err(err) => Err((0.0, PtxError::from(err))), + }), + _ => None, + })) + .parse_next(stream) +} + +fn f64<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> PResult { + take_error(any.verify_map(|t| match t { + Token::F64(f) => Some(match u64::from_str_radix(&f[2..], 16) { + Ok(x) => Ok(f64::from_bits(x)), + Err(err) => Err((0.0, PtxError::from(err))), + }), + _ => None, + })) + .parse_next(stream) +} + +fn s32<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> PResult { + take_error((opt(Token::Minus), num).map(|(sign, x)| { + let (text, radix, _) = x; + match i32::from_str_radix(text, radix) { + Ok(x) => Ok(if sign.is_some() { -x } else { x }), + Err(err) => Err((0, PtxError::from(err))), + } + })) + .parse_next(stream) +} + +fn immediate_value<'a, 'input>( + stream: &mut ParserState<'a, 'input>, +) -> PResult { + alt(( + int_immediate, + f32.map(ast::ImmediateValue::F32), + f64.map(ast::ImmediateValue::F64), + )) + .parse_next(stream) +} + +impl ast::ParsedOperand { + fn parse<'a, 'input>( + stream: &mut ParserState<'a, 'input>, + ) -> PResult> { + use winnow::combinator::*; + use winnow::token::any; + fn vector_index<'input>(inp: &'input str) -> Result { + match inp { + "x" | "r" => Ok(0), + "y" | "g" => Ok(1), + "z" | "b" => Ok(2), + "w" | "a" => Ok(3), + _ => Err(PtxError::WrongVectorElement), + } + } + fn ident_operands<'a, 'input>( + stream: &mut ParserState<'a, 'input>, + ) -> PResult> { + let main_ident = ident.parse_next(stream)?; + alt(( + preceded(Token::Plus, s32) + .map(move |offset| ast::ParsedOperand::RegOffset(main_ident, offset)), + take_error(preceded(Token::Dot, ident).map(move |suffix| { + let vector_index = vector_index(suffix) + .map_err(move |e| (ast::ParsedOperand::VecMember(main_ident, 0), e))?; + Ok(ast::ParsedOperand::VecMember(main_ident, vector_index)) + })), + empty.value(ast::ParsedOperand::Reg(main_ident)), + )) + .parse_next(stream) + } + fn vector_operand<'a, 'input>( + stream: &mut ParserState<'a, 'input>, + ) -> PResult> { + let (_, r1, _, r2) = + (Token::LBracket, ident, Token::Comma, ident).parse_next(stream)?; + dispatch! {any; + Token::LBracket => empty.map(|_| vec![r1, r2]), + Token::Comma => (ident, Token::Comma, ident, Token::LBracket).map(|(r3, _, r4, _)| vec![r1, r2, r3, r4]), + _ => fail + } + .parse_next(stream) + } + alt(( + ident_operands, + immediate_value.map(ast::ParsedOperand::Imm), + vector_operand.map(ast::ParsedOperand::VecPack), + )) + .parse_next(stream) + } +} + +#[derive(Debug, thiserror::Error)] +pub enum PtxError { + #[error("{source}")] + ParseInt { + #[from] + source: ParseIntError, + }, + #[error("{source}")] + ParseFloat { + #[from] + source: ParseFloatError, + }, + #[error("")] + SyntaxError, + #[error("")] + NonF32Ftz, + #[error("")] + WrongArrayType, + #[error("")] + WrongVectorElement, + #[error("")] + MultiArrayVariable, + #[error("")] + ZeroDimensionArray, + #[error("")] + ArrayInitalizer, + #[error("")] + NonExternPointer, + #[error("{start}:{end}")] + UnrecognizedStatement { start: usize, end: usize }, + #[error("{start}:{end}")] + UnrecognizedDirective { start: usize, end: usize }, +} #[derive(Debug)] struct ReverseStream<'a, T>(pub &'a [T]); @@ -180,24 +422,20 @@ where type Checkpoint = &'i [T]; - #[inline(always)] fn iter_offsets(&self) -> Self::IterOffsets { self.0.iter().rev().cloned().enumerate() } - #[inline(always)] fn eof_offset(&self) -> usize { self.0.len() } - #[inline(always)] fn next_token(&mut self) -> Option { let (token, next) = self.0.split_last()?; self.0 = next; Some(token.clone()) } - #[inline(always)] fn offset_for

(&self, predicate: P) -> Option where P: Fn(Self::Token) -> bool, @@ -205,7 +443,6 @@ where self.0.iter().rev().position(|b| predicate(b.clone())) } - #[inline(always)] fn offset_at(&self, tokens: usize) -> Result { if let Some(needed) = tokens .checked_sub(self.0.len()) @@ -217,7 +454,6 @@ where } } - #[inline(always)] fn next_slice(&mut self, offset: usize) -> Self::Slice { let offset = self.0.len() - offset; let (next, slice) = self.0.split_at(offset); @@ -225,24 +461,20 @@ where slice } - #[inline(always)] fn checkpoint(&self) -> Self::Checkpoint { self.0 } - #[inline(always)] fn reset(&mut self, checkpoint: &Self::Checkpoint) { self.0 = checkpoint; } - #[inline(always)] fn raw(&self) -> &dyn std::fmt::Debug { self } } impl<'a, T> Offset<&'a [T]> for ReverseStream<'a, T> { - #[inline] fn offset_from(&self, start: &&'a [T]) -> usize { let fst = start.as_ptr(); let snd = self.0.as_ptr(); @@ -267,6 +499,14 @@ impl<'a, T> StreamIsPartial for ReverseStream<'a, T> { } } +impl<'input, I: Stream + StreamIsPartial, E: ParserError> Parser + for Token<'input> +{ + fn parse_next(&mut self, input: &mut I) -> PResult { + any.parse_next(input) + } +} + // Modifiers are turned into arguments to the blocks, with type: // * If it is an alternative: // * If it is mandatory then its type is Foo (as defined by the relevant rule) @@ -275,12 +515,16 @@ impl<'a, T> StreamIsPartial for ReverseStream<'a, T> { // * If it is mandatory then it is skipped // * If it is optional then its type is `bool` +type ParsedOperand<'input> = ast::ParsedOperand<&'input str>; + derive_parser!( #[derive(Logos, PartialEq, Eq, Debug, Clone, Copy)] #[logos(skip r"\s+")] enum Token<'input> { #[token(",")] Comma, + #[token(".")] + Dot, #[regex(r"[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+", |lex| lex.slice(), priority = 0)] Ident(&'input str), #[token("|")] @@ -293,8 +537,18 @@ derive_parser!( LBracket, #[token("]")] RBracket, - #[regex(r"[0-9]+U?")] - Decimal + #[regex(r"0[fF][0-9a-zA-Z]{8}", |lex| lex.slice())] + F32(&'input str), + #[regex(r"0[dD][0-9a-zA-Z]{16}", |lex| lex.slice())] + F64(&'input str), + #[regex(r"0[xX][0-9a-zA-Z]+U?", |lex| lex.slice())] + Hex(&'input str), + #[regex(r"[0-9]+U?", |lex| lex.slice())] + Decimal(&'input str), + #[token("-")] + Minus, + #[token("+")] + Plus, } #[derive(Copy, Clone, PartialEq, Eq, Hash)] @@ -308,6 +562,20 @@ derive_parser!( #[derive(Copy, Clone, PartialEq, Eq, Hash)] pub enum ScalarType { } + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov + + mov.type d, a => { + Instruction::Mov{ + data: MovDetails::new(type_.into()), + arguments: MovArgs { dst: d, src: a } + } + } + .type: ScalarType = { .pred, + .b16, .b32, .b64, + .u16, .u32, .u64, + .s16, .s32, .s64, + .f32, .f64 }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st st{.weak}{.ss}{.cop}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => { todo!() @@ -416,8 +684,15 @@ fn main() { use winnow::token::*; use winnow::Parser; + println!("{}", mem::size_of::()); + let mut input: &[Token] = &[][..]; - let x = opt(any::<_, ContextError>.verify_map(|t| { println!("MAP");Some(true) })).parse_next(&mut input).unwrap(); + let x = opt(any::<_, ContextError>.verify_map(|t| { + println!("MAP"); + Some(true) + })) + .parse_next(&mut input) + .unwrap(); dbg!(x); let lexer = Token::lexer( " @@ -429,7 +704,10 @@ fn main() { ); let tokens = lexer.map(|t| t.unwrap()).collect::>(); println!("{:?}", &tokens); - let mut stream = &tokens[..]; + let mut stream = ParserState { + input: &tokens[..], + state: Vec::new(), + }; parse_instruction(&mut stream).unwrap(); //parse_prefix(&mut lexer); let mut parser = &*tokens;