Fix bugs in the new middle-end

This commit is contained in:
Andrzej Janik 2020-07-24 02:09:50 +02:00
parent 09be47a919
commit 18e5147fdc
3 changed files with 180 additions and 79 deletions

View file

@ -1,25 +1,38 @@
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int64
OpCapability Int8
%1 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %5 "ld_st"
%2 = OpTypeVoid
%3 = OpTypeInt 64 0
%4 = OpTypeFunction %2 %3 %3
%19 = OpTypePointer Generic %3
%5 = OpFunction %2 None %4
%6 = OpFunctionParameter %3
%7 = OpFunctionParameter %3
%18 = OpLabel
%13 = OpCopyObject %3 %6
%14 = OpCopyObject %3 %7
%15 = OpConvertUToPtr %19 %13
%16 = OpLoad %3 %15
%17 = OpConvertUToPtr %19 %14
OpStore %17 %16
OpReturn
OpFunctionEnd
; SPIR-V
; Version: 1.5
; Generator: Khronos SPIR-V Tools Assembler; 0
; Bound: 20
; Schema: 0
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int64
OpCapability Int8
%1 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %2 "ld_st"
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%5 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
%2 = OpFunction %void None %5
%8 = OpFunctionParameter %ulong
%9 = OpFunctionParameter %ulong
%10 = OpLabel
%11 = OpVariable %_ptr_Function_ulong Function
%12 = OpVariable %_ptr_Function_ulong Function
%13 = OpVariable %_ptr_Function_ulong Function
OpStore %11 %8
OpStore %12 %9
%14 = OpLoad %ulong %11
%15 = OpConvertUToPtr %_ptr_Generic_ulong %14
%16 = OpLoad %ulong %15
OpStore %13 %16
%17 = OpLoad %ulong %12
%18 = OpLoad %ulong %13
%19 = OpConvertUToPtr %_ptr_Generic_ulong %17
OpStore %19 %18
OpReturn
OpFunctionEnd

View file

