From 6f4530fe839012daca1936102f54d5c5f8184c7f Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Mon, 13 Apr 2020 01:13:45 +0200 Subject: [PATCH] Simplify error handling during ast construction --- ptx/src/ast.rs | 150 ++++++++++++++++--------------------------- ptx/src/ptx.lalrpop | 95 ++++++++++++++------------- ptx/src/test/mod.rs | 11 ++-- ptx/src/translate.rs | 130 ++++++++++++++++++++++--------------- 4 files changed, 191 insertions(+), 195 deletions(-) diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 3bb142d..70550b2 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -1,4 +1,5 @@ use std::convert::From; +use std::convert::Into; use std::error::Error; use std::mem; use std::num::ParseIntError; @@ -6,110 +7,42 @@ use std::num::ParseIntError; quick_error! { #[derive(Debug)] pub enum PtxError { - Parse (err: ParseIntError) { + ParseInt (err: ParseIntError) { + from() display("{}", err) cause(err) - from() } } } -pub struct WithErrors { - pub value: T, - pub errors: Vec, +pub trait UnwrapWithVec { + fn unwrap_with(self, errs: &mut Vec) -> To; } -impl WithErrors { - pub fn new(t: T) -> Self { - WithErrors { - value: t, - errors: Vec::new(), - } - } - - pub fn map U, U>(self, f: F) -> WithErrors { - WithErrors { - value: f(self.value), - errors: self.errors, - } - } - - pub fn map2 T>( - x: WithErrors, - y: WithErrors, - f: F, - ) -> Self { - let mut errors = x.errors; - let mut errors_other = y.errors; - if errors.len() < errors_other.len() { - mem::swap(&mut errors, &mut errors_other); - } - errors.extend(errors_other); - WithErrors { - value: f(x.value, y.value), - errors: errors, - } +impl, EInto> UnwrapWithVec + for Result +{ + fn unwrap_with(self, errs: &mut Vec) -> R { + self.unwrap_or_else(|e| { + errs.push(e.into()); + R::default() + }) } } -impl WithErrors { - pub fn from_results T>( - x: Result, - y: Result, - f: F, - ) -> Self { - match (x, y) { - (Ok(x), Ok(y)) => WithErrors { - value: f(x, y), - errors: Vec::new(), - }, - (Err(e), Ok(y)) => WithErrors { - value: f(X::default(), y), - errors: vec![e], - }, - (Ok(x), Err(e)) => WithErrors { - value: f(x, Y::default()), - errors: vec![e], - }, - (Err(e1), Err(e2)) => WithErrors { - value: T::default(), - errors: vec![e1, e2], - }, - } - } -} - -impl WithErrors, E> { - pub fn from_vec(v: Vec>) -> Self { - let mut values = Vec::with_capacity(v.len()); - let mut errors = Vec::new(); - for we in v.into_iter() { - values.push(we.value); - errors.extend(we.errors); - } - WithErrors { - value: values, - errors: errors, - } - } -} - -pub trait WithErrorsExt { - fn with_errors To>(self, f: F) -> WithErrors; -} - -impl WithErrorsExt for Result { - fn with_errors To>(self, f: F) -> WithErrors { - self.map_or_else( - |e| WithErrors { - value: To::default(), - errors: vec![e], - }, - |t| WithErrors { - value: f(t), - errors: Vec::new(), - }, - ) +impl< + R1: Default, + EFrom1: std::convert::Into, + R2: Default, + EFrom2: std::convert::Into, + EInto, + > UnwrapWithVec for (Result, Result) +{ + fn unwrap_with(self, errs: &mut Vec) -> (R1, R2) { + let (x, y) = self; + let r1 = x.unwrap_with(errs); + let r2 = y.unwrap_with(errs); + (r1, r2) } } @@ -132,6 +65,13 @@ pub struct Argument<'a> { pub length: u32, } +#[derive(PartialEq, Eq, Hash, Clone, Copy)] +pub enum Type { + Scalar(ScalarType), + ExtendedScalar(ExtendedScalarType), +} + +#[derive(PartialEq, Eq, Hash, Clone, Copy)] pub enum ScalarType { B8, B16, @@ -150,6 +90,12 @@ pub enum ScalarType { F64, } +#[derive(PartialEq, Eq, Hash, Clone, Copy)] +pub enum ExtendedScalarType { + F16x2, + Pred, +} + impl Default for ScalarType { fn default() -> Self { ScalarType::B8 @@ -158,11 +104,25 @@ impl Default for ScalarType { pub enum Statement<'a> { Label(&'a str), - Variable(Variable), + Variable(Variable<'a>), Instruction(Instruction), } -pub struct Variable {} +pub struct Variable<'a> { + pub space: StateSpace, + pub v_type: Type, + pub name: &'a str, + pub count: Option, +} + +pub enum StateSpace { + Reg, + Sreg, + Const, + Global, + Local, + Shared, +} pub enum Instruction { Ld, diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index b646d68..3ff5d9c 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -1,8 +1,8 @@ use crate::ast; -use crate::ast::{WithErrors, WithErrorsExt}; +use crate::ast::UnwrapWithVec; use crate::without_none; -grammar; +grammar<'a>(errors: &mut Vec); match { r"\s+" => { }, @@ -16,23 +16,18 @@ match { _ } -pub Module: WithErrors, ast::PtxError> = { +pub Module: ast::Module<'input> = { Target => { - let funcs = WithErrors::from_vec(without_none(f)); - WithErrors::map2(v, funcs, - |v, funcs| ast::Module { version: v, functions: funcs } - ) + ast::Module { version: v, functions: without_none(f) } } }; -Version: WithErrors<(u8, u8), ast::PtxError> = { +Version: (u8, u8) = { ".version" => { let dot = v.find('.').unwrap(); - let major = v[..dot].parse::().map_err(Into::into); - let minor = v[dot+1..].parse::().map_err(Into::into); - WithErrors::from_results(major, minor, - |major, minor| (major, minor) - ) + let major = v[..dot].parse::(); + let minor = v[dot+1..].parse::(); + (major,minor).unwrap_with(errors) } } @@ -49,7 +44,7 @@ TargetSpecifier = { "map_f64_to_f32" }; -Directive: Option, ast::PtxError>> = { +Directive: Option> = { AddressSize => None, => Some(f), File => None, @@ -60,11 +55,12 @@ AddressSize = { ".address_size" Num }; -Function: WithErrors, ast::PtxError> = { - LinkingDirective* "(" > ")" => { - WithErrors::from_vec(args) - .map(|args| ast::Function{kernel: k, name: n, args: args, body: b}) - } +Function: ast::Function<'input> = { + LinkingDirective* + + + "(" > ")" + => ast::Function{<>} }; LinkingDirective = { @@ -79,15 +75,14 @@ IsKernel: bool = { }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space -FunctionInput: WithErrors, ast::PtxError> = { +FunctionInput: ast::Argument<'input> = { ".param" <_type:ScalarType> => { - WithErrors::new(ast::Argument {a_type: _type, name: name, length: 1 }) + ast::Argument {a_type: _type, name: name, length: 1 } }, ".param" "[" "]" => { - let length = length.parse::().map_err(Into::into); - length.with_errors( - |l| ast::Argument { a_type: a_type, name: name, length: l } - ) + let length = length.parse::(); + let length = length.unwrap_with(errors); + ast::Argument { a_type: a_type, name: name, length: length } } }; @@ -95,13 +90,19 @@ FunctionBody: Vec> = { "{" "}" => { without_none(s) } }; -StateSpaceSpecifier = { - ".reg", - ".sreg", - ".const", - ".global", - ".local", - ".shared" +StateSpaceSpecifier: ast::StateSpace = { + ".reg" => ast::StateSpace::Reg, + ".sreg" => ast::StateSpace::Sreg, + ".const" => ast::StateSpace::Const, + ".global" => ast::StateSpace::Global, + ".local" => ast::StateSpace::Local, + ".shared" => ast::StateSpace::Shared +}; + + +Type: ast::Type = { + => ast::Type::Scalar(t), + => ast::Type::ExtendedScalar(t), }; ScalarType: ast::ScalarType = { @@ -122,12 +123,9 @@ ScalarType: ast::ScalarType = { ".f64" => ast::ScalarType::F64, }; - -Type = { - BaseType, - ".pred", - ".f16", - ".f16x2", +ExtendedScalarType: ast::ExtendedScalarType = { + ".f16x2" => ast::ExtendedScalarType::F16x2, + ".pred" => ast::ExtendedScalarType::Pred, }; BaseType = { @@ -157,13 +155,24 @@ Label: &'input str = { ":" => id }; -Variable: ast::Variable = { - StateSpaceSpecifier Type VariableName => ast::Variable {} +Variable: ast::Variable<'input> = { + => { + let (name, count) = v; + ast::Variable { space: s, v_type: t, name: name, count: count } + } }; -VariableName = { - ID, - ParametrizedID +VariableName: (&'input str, Option) = { + => (id, None), + => { + let left_angle = id.as_bytes().iter().copied().position(|x| x == b'<').unwrap(); + let count = id[left_angle+1..id.len()-1].parse::(); + let count = match count { + Ok(c) => Some(c), + Err(e) => { errors.push(e.into()); None }, + }; + (&id[0..left_angle], count) + } }; Instruction = { diff --git a/ptx/src/test/mod.rs b/ptx/src/test/mod.rs index 9a07271..1de55bb 100644 --- a/ptx/src/test/mod.rs +++ b/ptx/src/test/mod.rs @@ -1,12 +1,9 @@ use super::ptx; fn parse_and_assert(s: &str) { - assert!( - ptx::ModuleParser::new() - .parse(s) - .unwrap() - .errors - .len() == 0); + let mut errors = Vec::new(); + let ast = ptx::ModuleParser::new().parse(&mut errors, s).unwrap(); + assert!(errors.len() == 0); } #[test] @@ -18,4 +15,4 @@ fn empty() { fn vector_add() { let vector_add = include_str!("vectorAdd_kernel64.ptx"); parse_and_assert(vector_add); -} \ No newline at end of file +} diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 6039c55..f3abaf0 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1,35 +1,17 @@ use crate::ast; use rspirv::dr; +use std::collections::hash_map::Entry; use std::collections::HashMap; -pub struct TranslationError { - -} - #[derive(PartialEq, Eq, Hash, Clone, Copy)] enum SpirvType { - Base(BaseType), -} - -#[derive(PartialEq, Eq, Hash, Clone, Copy)] -enum BaseType { - Int8, - Int16, - Int32, - Int64, - Uint8, - Uint16, - Uint32, - Uint64, - Float16, - Float32, - Float64, + Base(ast::ScalarType), } struct TypeWordMap { void: spirv::Word, fn_void: spirv::Word, - complex: HashMap + complex: HashMap, } impl TypeWordMap { @@ -38,34 +20,57 @@ impl TypeWordMap { TypeWordMap { void: void, fn_void: b.type_function(void, vec![]), - complex: HashMap::::new() + complex: HashMap::::new(), } } - fn void(&self) -> spirv::Word { self.void } - fn fn_void(&self) -> spirv::Word { self.fn_void } + fn void(&self) -> spirv::Word { + self.void + } + fn fn_void(&self) -> spirv::Word { + self.fn_void + } fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> spirv::Word { - *self.complex.entry(t).or_insert_with(|| { - match t { - SpirvType::Base(BaseType::Int8) => b.type_int(8, 1), - SpirvType::Base(BaseType::Int16) => b.type_int(16, 1), - SpirvType::Base(BaseType::Int32) => b.type_int(32, 1), - SpirvType::Base(BaseType::Int64) => b.type_int(64, 1), - SpirvType::Base(BaseType::Uint8) => b.type_int(8, 0), - SpirvType::Base(BaseType::Uint16) => b.type_int(16, 0), - SpirvType::Base(BaseType::Uint32) => b.type_int(32, 0), - SpirvType::Base(BaseType::Uint64) => b.type_int(64, 0), - SpirvType::Base(BaseType::Float16) => b.type_float(16), - SpirvType::Base(BaseType::Float32) => b.type_float(32), - SpirvType::Base(BaseType::Float64) => b.type_float(64), + *self.complex.entry(t).or_insert_with(|| match t { + SpirvType::Base(ast::ScalarType::B8) | SpirvType::Base(ast::ScalarType::U8) => { + b.type_int(8, 0) } + SpirvType::Base(ast::ScalarType::B16) | SpirvType::Base(ast::ScalarType::U16) => { + b.type_int(16, 0) + } + SpirvType::Base(ast::ScalarType::B32) | SpirvType::Base(ast::ScalarType::U32) => { + b.type_int(32, 0) + } + SpirvType::Base(ast::ScalarType::B64) | SpirvType::Base(ast::ScalarType::U64) => { + b.type_int(64, 0) + } + SpirvType::Base(ast::ScalarType::S8) => b.type_int(8, 1), + SpirvType::Base(ast::ScalarType::S16) => b.type_int(16, 1), + SpirvType::Base(ast::ScalarType::S32) => b.type_int(32, 1), + SpirvType::Base(ast::ScalarType::S64) => b.type_int(64, 1), + SpirvType::Base(ast::ScalarType::F16) => b.type_float(16), + SpirvType::Base(ast::ScalarType::F32) => b.type_float(32), + SpirvType::Base(ast::ScalarType::F64) => b.type_float(64), }) } } -pub fn to_spirv(ast: ast::Module) -> Result, TranslationError> { +struct IdWordMap<'a>(HashMap<&'a str, spirv::Word>); + +impl<'a> IdWordMap<'a> { + fn new() -> Self { IdWordMap(HashMap::new()) } +} + +impl<'a> IdWordMap<'a> { + fn get_or_add(&mut self, b: &mut dr::Builder, id: &'a str) -> spirv::Word { + *self.0.entry(id).or_insert_with(|| b.id()) + } +} + +pub fn to_spirv(ast: ast::Module) -> Result, rspirv::dr::Error> { let mut builder = dr::Builder::new(); + let mut ids = IdWordMap::new(); // https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module builder.set_version(1, 0); emit_capabilities(&mut builder); @@ -74,9 +79,9 @@ pub fn to_spirv(ast: ast::Module) -> Result, TranslationError> { emit_memory_model(&mut builder); let mut map = TypeWordMap::new(&mut builder); for f in ast.functions { - emit_function(&mut builder, &mut map, &f); + emit_function(&mut builder, &mut map, &mut ids, &f)?; } - Ok(vec!()) + Ok(vec![]) } fn emit_capabilities(builder: &mut dr::Builder) { @@ -87,21 +92,46 @@ fn emit_capabilities(builder: &mut dr::Builder) { builder.capability(spirv::Capability::Int8); } -fn emit_extensions(_: &mut dr::Builder) { - -} +fn emit_extensions(_: &mut dr::Builder) {} fn emit_extended_instruction_sets(builder: &mut dr::Builder) { builder.ext_inst_import("OpenCL.std"); } fn emit_memory_model(builder: &mut dr::Builder) { - builder.memory_model(spirv::AddressingModel::Physical64, spirv::MemoryModel::OpenCL); + builder.memory_model( + spirv::AddressingModel::Physical64, + spirv::MemoryModel::OpenCL, + ); } -fn emit_function(builder: &mut dr::Builder, map: &TypeWordMap, f: &ast::Function) { - let func_id = builder.begin_function(map.void(), None, spirv::FunctionControl::NONE, map.fn_void()); - - builder.ret(); - builder.end_function(); -} \ No newline at end of file +fn emit_function<'a>( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + ids: &mut IdWordMap<'a>, + f: &ast::Function<'a>, +) -> Result<(), rspirv::dr::Error> { + let func_id = builder.begin_function( + map.void(), + None, + spirv::FunctionControl::NONE, + map.fn_void(), + )?; + for arg in f.args.iter() { + let arg_type = map.get_or_add(builder, SpirvType::Base(arg.a_type)); + builder.function_parameter(arg_type)?; + } + for s in f.body.iter() { + match s { + ast::Statement::Label(name) => { + let id = ids.get_or_add(builder, name); + builder.begin_block(Some(id))?; + } + ast::Statement::Variable(var) => panic!(), + ast::Statement::Instruction(i) => panic!(), + } + } + builder.ret()?; + builder.end_function()?; + Ok(()) +}