Wire new parser into spvtxt tests

This commit is contained in:
Andrzej Janik 2024-08-30 17:01:47 +02:00
parent 790fe18579
commit 2e5ad8ebdf
8 changed files with 41 additions and 47 deletions

View file

@ -19,7 +19,7 @@ pub(crate) fn run(
},
} if fn_defs.fns.contains_key(&src_reg) => {
if data.typ != ast::Type::Scalar(ast::ScalarType::U64) {
return Err(TranslateError::MismatchedType);
return Err(error_mismatched_type());
}
result.push(TypedStatement::FunctionPointer(FunctionPointerDetails {
dst: dst_reg,
@ -68,7 +68,7 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> {
// mov.u32 foobar, {a,b};
let scalar_t = match typ {
ast::Type::Vector(scalar_t, _) => *scalar_t,
_ => return Err(TranslateError::MismatchedType),
_ => return Err(error_mismatched_type()),
};
let temp_vec = self
.id_def
@ -115,7 +115,7 @@ impl<'a, 'b> ast::VisitorMap<ast::ParsedOperand<SpirvWord>, TypedOperand, Transl
ast::ParsedOperand::Imm(x) => TypedOperand::Imm(x),
ast::ParsedOperand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx),
ast::ParsedOperand::VecPack(vec) => {
let (type_, space) = type_space.ok_or(TranslateError::MismatchedType)?;
let (type_, space) = type_space.ok_or(error_mismatched_type())?;
TypedOperand::Reg(self.convert_vector(
is_dst,
relaxed_type_check,

View file

@ -17,7 +17,7 @@ pub(super) fn run<'input>(
HashMap<u8, (spirv::FPDenormMode, isize)>,
>,
directives: Vec<Directive<'input>>,
) -> Result<(), TranslateError> {
) -> Result<(dr::Module, HashMap<String, KernelInfo>, CString), TranslateError> {
builder.set_version(1, 3);
emit_capabilities(&mut builder);
emit_extensions(&mut builder);
@ -39,7 +39,8 @@ pub(super) fn run<'input>(
globals_use_map,
directives,
&mut kernel_info,
)
)?;
Ok((builder.module(), kernel_info, build_options))
}
fn emit_capabilities(builder: &mut dr::Builder) {
@ -942,7 +943,7 @@ fn emit_function_body_ops<'input>(
builder.constant_true(bool_type, Some(cnst.dst.0));
}
}
_ => return Err(TranslateError::MismatchedType),
_ => return Err(error_mismatched_type()),
}
}
Statement::Conversion(cv) => emit_implicit_conversion(builder, map, cv)?,
@ -2646,7 +2647,7 @@ fn emit_load_var(
Some((index, Some(width))) => {
let vector_type = match details.typ {
ast::Type::Scalar(scalar_t) => ast::Type::Vector(scalar_t, width),
_ => return Err(TranslateError::MismatchedType),
_ => return Err(error_mismatched_type()),
};
let vector_type_spirv = map.get_or_add(builder, SpirvType::new(vector_type));
let vector_temp = builder.load(

View file

@ -66,11 +66,11 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
if state_space != ast::StateSpace::Reg && state_space != ast::StateSpace::Sreg {
let (reg_type, reg_space) = self.id_def.get_typed(reg)?;
if !state_is_compatible(reg_space, ast::StateSpace::Reg) {
return Err(TranslateError::MismatchedType);
return Err(error_mismatched_type());
}
let reg_scalar_type = match reg_type {
ast::Type::Scalar(underlying_type) => underlying_type,
_ => return Err(TranslateError::MismatchedType),
_ => return Err(error_mismatched_type()),
};
let id_constant_stmt = self
.id_def

View file

@ -58,7 +58,7 @@ impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'b, 'input> {
) -> Result<SpirvWord, TranslateError> {
if let Some(sreg) = self.numeric_id_defs.special_registers.get(name) {
if is_dst {
return Err(TranslateError::MismatchedType);
return Err(error_mismatched_type());
}
let input_arguments = match (vector_index, sreg.get_function_input_type()) {
(Some(idx), Some(inp_type)) => {
@ -81,7 +81,7 @@ impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'b, 'input> {
)]
}
(None, None) => Vec::new(),
_ => return Err(TranslateError::MismatchedType),
_ => return Err(error_mismatched_type()),
};
let ocl_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat();
let return_type = sreg.get_function_return_type();

View file

@ -168,7 +168,7 @@ fn default_implicit_conversion_space(
| ast::StateSpace::Const
| ast::StateSpace::Local
| ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)),
_ => Err(TranslateError::MismatchedType),
_ => Err(error_mismatched_type()),
},
ast::Type::Scalar(ast::ScalarType::B32)
| ast::Type::Scalar(ast::ScalarType::U32)
@ -176,9 +176,9 @@ fn default_implicit_conversion_space(
ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => {
Ok(Some(ConversionKind::BitToPtr))
}
_ => Err(TranslateError::MismatchedType),
_ => Err(error_mismatched_type()),
},
_ => Err(TranslateError::MismatchedType),
_ => Err(error_mismatched_type()),
}
} else if state_is_compatible(instruction_space, ast::StateSpace::Reg) {
match instruction_type {
@ -191,10 +191,10 @@ fn default_implicit_conversion_space(
Ok(None)
}
}
_ => Err(TranslateError::MismatchedType),
_ => Err(error_mismatched_type()),
}
} else {
Err(TranslateError::MismatchedType)
Err(error_mismatched_type())
}
}
@ -208,7 +208,7 @@ fn default_implicit_conversion_type(
if should_bitcast(instruction_type, operand_type) {
Ok(Some(ConversionKind::Default))
} else {
Err(TranslateError::MismatchedType)
Err(error_mismatched_type())
}
} else {
Ok(Some(ConversionKind::PtrToPtr))
@ -265,14 +265,14 @@ fn should_convert_relaxed_dst_wrapper(
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
if !state_is_compatible(operand_space, instruction_space) {
return Err(TranslateError::MismatchedType);
return Err(error_mismatched_type());
}
if operand_type == instruction_type {
return Ok(None);
}
match should_convert_relaxed_dst(operand_type, instruction_type) {
conv @ Some(_) => Ok(conv),
None => Err(TranslateError::MismatchedType),
None => Err(error_mismatched_type()),
}
}
@ -342,14 +342,14 @@ fn should_convert_relaxed_src_wrapper(
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
if !state_is_compatible(operand_space, instruction_space) {
return Err(TranslateError::MismatchedType);
return Err(error_mismatched_type());
}
if operand_type == instruction_type {
return Ok(None);
}
match should_convert_relaxed_src(operand_type, instruction_type) {
conv @ Some(_) => Ok(conv),
None => Err(TranslateError::MismatchedType),
None => Err(error_mismatched_type()),
}
}

View file

@ -199,7 +199,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
var_type = ast::Type::Scalar(scalar_t);
width
}
_ => return Err(TranslateError::MismatchedType),
_ => return Err(error_mismatched_type()),
};
Some((
idx,

View file

@ -55,26 +55,7 @@ pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result<Module, Trans
})?;
normalize_variable_decls(&mut directives);
let denorm_information = compute_denorm_information(&directives);
emit_spirv::run(builder, &id_defs, call_map, denorm_information, directives);
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
todo!()
/*
let (build_options, should_flush_denorms) =
emit_denorm_build_string(&call_map, &denorm_information);
let (directives, globals_use_map) = get_globals_use_map(directives);
emit_directives(
&mut builder,
&mut map,
&id_defs,
opencl_id,
should_flush_denorms,
&call_map,
globals_use_map,
directives,
&mut kernel_info,
)?;
let spirv = builder.module();
let (spirv, kernel_info, build_options) = emit_spirv::run(builder, &id_defs, call_map, denorm_information, directives)?;
Ok(Module {
spirv,
kernel_info,
@ -85,7 +66,6 @@ pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result<Module, Trans
},
build_options,
})
*/
}
fn translate_directive<'input, 'a>(
@ -629,10 +609,24 @@ fn error_unreachable() -> TranslateError {
TranslateError::Unreachable
}
fn error_unknown_symbol() -> TranslateError {
panic!()
}
#[cfg(not(debug_assertions))]
fn error_unknown_symbol() -> TranslateError {
TranslateError::UnknownSymbol
}
fn error_mismatched_type() -> TranslateError {
panic!()
}
#[cfg(not(debug_assertions))]
fn error_mismatched_type() -> TranslateError {
TranslateError::MismatchedType
}
pub struct GlobalFnDeclResolver<'input, 'a> {
fns: &'a HashMap<SpirvWord, FnSigMapper<'input>>,
}

View file

@ -1,3 +1,4 @@
use crate::pass;
use crate::ptx;
use crate::translate;
use hip_runtime_sys::hipError_t;
@ -385,10 +386,8 @@ fn test_spvtxt_assert<'a>(
spirv_txt: &'a [u8],
spirv_file_name: &'a str,
) -> Result<(), Box<dyn error::Error + 'a>> {
let mut errors = Vec::new();
let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_txt)?;
assert!(errors.len() == 0);
let spirv_module = translate::to_spirv_module(ast)?;
let ast = ptx_parser::parse_module_unchecked(ptx_txt).unwrap();
let spirv_module = pass::to_spirv_module(ast)?;
let spv_context =
unsafe { spirv_tools::spvContextCreate(spv_target_env::SPV_ENV_UNIVERSAL_1_3) };
assert!(spv_context != ptr::null_mut());