Count invalid directives instead of returning semi-broken ones

This commit is contained in:
Andrzej Janik 2025-09-04 00:49:55 +00:00
commit 5f9e676e15
5 changed files with 74 additions and 17 deletions

10
Cargo.lock generated
View file

@ -138,12 +138,6 @@ dependencies = [
"syn 2.0.89", "syn 2.0.89",
] ]
[[package]]
name = "bit-vec"
version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb"
[[package]] [[package]]
name = "bit-vec" name = "bit-vec"
version = "0.8.0" version = "0.8.0"
@ -462,7 +456,7 @@ dependencies = [
name = "dark_api" name = "dark_api"
version = "0.0.0" version = "0.0.0"
dependencies = [ dependencies = [
"bit-vec 0.8.0", "bit-vec",
"cglue", "cglue",
"cuda_types", "cuda_types",
"format", "format",
@ -2586,7 +2580,7 @@ dependencies = [
name = "ptx" name = "ptx"
version = "0.0.0" version = "0.0.0"
dependencies = [ dependencies = [
"bit-vec 0.6.3", "bit-vec",
"bitflags 1.3.2", "bitflags 1.3.2",
"comgr", "comgr",
"cuda_macros", "cuda_macros",

View file

@ -11,7 +11,7 @@ ptx_parser = { path = "../ptx_parser" }
llvm_zluda = { path = "../llvm_zluda" } llvm_zluda = { path = "../llvm_zluda" }
quick-error = "1.2" quick-error = "1.2"
thiserror = "1.0" thiserror = "1.0"
bit-vec = "0.6" bit-vec = "0.8"
half ="1.6" half ="1.6"
bitflags = "1.2" bitflags = "1.2"
rustc-hash = "2.0.0" rustc-hash = "2.0.0"

View file

@ -1492,6 +1492,7 @@ pub enum Directive<'input, O: Operand> {
pub struct Module<'input> { pub struct Module<'input> {
pub version: (u8, u8), pub version: (u8, u8),
pub directives: Vec<Directive<'input, ParsedOperand<&'input str>>>, pub directives: Vec<Directive<'input, ParsedOperand<&'input str>>>,
pub invalid_directives: usize,
} }
#[derive(Copy, Clone)] #[derive(Copy, Clone)]

View file

@ -417,13 +417,16 @@ fn module<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::Module
version, version,
target, target,
opt(address_size), opt(address_size),
repeat_without_none(directive), repeat_without_none_and_count(directive),
eof, eof,
) )
.map(|(version, _, _, directives, _)| ast::Module { .map(
|(version, _, _, (directives, invalid_directives), _)| ast::Module {
version, version,
directives, directives,
}), invalid_directives,
},
),
) )
.parse_next(stream) .parse_next(stream)
} }
@ -458,7 +461,8 @@ fn shader_model<'a>(stream: &mut &str) -> PResult<(u32, Option<char>)> {
fn directive<'a, 'input>( fn directive<'a, 'input>(
stream: &mut PtxParser<'a, 'input>, stream: &mut PtxParser<'a, 'input>,
) -> PResult<Option<ast::Directive<'input, ast::ParsedOperand<&'input str>>>> { ) -> PResult<Option<ast::Directive<'input, ast::ParsedOperand<&'input str>>>> {
trace( let errors = stream.state.errors.len();
let directive = trace(
"directive", "directive",
with_recovery( with_recovery(
alt(( alt((
@ -488,7 +492,11 @@ fn directive<'a, 'input>(
) )
.map(Option::flatten), .map(Option::flatten),
) )
.parse_next(stream) .parse_next(stream)?;
if errors != stream.state.errors.len() {
return Ok(None);
}
Ok(directive)
} }
fn module_variable<'a, 'input>( fn module_variable<'a, 'input>(
@ -1266,6 +1274,25 @@ fn repeat_without_none<Input: Stream, Output, Error: ParserError<Input>>(
) )
} }
fn repeat_without_none_and_count<Input: Stream, Output, Error: ParserError<Input>>(
parser: impl Parser<Input, Option<Output>, Error>,
) -> impl Parser<Input, (Vec<Output>, usize), Error> {
trace(
"repeat_without_none_and_count",
repeat(0.., parser).fold(
|| (Vec::new(), 0),
|(mut accumulator, mut nones): (Vec<_>, usize), item| {
if let Some(item) = item {
accumulator.push(item);
} else {
nones += 1;
}
(accumulator, nones)
},
),
)
}
fn ident_literal< fn ident_literal<
'a, 'a,
'input, 'input,
@ -3803,6 +3830,7 @@ derive_parser!(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::first_optional; use crate::first_optional;
use crate::module;
use crate::parse_module_checked; use crate::parse_module_checked;
use crate::section; use crate::section;
use crate::PtxError; use crate::PtxError;
@ -4100,4 +4128,38 @@ mod tests {
assert!(section.parse(stream).is_ok()); assert!(section.parse(stream).is_ok());
assert_eq!(errors.len(), 0); assert_eq!(errors.len(), 0);
} }
#[test]
fn report_unknown_directives() {
let text = "
.version 6.5
.target sm_30
.address_size 64
.global .b32 global[4] = { unknown (1), 2, 3, 4};
.visible .entry func1()
{
st.u64 [out_addr], temp2;
ret;
}
.visible .entry func1()
{
broken_instruction;
ret;
}";
let tokens = Token::lexer(text)
.map(|t| t.map(|t| (t, Span::default())))
.collect::<Result<Vec<_>, _>>()
.unwrap();
let mut errors = Vec::new();
let stream = super::PtxParser {
input: &tokens[..],
state: PtxParserState::new(text, &mut errors),
};
let module = module.parse(stream).unwrap();
assert_eq!(module.directives.len(), 1);
assert_eq!(module.invalid_directives, 2);
}
} }