mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 08:24:44 +00:00
Simplify error handling during ast construction
This commit is contained in:
parent
bbe993392b
commit
6f4530fe83
4 changed files with 191 additions and 195 deletions
150
ptx/src/ast.rs
150
ptx/src/ast.rs
|
@ -1,4 +1,5 @@
|
|||
use std::convert::From;
|
||||
use std::convert::Into;
|
||||
use std::error::Error;
|
||||
use std::mem;
|
||||
use std::num::ParseIntError;
|
||||
|
@ -6,110 +7,42 @@ use std::num::ParseIntError;
|
|||
quick_error! {
|
||||
#[derive(Debug)]
|
||||
pub enum PtxError {
|
||||
Parse (err: ParseIntError) {
|
||||
ParseInt (err: ParseIntError) {
|
||||
from()
|
||||
display("{}", err)
|
||||
cause(err)
|
||||
from()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct WithErrors<T, E> {
|
||||
pub value: T,
|
||||
pub errors: Vec<E>,
|
||||
pub trait UnwrapWithVec<E, To> {
|
||||
fn unwrap_with(self, errs: &mut Vec<E>) -> To;
|
||||
}
|
||||
|
||||
impl<T, E> WithErrors<T, E> {
|
||||
pub fn new(t: T) -> Self {
|
||||
WithErrors {
|
||||
value: t,
|
||||
errors: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn map<F: FnOnce(T) -> U, U>(self, f: F) -> WithErrors<U, E> {
|
||||
WithErrors {
|
||||
value: f(self.value),
|
||||
errors: self.errors,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn map2<X, Y, F: FnOnce(X, Y) -> T>(
|
||||
x: WithErrors<X, E>,
|
||||
y: WithErrors<Y, E>,
|
||||
f: F,
|
||||
) -> Self {
|
||||
let mut errors = x.errors;
|
||||
let mut errors_other = y.errors;
|
||||
if errors.len() < errors_other.len() {
|
||||
mem::swap(&mut errors, &mut errors_other);
|
||||
}
|
||||
errors.extend(errors_other);
|
||||
WithErrors {
|
||||
value: f(x.value, y.value),
|
||||
errors: errors,
|
||||
}
|
||||
impl<R: Default, EFrom: std::convert::Into<EInto>, EInto> UnwrapWithVec<EInto, R>
|
||||
for Result<R, EFrom>
|
||||
{
|
||||
fn unwrap_with(self, errs: &mut Vec<EInto>) -> R {
|
||||
self.unwrap_or_else(|e| {
|
||||
errs.push(e.into());
|
||||
R::default()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<T:Default, E: Error> WithErrors<T, E> {
|
||||
pub fn from_results<X: Default, Y: Default, F: FnOnce(X, Y) -> T>(
|
||||
x: Result<X, E>,
|
||||
y: Result<Y, E>,
|
||||
f: F,
|
||||
) -> Self {
|
||||
match (x, y) {
|
||||
(Ok(x), Ok(y)) => WithErrors {
|
||||
value: f(x, y),
|
||||
errors: Vec::new(),
|
||||
},
|
||||
(Err(e), Ok(y)) => WithErrors {
|
||||
value: f(X::default(), y),
|
||||
errors: vec![e],
|
||||
},
|
||||
(Ok(x), Err(e)) => WithErrors {
|
||||
value: f(x, Y::default()),
|
||||
errors: vec![e],
|
||||
},
|
||||
(Err(e1), Err(e2)) => WithErrors {
|
||||
value: T::default(),
|
||||
errors: vec![e1, e2],
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, E: Error> WithErrors<Vec<T>, E> {
|
||||
pub fn from_vec(v: Vec<WithErrors<T, E>>) -> Self {
|
||||
let mut values = Vec::with_capacity(v.len());
|
||||
let mut errors = Vec::new();
|
||||
for we in v.into_iter() {
|
||||
values.push(we.value);
|
||||
errors.extend(we.errors);
|
||||
}
|
||||
WithErrors {
|
||||
value: values,
|
||||
errors: errors,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait WithErrorsExt<From, To, E> {
|
||||
fn with_errors<F: FnOnce(From) -> To>(self, f: F) -> WithErrors<To, E>;
|
||||
}
|
||||
|
||||
impl<From, To: Default, E> WithErrorsExt<From, To, E> for Result<From, E> {
|
||||
fn with_errors<F: FnOnce(From) -> To>(self, f: F) -> WithErrors<To, E> {
|
||||
self.map_or_else(
|
||||
|e| WithErrors {
|
||||
value: To::default(),
|
||||
errors: vec![e],
|
||||
},
|
||||
|t| WithErrors {
|
||||
value: f(t),
|
||||
errors: Vec::new(),
|
||||
},
|
||||
)
|
||||
impl<
|
||||
R1: Default,
|
||||
EFrom1: std::convert::Into<EInto>,
|
||||
R2: Default,
|
||||
EFrom2: std::convert::Into<EInto>,
|
||||
EInto,
|
||||
> UnwrapWithVec<EInto, (R1, R2)> for (Result<R1, EFrom1>, Result<R2, EFrom2>)
|
||||
{
|
||||
fn unwrap_with(self, errs: &mut Vec<EInto>) -> (R1, R2) {
|
||||
let (x, y) = self;
|
||||
let r1 = x.unwrap_with(errs);
|
||||
let r2 = y.unwrap_with(errs);
|
||||
(r1, r2)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -132,6 +65,13 @@ pub struct Argument<'a> {
|
|||
pub length: u32,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
|
||||
pub enum Type {
|
||||
Scalar(ScalarType),
|
||||
ExtendedScalar(ExtendedScalarType),
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
|
||||
pub enum ScalarType {
|
||||
B8,
|
||||
B16,
|
||||
|
@ -150,6 +90,12 @@ pub enum ScalarType {
|
|||
F64,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
|
||||
pub enum ExtendedScalarType {
|
||||
F16x2,
|
||||
Pred,
|
||||
}
|
||||
|
||||
impl Default for ScalarType {
|
||||
fn default() -> Self {
|
||||
ScalarType::B8
|
||||
|
@ -158,11 +104,25 @@ impl Default for ScalarType {
|
|||
|
||||
pub enum Statement<'a> {
|
||||
Label(&'a str),
|
||||
Variable(Variable),
|
||||
Variable(Variable<'a>),
|
||||
Instruction(Instruction),
|
||||
}
|
||||
|
||||
pub struct Variable {}
|
||||
pub struct Variable<'a> {
|
||||
pub space: StateSpace,
|
||||
pub v_type: Type,
|
||||
pub name: &'a str,
|
||||
pub count: Option<u32>,
|
||||
}
|
||||
|
||||
pub enum StateSpace {
|
||||
Reg,
|
||||
Sreg,
|
||||
Const,
|
||||
Global,
|
||||
Local,
|
||||
Shared,
|
||||
}
|
||||
|
||||
pub enum Instruction {
|
||||
Ld,
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
use crate::ast;
|
||||
use crate::ast::{WithErrors, WithErrorsExt};
|
||||
use crate::ast::UnwrapWithVec;
|
||||
use crate::without_none;
|
||||
|
||||
grammar;
|
||||
grammar<'a>(errors: &mut Vec<ast::PtxError>);
|
||||
|
||||
match {
|
||||
r"\s+" => { },
|
||||
|
@ -16,23 +16,18 @@ match {
|
|||
_
|
||||
}
|
||||
|
||||
pub Module: WithErrors<ast::Module<'input>, ast::PtxError> = {
|
||||
pub Module: ast::Module<'input> = {
|
||||
<v:Version> Target <f:Directive*> => {
|
||||
let funcs = WithErrors::from_vec(without_none(f));
|
||||
WithErrors::map2(v, funcs,
|
||||
|v, funcs| ast::Module { version: v, functions: funcs }
|
||||
)
|
||||
ast::Module { version: v, functions: without_none(f) }
|
||||
}
|
||||
};
|
||||
|
||||
Version: WithErrors<(u8, u8), ast::PtxError> = {
|
||||
Version: (u8, u8) = {
|
||||
".version" <v:VersionNumber> => {
|
||||
let dot = v.find('.').unwrap();
|
||||
let major = v[..dot].parse::<u8>().map_err(Into::into);
|
||||
let minor = v[dot+1..].parse::<u8>().map_err(Into::into);
|
||||
WithErrors::from_results(major, minor,
|
||||
|major, minor| (major, minor)
|
||||
)
|
||||
let major = v[..dot].parse::<u8>();
|
||||
let minor = v[dot+1..].parse::<u8>();
|
||||
(major,minor).unwrap_with(errors)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -49,7 +44,7 @@ TargetSpecifier = {
|
|||
"map_f64_to_f32"
|
||||
};
|
||||
|
||||
Directive: Option<WithErrors<ast::Function<'input>, ast::PtxError>> = {
|
||||
Directive: Option<ast::Function<'input>> = {
|
||||
AddressSize => None,
|
||||
<f:Function> => Some(f),
|
||||
File => None,
|
||||
|
@ -60,11 +55,12 @@ AddressSize = {
|
|||
".address_size" Num
|
||||
};
|
||||
|
||||
Function: WithErrors<ast::Function<'input>, ast::PtxError> = {
|
||||
LinkingDirective* <k:IsKernel> <n:ID> "(" <args:Comma<FunctionInput>> ")" <b:FunctionBody> => {
|
||||
WithErrors::from_vec(args)
|
||||
.map(|args| ast::Function{kernel: k, name: n, args: args, body: b})
|
||||
}
|
||||
Function: ast::Function<'input> = {
|
||||
LinkingDirective*
|
||||
<kernel:IsKernel>
|
||||
<name:ID>
|
||||
"(" <args:Comma<FunctionInput>> ")"
|
||||
<body:FunctionBody> => ast::Function{<>}
|
||||
};
|
||||
|
||||
LinkingDirective = {
|
||||
|
@ -79,15 +75,14 @@ IsKernel: bool = {
|
|||
};
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
|
||||
FunctionInput: WithErrors<ast::Argument<'input>, ast::PtxError> = {
|
||||
FunctionInput: ast::Argument<'input> = {
|
||||
".param" <_type:ScalarType> <name:ID> => {
|
||||
WithErrors::new(ast::Argument {a_type: _type, name: name, length: 1 })
|
||||
ast::Argument {a_type: _type, name: name, length: 1 }
|
||||
},
|
||||
".param" <a_type:ScalarType> <name:ID> "[" <length:Num> "]" => {
|
||||
let length = length.parse::<u32>().map_err(Into::into);
|
||||
length.with_errors(
|
||||
|l| ast::Argument { a_type: a_type, name: name, length: l }
|
||||
)
|
||||
let length = length.parse::<u32>();
|
||||
let length = length.unwrap_with(errors);
|
||||
ast::Argument { a_type: a_type, name: name, length: length }
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -95,13 +90,19 @@ FunctionBody: Vec<ast::Statement<'input>> = {
|
|||
"{" <s:Statement*> "}" => { without_none(s) }
|
||||
};
|
||||
|
||||
StateSpaceSpecifier = {
|
||||
".reg",
|
||||
".sreg",
|
||||
".const",
|
||||
".global",
|
||||
".local",
|
||||
".shared"
|
||||
StateSpaceSpecifier: ast::StateSpace = {
|
||||
".reg" => ast::StateSpace::Reg,
|
||||
".sreg" => ast::StateSpace::Sreg,
|
||||
".const" => ast::StateSpace::Const,
|
||||
".global" => ast::StateSpace::Global,
|
||||
".local" => ast::StateSpace::Local,
|
||||
".shared" => ast::StateSpace::Shared
|
||||
};
|
||||
|
||||
|
||||
Type: ast::Type = {
|
||||
<t:ScalarType> => ast::Type::Scalar(t),
|
||||
<t:ExtendedScalarType> => ast::Type::ExtendedScalar(t),
|
||||
};
|
||||
|
||||
ScalarType: ast::ScalarType = {
|
||||
|
@ -122,12 +123,9 @@ ScalarType: ast::ScalarType = {
|
|||
".f64" => ast::ScalarType::F64,
|
||||
};
|
||||
|
||||
|
||||
Type = {
|
||||
BaseType,
|
||||
".pred",
|
||||
".f16",
|
||||
".f16x2",
|
||||
ExtendedScalarType: ast::ExtendedScalarType = {
|
||||
".f16x2" => ast::ExtendedScalarType::F16x2,
|
||||
".pred" => ast::ExtendedScalarType::Pred,
|
||||
};
|
||||
|
||||
BaseType = {
|
||||
|
@ -157,13 +155,24 @@ Label: &'input str = {
|
|||
<id:ID> ":" => id
|
||||
};
|
||||
|
||||
Variable: ast::Variable = {
|
||||
StateSpaceSpecifier Type VariableName => ast::Variable {}
|
||||
Variable: ast::Variable<'input> = {
|
||||
<s:StateSpaceSpecifier> <t:Type> <v:VariableName> => {
|
||||
let (name, count) = v;
|
||||
ast::Variable { space: s, v_type: t, name: name, count: count }
|
||||
}
|
||||
};
|
||||
|
||||
VariableName = {
|
||||
ID,
|
||||
ParametrizedID
|
||||
VariableName: (&'input str, Option<u32>) = {
|
||||
<id:ID> => (id, None),
|
||||
<id:ParametrizedID> => {
|
||||
let left_angle = id.as_bytes().iter().copied().position(|x| x == b'<').unwrap();
|
||||
let count = id[left_angle+1..id.len()-1].parse::<u32>();
|
||||
let count = match count {
|
||||
Ok(c) => Some(c),
|
||||
Err(e) => { errors.push(e.into()); None },
|
||||
};
|
||||
(&id[0..left_angle], count)
|
||||
}
|
||||
};
|
||||
|
||||
Instruction = {
|
||||
|
|
|
@ -1,12 +1,9 @@
|
|||
use super::ptx;
|
||||
|
||||
fn parse_and_assert(s: &str) {
|
||||
assert!(
|
||||
ptx::ModuleParser::new()
|
||||
.parse(s)
|
||||
.unwrap()
|
||||
.errors
|
||||
.len() == 0);
|
||||
let mut errors = Vec::new();
|
||||
let ast = ptx::ModuleParser::new().parse(&mut errors, s).unwrap();
|
||||
assert!(errors.len() == 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -18,4 +15,4 @@ fn empty() {
|
|||
fn vector_add() {
|
||||
let vector_add = include_str!("vectorAdd_kernel64.ptx");
|
||||
parse_and_assert(vector_add);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,35 +1,17 @@
|
|||
use crate::ast;
|
||||
use rspirv::dr;
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub struct TranslationError {
|
||||
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
|
||||
enum SpirvType {
|
||||
Base(BaseType),
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
|
||||
enum BaseType {
|
||||
Int8,
|
||||
Int16,
|
||||
Int32,
|
||||
Int64,
|
||||
Uint8,
|
||||
Uint16,
|
||||
Uint32,
|
||||
Uint64,
|
||||
Float16,
|
||||
Float32,
|
||||
Float64,
|
||||
Base(ast::ScalarType),
|
||||
}
|
||||
|
||||
struct TypeWordMap {
|
||||
void: spirv::Word,
|
||||
fn_void: spirv::Word,
|
||||
complex: HashMap<SpirvType, spirv::Word>
|
||||
complex: HashMap<SpirvType, spirv::Word>,
|
||||
}
|
||||
|
||||
impl TypeWordMap {
|
||||
|
@ -38,34 +20,57 @@ impl TypeWordMap {
|
|||
TypeWordMap {
|
||||
void: void,
|
||||
fn_void: b.type_function(void, vec![]),
|
||||
complex: HashMap::<SpirvType, spirv::Word>::new()
|
||||
complex: HashMap::<SpirvType, spirv::Word>::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn void(&self) -> spirv::Word { self.void }
|
||||
fn fn_void(&self) -> spirv::Word { self.fn_void }
|
||||
fn void(&self) -> spirv::Word {
|
||||
self.void
|
||||
}
|
||||
fn fn_void(&self) -> spirv::Word {
|
||||
self.fn_void
|
||||
}
|
||||
|
||||
fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> spirv::Word {
|
||||
*self.complex.entry(t).or_insert_with(|| {
|
||||
match t {
|
||||
SpirvType::Base(BaseType::Int8) => b.type_int(8, 1),
|
||||
SpirvType::Base(BaseType::Int16) => b.type_int(16, 1),
|
||||
SpirvType::Base(BaseType::Int32) => b.type_int(32, 1),
|
||||
SpirvType::Base(BaseType::Int64) => b.type_int(64, 1),
|
||||
SpirvType::Base(BaseType::Uint8) => b.type_int(8, 0),
|
||||
SpirvType::Base(BaseType::Uint16) => b.type_int(16, 0),
|
||||
SpirvType::Base(BaseType::Uint32) => b.type_int(32, 0),
|
||||
SpirvType::Base(BaseType::Uint64) => b.type_int(64, 0),
|
||||
SpirvType::Base(BaseType::Float16) => b.type_float(16),
|
||||
SpirvType::Base(BaseType::Float32) => b.type_float(32),
|
||||
SpirvType::Base(BaseType::Float64) => b.type_float(64),
|
||||
*self.complex.entry(t).or_insert_with(|| match t {
|
||||
SpirvType::Base(ast::ScalarType::B8) | SpirvType::Base(ast::ScalarType::U8) => {
|
||||
b.type_int(8, 0)
|
||||
}
|
||||
SpirvType::Base(ast::ScalarType::B16) | SpirvType::Base(ast::ScalarType::U16) => {
|
||||
b.type_int(16, 0)
|
||||
}
|
||||
SpirvType::Base(ast::ScalarType::B32) | SpirvType::Base(ast::ScalarType::U32) => {
|
||||
b.type_int(32, 0)
|
||||
}
|
||||
SpirvType::Base(ast::ScalarType::B64) | SpirvType::Base(ast::ScalarType::U64) => {
|
||||
b.type_int(64, 0)
|
||||
}
|
||||
SpirvType::Base(ast::ScalarType::S8) => b.type_int(8, 1),
|
||||
SpirvType::Base(ast::ScalarType::S16) => b.type_int(16, 1),
|
||||
SpirvType::Base(ast::ScalarType::S32) => b.type_int(32, 1),
|
||||
SpirvType::Base(ast::ScalarType::S64) => b.type_int(64, 1),
|
||||
SpirvType::Base(ast::ScalarType::F16) => b.type_float(16),
|
||||
SpirvType::Base(ast::ScalarType::F32) => b.type_float(32),
|
||||
SpirvType::Base(ast::ScalarType::F64) => b.type_float(64),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_spirv(ast: ast::Module) -> Result<Vec<u32>, TranslationError> {
|
||||
struct IdWordMap<'a>(HashMap<&'a str, spirv::Word>);
|
||||
|
||||
impl<'a> IdWordMap<'a> {
|
||||
fn new() -> Self { IdWordMap(HashMap::new()) }
|
||||
}
|
||||
|
||||
impl<'a> IdWordMap<'a> {
|
||||
fn get_or_add(&mut self, b: &mut dr::Builder, id: &'a str) -> spirv::Word {
|
||||
*self.0.entry(id).or_insert_with(|| b.id())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_spirv(ast: ast::Module) -> Result<Vec<u32>, rspirv::dr::Error> {
|
||||
let mut builder = dr::Builder::new();
|
||||
let mut ids = IdWordMap::new();
|
||||
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
|
||||
builder.set_version(1, 0);
|
||||
emit_capabilities(&mut builder);
|
||||
|
@ -74,9 +79,9 @@ pub fn to_spirv(ast: ast::Module) -> Result<Vec<u32>, TranslationError> {
|
|||
emit_memory_model(&mut builder);
|
||||
let mut map = TypeWordMap::new(&mut builder);
|
||||
for f in ast.functions {
|
||||
emit_function(&mut builder, &mut map, &f);
|
||||
emit_function(&mut builder, &mut map, &mut ids, &f)?;
|
||||
}
|
||||
Ok(vec!())
|
||||
Ok(vec![])
|
||||
}
|
||||
|
||||
fn emit_capabilities(builder: &mut dr::Builder) {
|
||||
|
@ -87,21 +92,46 @@ fn emit_capabilities(builder: &mut dr::Builder) {
|
|||
builder.capability(spirv::Capability::Int8);
|
||||
}
|
||||
|
||||
fn emit_extensions(_: &mut dr::Builder) {
|
||||
|
||||
}
|
||||
fn emit_extensions(_: &mut dr::Builder) {}
|
||||
|
||||
fn emit_extended_instruction_sets(builder: &mut dr::Builder) {
|
||||
builder.ext_inst_import("OpenCL.std");
|
||||
}
|
||||
|
||||
fn emit_memory_model(builder: &mut dr::Builder) {
|
||||
builder.memory_model(spirv::AddressingModel::Physical64, spirv::MemoryModel::OpenCL);
|
||||
builder.memory_model(
|
||||
spirv::AddressingModel::Physical64,
|
||||
spirv::MemoryModel::OpenCL,
|
||||
);
|
||||
}
|
||||
|
||||
fn emit_function(builder: &mut dr::Builder, map: &TypeWordMap, f: &ast::Function) {
|
||||
let func_id = builder.begin_function(map.void(), None, spirv::FunctionControl::NONE, map.fn_void());
|
||||
|
||||
builder.ret();
|
||||
builder.end_function();
|
||||
}
|
||||
fn emit_function<'a>(
|
||||
builder: &mut dr::Builder,
|
||||
map: &mut TypeWordMap,
|
||||
ids: &mut IdWordMap<'a>,
|
||||
f: &ast::Function<'a>,
|
||||
) -> Result<(), rspirv::dr::Error> {
|
||||
let func_id = builder.begin_function(
|
||||
map.void(),
|
||||
None,
|
||||
spirv::FunctionControl::NONE,
|
||||
map.fn_void(),
|
||||
)?;
|
||||
for arg in f.args.iter() {
|
||||
let arg_type = map.get_or_add(builder, SpirvType::Base(arg.a_type));
|
||||
builder.function_parameter(arg_type)?;
|
||||
}
|
||||
for s in f.body.iter() {
|
||||
match s {
|
||||
ast::Statement::Label(name) => {
|
||||
let id = ids.get_or_add(builder, name);
|
||||
builder.begin_block(Some(id))?;
|
||||
}
|
||||
ast::Statement::Variable(var) => panic!(),
|
||||
ast::Statement::Instruction(i) => panic!(),
|
||||
}
|
||||
}
|
||||
builder.ret()?;
|
||||
builder.end_function()?;
|
||||
Ok(())
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue