Change the parser, attempt #1

This commit is contained in:
Andrzej Janik 2024-08-22 04:24:07 +02:00
parent 71e025845c
commit 90e65a46f1
9 changed files with 792 additions and 4141 deletions

View file

@ -2,13 +2,12 @@
name = "ptx"
version = "0.0.0"
authors = ["Andrzej Janik <vosen@vosen.pl>"]
edition = "2018"
edition = "2021"
[lib]
[dependencies]
lalrpop-util = "0.19"
regex = "1"
ptx_parser = { path = "../ptx_parser" }
rspirv = "0.7"
spirv_headers = "1.5"
quick-error = "1.2"
@ -17,10 +16,6 @@ bit-vec = "0.6"
half ="1.6"
bitflags = "1.2"
[build-dependencies.lalrpop]
version = "0.19"
features = ["lexer"]
[dev-dependencies]
hip_runtime-sys = { path = "../hip_runtime-sys" }
tempfile = "3"

View file

@ -1,5 +0,0 @@
extern crate lalrpop;
fn main() {
lalrpop::process_root().unwrap();
}

File diff suppressed because it is too large Load diff

View file

@ -1,10 +1,7 @@
#[cfg(test)]
extern crate paste;
#[macro_use]
extern crate lalrpop_util;
#[macro_use]
extern crate quick_error;
extern crate bit_vec;
extern crate half;
#[cfg(test)]
@ -18,168 +15,12 @@ extern crate spirv_tools_sys as spirv_tools;
#[macro_use]
extern crate bitflags;
lalrpop_mod!(
#[allow(warnings)]
ptx
);
pub mod ast;
#[cfg(test)]
mod test;
mod translate;
use std::fmt;
pub use crate::ptx::ModuleParser;
use ast::PtxError;
pub use lalrpop_util::lexer::Token;
pub use lalrpop_util::ParseError;
pub use rspirv::dr::Error as SpirvError;
pub use translate::to_spirv_module;
pub use translate::KernelInfo;
pub use translate::TranslateError;
pub trait ModuleParserExt {
fn parse_checked<'input>(
txt: &'input str,
) -> Result<ast::Module<'input>, Vec<ParseError<usize, Token<'input>, ast::PtxError>>>;
// Returned AST might be malformed. Some users, like logger, want to look at
// malformed AST to record information - list of kernels or such
fn parse_unchecked<'input>(
txt: &'input str,
) -> (
ast::Module<'input>,
Vec<ParseError<usize, Token<'input>, ast::PtxError>>,
);
}
impl ModuleParserExt for ModuleParser {
fn parse_checked<'input>(
txt: &'input str,
) -> Result<ast::Module<'input>, Vec<ParseError<usize, Token<'input>, ast::PtxError>>> {
let mut errors = Vec::new();
let maybe_ast = ptx::ModuleParser::new().parse(&mut errors, txt);
match (&*errors, maybe_ast) {
(&[], Ok(ast)) => Ok(ast),
(_, Err(unrecoverable)) => {
errors.push(unrecoverable);
Err(errors)
}
(_, Ok(_)) => Err(errors),
}
}
fn parse_unchecked<'input>(
txt: &'input str,
) -> (
ast::Module<'input>,
Vec<ParseError<usize, Token<'input>, ast::PtxError>>,
) {
let mut errors = Vec::new();
let maybe_ast = ptx::ModuleParser::new().parse(&mut errors, txt);
let ast = match maybe_ast {
Ok(ast) => ast,
Err(unrecoverable_err) => {
errors.push(unrecoverable_err);
ast::Module {
version: (0, 0),
directives: Vec::new(),
}
}
};
(ast, errors)
}
}
pub struct DisplayParseError<'a, Loc, Tok, Err>(&'a str, &'a ParseError<Loc, Tok, Err>);
impl<'a, Loc: fmt::Display + Into<usize> + Copy, Tok, Err> DisplayParseError<'a, Loc, Tok, Err> {
// unsafe because there's no guarantee that the input str is the one that this error was created from
pub unsafe fn new(error: &'a ParseError<Loc, Tok, Err>, text: &'a str) -> Self {
Self(text, error)
}
}
impl<'a, Loc, Tok> fmt::Display for DisplayParseError<'a, Loc, Tok, PtxError>
where
Loc: fmt::Display,
Tok: fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.1 {
ParseError::User {
error: PtxError::UnrecognizedStatement { start, end },
} => self.fmt_unrecognized(f, *start, *end, "statement"),
ParseError::User {
error: PtxError::UnrecognizedDirective { start, end },
} => self.fmt_unrecognized(f, *start, *end, "directive"),
_ => self.1.fmt(f),
}
}
}
impl<'a, Loc, Tok, Err> DisplayParseError<'a, Loc, Tok, Err> {
fn fmt_unrecognized(
&self,
f: &mut fmt::Formatter,
start: usize,
end: usize,
kind: &'static str,
) -> fmt::Result {
let full_substring = unsafe { self.0.get_unchecked(start..end) };
write!(
f,
"Unrecognized {} `{}` found at {}:{}",
kind, full_substring, start, end
)
}
}
pub(crate) fn without_none<T>(x: Vec<Option<T>>) -> Vec<T> {
x.into_iter().filter_map(|x| x).collect()
}
pub(crate) fn vector_index<'input>(
inp: &'input str,
) -> Result<u8, ParseError<usize, lalrpop_util::lexer::Token<'input>, ast::PtxError>> {
match inp {
"x" | "r" => Ok(0),
"y" | "g" => Ok(1),
"z" | "b" => Ok(2),
"w" | "a" => Ok(3),
_ => Err(ParseError::User {
error: ast::PtxError::WrongVectorElement,
}),
}
}
#[cfg(test)]
mod tests {
use crate::{DisplayParseError, ModuleParser, ModuleParserExt};
#[test]
fn error_report_unknown_instructions() {
let module = r#"
.version 6.5
.target sm_30
.address_size 64
.visible .entry add(
.param .u64 input,
)
{
.reg .u64 x;
does_not_exist.u64 x, x;
ret;
}"#;
let errors = match ModuleParser::parse_checked(module) {
Err(e) => e,
Ok(_) => panic!(),
};
assert_eq!(errors.len(), 1);
let reporter = DisplayParseError(module, &errors[0]);
let build_log_string = format!("{}", reporter);
assert!(build_log_string.contains("does_not_exist"));
}
}
use ptx_parser as ast;

