From bcb749cdd913cb32c988f786982772e9b9b33bcb Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 18 Sep 2020 18:08:40 +0200 Subject: [PATCH] Continue working on a better addressable support --- ptx/src/ast.rs | 38 +- ptx/src/lib.rs | 1 + ptx/src/ptx.lalrpop | 2 +- ptx/src/test/mod.rs | 7 +- ptx/src/translate.rs | 1437 ++++++++++++++++++++++++------------------ 5 files changed, 862 insertions(+), 623 deletions(-) diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 7ac9d18..3a5022d 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -164,8 +164,8 @@ pub enum MethodDecl<'a, P: ArgParams> { Kernel(&'a str, Vec>), } -pub type FnArgument = Variable; -pub type KernelArgument = Variable; +pub type FnArgument

= Variable; +pub type KernelArgument

= Variable; pub struct Function<'a, P: ArgParams, S> { pub func_directive: MethodDecl<'a, P>, @@ -316,7 +316,7 @@ pub struct PredAt { pub enum Instruction { Ld(LdData, Arg2

), - Mov(MovType, Arg2

), + Mov(MovType, Arg2Mov

), MovVector(MovVectorDetails, Arg2Vec

), Mul(MulDetails, Arg3

), Add(AddDetails, Arg3

), @@ -354,7 +354,7 @@ pub struct CallInst { pub trait ArgParams { type ID; type Operand; - type MemoryOperand; + type MovOperand; type CallOperand; type VecOperand; } @@ -366,7 +366,7 @@ pub struct ParsedArgParams<'a> { impl<'a> ArgParams for ParsedArgParams<'a> { type ID = &'a str; type Operand = Operand<&'a str>; - type MemoryOperand = Operand<&'a str>; + type MovOperand = MovOperand<&'a str>; type CallOperand = CallOperand<&'a str>; type VecOperand = (&'a str, u8); } @@ -380,13 +380,27 @@ pub struct Arg2 { pub src: P::Operand, } -pub struct Arg2Ld { +pub struct Arg2Mov { pub dst: P::ID, - pub src: P::MemoryOperand, + pub src: P::MovOperand, +} + +impl<'input> From>> for Arg2Mov> { + fn from(a: Arg2>) -> Arg2Mov> { + let new_src = match a.src { + Operand::Reg(r) => MovOperand::Reg(r), + Operand::RegOffset(r, imm) => MovOperand::RegOffset(r, imm), + Operand::Imm(x) => MovOperand::Imm(x), + }; + Arg2Mov { + dst: a.dst, + src: new_src, + } + } } pub struct Arg2St { - pub src1: P::MemoryOperand, + pub src1: P::Operand, pub src2: P::Operand, } @@ -419,6 +433,14 @@ pub struct Arg5 { pub src3: P::Operand, } +#[derive(Copy, Clone)] +pub enum MovOperand { + Reg(ID), + Address(ID), + RegOffset(ID, i32), + AddressOffset(ID, i32), + Imm(u32), +} #[derive(Copy, Clone)] pub enum Operand { Reg(ID), diff --git a/ptx/src/lib.rs b/ptx/src/lib.rs index 5e12579..8ae1c6d 100644 --- a/ptx/src/lib.rs +++ b/ptx/src/lib.rs @@ -31,6 +31,7 @@ pub use crate::ptx::ModuleParser; pub use lalrpop_util::lexer::Token; pub use lalrpop_util::ParseError; pub use rspirv::dr::Error as SpirvError; +pub use translate::TranslateError as TranslateError; pub use translate::to_spirv; pub(crate) fn without_none(x: Vec>) -> Vec { diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 44f29a5..46d0b48 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -496,7 +496,7 @@ LdCacheOperator: ast::LdCacheOperator = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov InstMov: ast::Instruction> = { "mov" => { - ast::Instruction::Mov(t, a) + ast::Instruction::Mov(t, a.into()) }, "mov" => { ast::Instruction::MovVector(ast::MovVectorDetails{typ: t, length: 0}, a) diff --git a/ptx/src/test/mod.rs b/ptx/src/test/mod.rs index f40fc02..d251884 100644 --- a/ptx/src/test/mod.rs +++ b/ptx/src/test/mod.rs @@ -1,4 +1,5 @@ use super::ptx; +use super::TranslateError; mod spirv_run; @@ -8,7 +9,7 @@ fn parse_and_assert(s: &str) { assert!(errors.len() == 0); } -fn compile_and_assert(s: &str) -> Result<(), rspirv::dr::Error> { +fn compile_and_assert(s: &str) -> Result<(), TranslateError> { let mut errors = Vec::new(); let ast = ptx::ModuleParser::new().parse(&mut errors, s).unwrap(); crate::to_spirv(ast)?; @@ -28,14 +29,14 @@ fn operands_ptx() { #[test] #[allow(non_snake_case)] -fn vectorAdd_kernel64_ptx() -> Result<(), rspirv::dr::Error> { +fn vectorAdd_kernel64_ptx() -> Result<(), TranslateError> { let vector_add = include_str!("vectorAdd_kernel64.ptx"); compile_and_assert(vector_add) } #[test] #[allow(non_snake_case)] -fn _Z9vectorAddPKfS0_Pfi_ptx() -> Result<(), rspirv::dr::Error> { +fn _Z9vectorAddPKfS0_Pfi_ptx() -> Result<(), TranslateError> { let vector_add = include_str!("_Z9vectorAddPKfS0_Pfi.ptx"); compile_and_assert(vector_add) } diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 45372f1..0617cbe 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -5,6 +5,22 @@ use std::{borrow::Cow, iter, mem}; use rspirv::binary::Assemble; +quick_error! { + #[derive(Debug)] + pub enum TranslateError { + UnknownSymbol {} + UntypedSymbol {} + MismatchedType {} + Spirv (err: rspirv::dr::Error) { + from() + display("{}", err) + cause(err) + } + Unreachable {} + Todo {} + } +} + #[derive(PartialEq, Eq, Hash, Clone)] enum SpirvType { Base(SpirvScalarKey), @@ -184,13 +200,13 @@ impl TypeWordMap { } } -pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result { +pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result { let mut id_defs = GlobalStringIdResolver::new(1); let ssa_functions = ast .functions .into_iter() .map(|f| to_ssa_function(&mut id_defs, f)) - .collect::>(); + .collect::, _>>()?; let mut builder = dr::Builder::new(); builder.reserve_ids(id_defs.current_id()); // https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module @@ -217,11 +233,11 @@ fn emit_function_header<'a>( map: &mut TypeWordMap, global: &GlobalStringIdResolver<'a>, func_directive: ast::MethodDecl, -) -> Result<(), dr::Error> { +) -> Result<(), TranslateError> { let (ret_type, func_type) = get_function_type(builder, map, &func_directive); let fn_id = match func_directive { ast::MethodDecl::Kernel(name, _) => { - let fn_id = global.get_id(name); + let fn_id = global.get_id(name)?; builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, &[]); fn_id } @@ -246,7 +262,7 @@ fn emit_function_header<'a>( Ok(()) } -pub fn to_spirv<'a>(ast: ast::Module<'a>) -> Result, dr::Error> { +pub fn to_spirv<'a>(ast: ast::Module<'a>) -> Result, TranslateError> { let module = to_spirv_module(ast)?; Ok(module.assemble()) } @@ -276,7 +292,7 @@ fn emit_memory_model(builder: &mut dr::Builder) { fn to_ssa_function<'a>( id_defs: &mut GlobalStringIdResolver<'a>, f: ast::ParsedFunction<'a>, -) -> ExpandedFunction<'a> { +) -> Result, TranslateError> { let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive); to_ssa(str_resolver, fn_resolver, fn_decl, f.body) } @@ -316,25 +332,26 @@ fn to_ssa<'input, 'b>( fn_defs: GlobalFnDeclResolver<'input, 'b>, f_args: ast::MethodDecl<'input, ExpandedArgParams>, f_body: Option>>>, -) -> ExpandedFunction<'input> { +) -> Result, TranslateError> { let f_body = match f_body { Some(vec) => vec, None => { - return ExpandedFunction { + return Ok(ExpandedFunction { func_directive: f_args, body: None, - } + }) } }; - let normalized_ids = normalize_identifiers(&mut id_defs, &fn_defs, f_body); + let normalized_ids = normalize_identifiers(&mut id_defs, &fn_defs, f_body)?; let mut numeric_id_defs = id_defs.finish(); let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs); let unadorned_statements = - add_types_to_statements(unadorned_statements, &fn_defs, &numeric_id_defs); + add_types_to_statements(unadorned_statements, &fn_defs, &numeric_id_defs)?; + let mut numeric_id_defs = numeric_id_defs.finish(); + let (f_args, ssa_statements) = + insert_mem_ssa_statements(unadorned_statements, &mut numeric_id_defs, f_args)?; todo!() /* - let (f_args, ssa_statements) = - insert_mem_ssa_statements(unadorned_statements, &mut numeric_id_defs, f_args); let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs); let expanded_statements = insert_implicit_conversions(expanded_statements, &mut numeric_id_defs); @@ -359,33 +376,104 @@ fn add_types_to_statements( func: Vec, fn_defs: &GlobalFnDeclResolver, id_defs: &NumericIdResolver, -) -> Vec { +) -> Result, TranslateError> { func.into_iter() .map(|s| { match s { Statement::Instruction(ast::Instruction::Call(call)) => { // TODO: error out if lengths don't match - let fn_def = fn_defs.get_fn_decl(call.func); + let fn_def = fn_defs.get_fn_decl(call.func)?; let ret_params = to_resolved_fn_args(call.ret_params, &*fn_def.ret_vals); let param_list = to_resolved_fn_args(call.param_list, &*fn_def.params); - let resolved_call: ResolvedCall = ResolvedCall { + let resolved_call = ResolvedCall { uniform: call.uniform, ret_params, func: call.func, param_list, }; - Statement::Call(resolved_call) + Ok(Statement::Call(resolved_call)) } - Statement::Instruction(ast::Instruction::Ld(d, arg)) => { - todo!() + // Supported ld/st: + // global: only compatible with reg b64/u64/s64 source/dest + // generic: compatible with global/local sources + // param: compiled as mov + // local compiled as mov + // We would like to convert ld/st local/param to movs here, + // but they have different semantics for implicit conversions + // For now, we convert generic ld from local params to ld.local. + // This way, we can rely on further stages of the compilation on + // ld.generic & ld.global having bytes address source + // One complication: immediate address is only allowed in local, + // It is not supported in generic ld + // ld.local foo, [1]; + Statement::Instruction(ast::Instruction::Ld(mut d, arg)) => { + match arg.src.underlying() { + None => return Ok(Statement::Instruction(ast::Instruction::Ld(d, arg))), + Some(u) => { + let (ss, typ) = id_defs.get_typed(*u)?; + match (d.state_space, ss) { + (ast::LdStateSpace::Generic, StateSpace::Local) => { + d.state_space = ast::LdStateSpace::Local; + } + _ => (), + }; + } + }; + + Ok(Statement::Instruction(ast::Instruction::Ld(d, arg))) } - Statement::Instruction(ast::Instruction::MovVector(dets, args)) => { - todo!() + Statement::Instruction(ast::Instruction::St(mut d, arg)) => { + match arg.src1.underlying() { + None => return Ok(Statement::Instruction(ast::Instruction::St(d, arg))), + Some(u) => { + let (ss, typ) = id_defs.get_typed(*u)?; + match (d.state_space, ss) { + (ast::StStateSpace::Generic, StateSpace::Local) => { + d.state_space = ast::StStateSpace::Local; + } + _ => (), + }; + } + }; + Ok(Statement::Instruction(ast::Instruction::St(d, arg))) } - s => todo!(), + Statement::Instruction(ast::Instruction::Mov(d, mut arg)) => { + arg.src = match arg.src { + ast::MovOperand::Reg(id) => { + let (ss, typ) = id_defs.get_typed(id)?; + match ss { + StateSpace::Reg => ast::MovOperand::Reg(id), + StateSpace::Const + | StateSpace::Global + | StateSpace::Local + | StateSpace::Shared + | StateSpace::Param + | StateSpace::ParamReg => ast::MovOperand::Address(id), + } + } + ast::MovOperand::RegOffset(id, imm) => { + let (ss, typ) = id_defs.get_typed(id)?; + match ss { + StateSpace::Reg => ast::MovOperand::RegOffset(id, imm), + StateSpace::Const + | StateSpace::Global + | StateSpace::Local + | StateSpace::Shared + | StateSpace::Param + | StateSpace::ParamReg => ast::MovOperand::AddressOffset(id, imm), + } + } + a @ ast::MovOperand::Imm(_) => a, + ast::MovOperand::Address(_) | ast::MovOperand::AddressOffset(_, _) => { + unreachable!() + } + }; + Ok(Statement::Instruction(ast::Instruction::Mov(d, arg))) + } + s => Ok(s), } }) - .collect() + .collect::, _>>() } fn to_resolved_fn_args( @@ -478,18 +566,21 @@ fn normalize_predicates( fn insert_mem_ssa_statements<'a, 'b>( func: Vec, - id_def: &mut NumericIdResolver, + id_def: &mut MutableNumericIdResolver, mut f_args: ast::MethodDecl<'a, ExpandedArgParams>, -) -> ( - ast::MethodDecl<'a, ExpandedArgParams>, - Vec, -) { +) -> Result< + ( + ast::MethodDecl<'a, ExpandedArgParams>, + Vec, + ), + TranslateError, +> { let mut result = Vec::with_capacity(func.len()); let out_param = match &mut f_args { ast::MethodDecl::Kernel(_, in_params) => { for p in in_params.iter_mut() { let typ = ast::Type::from(p.v_type); - let new_id = id_def.new_id(Some((StateSpace::Param, typ))); + let new_id = id_def.new_id(typ); result.push(Statement::Variable(ast::Variable { align: p.align, v_type: ast::VariableType::Param(p.v_type), @@ -508,12 +599,8 @@ fn insert_mem_ssa_statements<'a, 'b>( } ast::MethodDecl::Func(out_params, _, in_params) => { for p in in_params.iter_mut() { - let ss = match p.v_type { - ast::FnArgumentType::Reg(_) => StateSpace::Reg, - ast::FnArgumentType::Param(_) => StateSpace::Param, - }; let typ = ast::Type::from(p.v_type); - let new_id = id_def.new_id(Some((ss, typ))); + let new_id = id_def.new_id(typ); let var_typ = ast::VariableType::from(p.v_type); result.push(Statement::Variable(ast::Variable { align: p.align, @@ -545,31 +632,28 @@ fn insert_mem_ssa_statements<'a, 'b>( }; for s in func { match s { - Statement::Call(call) => insert_mem_ssa_statement_default(id_def, &mut result, call), + Statement::Call(call) => insert_mem_ssa_statement_default(id_def, &mut result, call)?, Statement::Instruction(inst) => match inst { ast::Instruction::Ret(d) => { if let Some(out_param) = out_param { - let typ = id_def.get_type(out_param); + let typ = id_def.get_typed(out_param)?; let new_id = id_def.new_id(typ); result.push(Statement::LoadVar( ast::Arg2 { dst: new_id, src: out_param, }, - typ.unwrap().1, + typ, )); result.push(Statement::RetValue(d, new_id)); } else { result.push(Statement::Instruction(ast::Instruction::Ret(d))) } } - inst => insert_mem_ssa_statement_default(id_def, &mut result, inst), + inst => insert_mem_ssa_statement_default(id_def, &mut result, inst)?, }, Statement::Conditional(mut bra) => { - let generated_id = id_def.new_id(Some(( - StateSpace::Reg, - ast::Type::Scalar(ast::ScalarType::Pred), - ))); + let generated_id = id_def.new_id(ast::Type::Scalar(ast::ScalarType::Pred)); result.push(Statement::LoadVar( Arg2 { dst: generated_id, @@ -589,41 +673,45 @@ fn insert_mem_ssa_statements<'a, 'b>( Statement::Composite(_) => todo!(), } } - (f_args, result) + Ok((f_args, result)) } trait VisitVariable: Sized { fn visit_variable< 'a, - F: FnMut(ArgumentDescriptor, Option) -> spirv::Word, + F: FnMut( + ArgumentDescriptor, + Option, + ) -> Result, >( self, f: &mut F, - ) -> UnadornedStatement; + ) -> Result; } trait VisitVariableExpanded { fn visit_variable_extended< - F: FnMut(ArgumentDescriptor, Option) -> spirv::Word, + F: FnMut( + ArgumentDescriptor, + Option, + ) -> Result, >( self, f: &mut F, - ) -> ExpandedStatement; + ) -> Result; } fn insert_mem_ssa_statement_default<'a, F: VisitVariable>( - id_def: &mut NumericIdResolver, + id_def: &mut MutableNumericIdResolver, result: &mut Vec, stmt: F, -) { +) -> Result<(), TranslateError> { let mut post_statements = Vec::new(); let new_statement = stmt.visit_variable(&mut |desc: ArgumentDescriptor, _| { - let id_type = match (id_def.get_type(desc.op), desc.sema) { - (Some((_, t)), ArgumentSemantics::ParamPtr) - | (Some((_, t)), ArgumentSemantics::Default) => t, - (Some((_, t)), ArgumentSemantics::Ptr) => ast::Type::Scalar(ast::ScalarType::B64), - (None, _) => return desc.op, + let id_type = match (id_def.get_typed(desc.op)?, desc.sema) { + (t, ArgumentSemantics::ParamPtr) | (t, ArgumentSemantics::Default) => t, + (t, ArgumentSemantics::Ptr) => ast::Type::Scalar(ast::ScalarType::B64), }; - let generated_id = id_def.new_id(Some((StateSpace::Reg, id_type))); + let generated_id = id_def.new_id(id_type); if !desc.is_dst { result.push(Statement::LoadVar( Arg2 { @@ -641,12 +729,14 @@ fn insert_mem_ssa_statement_default<'a, F: VisitVariable>( id_type, )); } - generated_id - }); + Ok(generated_id) + })?; result.push(new_statement); result.append(&mut post_statements); + Ok(()) } +/* fn expand_arguments<'a, 'b>( func: Vec, id_def: &'b mut NumericIdResolver<'a>, @@ -656,7 +746,7 @@ fn expand_arguments<'a, 'b>( match s { Statement::Call(call) => { let mut visitor = FlattenArguments::new(&mut result, id_def); - let (new_call, post_stmts) = (call.map(&mut visitor), visitor.post_stmts); + let (new_call, post_stmts) = (call.map(&mut visitor)?, visitor.post_stmts); result.push(Statement::Call(new_call)); result.extend(post_stmts); } @@ -687,6 +777,7 @@ fn expand_arguments<'a, 'b>( } result } +*/ struct FlattenArguments<'a, 'b> { func: &'b mut Vec, @@ -711,15 +802,15 @@ impl<'a, 'b> ArgumentMapVisitor &mut self, desc: ArgumentDescriptor, _: Option, - ) -> spirv::Word { - desc.op + ) -> Result { + Ok(desc.op) } fn operand( &mut self, desc: ArgumentDescriptor>, typ: ast::Type, - ) -> spirv::Word { + ) -> Result { match desc.op { ast::Operand::Reg(r) => self.variable(desc.new_op(r), Some(typ)), ast::Operand::Imm(x) => { @@ -736,77 +827,74 @@ impl<'a, 'b> ArgumentMapVisitor typ: scalar_t, value: x as i64, })); - id + Ok(id) } - ast::Operand::RegOffset(reg, offset) => { - match desc.sema { - ArgumentSemantics::Default => { - let scalar_t = if let ast::Type::Scalar(scalar) = typ { - scalar - } else { - todo!() - }; - let id_constant_stmt = self - .id_def - .new_id(Some((StateSpace::Reg, ast::Type::Scalar(scalar_t)))); - let result_id = self.id_def.new_id(Some((StateSpace::Reg, typ))); - self.func.push(Statement::Constant(ConstantDefinition { - dst: id_constant_stmt, - typ: scalar_t, - value: offset as i64, - })); - let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!()); - self.func.push(Statement::Instruction( - ast::Instruction::::Add( - ast::AddDetails::Int(ast::AddIntDesc { - typ: int_type, - saturate: false, - }), - ast::Arg3 { - dst: result_id, - src1: reg, - src2: id_constant_stmt, - }, - ), - )); - result_id - } - ArgumentSemantics::Ptr => { - let scalar_t = ast::ScalarType::U64; - let id_constant_stmt = self - .id_def - .new_id(Some((StateSpace::Reg, ast::Type::Scalar(scalar_t)))); - let result_id = self.id_def.new_id(Some((StateSpace::Reg, typ))); - self.func.push(Statement::Constant(ConstantDefinition { - dst: id_constant_stmt, - typ: scalar_t, - value: offset as i64, - })); - let int_type = ast::IntType::U64; - self.func.push(Statement::Instruction( - ast::Instruction::::Add( - ast::AddDetails::Int(ast::AddIntDesc { - typ: int_type, - saturate: false, - }), - ast::Arg3 { - dst: result_id, - src1: reg, - src2: id_constant_stmt, - }, - ), - )); - result_id - } - ArgumentSemantics::ParamPtr => { - if offset == 0 { - return reg; - } - // Will be needed for arrays + ast::Operand::RegOffset(reg, offset) => match desc.sema { + ArgumentSemantics::Default => { + let scalar_t = if let ast::Type::Scalar(scalar) = typ { + scalar + } else { todo!() - } + }; + let id_constant_stmt = self + .id_def + .new_id(Some((StateSpace::Reg, ast::Type::Scalar(scalar_t)))); + let result_id = self.id_def.new_id(Some((StateSpace::Reg, typ))); + self.func.push(Statement::Constant(ConstantDefinition { + dst: id_constant_stmt, + typ: scalar_t, + value: offset as i64, + })); + let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!()); + self.func.push(Statement::Instruction( + ast::Instruction::::Add( + ast::AddDetails::Int(ast::AddIntDesc { + typ: int_type, + saturate: false, + }), + ast::Arg3 { + dst: result_id, + src1: reg, + src2: id_constant_stmt, + }, + ), + )); + Ok(result_id) } - } + ArgumentSemantics::Ptr => { + let scalar_t = ast::ScalarType::U64; + let id_constant_stmt = self + .id_def + .new_id(Some((StateSpace::Reg, ast::Type::Scalar(scalar_t)))); + let result_id = self.id_def.new_id(Some((StateSpace::Reg, typ))); + self.func.push(Statement::Constant(ConstantDefinition { + dst: id_constant_stmt, + typ: scalar_t, + value: offset as i64, + })); + let int_type = ast::IntType::U64; + self.func.push(Statement::Instruction( + ast::Instruction::::Add( + ast::AddDetails::Int(ast::AddIntDesc { + typ: int_type, + saturate: false, + }), + ast::Arg3 { + dst: result_id, + src1: reg, + src2: id_constant_stmt, + }, + ), + )); + Ok(result_id) + } + ArgumentSemantics::ParamPtr => { + if offset == 0 { + return Ok(reg); + } + todo!() + } + }, } } @@ -814,7 +902,7 @@ impl<'a, 'b> ArgumentMapVisitor &mut self, desc: ArgumentDescriptor>, typ: ast::Type, - ) -> spirv::Word { + ) -> Result { match desc.op { ast::CallOperand::Reg(reg) => self.variable(desc.new_op(reg), Some(typ)), ast::CallOperand::Imm(x) => self.operand(desc.new_op(ast::Operand::Imm(x)), typ), @@ -825,7 +913,7 @@ impl<'a, 'b> ArgumentMapVisitor &mut self, desc: ArgumentDescriptor<(spirv::Word, u8)>, (scalar_type, vec_len): (ast::MovVectorType, u8), - ) -> spirv::Word { + ) -> Result { let new_id = self.id_def.new_id(Some(( StateSpace::Reg, ast::Type::Vector(scalar_type.into(), vec_len), @@ -836,15 +924,15 @@ impl<'a, 'b> ArgumentMapVisitor src_composite: desc.op.0, src_index: desc.op.1 as u32, })); - new_id + Ok(new_id) } fn mov_operand( &mut self, - desc: ArgumentDescriptor>, + desc: ArgumentDescriptor>, typ: ast::Type, - ) -> spirv::Word { - self.operand(desc, typ) + ) -> Result { + todo!() } } @@ -862,9 +950,10 @@ impl<'a, 'b> ArgumentMapVisitor - generic/global st: for instruction `st [x], y`, x must be of type b64/u64/s64, which is bitcast to a pointer */ +/* fn insert_implicit_conversions( func: Vec, - id_def: &mut NumericIdResolver, + id_def: &mut MutableNumericIdResolver, ) -> Vec { let mut result = Vec::with_capacity(func.len()); for s in func.into_iter() { @@ -936,7 +1025,7 @@ fn insert_implicit_conversions( let mut did_vector_implicit = false; let mut post_conv = None; if inst_typ_is_bit { - let src_type = id_def.get_type(arg.src).unwrap_or_else(|| todo!()).1; + let src_type = id_def.get_typed(arg.src)?; if let ast::Type::Vector(_, _) = src_type { arg.src = insert_conversion_src( &mut result, @@ -948,7 +1037,7 @@ fn insert_implicit_conversions( ); did_vector_implicit = true; } - let dst_type = id_def.get_type(arg.dst).unwrap_or_else(|| todo!()).1; + let dst_type = id_def.get_typed(arg.dst)?; if let ast::Type::Vector(_, _) = src_type { post_conv = Some(get_conversion_dst( id_def, @@ -988,6 +1077,7 @@ fn insert_implicit_conversions( } result } +*/ fn get_function_type( builder: &mut dr::Builder, @@ -1600,12 +1690,11 @@ fn emit_implicit_conversion( Ok(()) } -// TODO: support scopes fn normalize_identifiers<'a, 'b>( id_defs: &mut FnStringIdResolver<'a, 'b>, fn_defs: &GlobalFnDeclResolver<'a, 'b>, func: Vec>>, -) -> Vec { +) -> Result, TranslateError> { for s in func.iter() { match s { ast::Statement::Label(id) => { @@ -1616,9 +1705,9 @@ fn normalize_identifiers<'a, 'b>( } let mut result = Vec::new(); for s in func { - expand_map_variables(id_defs, fn_defs, &mut result, s); + expand_map_variables(id_defs, fn_defs, &mut result, s)?; } - result + Ok(result) } fn expand_map_variables<'a, 'b>( @@ -1626,19 +1715,20 @@ fn expand_map_variables<'a, 'b>( fn_defs: &GlobalFnDeclResolver<'a, 'b>, result: &mut Vec, s: ast::Statement>, -) { +) -> Result<(), TranslateError> { match s { ast::Statement::Block(block) => { id_defs.start_block(); for s in block { - expand_map_variables(id_defs, fn_defs, result, s); + expand_map_variables(id_defs, fn_defs, result, s)?; } id_defs.end_block(); } - ast::Statement::Label(name) => result.push(Statement::Label(id_defs.get_id(name))), + ast::Statement::Label(name) => result.push(Statement::Label(id_defs.get_id(name)?)), ast::Statement::Instruction(p, i) => result.push(Statement::Instruction(( - p.map(|p| p.map_variable(&mut |id| id_defs.get_id(id))), - i.map_variable(&mut |id| id_defs.get_id(id)), + p.map(|p| p.map_variable(&mut |id| id_defs.get_id(id))) + .transpose()?, + i.map_variable(&mut |id| id_defs.get_id(id))?, ))), ast::Statement::Variable(var) => { let ss = match var.var.v_type { @@ -1666,16 +1756,16 @@ fn expand_map_variables<'a, 'b>( } } } - } + }; + Ok(()) } -#[derive(Ord, PartialOrd, Eq, PartialEq, Hash)] +#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)] enum PtxSpecialRegister { Tid, Ntid, Ctaid, Nctaid, - Gridid, } impl PtxSpecialRegister { @@ -1685,10 +1775,27 @@ impl PtxSpecialRegister { "%ntid" => Some(Self::Ntid), "%ctaid" => Some(Self::Ctaid), "%nctaid" => Some(Self::Nctaid), - "%gridid" => Some(Self::Gridid), _ => None, } } + + fn get_type(self) -> ast::Type { + match self { + PtxSpecialRegister::Tid => ast::Type::Vector(ast::ScalarType::U32, 4), + PtxSpecialRegister::Ntid => ast::Type::Vector(ast::ScalarType::U32, 4), + PtxSpecialRegister::Ctaid => ast::Type::Vector(ast::ScalarType::U32, 4), + PtxSpecialRegister::Nctaid => ast::Type::Vector(ast::ScalarType::U32, 4), + } + } + + fn get_builtin(self) -> spirv::BuiltIn { + match self { + PtxSpecialRegister::Tid => spirv::BuiltIn::GlobalInvocationId, + PtxSpecialRegister::Ntid => spirv::BuiltIn::GlobalSize, + PtxSpecialRegister::Ctaid => spirv::BuiltIn::WorkgroupId, + PtxSpecialRegister::Nctaid => spirv::BuiltIn::NumWorkgroups, + } + } } struct GlobalStringIdResolver<'input> { @@ -1725,8 +1832,11 @@ impl<'a> GlobalStringIdResolver<'a> { } } - fn get_id(&self, id: &str) -> spirv::Word { - self.variables[id] + fn get_id(&self, id: &str) -> Result { + self.variables + .get(id) + .copied() + .ok_or(TranslateError::UnknownSymbol) } fn current_id(&self) -> spirv::Word { @@ -1741,7 +1851,7 @@ impl<'a> GlobalStringIdResolver<'a> { GlobalFnDeclResolver<'a, 'b>, ast::MethodDecl<'a, ExpandedArgParams>, ) { - // In case a function decl was inserted eearlier we want to use its id + // In case a function decl was inserted earlier we want to use its id let name_id = self.get_or_add_def(header.name()); let mut fn_resolver = FnStringIdResolver { current_id: &mut self.current_id, @@ -1784,12 +1894,15 @@ pub struct GlobalFnDeclResolver<'input, 'a> { } impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> { - fn get_fn_decl(&self, id: spirv::Word) -> &FnDecl { - &self.fns[&id] + fn get_fn_decl(&self, id: spirv::Word) -> Result<&FnDecl, TranslateError> { + self.fns.get(&id).ok_or(TranslateError::UnknownSymbol) } - fn get_fn_decl_str(&self, id: &str) -> &'a FnDecl { - &self.fns[&self.variables[id]] + fn get_fn_decl_str(&self, id: &str) -> Result<&'a FnDecl, TranslateError> { + match self.variables.get(id).map(|var_id| self.fns.get(var_id)) { + Some(Some(fn_d)) => Ok(fn_d), + _ => Err(TranslateError::UnknownSymbol), + } } } @@ -1798,7 +1911,7 @@ struct FnStringIdResolver<'input, 'b> { global_variables: &'b HashMap, spirv::Word>, special_registers: &'b mut HashMap, variables: Vec, spirv::Word>>, - type_check: HashMap, + type_check: HashMap>, } impl<'a, 'b> FnStringIdResolver<'a, 'b> { @@ -1806,6 +1919,11 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { NumericIdResolver { current_id: self.current_id, type_check: self.type_check, + special_registers: self + .special_registers + .iter() + .map(|(reg, id)| (*id, *reg)) + .collect(), } } @@ -1817,24 +1935,25 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { self.variables.pop(); } - fn get_id(&mut self, id: &str) -> spirv::Word { + fn get_id(&mut self, id: &str) -> Result { for scope in self.variables.iter().rev() { match scope.get(id) { - Some(id) => return *id, + Some(id) => return Ok(*id), None => continue, } } match self.global_variables.get(id) { - Some(id) => *id, + Some(id) => Ok(*id), None => { - let sreg = PtxSpecialRegister::try_parse(id).unwrap_or_else(|| todo!()); + let sreg = + PtxSpecialRegister::try_parse(id).ok_or(TranslateError::UnknownSymbol)?; match self.special_registers.entry(sreg) { - hash_map::Entry::Occupied(e) => *e.get(), + hash_map::Entry::Occupied(e) => Ok(*e.get()), hash_map::Entry::Vacant(e) => { let numeric_id = *self.current_id; *self.current_id += 1; e.insert(numeric_id); - numeric_id + Ok(numeric_id) } } } @@ -1847,9 +1966,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { .last_mut() .unwrap() .insert(Cow::Borrowed(id), numeric_id); - if let Some(typ) = typ { - self.type_check.insert(numeric_id, typ); - } + self.type_check.insert(numeric_id, typ); *self.current_id += 1; numeric_id } @@ -1868,7 +1985,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { .last_mut() .unwrap() .insert(Cow::Owned(format!("{}{}", base_id, i)), numeric_id + i); - self.type_check.insert(numeric_id + i, (ss, typ)); + self.type_check.insert(numeric_id + i, Some((ss, typ))); } *self.current_id += count; (0..count).into_iter().map(move |i| i + numeric_id) @@ -1877,24 +1994,48 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { struct NumericIdResolver<'b> { current_id: &'b mut spirv::Word, - type_check: HashMap, + type_check: HashMap>, + special_registers: HashMap, } impl<'b> NumericIdResolver<'b> { - fn get_type(&self, id: spirv::Word) -> Option<(StateSpace, ast::Type)> { - self.type_check.get(&id).map(|x| *x) + fn finish(self) -> MutableNumericIdResolver<'b> { + MutableNumericIdResolver { base: self } + } + + fn get_typed(&self, id: spirv::Word) -> Result<(StateSpace, ast::Type), TranslateError> { + match self.type_check.get(&id) { + Some(Some(x)) => Ok(*x), + Some(None) => Err(TranslateError::UntypedSymbol), + None => match self.special_registers.get(&id) { + Some(x) => Ok((StateSpace::Reg, x.get_type())), + None => Err(TranslateError::UntypedSymbol), + }, + } } fn new_id(&mut self, typ: Option<(StateSpace, ast::Type)>) -> spirv::Word { let new_id = *self.current_id; - if let Some(typ) = typ { - self.type_check.insert(new_id, typ); - } + self.type_check.insert(new_id, typ); *self.current_id += 1; new_id } } +struct MutableNumericIdResolver<'b> { + base: NumericIdResolver<'b>, +} + +impl<'b> MutableNumericIdResolver<'b> { + fn get_typed(&self, id: spirv::Word) -> Result { + self.base.get_typed(id).map(|(_, t)| t) + } + + fn new_id(&mut self, typ: ast::Type) -> spirv::Word { + self.base.new_id(Some((StateSpace::Reg, typ))) + } +} + enum Statement { Label(u32), Variable(ast::Variable), @@ -1921,11 +2062,11 @@ impl> ResolvedCall { fn map, V: ArgumentMapVisitor>( self, visitor: &mut V, - ) -> ResolvedCall { + ) -> Result, TranslateError> { let ret_params = self .ret_params .into_iter() - .map(|(id, typ)| { + .map::, _>(|(id, typ)| { let new_id = visitor.variable( ArgumentDescriptor { op: id, @@ -1933,10 +2074,10 @@ impl> ResolvedCall { sema: ArgumentSemantics::Default, }, Some(typ.into()), - ); - (new_id, typ) + )?; + Ok((new_id, typ)) }) - .collect(); + .collect::, _>>()?; let func = visitor.variable( ArgumentDescriptor { op: self.func, @@ -1944,11 +2085,11 @@ impl> ResolvedCall { sema: ArgumentSemantics::Default, }, None, - ); + )?; let param_list = self .param_list .into_iter() - .map(|(id, typ)| { + .map::, _>(|(id, typ)| { let new_id = visitor.src_call_operand( ArgumentDescriptor { op: id, @@ -1956,48 +2097,60 @@ impl> ResolvedCall { sema: ArgumentSemantics::Default, }, typ.into(), - ); - (new_id, typ) + )?; + Ok((new_id, typ)) }) - .collect(); - ResolvedCall { + .collect::, _>>()?; + Ok(ResolvedCall { uniform: self.uniform, ret_params, func, param_list, - } + }) } } impl VisitVariable for ResolvedCall { fn visit_variable< 'a, - F: FnMut(ArgumentDescriptor, Option) -> spirv::Word, + F: FnMut( + ArgumentDescriptor, + Option, + ) -> Result, >( self, f: &mut F, - ) -> UnadornedStatement { - Statement::Call(self.map(f)) + ) -> Result { + Ok(Statement::Call(self.map(f)?)) } } impl VisitVariableExpanded for ResolvedCall { fn visit_variable_extended< - F: FnMut(ArgumentDescriptor, Option) -> spirv::Word, + F: FnMut( + ArgumentDescriptor, + Option, + ) -> Result, >( self, f: &mut F, - ) -> ExpandedStatement { - Statement::Call(self.map(f)) + ) -> Result { + Ok(Statement::Call(self.map(f)?)) } } pub trait ArgParamsEx: ast::ArgParams { - fn get_fn_decl<'x, 'b>(id: &Self::ID, decl: &'b GlobalFnDeclResolver<'x, 'b>) -> &'b FnDecl; + fn get_fn_decl<'x, 'b>( + id: &Self::ID, + decl: &'b GlobalFnDeclResolver<'x, 'b>, + ) -> Result<&'b FnDecl, TranslateError>; } impl<'input> ArgParamsEx for ast::ParsedArgParams<'input> { - fn get_fn_decl<'x, 'b>(id: &Self::ID, decl: &'b GlobalFnDeclResolver<'x, 'b>) -> &'b FnDecl { + fn get_fn_decl<'x, 'b>( + id: &Self::ID, + decl: &'b GlobalFnDeclResolver<'x, 'b>, + ) -> Result<&'b FnDecl, TranslateError> { decl.get_fn_decl_str(id) } } @@ -2015,23 +2168,16 @@ type UnadornedStatement = Statement, Norma impl ast::ArgParams for NormalizedArgParams { type ID = spirv::Word; type Operand = ast::Operand; - type MemoryOperand = ast::Operand; + type MovOperand = ast::MovOperand; type CallOperand = ast::CallOperand; type VecOperand = (spirv::Word, u8); } -enum TypedArgParams {} -impl ast::ArgParams for TypedArgParams { - type ID = spirv::Word; - type Operand = ast::Operand; - type MemoryOperand = MemoryOperand; - type CallOperand = ast::CallOperand; - type VecOperand = (spirv::Word, u8); -} -type TypedStatement = Statement, TypedArgParams>; - impl ArgParamsEx for NormalizedArgParams { - fn get_fn_decl<'a, 'b>(id: &Self::ID, decl: &'b GlobalFnDeclResolver<'a, 'b>) -> &'b FnDecl { + fn get_fn_decl<'a, 'b>( + id: &Self::ID, + decl: &'b GlobalFnDeclResolver<'a, 'b>, + ) -> Result<&'b FnDecl, TranslateError> { decl.get_fn_decl(*id) } } @@ -2039,7 +2185,6 @@ impl ArgParamsEx for NormalizedArgParams { #[derive(Copy, Clone)] pub enum StateSpace { Reg, - Sreg, Const, Global, Local, @@ -2048,15 +2193,6 @@ pub enum StateSpace { ParamReg, } -#[derive(Copy, Clone)] -pub enum MemoryOperand { - Reg(spirv::Word), - Address(spirv::Word), - RegOffset(spirv::Word, i32), - AddressOffset(spirv::Word, i32), - Imm(u32), -} - enum ExpandedArgParams {} type ExpandedStatement = Statement, ExpandedArgParams>; type ExpandedFunction<'a> = ast::Function<'a, ExpandedArgParams, ExpandedStatement>; @@ -2064,54 +2200,76 @@ type ExpandedFunction<'a> = ast::Function<'a, ExpandedArgParams, ExpandedStateme impl ast::ArgParams for ExpandedArgParams { type ID = spirv::Word; type Operand = spirv::Word; - type MemoryOperand = spirv::Word; + type MovOperand = spirv::Word; type CallOperand = spirv::Word; type VecOperand = spirv::Word; } impl ArgParamsEx for ExpandedArgParams { - fn get_fn_decl<'a, 'b>(id: &Self::ID, decl: &'b GlobalFnDeclResolver<'a, 'b>) -> &'b FnDecl { + fn get_fn_decl<'a, 'b>( + id: &Self::ID, + decl: &'b GlobalFnDeclResolver<'a, 'b>, + ) -> Result<&'b FnDecl, TranslateError> { decl.get_fn_decl(*id) } } trait ArgumentMapVisitor { - fn variable(&mut self, desc: ArgumentDescriptor, typ: Option) -> U::ID; - fn operand(&mut self, desc: ArgumentDescriptor, typ: ast::Type) -> U::Operand; + fn variable( + &mut self, + desc: ArgumentDescriptor, + typ: Option, + ) -> Result; + fn operand( + &mut self, + desc: ArgumentDescriptor, + typ: ast::Type, + ) -> Result; fn mov_operand( &mut self, - desc: ArgumentDescriptor, + desc: ArgumentDescriptor, typ: ast::Type, - ) -> U::MemoryOperand; + ) -> Result; fn src_call_operand( &mut self, desc: ArgumentDescriptor, typ: ast::Type, - ) -> U::CallOperand; + ) -> Result; fn src_vec_operand( &mut self, desc: ArgumentDescriptor, typ: (ast::MovVectorType, u8), - ) -> U::VecOperand; + ) -> Result; } impl ArgumentMapVisitor for T where - T: FnMut(ArgumentDescriptor, Option) -> spirv::Word, + T: FnMut( + ArgumentDescriptor, + Option, + ) -> Result, { fn variable( &mut self, desc: ArgumentDescriptor, t: Option, - ) -> spirv::Word { + ) -> Result { self(desc, t) } - fn operand(&mut self, desc: ArgumentDescriptor, t: ast::Type) -> spirv::Word { + fn operand( + &mut self, + desc: ArgumentDescriptor, + t: ast::Type, + ) -> Result { self(desc, Some(t)) } - fn mov_operand(&mut self, desc: ArgumentDescriptor, t: ast::Type) -> spirv::Word { + fn mov_operand( + &mut self, + desc: ArgumentDescriptor, + t: ast::Type, + ) -> Result { self(desc, Some(t)) } @@ -2119,7 +2277,7 @@ where &mut self, desc: ArgumentDescriptor, t: ast::Type, - ) -> spirv::Word { + ) -> Result { self(desc, Some(t)) } @@ -2127,7 +2285,7 @@ where &mut self, desc: ArgumentDescriptor, (scalar_type, vec_len): (ast::MovVectorType, u8), - ) -> spirv::Word { + ) -> Result { self( desc.new_op(desc.op), Some(ast::Type::Vector(scalar_type.into(), vec_len)), @@ -2137,9 +2295,13 @@ where impl<'a, T> ArgumentMapVisitor, NormalizedArgParams> for T where - T: FnMut(&str) -> spirv::Word, + T: FnMut(&str) -> Result, { - fn variable(&mut self, desc: ArgumentDescriptor<&str>, _: Option) -> spirv::Word { + fn variable( + &mut self, + desc: ArgumentDescriptor<&str>, + _: Option, + ) -> Result { self(desc.op) } @@ -2147,11 +2309,11 @@ where &mut self, desc: ArgumentDescriptor>, _: ast::Type, - ) -> ast::Operand { + ) -> Result, TranslateError> { match desc.op { - ast::Operand::Reg(id) => ast::Operand::Reg(self(id)), - ast::Operand::Imm(imm) => ast::Operand::Imm(imm), - ast::Operand::RegOffset(id, imm) => ast::Operand::RegOffset(self(id), imm), + ast::Operand::Reg(id) => Ok(ast::Operand::Reg(self(id)?)), + ast::Operand::Imm(imm) => Ok(ast::Operand::Imm(imm)), + ast::Operand::RegOffset(id, imm) => Ok(ast::Operand::RegOffset(self(id)?, imm)), } } @@ -2159,10 +2321,10 @@ where &mut self, desc: ArgumentDescriptor>, _: ast::Type, - ) -> ast::CallOperand { + ) -> Result, TranslateError> { match desc.op { - ast::CallOperand::Reg(id) => ast::CallOperand::Reg(self(id)), - ast::CallOperand::Imm(imm) => ast::CallOperand::Imm(imm), + ast::CallOperand::Reg(id) => Ok(ast::CallOperand::Reg(self(id)?)), + ast::CallOperand::Imm(imm) => Ok(ast::CallOperand::Imm(imm)), } } @@ -2170,16 +2332,16 @@ where &mut self, desc: ArgumentDescriptor<(&str, u8)>, _: (ast::MovVectorType, u8), - ) -> (spirv::Word, u8) { - (self(desc.op.0), desc.op.1) + ) -> Result<(spirv::Word, u8), TranslateError> { + Ok((self(desc.op.0)?, desc.op.1)) } fn mov_operand( &mut self, - desc: ArgumentDescriptor>, + desc: ArgumentDescriptor>, typ: ast::Type, - ) -> ast::Operand { - self.operand(desc, typ) + ) -> Result, TranslateError> { + todo!() } } @@ -2210,41 +2372,41 @@ impl ast::Instruction { fn map>( self, visitor: &mut V, - ) -> ast::Instruction { - match self { + ) -> Result, TranslateError> { + Ok(match self { ast::Instruction::MovVector(t, a) => { - ast::Instruction::MovVector(t, a.map(visitor, (t.typ, t.length))) + ast::Instruction::MovVector(t, a.map(visitor, (t.typ, t.length))?) } ast::Instruction::Abs(d, arg) => { - ast::Instruction::Abs(d, arg.map(visitor, ast::Type::Scalar(d.typ))) + ast::Instruction::Abs(d, arg.map(visitor, ast::Type::Scalar(d.typ))?) } // Call instruction is converted to a call statement early on ast::Instruction::Call(_) => unreachable!(), ast::Instruction::Ld(d, a) => { let inst_type = d.typ; let is_param = d.state_space == ast::LdStateSpace::Param; - ast::Instruction::Ld(d, a.map_ld(visitor, inst_type, is_param)) + ast::Instruction::Ld(d, a.map_ld(visitor, inst_type, is_param)?) } ast::Instruction::Mov(mov_type, a) => { - ast::Instruction::Mov(mov_type, a.map(visitor, mov_type.into())) + ast::Instruction::Mov(mov_type, a.map(visitor, mov_type.into())?) } ast::Instruction::Mul(d, a) => { let inst_type = d.get_type(); - ast::Instruction::Mul(d, a.map_non_shift(visitor, inst_type)) + ast::Instruction::Mul(d, a.map_non_shift(visitor, inst_type)?) } ast::Instruction::Add(d, a) => { let inst_type = d.get_type(); - ast::Instruction::Add(d, a.map_non_shift(visitor, inst_type)) + ast::Instruction::Add(d, a.map_non_shift(visitor, inst_type)?) } ast::Instruction::Setp(d, a) => { let inst_type = d.typ; - ast::Instruction::Setp(d, a.map(visitor, ast::Type::Scalar(inst_type))) + ast::Instruction::Setp(d, a.map(visitor, ast::Type::Scalar(inst_type))?) } ast::Instruction::SetpBool(d, a) => { let inst_type = d.typ; - ast::Instruction::SetpBool(d, a.map(visitor, ast::Type::Scalar(inst_type))) + ast::Instruction::SetpBool(d, a.map(visitor, ast::Type::Scalar(inst_type))?) } - ast::Instruction::Not(t, a) => ast::Instruction::Not(t, a.map(visitor, t.to_type())), + ast::Instruction::Not(t, a) => ast::Instruction::Not(t, a.map(visitor, t.to_type())?), ast::Instruction::Cvt(d, a) => { let (dst_t, src_t) = match &d { ast::CvtDetails::FloatFromFloat(desc) => ( @@ -2264,47 +2426,53 @@ impl ast::Instruction { ast::Type::Scalar(desc.src.into()), ), }; - ast::Instruction::Cvt(d, a.map_cvt(visitor, dst_t, src_t)) + ast::Instruction::Cvt(d, a.map_cvt(visitor, dst_t, src_t)?) } ast::Instruction::Shl(t, a) => { - ast::Instruction::Shl(t, a.map_shift(visitor, t.to_type())) + ast::Instruction::Shl(t, a.map_shift(visitor, t.to_type())?) } ast::Instruction::St(d, a) => { let inst_type = d.typ; let is_param = d.state_space == ast::StStateSpace::Param; - ast::Instruction::St(d, a.map(visitor, inst_type, is_param)) + ast::Instruction::St(d, a.map(visitor, inst_type, is_param)?) } - ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, None)), + ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, None)?), ast::Instruction::Ret(d) => ast::Instruction::Ret(d), ast::Instruction::Cvta(d, a) => { let inst_type = ast::Type::Scalar(ast::ScalarType::B64); - ast::Instruction::Cvta(d, a.map(visitor, inst_type)) + ast::Instruction::Cvta(d, a.map(visitor, inst_type)?) } - } + }) } } impl VisitVariable for ast::Instruction { fn visit_variable< 'a, - F: FnMut(ArgumentDescriptor, Option) -> spirv::Word, + F: FnMut( + ArgumentDescriptor, + Option, + ) -> Result, >( self, f: &mut F, - ) -> UnadornedStatement { - Statement::Instruction(self.map(f)) + ) -> Result { + Ok(Statement::Instruction(self.map(f)?)) } } impl ArgumentMapVisitor for T where - T: FnMut(ArgumentDescriptor, Option) -> spirv::Word, + T: FnMut( + ArgumentDescriptor, + Option, + ) -> Result, { fn variable( &mut self, desc: ArgumentDescriptor, t: Option, - ) -> spirv::Word { + ) -> Result { self(desc, t) } @@ -2312,13 +2480,14 @@ where &mut self, desc: ArgumentDescriptor>, t: ast::Type, - ) -> ast::Operand { + ) -> Result, TranslateError> { match desc.op { - ast::Operand::Reg(id) => ast::Operand::Reg(self(desc.new_op(id), Some(t))), - ast::Operand::Imm(imm) => ast::Operand::Imm(imm), - ast::Operand::RegOffset(id, imm) => { - ast::Operand::RegOffset(self(desc.new_op(id), Some(t)), imm) - } + ast::Operand::Reg(id) => Ok(ast::Operand::Reg(self(desc.new_op(id), Some(t))?)), + ast::Operand::Imm(imm) => Ok(ast::Operand::Imm(imm)), + ast::Operand::RegOffset(id, imm) => Ok(ast::Operand::RegOffset( + self(desc.new_op(id), Some(t))?, + imm, + )), } } @@ -2326,10 +2495,10 @@ where &mut self, desc: ArgumentDescriptor>, t: ast::Type, - ) -> ast::CallOperand { + ) -> Result, TranslateError> { match desc.op { - ast::CallOperand::Reg(id) => ast::CallOperand::Reg(self(desc.new_op(id), Some(t))), - ast::CallOperand::Imm(imm) => ast::CallOperand::Imm(imm), + ast::CallOperand::Reg(id) => Ok(ast::CallOperand::Reg(self(desc.new_op(id), Some(t))?)), + ast::CallOperand::Imm(imm) => Ok(ast::CallOperand::Imm(imm)), } } @@ -2337,24 +2506,22 @@ where &mut self, desc: ArgumentDescriptor<(spirv::Word, u8)>, (scalar_type, vector_len): (ast::MovVectorType, u8), - ) -> (spirv::Word, u8) { - ( + ) -> Result<(spirv::Word, u8), TranslateError> { + Ok(( self( desc.new_op(desc.op.0), Some(ast::Type::Vector(scalar_type.into(), vector_len)), - ), + )?, desc.op.1, - ) + )) } fn mov_operand( &mut self, - desc: ArgumentDescriptor>, + desc: ArgumentDescriptor>, typ: ast::Type, - ) -> ast::Operand { - >::operand( - self, desc, typ, - ) + ) -> Result, TranslateError> { + todo!() } } @@ -2439,12 +2606,15 @@ impl ast::Instruction { impl VisitVariableExpanded for ast::Instruction { fn visit_variable_extended< - F: FnMut(ArgumentDescriptor, Option) -> spirv::Word, + F: FnMut( + ArgumentDescriptor, + Option, + ) -> Result, >( self, f: &mut F, - ) -> ExpandedStatement { - Statement::Instruction(self.map(f)) + ) -> Result { + Ok(Statement::Instruction(self.map(f)?)) } } @@ -2488,32 +2658,40 @@ enum ConversionKind { } impl ast::PredAt { - fn map_variable U>(self, f: &mut F) -> ast::PredAt { - ast::PredAt { + fn map_variable Result>( + self, + f: &mut F, + ) -> Result, TranslateError> { + let new_label = f(self.label)?; + Ok(ast::PredAt { not: self.not, - label: f(self.label), - } + label: new_label, + }) } } impl<'a> ast::Instruction> { - fn map_variable spirv::Word>( + fn map_variable Result>( self, f: &mut F, - ) -> ast::Instruction { + ) -> Result, TranslateError> { match self { ast::Instruction::Call(call) => { let call_inst = ast::CallInst { uniform: call.uniform, - ret_params: call.ret_params.into_iter().map(|p| f(p)).collect(), - func: f(call.func), + ret_params: call + .ret_params + .into_iter() + .map(|p| f(p)) + .collect::>()?, + func: f(call.func)?, param_list: call .param_list .into_iter() .map(|p| p.map_variable(f)) - .collect(), + .collect::>()?, }; - ast::Instruction::Call(call_inst) + Ok(ast::Instruction::Call(call_inst)) } i => i.map(f), } @@ -2525,17 +2703,16 @@ impl ast::Arg1 { self, visitor: &mut V, t: Option, - ) -> ast::Arg1 { - ast::Arg1 { - src: visitor.variable( - ArgumentDescriptor { - op: self.src, - is_dst: false, - sema: ArgumentSemantics::Default, - }, - t, - ), - } + ) -> Result, TranslateError> { + let new_src = visitor.variable( + ArgumentDescriptor { + op: self.src, + is_dst: false, + sema: ArgumentSemantics::Default, + }, + t, + )?; + Ok(ast::Arg1 { src: new_src }) } } @@ -2544,25 +2721,27 @@ impl ast::Arg2 { self, visitor: &mut V, t: ast::Type, - ) -> ast::Arg2 { - ast::Arg2 { - dst: visitor.variable( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - sema: ArgumentSemantics::Default, - }, - Some(t), - ), - src: visitor.operand( - ArgumentDescriptor { - op: self.src, - is_dst: false, - sema: ArgumentSemantics::Default, - }, - t, - ), - } + ) -> Result, TranslateError> { + let new_dst = visitor.variable( + ArgumentDescriptor { + op: self.dst, + is_dst: true, + sema: ArgumentSemantics::Default, + }, + Some(t), + )?; + let new_src = visitor.operand( + ArgumentDescriptor { + op: self.src, + is_dst: false, + sema: ArgumentSemantics::Default, + }, + t, + )?; + Ok(ast::Arg2 { + dst: new_dst, + src: new_src, + }) } fn map_ld>( @@ -2570,29 +2749,28 @@ impl ast::Arg2 { visitor: &mut V, t: ast::Type, is_param: bool, - ) -> ast::Arg2 { - ast::Arg2 { - dst: visitor.variable( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - sema: ArgumentSemantics::Default, + ) -> Result, TranslateError> { + let dst = visitor.variable( + ArgumentDescriptor { + op: self.dst, + is_dst: true, + sema: ArgumentSemantics::Default, + }, + Some(t), + )?; + let src = visitor.operand( + ArgumentDescriptor { + op: self.src, + is_dst: false, + sema: if is_param { + ArgumentSemantics::ParamPtr + } else { + ArgumentSemantics::Ptr }, - Some(t), - ), - src: visitor.operand( - ArgumentDescriptor { - op: self.src, - is_dst: false, - sema: if is_param { - ArgumentSemantics::ParamPtr - } else { - ArgumentSemantics::Ptr - }, - }, - t, - ), - } + }, + t, + )?; + Ok(ast::Arg2 { dst, src }) } fn map_cvt>( @@ -2600,25 +2778,50 @@ impl ast::Arg2 { visitor: &mut V, dst_t: ast::Type, src_t: ast::Type, - ) -> ast::Arg2 { - ast::Arg2 { - dst: visitor.variable( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - sema: ArgumentSemantics::Default, - }, - Some(dst_t), - ), - src: visitor.operand( - ArgumentDescriptor { - op: self.src, - is_dst: false, - sema: ArgumentSemantics::Default, - }, - src_t, - ), - } + ) -> Result, TranslateError> { + let dst = visitor.variable( + ArgumentDescriptor { + op: self.dst, + is_dst: true, + sema: ArgumentSemantics::Default, + }, + Some(dst_t), + )?; + let src = visitor.operand( + ArgumentDescriptor { + op: self.src, + is_dst: false, + sema: ArgumentSemantics::Default, + }, + src_t, + )?; + Ok(ast::Arg2 { dst, src }) + } +} + +impl ast::Arg2Mov { + fn map>( + self, + visitor: &mut V, + t: ast::Type, + ) -> Result, TranslateError> { + let dst = visitor.variable( + ArgumentDescriptor { + op: self.dst, + is_dst: true, + sema: ArgumentSemantics::Default, + }, + Some(t), + )?; + let src = visitor.mov_operand( + ArgumentDescriptor { + op: self.src, + is_dst: false, + sema: ArgumentSemantics::Default, + }, + t, + )?; + Ok(ast::Arg2Mov { dst, src }) } } @@ -2628,29 +2831,28 @@ impl ast::Arg2St { visitor: &mut V, t: ast::Type, is_param: bool, - ) -> ast::Arg2St { - ast::Arg2St { - src1: visitor.mov_operand( - ArgumentDescriptor { - op: self.src1, - is_dst: is_param, - sema: if is_param { - ArgumentSemantics::ParamPtr - } else { - ArgumentSemantics::Ptr - }, + ) -> Result, TranslateError> { + let src1 = visitor.operand( + ArgumentDescriptor { + op: self.src1, + is_dst: is_param, + sema: if is_param { + ArgumentSemantics::ParamPtr + } else { + ArgumentSemantics::Ptr }, - t, - ), - src2: visitor.operand( - ArgumentDescriptor { - op: self.src2, - is_dst: false, - sema: ArgumentSemantics::Default, - }, - t, - ), - } + }, + t, + )?; + let src2 = visitor.operand( + ArgumentDescriptor { + op: self.src2, + is_dst: false, + sema: ArgumentSemantics::Default, + }, + t, + )?; + Ok(ast::Arg2St { src1, src2 }) } } @@ -2667,84 +2869,81 @@ impl ast::Arg2Vec { self, visitor: &mut V, (scalar_type, vec_len): (ast::MovVectorType, u8), - ) -> ast::Arg2Vec { + ) -> Result, TranslateError> { match self { - ast::Arg2Vec::Dst((dst, len), composite_src, scalar_src) => ast::Arg2Vec::Dst( - ( - visitor.variable( - ArgumentDescriptor { - op: dst, - is_dst: true, - sema: ArgumentSemantics::Default, - }, - Some(ast::Type::Scalar(scalar_type.into())), - ), - len, - ), - visitor.variable( - ArgumentDescriptor { - op: composite_src, - is_dst: false, - sema: ArgumentSemantics::Default, - }, - Some(ast::Type::Scalar(scalar_type.into())), - ), - visitor.variable( - ArgumentDescriptor { - op: scalar_src, - is_dst: false, - sema: ArgumentSemantics::Default, - }, - Some(ast::Type::Scalar(scalar_type.into())), - ), - ), - ast::Arg2Vec::Src(dst, src) => ast::Arg2Vec::Src( - visitor.variable( + ast::Arg2Vec::Dst((dst, len), composite_src, scalar_src) => { + let dst = visitor.variable( ArgumentDescriptor { op: dst, is_dst: true, sema: ArgumentSemantics::Default, }, Some(ast::Type::Scalar(scalar_type.into())), - ), - visitor.src_vec_operand( - ArgumentDescriptor { - op: src, - is_dst: false, - sema: ArgumentSemantics::Default, - }, - (scalar_type, vec_len), - ), - ), - ast::Arg2Vec::Both((dst, len), composite_src, src) => ast::Arg2Vec::Both( - ( - visitor.variable( - ArgumentDescriptor { - op: dst, - is_dst: true, - sema: ArgumentSemantics::Default, - }, - Some(ast::Type::Scalar(scalar_type.into())), - ), - len, - ), - visitor.variable( + )?; + let src1 = visitor.variable( ArgumentDescriptor { op: composite_src, is_dst: false, sema: ArgumentSemantics::Default, }, Some(ast::Type::Scalar(scalar_type.into())), - ), - visitor.src_vec_operand( + )?; + let src2 = visitor.variable( + ArgumentDescriptor { + op: scalar_src, + is_dst: false, + sema: ArgumentSemantics::Default, + }, + Some(ast::Type::Scalar(scalar_type.into())), + )?; + Ok(ast::Arg2Vec::Dst((dst, len), src1, src2)) + } + ast::Arg2Vec::Src(dst, src) => { + let dst = visitor.variable( + ArgumentDescriptor { + op: dst, + is_dst: true, + sema: ArgumentSemantics::Default, + }, + Some(ast::Type::Scalar(scalar_type.into())), + )?; + let src = visitor.src_vec_operand( ArgumentDescriptor { op: src, is_dst: false, sema: ArgumentSemantics::Default, }, (scalar_type, vec_len), - ), - ), + )?; + Ok(ast::Arg2Vec::Src(dst, src)) + } + ast::Arg2Vec::Both((dst, len), composite_src, src) => { + let dst = visitor.variable( + ArgumentDescriptor { + op: dst, + is_dst: true, + sema: ArgumentSemantics::Default, + }, + Some(ast::Type::Scalar(scalar_type.into())), + )?; + let composite_src = visitor.variable( + ArgumentDescriptor { + op: composite_src, + is_dst: false, + sema: ArgumentSemantics::Default, + }, + Some(ast::Type::Scalar(scalar_type.into())), + )?; + let src = visitor.src_vec_operand( + ArgumentDescriptor { + op: src, + is_dst: false, + sema: ArgumentSemantics::Default, + }, + (scalar_type, vec_len), + )?; + Ok(ast::Arg2Vec::Both((dst, len), composite_src, src)) + } } } } @@ -2754,66 +2953,64 @@ impl ast::Arg3 { self, visitor: &mut V, t: ast::Type, - ) -> ast::Arg3 { - ast::Arg3 { - dst: visitor.variable( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - sema: ArgumentSemantics::Default, - }, - Some(t), - ), - src1: visitor.operand( - ArgumentDescriptor { - op: self.src1, - is_dst: false, - sema: ArgumentSemantics::Default, - }, - t, - ), - src2: visitor.operand( - ArgumentDescriptor { - op: self.src2, - is_dst: false, - sema: ArgumentSemantics::Default, - }, - t, - ), - } + ) -> Result, TranslateError> { + let dst = visitor.variable( + ArgumentDescriptor { + op: self.dst, + is_dst: true, + sema: ArgumentSemantics::Default, + }, + Some(t), + )?; + let src1 = visitor.operand( + ArgumentDescriptor { + op: self.src1, + is_dst: false, + sema: ArgumentSemantics::Default, + }, + t, + )?; + let src2 = visitor.operand( + ArgumentDescriptor { + op: self.src2, + is_dst: false, + sema: ArgumentSemantics::Default, + }, + t, + )?; + Ok(ast::Arg3 { dst, src1, src2 }) } fn map_shift>( self, visitor: &mut V, t: ast::Type, - ) -> ast::Arg3 { - ast::Arg3 { - dst: visitor.variable( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - sema: ArgumentSemantics::Default, - }, - Some(t), - ), - src1: visitor.operand( - ArgumentDescriptor { - op: self.src1, - is_dst: false, - sema: ArgumentSemantics::Default, - }, - t, - ), - src2: visitor.operand( - ArgumentDescriptor { - op: self.src2, - is_dst: false, - sema: ArgumentSemantics::Default, - }, - ast::Type::Scalar(ast::ScalarType::U32), - ), - } + ) -> Result, TranslateError> { + let dst = visitor.variable( + ArgumentDescriptor { + op: self.dst, + is_dst: true, + sema: ArgumentSemantics::Default, + }, + Some(t), + )?; + let src1 = visitor.operand( + ArgumentDescriptor { + op: self.src1, + is_dst: false, + sema: ArgumentSemantics::Default, + }, + t, + )?; + let src2 = visitor.operand( + ArgumentDescriptor { + op: self.src2, + is_dst: false, + sema: ArgumentSemantics::Default, + }, + ast::Type::Scalar(ast::ScalarType::U32), + )?; + Ok(ast::Arg3 { dst, src1, src2 }) } } @@ -2822,17 +3019,18 @@ impl ast::Arg4 { self, visitor: &mut V, t: ast::Type, - ) -> ast::Arg4 { - ast::Arg4 { - dst1: visitor.variable( - ArgumentDescriptor { - op: self.dst1, - is_dst: true, - sema: ArgumentSemantics::Default, - }, - Some(ast::Type::Scalar(ast::ScalarType::Pred)), - ), - dst2: self.dst2.map(|dst2| { + ) -> Result, TranslateError> { + let dst1 = visitor.variable( + ArgumentDescriptor { + op: self.dst1, + is_dst: true, + sema: ArgumentSemantics::Default, + }, + Some(ast::Type::Scalar(ast::ScalarType::Pred)), + )?; + let dst2 = self + .dst2 + .map(|dst2| { visitor.variable( ArgumentDescriptor { op: dst2, @@ -2841,24 +3039,30 @@ impl ast::Arg4 { }, Some(ast::Type::Scalar(ast::ScalarType::Pred)), ) - }), - src1: visitor.operand( - ArgumentDescriptor { - op: self.src1, - is_dst: false, - sema: ArgumentSemantics::Default, - }, - t, - ), - src2: visitor.operand( - ArgumentDescriptor { - op: self.src2, - is_dst: false, - sema: ArgumentSemantics::Default, - }, - t, - ), - } + }) + .transpose()?; + let src1 = visitor.operand( + ArgumentDescriptor { + op: self.src1, + is_dst: false, + sema: ArgumentSemantics::Default, + }, + t, + )?; + let src2 = visitor.operand( + ArgumentDescriptor { + op: self.src2, + is_dst: false, + sema: ArgumentSemantics::Default, + }, + t, + )?; + Ok(ast::Arg4 { + dst1, + dst2, + src1, + src2, + }) } } @@ -2867,17 +3071,18 @@ impl ast::Arg5 { self, visitor: &mut V, t: ast::Type, - ) -> ast::Arg5 { - ast::Arg5 { - dst1: visitor.variable( - ArgumentDescriptor { - op: self.dst1, - is_dst: true, - sema: ArgumentSemantics::Default, - }, - Some(ast::Type::Scalar(ast::ScalarType::Pred)), - ), - dst2: self.dst2.map(|dst2| { + ) -> Result, TranslateError> { + let dst1 = visitor.variable( + ArgumentDescriptor { + op: self.dst1, + is_dst: true, + sema: ArgumentSemantics::Default, + }, + Some(ast::Type::Scalar(ast::ScalarType::Pred)), + )?; + let dst2 = self + .dst2 + .map(|dst2| { visitor.variable( ArgumentDescriptor { op: dst2, @@ -2886,40 +3091,47 @@ impl ast::Arg5 { }, Some(ast::Type::Scalar(ast::ScalarType::Pred)), ) - }), - src1: visitor.operand( - ArgumentDescriptor { - op: self.src1, - is_dst: false, - sema: ArgumentSemantics::Default, - }, - t, - ), - src2: visitor.operand( - ArgumentDescriptor { - op: self.src2, - is_dst: false, - sema: ArgumentSemantics::Default, - }, - t, - ), - src3: visitor.operand( - ArgumentDescriptor { - op: self.src3, - is_dst: false, - sema: ArgumentSemantics::Default, - }, - ast::Type::Scalar(ast::ScalarType::Pred), - ), - } + }) + .transpose()?; + let src1 = visitor.operand( + ArgumentDescriptor { + op: self.src1, + is_dst: false, + sema: ArgumentSemantics::Default, + }, + t, + )?; + let src2 = visitor.operand( + ArgumentDescriptor { + op: self.src2, + is_dst: false, + sema: ArgumentSemantics::Default, + }, + t, + )?; + let src3 = visitor.operand( + ArgumentDescriptor { + op: self.src3, + is_dst: false, + sema: ArgumentSemantics::Default, + }, + ast::Type::Scalar(ast::ScalarType::Pred), + )?; + Ok(ast::Arg5 { + dst1, + dst2, + src1, + src2, + src3, + }) } } impl ast::CallOperand { - fn map_variable U>(self, f: &mut F) -> ast::CallOperand { + fn map_variable Result>(self, f: &mut F) -> Result, TranslateError> { match self { - ast::CallOperand::Reg(id) => ast::CallOperand::Reg(f(id)), - ast::CallOperand::Imm(x) => ast::CallOperand::Imm(x), + ast::CallOperand::Reg(id) => Ok(ast::CallOperand::Reg(f(id)?)), + ast::CallOperand::Imm(x) => Ok(ast::CallOperand::Imm(x)), } } } @@ -3195,37 +3407,37 @@ fn insert_with_conversions_pre_conv( fn get_implicit_conversions_ld_dst< ShouldConvert: FnOnce(ast::Type, ast::Type) -> Option, >( - id_def: &mut NumericIdResolver, + id_def: &mut MutableNumericIdResolver, instr_type: ast::Type, dst: spirv::Word, should_convert: ShouldConvert, in_reverse: bool, -) -> Option { - let dst_type = id_def.get_type(dst).unwrap_or_else(|| todo!()).1; +) -> Result, TranslateError> { + let dst_type = id_def.get_typed(dst)?; if let Some(conv) = should_convert(dst_type, instr_type) { - Some(ImplicitConversion { + Ok(Some(ImplicitConversion { src: u32::max_value(), dst: u32::max_value(), from: if !in_reverse { dst_type } else { instr_type }, to: if !in_reverse { instr_type } else { dst_type }, kind: conv, - }) + })) } else { - None + Ok(None) } } fn get_implicit_conversions_ld_src( - id_def: &mut NumericIdResolver, + id_def: &mut MutableNumericIdResolver, instr_type: ast::Type, state_space: ast::LdStateSpace, src: spirv::Word, -) -> Vec { - let src_type = id_def.get_type(src).unwrap_or_else(|| todo!()).1; +) -> Result, TranslateError> { + let src_type = id_def.get_typed(src)?; match state_space { ast::LdStateSpace::Param => { if src_type != instr_type { - vec![ + Ok(vec![ ImplicitConversion { src: u32::max_value(), dst: u32::max_value(), @@ -3234,9 +3446,9 @@ fn get_implicit_conversions_ld_src( kind: ConversionKind::Default, }; 1 - ] + ]) } else { - Vec::new() + Ok(Vec::new()) } } ast::LdStateSpace::Generic | ast::LdStateSpace::Global => { @@ -3268,12 +3480,12 @@ fn get_implicit_conversions_ld_src( kind: ConversionKind::Ptr(state_space), }); if result.len() == 2 { - let new_id = id_def.new_id(Some((StateSpace::Reg, new_src_type))); + let new_id = id_def.new_id(new_src_type); result[0].dst = new_id; result[1].src = new_id; result[1].from = new_src_type; } - result + Ok(result) } _ => todo!(), } @@ -3281,10 +3493,10 @@ fn get_implicit_conversions_ld_src( fn insert_implicit_conversions_ld_src( func: &mut Vec, instr_type: ast::Type, - id_def: &mut NumericIdResolver, + id_def: &mut MutableNumericIdResolver, state_space: ast::LdStateSpace, src: spirv::Word, -) -> spirv::Word { +) -> Result { match state_space { ast::LdStateSpace::Param => insert_implicit_conversions_ld_src_impl( func, @@ -3304,15 +3516,15 @@ fn insert_implicit_conversions_ld_src( new_src_type, src, should_convert_ld_generic_src_to_bitcast, - ); - insert_conversion_src( + )?; + Ok(insert_conversion_src( func, id_def, new_src, new_src_type, instr_type, ConversionKind::Ptr(state_space), - ) + )) } _ => todo!(), } @@ -3322,16 +3534,18 @@ fn insert_implicit_conversions_ld_src_impl< ShouldConvert: FnOnce(ast::Type, ast::Type) -> Option, >( func: &mut Vec, - id_def: &mut NumericIdResolver, + id_def: &mut MutableNumericIdResolver, instr_type: ast::Type, src: spirv::Word, should_convert: ShouldConvert, -) -> spirv::Word { - let src_type = id_def.get_type(src).unwrap_or_else(|| todo!()).1; +) -> Result { + let src_type = id_def.get_typed(src)?; if let Some(conv) = should_convert(src_type, instr_type) { - insert_conversion_src(func, id_def, src, src_type, instr_type, conv) + Ok(insert_conversion_src( + func, id_def, src, src_type, instr_type, conv, + )) } else { - src + Ok(src) } } @@ -3363,13 +3577,13 @@ fn should_convert_ld_generic_src_to_bitcast( #[must_use] fn insert_conversion_src( func: &mut Vec, - id_def: &mut NumericIdResolver, + id_def: &mut MutableNumericIdResolver, src: spirv::Word, src_type: ast::Type, instr_type: ast::Type, conv: ConversionKind, ) -> spirv::Word { - let temp_src = id_def.new_id(Some((StateSpace::Reg, instr_type))); + let temp_src = id_def.new_id(instr_type); func.push(Statement::Conversion(ImplicitConversion { src: src, dst: temp_src, @@ -3408,14 +3622,14 @@ fn insert_with_implicit_conversion_dst< #[must_use] fn get_conversion_dst( - id_def: &mut NumericIdResolver, + id_def: &mut MutableNumericIdResolver, dst: &mut spirv::Word, instr_type: ast::Type, dst_type: ast::Type, kind: ConversionKind, ) -> ExpandedStatement { let original_dst = *dst; - let temp_dst = id_def.new_id(Some((StateSpace::Reg, instr_type))); + let temp_dst = id_def.new_id(instr_type); *dst = temp_dst; Statement::Conversion(ImplicitConversion { src: temp_dst, @@ -3525,17 +3739,17 @@ fn should_convert_relaxed_dst( fn insert_implicit_bitcasts( func: &mut Vec, - id_def: &mut NumericIdResolver, + id_def: &mut MutableNumericIdResolver, stmt: impl VisitVariableExpanded, -) { +) -> Result<(), TranslateError> { let mut dst_coercion = None; let instr = stmt.visit_variable_extended(&mut |mut desc, typ| { let id_type_from_instr = match typ { Some(t) => t, - None => return desc.op, + None => return Ok(desc.op), }; - let id_actual_type = id_def.get_type(desc.op).unwrap().1; - if should_bitcast(id_type_from_instr, id_def.get_type(desc.op).unwrap().1) { + let id_actual_type = id_def.get_typed(desc.op)?; + if should_bitcast(id_type_from_instr, id_def.get_typed(desc.op)?) { if desc.is_dst { dst_coercion = Some(get_conversion_dst( id_def, @@ -3544,25 +3758,26 @@ fn insert_implicit_bitcasts( id_actual_type, ConversionKind::Default, )); - desc.op + Ok(desc.op) } else { - insert_conversion_src( + Ok(insert_conversion_src( func, id_def, desc.op, id_actual_type, id_type_from_instr, ConversionKind::Default, - ) + )) } } else { - desc.op + Ok(desc.op) } - }); + })?; func.push(instr); if let Some(cond) = dst_coercion { func.push(cond); } + Ok(()) } impl<'a> ast::MethodDecl<'a, ast::ParsedArgParams<'a>> { fn name(&self) -> &'a str {