Move all types to a separate module

This commit is contained in:
Andrzej Janik 2024-08-16 18:29:13 +02:00
commit 91dbbb372b
4 changed files with 168 additions and 199 deletions

View file

@ -1,4 +1,113 @@
use super::MemScope;
use super::{MemScope, ScalarType, VectorPrefix, StateSpace};
gen::generate_instruction_type!(
pub enum Instruction<T> {
Mov {
type: { &data.typ },
data: MovDetails,
arguments<T>: {
dst: T,
src: T
}
},
Ld {
type: { &data.typ },
data: LdDetails,
arguments<T>: {
dst: T,
src: {
repr: T,
space: { data.state_space },
}
}
},
Add {
type: { data.type_().into() },
data: ArithDetails,
arguments<T>: {
dst: T,
src1: T,
src2: T,
}
},
St {
type: { &data.typ },
data: StData,
arguments<T>: {
src1: {
repr: T,
space: { data.state_space },
},
src2: T,
}
},
Ret {
data: RetData
},
Trap { }
}
);
pub trait Visitor<T> {
fn visit(&mut self, args: &T, type_: &Type, space: StateSpace, is_dst: bool);
}
pub trait VisitorMut<T> {
fn visit(&mut self, args: &mut T, type_: &Type, space: StateSpace, is_dst: bool);
}
pub trait VisitorMap<From, To> {
fn visit(&mut self, args: From, type_: &Type, space: StateSpace, is_dst: bool) -> To;
}
#[derive(PartialEq, Eq, Clone, Hash)]
pub enum Type {
// .param.b32 foo;
Scalar(ScalarType),
// .param.v2.b32 foo;
Vector(ScalarType, u8),
// .param.b32 foo[4];
Array(ScalarType, Vec<u32>),
}
impl Type {
pub(crate) fn maybe_vector(vector: Option<VectorPrefix>, scalar: ScalarType) -> Self {
match vector {
Some(VectorPrefix::V2) => Type::Vector(scalar, 2),
Some(VectorPrefix::V4) => Type::Vector(scalar, 4),
None => Type::Scalar(scalar),
}
}
}
impl From<ScalarType> for Type {
fn from(value: ScalarType) -> Self {
Type::Scalar(value)
}
}
#[derive(Clone)]
pub struct MovDetails {
pub typ: super::Type,
pub src_is_address: bool,
// two fields below are in use by member moves
pub dst_width: u8,
pub src_width: u8,
// This is in use by auto-generated movs
pub relaxed_src2_conv: bool,
}
impl MovDetails {
pub(crate) fn new(vector: Option<VectorPrefix>, scalar: ScalarType) -> Self {
MovDetails {
typ: Type::maybe_vector(vector, scalar),
src_is_address: false,
dst_width: 0,
src_width: 0,
relaxed_src2_conv: false,
}
}
}
#[derive(Clone)]
pub enum ParsedOperand<Ident> {
@ -81,3 +190,24 @@ pub enum RoundingMode {
NegativeInf,
PositiveInf,
}
pub struct LdDetails {
pub qualifier: LdStQualifier,
pub state_space: StateSpace,
pub caching: LdCacheOperator,
pub typ: Type,
pub non_coherent: bool,
}
pub struct StData {
pub qualifier: LdStQualifier,
pub state_space: StateSpace,
pub caching: StCacheOperator,
pub typ: Type,
}
#[derive(Copy, Clone)]
pub struct RetData {
pub uniform: bool,
}

View file

@ -12,176 +12,7 @@ use winnow::{
use winnow::{prelude::*, Stateful};
mod ast;
pub trait Operand {}
pub trait Visitor<T> {
fn visit(&mut self, args: &T, type_: &Type, space: StateSpace, is_dst: bool);
}
pub trait VisitorMut<T> {
fn visit(&mut self, args: &mut T, type_: &Type, space: StateSpace, is_dst: bool);
}
pub trait VisitorMap<From, To> {
fn visit(&mut self, args: From, type_: &Type, space: StateSpace, is_dst: bool) -> To;
}
#[derive(Clone)]
pub struct MovDetails {
pub typ: Type,
pub src_is_address: bool,
// two fields below are in use by member moves
pub dst_width: u8,
pub src_width: u8,
// This is in use by auto-generated movs
pub relaxed_src2_conv: bool,
}
impl MovDetails {
fn new(vector: Option<VectorPrefix>, scalar: ScalarType) -> Self {
MovDetails {
typ: Type::maybe_vector(vector, scalar),
src_is_address: false,
dst_width: 0,
src_width: 0,
relaxed_src2_conv: false,
}
}
}
gen::generate_instruction_type!(
enum Instruction<T> {
Mov {
type: { &data.typ },
data: MovDetails,
arguments<T>: {
dst: T,
src: T
}
},
Ld {
type: { &data.typ },
data: LdDetails,
arguments<T>: {
dst: T,
src: {
repr: T,
space: { data.state_space },
}
}
},
Add {
type: { data.type_().into() },
data: ast::ArithDetails,
arguments<T>: {
dst: T,
src1: T,
src2: T,
}
},
St {
type: { &data.typ },
data: StData,
arguments<T>: {
src1: {
repr: T,
space: { data.state_space },
},
src2: T,
}
},
Ret {
data: RetData
},
Trap { }
}
);
pub struct LdDetails {
pub qualifier: ast::LdStQualifier,
pub state_space: StateSpace,
pub caching: ast::LdCacheOperator,
pub typ: Type,
pub non_coherent: bool,
}
#[derive(Copy, Clone)]
pub enum ArithDetails {
Unsigned(ScalarType),
Signed(ArithSInt),
Float(ArithFloat),
}
impl ArithDetails {
fn type_(&self) -> ScalarType {
match self {
ArithDetails::Unsigned(t) => *t,
ArithDetails::Signed(arith) => arith.typ,
ArithDetails::Float(arith) => arith.typ,
}
}
}
#[derive(Copy, Clone)]
pub struct ArithSInt {
pub typ: ScalarType,
pub saturate: bool,
}
#[derive(Copy, Clone)]
pub struct ArithFloat {
pub typ: ScalarType,
pub rounding: Option<RoundingMode>,
pub flush_to_zero: Option<bool>,
pub saturate: bool,
}
#[derive(PartialEq, Eq, Copy, Clone)]
pub enum RoundingMode {
NearestEven,
Zero,
NegativeInf,
PositiveInf,
}
#[derive(PartialEq, Eq, Clone, Hash)]
pub enum Type {
// .param.b32 foo;
Scalar(ScalarType),
// .param.v2.b32 foo;
Vector(ScalarType, u8),
// .param.b32 foo[4];
Array(ScalarType, Vec<u32>),
}
impl Type {
fn maybe_vector(vector: Option<VectorPrefix>, scalar: ScalarType) -> Self {
match vector {
Some(VectorPrefix::V2) => Type::Vector(scalar, 2),
Some(VectorPrefix::V4) => Type::Vector(scalar, 4),
None => Type::Scalar(scalar),
}
}
}
impl From<ScalarType> for Type {
fn from(value: ScalarType) -> Self {
Type::Scalar(value)
}
}
pub struct StData {
pub qualifier: ast::LdStQualifier,
pub state_space: StateSpace,
pub caching: ast::StCacheOperator,
pub typ: Type,
}
#[derive(Copy, Clone)]
pub struct RetData {
pub uniform: bool,
}
pub use ast::*;
impl From<RawStCacheOperator> for ast::StCacheOperator {
fn from(value: RawStCacheOperator) -> Self {
@ -350,7 +181,7 @@ fn immediate_value<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<as
fn fn_body<'a, 'input>(
stream: &mut PtxParser<'a, 'input>,
) -> PResult<Vec<Instruction<ParsedOperand<'input>>>> {
) -> PResult<Vec<Instruction<ParsedOperandStr<'input>>>> {
repeat(3.., terminated(parse_instruction, Token::Semicolon)).parse_next(stream)
}
@ -550,7 +381,7 @@ impl<'input, I: Stream<Token = Self> + StreamIsPartial, E: ParserError<I>> Parse
// * If it is mandatory then it is skipped
// * If it is optional then its type is `bool`
type ParsedOperand<'input> = ast::ParsedOperand<&'input str>;
type ParsedOperandStr<'input> = ast::ParsedOperand<&'input str>;
derive_parser!(
#[derive(Logos, PartialEq, Eq, Debug, Clone, Copy)]
@ -601,7 +432,7 @@ derive_parser!(
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov
mov{.vec}.type d, a => {
Instruction::Mov {
data: MovDetails::new(vec, type_),
data: ast::MovDetails::new(vec, type_),
arguments: MovArgs { dst: d, src: a },
}
}
@ -622,7 +453,7 @@ derive_parser!(
qualifier: weak.unwrap_or(RawLdStQualifier::Weak).into(),
state_space: ss.unwrap_or(StateSpace::Generic),
caching: cop.unwrap_or(RawStCacheOperator::Wb).into(),
typ: Type::maybe_vector(vec, type_)
typ: ast::Type::maybe_vector(vec, type_)
},
arguments: StArgs { src1:a, src2:b }
}
@ -633,7 +464,7 @@ derive_parser!(
qualifier: volatile.into(),
state_space: ss.unwrap_or(StateSpace::Generic),
caching: ast::StCacheOperator::Writeback,
typ: Type::maybe_vector(vec, type_)
typ: ast::Type::maybe_vector(vec, type_)
},
arguments: StArgs { src1:a, src2:b }
}
@ -647,7 +478,7 @@ derive_parser!(
qualifier: ast::LdStQualifier::Relaxed(scope),
state_space: ss.unwrap_or(StateSpace::Generic),
caching: ast::StCacheOperator::Writeback,
typ: Type::maybe_vector(vec, type_)
typ: ast::Type::maybe_vector(vec, type_)
},
arguments: StArgs { src1:a, src2:b }
}
@ -661,7 +492,7 @@ derive_parser!(
qualifier: ast::LdStQualifier::Release(scope),
state_space: ss.unwrap_or(StateSpace::Generic),
caching: ast::StCacheOperator::Writeback,
typ: Type::maybe_vector(vec, type_)
typ: ast::Type::maybe_vector(vec, type_)
},
arguments: StArgs { src1:a, src2:b }
}
@ -669,13 +500,13 @@ derive_parser!(
st.mmio.relaxed.sys{.global}.type [a], b => {
state.push(PtxError::Todo);
Instruction::St {
data: StData {
data: ast::StData {
qualifier: ast::LdStQualifier::Relaxed(MemScope::Sys),
state_space: global.unwrap_or(StateSpace::Generic),
caching: ast::StCacheOperator::Writeback,
typ: type_.into()
},
arguments: StArgs { src1:a, src2:b }
arguments: ast::StArgs { src1:a, src2:b }
}
}
@ -704,7 +535,7 @@ derive_parser!(
qualifier: weak.unwrap_or(RawLdStQualifier::Weak).into(),
state_space: ss.unwrap_or(StateSpace::Generic),
caching: cop.unwrap_or(RawLdCacheOperator::Ca).into(),
typ: Type::maybe_vector(vec, type_),
typ: ast::Type::maybe_vector(vec, type_),
non_coherent: false
},
arguments: LdArgs { dst:d, src:a }
@ -719,7 +550,7 @@ derive_parser!(
qualifier: volatile.into(),
state_space: ss.unwrap_or(StateSpace::Generic),
caching: ast::LdCacheOperator::Cached,
typ: Type::maybe_vector(vec, type_),
typ: ast::Type::maybe_vector(vec, type_),
non_coherent: false
},
arguments: LdArgs { dst:d, src:a }
@ -734,7 +565,7 @@ derive_parser!(
qualifier: ast::LdStQualifier::Relaxed(scope),
state_space: ss.unwrap_or(StateSpace::Generic),
caching: ast::LdCacheOperator::Cached,
typ: Type::maybe_vector(vec, type_),
typ: ast::Type::maybe_vector(vec, type_),
non_coherent: false
},
arguments: LdArgs { dst:d, src:a }
@ -749,7 +580,7 @@ derive_parser!(
qualifier: ast::LdStQualifier::Acquire(scope),
state_space: ss.unwrap_or(StateSpace::Generic),
caching: ast::LdCacheOperator::Cached,
typ: Type::maybe_vector(vec, type_),
typ: ast::Type::maybe_vector(vec, type_),
non_coherent: false
},
arguments: LdArgs { dst:d, src:a }
@ -931,7 +762,7 @@ fn main() {
println!("{}", mem::size_of::<Token>());
let mut input: &[Token] = &[][..];
let x = opt(any::<_, ContextError>.verify_map(|t| {
let x = opt(any::<_, ContextError>.verify_map(|_| {
println!("MAP");
Some(true)
}))
@ -948,13 +779,11 @@ fn main() {
);
let tokens = lexer.map(|t| t.unwrap()).collect::<Vec<_>>();
println!("{:?}", &tokens);
let mut stream = PtxParser {
let stream = PtxParser {
input: &tokens[..],
state: Vec::new(),
};
let fn_body = fn_body.parse(stream).unwrap();
println!("{}", fn_body.len());
//parse_prefix(&mut lexer);
let mut parser = &*tokens;
println!("{}", mem::size_of::<Token>());
}