Track spans alongside tokens

This commit is contained in:
Andrzej Janik 2024-10-31 14:55:14 +01:00
parent 3870a96592
commit 6f5d20af71

View file

@ -127,10 +127,11 @@ impl<'a, 'input> Debug for PtxParserState<'a, 'input> {
}
}
type PtxParser<'a, 'input> = Stateful<&'a [Token<'input>], PtxParserState<'a, 'input>>;
type PtxParser<'a, 'input> =
Stateful<&'a [(Token<'input>, logos::Span)], PtxParserState<'a, 'input>>;
fn ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input str> {
any.verify_map(|t| {
any.verify_map(|(t, _)| {
if let Token::Ident(text) = t {
Some(text)
} else if let Some(text) = t.opcode_text() {
@ -143,7 +144,7 @@ fn ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input str>
}
fn dot_ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input str> {
any.verify_map(|t| {
any.verify_map(|(t, _)| {
if let Token::DotIdent(text) = t {
Some(text)
} else {
@ -154,7 +155,7 @@ fn dot_ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input
}
fn num<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(&'input str, u32, bool)> {
any.verify_map(|t| {
any.verify_map(|(t, _)| {
Some(match t {
Token::Hex(s) => {
if s.ends_with('U') {
@ -218,7 +219,7 @@ fn int_immediate<'a, 'input>(input: &mut PtxParser<'a, 'input>) -> PResult<ast::
}
fn f32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<f32> {
take_error(any.verify_map(|t| match t {
take_error(any.verify_map(|(t, _)| match t {
Token::F32(f) => Some(match u32::from_str_radix(&f[2..], 16) {
Ok(x) => Ok(f32::from_bits(x)),
Err(err) => Err((0.0, PtxError::from(err))),
@ -229,7 +230,7 @@ fn f32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<f32> {
}
fn f64<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<f64> {
take_error(any.verify_map(|t| match t {
take_error(any.verify_map(|(t, _)| match t {
Token::F64(f) => Some(match u64::from_str_radix(&f[2..], 16) {
Ok(x) => Ok(f64::from_bits(x)),
Err(err) => Err((0.0, PtxError::from(err))),
@ -283,7 +284,7 @@ fn immediate_value<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<as
pub fn parse_module_unchecked<'input>(text: &'input str) -> Option<ast::Module<'input>> {
let lexer = Token::lexer(text);
let input = lexer.collect::<Result<Vec<_>, _>>().ok()?;
let mut input = lex_with_span(text).ok()?;
let mut errors = Vec::new();
let state = PtxParserState::new(&mut errors);
let parser = PtxParser {
@ -310,7 +311,7 @@ pub fn parse_module_checked<'input>(
None => break,
};
match maybe_token {
Ok(token) => tokens.push(token),
Ok(token) => tokens.push((token, lexer.span())),
Err(mut err) => {
err.0 = lexer.span();
errors.push(PtxError::from(err))
@ -340,6 +341,17 @@ pub fn parse_module_checked<'input>(
}
}
fn lex_with_span<'input>(
text: &'input str,
) -> Result<Vec<(Token<'input>, logos::Span)>, TokenError> {
let lexer = Token::lexer(text);
let mut result = Vec::new();
for (token, span) in lexer.spanned() {
result.push((token?, span));
}
Ok(result)
}
fn module<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::Module<'input>> {
(
version,
@ -487,9 +499,9 @@ fn linking_directives<'a, 'input>(
repeat(
0..,
dispatch! { any;
Token::DotExtern => empty.value(ast::LinkingDirective::EXTERN),
Token::DotVisible => empty.value(ast::LinkingDirective::VISIBLE),
Token::DotWeak => empty.value(ast::LinkingDirective::WEAK),
(Token::DotExtern, _) => empty.value(ast::LinkingDirective::EXTERN),
(Token::DotVisible, _) => empty.value(ast::LinkingDirective::VISIBLE),
(Token::DotWeak, _) => empty.value(ast::LinkingDirective::WEAK),
_ => fail
},
)
@ -501,10 +513,10 @@ fn tuning_directive<'a, 'input>(
stream: &mut PtxParser<'a, 'input>,
) -> PResult<ast::TuningDirective> {
dispatch! {any;
Token::DotMaxnreg => u32.map(ast::TuningDirective::MaxNReg),
Token::DotMaxntid => tuple1to3_u32.map(|(nx, ny, nz)| ast::TuningDirective::MaxNtid(nx, ny, nz)),
Token::DotReqntid => tuple1to3_u32.map(|(nx, ny, nz)| ast::TuningDirective::ReqNtid(nx, ny, nz)),
Token::DotMinnctapersm => u32.map(ast::TuningDirective::MinNCtaPerSm),
(Token::DotMaxnreg, _) => u32.map(ast::TuningDirective::MaxNReg),
(Token::DotMaxntid, _) => tuple1to3_u32.map(|(nx, ny, nz)| ast::TuningDirective::MaxNtid(nx, ny, nz)),
(Token::DotReqntid, _) => tuple1to3_u32.map(|(nx, ny, nz)| ast::TuningDirective::ReqNtid(nx, ny, nz)),
(Token::DotMinnctapersm, _) => u32.map(ast::TuningDirective::MinNCtaPerSm),
_ => fail
}
.parse_next(stream)
@ -514,10 +526,10 @@ fn method_declaration<'a, 'input>(
stream: &mut PtxParser<'a, 'input>,
) -> PResult<ast::MethodDeclaration<'input, &'input str>> {
dispatch! {any;
Token::DotEntry => (ident, kernel_arguments).map(|(name, input_arguments)| ast::MethodDeclaration{
(Token::DotEntry, _) => (ident, kernel_arguments).map(|(name, input_arguments)| ast::MethodDeclaration{
return_arguments: Vec::new(), name: ast::MethodName::Kernel(name), input_arguments, shared_mem: None
}),
Token::DotFunc => (opt(fn_arguments), ident, fn_arguments).map(|(return_arguments, name,input_arguments)| {
(Token::DotFunc, _) => (opt(fn_arguments), ident, fn_arguments).map(|(return_arguments, name,input_arguments)| {
let return_arguments = return_arguments.unwrap_or_else(|| Vec::new());
let name = ast::MethodName::Func(name);
ast::MethodDeclaration{ return_arguments, name, input_arguments, shared_mem: None }
@ -557,8 +569,8 @@ fn kernel_input<'a, 'input>(
fn fn_input<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::Variable<&'input str>> {
dispatch! { any;
Token::DotParam => method_parameter(StateSpace::Param),
Token::DotReg => method_parameter(StateSpace::Reg),
(Token::DotParam, _) => method_parameter(StateSpace::Param),
(Token::DotReg, _) => method_parameter(StateSpace::Reg),
_ => fail
}
.parse_next(stream)
@ -606,8 +618,8 @@ fn function_body<'a, 'input>(
stream: &mut PtxParser<'a, 'input>,
) -> PResult<Option<Vec<ast::Statement<ParsedOperandStr<'input>>>>> {
dispatch! {any;
Token::LBrace => terminated(repeat_without_none(statement), Token::RBrace).map(Some),
Token::Semicolon => empty.map(|_| None),
(Token::LBrace, _) => terminated(repeat_without_none(statement), Token::RBrace).map(Some),
(Token::Semicolon, _) => empty.map(|_| None),
_ => fail
}
.parse_next(stream)
@ -616,22 +628,56 @@ fn function_body<'a, 'input>(
fn statement<'a, 'input>(
stream: &mut PtxParser<'a, 'input>,
) -> PResult<Option<Statement<ParsedOperandStr<'input>>>> {
alt((
label.map(Some),
debug_directive.map(|_| None),
terminated(
method_space
.flat_map(|space| multi_variable(false, space))
.map(|var| Some(Statement::Variable(var))),
Token::Semicolon,
),
predicated_instruction.map(Some),
pragma.map(|_| None),
block_statement.map(Some),
))
with_recovery(
alt((
label.map(Some),
debug_directive.map(|_| None),
terminated(
method_space
.flat_map(|space| multi_variable(false, space))
.map(|var| Some(Statement::Variable(var))),
Token::Semicolon,
),
predicated_instruction.map(Some),
pragma.map(|_| None),
block_statement.map(Some),
)),
fail,
)
.map(Option::flatten)
.parse_next(stream)
}
fn with_recovery<'a, 'input: 'a, T>(
mut parser: impl Parser<PtxParser<'a, 'input>, T, ContextError>,
mut recovery: impl Parser<PtxParser<'a, 'input>, (), ContextError>,
) -> impl Parser<PtxParser<'a, 'input>, Option<T>, ContextError> {
move |stream: &mut PtxParser<'a, 'input>| {
let input_start = stream.input.first().map(|(_, s)| s).cloned();
match parser.parse_next(stream) {
Ok(value) => Ok(Some(value)),
Err(err) => {
recovery.parse_next(stream)?;
let input_end = stream.input.first().map(|(_, s)| s).cloned();
let range = match (input_start, input_end) {
(Some(start), Some(end)) => Some(std::ops::Range {
start: start.start,
end: end.end,
}),
// We could handle `(Some(start), None)``, but this whole error recovery is to
// recover from unknown instructions, so we don't care about early end of stream
_ => None,
};
stream
.state
.errors
.push(PtxError::UnrecognizedStatement(range));
Err(err)
}
}
}
}
fn pragma<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> {
(Token::DotPragma, Token::String, Token::Semicolon)
.void()
@ -935,7 +981,7 @@ fn vector_prefix<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<Opti
}
fn scalar_type<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ScalarType> {
any.verify_map(|t| {
any.verify_map(|(t, _)| {
Some(match t {
Token::DotS8 => ScalarType::S8,
Token::DotS16 => ScalarType::S16,
@ -1001,8 +1047,8 @@ fn debug_directive<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()
ident_literal("function_name"),
ident,
dispatch! { any;
Token::Comma => (ident_literal("inlined_at"), u32, u32, u32).void(),
Token::Plus => (u32, Token::Comma, ident_literal("inlined_at"), u32, u32, u32).void(),
(Token::Comma, _) => (ident_literal("inlined_at"), u32, u32, u32).void(),
(Token::Plus, _) => (u32, Token::Comma, ident_literal("inlined_at"), u32, u32, u32).void(),
_ => fail
},
)),
@ -1033,13 +1079,14 @@ fn repeat_without_none<Input: Stream, Output, Error: ParserError<Input>>(
fn ident_literal<
'a,
'input,
I: Stream<Token = Token<'input>> + StreamIsPartial,
X,
I: Stream<Token = (Token<'input>, X)> + StreamIsPartial,
E: ParserError<I>,
>(
s: &'input str,
) -> impl Parser<I, (), E> + 'input {
move |stream: &mut I| {
any.verify(|t| matches!(t, Token::Ident(text) if *text == s))
any.verify(|(t, _)| matches!(t, Token::Ident(text) if *text == s))
.void()
.parse_next(stream)
}
@ -1086,8 +1133,8 @@ impl<Ident> ast::ParsedOperand<Ident> {
let (_, r1, _, r2) = (Token::LBrace, ident, Token::Comma, ident).parse_next(stream)?;
// TODO: parse .v8 literals
dispatch! {any;
Token::RBrace => empty.map(|_| vec![r1, r2]),
Token::Comma => (ident, Token::Comma, ident, Token::RBrace).map(|(r3, _, r4, _)| vec![r1, r2, r3, r4]),
(Token::RBrace, _) => empty.map(|_| vec![r1, r2]),
(Token::Comma, _) => (ident, Token::Comma, ident, Token::RBrace).map(|(r3, _, r4, _)| vec![r1, r2, r3, r4]),
_ => fail
}
.parse_next(stream)
@ -1146,8 +1193,8 @@ pub enum PtxError {
ArrayInitalizer,
#[error("")]
NonExternPointer,
#[error("{start}:{end}")]
UnrecognizedStatement { start: usize, end: usize },
#[error("")]
UnrecognizedStatement(Option<std::ops::Range<usize>>),
#[error("{start}:{end}")]
UnrecognizedDirective { start: usize, end: usize },
}
@ -1244,11 +1291,11 @@ impl<'a, T> StreamIsPartial for ReverseStream<'a, T> {
}
}
impl<'input, I: Stream<Token = Self> + StreamIsPartial, E: ParserError<I>> Parser<I, Self, E>
for Token<'input>
impl<'input, X, I: Stream<Token = (Self, X)> + StreamIsPartial, E: ParserError<I>>
Parser<I, (Self, X), E> for Token<'input>
{
fn parse_next(&mut self, input: &mut I) -> PResult<Self, E> {
any.verify(|t| t == self).parse_next(input)
fn parse_next(&mut self, input: &mut I) -> PResult<(Self, X), E> {
any.verify(|(t, _)| t == self).parse_next(input)
}
}
@ -1257,7 +1304,7 @@ fn bra<'a, 'input>(
) -> PResult<ast::Instruction<ParsedOperandStr<'input>>> {
preceded(
opt(Token::DotUni),
any.verify_map(|t| match t {
any.verify_map(|(t, _)| match t {
Token::Ident(ident) => Some(ast::Instruction::Bra {
arguments: BraArgs { src: ident },
}),