mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
Parse simplest vector add kernel
This commit is contained in:
parent
91dbbb372b
commit
77de5c7a15
5 changed files with 629 additions and 17 deletions
|
@ -11,15 +11,17 @@ use syn::{
|
|||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#vectors
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#fundamental-types
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#alternate-floating-point-data-formats
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#packed-floating-point-data-types
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#packed-integer-data-types
|
||||
#[rustfmt::skip]
|
||||
static POSTFIX_MODIFIERS: &[&str] = &[
|
||||
".v2", ".v4",
|
||||
".s8", ".s16", ".s32", ".s64",
|
||||
".u8", ".u16", ".u32", ".u64",
|
||||
".s8", ".s16", ".s16x2", ".s32", ".s64",
|
||||
".u8", ".u16", ".u16x2", ".u32", ".u64",
|
||||
".f16", ".f16x2", ".f32", ".f64",
|
||||
".b8", ".b16", ".b32", ".b64", ".b128",
|
||||
".pred",
|
||||
".bf16", ".e4m3", ".e5m2", ".tf32",
|
||||
".bf16", ".bf16x2", ".e4m3", ".e5m2", ".tf32",
|
||||
];
|
||||
|
||||
static POSTFIX_TYPES: &[&str] = &["ScalarType", "VectorPrefix"];
|
||||
|
|
|
@ -332,7 +332,8 @@ impl DotModifier {
|
|||
capitalize = true;
|
||||
continue;
|
||||
}
|
||||
let c = if capitalize {
|
||||
// Special hack to emit `BF16`` instead of `Bf16``
|
||||
let c = if capitalize || c == 'f' && result.ends_with('B') {
|
||||
capitalize = false;
|
||||
c.to_ascii_uppercase()
|
||||
} else {
|
||||
|
|
|
@ -8,3 +8,4 @@ logos = "0.14"
|
|||
winnow = { version = "0.6.18", features = ["debug"] }
|
||||
gen = { path = "../gen" }
|
||||
thiserror = "1.0"
|
||||
bitflags = "1.2"
|
||||
|
|
|
@ -1,4 +1,31 @@
|
|||
use super::{MemScope, ScalarType, VectorPrefix, StateSpace};
|
||||
use super::{MemScope, ScalarType, StateSpace, VectorPrefix};
|
||||
use bitflags::bitflags;
|
||||
|
||||
pub enum Statement<P: Operand> {
|
||||
Label(P::Ident),
|
||||
Variable(MultiVariable<P::Ident>),
|
||||
Instruction(Option<PredAt<P::Ident>>, Instruction<P>),
|
||||
Block(Vec<Statement<P>>),
|
||||
}
|
||||
|
||||
pub struct MultiVariable<ID> {
|
||||
pub var: Variable<ID>,
|
||||
pub count: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Variable<ID> {
|
||||
pub align: Option<u32>,
|
||||
pub v_type: Type,
|
||||
pub state_space: StateSpace,
|
||||
pub name: ID,
|
||||
pub array_init: Vec<u8>,
|
||||
}
|
||||
|
||||
pub struct PredAt<ID> {
|
||||
pub not: bool,
|
||||
pub label: ID,
|
||||
}
|
||||
|
||||
gen::generate_instruction_type!(
|
||||
pub enum Instruction<T> {
|
||||
|
@ -118,6 +145,14 @@ pub enum ParsedOperand<Ident> {
|
|||
VecPack(Vec<Ident>),
|
||||
}
|
||||
|
||||
impl<Ident> Operand for ParsedOperand<Ident> {
|
||||
type Ident = Ident;
|
||||
}
|
||||
|
||||
pub trait Operand {
|
||||
type Ident;
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub enum ImmediateValue {
|
||||
U64(u64),
|
||||
|
@ -143,8 +178,6 @@ pub enum LdCacheOperator {
|
|||
Uncached,
|
||||
}
|
||||
|
||||
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub enum ArithDetails {
|
||||
Integer(ArithInteger),
|
||||
|
@ -199,7 +232,6 @@ pub struct LdDetails {
|
|||
pub non_coherent: bool,
|
||||
}
|
||||
|
||||
|
||||
pub struct StData {
|
||||
pub qualifier: LdStQualifier,
|
||||
pub state_space: StateSpace,
|
||||
|
@ -211,3 +243,52 @@ pub struct StData {
|
|||
pub struct RetData {
|
||||
pub uniform: bool,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq)]
|
||||
pub enum TuningDirective {
|
||||
MaxNReg(u32),
|
||||
MaxNtid(u32, u32, u32),
|
||||
ReqNtid(u32, u32, u32),
|
||||
MinNCtaPerSm(u32),
|
||||
}
|
||||
|
||||
pub struct MethodDeclaration<'input, ID> {
|
||||
pub return_arguments: Vec<Variable<ID>>,
|
||||
pub name: MethodName<'input, ID>,
|
||||
pub input_arguments: Vec<Variable<ID>>,
|
||||
pub shared_mem: Option<ID>,
|
||||
}
|
||||
|
||||
#[derive(Hash, PartialEq, Eq, Copy, Clone)]
|
||||
pub enum MethodName<'input, ID> {
|
||||
Kernel(&'input str),
|
||||
Func(ID),
|
||||
}
|
||||
|
||||
bitflags! {
|
||||
pub struct LinkingDirective: u8 {
|
||||
const NONE = 0b000;
|
||||
const EXTERN = 0b001;
|
||||
const VISIBLE = 0b10;
|
||||
const WEAK = 0b100;
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Function<'a, ID, S> {
|
||||
pub func_directive: MethodDeclaration<'a, ID>,
|
||||
pub tuning: Vec<TuningDirective>,
|
||||
pub body: Option<Vec<S>>,
|
||||
}
|
||||
|
||||
pub enum Directive<'input, O: Operand> {
|
||||
Variable(LinkingDirective, Variable<O::Ident>),
|
||||
Method(
|
||||
LinkingDirective,
|
||||
Function<'input, &'input str, Statement<O>>,
|
||||
),
|
||||
}
|
||||
|
||||
pub struct Module<'input> {
|
||||
pub version: (u8, u8),
|
||||
pub directives: Vec<Directive<'input, ParsedOperand<&'input str>>>,
|
||||
}
|
||||
|
|
|
@ -2,7 +2,10 @@ use gen::derive_parser;
|
|||
use logos::Logos;
|
||||
use std::mem;
|
||||
use std::num::{ParseFloatError, ParseIntError};
|
||||
use winnow::ascii::{dec_uint, digit1};
|
||||
use winnow::combinator::*;
|
||||
use winnow::error::ErrMode;
|
||||
use winnow::stream::Accumulate;
|
||||
use winnow::token::any;
|
||||
use winnow::{
|
||||
error::{ContextError, ParserError},
|
||||
|
@ -170,6 +173,28 @@ fn s32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<i32> {
|
|||
.parse_next(stream)
|
||||
}
|
||||
|
||||
fn u8<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<u8> {
|
||||
take_error(num.map(|x| {
|
||||
let (text, radix, _) = x;
|
||||
match u8::from_str_radix(text, radix) {
|
||||
Ok(x) => Ok(x),
|
||||
Err(err) => Err((0, PtxError::from(err))),
|
||||
}
|
||||
}))
|
||||
.parse_next(stream)
|
||||
}
|
||||
|
||||
fn u32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<u32> {
|
||||
take_error(num.map(|x| {
|
||||
let (text, radix, _) = x;
|
||||
match u32::from_str_radix(text, radix) {
|
||||
Ok(x) => Ok(x),
|
||||
Err(err) => Err((0, PtxError::from(err))),
|
||||
}
|
||||
}))
|
||||
.parse_next(stream)
|
||||
}
|
||||
|
||||
fn immediate_value<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::ImmediateValue> {
|
||||
alt((
|
||||
int_immediate,
|
||||
|
@ -179,10 +204,402 @@ fn immediate_value<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<as
|
|||
.parse_next(stream)
|
||||
}
|
||||
|
||||
fn fn_body<'a, 'input>(
|
||||
fn module<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::Module<'input>> {
|
||||
(
|
||||
version,
|
||||
target,
|
||||
opt(address_size),
|
||||
repeat_without_none(directive),
|
||||
)
|
||||
.map(|(version, _, _, directives)| ast::Module {
|
||||
version,
|
||||
directives,
|
||||
})
|
||||
.parse_next(stream)
|
||||
}
|
||||
|
||||
fn address_size<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> {
|
||||
(Token::DotAddressSize, u8_literal(64))
|
||||
.void()
|
||||
.parse_next(stream)
|
||||
}
|
||||
|
||||
fn version<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(u8, u8)> {
|
||||
(Token::DotVersion, u8, Token::Dot, u8)
|
||||
.map(|(_, major, _, minor)| (major, minor))
|
||||
.parse_next(stream)
|
||||
}
|
||||
|
||||
fn target<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(u32, Option<char>)> {
|
||||
preceded(Token::DotTarget, ident.and_then(shader_model)).parse_next(stream)
|
||||
}
|
||||
|
||||
fn shader_model<'a>(stream: &mut &str) -> PResult<(u32, Option<char>)> {
|
||||
(
|
||||
"sm_",
|
||||
dec_uint,
|
||||
opt(any.verify(|c: &char| c.is_ascii_lowercase())),
|
||||
eof,
|
||||
)
|
||||
.map(|(_, digits, arch_variant, _)| (digits, arch_variant))
|
||||
.parse_next(stream)
|
||||
}
|
||||
|
||||
fn directive<'a, 'input>(
|
||||
stream: &mut PtxParser<'a, 'input>,
|
||||
) -> PResult<Vec<Instruction<ParsedOperandStr<'input>>>> {
|
||||
repeat(3.., terminated(parse_instruction, Token::Semicolon)).parse_next(stream)
|
||||
) -> PResult<Option<ast::Directive<'input, ast::ParsedOperand<&'input str>>>> {
|
||||
(function.map(|f| {
|
||||
let (linking, func) = f;
|
||||
Some(ast::Directive::Method(linking, func))
|
||||
}))
|
||||
.parse_next(stream)
|
||||
}
|
||||
|
||||
fn function<'a, 'input>(
|
||||
stream: &mut PtxParser<'a, 'input>,
|
||||
) -> PResult<(
|
||||
ast::LinkingDirective,
|
||||
ast::Function<'input, &'input str, ast::Statement<ParsedOperand<&'input str>>>,
|
||||
)> {
|
||||
(
|
||||
linking_directives,
|
||||
method_declaration,
|
||||
repeat(0.., tuning_directive),
|
||||
function_body,
|
||||
)
|
||||
.map(|(linking, func_directive, tuning, body)| {
|
||||
(
|
||||
linking,
|
||||
ast::Function {
|
||||
func_directive,
|
||||
tuning,
|
||||
body,
|
||||
},
|
||||
)
|
||||
})
|
||||
.parse_next(stream)
|
||||
}
|
||||
|
||||
fn linking_directives<'a, 'input>(
|
||||
stream: &mut PtxParser<'a, 'input>,
|
||||
) -> PResult<ast::LinkingDirective> {
|
||||
dispatch! { any;
|
||||
Token::DotExtern => empty.value(ast::LinkingDirective::EXTERN),
|
||||
Token::DotVisible => empty.value(ast::LinkingDirective::VISIBLE),
|
||||
Token::DotWeak => empty.value(ast::LinkingDirective::WEAK),
|
||||
_ => fail
|
||||
}
|
||||
.parse_next(stream)
|
||||
}
|
||||
|
||||
fn tuning_directive<'a, 'input>(
|
||||
stream: &mut PtxParser<'a, 'input>,
|
||||
) -> PResult<ast::TuningDirective> {
|
||||
dispatch! {any;
|
||||
Token::DotMaxnreg => u32.map(ast::TuningDirective::MaxNReg),
|
||||
Token::DotMaxntid => tuple1to3_u32.map(|(nx, ny, nz)| ast::TuningDirective::MaxNtid(nx, ny, nz)),
|
||||
Token::DotReqntid => tuple1to3_u32.map(|(nx, ny, nz)| ast::TuningDirective::ReqNtid(nx, ny, nz)),
|
||||
Token::DotMinnctapersm => u32.map(ast::TuningDirective::MinNCtaPerSm),
|
||||
_ => fail
|
||||
}
|
||||
.parse_next(stream)
|
||||
}
|
||||
|
||||
fn method_declaration<'a, 'input>(
|
||||
stream: &mut PtxParser<'a, 'input>,
|
||||
) -> PResult<ast::MethodDeclaration<'input, &'input str>> {
|
||||
dispatch! {any;
|
||||
Token::DotEntry => (ident, kernel_arguments).map(|(name, input_arguments)| ast::MethodDeclaration{
|
||||
return_arguments: Vec::new(), name: ast::MethodName::Kernel(name), input_arguments, shared_mem: None
|
||||
}),
|
||||
Token::DotFunc => (opt(fn_arguments), ident, fn_arguments).map(|(return_arguments, name,input_arguments)| {
|
||||
let return_arguments = return_arguments.unwrap_or_else(|| Vec::new());
|
||||
let name = ast::MethodName::Func(name);
|
||||
ast::MethodDeclaration{ return_arguments, name, input_arguments, shared_mem: None }
|
||||
}),
|
||||
_ => fail
|
||||
}
|
||||
.parse_next(stream)
|
||||
}
|
||||
|
||||
fn fn_arguments<'a, 'input>(
|
||||
stream: &mut PtxParser<'a, 'input>,
|
||||
) -> PResult<Vec<ast::Variable<&'input str>>> {
|
||||
delimited(
|
||||
Token::LParen,
|
||||
separated(0.., fn_input, Token::Comma),
|
||||
Token::RParen,
|
||||
)
|
||||
.parse_next(stream)
|
||||
}
|
||||
|
||||
fn kernel_arguments<'a, 'input>(
|
||||
stream: &mut PtxParser<'a, 'input>,
|
||||
) -> PResult<Vec<ast::Variable<&'input str>>> {
|
||||
delimited(
|
||||
Token::LParen,
|
||||
separated(0.., kernel_input, Token::Comma),
|
||||
Token::RParen,
|
||||
)
|
||||
.parse_next(stream)
|
||||
}
|
||||
|
||||
fn kernel_input<'a, 'input>(
|
||||
stream: &mut PtxParser<'a, 'input>,
|
||||
) -> PResult<ast::Variable<&'input str>> {
|
||||
preceded(
|
||||
Token::DotParam,
|
||||
variable_scalar_or_vector(StateSpace::Param),
|
||||
)
|
||||
.parse_next(stream)
|
||||
}
|
||||
|
||||
fn fn_input<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::Variable<&'input str>> {
|
||||
dispatch! { any;
|
||||
Token::DotParam => variable_scalar_or_vector(StateSpace::Param),
|
||||
Token::DotReg => variable_scalar_or_vector(StateSpace::Reg),
|
||||
_ => fail
|
||||
}
|
||||
.parse_next(stream)
|
||||
}
|
||||
|
||||
fn tuple1to3_u32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(u32, u32, u32)> {
|
||||
struct Tuple3AccumulateU32 {
|
||||
index: usize,
|
||||
value: (u32, u32, u32),
|
||||
}
|
||||
|
||||
impl Accumulate<u32> for Tuple3AccumulateU32 {
|
||||
fn initial(_: Option<usize>) -> Self {
|
||||
Self {
|
||||
index: 0,
|
||||
value: (1, 1, 1),
|
||||
}
|
||||
}
|
||||
|
||||
fn accumulate(&mut self, value: u32) {
|
||||
match self.index {
|
||||
0 => {
|
||||
self.value = (value, self.value.1, self.value.2);
|
||||
self.index = 1;
|
||||
}
|
||||
1 => {
|
||||
self.value = (self.value.0, value, self.value.2);
|
||||
self.index = 2;
|
||||
}
|
||||
2 => {
|
||||
self.value = (self.value.0, self.value.1, value);
|
||||
self.index = 3;
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
separated::<_, _, Tuple3AccumulateU32, _, _, _, _>(1..3, u32, Token::Comma)
|
||||
.map(|acc| acc.value)
|
||||
.parse_next(stream)
|
||||
}
|
||||
|
||||
fn function_body<'a, 'input>(
|
||||
stream: &mut PtxParser<'a, 'input>,
|
||||
) -> PResult<Option<Vec<ast::Statement<ParsedOperandStr<'input>>>>> {
|
||||
dispatch! {any;
|
||||
Token::LBrace => terminated(repeat_without_none(statement), Token::RBrace).map(Some),
|
||||
Token::Semicolon => empty.map(|_| None),
|
||||
_ => fail
|
||||
}
|
||||
.parse_next(stream)
|
||||
}
|
||||
|
||||
fn statement<'a, 'input>(
|
||||
stream: &mut PtxParser<'a, 'input>,
|
||||
) -> PResult<Option<Statement<ParsedOperandStr<'input>>>> {
|
||||
alt((
|
||||
label.map(Some),
|
||||
debug_directive.map(|_| None),
|
||||
multi_variable.map(Some),
|
||||
predicated_instruction.map(Some),
|
||||
pragma.map(|_| None),
|
||||
block_statement.map(Some),
|
||||
))
|
||||
.parse_next(stream)
|
||||
}
|
||||
|
||||
fn pragma<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> {
|
||||
(Token::DotPragma, Token::String, Token::Semicolon)
|
||||
.void()
|
||||
.parse_next(stream)
|
||||
}
|
||||
|
||||
fn multi_variable<'a, 'input>(
|
||||
stream: &mut PtxParser<'a, 'input>,
|
||||
) -> PResult<ast::Statement<ParsedOperandStr<'input>>> {
|
||||
(
|
||||
variable,
|
||||
opt(delimited(Token::Lt, u32, Token::Gt)),
|
||||
Token::Semicolon,
|
||||
)
|
||||
.map(|(var, count, _)| ast::Statement::Variable(ast::MultiVariable { var, count }))
|
||||
.parse_next(stream)
|
||||
}
|
||||
|
||||
fn variable<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::Variable<&'input str>> {
|
||||
dispatch! {any;
|
||||
Token::DotReg => variable_scalar_or_vector(StateSpace::Reg),
|
||||
Token::DotLocal => variable_scalar_or_vector(StateSpace::Local),
|
||||
Token::DotParam => variable_scalar_or_vector(StateSpace::Param),
|
||||
Token::DotShared => variable_scalar_or_vector(StateSpace::Shared),
|
||||
_ => fail
|
||||
}
|
||||
.parse_next(stream)
|
||||
}
|
||||
|
||||
fn variable_scalar_or_vector<'a, 'input: 'a>(
|
||||
state_space: StateSpace,
|
||||
) -> impl Parser<PtxParser<'a, 'input>, ast::Variable<&'input str>, ContextError> {
|
||||
move |stream: &mut PtxParser<'a, 'input>| {
|
||||
(opt(align), scalar_vector_type, ident)
|
||||
.map(|(align, v_type, name)| ast::Variable {
|
||||
align,
|
||||
v_type,
|
||||
state_space,
|
||||
name,
|
||||
array_init: Vec::new(),
|
||||
})
|
||||
.parse_next(stream)
|
||||
}
|
||||
}
|
||||
|
||||
fn align<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<u32> {
|
||||
preceded(Token::DotAlign, u32).parse_next(stream)
|
||||
}
|
||||
|
||||
fn scalar_vector_type<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<Type> {
|
||||
(
|
||||
opt(alt((
|
||||
Token::DotV2.value(VectorPrefix::V2),
|
||||
Token::DotV4.value(VectorPrefix::V4),
|
||||
))),
|
||||
scalar_type,
|
||||
)
|
||||
.map(|(prefix, scalar)| ast::Type::maybe_vector(prefix, scalar))
|
||||
.parse_next(stream)
|
||||
}
|
||||
|
||||
fn scalar_type<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ScalarType> {
|
||||
any.verify_map(|t| {
|
||||
Some(match t {
|
||||
Token::DotS8 => ScalarType::S8,
|
||||
Token::DotS16 => ScalarType::S16,
|
||||
Token::DotS16x2 => ScalarType::S16x2,
|
||||
Token::DotS32 => ScalarType::S32,
|
||||
Token::DotS64 => ScalarType::S64,
|
||||
Token::DotU8 => ScalarType::U8,
|
||||
Token::DotU16 => ScalarType::U16,
|
||||
Token::DotU16x2 => ScalarType::U16x2,
|
||||
Token::DotU32 => ScalarType::U32,
|
||||
Token::DotU64 => ScalarType::U64,
|
||||
Token::DotB8 => ScalarType::B8,
|
||||
Token::DotB16 => ScalarType::B16,
|
||||
Token::DotB32 => ScalarType::B32,
|
||||
Token::DotB64 => ScalarType::B64,
|
||||
Token::DotB128 => ScalarType::B128,
|
||||
Token::DotPred => ScalarType::Pred,
|
||||
Token::DotF16 => ScalarType::F16,
|
||||
Token::DotF16x2 => ScalarType::F16x2,
|
||||
Token::DotF32 => ScalarType::F32,
|
||||
Token::DotF64 => ScalarType::F64,
|
||||
Token::DotBF16 => ScalarType::BF16,
|
||||
Token::DotBF16x2 => ScalarType::BF16x2,
|
||||
_ => return None,
|
||||
})
|
||||
})
|
||||
.parse_next(stream)
|
||||
}
|
||||
|
||||
fn predicated_instruction<'a, 'input>(
|
||||
stream: &mut PtxParser<'a, 'input>,
|
||||
) -> PResult<ast::Statement<ParsedOperandStr<'input>>> {
|
||||
(opt(pred_at), parse_instruction, Token::Semicolon)
|
||||
.map(|(p, i, _)| ast::Statement::Instruction(p, i))
|
||||
.parse_next(stream)
|
||||
}
|
||||
|
||||
fn pred_at<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::PredAt<&'input str>> {
|
||||
(Token::At, opt(Token::Not), ident)
|
||||
.map(|(_, not, label)| ast::PredAt {
|
||||
not: not.is_some(),
|
||||
label,
|
||||
})
|
||||
.parse_next(stream)
|
||||
}
|
||||
|
||||
fn label<'a, 'input>(
|
||||
stream: &mut PtxParser<'a, 'input>,
|
||||
) -> PResult<ast::Statement<ParsedOperandStr<'input>>> {
|
||||
terminated(ident, Token::Colon)
|
||||
.map(|l| ast::Statement::Label(l))
|
||||
.parse_next(stream)
|
||||
}
|
||||
|
||||
fn debug_directive<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> {
|
||||
(
|
||||
Token::DotLoc,
|
||||
u32,
|
||||
u32,
|
||||
u32,
|
||||
opt((
|
||||
Token::Comma,
|
||||
ident_literal("function_name"),
|
||||
ident,
|
||||
dispatch! { any;
|
||||
Token::Comma => (ident_literal("inlined_at"), u32, u32, u32).void(),
|
||||
Token::Plus => (u32, Token::Comma, ident_literal("inlined_at"), u32, u32, u32).void(),
|
||||
_ => fail
|
||||
},
|
||||
)),
|
||||
)
|
||||
.void()
|
||||
.parse_next(stream)
|
||||
}
|
||||
|
||||
fn block_statement<'a, 'input>(
|
||||
stream: &mut PtxParser<'a, 'input>,
|
||||
) -> PResult<ast::Statement<ParsedOperandStr<'input>>> {
|
||||
delimited(Token::LBrace, repeat_without_none(statement), Token::RBrace)
|
||||
.map(|s| ast::Statement::Block(s))
|
||||
.parse_next(stream)
|
||||
}
|
||||
|
||||
fn repeat_without_none<Input: Stream, Output, Error: ParserError<Input>>(
|
||||
parser: impl Parser<Input, Option<Output>, Error>,
|
||||
) -> impl Parser<Input, Vec<Output>, Error> {
|
||||
repeat(0.., parser).fold(Vec::new, |mut acc: Vec<_>, item| {
|
||||
if let Some(item) = item {
|
||||
acc.push(item);
|
||||
}
|
||||
acc
|
||||
})
|
||||
}
|
||||
|
||||
fn ident_literal<
|
||||
'a,
|
||||
'input,
|
||||
I: Stream<Token = Token<'input>> + StreamIsPartial,
|
||||
E: ParserError<I>,
|
||||
>(
|
||||
s: &'input str,
|
||||
) -> impl Parser<I, (), E> + 'input {
|
||||
move |stream: &mut I| {
|
||||
any.verify(|t| matches!(t, Token::Ident(text) if *text == s))
|
||||
.void()
|
||||
.parse_next(stream)
|
||||
}
|
||||
}
|
||||
|
||||
fn u8_literal<'a, 'input>(x: u8) -> impl Parser<PtxParser<'a, 'input>, (), ContextError> {
|
||||
move |stream: &mut PtxParser| u8.verify(|t| *t == x).void().parse_next(stream)
|
||||
}
|
||||
|
||||
impl<Ident> ast::ParsedOperand<Ident> {
|
||||
|
@ -391,18 +808,36 @@ derive_parser!(
|
|||
Comma,
|
||||
#[token(".")]
|
||||
Dot,
|
||||
#[token(":")]
|
||||
Colon,
|
||||
#[token(";")]
|
||||
Semicolon,
|
||||
#[token("@")]
|
||||
At,
|
||||
#[regex(r"[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+", |lex| lex.slice(), priority = 0)]
|
||||
Ident(&'input str),
|
||||
#[regex(r#""[^"]*""#)]
|
||||
String,
|
||||
#[token("|")]
|
||||
Or,
|
||||
#[token("!")]
|
||||
Not,
|
||||
#[token("(")]
|
||||
LParen,
|
||||
#[token(")")]
|
||||
RParen,
|
||||
#[token("[")]
|
||||
LBracket,
|
||||
#[token("]")]
|
||||
RBracket,
|
||||
#[token("{")]
|
||||
LBrace,
|
||||
#[token("}")]
|
||||
RBrace,
|
||||
#[token("<")]
|
||||
Lt,
|
||||
#[token(">")]
|
||||
Gt,
|
||||
#[regex(r"0[fF][0-9a-zA-Z]{8}", |lex| lex.slice())]
|
||||
F32(&'input str),
|
||||
#[regex(r"0[dD][0-9a-zA-Z]{16}", |lex| lex.slice())]
|
||||
|
@ -415,6 +850,36 @@ derive_parser!(
|
|||
Minus,
|
||||
#[token("+")]
|
||||
Plus,
|
||||
#[token(".version")]
|
||||
DotVersion,
|
||||
#[token(".loc")]
|
||||
DotLoc,
|
||||
#[token(".reg")]
|
||||
DotReg,
|
||||
#[token(".align")]
|
||||
DotAlign,
|
||||
#[token(".pragma")]
|
||||
DotPragma,
|
||||
#[token(".maxnreg")]
|
||||
DotMaxnreg,
|
||||
#[token(".maxntid")]
|
||||
DotMaxntid,
|
||||
#[token(".reqntid")]
|
||||
DotReqntid,
|
||||
#[token(".minnctapersm")]
|
||||
DotMinnctapersm,
|
||||
#[token(".entry")]
|
||||
DotEntry,
|
||||
#[token(".func")]
|
||||
DotFunc,
|
||||
#[token(".extern")]
|
||||
DotExtern,
|
||||
#[token(".visible")]
|
||||
DotVisible,
|
||||
#[token(".target")]
|
||||
DotTarget,
|
||||
#[token(".address_size")]
|
||||
DotAddressSize
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
|
||||
|
@ -771,10 +1236,29 @@ fn main() {
|
|||
dbg!(x);
|
||||
let lexer = Token::lexer(
|
||||
"
|
||||
ld.u64 temp, [in_addr];
|
||||
add.u64 temp2, temp, 1;
|
||||
st.u64 [out_addr], temp2;
|
||||
ret;
|
||||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.visible .entry add(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .u64 temp;
|
||||
.reg .u64 temp2;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
ld.u64 temp, [in_addr];
|
||||
add.u64 temp2, temp, 1;
|
||||
st.u64 [out_addr], temp2;
|
||||
ret;
|
||||
}
|
||||
|
||||
",
|
||||
);
|
||||
let tokens = lexer.map(|t| t.unwrap()).collect::<Vec<_>>();
|
||||
|
@ -783,7 +1267,50 @@ fn main() {
|
|||
input: &tokens[..],
|
||||
state: Vec::new(),
|
||||
};
|
||||
let fn_body = fn_body.parse(stream).unwrap();
|
||||
println!("{}", fn_body.len());
|
||||
let module_ = module.parse(stream).unwrap();
|
||||
println!("{}", mem::size_of::<Token>());
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::target;
|
||||
use super::Token;
|
||||
use logos::Logos;
|
||||
use winnow::prelude::*;
|
||||
|
||||
#[test]
|
||||
fn sm_11() {
|
||||
let tokens = Token::lexer(".target sm_11")
|
||||
.collect::<Result<Vec<_>, ()>>()
|
||||
.unwrap();
|
||||
let stream = super::PtxParser {
|
||||
input: &tokens[..],
|
||||
state: Vec::new(),
|
||||
};
|
||||
assert_eq!(target.parse(stream).unwrap(), (11, None));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sm_90a() {
|
||||
let tokens = Token::lexer(".target sm_90a")
|
||||
.collect::<Result<Vec<_>, ()>>()
|
||||
.unwrap();
|
||||
let stream = super::PtxParser {
|
||||
input: &tokens[..],
|
||||
state: Vec::new(),
|
||||
};
|
||||
assert_eq!(target.parse(stream).unwrap(), (90, Some('a')));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sm_90ab() {
|
||||
let tokens = Token::lexer(".target sm_90ab")
|
||||
.collect::<Result<Vec<_>, ()>>()
|
||||
.unwrap();
|
||||
let stream = super::PtxParser {
|
||||
input: &tokens[..],
|
||||
state: Vec::new(),
|
||||
};
|
||||
assert!(target.parse(stream).is_err());
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue