diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index 13764e7..6dedbbb 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -1492,6 +1492,46 @@ pub struct TokenError(std::ops::Range); impl std::error::Error for TokenError {} +fn first_optional< + 'a, + 'input, + Input: Stream, + OptionalOutput, + RequiredOutput, + Error, + ParseOptional, + ParseRequired, +>( + mut optional: ParseOptional, + mut required: ParseRequired, +) -> impl Parser, RequiredOutput), Error> +where + ParseOptional: Parser, + ParseRequired: Parser, + Error: ParserError, +{ + move |input: &mut Input| -> Result<(Option, RequiredOutput), ErrMode> { + let start = input.checkpoint(); + + let parsed_optional = match optional.parse_next(input) { + Ok(v) => Some(v), + Err(ErrMode::Backtrack(_)) => { + input.reset(&start); + None + }, + Err(e) => return Err(e) + }; + + match required.parse_next(input) { + Ok(v) => return Ok((parsed_optional, v)), + Err(ErrMode::Backtrack(_)) => input.reset(&start), + Err(e) => return Err(e) + }; + + Ok((None, required.parse_next(input)?)) + } +} + // This macro is responsible for generating parser code for instruction parser. // Instruction parsing is by far the most complex part of parsing PTX code: // * There are tens of instruction kinds, each with slightly different parsing rules @@ -3413,6 +3453,7 @@ derive_parser!( #[cfg(test)] mod tests { + use crate::first_optional; use crate::parse_module_checked; use crate::PtxError; @@ -3423,6 +3464,55 @@ mod tests { use logos::Span; use winnow::prelude::*; + #[test] + fn first_optional_present() { + let text = "AB"; + let result = first_optional::<_, _, _, (), _, _>('A', 'B').parse(text); + assert_eq!(result, Ok((Some('A'), 'B'))); + } + + #[test] + fn first_optional_absent() { + let text = "B"; + let result = first_optional::<_, _, _, (), _, _>('A', 'B').parse(text); + assert_eq!(result, Ok((None, 'B'))); + } + + #[test] + fn first_optional_repeated_absent() { + let text = "A"; + let result = first_optional::<_, _, _, (), _, _>('A', 'A').parse(text); + assert_eq!(result, Ok((None, 'A'))); + } + + #[test] + fn first_optional_repeated_present() { + let text = "AA"; + let result = first_optional::<_, _, _, (), _, _>('A', 'A').parse(text); + assert_eq!(result, Ok((Some('A'), 'A'))); + } + + #[test] + fn first_optional_sequence_absent() { + let text = "AA"; + let result = ('A', first_optional::<_, _, _, (), _, _>('A', 'A')).parse(text); + assert_eq!(result, Ok(('A', (None, 'A')))); + } + + #[test] + fn first_optional_sequence_present() { + let text = "AAA"; + let result = ('A', first_optional::<_, _, _, (), _, _>('A', 'A')).parse(text); + assert_eq!(result, Ok(('A', (Some('A'), 'A')))); + } + + #[test] + fn first_optional_no_match() { + let text = "C"; + let result = first_optional::<_, _, _, (), _, _>('A', 'B').parse(text); + assert!(result.is_err()); + } + #[test] fn sm_11() { let text = ".target sm_11"; diff --git a/ptx_parser_macros/src/lib.rs b/ptx_parser_macros/src/lib.rs index 5728c8c..f88395d 100644 --- a/ptx_parser_macros/src/lib.rs +++ b/ptx_parser_macros/src/lib.rs @@ -757,12 +757,13 @@ fn emit_definition_parser( DotModifierRef::Direct { optional: true, .. } | DotModifierRef::Indirect { optional: true, .. } => TokenStream::new(), }); - let arguments_parse = definition.arguments.0.iter().enumerate().map(|(idx, arg)| { + let (arguments_pattern, arguments_parser) = definition.arguments.0.iter().enumerate().rfold((quote! { () }, quote! { empty }), |(emitted_pattern, emitted_parser), (idx, arg)| { let comma = if idx == 0 || arg.pre_pipe { quote! { empty } } else { quote! { any.verify(|(t, _)| *t == #token_type::Comma).void() } }; + let pre_bracket = if arg.pre_bracket { quote! { any.verify(|(t, _)| *t == #token_type::LBracket).void() @@ -833,16 +834,20 @@ fn emit_definition_parser( #pattern.map(|(_, _, _, _, name, _, _)| name) } }; - if arg.optional { - quote! { - let #arg_name = opt(#inner_parser).parse_next(stream)?; - } + + let parser = if arg.optional { + quote! { first_optional(#inner_parser, #emitted_parser) } } else { - quote! { - let #arg_name = #inner_parser.parse_next(stream)?; - } - } + quote! { (#inner_parser, #emitted_parser) } + }; + + let pattern = quote! { ( #arg_name, #emitted_pattern ) }; + + (pattern, parser) }); + + let arguments_parse = quote! { let #arguments_pattern = ( #arguments_parser ).parse_next(stream)?; }; + let fn_args = definition.function_arguments(); let fn_name = format_ident!("{}_{}", opcode, fn_idx); let fn_call = quote! { @@ -863,7 +868,7 @@ fn emit_definition_parser( } } #(#unordered_parse_validations)* - #(#arguments_parse)* + #arguments_parse #fn_call } }