Compare commits

...

2 commits

Author SHA1 Message Date
Andrzej Janik
a99111720e
Merge 0da45ea7d8 into 872054ae40 2024-08-15 20:25:06 +00:00
Andrzej Janik
0da45ea7d8 Add parsing of st, allow associating type with a non-alternative modifier 2024-08-15 22:24:53 +02:00
4 changed files with 239 additions and 63 deletions

View file

@ -3,7 +3,9 @@ use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote, ToTokens};
use rustc_hash::{FxHashMap, FxHashSet};
use std::{collections::hash_map, hash::Hash, rc::Rc};
use syn::{parse_macro_input, punctuated::Punctuated, Ident, ItemEnum, Token, TypePath, Variant};
use syn::{
parse_macro_input, punctuated::Punctuated, Ident, ItemEnum, Token, Type, TypePath, Variant,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#vectors
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#fundamental-types
@ -46,7 +48,7 @@ impl OpcodeDefinitions {
_ => {}
}
'check_definitions: for i in unselected.iter().copied() {
// just pick the first alternative and attempt every modifier
// Attempt every modifier
'check_candidates: for candidate in definitions[i]
.unordered_modifiers
.iter()
@ -203,32 +205,31 @@ impl SingleOpcodeDefinition {
output: &mut FxHashMap<Ident, Vec<SingleOpcodeDefinition>>,
parser::OpcodeDefinition(pattern_seq, rules): parser::OpcodeDefinition,
) {
let mut rules = rules
.into_iter()
.map(|r| (r.modifier.clone(), Rc::new(r)))
.collect::<FxHashMap<_, _>>();
let (mut named_rules, mut unnamed_rules) = gather_rules(rules);
let mut last_opcode = pattern_seq.0.last().unwrap().0 .0.name.clone();
for (opcode_decl, code_block) in pattern_seq.0.into_iter().rev() {
let current_opcode = opcode_decl.0.name.clone();
if last_opcode != current_opcode {
rules = FxHashMap::default();
named_rules = FxHashMap::default();
unnamed_rules = FxHashMap::default();
}
let mut possible_modifiers = FxHashSet::default();
for (_, options) in rules.iter() {
for (_, options) in named_rules.iter() {
possible_modifiers.extend(options.alternatives.iter().cloned());
}
let parser::OpcodeDecl(instruction, arguments) = opcode_decl;
let mut unordered_modifiers = instruction
.modifiers
.into_iter()
.map(
|parser::MaybeDotModifier { optional, modifier }| match rules.get(&modifier) {
.map(|parser::MaybeDotModifier { optional, modifier }| {
match named_rules.get(&modifier) {
Some(alts) => {
if alts.alternatives.len() == 1 && alts.type_.is_none() {
DotModifierRef::Direct {
optional,
value: alts.alternatives[0].clone(),
name: modifier,
type_: alts.type_.clone(),
}
} else {
DotModifierRef::Indirect {
@ -239,15 +240,17 @@ impl SingleOpcodeDefinition {
}
}
None => {
let type_ = unnamed_rules.get(&modifier).cloned();
possible_modifiers.insert(modifier.clone());
DotModifierRef::Direct {
optional,
value: modifier.clone(),
name: modifier,
type_,
}
}
},
)
}
})
.collect::<Vec<_>>();
let ordered_modifiers = Self::extract_ordered_modifiers(&mut unordered_modifiers);
let entry = Self {
@ -293,6 +296,29 @@ impl SingleOpcodeDefinition {
}
}
fn gather_rules(
rules: Vec<parser::Rule>,
) -> (
FxHashMap<parser::DotModifier, Rc<parser::Rule>>,
FxHashMap<parser::DotModifier, Type>,
) {
let mut named = FxHashMap::default();
let mut unnamed = FxHashMap::default();
for rule in rules {
match rule.modifier {
Some(ref modifier) => {
named.insert(modifier.clone(), Rc::new(rule));
}
None => unnamed.extend(
rule.alternatives
.into_iter()
.map(|alt| (alt, rule.type_.as_ref().unwrap().clone())),
),
}
}
(named, unnamed)
}
#[proc_macro]
pub fn derive_parser(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
let parse_definitions = parse_macro_input!(tokens as gen_impl::parser::ParseDefinitions);
@ -512,7 +538,7 @@ fn emit_definition_parser(
let ordered_parse = definition.ordered_modifiers.iter().rev().map(|modifier| {
let arg_name = modifier.ident();
match modifier {
DotModifierRef::Direct { optional, value, .. } => {
DotModifierRef::Direct { optional, value, type_: None, .. } => {
let variant = value.dot_capitalized();
if *optional {
quote! {
@ -524,6 +550,7 @@ fn emit_definition_parser(
}
}
}
DotModifierRef::Direct { type_: Some(_), .. } => { todo!() }
DotModifierRef::Indirect { optional, value, .. } => {
let variants = value.alternatives.iter().map(|alt| {
let type_ = value.type_.as_ref().unwrap();
@ -566,7 +593,12 @@ fn emit_definition_parser(
.unordered_modifiers
.iter()
.map(|modifier| match modifier {
DotModifierRef::Direct { name, value, .. } => {
DotModifierRef::Direct {
name,
value,
type_: None,
..
} => {
let name = name.ident();
let token_variant = value.dot_capitalized();
quote! {
@ -578,6 +610,24 @@ fn emit_definition_parser(
}
}
}
DotModifierRef::Direct {
name,
value,
type_: Some(type_),
..
} => {
let variable = name.ident();
let token_variant = value.dot_capitalized();
let enum_variant = value.variant_capitalized();
quote! {
#token_type :: #token_variant => {
if #variable.is_some() {
#return_error_ref;
}
#variable = Some(#type_ :: #enum_variant);
}
}
}
DotModifierRef::Indirect { value, name, .. } => {
let variable = name.ident();
let type_ = value.type_.as_ref().unwrap();
@ -606,6 +656,7 @@ fn emit_definition_parser(
DotModifierRef::Direct {
optional: false,
name,
type_: None,
..
} => {
let variable = name.ident();
@ -615,7 +666,20 @@ fn emit_definition_parser(
}
}
}
DotModifierRef::Direct { optional: true, .. } => TokenStream::new(),
DotModifierRef::Direct {
optional: false,
name,
type_: Some(type_),
..
} => {
let variable = name.ident();
quote! {
let #variable = match #variable {
Some(x) => x,
None => #return_error
};
}
}
DotModifierRef::Indirect {
optional: false,
name,
@ -629,7 +693,8 @@ fn emit_definition_parser(
};
}
}
DotModifierRef::Indirect { optional: true, .. } => TokenStream::new(),
DotModifierRef::Direct { optional: true, .. }
| DotModifierRef::Indirect { optional: true, .. } => TokenStream::new(),
});
let arguments_parse = definition.arguments.0.iter().enumerate().map(|(idx, arg)| {
let comma = if idx == 0 {
@ -772,6 +837,7 @@ enum DotModifierRef {
optional: bool,
value: parser::DotModifier,
name: parser::DotModifier,
type_: Option<Type>,
},
Indirect {
optional: bool,
@ -790,10 +856,26 @@ impl DotModifierRef {
fn type_of(&self) -> Option<syn::Type> {
Some(match self {
DotModifierRef::Direct { optional: true, .. } => syn::parse_quote! { bool },
DotModifierRef::Direct {
optional: false, ..
optional: true,
type_: None,
..
} => syn::parse_quote! { bool },
DotModifierRef::Direct {
optional: false,
type_: None,
..
} => return None,
DotModifierRef::Direct {
optional: true,
type_: Some(type_),
..
} => syn::parse_quote! { Option<#type_> },
DotModifierRef::Direct {
optional: false,
type_: Some(type_),
..
} => type_.clone(),
DotModifierRef::Indirect {
optional, value, ..
} => {
@ -812,7 +894,10 @@ impl DotModifierRef {
fn type_of_check(&self) -> syn::Type {
match self {
DotModifierRef::Direct { .. } => syn::parse_quote! { bool },
DotModifierRef::Direct { type_: None, .. } => syn::parse_quote! { bool },
DotModifierRef::Direct {
type_: Some(type_), ..
} => syn::parse_quote! { Option<#type_> },
DotModifierRef::Indirect { value, .. } => {
let type_ = value
.type_

View file

@ -97,7 +97,7 @@ impl Parse for CodeBlock {
}
pub struct Rule {
pub modifier: DotModifier,
pub modifier: Option<DotModifier>,
pub type_: Option<Type>,
pub alternatives: Vec<DotModifier>,
}
@ -105,6 +105,7 @@ pub struct Rule {
impl Rule {
fn peek(input: syn::parse::ParseStream) -> bool {
DotModifier::peek(input)
|| (input.peek(Ident) && input.peek2(Token![=]) && !input.peek3(Token![>]))
}
fn parse_alternatives(input: syn::parse::ParseStream) -> syn::Result<Vec<DotModifier>> {
@ -181,12 +182,16 @@ impl Parse for IdentOrTypeSuffix {
impl Parse for Rule {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let modifier = input.parse::<DotModifier>()?;
let type_ = if input.peek(Token![:]) {
input.parse::<Token![:]>()?;
Some(input.parse::<Type>()?)
let (modifier, type_) = if DotModifier::peek(input) {
let modifier = Some(input.parse::<DotModifier>()?);
if input.peek(Token![:]) {
input.parse::<Token![:]>()?;
(modifier, Some(input.parse::<Type>()?))
} else {
(modifier, None)
}
} else {
None
(None, Some(input.parse::<Type>()?))
};
input.parse::<Token![=]>()?;
let content;

View file

@ -1,3 +1,5 @@
use super::MemScope;
#[derive(Clone)]
pub enum ParsedOperand<Ident> {
Reg(Ident),
@ -14,3 +16,20 @@ pub enum ImmediateValue {
F32(f32),
F64(f64),
}
#[derive(Copy, Clone, PartialEq, Eq)]
pub enum StCacheOperator {
Writeback,
L2Only,
Streaming,
Writethrough,
}
#[derive(Copy, Clone, PartialEq, Eq)]
pub enum LdStQualifier {
Weak,
Volatile,
Relaxed(MemScope),
Acquire(MemScope),
Release(MemScope),
}

View file

@ -39,9 +39,9 @@ pub struct MovDetails {
}
impl MovDetails {
pub fn new(typ: Type) -> Self {
fn new(vector: Option<VectorPrefix>, scalar: ScalarType) -> Self {
MovDetails {
typ,
typ: Type::maybe_vector(vector, scalar),
src_is_address: false,
dst_width: 0,
src_width: 0,
@ -99,7 +99,7 @@ gen::generate_instruction_type!(
);
pub struct LdDetails {
pub qualifier: LdStQualifier,
pub qualifier: ast::LdStQualifier,
pub state_space: StateSpace,
pub caching: LdCacheOperator,
pub typ: Type,
@ -164,41 +164,54 @@ pub enum Type {
Array(ScalarType, Vec<u32>),
}
impl Type {
fn maybe_vector(vector: Option<VectorPrefix>, scalar: ScalarType) -> Self {
match vector {
Some(VectorPrefix::V2) => Type::Vector(scalar, 2),
Some(VectorPrefix::V4) => Type::Vector(scalar, 4),
None => Type::Scalar(scalar),
}
}
}
impl From<ScalarType> for Type {
fn from(value: ScalarType) -> Self {
Type::Scalar(value)
}
}
#[derive(Copy, Clone, PartialEq, Eq)]
pub enum LdStQualifier {
Weak,
Volatile,
Relaxed(MemScope),
Acquire(MemScope),
Release(MemScope),
}
pub struct StData {
pub qualifier: LdStQualifier,
pub qualifier: ast::LdStQualifier,
pub state_space: StateSpace,
pub caching: StCacheOperator,
pub caching: ast::StCacheOperator,
pub typ: Type,
}
#[derive(PartialEq, Eq)]
pub enum StCacheOperator {
Writeback,
L2Only,
Streaming,
Writethrough,
}
#[derive(Copy, Clone)]
pub struct RetData {
pub uniform: bool,
}
impl From<RawStCacheOperator> for ast::StCacheOperator {
fn from(value: RawStCacheOperator) -> Self {
match value {
RawStCacheOperator::Wb => ast::StCacheOperator::Writeback,
RawStCacheOperator::Cg => ast::StCacheOperator::L2Only,
RawStCacheOperator::Cs => ast::StCacheOperator::Streaming,
RawStCacheOperator::Wt => ast::StCacheOperator::Writethrough,
}
}
}
impl From<RawLdStQualifier> for ast::LdStQualifier {
fn from(value: RawLdStQualifier) -> Self {
match value {
RawLdStQualifier::Weak => ast::LdStQualifier::Weak,
RawLdStQualifier::Volatile => ast::LdStQualifier::Volatile,
}
}
}
type PtxParserState = Vec<PtxError>;
type PtxParser<'a, 'input> = Stateful<&'a [Token<'input>], PtxParserState>;
@ -312,9 +325,7 @@ fn s32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<i32> {
.parse_next(stream)
}
fn immediate_value<'a, 'input>(
stream: &mut PtxParser<'a, 'input>,
) -> PResult<ast::ImmediateValue> {
fn immediate_value<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::ImmediateValue> {
alt((
int_immediate,
f32.map(ast::ImmediateValue::F32),
@ -388,6 +399,8 @@ pub enum PtxError {
source: ParseFloatError,
},
#[error("")]
Todo,
#[error("")]
SyntaxError,
#[error("")]
NonF32Ftz,
@ -555,7 +568,8 @@ derive_parser!(
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
pub enum StateSpace {
Reg
Reg,
Generic,
}
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
@ -565,33 +579,84 @@ derive_parser!(
pub enum ScalarType { }
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov
mov.type d, a => {
mov{.vec}.type d, a => {
Instruction::Mov {
data: MovDetails::new(type_.into()),
data: MovDetails::new(vec, type_),
arguments: MovArgs { dst: d, src: a },
}
}
.type: ScalarType = { .pred,
.b16, .b32, .b64,
.u16, .u32, .u64,
.s16, .s32, .s64,
.f32, .f64 };
.vec: VectorPrefix = { .v2, .v4 };
.type: ScalarType = { .pred,
.b16, .b32, .b64,
.u16, .u32, .u64,
.s16, .s32, .s64,
.f32, .f64 };
// https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st
st{.weak}{.ss}{.cop}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => {
todo!()
if level_eviction_priority.is_some() || level_cache_hint || cache_policy.is_some() {
state.push(PtxError::Todo);
}
Instruction::St {
data: StData {
qualifier: weak.unwrap_or(RawLdStQualifier::Weak).into(),
state_space: ss.unwrap_or(StateSpace::Generic),
caching: cop.unwrap_or(RawStCacheOperator::Wb).into(),
typ: Type::maybe_vector(vec, type_)
},
arguments: StArgs { src1:a, src2:b }
}
}
st.volatile{.ss}{.vec}.type [a], b => {
todo!()
Instruction::St {
data: StData {
qualifier: volatile.into(),
state_space: ss.unwrap_or(StateSpace::Generic),
caching: ast::StCacheOperator::Writeback,
typ: Type::maybe_vector(vec, type_)
},
arguments: StArgs { src1:a, src2:b }
}
}
st.relaxed.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => {
todo!()
if level_eviction_priority.is_some() || level_cache_hint || cache_policy.is_some() {
state.push(PtxError::Todo);
}
Instruction::St {
data: StData {
qualifier: ast::LdStQualifier::Relaxed(scope),
state_space: ss.unwrap_or(StateSpace::Generic),
caching: ast::StCacheOperator::Writeback,
typ: Type::maybe_vector(vec, type_)
},
arguments: StArgs { src1:a, src2:b }
}
}
st.release.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => {
todo!()
if level_eviction_priority.is_some() || level_cache_hint || cache_policy.is_some() {
state.push(PtxError::Todo);
}
Instruction::St {
data: StData {
qualifier: ast::LdStQualifier::Release(scope),
state_space: ss.unwrap_or(StateSpace::Generic),
caching: ast::StCacheOperator::Writeback,
typ: Type::maybe_vector(vec, type_)
},
arguments: StArgs { src1:a, src2:b }
}
}
st.mmio.relaxed.sys{.global}.type [a], b => {
todo!()
state.push(PtxError::Todo);
Instruction::St {
data: StData {
qualifier: ast::LdStQualifier::Relaxed(MemScope::Sys),
state_space: global.unwrap_or(StateSpace::Generic),
caching: ast::StCacheOperator::Writeback,
typ: type_.into()
},
arguments: StArgs { src1:a, src2:b }
}
}
.ss: StateSpace = { .global, .local, .param{::func}, .shared{::cta, ::cluster} };
@ -605,6 +670,8 @@ derive_parser!(
.u8, .u16, .u32, .u64,
.s8, .s16, .s32, .s64,
.f32, .f64 };
RawLdStQualifier = { .weak, .volatile };
StateSpace = { .global };
// https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-ld
ld{.weak}{.ss}{.cop}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{.unified}{, cache_policy} => {