Simplify error handling during ast construction

This commit is contained in:
Andrzej Janik 2020-04-13 01:13:45 +02:00
parent bbe993392b
commit 6f4530fe83
4 changed files with 191 additions and 195 deletions

View file

@ -1,4 +1,5 @@
use std::convert::From;
use std::convert::Into;
use std::error::Error;
use std::mem;
use std::num::ParseIntError;
@ -6,110 +7,42 @@ use std::num::ParseIntError;
quick_error! {
#[derive(Debug)]
pub enum PtxError {
Parse (err: ParseIntError) {
ParseInt (err: ParseIntError) {
from()
display("{}", err)
cause(err)
from()
}
}
}
pub struct WithErrors<T, E> {
pub value: T,
pub errors: Vec<E>,
pub trait UnwrapWithVec<E, To> {
fn unwrap_with(self, errs: &mut Vec<E>) -> To;
}
impl<T, E> WithErrors<T, E> {
pub fn new(t: T) -> Self {
WithErrors {
value: t,
errors: Vec::new(),
}
}
pub fn map<F: FnOnce(T) -> U, U>(self, f: F) -> WithErrors<U, E> {
WithErrors {
value: f(self.value),
errors: self.errors,
}
}
pub fn map2<X, Y, F: FnOnce(X, Y) -> T>(
x: WithErrors<X, E>,
y: WithErrors<Y, E>,
f: F,
) -> Self {
let mut errors = x.errors;
let mut errors_other = y.errors;
if errors.len() < errors_other.len() {
mem::swap(&mut errors, &mut errors_other);
}
errors.extend(errors_other);
WithErrors {
value: f(x.value, y.value),
errors: errors,
}
impl<R: Default, EFrom: std::convert::Into<EInto>, EInto> UnwrapWithVec<EInto, R>
for Result<R, EFrom>
{
fn unwrap_with(self, errs: &mut Vec<EInto>) -> R {
self.unwrap_or_else(|e| {
errs.push(e.into());
R::default()
})
}
}
impl<T:Default, E: Error> WithErrors<T, E> {
pub fn from_results<X: Default, Y: Default, F: FnOnce(X, Y) -> T>(
x: Result<X, E>,
y: Result<Y, E>,
f: F,
) -> Self {
match (x, y) {
(Ok(x), Ok(y)) => WithErrors {
value: f(x, y),
errors: Vec::new(),
},
(Err(e), Ok(y)) => WithErrors {
value: f(X::default(), y),
errors: vec![e],
},
(Ok(x), Err(e)) => WithErrors {
value: f(x, Y::default()),
errors: vec![e],
},
(Err(e1), Err(e2)) => WithErrors {
value: T::default(),
errors: vec![e1, e2],
},
}
}
}
impl<T, E: Error> WithErrors<Vec<T>, E> {
pub fn from_vec(v: Vec<WithErrors<T, E>>) -> Self {
let mut values = Vec::with_capacity(v.len());
let mut errors = Vec::new();
for we in v.into_iter() {
values.push(we.value);
errors.extend(we.errors);
}
WithErrors {
value: values,
errors: errors,
}
}
}
pub trait WithErrorsExt<From, To, E> {
fn with_errors<F: FnOnce(From) -> To>(self, f: F) -> WithErrors<To, E>;
}
impl<From, To: Default, E> WithErrorsExt<From, To, E> for Result<From, E> {
fn with_errors<F: FnOnce(From) -> To>(self, f: F) -> WithErrors<To, E> {
self.map_or_else(
|e| WithErrors {
value: To::default(),
errors: vec![e],
},
|t| WithErrors {
value: f(t),
errors: Vec::new(),
},
)
impl<
R1: Default,
EFrom1: std::convert::Into<EInto>,
R2: Default,
EFrom2: std::convert::Into<EInto>,
EInto,
> UnwrapWithVec<EInto, (R1, R2)> for (Result<R1, EFrom1>, Result<R2, EFrom2>)
{
fn unwrap_with(self, errs: &mut Vec<EInto>) -> (R1, R2) {
let (x, y) = self;
let r1 = x.unwrap_with(errs);
let r2 = y.unwrap_with(errs);
(r1, r2)
}
}
@ -132,6 +65,13 @@ pub struct Argument<'a> {
pub length: u32,
}
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
pub enum Type {
Scalar(ScalarType),
ExtendedScalar(ExtendedScalarType),
}
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
pub enum ScalarType {
B8,
B16,
@ -150,6 +90,12 @@ pub enum ScalarType {
F64,
}
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
pub enum ExtendedScalarType {
F16x2,
Pred,
}
impl Default for ScalarType {
fn default() -> Self {
ScalarType::B8
@ -158,11 +104,25 @@ impl Default for ScalarType {
pub enum Statement<'a> {
Label(&'a str),
Variable(Variable),
Variable(Variable<'a>),
Instruction(Instruction),
}
pub struct Variable {}
pub struct Variable<'a> {
pub space: StateSpace,
pub v_type: Type,
pub name: &'a str,
pub count: Option<u32>,
}
pub enum StateSpace {
Reg,
Sreg,
Const,
Global,
Local,
Shared,
}
pub enum Instruction {
Ld,

View file

@ -1,8 +1,8 @@
use crate::ast;
use crate::ast::{WithErrors, WithErrorsExt};
use crate::ast::UnwrapWithVec;
use crate::without_none;
grammar;
grammar<'a>(errors: &mut Vec<ast::PtxError>);
match {
r"\s+" => { },
@ -16,23 +16,18 @@ match {
_
}
pub Module: WithErrors<ast::Module<'input>, ast::PtxError> = {
pub Module: ast::Module<'input> = {
<v:Version> Target <f:Directive*> => {
let funcs = WithErrors::from_vec(without_none(f));
WithErrors::map2(v, funcs,
|v, funcs| ast::Module { version: v, functions: funcs }
)
ast::Module { version: v, functions: without_none(f) }
}
};
Version: WithErrors<(u8, u8), ast::PtxError> = {
Version: (u8, u8) = {
".version" <v:VersionNumber> => {
let dot = v.find('.').unwrap();
let major = v[..dot].parse::<u8>().map_err(Into::into);
let minor = v[dot+1..].parse::<u8>().map_err(Into::into);
WithErrors::from_results(major, minor,
|major, minor| (major, minor)
)
let major = v[..dot].parse::<u8>();
let minor = v[dot+1..].parse::<u8>();
(major,minor).unwrap_with(errors)
}
}
@ -49,7 +44,7 @@ TargetSpecifier = {
"map_f64_to_f32"
};
Directive: Option<WithErrors<ast::Function<'input>, ast::PtxError>> = {
Directive: Option<ast::Function<'input>> = {
AddressSize => None,
<f:Function> => Some(f),
File => None,
@ -60,11 +55,12 @@ AddressSize = {
".address_size" Num
};
Function: WithErrors<ast::Function<'input>, ast::PtxError> = {
LinkingDirective* <k:IsKernel> <n:ID> "(" <args:Comma<FunctionInput>> ")" <b:FunctionBody> => {
WithErrors::from_vec(args)
.map(|args| ast::Function{kernel: k, name: n, args: args, body: b})
}
Function: ast::Function<'input> = {
LinkingDirective*
<kernel:IsKernel>
<name:ID>
"(" <args:Comma<FunctionInput>> ")"
<body:FunctionBody> => ast::Function{<>}
};
LinkingDirective = {
@ -79,15 +75,14 @@ IsKernel: bool = {
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
FunctionInput: WithErrors<ast::Argument<'input>, ast::PtxError> = {
FunctionInput: ast::Argument<'input> = {
".param" <_type:ScalarType> <name:ID> => {
WithErrors::new(ast::Argument {a_type: _type, name: name, length: 1 })
ast::Argument {a_type: _type, name: name, length: 1 }
},
".param" <a_type:ScalarType> <name:ID> "[" <length:Num> "]" => {
let length = length.parse::<u32>().map_err(Into::into);
length.with_errors(
|l| ast::Argument { a_type: a_type, name: name, length: l }
)
let length = length.parse::<u32>();
let length = length.unwrap_with(errors);
ast::Argument { a_type: a_type, name: name, length: length }
}
};
@ -95,13 +90,19 @@ FunctionBody: Vec<ast::Statement<'input>> = {
"{" <s:Statement*> "}" => { without_none(s) }
};
StateSpaceSpecifier = {
".reg",
".sreg",
".const",
".global",
".local",
".shared"
StateSpaceSpecifier: ast::StateSpace = {
".reg" => ast::StateSpace::Reg,
".sreg" => ast::StateSpace::Sreg,
".const" => ast::StateSpace::Const,
".global" => ast::StateSpace::Global,
".local" => ast::StateSpace::Local,
".shared" => ast::StateSpace::Shared
};
Type: ast::Type = {
<t:ScalarType> => ast::Type::Scalar(t),
<t:ExtendedScalarType> => ast::Type::ExtendedScalar(t),
};
ScalarType: ast::ScalarType = {
@ -122,12 +123,9 @@ ScalarType: ast::ScalarType = {
".f64" => ast::ScalarType::F64,
};
Type = {
BaseType,
".pred",
".f16",
".f16x2",
ExtendedScalarType: ast::ExtendedScalarType = {
".f16x2" => ast::ExtendedScalarType::F16x2,
".pred" => ast::ExtendedScalarType::Pred,
};
BaseType = {
@ -157,13 +155,24 @@ Label: &'input str = {
<id:ID> ":" => id
};
Variable: ast::Variable = {
StateSpaceSpecifier Type VariableName => ast::Variable {}
Variable: ast::Variable<'input> = {
<s:StateSpaceSpecifier> <t:Type> <v:VariableName> => {
let (name, count) = v;
ast::Variable { space: s, v_type: t, name: name, count: count }
}
};
VariableName = {
ID,
ParametrizedID
VariableName: (&'input str, Option<u32>) = {
<id:ID> => (id, None),
<id:ParametrizedID> => {
let left_angle = id.as_bytes().iter().copied().position(|x| x == b'<').unwrap();
let count = id[left_angle+1..id.len()-1].parse::<u32>();
let count = match count {
Ok(c) => Some(c),
Err(e) => { errors.push(e.into()); None },
};
(&id[0..left_angle], count)
}
};
Instruction = {

View file

@ -1,12 +1,9 @@
use super::ptx;
fn parse_and_assert(s: &str) {
assert!(
ptx::ModuleParser::new()
.parse(s)
.unwrap()
.errors
.len() == 0);
let mut errors = Vec::new();
let ast = ptx::ModuleParser::new().parse(&mut errors, s).unwrap();
assert!(errors.len() == 0);
}
#[test]
@ -18,4 +15,4 @@ fn empty() {
fn vector_add() {
let vector_add = include_str!("vectorAdd_kernel64.ptx");
parse_and_assert(vector_add);
}
}

View file

@ -1,35 +1,17 @@
use crate::ast;
use rspirv::dr;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
pub struct TranslationError {
}
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
enum SpirvType {
Base(BaseType),
}
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
enum BaseType {
Int8,
Int16,
Int32,
Int64,
Uint8,
Uint16,
Uint32,
Uint64,
Float16,
Float32,
Float64,
Base(ast::ScalarType),
}
struct TypeWordMap {
void: spirv::Word,
fn_void: spirv::Word,
complex: HashMap<SpirvType, spirv::Word>
complex: HashMap<SpirvType, spirv::Word>,
}
impl TypeWordMap {
@ -38,34 +20,57 @@ impl TypeWordMap {
TypeWordMap {
void: void,
fn_void: b.type_function(void, vec![]),
complex: HashMap::<SpirvType, spirv::Word>::new()
complex: HashMap::<SpirvType, spirv::Word>::new(),
}
}
fn void(&self) -> spirv::Word { self.void }
fn fn_void(&self) -> spirv::Word { self.fn_void }
fn void(&self) -> spirv::Word {
self.void
}
fn fn_void(&self) -> spirv::Word {
self.fn_void
}
fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> spirv::Word {
*self.complex.entry(t).or_insert_with(|| {
match t {
SpirvType::Base(BaseType::Int8) => b.type_int(8, 1),
SpirvType::Base(BaseType::Int16) => b.type_int(16, 1),
SpirvType::Base(BaseType::Int32) => b.type_int(32, 1),
SpirvType::Base(BaseType::Int64) => b.type_int(64, 1),
SpirvType::Base(BaseType::Uint8) => b.type_int(8, 0),
SpirvType::Base(BaseType::Uint16) => b.type_int(16, 0),
SpirvType::Base(BaseType::Uint32) => b.type_int(32, 0),
SpirvType::Base(BaseType::Uint64) => b.type_int(64, 0),
SpirvType::Base(BaseType::Float16) => b.type_float(16),
SpirvType::Base(BaseType::Float32) => b.type_float(32),
SpirvType::Base(BaseType::Float64) => b.type_float(64),
*self.complex.entry(t).or_insert_with(|| match t {
SpirvType::Base(ast::ScalarType::B8) | SpirvType::Base(ast::ScalarType::U8) => {
b.type_int(8, 0)
}
SpirvType::Base(ast::ScalarType::B16) | SpirvType::Base(ast::ScalarType::U16) => {
b.type_int(16, 0)
}
SpirvType::Base(ast::ScalarType::B32) | SpirvType::Base(ast::ScalarType::U32) => {
b.type_int(32, 0)
}
SpirvType::Base(ast::ScalarType::B64) | SpirvType::Base(ast::ScalarType::U64) => {
b.type_int(64, 0)
}
SpirvType::Base(ast::ScalarType::S8) => b.type_int(8, 1),
SpirvType::Base(ast::ScalarType::S16) => b.type_int(16, 1),
SpirvType::Base(ast::ScalarType::S32) => b.type_int(32, 1),
SpirvType::Base(ast::ScalarType::S64) => b.type_int(64, 1),
SpirvType::Base(ast::ScalarType::F16) => b.type_float(16),
SpirvType::Base(ast::ScalarType::F32) => b.type_float(32),
SpirvType::Base(ast::ScalarType::F64) => b.type_float(64),
})
}
}
pub fn to_spirv(ast: ast::Module) -> Result<Vec<u32>, TranslationError> {
struct IdWordMap<'a>(HashMap<&'a str, spirv::Word>);
impl<'a> IdWordMap<'a> {
fn new() -> Self { IdWordMap(HashMap::new()) }
}
impl<'a> IdWordMap<'a> {
fn get_or_add(&mut self, b: &mut dr::Builder, id: &'a str) -> spirv::Word {
*self.0.entry(id).or_insert_with(|| b.id())
}
}
pub fn to_spirv(ast: ast::Module) -> Result<Vec<u32>, rspirv::dr::Error> {
let mut builder = dr::Builder::new();
let mut ids = IdWordMap::new();
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
builder.set_version(1, 0);
emit_capabilities(&mut builder);
@ -74,9 +79,9 @@ pub fn to_spirv(ast: ast::Module) -> Result<Vec<u32>, TranslationError> {
emit_memory_model(&mut builder);
let mut map = TypeWordMap::new(&mut builder);
for f in ast.functions {
emit_function(&mut builder, &mut map, &f);
emit_function(&mut builder, &mut map, &mut ids, &f)?;
}
Ok(vec!())
Ok(vec![])
}
fn emit_capabilities(builder: &mut dr::Builder) {
@ -87,21 +92,46 @@ fn emit_capabilities(builder: &mut dr::Builder) {
builder.capability(spirv::Capability::Int8);
}
fn emit_extensions(_: &mut dr::Builder) {
}
fn emit_extensions(_: &mut dr::Builder) {}
fn emit_extended_instruction_sets(builder: &mut dr::Builder) {
builder.ext_inst_import("OpenCL.std");
}
fn emit_memory_model(builder: &mut dr::Builder) {
builder.memory_model(spirv::AddressingModel::Physical64, spirv::MemoryModel::OpenCL);
builder.memory_model(
spirv::AddressingModel::Physical64,
spirv::MemoryModel::OpenCL,
);
}
fn emit_function(builder: &mut dr::Builder, map: &TypeWordMap, f: &ast::Function) {
let func_id = builder.begin_function(map.void(), None, spirv::FunctionControl::NONE, map.fn_void());
builder.ret();
builder.end_function();
}
fn emit_function<'a>(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
ids: &mut IdWordMap<'a>,
f: &ast::Function<'a>,
) -> Result<(), rspirv::dr::Error> {
let func_id = builder.begin_function(
map.void(),
None,
spirv::FunctionControl::NONE,
map.fn_void(),
)?;
for arg in f.args.iter() {
let arg_type = map.get_or_add(builder, SpirvType::Base(arg.a_type));
builder.function_parameter(arg_type)?;
}
for s in f.body.iter() {
match s {
ast::Statement::Label(name) => {
let id = ids.get_or_add(builder, name);
builder.begin_block(Some(id))?;
}
ast::Statement::Variable(var) => panic!(),
ast::Statement::Instruction(i) => panic!(),
}
}
builder.ret()?;
builder.end_function()?;
Ok(())
}