View file

@ -1,4 +1,3 @@
use super::ptx;
use super::TranslateError;
mod spirv_run;

File diff suppressed because it is too large Load diff

View file

@ -4,6 +4,8 @@ version = "0.0.0"
authors = ["Andrzej Janik <vosen@vosen.pl>"]
edition = "2021"
[lib]
[dependencies]
logos = "0.14"
winnow = { version = "0.6.18" }
@ -11,3 +13,4 @@ ptx_parser_macros = { path = "../ptx_parser_macros" }
thiserror = "1.0"
bitflags = "1.2"
rustc-hash = "2.0.0"
derive_more = { version = "1", features = ["display"] }

View file

@ -717,7 +717,7 @@ impl<Ident: Copy> Operand for ParsedOperand<Ident> {
type Ident = Ident;
}
pub trait Operand {
pub trait Operand: Sized {
type Ident: Copy;
}

View file

@ -1,8 +1,8 @@
use derive_more::Display;
use logos::Logos;
use ptx_parser_macros::derive_parser;
use rustc_hash::FxHashMap;
use std::fmt::Debug;
use std::mem;
use std::num::{ParseFloatError, ParseIntError};
use winnow::ascii::dec_uint;
use winnow::combinator::*;
@ -81,16 +81,16 @@ impl VectorPrefix {
}
}
struct PtxParserState<'input> {
errors: Vec<PtxError>,
struct PtxParserState<'a, 'input> {
errors: &'a mut Vec<PtxError>,
function_declarations:
FxHashMap<&'input str, (Vec<(ast::Type, StateSpace)>, Vec<(ast::Type, StateSpace)>)>,
}
impl<'input> PtxParserState<'input> {
fn new() -> Self {
impl<'a, 'input> PtxParserState<'a, 'input> {
fn new(errors: &'a mut Vec<PtxError>) -> Self {
Self {
errors: Vec::new(),
errors,
function_declarations: FxHashMap::default(),
}
}
@ -115,7 +115,7 @@ impl<'input> PtxParserState<'input> {
}
}
impl<'input> Debug for PtxParserState<'input> {
impl<'a, 'input> Debug for PtxParserState<'a, 'input> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PtxParserState")
.field("errors", &self.errors) /* .field("function_decl", &self.function_decl) */
@ -123,7 +123,7 @@ impl<'input> Debug for PtxParserState<'input> {
}
}
type PtxParser<'a, 'input> = Stateful<&'a [Token<'input>], PtxParserState<'input>>;
type PtxParser<'a, 'input> = Stateful<&'a [Token<'input>], PtxParserState<'a, 'input>>;
fn ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input str> {
any.verify_map(|t| {
@ -277,6 +277,18 @@ fn immediate_value<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<as
.parse_next(stream)
}
pub fn parse_module_unchecked<'input>(text: &'input str) -> Option<ast::Module<'input>> {
let lexer = Token::lexer(text);
let input = lexer.collect::<Result<Vec<_>, _>>().ok()?;
let mut errors = Vec::new();
let state = PtxParserState::new(&mut errors);
let parser = PtxParser {
state,
input: &input[..],
};
module.parse(parser).ok()
}
fn module<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::Module<'input>> {
(
version,
@ -818,6 +830,8 @@ pub enum PtxError {
source: ParseFloatError,
},
#[error("")]
Lexer(#[from] TokenError),
#[error("")]
Todo,
#[error("")]
SyntaxError,
@ -1042,9 +1056,15 @@ fn empty_call<'input>(
type ParsedOperandStr<'input> = ast::ParsedOperand<&'input str>;
#[derive(Clone, PartialEq, Default, Debug, Display)]
pub struct TokenError;
impl std::error::Error for TokenError {}
derive_parser!(
#[derive(Logos, PartialEq, Eq, Debug, Clone, Copy)]
#[logos(skip r"\s+")]
#[logos(error = TokenError)]
enum Token<'input> {
#[token(",")]
Comma,
@ -2825,57 +2845,6 @@ derive_parser!(
);
fn main() {
use winnow::Parser;
let lexer = Token::lexer(
"
.version 6.5
.target sm_30
.address_size 64
.const .align 8 .b32 constparams;
.visible .entry const(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .b16 temp1;
.reg .b16 temp2;
.reg .b16 temp3;
.reg .b16 temp4;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.const.b16 temp1, [constparams];
ld.const.b16 temp2, [constparams+2];
ld.const.b16 temp3, [constparams+4];
ld.const.b16 temp4, [constparams+6];
st.u16 [out_addr], temp1;
st.u16 [out_addr+2], temp2;
st.u16 [out_addr+4], temp3;
st.u16 [out_addr+6], temp4;
ret;
}
",
);
let tokens = lexer.clone().collect::<Vec<_>>();
println!("{:?}", &tokens);
let tokens = lexer.map(|t| t.unwrap()).collect::<Vec<_>>();
println!("{:?}", &tokens);
let stream = PtxParser {
input: &tokens[..],
state: PtxParserState::new(),
};
let _module = module.parse(stream).unwrap();
println!("{}", mem::size_of::<Token>());
}
#[cfg(test)]
mod tests {
use super::target;