@ -169,11 +169,17 @@ fn is_spirv_fn_equal(fn1: &Function, fn2: &Function) -> bool {
if !is_option_equal(&fn1.end, &fn2.end, &mut map, is_instr_equal) {
return false;
}
if fn1.parameters.len() != fn2.parameters.len() {
return false;
}
for (inst1, inst2) in fn1.parameters.iter().zip(fn2.parameters.iter()) {
if !is_instr_equal(inst1, inst2, &mut map) {
return false;
}
}
if fn1.blocks.len() != fn2.blocks.len() {
return false;
}
for (b1, b2) in fn1.blocks.iter().zip(fn2.blocks.iter()) {
if !is_block_equal(b1, b2, &mut map) {
return false;
@ -186,6 +192,9 @@ fn is_block_equal(b1: &Block, b2: &Block, map: &mut HashMap<Word, Word>) -> bool
if !is_option_equal(&b1.label, &b2.label, map, is_instr_equal) {
return false;
}
if b1.instructions.len() != b2.instructions.len() {
return false;
}
for (inst1, inst2) in b1.instructions.iter().zip(b2.instructions.iter()) {
if !is_instr_equal(inst1, inst2, map) {
return false;
@ -205,6 +214,9 @@ fn is_instr_equal(
if !is_option_equal(&instr1.result_id, &instr2.result_id, map, is_word_equal) {
return false;
}
if instr1.operands.len() != instr2.operands.len() {
return false;
}
for (o1, o2) in instr1.operands.iter().zip(instr2.operands.iter()) {
match (o1, o2) {
(Operand::IdMemorySemantics(w1), Operand::IdMemorySemantics(w2)) => {

View file

@ -3,7 +3,7 @@ use bit_vec::BitVec;
use rspirv::dr;
use std::cell::RefCell;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::{borrow::Cow, fmt, mem};
use std::{borrow::Cow, fmt, iter, mem};
use rspirv::binary::Assemble;
@ -11,7 +11,7 @@ use rspirv::binary::Assemble;
enum SpirvType {
Base(ast::ScalarType),
Extended(ast::ExtendedScalarType),
Pointer(ast::ScalarType, spirv::StorageClass),
Pointer(ast::Type, spirv::StorageClass),
}
impl From<ast::Type> for SpirvType {
@ -77,9 +77,12 @@ impl TypeWordMap {
fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> spirv::Word {
match t {
SpirvType::Base(scalar) => self.get_or_add_scalar(b, scalar),
SpirvType::Extended(t) => self.get_or_add_extended(b, t),
SpirvType::Pointer(scalar, storage) => {
let base = self.get_or_add_scalar(b, scalar);
SpirvType::Extended(ext) => self.get_or_add_extended(b, ext),
SpirvType::Pointer(typ, storage) => {
let base = match typ {
ast::Type::Scalar(scalar) => self.get_or_add_scalar(b, scalar),
ast::Type::ExtendedScalar(ext) => self.get_or_add_extended(b, ext),
};
*self
.complex
.entry(t)
@ -175,7 +178,40 @@ fn to_ssa<'a>(
let ssa_statements = insert_mem_ssa_statements(normalized_statements, &mut id_def);
let expanded_statements = expand_arguments(ssa_statements, &mut id_def);
let expanded_statements = insert_implicit_conversions(expanded_statements, &mut id_def);
(expanded_statements, id_def.ids_count())
let labeled_statements = normalize_labels(expanded_statements, &mut id_def);
(labeled_statements, id_def.ids_count())
}
fn normalize_labels(
func: Vec<Statement<ExpandedArgs>>,
id_def: &mut NumericIdResolver,
) -> Vec<Statement<ExpandedArgs>> {
let mut labels_in_use = HashSet::new();
for s in func.iter() {
match s {
Statement::Instruction(i) => {
if let Some(target) = i.jump_target() {
labels_in_use.insert(target);
}
}
Statement::Conditional(cond) => {
labels_in_use.insert(cond.if_true);
labels_in_use.insert(cond.if_false);
}
Statement::Variable(_, _, _)
| Statement::LoadVar(_, _)
| Statement::StoreVar(_, _)
| Statement::Converison(_)
| Statement::Constant(_)
| Statement::Label(_) => (),
}
}
iter::once(Statement::Label(id_def.new_id(None)))
.chain(func.into_iter().filter(|s| match s {
Statement::Label(i) => labels_in_use.contains(i),
_ => true,
}))
.collect::<Vec<_>>()
}
fn normalize_predicates(
@ -212,7 +248,9 @@ fn normalize_predicates(
result.push(Statement::Instruction(Instruction::from_ast(inst)));
}
}
ast::Statement::Variable(var) => result.push(Statement::Variable(var.name, var.v_type)),
ast::Statement::Variable(var) => {
result.push(Statement::Variable(var.name, var.v_type, var.space))
}
}
}
result
@ -225,35 +263,46 @@ fn insert_mem_ssa_statements(
let mut result = Vec::with_capacity(func.len());
for s in func {
match s {
Statement::Instruction(mut inst) => {
let inst_type = inst.get_type();
let mut post_statements = Vec::new();
inst.visit_id_mut(&mut |is_dst, id| {
let inst_type = inst_type.unwrap();
let generated_id = id_def.new_id(Some(inst_type));
if !is_dst {
result.push(Statement::LoadVar(
Arg2 {
dst: generated_id,
src: *id,
},
inst_type,
));
} else {
post_statements.push(Statement::StoreVar(
Arg2St {
src1: *id,
src2: generated_id,
},
inst_type,
));
}
*id = generated_id;
});
result.push(Statement::Instruction(inst));
result.append(&mut post_statements);
}
s @ Statement::Variable(_, _)
Statement::Instruction(inst) => match inst {
Instruction::Ld(
ld @ ast::LdData {
state_space: ast::LdStateSpace::Param,
..
},
arg,
) => {
result.push(Statement::Instruction(Instruction::Ld(ld, arg)));
}
mut inst => {
let inst_type = inst.get_type();
let mut post_statements = Vec::new();
inst.visit_id_mut(&mut |is_dst, id| {
let inst_type = inst_type.unwrap();
let generated_id = id_def.new_id(Some(inst_type));
if !is_dst {
result.push(Statement::LoadVar(
Arg2 {
dst: generated_id,
src: *id,
},
inst_type,
));
} else {
post_statements.push(Statement::StoreVar(
Arg2St {
src1: *id,
src2: generated_id,
},
inst_type,
));
}
*id = generated_id;
});
result.push(Statement::Instruction(inst));
result.append(&mut post_statements);
}
},
s @ Statement::Variable(_, _, _)
| s @ Statement::Label(_)
| s @ Statement::Conditional(_) => result.push(s),
Statement::LoadVar(_, _)
@ -273,9 +322,10 @@ fn expand_arguments(
for s in func {
match s {
Statement::Instruction(inst) => {
normalize_insert_instruction(&mut result, id_def, inst);
let new_inst = normalize_insert_instruction(&mut result, id_def, inst);
result.push(Statement::Instruction(new_inst));
}
Statement::Variable(id, typ) => result.push(Statement::Variable(id, typ)),
Statement::Variable(id, typ, ss) => result.push(Statement::Variable(id, typ, ss)),
Statement::Label(id) => result.push(Statement::Label(id)),
Statement::Conditional(bra) => result.push(Statement::Conditional(bra)),
Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)),
@ -467,11 +517,11 @@ fn normalize_expand_mov_operand(
which is bitcast to a pointer
*/
fn insert_implicit_conversions(
normalized_ids: Vec<ExpandedStatement>,
func: Vec<ExpandedStatement>,
id_def: &mut NumericIdResolver,
) -> Vec<ExpandedStatement> {
let mut result = Vec::with_capacity(normalized_ids.len());
for s in normalized_ids.into_iter() {
let mut result = Vec::with_capacity(func.len());
for s in func.into_iter() {
match s {
Statement::Instruction(inst) => match inst {
Instruction::Ld(ld, mut arg) => {
@ -515,11 +565,12 @@ fn insert_implicit_conversions(
}
inst @ _ => insert_implicit_bitcasts(&mut result, id_def, inst),
},
s @ Statement::Conditional(_) | s @ Statement::Label(_) => result.push(s),
Statement::Constant(_)
| Statement::Variable(_, _)
| Statement::LoadVar(_, _)
| Statement::StoreVar(_, _) => (),
s @ Statement::Conditional(_)
| s @ Statement::Label(_)
| s @ Statement::Constant(_)
| s @ Statement::Variable(_, _, _)
| s @ Statement::LoadVar(_, _)
| s @ Statement::StoreVar(_, _) => result.push(s),
Statement::Converison(_) => unreachable!(),
}
}
@ -591,9 +642,22 @@ fn emit_function_body_ops(
) -> Result<(), dr::Error> {
for s in func {
match s {
// If block starts with a label it has already been emitted,
// all other labels in the block are unused
Statement::Label(_) => (),
Statement::Label(id) => {
if builder.block.is_some() {
builder.branch(*id)?;
}
builder.begin_block(Some(*id))?;
}
Statement::Variable(id, typ, ss) => {
let type_id = map.get_or_add(
builder,
SpirvType::Pointer(*typ, spirv::StorageClass::Function),
);
if *ss != ast::StateSpace::Reg {
todo!()
}
builder.variable(type_id, Some(*id), spirv::StorageClass::Function, None);
}
Statement::Constant(_) => todo!(),
Statement::Converison(cv) => emit_implicit_conversion(builder, map, cv)?,
Statement::Conditional(bra) => {
@ -614,7 +678,7 @@ fn emit_function_body_ops(
builder.load(result_type, Some(arg.dst), arg.src, None, [])?;
}
ast::LdStateSpace::Param => {
builder.copy_object(result_type, Some(arg.dst), arg.src)?;
builder.store(arg.dst, arg.src, None, [])?;
}
_ => todo!(),
}
@ -642,7 +706,16 @@ fn emit_function_body_ops(
},
_ => todo!(),
},
_ => todo!(),
Statement::LoadVar(arg, typ) => {
let type_id = map.get_or_add(
builder,
SpirvType::from(*typ),
);
builder.load(type_id, Some(arg.dst), arg.src, None, [])?;
}
Statement::StoreVar(arg, typ) => {
builder.store(arg.src1, arg.src2, None, [])?;
}
}
}
Ok(())
@ -672,7 +745,10 @@ fn emit_implicit_conversion(
ConversionKind::Ptr => {
let dst_type = map.get_or_add(
builder,
SpirvType::Pointer(to_type, spirv_headers::StorageClass::Generic),
SpirvType::Pointer(
ast::Type::Scalar(to_type),
spirv_headers::StorageClass::Generic,
),
);
builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?;
}
@ -856,7 +932,7 @@ impl NumericIdResolver {
}
enum Statement<A: Args> {
Variable(spirv::Word, ast::Type),
Variable(spirv::Word, ast::Type, ast::StateSpace),
LoadVar(Arg2, ast::Type),
StoreVar(Arg2St, ast::Type),
Label(u32),
@ -870,7 +946,7 @@ enum Statement<A: Args> {
impl<A: Args> Statement<A> {
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
match self {
Statement::Variable(id, _) => f(true, id),
Statement::Variable(id, _, _) => f(true, id),
Statement::LoadVar(a, _) => a.visit_id_mut(f),
Statement::StoreVar(a, _) => a.visit_id_mut(f),
Statement::Label(id) => f(false, id),