From 0112880f2742183558bbfd27022f080a8e8817fd Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 16 Aug 2024 16:02:26 +0200 Subject: [PATCH] Parse ld, add, ret --- gen/src/lib.rs | 76 +++++++++---- ptx_parser/src/ast.rs | 48 +++++++++ ptx_parser/src/main.rs | 239 +++++++++++++++++++++++++++++++++++------ 3 files changed, 313 insertions(+), 50 deletions(-) diff --git a/gen/src/lib.rs b/gen/src/lib.rs index 6ea0136..93b31fe 100644 --- a/gen/src/lib.rs +++ b/gen/src/lib.rs @@ -2,9 +2,10 @@ use gen_impl::parser; use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote, ToTokens}; use rustc_hash::{FxHashMap, FxHashSet}; -use std::{collections::hash_map, hash::Hash, rc::Rc}; +use std::{collections::hash_map, hash::Hash, iter, rc::Rc}; use syn::{ - parse_macro_input, punctuated::Punctuated, Ident, ItemEnum, Token, Type, TypePath, Variant, + parse_macro_input, parse_quote, punctuated::Punctuated, Ident, ItemEnum, Token, Type, TypePath, + Variant, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#vectors @@ -176,10 +177,15 @@ impl SingleOpcodeDefinition { }) .chain(self.arguments.0.iter().map(|arg| { let name = &arg.ident; - if arg.optional { - quote! { #name : Option> } + let arg_type = if arg.unified { + quote! { (ParsedOperand<'input>, bool) } } else { - quote! { #name : ParsedOperand<'input> } + quote! { ParsedOperand<'input> } + }; + if arg.optional { + quote! { #name : Option<#arg_type> } + } else { + quote! { #name : #arg_type } } })) } @@ -477,7 +483,8 @@ fn emit_parse_function( #type_name :: #variant => Some(#value), } }); - let modifier_names = all_modifier.iter().map(|m| m.dot_capitalized()); + let modifier_names = iter::once(Ident::new("DotUnified", Span::call_site())) + .chain(all_modifier.iter().map(|m| m.dot_capitalized())); quote! { impl<'input> #type_name<'input> { fn opcode_text(self) -> Option<&'static str> { @@ -550,7 +557,16 @@ fn emit_definition_parser( } } } - DotModifierRef::Direct { type_: Some(_), .. } => { todo!() } + DotModifierRef::Direct { optional: false, type_: Some(type_), name, value } => { + let variable = name.ident(); + let variant = value.dot_capitalized(); + let parsed_variant = value.variant_capitalized(); + quote! { + any.verify(|t| *t == #token_type :: #variant).parse_next(&mut stream)?; + #variable = #type_ :: #parsed_variant; + } + } + DotModifierRef::Direct { optional: true, type_: Some(_), .. } => { todo!() } DotModifierRef::Indirect { optional, value, .. } => { let variants = value.alternatives.iter().map(|alt| { let type_ = value.type_.as_ref().unwrap(); @@ -669,7 +685,7 @@ fn emit_definition_parser( DotModifierRef::Direct { optional: false, name, - type_: Some(type_), + type_: Some(_), .. } => { let variable = name.ident(); @@ -700,11 +716,11 @@ fn emit_definition_parser( let comma = if idx == 0 { quote! { empty } } else { - quote! { any.verify(|t| *t == #token_type::Comma) } + quote! { any.verify(|t| *t == #token_type::Comma).void() } }; let pre_bracket = if arg.pre_bracket { quote! { - any.verify(|t| *t == #token_type::LBracket).map(|_| ()) + any.verify(|t| *t == #token_type::LBracket).void() } } else { quote! { @@ -713,7 +729,7 @@ fn emit_definition_parser( }; let pre_pipe = if arg.pre_pipe { quote! { - any.verify(|t| *t == #token_type::Or).map(|_| ()) + any.verify(|t| *t == #token_type::Or).void() } } else { quote! { @@ -736,24 +752,42 @@ fn emit_definition_parser( }; let post_bracket = if arg.post_bracket { quote! { - any.verify(|t| *t == #token_type::RBracket).map(|_| ()) + any.verify(|t| *t == #token_type::RBracket).void() } } else { quote! { empty } }; - let parser = quote! { - (#comma, #pre_bracket, #pre_pipe, #can_be_negated, #operand, #post_bracket) - }; - let arg_name = &arg.ident; - if arg.optional { + let unified = if arg.unified { quote! { - let #arg_name = opt(#parser.map(|(_, _, _, _, name, _)| name)).parse_next(stream)?; + opt(any.verify(|t| *t == #token_type::DotUnified).void()).map(|u| u.is_some()) } } else { quote! { - let #arg_name = #parser.map(|(_, _, _, _, name, _)| name).parse_next(stream)?; + empty + } + }; + let pattern = quote! { + (#comma, #pre_bracket, #pre_pipe, #can_be_negated, #operand, #post_bracket, #unified) + }; + let arg_name = &arg.ident; + let inner_parser = if arg.unified { + quote! { + #pattern.map(|(_, _, _, _, name, _, unified)| (name, unified)) + } + } else { + quote! { + #pattern.map(|(_, _, _, _, name, _, _)| name) + } + }; + if arg.optional { + quote! { + let #arg_name = opt(#inner_parser).parse_next(stream)?; + } + } else { + quote! { + let #arg_name = #inner_parser.parse_next(stream)?; } } }); @@ -812,6 +846,10 @@ fn write_definitions_into_tokens<'a>( }; variants.push(arg); } + variants.push(parse_quote! { + #[token(".unified")] + DotUnified + }); (all_opcodes, all_modifiers) } diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index c45a241..a471b8e 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -25,6 +25,46 @@ pub enum StCacheOperator { Writethrough, } +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum LdCacheOperator { + Cached, + L2Only, + Streaming, + LastUse, + Uncached, +} + + + +#[derive(Copy, Clone)] +pub enum ArithDetails { + Integer(ArithInteger), + Float(ArithFloat), +} + +impl ArithDetails { + pub fn type_(&self) -> super::ScalarType { + match self { + ArithDetails::Integer(t) => t.type_, + ArithDetails::Float(arith) => arith.type_, + } + } +} + +#[derive(Copy, Clone)] +pub struct ArithInteger { + pub type_: super::ScalarType, + pub saturate: bool, +} + +#[derive(Copy, Clone)] +pub struct ArithFloat { + pub type_: super::ScalarType, + pub rounding: Option, + pub flush_to_zero: Option, + pub saturate: bool, +} + #[derive(Copy, Clone, PartialEq, Eq)] pub enum LdStQualifier { Weak, @@ -33,3 +73,11 @@ pub enum LdStQualifier { Acquire(MemScope), Release(MemScope), } + +#[derive(PartialEq, Eq, Copy, Clone)] +pub enum RoundingMode { + NearestEven, + Zero, + NegativeInf, + PositiveInf, +} diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/main.rs index dd9e6d2..eb137a5 100644 --- a/ptx_parser/src/main.rs +++ b/ptx_parser/src/main.rs @@ -73,7 +73,7 @@ gen::generate_instruction_type!( }, Add { type: { data.type_().into() }, - data: ArithDetails, + data: ast::ArithDetails, arguments: { dst: T, src1: T, @@ -101,7 +101,7 @@ gen::generate_instruction_type!( pub struct LdDetails { pub qualifier: ast::LdStQualifier, pub state_space: StateSpace, - pub caching: LdCacheOperator, + pub caching: ast::LdCacheOperator, pub typ: Type, pub non_coherent: bool, } @@ -145,15 +145,6 @@ pub enum RoundingMode { PositiveInf, } -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum LdCacheOperator { - Cached, - L2Only, - Streaming, - LastUse, - Uncached, -} - #[derive(PartialEq, Eq, Clone, Hash)] pub enum Type { // .param.b32 foo; @@ -203,6 +194,18 @@ impl From for ast::StCacheOperator { } } +impl From for ast::LdCacheOperator { + fn from(value: RawLdCacheOperator) -> Self { + match value { + RawLdCacheOperator::Ca => ast::LdCacheOperator::Cached, + RawLdCacheOperator::Cg => ast::LdCacheOperator::L2Only, + RawLdCacheOperator::Cs => ast::LdCacheOperator::Streaming, + RawLdCacheOperator::Lu => ast::LdCacheOperator::LastUse, + RawLdCacheOperator::Cv => ast::LdCacheOperator::Uncached, + } + } +} + impl From for ast::LdStQualifier { fn from(value: RawLdStQualifier) -> Self { match value { @@ -212,6 +215,17 @@ impl From for ast::LdStQualifier { } } +impl From for ast::RoundingMode { + fn from(value: RawFloatRounding) -> Self { + match value { + RawFloatRounding::Rn => ast::RoundingMode::NearestEven, + RawFloatRounding::Rz => ast::RoundingMode::Zero, + RawFloatRounding::Rm => ast::RoundingMode::NegativeInf, + RawFloatRounding::Rp => ast::RoundingMode::PositiveInf, + } + } +} + type PtxParserState = Vec; type PtxParser<'a, 'input> = Stateful<&'a [Token<'input>], PtxParserState>; @@ -334,6 +348,12 @@ fn immediate_value<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult( + stream: &mut PtxParser<'a, 'input>, +) -> PResult>>> { + repeat(3.., terminated(parse_instruction, Token::Semicolon)).parse_next(stream) +} + impl ast::ParsedOperand { fn parse<'a, 'input>( stream: &mut PtxParser<'a, 'input>, @@ -518,7 +538,7 @@ impl<'input, I: Stream + StreamIsPartial, E: ParserError> Parse for Token<'input> { fn parse_next(&mut self, input: &mut I) -> PResult { - any.parse_next(input) + any.verify(|t| t == self).parse_next(input) } } @@ -540,14 +560,14 @@ derive_parser!( Comma, #[token(".")] Dot, + #[token(";")] + Semicolon, #[regex(r"[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+", |lex| lex.slice(), priority = 0)] Ident(&'input str), #[token("|")] Or, #[token("!")] Not, - #[token(";")] - Semicolon, #[token("[")] LBracket, #[token("]")] @@ -675,23 +695,82 @@ derive_parser!( // https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-ld ld{.weak}{.ss}{.cop}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{.unified}{, cache_policy} => { - todo!() + let (a, unified) = a; + if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || unified || cache_policy.is_some() { + state.push(PtxError::Todo); + } + Instruction::Ld { + data: LdDetails { + qualifier: weak.unwrap_or(RawLdStQualifier::Weak).into(), + state_space: ss.unwrap_or(StateSpace::Generic), + caching: cop.unwrap_or(RawLdCacheOperator::Ca).into(), + typ: Type::maybe_vector(vec, type_), + non_coherent: false + }, + arguments: LdArgs { dst:d, src:a } + } } ld.volatile{.ss}{.level::prefetch_size}{.vec}.type d, [a] => { - todo!() + if level_prefetch_size.is_some() { + state.push(PtxError::Todo); + } + Instruction::Ld { + data: LdDetails { + qualifier: volatile.into(), + state_space: ss.unwrap_or(StateSpace::Generic), + caching: ast::LdCacheOperator::Cached, + typ: Type::maybe_vector(vec, type_), + non_coherent: false + }, + arguments: LdArgs { dst:d, src:a } + } } ld.relaxed.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{, cache_policy} => { - todo!() + if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || cache_policy.is_some() { + state.push(PtxError::Todo); + } + Instruction::Ld { + data: LdDetails { + qualifier: ast::LdStQualifier::Relaxed(scope), + state_space: ss.unwrap_or(StateSpace::Generic), + caching: ast::LdCacheOperator::Cached, + typ: Type::maybe_vector(vec, type_), + non_coherent: false + }, + arguments: LdArgs { dst:d, src:a } + } } ld.acquire.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{, cache_policy} => { - todo!() + if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || cache_policy.is_some() { + state.push(PtxError::Todo); + } + Instruction::Ld { + data: LdDetails { + qualifier: ast::LdStQualifier::Acquire(scope), + state_space: ss.unwrap_or(StateSpace::Generic), + caching: ast::LdCacheOperator::Cached, + typ: Type::maybe_vector(vec, type_), + non_coherent: false + }, + arguments: LdArgs { dst:d, src:a } + } } ld.mmio.relaxed.sys{.global}.type d, [a] => { - todo!() + state.push(PtxError::Todo); + Instruction::Ld { + data: LdDetails { + qualifier: ast::LdStQualifier::Relaxed(MemScope::Sys), + state_space: global.unwrap_or(StateSpace::Generic), + caching: ast::LdCacheOperator::Cached, + typ: type_.into(), + non_coherent: false + }, + arguments: LdArgs { dst:d, src:a } + } } .ss: StateSpace = { .const, .global, .local, .param{::entry, ::func}, .shared{::cta, ::cluster} }; - .cop: RawCacheOp = { .ca, .cg, .cs, .lu, .cv }; + .cop: RawLdCacheOperator = { .ca, .cg, .cs, .lu, .cv }; .level::eviction_priority: EvictionPriority = { .L1::evict_normal, .L1::evict_unchanged, .L1::evict_first, .L1::evict_last, .L1::no_allocate }; .level::cache_hint = { .L2::cache_hint }; @@ -702,47 +781,144 @@ derive_parser!( .u8, .u16, .u32, .u64, .s8, .s16, .s32, .s64, .f32, .f64 }; + RawLdStQualifier = { .weak, .volatile }; + StateSpace = { .global }; // https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-add add.type d, a, b => { - todo!() + Instruction::Add { + data: ast::ArithDetails::Integer( + ast::ArithInteger { + type_, + saturate: false + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } } add{.sat}.s32 d, a, b => { - todo!() + Instruction::Add { + data: ast::ArithDetails::Integer( + ast::ArithInteger { + type_: s32, + saturate: sat + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } } .type: ScalarType = { .u16, .u32, .u64, .s16, .s64, .u16x2, .s16x2 }; + ScalarType = { .s32 }; // https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-add add{.rnd}{.ftz}{.sat}.f32 d, a, b => { - todo!() + Instruction::Add { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f32, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } } add{.rnd}.f64 d, a, b => { - todo!() + Instruction::Add { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f64, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } } .rnd: RawFloatRounding = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; // https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-add add{.rnd}{.ftz}{.sat}.f16 d, a, b => { - todo!() + Instruction::Add { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f16, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } } add{.rnd}{.ftz}{.sat}.f16x2 d, a, b => { - todo!() + Instruction::Add { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f16x2, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } } add{.rnd}.bf16 d, a, b => { - todo!() + Instruction::Add { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: bf16, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } } add{.rnd}.bf16x2 d, a, b => { - todo!() + Instruction::Add { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: bf16x2, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } } .rnd: RawFloatRounding = { .rn }; + ScalarType = { .f16, .f16x2, .bf16, .bf16x2 }; - ret => { - todo!() + ret{.uni} => { + Instruction::Ret { data: RetData { uniform: uni } } } ); @@ -776,7 +952,8 @@ fn main() { input: &tokens[..], state: Vec::new(), }; - parse_instruction(&mut stream).unwrap(); + let fn_body = fn_body.parse(stream).unwrap(); + println!("{}", fn_body.len()); //parse_prefix(&mut lexer); let mut parser = &*tokens; println!("{}", mem::size_of::());