From a420601128ec04d9ba4d2e30fddf78d300070e5a Mon Sep 17 00:00:00 2001 From: Violet Date: Wed, 13 Aug 2025 17:23:51 -0700 Subject: [PATCH] Add test for unrecognized statement error with vector braces (#472) The old code using `take_till_inclusive` assumed that a right brace would be the end of a block and therefore never part of a statement. However, some PTX statements can include vector operands. This meant that any unrecognized statement with a vector operand would backtrace and eventually produce an unhelpful context error rather than an `UnrecognizedStatement` error. This pull request also adds a mechanism for testing parser errors. --- ptx/src/test/mod.rs | 21 +++++++ ptx/src/test/spirv_run/mod.rs | 25 +------- ptx_parser/src/lib.rs | 115 ++++++++++++++++------------------ 3 files changed, 77 insertions(+), 84 deletions(-) diff --git a/ptx/src/test/mod.rs b/ptx/src/test/mod.rs index d81fdf8..2de24b7 100644 --- a/ptx/src/test/mod.rs +++ b/ptx/src/test/mod.rs @@ -3,6 +3,27 @@ use ptx_parser as ast; mod spirv_run; +#[macro_export] +macro_rules! read_test_file { + ($file:expr) => { + { + if cfg!(feature = "ci_build") { + include_str!($file).to_string() + } else { + use std::path::PathBuf; + // CARGO_MANIFEST_DIR is the crate directory (ptx), but file! is relative to the workspace root (and therefore also includes ptx). + let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + path.pop(); + path.push(file!()); + path.pop(); + path.push($file); + std::fs::read_to_string(path).unwrap() + } + } + }; +} +pub(crate) use read_test_file; + fn parse_and_assert(ptx_text: &str) { ast::parse_module_checked(ptx_text).unwrap(); } diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 4eb4c99..e508a5c 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -1,3 +1,4 @@ +use super::read_test_file; use crate::pass; use comgr::Comgr; use cuda_types::cuda::CUstream; @@ -10,32 +11,10 @@ use std::fmt::{self, Debug, Display, Formatter}; use std::fs::{self, File}; use std::io::Write; use std::mem; -use std::path::{Path, PathBuf}; +use std::path::Path; use std::ptr; use std::str; -#[cfg(not(feature = "ci_build"))] -macro_rules! read_test_file { - ($file:expr) => { - { - // CARGO_MANIFEST_DIR is the crate directory (ptx), but file! is relative to the workspace root (and therefore also includes ptx). - let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - path.pop(); - path.push(file!()); - path.pop(); - path.push($file); - std::fs::read_to_string(path).unwrap() - } - }; -} - -#[cfg(feature = "ci_build")] -macro_rules! read_test_file { - ($file:expr) => { - include_str!($file).to_string() - }; -} - macro_rules! test_ptx_llvm { ($fn_name:ident) => { paste::item! { diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index d788604..e545f49 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -732,79 +732,52 @@ fn statement<'a, 'input>( pragma.map(|_| None), block_statement.map(Some), )), - take_till_inclusive( - |(t, _)| *t == Token::RBrace, - |(t, _)| match t { - Token::Semicolon | Token::Colon => true, - _ => false, - }, - ), + take_till_end_of_statement(), |text| PtxError::UnrecognizedStatement(text.unwrap_or("")), ) .map(Option::flatten) .parse_next(stream) } -fn take_till_inclusive>( - backtrack_token: impl Fn(&I::Token) -> bool, - end_token: impl Fn(&I::Token) -> bool, -) -> impl Parser::Slice, E> { - fn get_offset>( - input: &mut I, - backtrack_token: &impl Fn(&I::Token) -> bool, - end_token: &impl Fn(&I::Token) -> bool, - should_backtrack: &mut bool, - ) -> Result { - *should_backtrack = false; - let mut hit = false; - for (offset, token) in input.iter_offsets() { - if hit { - return Ok(offset); - } else { - if backtrack_token(&token) { - *should_backtrack = true; - return Ok(offset); +fn take_till_end_of_statement< + 'a, + I: Stream, std::ops::Range)>, + E: ParserError, +>() -> impl Parser::Slice, E> { + trace("take_till_end_of_statement", move |stream: &mut I| { + let mut depth = 0; + + let mut iterator = stream.iter_offsets().peekable(); + while let Some((current_offset, (token, _))) = iterator.next() { + match token { + Token::LBrace => { + depth += 1; } - if end_token(&token) { - hit = true; + Token::RBrace => { + if depth == 0 { + return Err(ErrMode::from_error_kind( + stream, + winnow::error::ErrorKind::Token, + )); + } + depth -= 1; } + Token::Semicolon | Token::Colon => { + let offset = if let Some((next_offset, _)) = iterator.peek() { + *next_offset + } else { + current_offset + }; + return Ok(stream.next_slice(offset)); + } + _ => {} } } - Err(ParserError::from_error_kind(input, ErrorKind::Eof)) - } - trace("take_till_inclusive", move |stream: &mut I| { - let mut should_backtrack = false; - let offset = get_offset(stream, &backtrack_token, &end_token, &mut should_backtrack)?; - let result = stream.next_slice(offset); - if should_backtrack { - Err(ErrMode::from_error_kind( - stream, - winnow::error::ErrorKind::Token, - )) - } else { - Ok(result) - } + + Err(ParserError::from_error_kind(stream, ErrorKind::Eof)) }) } -/* -pub fn take_till_or_backtrack_eof( - set: Set, -) -> impl Parser::Slice, Error> -where - Input: StreamIsPartial + Stream, - Set: winnow::stream::ContainsToken<::Token>, - Error: ParserError, -{ - move |stream: &mut Input| { - if stream.eof_offset() == 0 { - return ; - } - take_till(0.., set) - } -} - */ - fn with_recovery<'a, 'input: 'a, T>( mut parser: impl Parser, T, ContextError>, recovery: impl Parser, &'a [(Token<'input>, logos::Span)], ContextError>, @@ -1344,7 +1317,7 @@ impl ast::ParsedOperand { } } -#[derive(Debug, thiserror::Error, strum::AsRefStr)] +#[derive(Debug, thiserror::Error, PartialEq, strum::AsRefStr)] pub enum PtxError<'input> { #[error("{source}")] ParseInt { @@ -3867,6 +3840,26 @@ mod tests { )); } + #[test] + fn report_unknown_instruction_with_braces() { + let text = " + .version 6.5 + .target sm_60 + .address_size 64 + + .visible .entry unrecognized_braces( + ) + { + mov.u32 foo, {} {}; + ret; + }"; + let errors = parse_module_checked(text).err().unwrap(); + assert_eq!( + errors, + vec![PtxError::UnrecognizedStatement("mov.u32 foo, {} {};")] + ); + } + #[test] fn report_unknown_directive() { let text = "