Fully parse operands

This commit is contained in:
Andrzej Janik 2024-08-15 03:26:38 +02:00
parent a05bee9ccb
commit 8d7c88c095
5 changed files with 323 additions and 28 deletions

View file

@ -174,9 +174,9 @@ impl SingleOpcodeDefinition {
.chain(self.arguments.0.iter().map(|arg| {
let name = &arg.ident;
if arg.optional {
quote! { #name : Option<&str> }
quote! { #name : Option<ParsedOperand<'input>> }
} else {
quote! { #name : &str }
quote! { #name : ParsedOperand<'input> }
}
}))
}
@ -377,7 +377,7 @@ fn emit_parse_function(
let code_block = &def.code_block.0;
let args = def.function_arguments_declarations();
quote! {
fn #fn_name( #(#args),* ) -> Instruction<ParsedOperand> #code_block
fn #fn_name<'input>( #(#args),* ) -> Instruction<ParsedOperand<'input>> #code_block
}
})
})
@ -452,7 +452,7 @@ fn emit_parse_function(
#(#fns_)*
fn parse_instruction<'input>(stream: &mut (impl winnow::stream::Stream<Token = #type_name<'input>, Slice = &'input [#type_name<'input>]> + winnow::stream::StreamIsPartial)) -> winnow::error::PResult<Instruction<ParsedOperand>>
fn parse_instruction<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> winnow::error::PResult<Instruction<ParsedOperand<'input>>>
{
use winnow::Parser;
use winnow::token::*;
@ -642,9 +642,9 @@ fn emit_definition_parser(
empty
}
};
let ident = {
let operand = {
quote! {
any.verify_map(|t| match t { #token_type::Ident(s) => Some(s), _ => None })
ParsedOperand::parse
}
};
let post_bracket = if arg.post_bracket {
@ -657,7 +657,7 @@ fn emit_definition_parser(
}
};
let parser = quote! {
(#comma, #pre_bracket, #pre_pipe, #can_be_negated, #ident, #post_bracket)
(#comma, #pre_bracket, #pre_pipe, #can_be_negated, #operand, #post_bracket)
};
let arg_name = &arg.ident;
if arg.optional {

View file

@ -67,7 +67,7 @@ impl GenerateInstructionType {
let visit_slice_fn = format_ident!("visit{}_slice", kind.fn_suffix());
let (type_parameters, visitor_parameters, return_type) = if kind == VisitKind::Map {
(
quote! { <#type_parameters, To: Operand> },
quote! { <#type_parameters, To> },
quote! { <#short_parameters, To> },
quote! { #type_name<To> },
)

View file

@ -7,3 +7,4 @@ edition = "2021"
logos = "0.14"
winnow = { version = "0.6.18", features = ["debug"] }
gen = { path = "../gen" }
thiserror = "1.0"

16
ptx_parser/src/ast.rs Normal file
View file

@ -0,0 +1,16 @@
#[derive(Clone)]
pub enum ParsedOperand<Ident> {
Reg(Ident),
RegOffset(Ident, i32),
Imm(ImmediateValue),
VecMember(Ident, u8),
VecPack(Vec<Ident>),
}
#[derive(Copy, Clone)]
pub enum ImmediateValue {
U64(u64),
S64(i64),
F32(f32),
F64(f64),
}

View file

@ -1,27 +1,66 @@
use gen::derive_parser;
use logos::Logos;
use std::mem;
use std::num::{ParseFloatError, ParseIntError};
use winnow::combinator::{alt, empty, fail, opt};
use winnow::stream::SliceLen;
use winnow::token::{any, literal};
use winnow::{
error::{ContextError, ParserError},
stream::{Offset, Stream, StreamIsPartial},
PResult,
};
use winnow::{prelude::*, Stateful};
mod ast;
pub trait Operand {}
pub trait Visitor<T: Operand> {
pub trait Visitor<T> {
fn visit(&mut self, args: &T, type_: &Type, space: StateSpace, is_dst: bool);
}
pub trait VisitorMut<T: Operand> {
pub trait VisitorMut<T> {
fn visit(&mut self, args: &mut T, type_: &Type, space: StateSpace, is_dst: bool);
}
pub trait VisitorMap<From: Operand, To: Operand> {
pub trait VisitorMap<From, To> {
fn visit(&mut self, args: From, type_: &Type, space: StateSpace, is_dst: bool) -> To;
}
#[derive(Clone)]
pub struct MovDetails {
pub typ: Type,
pub src_is_address: bool,
// two fields below are in use by member moves
pub dst_width: u8,
pub src_width: u8,
// This is in use by auto-generated movs
pub relaxed_src2_conv: bool,
}
impl MovDetails {
pub fn new(typ: Type) -> Self {
MovDetails {
typ,
src_is_address: false,
dst_width: 0,
src_width: 0,
relaxed_src2_conv: false,
}
}
}
gen::generate_instruction_type!(
enum Instruction<T: Operand> {
enum Instruction<T> {
Mov {
type: { &data.typ },
data: MovDetails,
arguments<T>: {
dst: T,
src: T
}
},
Ld {
type: { &data.typ },
data: LdDetails,
@ -161,9 +200,212 @@ pub struct RetData {
pub uniform: bool,
}
pub struct ParsedOperand {}
type ParserState<'a, 'input> = Stateful<&'a [Token<'input>], Vec<PtxError>>;
impl Operand for ParsedOperand {}
fn ident<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> PResult<&'input str> {
any.verify_map(|t| {
if let Token::Ident(text) = t {
Some(text)
} else {
None
}
})
.parse_next(stream)
}
fn num<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> PResult<(&'input str, u32, bool)> {
any.verify_map(|t| {
Some(match t {
Token::Hex(s) => {
if s.ends_with('U') {
(&s[2..s.len() - 1], 16, true)
} else {
(&s[2..], 16, false)
}
}
Token::Decimal(s) => {
let radix = if s.starts_with('0') { 8 } else { 10 };
if s.ends_with('U') {
(&s[..s.len() - 1], radix, true)
} else {
(s, radix, false)
}
}
_ => return None,
})
})
.parse_next(stream)
}
fn take_error<'a, 'input: 'a, O, E>(
mut parser: impl Parser<ParserState<'a, 'input>, Result<O, (O, PtxError)>, E>,
) -> impl Parser<ParserState<'a, 'input>, O, E> {
move |input: &mut ParserState<'a, 'input>| {
Ok(match parser.parse_next(input)? {
Ok(x) => x,
Err((x, err)) => {
input.state.push(err);
x
}
})
}
}
fn int_immediate<'a, 'input>(input: &mut ParserState<'a, 'input>) -> PResult<ast::ImmediateValue> {
take_error((opt(Token::Minus), num).map(|(neg, x)| {
let (num, radix, is_unsigned) = x;
if neg.is_some() {
match i64::from_str_radix(num, radix) {
Ok(x) => Ok(ast::ImmediateValue::S64(-x)),
Err(err) => Err((ast::ImmediateValue::S64(0), PtxError::from(err))),
}
} else if is_unsigned {
match u64::from_str_radix(num, radix) {
Ok(x) => Ok(ast::ImmediateValue::U64(x)),
Err(err) => Err((ast::ImmediateValue::U64(0), PtxError::from(err))),
}
} else {
match i64::from_str_radix(num, radix) {
Ok(x) => Ok(ast::ImmediateValue::S64(x)),
Err(_) => match u64::from_str_radix(num, radix) {
Ok(x) => Ok(ast::ImmediateValue::U64(x)),
Err(err) => Err((ast::ImmediateValue::U64(0), PtxError::from(err))),
},
}
}
}))
.parse_next(input)
}
fn f32<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> PResult<f32> {
take_error(any.verify_map(|t| match t {
Token::F32(f) => Some(match u32::from_str_radix(&f[2..], 16) {
Ok(x) => Ok(f32::from_bits(x)),
Err(err) => Err((0.0, PtxError::from(err))),
}),
_ => None,
}))
.parse_next(stream)
}
fn f64<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> PResult<f64> {
take_error(any.verify_map(|t| match t {
Token::F64(f) => Some(match u64::from_str_radix(&f[2..], 16) {
Ok(x) => Ok(f64::from_bits(x)),
Err(err) => Err((0.0, PtxError::from(err))),
}),
_ => None,
}))
.parse_next(stream)
}
fn s32<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> PResult<i32> {
take_error((opt(Token::Minus), num).map(|(sign, x)| {
let (text, radix, _) = x;
match i32::from_str_radix(text, radix) {
Ok(x) => Ok(if sign.is_some() { -x } else { x }),
Err(err) => Err((0, PtxError::from(err))),
}
}))
.parse_next(stream)
}
fn immediate_value<'a, 'input>(
stream: &mut ParserState<'a, 'input>,
) -> PResult<ast::ImmediateValue> {
alt((
int_immediate,
f32.map(ast::ImmediateValue::F32),
f64.map(ast::ImmediateValue::F64),
))
.parse_next(stream)
}
impl<Ident> ast::ParsedOperand<Ident> {
fn parse<'a, 'input>(
stream: &mut ParserState<'a, 'input>,
) -> PResult<ast::ParsedOperand<&'input str>> {
use winnow::combinator::*;
use winnow::token::any;
fn vector_index<'input>(inp: &'input str) -> Result<u8, PtxError> {
match inp {
"x" | "r" => Ok(0),
"y" | "g" => Ok(1),
"z" | "b" => Ok(2),
"w" | "a" => Ok(3),
_ => Err(PtxError::WrongVectorElement),
}
}
fn ident_operands<'a, 'input>(
stream: &mut ParserState<'a, 'input>,
) -> PResult<ast::ParsedOperand<&'input str>> {
let main_ident = ident.parse_next(stream)?;
alt((
preceded(Token::Plus, s32)
.map(move |offset| ast::ParsedOperand::RegOffset(main_ident, offset)),
take_error(preceded(Token::Dot, ident).map(move |suffix| {
let vector_index = vector_index(suffix)
.map_err(move |e| (ast::ParsedOperand::VecMember(main_ident, 0), e))?;
Ok(ast::ParsedOperand::VecMember(main_ident, vector_index))
})),
empty.value(ast::ParsedOperand::Reg(main_ident)),
))
.parse_next(stream)
}
fn vector_operand<'a, 'input>(
stream: &mut ParserState<'a, 'input>,
) -> PResult<Vec<&'input str>> {
let (_, r1, _, r2) =
(Token::LBracket, ident, Token::Comma, ident).parse_next(stream)?;
dispatch! {any;
Token::LBracket => empty.map(|_| vec![r1, r2]),
Token::Comma => (ident, Token::Comma, ident, Token::LBracket).map(|(r3, _, r4, _)| vec![r1, r2, r3, r4]),
_ => fail
}
.parse_next(stream)
}
alt((
ident_operands,
immediate_value.map(ast::ParsedOperand::Imm),
vector_operand.map(ast::ParsedOperand::VecPack),
))
.parse_next(stream)
}
}
#[derive(Debug, thiserror::Error)]
pub enum PtxError {
#[error("{source}")]
ParseInt {
#[from]
source: ParseIntError,
},
#[error("{source}")]
ParseFloat {
#[from]
source: ParseFloatError,
},
#[error("")]
SyntaxError,
#[error("")]
NonF32Ftz,
#[error("")]
WrongArrayType,
#[error("")]
WrongVectorElement,
#[error("")]
MultiArrayVariable,
#[error("")]
ZeroDimensionArray,
#[error("")]
ArrayInitalizer,
#[error("")]
NonExternPointer,
#[error("{start}:{end}")]
UnrecognizedStatement { start: usize, end: usize },
#[error("{start}:{end}")]
UnrecognizedDirective { start: usize, end: usize },
}
#[derive(Debug)]
struct ReverseStream<'a, T>(pub &'a [T]);
@ -180,24 +422,20 @@ where
type Checkpoint = &'i [T];
#[inline(always)]
fn iter_offsets(&self) -> Self::IterOffsets {
self.0.iter().rev().cloned().enumerate()
}
#[inline(always)]
fn eof_offset(&self) -> usize {
self.0.len()
}
#[inline(always)]
fn next_token(&mut self) -> Option<Self::Token> {
let (token, next) = self.0.split_last()?;
self.0 = next;
Some(token.clone())
}
#[inline(always)]
fn offset_for<P>(&self, predicate: P) -> Option<usize>
where
P: Fn(Self::Token) -> bool,
@ -205,7 +443,6 @@ where
self.0.iter().rev().position(|b| predicate(b.clone()))
}
#[inline(always)]
fn offset_at(&self, tokens: usize) -> Result<usize, winnow::error::Needed> {
if let Some(needed) = tokens
.checked_sub(self.0.len())
@ -217,7 +454,6 @@ where
}
}
#[inline(always)]
fn next_slice(&mut self, offset: usize) -> Self::Slice {
let offset = self.0.len() - offset;
let (next, slice) = self.0.split_at(offset);
@ -225,24 +461,20 @@ where
slice
}
#[inline(always)]
fn checkpoint(&self) -> Self::Checkpoint {
self.0
}
#[inline(always)]
fn reset(&mut self, checkpoint: &Self::Checkpoint) {
self.0 = checkpoint;
}
#[inline(always)]
fn raw(&self) -> &dyn std::fmt::Debug {
self
}
}
impl<'a, T> Offset<&'a [T]> for ReverseStream<'a, T> {
#[inline]
fn offset_from(&self, start: &&'a [T]) -> usize {
let fst = start.as_ptr();
let snd = self.0.as_ptr();
@ -267,6 +499,14 @@ impl<'a, T> StreamIsPartial for ReverseStream<'a, T> {
}
}
impl<'input, I: Stream<Token = Self> + StreamIsPartial, E: ParserError<I>> Parser<I, Self, E>
for Token<'input>
{
fn parse_next(&mut self, input: &mut I) -> PResult<Self, E> {
any.parse_next(input)
}
}
// Modifiers are turned into arguments to the blocks, with type:
// * If it is an alternative:
// * If it is mandatory then its type is Foo (as defined by the relevant rule)
@ -275,12 +515,16 @@ impl<'a, T> StreamIsPartial for ReverseStream<'a, T> {
// * If it is mandatory then it is skipped
// * If it is optional then its type is `bool`
type ParsedOperand<'input> = ast::ParsedOperand<&'input str>;
derive_parser!(
#[derive(Logos, PartialEq, Eq, Debug, Clone, Copy)]
#[logos(skip r"\s+")]
enum Token<'input> {
#[token(",")]
Comma,
#[token(".")]
Dot,
#[regex(r"[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+", |lex| lex.slice(), priority = 0)]
Ident(&'input str),
#[token("|")]
@ -293,8 +537,18 @@ derive_parser!(
LBracket,
#[token("]")]
RBracket,
#[regex(r"[0-9]+U?")]
Decimal
#[regex(r"0[fF][0-9a-zA-Z]{8}", |lex| lex.slice())]
F32(&'input str),
#[regex(r"0[dD][0-9a-zA-Z]{16}", |lex| lex.slice())]
F64(&'input str),
#[regex(r"0[xX][0-9a-zA-Z]+U?", |lex| lex.slice())]
Hex(&'input str),
#[regex(r"[0-9]+U?", |lex| lex.slice())]
Decimal(&'input str),
#[token("-")]
Minus,
#[token("+")]
Plus,
}
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
@ -308,6 +562,20 @@ derive_parser!(
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
pub enum ScalarType { }
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov
mov.type d, a => {
Instruction::Mov{
data: MovDetails::new(type_.into()),
arguments: MovArgs { dst: d, src: a }
}
}
.type: ScalarType = { .pred,
.b16, .b32, .b64,
.u16, .u32, .u64,
.s16, .s32, .s64,
.f32, .f64 };
// https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st
st{.weak}{.ss}{.cop}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => {
todo!()
@ -416,8 +684,15 @@ fn main() {
use winnow::token::*;
use winnow::Parser;
println!("{}", mem::size_of::<Token>());
let mut input: &[Token] = &[][..];
let x = opt(any::<_, ContextError>.verify_map(|t| { println!("MAP");Some(true) })).parse_next(&mut input).unwrap();
let x = opt(any::<_, ContextError>.verify_map(|t| {
println!("MAP");
Some(true)
}))
.parse_next(&mut input)
.unwrap();
dbg!(x);
let lexer = Token::lexer(
"
@ -429,7 +704,10 @@ fn main() {
);
let tokens = lexer.map(|t| t.unwrap()).collect::<Vec<_>>();
println!("{:?}", &tokens);
let mut stream = &tokens[..];
let mut stream = ParserState {
input: &tokens[..],
state: Vec::new(),
};
parse_instruction(&mut stream).unwrap();
//parse_prefix(&mut lexer);
let mut parser = &*tokens;