mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-08-08 09:09:49 +00:00
Wire new parser into spvtxt tests
This commit is contained in:
parent
790fe18579
commit
2e5ad8ebdf
8 changed files with 41 additions and 47 deletions
|
@ -19,7 +19,7 @@ pub(crate) fn run(
|
||||||
},
|
},
|
||||||
} if fn_defs.fns.contains_key(&src_reg) => {
|
} if fn_defs.fns.contains_key(&src_reg) => {
|
||||||
if data.typ != ast::Type::Scalar(ast::ScalarType::U64) {
|
if data.typ != ast::Type::Scalar(ast::ScalarType::U64) {
|
||||||
return Err(TranslateError::MismatchedType);
|
return Err(error_mismatched_type());
|
||||||
}
|
}
|
||||||
result.push(TypedStatement::FunctionPointer(FunctionPointerDetails {
|
result.push(TypedStatement::FunctionPointer(FunctionPointerDetails {
|
||||||
dst: dst_reg,
|
dst: dst_reg,
|
||||||
|
@ -68,7 +68,7 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> {
|
||||||
// mov.u32 foobar, {a,b};
|
// mov.u32 foobar, {a,b};
|
||||||
let scalar_t = match typ {
|
let scalar_t = match typ {
|
||||||
ast::Type::Vector(scalar_t, _) => *scalar_t,
|
ast::Type::Vector(scalar_t, _) => *scalar_t,
|
||||||
_ => return Err(TranslateError::MismatchedType),
|
_ => return Err(error_mismatched_type()),
|
||||||
};
|
};
|
||||||
let temp_vec = self
|
let temp_vec = self
|
||||||
.id_def
|
.id_def
|
||||||
|
@ -115,7 +115,7 @@ impl<'a, 'b> ast::VisitorMap<ast::ParsedOperand<SpirvWord>, TypedOperand, Transl
|
||||||
ast::ParsedOperand::Imm(x) => TypedOperand::Imm(x),
|
ast::ParsedOperand::Imm(x) => TypedOperand::Imm(x),
|
||||||
ast::ParsedOperand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx),
|
ast::ParsedOperand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx),
|
||||||
ast::ParsedOperand::VecPack(vec) => {
|
ast::ParsedOperand::VecPack(vec) => {
|
||||||
let (type_, space) = type_space.ok_or(TranslateError::MismatchedType)?;
|
let (type_, space) = type_space.ok_or(error_mismatched_type())?;
|
||||||
TypedOperand::Reg(self.convert_vector(
|
TypedOperand::Reg(self.convert_vector(
|
||||||
is_dst,
|
is_dst,
|
||||||
relaxed_type_check,
|
relaxed_type_check,
|
||||||
|
|
|
@ -17,7 +17,7 @@ pub(super) fn run<'input>(
|
||||||
HashMap<u8, (spirv::FPDenormMode, isize)>,
|
HashMap<u8, (spirv::FPDenormMode, isize)>,
|
||||||
>,
|
>,
|
||||||
directives: Vec<Directive<'input>>,
|
directives: Vec<Directive<'input>>,
|
||||||
) -> Result<(), TranslateError> {
|
) -> Result<(dr::Module, HashMap<String, KernelInfo>, CString), TranslateError> {
|
||||||
builder.set_version(1, 3);
|
builder.set_version(1, 3);
|
||||||
emit_capabilities(&mut builder);
|
emit_capabilities(&mut builder);
|
||||||
emit_extensions(&mut builder);
|
emit_extensions(&mut builder);
|
||||||
|
@ -39,7 +39,8 @@ pub(super) fn run<'input>(
|
||||||
globals_use_map,
|
globals_use_map,
|
||||||
directives,
|
directives,
|
||||||
&mut kernel_info,
|
&mut kernel_info,
|
||||||
)
|
)?;
|
||||||
|
Ok((builder.module(), kernel_info, build_options))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn emit_capabilities(builder: &mut dr::Builder) {
|
fn emit_capabilities(builder: &mut dr::Builder) {
|
||||||
|
@ -942,7 +943,7 @@ fn emit_function_body_ops<'input>(
|
||||||
builder.constant_true(bool_type, Some(cnst.dst.0));
|
builder.constant_true(bool_type, Some(cnst.dst.0));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => return Err(TranslateError::MismatchedType),
|
_ => return Err(error_mismatched_type()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Statement::Conversion(cv) => emit_implicit_conversion(builder, map, cv)?,
|
Statement::Conversion(cv) => emit_implicit_conversion(builder, map, cv)?,
|
||||||
|
@ -2646,7 +2647,7 @@ fn emit_load_var(
|
||||||
Some((index, Some(width))) => {
|
Some((index, Some(width))) => {
|
||||||
let vector_type = match details.typ {
|
let vector_type = match details.typ {
|
||||||
ast::Type::Scalar(scalar_t) => ast::Type::Vector(scalar_t, width),
|
ast::Type::Scalar(scalar_t) => ast::Type::Vector(scalar_t, width),
|
||||||
_ => return Err(TranslateError::MismatchedType),
|
_ => return Err(error_mismatched_type()),
|
||||||
};
|
};
|
||||||
let vector_type_spirv = map.get_or_add(builder, SpirvType::new(vector_type));
|
let vector_type_spirv = map.get_or_add(builder, SpirvType::new(vector_type));
|
||||||
let vector_temp = builder.load(
|
let vector_temp = builder.load(
|
||||||
|
|
|
@ -66,11 +66,11 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
|
||||||
if state_space != ast::StateSpace::Reg && state_space != ast::StateSpace::Sreg {
|
if state_space != ast::StateSpace::Reg && state_space != ast::StateSpace::Sreg {
|
||||||
let (reg_type, reg_space) = self.id_def.get_typed(reg)?;
|
let (reg_type, reg_space) = self.id_def.get_typed(reg)?;
|
||||||
if !state_is_compatible(reg_space, ast::StateSpace::Reg) {
|
if !state_is_compatible(reg_space, ast::StateSpace::Reg) {
|
||||||
return Err(TranslateError::MismatchedType);
|
return Err(error_mismatched_type());
|
||||||
}
|
}
|
||||||
let reg_scalar_type = match reg_type {
|
let reg_scalar_type = match reg_type {
|
||||||
ast::Type::Scalar(underlying_type) => underlying_type,
|
ast::Type::Scalar(underlying_type) => underlying_type,
|
||||||
_ => return Err(TranslateError::MismatchedType),
|
_ => return Err(error_mismatched_type()),
|
||||||
};
|
};
|
||||||
let id_constant_stmt = self
|
let id_constant_stmt = self
|
||||||
.id_def
|
.id_def
|
||||||
|
|
|
@ -58,7 +58,7 @@ impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'b, 'input> {
|
||||||
) -> Result<SpirvWord, TranslateError> {
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
if let Some(sreg) = self.numeric_id_defs.special_registers.get(name) {
|
if let Some(sreg) = self.numeric_id_defs.special_registers.get(name) {
|
||||||
if is_dst {
|
if is_dst {
|
||||||
return Err(TranslateError::MismatchedType);
|
return Err(error_mismatched_type());
|
||||||
}
|
}
|
||||||
let input_arguments = match (vector_index, sreg.get_function_input_type()) {
|
let input_arguments = match (vector_index, sreg.get_function_input_type()) {
|
||||||
(Some(idx), Some(inp_type)) => {
|
(Some(idx), Some(inp_type)) => {
|
||||||
|
@ -81,7 +81,7 @@ impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'b, 'input> {
|
||||||
)]
|
)]
|
||||||
}
|
}
|
||||||
(None, None) => Vec::new(),
|
(None, None) => Vec::new(),
|
||||||
_ => return Err(TranslateError::MismatchedType),
|
_ => return Err(error_mismatched_type()),
|
||||||
};
|
};
|
||||||
let ocl_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat();
|
let ocl_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat();
|
||||||
let return_type = sreg.get_function_return_type();
|
let return_type = sreg.get_function_return_type();
|
||||||
|
|
|
@ -168,7 +168,7 @@ fn default_implicit_conversion_space(
|
||||||
| ast::StateSpace::Const
|
| ast::StateSpace::Const
|
||||||
| ast::StateSpace::Local
|
| ast::StateSpace::Local
|
||||||
| ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)),
|
| ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)),
|
||||||
_ => Err(TranslateError::MismatchedType),
|
_ => Err(error_mismatched_type()),
|
||||||
},
|
},
|
||||||
ast::Type::Scalar(ast::ScalarType::B32)
|
ast::Type::Scalar(ast::ScalarType::B32)
|
||||||
| ast::Type::Scalar(ast::ScalarType::U32)
|
| ast::Type::Scalar(ast::ScalarType::U32)
|
||||||
|
@ -176,9 +176,9 @@ fn default_implicit_conversion_space(
|
||||||
ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => {
|
ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => {
|
||||||
Ok(Some(ConversionKind::BitToPtr))
|
Ok(Some(ConversionKind::BitToPtr))
|
||||||
}
|
}
|
||||||
_ => Err(TranslateError::MismatchedType),
|
_ => Err(error_mismatched_type()),
|
||||||
},
|
},
|
||||||
_ => Err(TranslateError::MismatchedType),
|
_ => Err(error_mismatched_type()),
|
||||||
}
|
}
|
||||||
} else if state_is_compatible(instruction_space, ast::StateSpace::Reg) {
|
} else if state_is_compatible(instruction_space, ast::StateSpace::Reg) {
|
||||||
match instruction_type {
|
match instruction_type {
|
||||||
|
@ -191,10 +191,10 @@ fn default_implicit_conversion_space(
|
||||||
Ok(None)
|
Ok(None)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => Err(TranslateError::MismatchedType),
|
_ => Err(error_mismatched_type()),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
Err(TranslateError::MismatchedType)
|
Err(error_mismatched_type())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -208,7 +208,7 @@ fn default_implicit_conversion_type(
|
||||||
if should_bitcast(instruction_type, operand_type) {
|
if should_bitcast(instruction_type, operand_type) {
|
||||||
Ok(Some(ConversionKind::Default))
|
Ok(Some(ConversionKind::Default))
|
||||||
} else {
|
} else {
|
||||||
Err(TranslateError::MismatchedType)
|
Err(error_mismatched_type())
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
Ok(Some(ConversionKind::PtrToPtr))
|
Ok(Some(ConversionKind::PtrToPtr))
|
||||||
|
@ -265,14 +265,14 @@ fn should_convert_relaxed_dst_wrapper(
|
||||||
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||||
if !state_is_compatible(operand_space, instruction_space) {
|
if !state_is_compatible(operand_space, instruction_space) {
|
||||||
return Err(TranslateError::MismatchedType);
|
return Err(error_mismatched_type());
|
||||||
}
|
}
|
||||||
if operand_type == instruction_type {
|
if operand_type == instruction_type {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
}
|
}
|
||||||
match should_convert_relaxed_dst(operand_type, instruction_type) {
|
match should_convert_relaxed_dst(operand_type, instruction_type) {
|
||||||
conv @ Some(_) => Ok(conv),
|
conv @ Some(_) => Ok(conv),
|
||||||
None => Err(TranslateError::MismatchedType),
|
None => Err(error_mismatched_type()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -342,14 +342,14 @@ fn should_convert_relaxed_src_wrapper(
|
||||||
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||||
if !state_is_compatible(operand_space, instruction_space) {
|
if !state_is_compatible(operand_space, instruction_space) {
|
||||||
return Err(TranslateError::MismatchedType);
|
return Err(error_mismatched_type());
|
||||||
}
|
}
|
||||||
if operand_type == instruction_type {
|
if operand_type == instruction_type {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
}
|
}
|
||||||
match should_convert_relaxed_src(operand_type, instruction_type) {
|
match should_convert_relaxed_src(operand_type, instruction_type) {
|
||||||
conv @ Some(_) => Ok(conv),
|
conv @ Some(_) => Ok(conv),
|
||||||
None => Err(TranslateError::MismatchedType),
|
None => Err(error_mismatched_type()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -199,7 +199,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
|
||||||
var_type = ast::Type::Scalar(scalar_t);
|
var_type = ast::Type::Scalar(scalar_t);
|
||||||
width
|
width
|
||||||
}
|
}
|
||||||
_ => return Err(TranslateError::MismatchedType),
|
_ => return Err(error_mismatched_type()),
|
||||||
};
|
};
|
||||||
Some((
|
Some((
|
||||||
idx,
|
idx,
|
||||||
|
|
|
@ -55,26 +55,7 @@ pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result<Module, Trans
|
||||||
})?;
|
})?;
|
||||||
normalize_variable_decls(&mut directives);
|
normalize_variable_decls(&mut directives);
|
||||||
let denorm_information = compute_denorm_information(&directives);
|
let denorm_information = compute_denorm_information(&directives);
|
||||||
emit_spirv::run(builder, &id_defs, call_map, denorm_information, directives);
|
let (spirv, kernel_info, build_options) = emit_spirv::run(builder, &id_defs, call_map, denorm_information, directives)?;
|
||||||
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
|
|
||||||
|
|
||||||
todo!()
|
|
||||||
/*
|
|
||||||
let (build_options, should_flush_denorms) =
|
|
||||||
emit_denorm_build_string(&call_map, &denorm_information);
|
|
||||||
let (directives, globals_use_map) = get_globals_use_map(directives);
|
|
||||||
emit_directives(
|
|
||||||
&mut builder,
|
|
||||||
&mut map,
|
|
||||||
&id_defs,
|
|
||||||
opencl_id,
|
|
||||||
should_flush_denorms,
|
|
||||||
&call_map,
|
|
||||||
globals_use_map,
|
|
||||||
directives,
|
|
||||||
&mut kernel_info,
|
|
||||||
)?;
|
|
||||||
let spirv = builder.module();
|
|
||||||
Ok(Module {
|
Ok(Module {
|
||||||
spirv,
|
spirv,
|
||||||
kernel_info,
|
kernel_info,
|
||||||
|
@ -85,7 +66,6 @@ pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result<Module, Trans
|
||||||
},
|
},
|
||||||
build_options,
|
build_options,
|
||||||
})
|
})
|
||||||
*/
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn translate_directive<'input, 'a>(
|
fn translate_directive<'input, 'a>(
|
||||||
|
@ -629,10 +609,24 @@ fn error_unreachable() -> TranslateError {
|
||||||
TranslateError::Unreachable
|
TranslateError::Unreachable
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn error_unknown_symbol() -> TranslateError {
|
||||||
|
panic!()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(debug_assertions))]
|
||||||
fn error_unknown_symbol() -> TranslateError {
|
fn error_unknown_symbol() -> TranslateError {
|
||||||
TranslateError::UnknownSymbol
|
TranslateError::UnknownSymbol
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn error_mismatched_type() -> TranslateError {
|
||||||
|
panic!()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(debug_assertions))]
|
||||||
|
fn error_mismatched_type() -> TranslateError {
|
||||||
|
TranslateError::MismatchedType
|
||||||
|
}
|
||||||
|
|
||||||
pub struct GlobalFnDeclResolver<'input, 'a> {
|
pub struct GlobalFnDeclResolver<'input, 'a> {
|
||||||
fns: &'a HashMap<SpirvWord, FnSigMapper<'input>>,
|
fns: &'a HashMap<SpirvWord, FnSigMapper<'input>>,
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
use crate::pass;
|
||||||
use crate::ptx;
|
use crate::ptx;
|
||||||
use crate::translate;
|
use crate::translate;
|
||||||
use hip_runtime_sys::hipError_t;
|
use hip_runtime_sys::hipError_t;
|
||||||
|
@ -385,10 +386,8 @@ fn test_spvtxt_assert<'a>(
|
||||||
spirv_txt: &'a [u8],
|
spirv_txt: &'a [u8],
|
||||||
spirv_file_name: &'a str,
|
spirv_file_name: &'a str,
|
||||||
) -> Result<(), Box<dyn error::Error + 'a>> {
|
) -> Result<(), Box<dyn error::Error + 'a>> {
|
||||||
let mut errors = Vec::new();
|
let ast = ptx_parser::parse_module_unchecked(ptx_txt).unwrap();
|
||||||
let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_txt)?;
|
let spirv_module = pass::to_spirv_module(ast)?;
|
||||||
assert!(errors.len() == 0);
|
|
||||||
let spirv_module = translate::to_spirv_module(ast)?;
|
|
||||||
let spv_context =
|
let spv_context =
|
||||||
unsafe { spirv_tools::spvContextCreate(spv_target_env::SPV_ENV_UNIVERSAL_1_3) };
|
unsafe { spirv_tools::spvContextCreate(spv_target_env::SPV_ENV_UNIVERSAL_1_3) };
|
||||||
assert!(spv_context != ptr::null_mut());
|
assert!(spv_context != ptr::null_mut());
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue