From ba17906de8381482241dc151d4891845a84bc71e Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Thu, 15 Aug 2024 19:30:09 +0200 Subject: [PATCH] Pass parser state to instruction callbacks --- gen/src/lib.rs | 6 +++--- ptx_parser/src/main.rs | 35 ++++++++++++++++++----------------- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/gen/src/lib.rs b/gen/src/lib.rs index 3ab5e43..6bea2df 100644 --- a/gen/src/lib.rs +++ b/gen/src/lib.rs @@ -383,7 +383,7 @@ fn emit_parse_function( let code_block = &def.code_block.0; let args = def.function_arguments_declarations(); quote! { - fn #fn_name<'input>( #(#args),* ) -> Instruction> #code_block + fn #fn_name<'input>(state: &mut PtxParserState, #(#args),* ) -> Instruction> #code_block } }) }) @@ -473,7 +473,7 @@ fn emit_parse_function( #(#fns_)* - fn parse_instruction<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> winnow::error::PResult>> + fn parse_instruction<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> winnow::error::PResult>> { use winnow::Parser; use winnow::token::*; @@ -695,7 +695,7 @@ fn emit_definition_parser( let fn_args = definition.function_arguments(); let fn_name = format_ident!("{}_{}", opcode, fn_idx); let fn_call = quote! { - #fn_name( #(#fn_args),* ) + #fn_name(&mut stream.state, #(#fn_args),* ) }; quote! { #(#unordered_parse_declarations)* diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/main.rs index 96f08b6..7786deb 100644 --- a/ptx_parser/src/main.rs +++ b/ptx_parser/src/main.rs @@ -199,9 +199,10 @@ pub struct RetData { pub uniform: bool, } -type ParserState<'a, 'input> = Stateful<&'a [Token<'input>], Vec>; +type PtxParserState = Vec; +type PtxParser<'a, 'input> = Stateful<&'a [Token<'input>], PtxParserState>; -fn ident<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> PResult<&'input str> { +fn ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input str> { any.verify_map(|t| { if let Token::Ident(text) = t { Some(text) @@ -214,7 +215,7 @@ fn ident<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> PResult<&'input st .parse_next(stream) } -fn num<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> PResult<(&'input str, u32, bool)> { +fn num<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(&'input str, u32, bool)> { any.verify_map(|t| { Some(match t { Token::Hex(s) => { @@ -239,9 +240,9 @@ fn num<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> PResult<(&'input str } fn take_error<'a, 'input: 'a, O, E>( - mut parser: impl Parser, Result, E>, -) -> impl Parser, O, E> { - move |input: &mut ParserState<'a, 'input>| { + mut parser: impl Parser, Result, E>, +) -> impl Parser, O, E> { + move |input: &mut PtxParser<'a, 'input>| { Ok(match parser.parse_next(input)? { Ok(x) => x, Err((x, err)) => { @@ -252,7 +253,7 @@ fn take_error<'a, 'input: 'a, O, E>( } } -fn int_immediate<'a, 'input>(input: &mut ParserState<'a, 'input>) -> PResult { +fn int_immediate<'a, 'input>(input: &mut PtxParser<'a, 'input>) -> PResult { take_error((opt(Token::Minus), num).map(|(neg, x)| { let (num, radix, is_unsigned) = x; if neg.is_some() { @@ -278,7 +279,7 @@ fn int_immediate<'a, 'input>(input: &mut ParserState<'a, 'input>) -> PResult(stream: &mut ParserState<'a, 'input>) -> PResult { +fn f32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { take_error(any.verify_map(|t| match t { Token::F32(f) => Some(match u32::from_str_radix(&f[2..], 16) { Ok(x) => Ok(f32::from_bits(x)), @@ -289,7 +290,7 @@ fn f32<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> PResult { .parse_next(stream) } -fn f64<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> PResult { +fn f64<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { take_error(any.verify_map(|t| match t { Token::F64(f) => Some(match u64::from_str_radix(&f[2..], 16) { Ok(x) => Ok(f64::from_bits(x)), @@ -300,7 +301,7 @@ fn f64<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> PResult { .parse_next(stream) } -fn s32<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> PResult { +fn s32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { take_error((opt(Token::Minus), num).map(|(sign, x)| { let (text, radix, _) = x; match i32::from_str_radix(text, radix) { @@ -312,7 +313,7 @@ fn s32<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> PResult { } fn immediate_value<'a, 'input>( - stream: &mut ParserState<'a, 'input>, + stream: &mut PtxParser<'a, 'input>, ) -> PResult { alt(( int_immediate, @@ -324,7 +325,7 @@ fn immediate_value<'a, 'input>( impl ast::ParsedOperand { fn parse<'a, 'input>( - stream: &mut ParserState<'a, 'input>, + stream: &mut PtxParser<'a, 'input>, ) -> PResult> { use winnow::combinator::*; use winnow::token::any; @@ -338,7 +339,7 @@ impl ast::ParsedOperand { } } fn ident_operands<'a, 'input>( - stream: &mut ParserState<'a, 'input>, + stream: &mut PtxParser<'a, 'input>, ) -> PResult> { let main_ident = ident.parse_next(stream)?; alt(( @@ -354,7 +355,7 @@ impl ast::ParsedOperand { .parse_next(stream) } fn vector_operand<'a, 'input>( - stream: &mut ParserState<'a, 'input>, + stream: &mut PtxParser<'a, 'input>, ) -> PResult> { let (_, r1, _, r2) = (Token::LBracket, ident, Token::Comma, ident).parse_next(stream)?; @@ -565,9 +566,9 @@ derive_parser!( // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov mov.type d, a => { - Instruction::Mov{ + Instruction::Mov { data: MovDetails::new(type_.into()), - arguments: MovArgs { dst: d, src: a } + arguments: MovArgs { dst: d, src: a }, } } .type: ScalarType = { .pred, @@ -704,7 +705,7 @@ fn main() { ); let tokens = lexer.map(|t| t.unwrap()).collect::>(); println!("{:?}", &tokens); - let mut stream = ParserState { + let mut stream = PtxParser { input: &tokens[..], state: Vec::new(), };