diff --git a/llvm_zluda/build.rs b/llvm_zluda/build.rs index 9660b4d..ab66e1b 100644 --- a/llvm_zluda/build.rs +++ b/llvm_zluda/build.rs @@ -16,6 +16,8 @@ fn main() { let mut cmake = Config::new(r"../ext/llvm-project/llvm"); try_use_ninja(&mut cmake); cmake + // It's not like we can do anything about the warnings + .define("LLVM_ENABLE_WARNINGS", "OFF") // For some reason Rust always links to release MSVCRT .define("CMAKE_MSVC_RUNTIME_LIBRARY", "MultiThreadedDLL") .define("LLVM_ENABLE_TERMINFO", "OFF") diff --git a/llvm_zluda/src/lib.cpp b/llvm_zluda/src/lib.cpp index 886aa0d..072f773 100644 --- a/llvm_zluda/src/lib.cpp +++ b/llvm_zluda/src/lib.cpp @@ -1,7 +1,7 @@ #include -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Type.h" -#include "llvm/IR/Instructions.h" +#include +#include +#include using namespace llvm; diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index f0d3fbe..c5e8e79 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -1264,9 +1264,9 @@ pub enum SetpCompareFloat { } impl TryFrom<(RawSetpCompareOp, ScalarKind)> for SetpCompareInt { - type Error = PtxError; + type Error = PtxError<'static>; - fn try_from((value, kind): (RawSetpCompareOp, ScalarKind)) -> Result { + fn try_from((value, kind): (RawSetpCompareOp, ScalarKind)) -> Result> { match (value, kind) { (RawSetpCompareOp::Eq, _) => Ok(SetpCompareInt::Eq), (RawSetpCompareOp::Ne, _) => Ok(SetpCompareInt::NotEq), diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index 2d77fd1..4f3f2ac 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -86,14 +86,16 @@ impl VectorPrefix { } struct PtxParserState<'a, 'input> { - errors: &'a mut Vec, + text: &'input str, + errors: &'a mut Vec>, function_declarations: FxHashMap<&'input str, (Vec<(ast::Type, StateSpace)>, Vec<(ast::Type, StateSpace)>)>, } impl<'a, 'input> PtxParserState<'a, 'input> { - fn new(errors: &'a mut Vec) -> Self { + fn new(text: &'input str, errors: &'a mut Vec>) -> Self { Self { + text, errors, function_declarations: FxHashMap::default(), } @@ -179,7 +181,7 @@ fn num<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(&'input str, } fn take_error<'a, 'input: 'a, O, E>( - mut parser: impl Parser, Result, E>, + mut parser: impl Parser, Result)>, E>, ) -> impl Parser, O, E> { move |input: &mut PtxParser<'a, 'input>| { Ok(match parser.parse_next(input)? { @@ -285,7 +287,7 @@ fn immediate_value<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult(text: &'input str) -> Option> { let input = lex_with_span(text).ok()?; let mut errors = Vec::new(); - let state = PtxParserState::new(&mut errors); + let state = PtxParserState::new(text, &mut errors); let parser = PtxParser { state, input: &input[..], @@ -321,7 +323,7 @@ pub fn parse_module_checked<'input>( return Err(errors); } let parse_result = { - let state = PtxParserState::new(&mut errors); + let state = PtxParserState::new(text, &mut errors); let parser = PtxParser { state, input: &tokens[..], @@ -641,28 +643,81 @@ fn statement<'a, 'input>( pragma.map(|_| None), block_statement.map(Some), )), - fail, + take_till_inclusive( + |(t, _)| *t == Token::RBrace, + |(t, _)| match t { + Token::Semicolon | Token::Colon => true, + _ => false, + }, + ), /* + take_till(0.., |(t, _)| match t { + Token::Semicolon | Token::Colon => true, + _ => false, + }) + */ ) .map(Option::flatten) .parse_next(stream) } +fn take_till_inclusive>( + backtrack_token: impl Fn(&I::Token) -> bool, + end_token: impl Fn(&I::Token) -> bool, +) -> impl Parser::Slice, E> { + fn get_offset( + input: &mut I, + backtrack_token: &impl Fn(&I::Token) -> bool, + end_token: &impl Fn(&I::Token) -> bool, + should_backtrack: &mut bool, + ) -> usize { + *should_backtrack = false; + let mut hit = false; + for (offset, token) in input.iter_offsets() { + if hit { + return offset; + } else { + if backtrack_token(&token) { + *should_backtrack = true; + return offset; + } + if end_token(&token) { + hit = true; + } + } + } + input.eof_offset() + } + move |stream: &mut I| { + let mut should_backtrack = false; + let offset = get_offset(stream, &backtrack_token, &end_token, &mut should_backtrack); + let result = stream.next_slice(offset); + if should_backtrack { + Err(ErrMode::from_error_kind( + stream, + winnow::error::ErrorKind::Token, + )) + } else { + Ok(result) + } + } +} + fn with_recovery<'a, 'input: 'a, T>( mut parser: impl Parser, T, ContextError>, - mut recovery: impl Parser, (), ContextError>, + mut recovery: impl Parser, &'a [(Token<'input>, logos::Span)], ContextError>, ) -> impl Parser, Option, ContextError> { move |stream: &mut PtxParser<'a, 'input>| { let input_start = stream.input.first().map(|(_, s)| s).cloned(); + let stream_start = stream.checkpoint(); match parser.parse_next(stream) { Ok(value) => Ok(Some(value)), - Err(err) => { - recovery.parse_next(stream)?; - let input_end = stream.input.first().map(|(_, s)| s).cloned(); - let range = match (input_start, input_end) { - (Some(start), Some(end)) => Some(std::ops::Range { - start: start.start, - end: end.end, - }), + Err(ErrMode::Backtrack(_)) => { + stream.reset(&stream_start); + let tokens = recovery.parse_next(stream)?; + let range = match input_start { + Some(start) => { + Some(&stream.state.text[start.start..tokens.last().unwrap().1.end]) + } // We could handle `(Some(start), None)``, but this whole error recovery is to // recover from unknown instructions, so we don't care about early end of stream _ => None, @@ -671,8 +726,9 @@ fn with_recovery<'a, 'input: 'a, T>( .state .errors .push(PtxError::UnrecognizedStatement(range)); - Err(err) + Ok(None) } + Err(err) => Err(err), } } } @@ -1148,7 +1204,7 @@ impl ast::ParsedOperand { } #[derive(Debug, thiserror::Error)] -pub enum PtxError { +pub enum PtxError<'input> { #[error("{source}")] ParseInt { #[from] @@ -1192,10 +1248,8 @@ pub enum PtxError { ArrayInitalizer, #[error("")] NonExternPointer, - #[error("")] - UnrecognizedStatement(Option>), - #[error("{start}:{end}")] - UnrecognizedDirective { start: usize, end: usize }, + #[error("{0:?}")] + UnrecognizedStatement(Option<&'input str>), } #[derive(Debug)] @@ -3270,21 +3324,27 @@ derive_parser!( #[cfg(test)] mod tests { + use crate::parse_module_checked; + use crate::PtxError; + use super::target; use super::PtxParserState; use super::Token; use logos::Logos; + use logos::Span; use winnow::prelude::*; #[test] fn sm_11() { - let tokens = Token::lexer(".target sm_11") + let text = ".target sm_11"; + let tokens = Token::lexer(text) + .map(|t| t.map(|t| (t, Span::default()))) .collect::, _>>() .unwrap(); let mut errors = Vec::new(); let stream = super::PtxParser { input: &tokens[..], - state: PtxParserState::new(&mut errors), + state: PtxParserState::new(text, &mut errors), }; assert_eq!(target.parse(stream).unwrap(), (11, None)); assert_eq!(errors.len(), 0); @@ -3292,13 +3352,15 @@ mod tests { #[test] fn sm_90a() { - let tokens = Token::lexer(".target sm_90a") + let text = ".target sm_90a"; + let tokens = Token::lexer(text) + .map(|t| t.map(|t| (t, Span::default()))) .collect::, _>>() .unwrap(); let mut errors = Vec::new(); let stream = super::PtxParser { input: &tokens[..], - state: PtxParserState::new(&mut errors), + state: PtxParserState::new(text, &mut errors), }; assert_eq!(target.parse(stream).unwrap(), (90, Some('a'))); assert_eq!(errors.len(), 0); @@ -3306,15 +3368,56 @@ mod tests { #[test] fn sm_90ab() { - let tokens = Token::lexer(".target sm_90ab") + let text = ".target sm_90ab"; + let tokens = Token::lexer(text) + .map(|t| t.map(|t| (t, Span::default()))) .collect::, _>>() .unwrap(); let mut errors = Vec::new(); let stream = super::PtxParser { input: &tokens[..], - state: PtxParserState::new(&mut errors), + state: PtxParserState::new(text, &mut errors), }; assert!(target.parse(stream).is_err()); assert_eq!(errors.len(), 0); } + + #[test] + fn report_unknown_intruction() { + let text = " + .version 6.5 + .target sm_30 + .address_size 64 + + .visible .entry add( + .param .u64 input, + .param .u64 output + ) + { + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u64 temp; + .reg .u64 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.u64 temp, [in_addr]; + unknown_op1.asdf foobar; + add.u64 temp2, temp, 1; + unknown_op2 temp2, temp; + st.u64 [out_addr], temp2; + ret; + }"; + let errors = parse_module_checked(text).err().unwrap(); + assert_eq!(errors.len(), 2); + assert!(matches!( + errors[0], + PtxError::UnrecognizedStatement(Some("unknown_op1.asdf foobar;")) + )); + assert!(matches!( + errors[1], + PtxError::UnrecognizedStatement(Some("unknown_op2 temp2, temp;")) + )); + } }