diff --git a/ptx/src/pass/convert_to_typed.rs b/ptx/src/pass/convert_to_typed.rs index 7ff5290..2342ad5 100644 --- a/ptx/src/pass/convert_to_typed.rs +++ b/ptx/src/pass/convert_to_typed.rs @@ -19,7 +19,7 @@ pub(crate) fn run( }, } if fn_defs.fns.contains_key(&src_reg) => { if data.typ != ast::Type::Scalar(ast::ScalarType::U64) { - return Err(TranslateError::MismatchedType); + return Err(error_mismatched_type()); } result.push(TypedStatement::FunctionPointer(FunctionPointerDetails { dst: dst_reg, @@ -68,7 +68,7 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> { // mov.u32 foobar, {a,b}; let scalar_t = match typ { ast::Type::Vector(scalar_t, _) => *scalar_t, - _ => return Err(TranslateError::MismatchedType), + _ => return Err(error_mismatched_type()), }; let temp_vec = self .id_def @@ -115,7 +115,7 @@ impl<'a, 'b> ast::VisitorMap, TypedOperand, Transl ast::ParsedOperand::Imm(x) => TypedOperand::Imm(x), ast::ParsedOperand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx), ast::ParsedOperand::VecPack(vec) => { - let (type_, space) = type_space.ok_or(TranslateError::MismatchedType)?; + let (type_, space) = type_space.ok_or(error_mismatched_type())?; TypedOperand::Reg(self.convert_vector( is_dst, relaxed_type_check, diff --git a/ptx/src/pass/emit_spirv.rs b/ptx/src/pass/emit_spirv.rs index 9dff12e..e2e6a3b 100644 --- a/ptx/src/pass/emit_spirv.rs +++ b/ptx/src/pass/emit_spirv.rs @@ -17,7 +17,7 @@ pub(super) fn run<'input>( HashMap, >, directives: Vec>, -) -> Result<(), TranslateError> { +) -> Result<(dr::Module, HashMap, CString), TranslateError> { builder.set_version(1, 3); emit_capabilities(&mut builder); emit_extensions(&mut builder); @@ -39,7 +39,8 @@ pub(super) fn run<'input>( globals_use_map, directives, &mut kernel_info, - ) + )?; + Ok((builder.module(), kernel_info, build_options)) } fn emit_capabilities(builder: &mut dr::Builder) { @@ -942,7 +943,7 @@ fn emit_function_body_ops<'input>( builder.constant_true(bool_type, Some(cnst.dst.0)); } } - _ => return Err(TranslateError::MismatchedType), + _ => return Err(error_mismatched_type()), } } Statement::Conversion(cv) => emit_implicit_conversion(builder, map, cv)?, @@ -2646,7 +2647,7 @@ fn emit_load_var( Some((index, Some(width))) => { let vector_type = match details.typ { ast::Type::Scalar(scalar_t) => ast::Type::Vector(scalar_t, width), - _ => return Err(TranslateError::MismatchedType), + _ => return Err(error_mismatched_type()), }; let vector_type_spirv = map.get_or_add(builder, SpirvType::new(vector_type)); let vector_temp = builder.load( diff --git a/ptx/src/pass/expand_arguments.rs b/ptx/src/pass/expand_arguments.rs index eb03866..3680005 100644 --- a/ptx/src/pass/expand_arguments.rs +++ b/ptx/src/pass/expand_arguments.rs @@ -66,11 +66,11 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { if state_space != ast::StateSpace::Reg && state_space != ast::StateSpace::Sreg { let (reg_type, reg_space) = self.id_def.get_typed(reg)?; if !state_is_compatible(reg_space, ast::StateSpace::Reg) { - return Err(TranslateError::MismatchedType); + return Err(error_mismatched_type()); } let reg_scalar_type = match reg_type { ast::Type::Scalar(underlying_type) => underlying_type, - _ => return Err(TranslateError::MismatchedType), + _ => return Err(error_mismatched_type()), }; let id_constant_stmt = self .id_def diff --git a/ptx/src/pass/fix_special_registers.rs b/ptx/src/pass/fix_special_registers.rs index 304bc61..c029016 100644 --- a/ptx/src/pass/fix_special_registers.rs +++ b/ptx/src/pass/fix_special_registers.rs @@ -58,7 +58,7 @@ impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'b, 'input> { ) -> Result { if let Some(sreg) = self.numeric_id_defs.special_registers.get(name) { if is_dst { - return Err(TranslateError::MismatchedType); + return Err(error_mismatched_type()); } let input_arguments = match (vector_index, sreg.get_function_input_type()) { (Some(idx), Some(inp_type)) => { @@ -81,7 +81,7 @@ impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'b, 'input> { )] } (None, None) => Vec::new(), - _ => return Err(TranslateError::MismatchedType), + _ => return Err(error_mismatched_type()), }; let ocl_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat(); let return_type = sreg.get_function_return_type(); diff --git a/ptx/src/pass/insert_implicit_conversions.rs b/ptx/src/pass/insert_implicit_conversions.rs index 4a0dc8e..baf3453 100644 --- a/ptx/src/pass/insert_implicit_conversions.rs +++ b/ptx/src/pass/insert_implicit_conversions.rs @@ -168,7 +168,7 @@ fn default_implicit_conversion_space( | ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)), - _ => Err(TranslateError::MismatchedType), + _ => Err(error_mismatched_type()), }, ast::Type::Scalar(ast::ScalarType::B32) | ast::Type::Scalar(ast::ScalarType::U32) @@ -176,9 +176,9 @@ fn default_implicit_conversion_space( ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => { Ok(Some(ConversionKind::BitToPtr)) } - _ => Err(TranslateError::MismatchedType), + _ => Err(error_mismatched_type()), }, - _ => Err(TranslateError::MismatchedType), + _ => Err(error_mismatched_type()), } } else if state_is_compatible(instruction_space, ast::StateSpace::Reg) { match instruction_type { @@ -191,10 +191,10 @@ fn default_implicit_conversion_space( Ok(None) } } - _ => Err(TranslateError::MismatchedType), + _ => Err(error_mismatched_type()), } } else { - Err(TranslateError::MismatchedType) + Err(error_mismatched_type()) } } @@ -208,7 +208,7 @@ fn default_implicit_conversion_type( if should_bitcast(instruction_type, operand_type) { Ok(Some(ConversionKind::Default)) } else { - Err(TranslateError::MismatchedType) + Err(error_mismatched_type()) } } else { Ok(Some(ConversionKind::PtrToPtr)) @@ -265,14 +265,14 @@ fn should_convert_relaxed_dst_wrapper( (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), ) -> Result, TranslateError> { if !state_is_compatible(operand_space, instruction_space) { - return Err(TranslateError::MismatchedType); + return Err(error_mismatched_type()); } if operand_type == instruction_type { return Ok(None); } match should_convert_relaxed_dst(operand_type, instruction_type) { conv @ Some(_) => Ok(conv), - None => Err(TranslateError::MismatchedType), + None => Err(error_mismatched_type()), } } @@ -342,14 +342,14 @@ fn should_convert_relaxed_src_wrapper( (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), ) -> Result, TranslateError> { if !state_is_compatible(operand_space, instruction_space) { - return Err(TranslateError::MismatchedType); + return Err(error_mismatched_type()); } if operand_type == instruction_type { return Ok(None); } match should_convert_relaxed_src(operand_type, instruction_type) { conv @ Some(_) => Ok(conv), - None => Err(TranslateError::MismatchedType), + None => Err(error_mismatched_type()), } } diff --git a/ptx/src/pass/insert_mem_ssa_statements.rs b/ptx/src/pass/insert_mem_ssa_statements.rs index 6ab19bd..7369cdb 100644 --- a/ptx/src/pass/insert_mem_ssa_statements.rs +++ b/ptx/src/pass/insert_mem_ssa_statements.rs @@ -199,7 +199,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { var_type = ast::Type::Scalar(scalar_t); width } - _ => return Err(TranslateError::MismatchedType), + _ => return Err(error_mismatched_type()), }; Some(( idx, diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 8923718..2825017 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -55,26 +55,7 @@ pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result(ast: ast::Module<'input>) -> Result( @@ -629,10 +609,24 @@ fn error_unreachable() -> TranslateError { TranslateError::Unreachable } +fn error_unknown_symbol() -> TranslateError { + panic!() +} + +#[cfg(not(debug_assertions))] fn error_unknown_symbol() -> TranslateError { TranslateError::UnknownSymbol } +fn error_mismatched_type() -> TranslateError { + panic!() +} + +#[cfg(not(debug_assertions))] +fn error_mismatched_type() -> TranslateError { + TranslateError::MismatchedType +} + pub struct GlobalFnDeclResolver<'input, 'a> { fns: &'a HashMap>, } diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index f5dfa64..62dba04 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -1,3 +1,4 @@ +use crate::pass; use crate::ptx; use crate::translate; use hip_runtime_sys::hipError_t; @@ -385,10 +386,8 @@ fn test_spvtxt_assert<'a>( spirv_txt: &'a [u8], spirv_file_name: &'a str, ) -> Result<(), Box> { - let mut errors = Vec::new(); - let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_txt)?; - assert!(errors.len() == 0); - let spirv_module = translate::to_spirv_module(ast)?; + let ast = ptx_parser::parse_module_unchecked(ptx_txt).unwrap(); + let spirv_module = pass::to_spirv_module(ast)?; let spv_context = unsafe { spirv_tools::spvContextCreate(spv_target_env::SPV_ENV_UNIVERSAL_1_3) }; assert!(spv_context != ptr::null_mut());