diff --git a/gen/src/lib.rs b/gen/src/lib.rs index 67b276e..ebddf03 100644 --- a/gen/src/lib.rs +++ b/gen/src/lib.rs @@ -11,15 +11,17 @@ use syn::{ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#vectors // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#fundamental-types // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#alternate-floating-point-data-formats +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#packed-floating-point-data-types +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#packed-integer-data-types #[rustfmt::skip] static POSTFIX_MODIFIERS: &[&str] = &[ ".v2", ".v4", - ".s8", ".s16", ".s32", ".s64", - ".u8", ".u16", ".u32", ".u64", + ".s8", ".s16", ".s16x2", ".s32", ".s64", + ".u8", ".u16", ".u16x2", ".u32", ".u64", ".f16", ".f16x2", ".f32", ".f64", ".b8", ".b16", ".b32", ".b64", ".b128", ".pred", - ".bf16", ".e4m3", ".e5m2", ".tf32", + ".bf16", ".bf16x2", ".e4m3", ".e5m2", ".tf32", ]; static POSTFIX_TYPES: &[&str] = &["ScalarType", "VectorPrefix"]; diff --git a/gen_impl/src/parser.rs b/gen_impl/src/parser.rs index 6834cbc..519bf12 100644 --- a/gen_impl/src/parser.rs +++ b/gen_impl/src/parser.rs @@ -332,7 +332,8 @@ impl DotModifier { capitalize = true; continue; } - let c = if capitalize { + // Special hack to emit `BF16`` instead of `Bf16`` + let c = if capitalize || c == 'f' && result.ends_with('B') { capitalize = false; c.to_ascii_uppercase() } else { diff --git a/ptx_parser/Cargo.toml b/ptx_parser/Cargo.toml index 951d508..4f32860 100644 --- a/ptx_parser/Cargo.toml +++ b/ptx_parser/Cargo.toml @@ -8,3 +8,4 @@ logos = "0.14" winnow = { version = "0.6.18", features = ["debug"] } gen = { path = "../gen" } thiserror = "1.0" +bitflags = "1.2" diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 302aef7..2dabf3e 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -1,4 +1,31 @@ -use super::{MemScope, ScalarType, VectorPrefix, StateSpace}; +use super::{MemScope, ScalarType, StateSpace, VectorPrefix}; +use bitflags::bitflags; + +pub enum Statement { + Label(P::Ident), + Variable(MultiVariable), + Instruction(Option>, Instruction

), + Block(Vec>), +} + +pub struct MultiVariable { + pub var: Variable, + pub count: Option, +} + +#[derive(Clone)] +pub struct Variable { + pub align: Option, + pub v_type: Type, + pub state_space: StateSpace, + pub name: ID, + pub array_init: Vec, +} + +pub struct PredAt { + pub not: bool, + pub label: ID, +} gen::generate_instruction_type!( pub enum Instruction { @@ -118,6 +145,14 @@ pub enum ParsedOperand { VecPack(Vec), } +impl Operand for ParsedOperand { + type Ident = Ident; +} + +pub trait Operand { + type Ident; +} + #[derive(Copy, Clone)] pub enum ImmediateValue { U64(u64), @@ -143,8 +178,6 @@ pub enum LdCacheOperator { Uncached, } - - #[derive(Copy, Clone)] pub enum ArithDetails { Integer(ArithInteger), @@ -199,7 +232,6 @@ pub struct LdDetails { pub non_coherent: bool, } - pub struct StData { pub qualifier: LdStQualifier, pub state_space: StateSpace, @@ -211,3 +243,52 @@ pub struct StData { pub struct RetData { pub uniform: bool, } + +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum TuningDirective { + MaxNReg(u32), + MaxNtid(u32, u32, u32), + ReqNtid(u32, u32, u32), + MinNCtaPerSm(u32), +} + +pub struct MethodDeclaration<'input, ID> { + pub return_arguments: Vec>, + pub name: MethodName<'input, ID>, + pub input_arguments: Vec>, + pub shared_mem: Option, +} + +#[derive(Hash, PartialEq, Eq, Copy, Clone)] +pub enum MethodName<'input, ID> { + Kernel(&'input str), + Func(ID), +} + +bitflags! { + pub struct LinkingDirective: u8 { + const NONE = 0b000; + const EXTERN = 0b001; + const VISIBLE = 0b10; + const WEAK = 0b100; + } +} + +pub struct Function<'a, ID, S> { + pub func_directive: MethodDeclaration<'a, ID>, + pub tuning: Vec, + pub body: Option>, +} + +pub enum Directive<'input, O: Operand> { + Variable(LinkingDirective, Variable), + Method( + LinkingDirective, + Function<'input, &'input str, Statement>, + ), +} + +pub struct Module<'input> { + pub version: (u8, u8), + pub directives: Vec>>, +} diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/main.rs index 34c27da..7a29e63 100644 --- a/ptx_parser/src/main.rs +++ b/ptx_parser/src/main.rs @@ -2,7 +2,10 @@ use gen::derive_parser; use logos::Logos; use std::mem; use std::num::{ParseFloatError, ParseIntError}; +use winnow::ascii::{dec_uint, digit1}; use winnow::combinator::*; +use winnow::error::ErrMode; +use winnow::stream::Accumulate; use winnow::token::any; use winnow::{ error::{ContextError, ParserError}, @@ -170,6 +173,28 @@ fn s32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { .parse_next(stream) } +fn u8<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { + take_error(num.map(|x| { + let (text, radix, _) = x; + match u8::from_str_radix(text, radix) { + Ok(x) => Ok(x), + Err(err) => Err((0, PtxError::from(err))), + } + })) + .parse_next(stream) +} + +fn u32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { + take_error(num.map(|x| { + let (text, radix, _) = x; + match u32::from_str_radix(text, radix) { + Ok(x) => Ok(x), + Err(err) => Err((0, PtxError::from(err))), + } + })) + .parse_next(stream) +} + fn immediate_value<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { alt(( int_immediate, @@ -179,10 +204,402 @@ fn immediate_value<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult( +fn module<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult> { + ( + version, + target, + opt(address_size), + repeat_without_none(directive), + ) + .map(|(version, _, _, directives)| ast::Module { + version, + directives, + }) + .parse_next(stream) +} + +fn address_size<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + (Token::DotAddressSize, u8_literal(64)) + .void() + .parse_next(stream) +} + +fn version<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(u8, u8)> { + (Token::DotVersion, u8, Token::Dot, u8) + .map(|(_, major, _, minor)| (major, minor)) + .parse_next(stream) +} + +fn target<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(u32, Option)> { + preceded(Token::DotTarget, ident.and_then(shader_model)).parse_next(stream) +} + +fn shader_model<'a>(stream: &mut &str) -> PResult<(u32, Option)> { + ( + "sm_", + dec_uint, + opt(any.verify(|c: &char| c.is_ascii_lowercase())), + eof, + ) + .map(|(_, digits, arch_variant, _)| (digits, arch_variant)) + .parse_next(stream) +} + +fn directive<'a, 'input>( stream: &mut PtxParser<'a, 'input>, -) -> PResult>>> { - repeat(3.., terminated(parse_instruction, Token::Semicolon)).parse_next(stream) +) -> PResult>>> { + (function.map(|f| { + let (linking, func) = f; + Some(ast::Directive::Method(linking, func)) + })) + .parse_next(stream) +} + +fn function<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<( + ast::LinkingDirective, + ast::Function<'input, &'input str, ast::Statement>>, +)> { + ( + linking_directives, + method_declaration, + repeat(0.., tuning_directive), + function_body, + ) + .map(|(linking, func_directive, tuning, body)| { + ( + linking, + ast::Function { + func_directive, + tuning, + body, + }, + ) + }) + .parse_next(stream) +} + +fn linking_directives<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult { + dispatch! { any; + Token::DotExtern => empty.value(ast::LinkingDirective::EXTERN), + Token::DotVisible => empty.value(ast::LinkingDirective::VISIBLE), + Token::DotWeak => empty.value(ast::LinkingDirective::WEAK), + _ => fail + } + .parse_next(stream) +} + +fn tuning_directive<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult { + dispatch! {any; + Token::DotMaxnreg => u32.map(ast::TuningDirective::MaxNReg), + Token::DotMaxntid => tuple1to3_u32.map(|(nx, ny, nz)| ast::TuningDirective::MaxNtid(nx, ny, nz)), + Token::DotReqntid => tuple1to3_u32.map(|(nx, ny, nz)| ast::TuningDirective::ReqNtid(nx, ny, nz)), + Token::DotMinnctapersm => u32.map(ast::TuningDirective::MinNCtaPerSm), + _ => fail + } + .parse_next(stream) +} + +fn method_declaration<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult> { + dispatch! {any; + Token::DotEntry => (ident, kernel_arguments).map(|(name, input_arguments)| ast::MethodDeclaration{ + return_arguments: Vec::new(), name: ast::MethodName::Kernel(name), input_arguments, shared_mem: None + }), + Token::DotFunc => (opt(fn_arguments), ident, fn_arguments).map(|(return_arguments, name,input_arguments)| { + let return_arguments = return_arguments.unwrap_or_else(|| Vec::new()); + let name = ast::MethodName::Func(name); + ast::MethodDeclaration{ return_arguments, name, input_arguments, shared_mem: None } + }), + _ => fail + } + .parse_next(stream) +} + +fn fn_arguments<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult>> { + delimited( + Token::LParen, + separated(0.., fn_input, Token::Comma), + Token::RParen, + ) + .parse_next(stream) +} + +fn kernel_arguments<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult>> { + delimited( + Token::LParen, + separated(0.., kernel_input, Token::Comma), + Token::RParen, + ) + .parse_next(stream) +} + +fn kernel_input<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult> { + preceded( + Token::DotParam, + variable_scalar_or_vector(StateSpace::Param), + ) + .parse_next(stream) +} + +fn fn_input<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult> { + dispatch! { any; + Token::DotParam => variable_scalar_or_vector(StateSpace::Param), + Token::DotReg => variable_scalar_or_vector(StateSpace::Reg), + _ => fail + } + .parse_next(stream) +} + +fn tuple1to3_u32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(u32, u32, u32)> { + struct Tuple3AccumulateU32 { + index: usize, + value: (u32, u32, u32), + } + + impl Accumulate for Tuple3AccumulateU32 { + fn initial(_: Option) -> Self { + Self { + index: 0, + value: (1, 1, 1), + } + } + + fn accumulate(&mut self, value: u32) { + match self.index { + 0 => { + self.value = (value, self.value.1, self.value.2); + self.index = 1; + } + 1 => { + self.value = (self.value.0, value, self.value.2); + self.index = 2; + } + 2 => { + self.value = (self.value.0, self.value.1, value); + self.index = 3; + } + _ => unreachable!(), + } + } + } + + separated::<_, _, Tuple3AccumulateU32, _, _, _, _>(1..3, u32, Token::Comma) + .map(|acc| acc.value) + .parse_next(stream) +} + +fn function_body<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult>>>> { + dispatch! {any; + Token::LBrace => terminated(repeat_without_none(statement), Token::RBrace).map(Some), + Token::Semicolon => empty.map(|_| None), + _ => fail + } + .parse_next(stream) +} + +fn statement<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult>>> { + alt(( + label.map(Some), + debug_directive.map(|_| None), + multi_variable.map(Some), + predicated_instruction.map(Some), + pragma.map(|_| None), + block_statement.map(Some), + )) + .parse_next(stream) +} + +fn pragma<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + (Token::DotPragma, Token::String, Token::Semicolon) + .void() + .parse_next(stream) +} + +fn multi_variable<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult>> { + ( + variable, + opt(delimited(Token::Lt, u32, Token::Gt)), + Token::Semicolon, + ) + .map(|(var, count, _)| ast::Statement::Variable(ast::MultiVariable { var, count })) + .parse_next(stream) +} + +fn variable<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult> { + dispatch! {any; + Token::DotReg => variable_scalar_or_vector(StateSpace::Reg), + Token::DotLocal => variable_scalar_or_vector(StateSpace::Local), + Token::DotParam => variable_scalar_or_vector(StateSpace::Param), + Token::DotShared => variable_scalar_or_vector(StateSpace::Shared), + _ => fail + } + .parse_next(stream) +} + +fn variable_scalar_or_vector<'a, 'input: 'a>( + state_space: StateSpace, +) -> impl Parser, ast::Variable<&'input str>, ContextError> { + move |stream: &mut PtxParser<'a, 'input>| { + (opt(align), scalar_vector_type, ident) + .map(|(align, v_type, name)| ast::Variable { + align, + v_type, + state_space, + name, + array_init: Vec::new(), + }) + .parse_next(stream) + } +} + +fn align<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { + preceded(Token::DotAlign, u32).parse_next(stream) +} + +fn scalar_vector_type<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { + ( + opt(alt(( + Token::DotV2.value(VectorPrefix::V2), + Token::DotV4.value(VectorPrefix::V4), + ))), + scalar_type, + ) + .map(|(prefix, scalar)| ast::Type::maybe_vector(prefix, scalar)) + .parse_next(stream) +} + +fn scalar_type<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { + any.verify_map(|t| { + Some(match t { + Token::DotS8 => ScalarType::S8, + Token::DotS16 => ScalarType::S16, + Token::DotS16x2 => ScalarType::S16x2, + Token::DotS32 => ScalarType::S32, + Token::DotS64 => ScalarType::S64, + Token::DotU8 => ScalarType::U8, + Token::DotU16 => ScalarType::U16, + Token::DotU16x2 => ScalarType::U16x2, + Token::DotU32 => ScalarType::U32, + Token::DotU64 => ScalarType::U64, + Token::DotB8 => ScalarType::B8, + Token::DotB16 => ScalarType::B16, + Token::DotB32 => ScalarType::B32, + Token::DotB64 => ScalarType::B64, + Token::DotB128 => ScalarType::B128, + Token::DotPred => ScalarType::Pred, + Token::DotF16 => ScalarType::F16, + Token::DotF16x2 => ScalarType::F16x2, + Token::DotF32 => ScalarType::F32, + Token::DotF64 => ScalarType::F64, + Token::DotBF16 => ScalarType::BF16, + Token::DotBF16x2 => ScalarType::BF16x2, + _ => return None, + }) + }) + .parse_next(stream) +} + +fn predicated_instruction<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult>> { + (opt(pred_at), parse_instruction, Token::Semicolon) + .map(|(p, i, _)| ast::Statement::Instruction(p, i)) + .parse_next(stream) +} + +fn pred_at<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult> { + (Token::At, opt(Token::Not), ident) + .map(|(_, not, label)| ast::PredAt { + not: not.is_some(), + label, + }) + .parse_next(stream) +} + +fn label<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult>> { + terminated(ident, Token::Colon) + .map(|l| ast::Statement::Label(l)) + .parse_next(stream) +} + +fn debug_directive<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + ( + Token::DotLoc, + u32, + u32, + u32, + opt(( + Token::Comma, + ident_literal("function_name"), + ident, + dispatch! { any; + Token::Comma => (ident_literal("inlined_at"), u32, u32, u32).void(), + Token::Plus => (u32, Token::Comma, ident_literal("inlined_at"), u32, u32, u32).void(), + _ => fail + }, + )), + ) + .void() + .parse_next(stream) +} + +fn block_statement<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult>> { + delimited(Token::LBrace, repeat_without_none(statement), Token::RBrace) + .map(|s| ast::Statement::Block(s)) + .parse_next(stream) +} + +fn repeat_without_none>( + parser: impl Parser, Error>, +) -> impl Parser, Error> { + repeat(0.., parser).fold(Vec::new, |mut acc: Vec<_>, item| { + if let Some(item) = item { + acc.push(item); + } + acc + }) +} + +fn ident_literal< + 'a, + 'input, + I: Stream> + StreamIsPartial, + E: ParserError, +>( + s: &'input str, +) -> impl Parser + 'input { + move |stream: &mut I| { + any.verify(|t| matches!(t, Token::Ident(text) if *text == s)) + .void() + .parse_next(stream) + } +} + +fn u8_literal<'a, 'input>(x: u8) -> impl Parser, (), ContextError> { + move |stream: &mut PtxParser| u8.verify(|t| *t == x).void().parse_next(stream) } impl ast::ParsedOperand { @@ -391,18 +808,36 @@ derive_parser!( Comma, #[token(".")] Dot, + #[token(":")] + Colon, #[token(";")] Semicolon, + #[token("@")] + At, #[regex(r"[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+", |lex| lex.slice(), priority = 0)] Ident(&'input str), + #[regex(r#""[^"]*""#)] + String, #[token("|")] Or, #[token("!")] Not, + #[token("(")] + LParen, + #[token(")")] + RParen, #[token("[")] LBracket, #[token("]")] RBracket, + #[token("{")] + LBrace, + #[token("}")] + RBrace, + #[token("<")] + Lt, + #[token(">")] + Gt, #[regex(r"0[fF][0-9a-zA-Z]{8}", |lex| lex.slice())] F32(&'input str), #[regex(r"0[dD][0-9a-zA-Z]{16}", |lex| lex.slice())] @@ -415,6 +850,36 @@ derive_parser!( Minus, #[token("+")] Plus, + #[token(".version")] + DotVersion, + #[token(".loc")] + DotLoc, + #[token(".reg")] + DotReg, + #[token(".align")] + DotAlign, + #[token(".pragma")] + DotPragma, + #[token(".maxnreg")] + DotMaxnreg, + #[token(".maxntid")] + DotMaxntid, + #[token(".reqntid")] + DotReqntid, + #[token(".minnctapersm")] + DotMinnctapersm, + #[token(".entry")] + DotEntry, + #[token(".func")] + DotFunc, + #[token(".extern")] + DotExtern, + #[token(".visible")] + DotVisible, + #[token(".target")] + DotTarget, + #[token(".address_size")] + DotAddressSize } #[derive(Copy, Clone, PartialEq, Eq, Hash)] @@ -771,10 +1236,29 @@ fn main() { dbg!(x); let lexer = Token::lexer( " - ld.u64 temp, [in_addr]; - add.u64 temp2, temp, 1; - st.u64 [out_addr], temp2; - ret; + .version 6.5 + .target sm_30 + .address_size 64 + + .visible .entry add( + .param .u64 input, + .param .u64 output + ) + { + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u64 temp; + .reg .u64 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.u64 temp, [in_addr]; + add.u64 temp2, temp, 1; + st.u64 [out_addr], temp2; + ret; + } + ", ); let tokens = lexer.map(|t| t.unwrap()).collect::>(); @@ -783,7 +1267,50 @@ fn main() { input: &tokens[..], state: Vec::new(), }; - let fn_body = fn_body.parse(stream).unwrap(); - println!("{}", fn_body.len()); + let module_ = module.parse(stream).unwrap(); println!("{}", mem::size_of::()); } + +#[cfg(test)] +mod tests { + use super::target; + use super::Token; + use logos::Logos; + use winnow::prelude::*; + + #[test] + fn sm_11() { + let tokens = Token::lexer(".target sm_11") + .collect::, ()>>() + .unwrap(); + let stream = super::PtxParser { + input: &tokens[..], + state: Vec::new(), + }; + assert_eq!(target.parse(stream).unwrap(), (11, None)); + } + + #[test] + fn sm_90a() { + let tokens = Token::lexer(".target sm_90a") + .collect::, ()>>() + .unwrap(); + let stream = super::PtxParser { + input: &tokens[..], + state: Vec::new(), + }; + assert_eq!(target.parse(stream).unwrap(), (90, Some('a'))); + } + + #[test] + fn sm_90ab() { + let tokens = Token::lexer(".target sm_90ab") + .collect::, ()>>() + .unwrap(); + let stream = super::PtxParser { + input: &tokens[..], + state: Vec::new(), + }; + assert!(target.parse(stream).is_err()); + } +}