Track spans alongside tokens

This commit is contained in:
Andrzej Janik 2024-10-31 14:55:14 +01:00
commit 6f5d20af71

View file

@ -127,10 +127,11 @@ impl<'a, 'input> Debug for PtxParserState<'a, 'input> {
} }
} }
type PtxParser<'a, 'input> = Stateful<&'a [Token<'input>], PtxParserState<'a, 'input>>; type PtxParser<'a, 'input> =
Stateful<&'a [(Token<'input>, logos::Span)], PtxParserState<'a, 'input>>;
fn ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input str> { fn ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input str> {
any.verify_map(|t| { any.verify_map(|(t, _)| {
if let Token::Ident(text) = t { if let Token::Ident(text) = t {
Some(text) Some(text)
} else if let Some(text) = t.opcode_text() { } else if let Some(text) = t.opcode_text() {
@ -143,7 +144,7 @@ fn ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input str>
} }
fn dot_ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input str> { fn dot_ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input str> {
any.verify_map(|t| { any.verify_map(|(t, _)| {
if let Token::DotIdent(text) = t { if let Token::DotIdent(text) = t {
Some(text) Some(text)
} else { } else {
@ -154,7 +155,7 @@ fn dot_ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input
} }
fn num<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(&'input str, u32, bool)> { fn num<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(&'input str, u32, bool)> {
any.verify_map(|t| { any.verify_map(|(t, _)| {
Some(match t { Some(match t {
Token::Hex(s) => { Token::Hex(s) => {
if s.ends_with('U') { if s.ends_with('U') {
@ -218,7 +219,7 @@ fn int_immediate<'a, 'input>(input: &mut PtxParser<'a, 'input>) -> PResult<ast::
} }
fn f32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<f32> { fn f32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<f32> {
take_error(any.verify_map(|t| match t { take_error(any.verify_map(|(t, _)| match t {
Token::F32(f) => Some(match u32::from_str_radix(&f[2..], 16) { Token::F32(f) => Some(match u32::from_str_radix(&f[2..], 16) {
Ok(x) => Ok(f32::from_bits(x)), Ok(x) => Ok(f32::from_bits(x)),
Err(err) => Err((0.0, PtxError::from(err))), Err(err) => Err((0.0, PtxError::from(err))),
@ -229,7 +230,7 @@ fn f32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<f32> {
} }
fn f64<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<f64> { fn f64<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<f64> {
take_error(any.verify_map(|t| match t { take_error(any.verify_map(|(t, _)| match t {
Token::F64(f) => Some(match u64::from_str_radix(&f[2..], 16) { Token::F64(f) => Some(match u64::from_str_radix(&f[2..], 16) {
Ok(x) => Ok(f64::from_bits(x)), Ok(x) => Ok(f64::from_bits(x)),
Err(err) => Err((0.0, PtxError::from(err))), Err(err) => Err((0.0, PtxError::from(err))),
@ -283,7 +284,7 @@ fn immediate_value<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<as
pub fn parse_module_unchecked<'input>(text: &'input str) -> Option<ast::Module<'input>> { pub fn parse_module_unchecked<'input>(text: &'input str) -> Option<ast::Module<'input>> {
let lexer = Token::lexer(text); let lexer = Token::lexer(text);
let input = lexer.collect::<Result<Vec<_>, _>>().ok()?; let mut input = lex_with_span(text).ok()?;
let mut errors = Vec::new(); let mut errors = Vec::new();
let state = PtxParserState::new(&mut errors); let state = PtxParserState::new(&mut errors);
let parser = PtxParser { let parser = PtxParser {
@ -310,7 +311,7 @@ pub fn parse_module_checked<'input>(
None => break, None => break,
}; };
match maybe_token { match maybe_token {
Ok(token) => tokens.push(token), Ok(token) => tokens.push((token, lexer.span())),
Err(mut err) => { Err(mut err) => {
err.0 = lexer.span(); err.0 = lexer.span();
errors.push(PtxError::from(err)) errors.push(PtxError::from(err))
@ -340,6 +341,17 @@ pub fn parse_module_checked<'input>(
} }
} }
fn lex_with_span<'input>(
text: &'input str,
) -> Result<Vec<(Token<'input>, logos::Span)>, TokenError> {
let lexer = Token::lexer(text);
let mut result = Vec::new();
for (token, span) in lexer.spanned() {
result.push((token?, span));
}
Ok(result)
}
fn module<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::Module<'input>> { fn module<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::Module<'input>> {
( (
version, version,
@ -487,9 +499,9 @@ fn linking_directives<'a, 'input>(
repeat( repeat(
0.., 0..,
dispatch! { any; dispatch! { any;
Token::DotExtern => empty.value(ast::LinkingDirective::EXTERN), (Token::DotExtern, _) => empty.value(ast::LinkingDirective::EXTERN),
Token::DotVisible => empty.value(ast::LinkingDirective::VISIBLE), (Token::DotVisible, _) => empty.value(ast::LinkingDirective::VISIBLE),
Token::DotWeak => empty.value(ast::LinkingDirective::WEAK), (Token::DotWeak, _) => empty.value(ast::LinkingDirective::WEAK),
_ => fail _ => fail
}, },
) )
@ -501,10 +513,10 @@ fn tuning_directive<'a, 'input>(
stream: &mut PtxParser<'a, 'input>, stream: &mut PtxParser<'a, 'input>,
) -> PResult<ast::TuningDirective> { ) -> PResult<ast::TuningDirective> {
dispatch! {any; dispatch! {any;
Token::DotMaxnreg => u32.map(ast::TuningDirective::MaxNReg), (Token::DotMaxnreg, _) => u32.map(ast::TuningDirective::MaxNReg),
Token::DotMaxntid => tuple1to3_u32.map(|(nx, ny, nz)| ast::TuningDirective::MaxNtid(nx, ny, nz)), (Token::DotMaxntid, _) => tuple1to3_u32.map(|(nx, ny, nz)| ast::TuningDirective::MaxNtid(nx, ny, nz)),
Token::DotReqntid => tuple1to3_u32.map(|(nx, ny, nz)| ast::TuningDirective::ReqNtid(nx, ny, nz)), (Token::DotReqntid, _) => tuple1to3_u32.map(|(nx, ny, nz)| ast::TuningDirective::ReqNtid(nx, ny, nz)),
Token::DotMinnctapersm => u32.map(ast::TuningDirective::MinNCtaPerSm), (Token::DotMinnctapersm, _) => u32.map(ast::TuningDirective::MinNCtaPerSm),
_ => fail _ => fail
} }
.parse_next(stream) .parse_next(stream)
@ -514,10 +526,10 @@ fn method_declaration<'a, 'input>(
stream: &mut PtxParser<'a, 'input>, stream: &mut PtxParser<'a, 'input>,
) -> PResult<ast::MethodDeclaration<'input, &'input str>> { ) -> PResult<ast::MethodDeclaration<'input, &'input str>> {
dispatch! {any; dispatch! {any;
Token::DotEntry => (ident, kernel_arguments).map(|(name, input_arguments)| ast::MethodDeclaration{ (Token::DotEntry, _) => (ident, kernel_arguments).map(|(name, input_arguments)| ast::MethodDeclaration{
return_arguments: Vec::new(), name: ast::MethodName::Kernel(name), input_arguments, shared_mem: None return_arguments: Vec::new(), name: ast::MethodName::Kernel(name), input_arguments, shared_mem: None
}), }),
Token::DotFunc => (opt(fn_arguments), ident, fn_arguments).map(|(return_arguments, name,input_arguments)| { (Token::DotFunc, _) => (opt(fn_arguments), ident, fn_arguments).map(|(return_arguments, name,input_arguments)| {
let return_arguments = return_arguments.unwrap_or_else(|| Vec::new()); let return_arguments = return_arguments.unwrap_or_else(|| Vec::new());
let name = ast::MethodName::Func(name); let name = ast::MethodName::Func(name);
ast::MethodDeclaration{ return_arguments, name, input_arguments, shared_mem: None } ast::MethodDeclaration{ return_arguments, name, input_arguments, shared_mem: None }
@ -557,8 +569,8 @@ fn kernel_input<'a, 'input>(
fn fn_input<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::Variable<&'input str>> { fn fn_input<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::Variable<&'input str>> {
dispatch! { any; dispatch! { any;
Token::DotParam => method_parameter(StateSpace::Param), (Token::DotParam, _) => method_parameter(StateSpace::Param),
Token::DotReg => method_parameter(StateSpace::Reg), (Token::DotReg, _) => method_parameter(StateSpace::Reg),
_ => fail _ => fail
} }
.parse_next(stream) .parse_next(stream)
@ -606,8 +618,8 @@ fn function_body<'a, 'input>(
stream: &mut PtxParser<'a, 'input>, stream: &mut PtxParser<'a, 'input>,
) -> PResult<Option<Vec<ast::Statement<ParsedOperandStr<'input>>>>> { ) -> PResult<Option<Vec<ast::Statement<ParsedOperandStr<'input>>>>> {
dispatch! {any; dispatch! {any;
Token::LBrace => terminated(repeat_without_none(statement), Token::RBrace).map(Some), (Token::LBrace, _) => terminated(repeat_without_none(statement), Token::RBrace).map(Some),
Token::Semicolon => empty.map(|_| None), (Token::Semicolon, _) => empty.map(|_| None),
_ => fail _ => fail
} }
.parse_next(stream) .parse_next(stream)
@ -616,6 +628,7 @@ fn function_body<'a, 'input>(
fn statement<'a, 'input>( fn statement<'a, 'input>(
stream: &mut PtxParser<'a, 'input>, stream: &mut PtxParser<'a, 'input>,
) -> PResult<Option<Statement<ParsedOperandStr<'input>>>> { ) -> PResult<Option<Statement<ParsedOperandStr<'input>>>> {
with_recovery(
alt(( alt((
label.map(Some), label.map(Some),
debug_directive.map(|_| None), debug_directive.map(|_| None),
@ -628,10 +641,43 @@ fn statement<'a, 'input>(
predicated_instruction.map(Some), predicated_instruction.map(Some),
pragma.map(|_| None), pragma.map(|_| None),
block_statement.map(Some), block_statement.map(Some),
)) )),
fail,
)
.map(Option::flatten)
.parse_next(stream) .parse_next(stream)
} }
fn with_recovery<'a, 'input: 'a, T>(
mut parser: impl Parser<PtxParser<'a, 'input>, T, ContextError>,
mut recovery: impl Parser<PtxParser<'a, 'input>, (), ContextError>,
) -> impl Parser<PtxParser<'a, 'input>, Option<T>, ContextError> {
move |stream: &mut PtxParser<'a, 'input>| {
let input_start = stream.input.first().map(|(_, s)| s).cloned();
match parser.parse_next(stream) {
Ok(value) => Ok(Some(value)),
Err(err) => {
recovery.parse_next(stream)?;
let input_end = stream.input.first().map(|(_, s)| s).cloned();
let range = match (input_start, input_end) {
(Some(start), Some(end)) => Some(std::ops::Range {
start: start.start,
end: end.end,
}),
// We could handle `(Some(start), None)``, but this whole error recovery is to
// recover from unknown instructions, so we don't care about early end of stream
_ => None,
};
stream
.state
.errors
.push(PtxError::UnrecognizedStatement(range));
Err(err)
}
}
}
}
fn pragma<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { fn pragma<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> {
(Token::DotPragma, Token::String, Token::Semicolon) (Token::DotPragma, Token::String, Token::Semicolon)
.void() .void()
@ -935,7 +981,7 @@ fn vector_prefix<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<Opti
} }
fn scalar_type<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ScalarType> { fn scalar_type<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ScalarType> {
any.verify_map(|t| { any.verify_map(|(t, _)| {
Some(match t { Some(match t {
Token::DotS8 => ScalarType::S8, Token::DotS8 => ScalarType::S8,
Token::DotS16 => ScalarType::S16, Token::DotS16 => ScalarType::S16,
@ -1001,8 +1047,8 @@ fn debug_directive<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()
ident_literal("function_name"), ident_literal("function_name"),
ident, ident,
dispatch! { any; dispatch! { any;
Token::Comma => (ident_literal("inlined_at"), u32, u32, u32).void(), (Token::Comma, _) => (ident_literal("inlined_at"), u32, u32, u32).void(),
Token::Plus => (u32, Token::Comma, ident_literal("inlined_at"), u32, u32, u32).void(), (Token::Plus, _) => (u32, Token::Comma, ident_literal("inlined_at"), u32, u32, u32).void(),
_ => fail _ => fail
}, },
)), )),
@ -1033,13 +1079,14 @@ fn repeat_without_none<Input: Stream, Output, Error: ParserError<Input>>(
fn ident_literal< fn ident_literal<
'a, 'a,
'input, 'input,
I: Stream<Token = Token<'input>> + StreamIsPartial, X,
I: Stream<Token = (Token<'input>, X)> + StreamIsPartial,
E: ParserError<I>, E: ParserError<I>,
>( >(
s: &'input str, s: &'input str,
) -> impl Parser<I, (), E> + 'input { ) -> impl Parser<I, (), E> + 'input {
move |stream: &mut I| { move |stream: &mut I| {
any.verify(|t| matches!(t, Token::Ident(text) if *text == s)) any.verify(|(t, _)| matches!(t, Token::Ident(text) if *text == s))
.void() .void()
.parse_next(stream) .parse_next(stream)
} }
@ -1086,8 +1133,8 @@ impl<Ident> ast::ParsedOperand<Ident> {
let (_, r1, _, r2) = (Token::LBrace, ident, Token::Comma, ident).parse_next(stream)?; let (_, r1, _, r2) = (Token::LBrace, ident, Token::Comma, ident).parse_next(stream)?;
// TODO: parse .v8 literals // TODO: parse .v8 literals
dispatch! {any; dispatch! {any;
Token::RBrace => empty.map(|_| vec![r1, r2]), (Token::RBrace, _) => empty.map(|_| vec![r1, r2]),
Token::Comma => (ident, Token::Comma, ident, Token::RBrace).map(|(r3, _, r4, _)| vec![r1, r2, r3, r4]), (Token::Comma, _) => (ident, Token::Comma, ident, Token::RBrace).map(|(r3, _, r4, _)| vec![r1, r2, r3, r4]),
_ => fail _ => fail
} }
.parse_next(stream) .parse_next(stream)
@ -1146,8 +1193,8 @@ pub enum PtxError {
ArrayInitalizer, ArrayInitalizer,
#[error("")] #[error("")]
NonExternPointer, NonExternPointer,
#[error("{start}:{end}")] #[error("")]
UnrecognizedStatement { start: usize, end: usize }, UnrecognizedStatement(Option<std::ops::Range<usize>>),
#[error("{start}:{end}")] #[error("{start}:{end}")]
UnrecognizedDirective { start: usize, end: usize }, UnrecognizedDirective { start: usize, end: usize },
} }
@ -1244,11 +1291,11 @@ impl<'a, T> StreamIsPartial for ReverseStream<'a, T> {
} }
} }
impl<'input, I: Stream<Token = Self> + StreamIsPartial, E: ParserError<I>> Parser<I, Self, E> impl<'input, X, I: Stream<Token = (Self, X)> + StreamIsPartial, E: ParserError<I>>
for Token<'input> Parser<I, (Self, X), E> for Token<'input>
{ {
fn parse_next(&mut self, input: &mut I) -> PResult<Self, E> { fn parse_next(&mut self, input: &mut I) -> PResult<(Self, X), E> {
any.verify(|t| t == self).parse_next(input) any.verify(|(t, _)| t == self).parse_next(input)
} }
} }
@ -1257,7 +1304,7 @@ fn bra<'a, 'input>(
) -> PResult<ast::Instruction<ParsedOperandStr<'input>>> { ) -> PResult<ast::Instruction<ParsedOperandStr<'input>>> {
preceded( preceded(
opt(Token::DotUni), opt(Token::DotUni),
any.verify_map(|t| match t { any.verify_map(|(t, _)| match t {
Token::Ident(ident) => Some(ast::Instruction::Bra { Token::Ident(ident) => Some(ast::Instruction::Bra {
arguments: BraArgs { src: ident }, arguments: BraArgs { src: ident },
}), }),