diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index b49503b..5932748 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -127,10 +127,11 @@ impl<'a, 'input> Debug for PtxParserState<'a, 'input> { } } -type PtxParser<'a, 'input> = Stateful<&'a [Token<'input>], PtxParserState<'a, 'input>>; +type PtxParser<'a, 'input> = + Stateful<&'a [(Token<'input>, logos::Span)], PtxParserState<'a, 'input>>; fn ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input str> { - any.verify_map(|t| { + any.verify_map(|(t, _)| { if let Token::Ident(text) = t { Some(text) } else if let Some(text) = t.opcode_text() { @@ -143,7 +144,7 @@ fn ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input str> } fn dot_ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input str> { - any.verify_map(|t| { + any.verify_map(|(t, _)| { if let Token::DotIdent(text) = t { Some(text) } else { @@ -154,7 +155,7 @@ fn dot_ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input } fn num<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(&'input str, u32, bool)> { - any.verify_map(|t| { + any.verify_map(|(t, _)| { Some(match t { Token::Hex(s) => { if s.ends_with('U') { @@ -218,7 +219,7 @@ fn int_immediate<'a, 'input>(input: &mut PtxParser<'a, 'input>) -> PResult(stream: &mut PtxParser<'a, 'input>) -> PResult { - take_error(any.verify_map(|t| match t { + take_error(any.verify_map(|(t, _)| match t { Token::F32(f) => Some(match u32::from_str_radix(&f[2..], 16) { Ok(x) => Ok(f32::from_bits(x)), Err(err) => Err((0.0, PtxError::from(err))), @@ -229,7 +230,7 @@ fn f32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { } fn f64<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { - take_error(any.verify_map(|t| match t { + take_error(any.verify_map(|(t, _)| match t { Token::F64(f) => Some(match u64::from_str_radix(&f[2..], 16) { Ok(x) => Ok(f64::from_bits(x)), Err(err) => Err((0.0, PtxError::from(err))), @@ -283,7 +284,7 @@ fn immediate_value<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult(text: &'input str) -> Option> { let lexer = Token::lexer(text); - let input = lexer.collect::, _>>().ok()?; + let mut input = lex_with_span(text).ok()?; let mut errors = Vec::new(); let state = PtxParserState::new(&mut errors); let parser = PtxParser { @@ -310,7 +311,7 @@ pub fn parse_module_checked<'input>( None => break, }; match maybe_token { - Ok(token) => tokens.push(token), + Ok(token) => tokens.push((token, lexer.span())), Err(mut err) => { err.0 = lexer.span(); errors.push(PtxError::from(err)) @@ -340,6 +341,17 @@ pub fn parse_module_checked<'input>( } } +fn lex_with_span<'input>( + text: &'input str, +) -> Result, logos::Span)>, TokenError> { + let lexer = Token::lexer(text); + let mut result = Vec::new(); + for (token, span) in lexer.spanned() { + result.push((token?, span)); + } + Ok(result) +} + fn module<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult> { ( version, @@ -487,9 +499,9 @@ fn linking_directives<'a, 'input>( repeat( 0.., dispatch! { any; - Token::DotExtern => empty.value(ast::LinkingDirective::EXTERN), - Token::DotVisible => empty.value(ast::LinkingDirective::VISIBLE), - Token::DotWeak => empty.value(ast::LinkingDirective::WEAK), + (Token::DotExtern, _) => empty.value(ast::LinkingDirective::EXTERN), + (Token::DotVisible, _) => empty.value(ast::LinkingDirective::VISIBLE), + (Token::DotWeak, _) => empty.value(ast::LinkingDirective::WEAK), _ => fail }, ) @@ -501,10 +513,10 @@ fn tuning_directive<'a, 'input>( stream: &mut PtxParser<'a, 'input>, ) -> PResult { dispatch! {any; - Token::DotMaxnreg => u32.map(ast::TuningDirective::MaxNReg), - Token::DotMaxntid => tuple1to3_u32.map(|(nx, ny, nz)| ast::TuningDirective::MaxNtid(nx, ny, nz)), - Token::DotReqntid => tuple1to3_u32.map(|(nx, ny, nz)| ast::TuningDirective::ReqNtid(nx, ny, nz)), - Token::DotMinnctapersm => u32.map(ast::TuningDirective::MinNCtaPerSm), + (Token::DotMaxnreg, _) => u32.map(ast::TuningDirective::MaxNReg), + (Token::DotMaxntid, _) => tuple1to3_u32.map(|(nx, ny, nz)| ast::TuningDirective::MaxNtid(nx, ny, nz)), + (Token::DotReqntid, _) => tuple1to3_u32.map(|(nx, ny, nz)| ast::TuningDirective::ReqNtid(nx, ny, nz)), + (Token::DotMinnctapersm, _) => u32.map(ast::TuningDirective::MinNCtaPerSm), _ => fail } .parse_next(stream) @@ -514,10 +526,10 @@ fn method_declaration<'a, 'input>( stream: &mut PtxParser<'a, 'input>, ) -> PResult> { dispatch! {any; - Token::DotEntry => (ident, kernel_arguments).map(|(name, input_arguments)| ast::MethodDeclaration{ + (Token::DotEntry, _) => (ident, kernel_arguments).map(|(name, input_arguments)| ast::MethodDeclaration{ return_arguments: Vec::new(), name: ast::MethodName::Kernel(name), input_arguments, shared_mem: None }), - Token::DotFunc => (opt(fn_arguments), ident, fn_arguments).map(|(return_arguments, name,input_arguments)| { + (Token::DotFunc, _) => (opt(fn_arguments), ident, fn_arguments).map(|(return_arguments, name,input_arguments)| { let return_arguments = return_arguments.unwrap_or_else(|| Vec::new()); let name = ast::MethodName::Func(name); ast::MethodDeclaration{ return_arguments, name, input_arguments, shared_mem: None } @@ -557,8 +569,8 @@ fn kernel_input<'a, 'input>( fn fn_input<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult> { dispatch! { any; - Token::DotParam => method_parameter(StateSpace::Param), - Token::DotReg => method_parameter(StateSpace::Reg), + (Token::DotParam, _) => method_parameter(StateSpace::Param), + (Token::DotReg, _) => method_parameter(StateSpace::Reg), _ => fail } .parse_next(stream) @@ -606,8 +618,8 @@ fn function_body<'a, 'input>( stream: &mut PtxParser<'a, 'input>, ) -> PResult>>>> { dispatch! {any; - Token::LBrace => terminated(repeat_without_none(statement), Token::RBrace).map(Some), - Token::Semicolon => empty.map(|_| None), + (Token::LBrace, _) => terminated(repeat_without_none(statement), Token::RBrace).map(Some), + (Token::Semicolon, _) => empty.map(|_| None), _ => fail } .parse_next(stream) @@ -616,22 +628,56 @@ fn function_body<'a, 'input>( fn statement<'a, 'input>( stream: &mut PtxParser<'a, 'input>, ) -> PResult>>> { - alt(( - label.map(Some), - debug_directive.map(|_| None), - terminated( - method_space - .flat_map(|space| multi_variable(false, space)) - .map(|var| Some(Statement::Variable(var))), - Token::Semicolon, - ), - predicated_instruction.map(Some), - pragma.map(|_| None), - block_statement.map(Some), - )) + with_recovery( + alt(( + label.map(Some), + debug_directive.map(|_| None), + terminated( + method_space + .flat_map(|space| multi_variable(false, space)) + .map(|var| Some(Statement::Variable(var))), + Token::Semicolon, + ), + predicated_instruction.map(Some), + pragma.map(|_| None), + block_statement.map(Some), + )), + fail, + ) + .map(Option::flatten) .parse_next(stream) } +fn with_recovery<'a, 'input: 'a, T>( + mut parser: impl Parser, T, ContextError>, + mut recovery: impl Parser, (), ContextError>, +) -> impl Parser, Option, ContextError> { + move |stream: &mut PtxParser<'a, 'input>| { + let input_start = stream.input.first().map(|(_, s)| s).cloned(); + match parser.parse_next(stream) { + Ok(value) => Ok(Some(value)), + Err(err) => { + recovery.parse_next(stream)?; + let input_end = stream.input.first().map(|(_, s)| s).cloned(); + let range = match (input_start, input_end) { + (Some(start), Some(end)) => Some(std::ops::Range { + start: start.start, + end: end.end, + }), + // We could handle `(Some(start), None)``, but this whole error recovery is to + // recover from unknown instructions, so we don't care about early end of stream + _ => None, + }; + stream + .state + .errors + .push(PtxError::UnrecognizedStatement(range)); + Err(err) + } + } + } +} + fn pragma<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { (Token::DotPragma, Token::String, Token::Semicolon) .void() @@ -935,7 +981,7 @@ fn vector_prefix<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult(stream: &mut PtxParser<'a, 'input>) -> PResult { - any.verify_map(|t| { + any.verify_map(|(t, _)| { Some(match t { Token::DotS8 => ScalarType::S8, Token::DotS16 => ScalarType::S16, @@ -1001,8 +1047,8 @@ fn debug_directive<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<() ident_literal("function_name"), ident, dispatch! { any; - Token::Comma => (ident_literal("inlined_at"), u32, u32, u32).void(), - Token::Plus => (u32, Token::Comma, ident_literal("inlined_at"), u32, u32, u32).void(), + (Token::Comma, _) => (ident_literal("inlined_at"), u32, u32, u32).void(), + (Token::Plus, _) => (u32, Token::Comma, ident_literal("inlined_at"), u32, u32, u32).void(), _ => fail }, )), @@ -1033,13 +1079,14 @@ fn repeat_without_none>( fn ident_literal< 'a, 'input, - I: Stream> + StreamIsPartial, + X, + I: Stream, X)> + StreamIsPartial, E: ParserError, >( s: &'input str, ) -> impl Parser + 'input { move |stream: &mut I| { - any.verify(|t| matches!(t, Token::Ident(text) if *text == s)) + any.verify(|(t, _)| matches!(t, Token::Ident(text) if *text == s)) .void() .parse_next(stream) } @@ -1086,8 +1133,8 @@ impl ast::ParsedOperand { let (_, r1, _, r2) = (Token::LBrace, ident, Token::Comma, ident).parse_next(stream)?; // TODO: parse .v8 literals dispatch! {any; - Token::RBrace => empty.map(|_| vec![r1, r2]), - Token::Comma => (ident, Token::Comma, ident, Token::RBrace).map(|(r3, _, r4, _)| vec![r1, r2, r3, r4]), + (Token::RBrace, _) => empty.map(|_| vec![r1, r2]), + (Token::Comma, _) => (ident, Token::Comma, ident, Token::RBrace).map(|(r3, _, r4, _)| vec![r1, r2, r3, r4]), _ => fail } .parse_next(stream) @@ -1146,8 +1193,8 @@ pub enum PtxError { ArrayInitalizer, #[error("")] NonExternPointer, - #[error("{start}:{end}")] - UnrecognizedStatement { start: usize, end: usize }, + #[error("")] + UnrecognizedStatement(Option>), #[error("{start}:{end}")] UnrecognizedDirective { start: usize, end: usize }, } @@ -1244,11 +1291,11 @@ impl<'a, T> StreamIsPartial for ReverseStream<'a, T> { } } -impl<'input, I: Stream + StreamIsPartial, E: ParserError> Parser - for Token<'input> +impl<'input, X, I: Stream + StreamIsPartial, E: ParserError> + Parser for Token<'input> { - fn parse_next(&mut self, input: &mut I) -> PResult { - any.verify(|t| t == self).parse_next(input) + fn parse_next(&mut self, input: &mut I) -> PResult<(Self, X), E> { + any.verify(|(t, _)| t == self).parse_next(input) } } @@ -1257,7 +1304,7 @@ fn bra<'a, 'input>( ) -> PResult>> { preceded( opt(Token::DotUni), - any.verify_map(|t| match t { + any.verify_map(|(t, _)| match t { Token::Ident(ident) => Some(ast::Instruction::Bra { arguments: BraArgs { src: ident }, }),