mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
Fix bugs in the new middle-end
This commit is contained in:
parent
09be47a919
commit
18e5147fdc
3 changed files with 180 additions and 79 deletions
|
@ -1,25 +1,38 @@
|
|||
OpCapability GenericPointer
|
||||
OpCapability Linkage
|
||||
OpCapability Addresses
|
||||
OpCapability Kernel
|
||||
OpCapability Int64
|
||||
OpCapability Int8
|
||||
%1 = OpExtInstImport "OpenCL.std"
|
||||
OpMemoryModel Physical64 OpenCL
|
||||
OpEntryPoint Kernel %5 "ld_st"
|
||||
%2 = OpTypeVoid
|
||||
%3 = OpTypeInt 64 0
|
||||
%4 = OpTypeFunction %2 %3 %3
|
||||
%19 = OpTypePointer Generic %3
|
||||
%5 = OpFunction %2 None %4
|
||||
%6 = OpFunctionParameter %3
|
||||
%7 = OpFunctionParameter %3
|
||||
%18 = OpLabel
|
||||
%13 = OpCopyObject %3 %6
|
||||
%14 = OpCopyObject %3 %7
|
||||
%15 = OpConvertUToPtr %19 %13
|
||||
%16 = OpLoad %3 %15
|
||||
%17 = OpConvertUToPtr %19 %14
|
||||
OpStore %17 %16
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
; SPIR-V
|
||||
; Version: 1.5
|
||||
; Generator: Khronos SPIR-V Tools Assembler; 0
|
||||
; Bound: 20
|
||||
; Schema: 0
|
||||
OpCapability GenericPointer
|
||||
OpCapability Linkage
|
||||
OpCapability Addresses
|
||||
OpCapability Kernel
|
||||
OpCapability Int64
|
||||
OpCapability Int8
|
||||
%1 = OpExtInstImport "OpenCL.std"
|
||||
OpMemoryModel Physical64 OpenCL
|
||||
OpEntryPoint Kernel %2 "ld_st"
|
||||
%void = OpTypeVoid
|
||||
%ulong = OpTypeInt 64 0
|
||||
%5 = OpTypeFunction %void %ulong %ulong
|
||||
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
|
||||
%2 = OpFunction %void None %5
|
||||
%8 = OpFunctionParameter %ulong
|
||||
%9 = OpFunctionParameter %ulong
|
||||
%10 = OpLabel
|
||||
%11 = OpVariable %_ptr_Function_ulong Function
|
||||
%12 = OpVariable %_ptr_Function_ulong Function
|
||||
%13 = OpVariable %_ptr_Function_ulong Function
|
||||
OpStore %11 %8
|
||||
OpStore %12 %9
|
||||
%14 = OpLoad %ulong %11
|
||||
%15 = OpConvertUToPtr %_ptr_Generic_ulong %14
|
||||
%16 = OpLoad %ulong %15
|
||||
OpStore %13 %16
|
||||
%17 = OpLoad %ulong %12
|
||||
%18 = OpLoad %ulong %13
|
||||
%19 = OpConvertUToPtr %_ptr_Generic_ulong %17
|
||||
OpStore %19 %18
|
||||
OpReturn
|
||||
OpFunctionEnd
|
|
@ -169,11 +169,17 @@ fn is_spirv_fn_equal(fn1: &Function, fn2: &Function) -> bool {
|
|||
if !is_option_equal(&fn1.end, &fn2.end, &mut map, is_instr_equal) {
|
||||
return false;
|
||||
}
|
||||
if fn1.parameters.len() != fn2.parameters.len() {
|
||||
return false;
|
||||
}
|
||||
for (inst1, inst2) in fn1.parameters.iter().zip(fn2.parameters.iter()) {
|
||||
if !is_instr_equal(inst1, inst2, &mut map) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if fn1.blocks.len() != fn2.blocks.len() {
|
||||
return false;
|
||||
}
|
||||
for (b1, b2) in fn1.blocks.iter().zip(fn2.blocks.iter()) {
|
||||
if !is_block_equal(b1, b2, &mut map) {
|
||||
return false;
|
||||
|
@ -186,6 +192,9 @@ fn is_block_equal(b1: &Block, b2: &Block, map: &mut HashMap<Word, Word>) -> bool
|
|||
if !is_option_equal(&b1.label, &b2.label, map, is_instr_equal) {
|
||||
return false;
|
||||
}
|
||||
if b1.instructions.len() != b2.instructions.len() {
|
||||
return false;
|
||||
}
|
||||
for (inst1, inst2) in b1.instructions.iter().zip(b2.instructions.iter()) {
|
||||
if !is_instr_equal(inst1, inst2, map) {
|
||||
return false;
|
||||
|
@ -205,6 +214,9 @@ fn is_instr_equal(
|
|||
if !is_option_equal(&instr1.result_id, &instr2.result_id, map, is_word_equal) {
|
||||
return false;
|
||||
}
|
||||
if instr1.operands.len() != instr2.operands.len() {
|
||||
return false;
|
||||
}
|
||||
for (o1, o2) in instr1.operands.iter().zip(instr2.operands.iter()) {
|
||||
match (o1, o2) {
|
||||
(Operand::IdMemorySemantics(w1), Operand::IdMemorySemantics(w2)) => {
|
||||
|
|
|
@ -3,7 +3,7 @@ use bit_vec::BitVec;
|
|||
use rspirv::dr;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::{BTreeMap, HashMap, HashSet};
|
||||
use std::{borrow::Cow, fmt, mem};
|
||||
use std::{borrow::Cow, fmt, iter, mem};
|
||||
|
||||
use rspirv::binary::Assemble;
|
||||
|
||||
|
@ -11,7 +11,7 @@ use rspirv::binary::Assemble;
|
|||
enum SpirvType {
|
||||
Base(ast::ScalarType),
|
||||
Extended(ast::ExtendedScalarType),
|
||||
Pointer(ast::ScalarType, spirv::StorageClass),
|
||||
Pointer(ast::Type, spirv::StorageClass),
|
||||
}
|
||||
|
||||
impl From<ast::Type> for SpirvType {
|
||||
|
@ -77,9 +77,12 @@ impl TypeWordMap {
|
|||
fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> spirv::Word {
|
||||
match t {
|
||||
SpirvType::Base(scalar) => self.get_or_add_scalar(b, scalar),
|
||||
SpirvType::Extended(t) => self.get_or_add_extended(b, t),
|
||||
SpirvType::Pointer(scalar, storage) => {
|
||||
let base = self.get_or_add_scalar(b, scalar);
|
||||
SpirvType::Extended(ext) => self.get_or_add_extended(b, ext),
|
||||
SpirvType::Pointer(typ, storage) => {
|
||||
let base = match typ {
|
||||
ast::Type::Scalar(scalar) => self.get_or_add_scalar(b, scalar),
|
||||
ast::Type::ExtendedScalar(ext) => self.get_or_add_extended(b, ext),
|
||||
};
|
||||
*self
|
||||
.complex
|
||||
.entry(t)
|
||||
|
@ -175,7 +178,40 @@ fn to_ssa<'a>(
|
|||
let ssa_statements = insert_mem_ssa_statements(normalized_statements, &mut id_def);
|
||||
let expanded_statements = expand_arguments(ssa_statements, &mut id_def);
|
||||
let expanded_statements = insert_implicit_conversions(expanded_statements, &mut id_def);
|
||||
(expanded_statements, id_def.ids_count())
|
||||
let labeled_statements = normalize_labels(expanded_statements, &mut id_def);
|
||||
(labeled_statements, id_def.ids_count())
|
||||
}
|
||||
|
||||
fn normalize_labels(
|
||||
func: Vec<Statement<ExpandedArgs>>,
|
||||
id_def: &mut NumericIdResolver,
|
||||
) -> Vec<Statement<ExpandedArgs>> {
|
||||
let mut labels_in_use = HashSet::new();
|
||||
for s in func.iter() {
|
||||
match s {
|
||||
Statement::Instruction(i) => {
|
||||
if let Some(target) = i.jump_target() {
|
||||
labels_in_use.insert(target);
|
||||
}
|
||||
}
|
||||
Statement::Conditional(cond) => {
|
||||
labels_in_use.insert(cond.if_true);
|
||||
labels_in_use.insert(cond.if_false);
|
||||
}
|
||||
Statement::Variable(_, _, _)
|
||||
| Statement::LoadVar(_, _)
|
||||
| Statement::StoreVar(_, _)
|
||||
| Statement::Converison(_)
|
||||
| Statement::Constant(_)
|
||||
| Statement::Label(_) => (),
|
||||
}
|
||||
}
|
||||
iter::once(Statement::Label(id_def.new_id(None)))
|
||||
.chain(func.into_iter().filter(|s| match s {
|
||||
Statement::Label(i) => labels_in_use.contains(i),
|
||||
_ => true,
|
||||
}))
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
fn normalize_predicates(
|
||||
|
@ -212,7 +248,9 @@ fn normalize_predicates(
|
|||
result.push(Statement::Instruction(Instruction::from_ast(inst)));
|
||||
}
|
||||
}
|
||||
ast::Statement::Variable(var) => result.push(Statement::Variable(var.name, var.v_type)),
|
||||
ast::Statement::Variable(var) => {
|
||||
result.push(Statement::Variable(var.name, var.v_type, var.space))
|
||||
}
|
||||
}
|
||||
}
|
||||
result
|
||||
|
@ -225,35 +263,46 @@ fn insert_mem_ssa_statements(
|
|||
let mut result = Vec::with_capacity(func.len());
|
||||
for s in func {
|
||||
match s {
|
||||
Statement::Instruction(mut inst) => {
|
||||
let inst_type = inst.get_type();
|
||||
let mut post_statements = Vec::new();
|
||||
inst.visit_id_mut(&mut |is_dst, id| {
|
||||
let inst_type = inst_type.unwrap();
|
||||
let generated_id = id_def.new_id(Some(inst_type));
|
||||
if !is_dst {
|
||||
result.push(Statement::LoadVar(
|
||||
Arg2 {
|
||||
dst: generated_id,
|
||||
src: *id,
|
||||
},
|
||||
inst_type,
|
||||
));
|
||||
} else {
|
||||
post_statements.push(Statement::StoreVar(
|
||||
Arg2St {
|
||||
src1: *id,
|
||||
src2: generated_id,
|
||||
},
|
||||
inst_type,
|
||||
));
|
||||
}
|
||||
*id = generated_id;
|
||||
});
|
||||
result.push(Statement::Instruction(inst));
|
||||
result.append(&mut post_statements);
|
||||
}
|
||||
s @ Statement::Variable(_, _)
|
||||
Statement::Instruction(inst) => match inst {
|
||||
Instruction::Ld(
|
||||
ld @ ast::LdData {
|
||||
state_space: ast::LdStateSpace::Param,
|
||||
..
|
||||
},
|
||||
arg,
|
||||
) => {
|
||||
result.push(Statement::Instruction(Instruction::Ld(ld, arg)));
|
||||
}
|
||||
mut inst => {
|
||||
let inst_type = inst.get_type();
|
||||
let mut post_statements = Vec::new();
|
||||
inst.visit_id_mut(&mut |is_dst, id| {
|
||||
let inst_type = inst_type.unwrap();
|
||||
let generated_id = id_def.new_id(Some(inst_type));
|
||||
if !is_dst {
|
||||
result.push(Statement::LoadVar(
|
||||
Arg2 {
|
||||
dst: generated_id,
|
||||
src: *id,
|
||||
},
|
||||
inst_type,
|
||||
));
|
||||
} else {
|
||||
post_statements.push(Statement::StoreVar(
|
||||
Arg2St {
|
||||
src1: *id,
|
||||
src2: generated_id,
|
||||
},
|
||||
inst_type,
|
||||
));
|
||||
}
|
||||
*id = generated_id;
|
||||
});
|
||||
result.push(Statement::Instruction(inst));
|
||||
result.append(&mut post_statements);
|
||||
}
|
||||
},
|
||||
s @ Statement::Variable(_, _, _)
|
||||
| s @ Statement::Label(_)
|
||||
| s @ Statement::Conditional(_) => result.push(s),
|
||||
Statement::LoadVar(_, _)
|
||||
|
@ -273,9 +322,10 @@ fn expand_arguments(
|
|||
for s in func {
|
||||
match s {
|
||||
Statement::Instruction(inst) => {
|
||||
normalize_insert_instruction(&mut result, id_def, inst);
|
||||
let new_inst = normalize_insert_instruction(&mut result, id_def, inst);
|
||||
result.push(Statement::Instruction(new_inst));
|
||||
}
|
||||
Statement::Variable(id, typ) => result.push(Statement::Variable(id, typ)),
|
||||
Statement::Variable(id, typ, ss) => result.push(Statement::Variable(id, typ, ss)),
|
||||
Statement::Label(id) => result.push(Statement::Label(id)),
|
||||
Statement::Conditional(bra) => result.push(Statement::Conditional(bra)),
|
||||
Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)),
|
||||
|
@ -467,11 +517,11 @@ fn normalize_expand_mov_operand(
|
|||
which is bitcast to a pointer
|
||||
*/
|
||||
fn insert_implicit_conversions(
|
||||
normalized_ids: Vec<ExpandedStatement>,
|
||||
func: Vec<ExpandedStatement>,
|
||||
id_def: &mut NumericIdResolver,
|
||||
) -> Vec<ExpandedStatement> {
|
||||
let mut result = Vec::with_capacity(normalized_ids.len());
|
||||
for s in normalized_ids.into_iter() {
|
||||
let mut result = Vec::with_capacity(func.len());
|
||||
for s in func.into_iter() {
|
||||
match s {
|
||||
Statement::Instruction(inst) => match inst {
|
||||
Instruction::Ld(ld, mut arg) => {
|
||||
|
@ -515,11 +565,12 @@ fn insert_implicit_conversions(
|
|||
}
|
||||
inst @ _ => insert_implicit_bitcasts(&mut result, id_def, inst),
|
||||
},
|
||||
s @ Statement::Conditional(_) | s @ Statement::Label(_) => result.push(s),
|
||||
Statement::Constant(_)
|
||||
| Statement::Variable(_, _)
|
||||
| Statement::LoadVar(_, _)
|
||||
| Statement::StoreVar(_, _) => (),
|
||||
s @ Statement::Conditional(_)
|
||||
| s @ Statement::Label(_)
|
||||
| s @ Statement::Constant(_)
|
||||
| s @ Statement::Variable(_, _, _)
|
||||
| s @ Statement::LoadVar(_, _)
|
||||
| s @ Statement::StoreVar(_, _) => result.push(s),
|
||||
Statement::Converison(_) => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
@ -591,9 +642,22 @@ fn emit_function_body_ops(
|
|||
) -> Result<(), dr::Error> {
|
||||
for s in func {
|
||||
match s {
|
||||
// If block starts with a label it has already been emitted,
|
||||
// all other labels in the block are unused
|
||||
Statement::Label(_) => (),
|
||||
Statement::Label(id) => {
|
||||
if builder.block.is_some() {
|
||||
builder.branch(*id)?;
|
||||
}
|
||||
builder.begin_block(Some(*id))?;
|
||||
}
|
||||
Statement::Variable(id, typ, ss) => {
|
||||
let type_id = map.get_or_add(
|
||||
builder,
|
||||
SpirvType::Pointer(*typ, spirv::StorageClass::Function),
|
||||
);
|
||||
if *ss != ast::StateSpace::Reg {
|
||||
todo!()
|
||||
}
|
||||
builder.variable(type_id, Some(*id), spirv::StorageClass::Function, None);
|
||||
}
|
||||
Statement::Constant(_) => todo!(),
|
||||
Statement::Converison(cv) => emit_implicit_conversion(builder, map, cv)?,
|
||||
Statement::Conditional(bra) => {
|
||||
|
@ -614,7 +678,7 @@ fn emit_function_body_ops(
|
|||
builder.load(result_type, Some(arg.dst), arg.src, None, [])?;
|
||||
}
|
||||
ast::LdStateSpace::Param => {
|
||||
builder.copy_object(result_type, Some(arg.dst), arg.src)?;
|
||||
builder.store(arg.dst, arg.src, None, [])?;
|
||||
}
|
||||
_ => todo!(),
|
||||
}
|
||||
|
@ -642,7 +706,16 @@ fn emit_function_body_ops(
|
|||
},
|
||||
_ => todo!(),
|
||||
},
|
||||
_ => todo!(),
|
||||
Statement::LoadVar(arg, typ) => {
|
||||
let type_id = map.get_or_add(
|
||||
builder,
|
||||
SpirvType::from(*typ),
|
||||
);
|
||||
builder.load(type_id, Some(arg.dst), arg.src, None, [])?;
|
||||
}
|
||||
Statement::StoreVar(arg, typ) => {
|
||||
builder.store(arg.src1, arg.src2, None, [])?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
|
@ -672,7 +745,10 @@ fn emit_implicit_conversion(
|
|||
ConversionKind::Ptr => {
|
||||
let dst_type = map.get_or_add(
|
||||
builder,
|
||||
SpirvType::Pointer(to_type, spirv_headers::StorageClass::Generic),
|
||||
SpirvType::Pointer(
|
||||
ast::Type::Scalar(to_type),
|
||||
spirv_headers::StorageClass::Generic,
|
||||
),
|
||||
);
|
||||
builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?;
|
||||
}
|
||||
|
@ -856,7 +932,7 @@ impl NumericIdResolver {
|
|||
}
|
||||
|
||||
enum Statement<A: Args> {
|
||||
Variable(spirv::Word, ast::Type),
|
||||
Variable(spirv::Word, ast::Type, ast::StateSpace),
|
||||
LoadVar(Arg2, ast::Type),
|
||||
StoreVar(Arg2St, ast::Type),
|
||||
Label(u32),
|
||||
|
@ -870,7 +946,7 @@ enum Statement<A: Args> {
|
|||
impl<A: Args> Statement<A> {
|
||||
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
|
||||
match self {
|
||||
Statement::Variable(id, _) => f(true, id),
|
||||
Statement::Variable(id, _, _) => f(true, id),
|
||||
Statement::LoadVar(a, _) => a.visit_id_mut(f),
|
||||
Statement::StoreVar(a, _) => a.visit_id_mut(f),
|
||||
Statement::Label(id) => f(false, id),
|
||||
|
|
Loading…
Add table
Reference in a new issue