diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/main.rs index 7a29e63..0ac1260 100644 --- a/ptx_parser/src/main.rs +++ b/ptx_parser/src/main.rs @@ -2,9 +2,8 @@ use gen::derive_parser; use logos::Logos; use std::mem; use std::num::{ParseFloatError, ParseIntError}; -use winnow::ascii::{dec_uint, digit1}; +use winnow::ascii::dec_uint; use winnow::combinator::*; -use winnow::error::ErrMode; use winnow::stream::Accumulate; use winnow::token::any; use winnow::{ @@ -76,6 +75,17 @@ fn ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input str> .parse_next(stream) } +fn dot_ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input str> { + any.verify_map(|t| { + if let Token::DotIdent(text) = t { + Some(text) + } else { + None + } + }) + .parse_next(stream) +} + fn num<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(&'input str, u32, bool)> { any.verify_map(|t| { Some(match t { @@ -210,8 +220,9 @@ fn module<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult(stream: &mut &str) -> PResult<(u32, Option)> { fn directive<'a, 'input>( stream: &mut PtxParser<'a, 'input>, ) -> PResult>>> { - (function.map(|f| { - let (linking, func) = f; - Some(ast::Directive::Method(linking, func)) - })) + alt(( + function.map(|(linking, func)| Some(ast::Directive::Method(linking, func))), + file.map(|_| None), + section.map(|_| None), + (module_variable, Token::Semicolon) + .map(|((linking, var), _)| Some(ast::Directive::Variable(linking, var))), + )) .parse_next(stream) } +fn module_variable<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<(ast::LinkingDirective, ast::Variable<&'input str>)> { + ( + linking_directives, + module_variable_state_space.flat_map(variable_scalar_or_vector), + ) + .parse_next(stream) +} + +fn module_variable_state_space<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult { + alt(( + Token::DotConst.value(StateSpace::Const), + Token::DotGlobal.value(StateSpace::Global), + Token::DotShared.value(StateSpace::Shared), + )) + .parse_next(stream) +} + +fn file<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + ( + Token::DotFile, + u32, + Token::String, + opt((Token::Comma, u32, Token::Comma, u32)), + ) + .void() + .parse_next(stream) +} + +fn section<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + ( + Token::DotSection.void(), + dot_ident.void(), + Token::LBrace.void(), + repeat::<_, _, (), _, _>(0.., section_dwarf_line), + Token::RBrace.void(), + ) + .void() + .parse_next(stream) +} + +fn section_dwarf_line<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + alt(( + (section_label, Token::Colon).void(), + (Token::DotB32, section_label, opt((Token::Add, u32))).void(), + (Token::DotB64, section_label, opt((Token::Add, u32))).void(), + ( + any_bit_type, + separated::<_, _, (), _, _, _, _>(1.., u32, Token::Comma), + ) + .void(), + )) + .parse_next(stream) +} + +fn any_bit_type<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + alt((Token::DotB8, Token::DotB16, Token::DotB32, Token::DotB64)) + .void() + .parse_next(stream) +} + +fn section_label<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + alt((ident, dot_ident)).void().parse_next(stream) +} + fn function<'a, 'input>( stream: &mut PtxParser<'a, 'input>, ) -> PResult<( @@ -283,12 +365,16 @@ fn function<'a, 'input>( fn linking_directives<'a, 'input>( stream: &mut PtxParser<'a, 'input>, ) -> PResult { - dispatch! { any; - Token::DotExtern => empty.value(ast::LinkingDirective::EXTERN), - Token::DotVisible => empty.value(ast::LinkingDirective::VISIBLE), - Token::DotWeak => empty.value(ast::LinkingDirective::WEAK), - _ => fail - } + repeat( + 0.., + dispatch! { any; + Token::DotExtern => empty.value(ast::LinkingDirective::EXTERN), + Token::DotVisible => empty.value(ast::LinkingDirective::VISIBLE), + Token::DotWeak => empty.value(ast::LinkingDirective::WEAK), + _ => fail + }, + ) + .fold(|| ast::LinkingDirective::NONE, |x, y| x | y) .parse_next(stream) } @@ -816,6 +902,8 @@ derive_parser!( At, #[regex(r"[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+", |lex| lex.slice(), priority = 0)] Ident(&'input str), + #[regex(r"\.[a-zA-Z][a-zA-Z0-9_$]*|\.[_$%][a-zA-Z0-9_$]+", |lex| lex.slice(), priority = 0)] + DotIdent(&'input str), #[regex(r#""[^"]*""#)] String, #[token("|")] @@ -879,7 +967,11 @@ derive_parser!( #[token(".target")] DotTarget, #[token(".address_size")] - DotAddressSize + DotAddressSize, + #[token(".action")] + DotSection, + #[token(".file")] + DotFile } #[derive(Copy, Clone, PartialEq, Eq, Hash)] @@ -1224,50 +1316,51 @@ fn main() { use winnow::token::*; use winnow::Parser; - println!("{}", mem::size_of::()); - - let mut input: &[Token] = &[][..]; - let x = opt(any::<_, ContextError>.verify_map(|_| { - println!("MAP"); - Some(true) - })) - .parse_next(&mut input) - .unwrap(); - dbg!(x); let lexer = Token::lexer( " .version 6.5 .target sm_30 .address_size 64 - .visible .entry add( + .const .align 8 .b32 constparams; + + .visible .entry const( .param .u64 input, .param .u64 output ) { .reg .u64 in_addr; .reg .u64 out_addr; - .reg .u64 temp; - .reg .u64 temp2; + .reg .b16 temp1; + .reg .b16 temp2; + .reg .b16 temp3; + .reg .b16 temp4; ld.param.u64 in_addr, [input]; ld.param.u64 out_addr, [output]; - ld.u64 temp, [in_addr]; - add.u64 temp2, temp, 1; - st.u64 [out_addr], temp2; + ld.const.b16 temp1, [constparams]; + ld.const.b16 temp2, [constparams+2]; + ld.const.b16 temp3, [constparams+4]; + ld.const.b16 temp4, [constparams+6]; + st.u16 [out_addr], temp1; + st.u16 [out_addr+2], temp2; + st.u16 [out_addr+4], temp3; + st.u16 [out_addr+6], temp4; ret; } ", ); + let tokens = lexer.clone().collect::>(); + println!("{:?}", &tokens); let tokens = lexer.map(|t| t.unwrap()).collect::>(); println!("{:?}", &tokens); let stream = PtxParser { input: &tokens[..], state: Vec::new(), }; - let module_ = module.parse(stream).unwrap(); + let _module = module.parse(stream).unwrap(); println!("{}", mem::size_of::()); }