Wire new parser into spvtxt tests

This commit is contained in:
Andrzej Janik 2024-08-30 17:01:47 +02:00
commit 2e5ad8ebdf
8 changed files with 41 additions and 47 deletions

View file

@ -19,7 +19,7 @@ pub(crate) fn run(
}, },
} if fn_defs.fns.contains_key(&src_reg) => { } if fn_defs.fns.contains_key(&src_reg) => {
if data.typ != ast::Type::Scalar(ast::ScalarType::U64) { if data.typ != ast::Type::Scalar(ast::ScalarType::U64) {
return Err(TranslateError::MismatchedType); return Err(error_mismatched_type());
} }
result.push(TypedStatement::FunctionPointer(FunctionPointerDetails { result.push(TypedStatement::FunctionPointer(FunctionPointerDetails {
dst: dst_reg, dst: dst_reg,
@ -68,7 +68,7 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> {
// mov.u32 foobar, {a,b}; // mov.u32 foobar, {a,b};
let scalar_t = match typ { let scalar_t = match typ {
ast::Type::Vector(scalar_t, _) => *scalar_t, ast::Type::Vector(scalar_t, _) => *scalar_t,
_ => return Err(TranslateError::MismatchedType), _ => return Err(error_mismatched_type()),
}; };
let temp_vec = self let temp_vec = self
.id_def .id_def
@ -115,7 +115,7 @@ impl<'a, 'b> ast::VisitorMap<ast::ParsedOperand<SpirvWord>, TypedOperand, Transl
ast::ParsedOperand::Imm(x) => TypedOperand::Imm(x), ast::ParsedOperand::Imm(x) => TypedOperand::Imm(x),
ast::ParsedOperand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx), ast::ParsedOperand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx),
ast::ParsedOperand::VecPack(vec) => { ast::ParsedOperand::VecPack(vec) => {
let (type_, space) = type_space.ok_or(TranslateError::MismatchedType)?; let (type_, space) = type_space.ok_or(error_mismatched_type())?;
TypedOperand::Reg(self.convert_vector( TypedOperand::Reg(self.convert_vector(
is_dst, is_dst,
relaxed_type_check, relaxed_type_check,

View file

@ -17,7 +17,7 @@ pub(super) fn run<'input>(
HashMap<u8, (spirv::FPDenormMode, isize)>, HashMap<u8, (spirv::FPDenormMode, isize)>,
>, >,
directives: Vec<Directive<'input>>, directives: Vec<Directive<'input>>,
) -> Result<(), TranslateError> { ) -> Result<(dr::Module, HashMap<String, KernelInfo>, CString), TranslateError> {
builder.set_version(1, 3); builder.set_version(1, 3);
emit_capabilities(&mut builder); emit_capabilities(&mut builder);
emit_extensions(&mut builder); emit_extensions(&mut builder);
@ -39,7 +39,8 @@ pub(super) fn run<'input>(
globals_use_map, globals_use_map,
directives, directives,
&mut kernel_info, &mut kernel_info,
) )?;
Ok((builder.module(), kernel_info, build_options))
} }
fn emit_capabilities(builder: &mut dr::Builder) { fn emit_capabilities(builder: &mut dr::Builder) {
@ -942,7 +943,7 @@ fn emit_function_body_ops<'input>(
builder.constant_true(bool_type, Some(cnst.dst.0)); builder.constant_true(bool_type, Some(cnst.dst.0));
} }
} }
_ => return Err(TranslateError::MismatchedType), _ => return Err(error_mismatched_type()),
} }
} }
Statement::Conversion(cv) => emit_implicit_conversion(builder, map, cv)?, Statement::Conversion(cv) => emit_implicit_conversion(builder, map, cv)?,
@ -2646,7 +2647,7 @@ fn emit_load_var(
Some((index, Some(width))) => { Some((index, Some(width))) => {
let vector_type = match details.typ { let vector_type = match details.typ {
ast::Type::Scalar(scalar_t) => ast::Type::Vector(scalar_t, width), ast::Type::Scalar(scalar_t) => ast::Type::Vector(scalar_t, width),
_ => return Err(TranslateError::MismatchedType), _ => return Err(error_mismatched_type()),
}; };
let vector_type_spirv = map.get_or_add(builder, SpirvType::new(vector_type)); let vector_type_spirv = map.get_or_add(builder, SpirvType::new(vector_type));
let vector_temp = builder.load( let vector_temp = builder.load(

View file

@ -66,11 +66,11 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
if state_space != ast::StateSpace::Reg && state_space != ast::StateSpace::Sreg { if state_space != ast::StateSpace::Reg && state_space != ast::StateSpace::Sreg {
let (reg_type, reg_space) = self.id_def.get_typed(reg)?; let (reg_type, reg_space) = self.id_def.get_typed(reg)?;
if !state_is_compatible(reg_space, ast::StateSpace::Reg) { if !state_is_compatible(reg_space, ast::StateSpace::Reg) {
return Err(TranslateError::MismatchedType); return Err(error_mismatched_type());
} }
let reg_scalar_type = match reg_type { let reg_scalar_type = match reg_type {
ast::Type::Scalar(underlying_type) => underlying_type, ast::Type::Scalar(underlying_type) => underlying_type,
_ => return Err(TranslateError::MismatchedType), _ => return Err(error_mismatched_type()),
}; };
let id_constant_stmt = self let id_constant_stmt = self
.id_def .id_def

View file

@ -58,7 +58,7 @@ impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'b, 'input> {
) -> Result<SpirvWord, TranslateError> { ) -> Result<SpirvWord, TranslateError> {
if let Some(sreg) = self.numeric_id_defs.special_registers.get(name) { if let Some(sreg) = self.numeric_id_defs.special_registers.get(name) {
if is_dst { if is_dst {
return Err(TranslateError::MismatchedType); return Err(error_mismatched_type());
} }
let input_arguments = match (vector_index, sreg.get_function_input_type()) { let input_arguments = match (vector_index, sreg.get_function_input_type()) {
(Some(idx), Some(inp_type)) => { (Some(idx), Some(inp_type)) => {
@ -81,7 +81,7 @@ impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'b, 'input> {
)] )]
} }
(None, None) => Vec::new(), (None, None) => Vec::new(),
_ => return Err(TranslateError::MismatchedType), _ => return Err(error_mismatched_type()),
}; };
let ocl_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat(); let ocl_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat();
let return_type = sreg.get_function_return_type(); let return_type = sreg.get_function_return_type();

View file

@ -168,7 +168,7 @@ fn default_implicit_conversion_space(
| ast::StateSpace::Const | ast::StateSpace::Const
| ast::StateSpace::Local | ast::StateSpace::Local
| ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)), | ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)),
_ => Err(TranslateError::MismatchedType), _ => Err(error_mismatched_type()),
}, },
ast::Type::Scalar(ast::ScalarType::B32) ast::Type::Scalar(ast::ScalarType::B32)
| ast::Type::Scalar(ast::ScalarType::U32) | ast::Type::Scalar(ast::ScalarType::U32)
@ -176,9 +176,9 @@ fn default_implicit_conversion_space(
ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => { ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => {
Ok(Some(ConversionKind::BitToPtr)) Ok(Some(ConversionKind::BitToPtr))
} }
_ => Err(TranslateError::MismatchedType), _ => Err(error_mismatched_type()),
}, },
_ => Err(TranslateError::MismatchedType), _ => Err(error_mismatched_type()),
} }
} else if state_is_compatible(instruction_space, ast::StateSpace::Reg) { } else if state_is_compatible(instruction_space, ast::StateSpace::Reg) {
match instruction_type { match instruction_type {
@ -191,10 +191,10 @@ fn default_implicit_conversion_space(
Ok(None) Ok(None)
} }
} }
_ => Err(TranslateError::MismatchedType), _ => Err(error_mismatched_type()),
} }
} else { } else {
Err(TranslateError::MismatchedType) Err(error_mismatched_type())
} }
} }
@ -208,7 +208,7 @@ fn default_implicit_conversion_type(
if should_bitcast(instruction_type, operand_type) { if should_bitcast(instruction_type, operand_type) {
Ok(Some(ConversionKind::Default)) Ok(Some(ConversionKind::Default))
} else { } else {
Err(TranslateError::MismatchedType) Err(error_mismatched_type())
} }
} else { } else {
Ok(Some(ConversionKind::PtrToPtr)) Ok(Some(ConversionKind::PtrToPtr))
@ -265,14 +265,14 @@ fn should_convert_relaxed_dst_wrapper(
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type), (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> { ) -> Result<Option<ConversionKind>, TranslateError> {
if !state_is_compatible(operand_space, instruction_space) { if !state_is_compatible(operand_space, instruction_space) {
return Err(TranslateError::MismatchedType); return Err(error_mismatched_type());
} }
if operand_type == instruction_type { if operand_type == instruction_type {
return Ok(None); return Ok(None);
} }
match should_convert_relaxed_dst(operand_type, instruction_type) { match should_convert_relaxed_dst(operand_type, instruction_type) {
conv @ Some(_) => Ok(conv), conv @ Some(_) => Ok(conv),
None => Err(TranslateError::MismatchedType), None => Err(error_mismatched_type()),
} }
} }
@ -342,14 +342,14 @@ fn should_convert_relaxed_src_wrapper(
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type), (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> { ) -> Result<Option<ConversionKind>, TranslateError> {
if !state_is_compatible(operand_space, instruction_space) { if !state_is_compatible(operand_space, instruction_space) {
return Err(TranslateError::MismatchedType); return Err(error_mismatched_type());
} }
if operand_type == instruction_type { if operand_type == instruction_type {
return Ok(None); return Ok(None);
} }
match should_convert_relaxed_src(operand_type, instruction_type) { match should_convert_relaxed_src(operand_type, instruction_type) {
conv @ Some(_) => Ok(conv), conv @ Some(_) => Ok(conv),
None => Err(TranslateError::MismatchedType), None => Err(error_mismatched_type()),
} }
} }

View file

@ -199,7 +199,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
var_type = ast::Type::Scalar(scalar_t); var_type = ast::Type::Scalar(scalar_t);
width width
} }
_ => return Err(TranslateError::MismatchedType), _ => return Err(error_mismatched_type()),
}; };
Some(( Some((
idx, idx,

View file

@ -55,26 +55,7 @@ pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result<Module, Trans
})?; })?;
normalize_variable_decls(&mut directives); normalize_variable_decls(&mut directives);
let denorm_information = compute_denorm_information(&directives); let denorm_information = compute_denorm_information(&directives);
emit_spirv::run(builder, &id_defs, call_map, denorm_information, directives); let (spirv, kernel_info, build_options) = emit_spirv::run(builder, &id_defs, call_map, denorm_information, directives)?;
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
todo!()
/*
let (build_options, should_flush_denorms) =
emit_denorm_build_string(&call_map, &denorm_information);
let (directives, globals_use_map) = get_globals_use_map(directives);
emit_directives(
&mut builder,
&mut map,
&id_defs,
opencl_id,
should_flush_denorms,
&call_map,
globals_use_map,
directives,
&mut kernel_info,
)?;
let spirv = builder.module();
Ok(Module { Ok(Module {
spirv, spirv,
kernel_info, kernel_info,
@ -85,7 +66,6 @@ pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result<Module, Trans
}, },
build_options, build_options,
}) })
*/
} }
fn translate_directive<'input, 'a>( fn translate_directive<'input, 'a>(
@ -629,10 +609,24 @@ fn error_unreachable() -> TranslateError {
TranslateError::Unreachable TranslateError::Unreachable
} }
fn error_unknown_symbol() -> TranslateError {
panic!()
}
#[cfg(not(debug_assertions))]
fn error_unknown_symbol() -> TranslateError { fn error_unknown_symbol() -> TranslateError {
TranslateError::UnknownSymbol TranslateError::UnknownSymbol
} }
fn error_mismatched_type() -> TranslateError {
panic!()
}
#[cfg(not(debug_assertions))]
fn error_mismatched_type() -> TranslateError {
TranslateError::MismatchedType
}
pub struct GlobalFnDeclResolver<'input, 'a> { pub struct GlobalFnDeclResolver<'input, 'a> {
fns: &'a HashMap<SpirvWord, FnSigMapper<'input>>, fns: &'a HashMap<SpirvWord, FnSigMapper<'input>>,
} }

View file

@ -1,3 +1,4 @@
use crate::pass;
use crate::ptx; use crate::ptx;
use crate::translate; use crate::translate;
use hip_runtime_sys::hipError_t; use hip_runtime_sys::hipError_t;
@ -385,10 +386,8 @@ fn test_spvtxt_assert<'a>(
spirv_txt: &'a [u8], spirv_txt: &'a [u8],
spirv_file_name: &'a str, spirv_file_name: &'a str,
) -> Result<(), Box<dyn error::Error + 'a>> { ) -> Result<(), Box<dyn error::Error + 'a>> {
let mut errors = Vec::new(); let ast = ptx_parser::parse_module_unchecked(ptx_txt).unwrap();
let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_txt)?; let spirv_module = pass::to_spirv_module(ast)?;
assert!(errors.len() == 0);
let spirv_module = translate::to_spirv_module(ast)?;
let spv_context = let spv_context =
unsafe { spirv_tools::spvContextCreate(spv_target_env::SPV_ENV_UNIVERSAL_1_3) }; unsafe { spirv_tools::spvContextCreate(spv_target_env::SPV_ENV_UNIVERSAL_1_3) };
assert!(spv_context != ptr::null_mut()); assert!(spv_context != ptr::null_mut());