From 18e5147fdcd6d32294f3349fd5849de43cf01800 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 24 Jul 2020 02:09:50 +0200 Subject: [PATCH] Fix bugs in the new middle-end --- ptx/src/test/spirv_run/ld_st.spvtxt | 63 ++++++---- ptx/src/test/spirv_run/mod.rs | 12 ++ ptx/src/translate.rs | 184 ++++++++++++++++++++-------- 3 files changed, 180 insertions(+), 79 deletions(-) diff --git a/ptx/src/test/spirv_run/ld_st.spvtxt b/ptx/src/test/spirv_run/ld_st.spvtxt index 33bd251..1cb7094 100644 --- a/ptx/src/test/spirv_run/ld_st.spvtxt +++ b/ptx/src/test/spirv_run/ld_st.spvtxt @@ -1,25 +1,38 @@ -OpCapability GenericPointer -OpCapability Linkage -OpCapability Addresses -OpCapability Kernel -OpCapability Int64 -OpCapability Int8 -%1 = OpExtInstImport "OpenCL.std" -OpMemoryModel Physical64 OpenCL -OpEntryPoint Kernel %5 "ld_st" -%2 = OpTypeVoid -%3 = OpTypeInt 64 0 -%4 = OpTypeFunction %2 %3 %3 -%19 = OpTypePointer Generic %3 -%5 = OpFunction %2 None %4 -%6 = OpFunctionParameter %3 -%7 = OpFunctionParameter %3 -%18 = OpLabel -%13 = OpCopyObject %3 %6 -%14 = OpCopyObject %3 %7 -%15 = OpConvertUToPtr %19 %13 -%16 = OpLoad %3 %15 -%17 = OpConvertUToPtr %19 %14 -OpStore %17 %16 -OpReturn -OpFunctionEnd +; SPIR-V +; Version: 1.5 +; Generator: Khronos SPIR-V Tools Assembler; 0 +; Bound: 20 +; Schema: 0 + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + %1 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %2 "ld_st" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %5 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %2 = OpFunction %void None %5 + %8 = OpFunctionParameter %ulong + %9 = OpFunctionParameter %ulong + %10 = OpLabel + %11 = OpVariable %_ptr_Function_ulong Function + %12 = OpVariable %_ptr_Function_ulong Function + %13 = OpVariable %_ptr_Function_ulong Function + OpStore %11 %8 + OpStore %12 %9 + %14 = OpLoad %ulong %11 + %15 = OpConvertUToPtr %_ptr_Generic_ulong %14 + %16 = OpLoad %ulong %15 + OpStore %13 %16 + %17 = OpLoad %ulong %12 + %18 = OpLoad %ulong %13 + %19 = OpConvertUToPtr %_ptr_Generic_ulong %17 + OpStore %19 %18 + OpReturn + OpFunctionEnd \ No newline at end of file diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index b374324..394e757 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -169,11 +169,17 @@ fn is_spirv_fn_equal(fn1: &Function, fn2: &Function) -> bool { if !is_option_equal(&fn1.end, &fn2.end, &mut map, is_instr_equal) { return false; } + if fn1.parameters.len() != fn2.parameters.len() { + return false; + } for (inst1, inst2) in fn1.parameters.iter().zip(fn2.parameters.iter()) { if !is_instr_equal(inst1, inst2, &mut map) { return false; } } + if fn1.blocks.len() != fn2.blocks.len() { + return false; + } for (b1, b2) in fn1.blocks.iter().zip(fn2.blocks.iter()) { if !is_block_equal(b1, b2, &mut map) { return false; @@ -186,6 +192,9 @@ fn is_block_equal(b1: &Block, b2: &Block, map: &mut HashMap) -> bool if !is_option_equal(&b1.label, &b2.label, map, is_instr_equal) { return false; } + if b1.instructions.len() != b2.instructions.len() { + return false; + } for (inst1, inst2) in b1.instructions.iter().zip(b2.instructions.iter()) { if !is_instr_equal(inst1, inst2, map) { return false; @@ -205,6 +214,9 @@ fn is_instr_equal( if !is_option_equal(&instr1.result_id, &instr2.result_id, map, is_word_equal) { return false; } + if instr1.operands.len() != instr2.operands.len() { + return false; + } for (o1, o2) in instr1.operands.iter().zip(instr2.operands.iter()) { match (o1, o2) { (Operand::IdMemorySemantics(w1), Operand::IdMemorySemantics(w2)) => { diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 7cce63c..1ad077c 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -3,7 +3,7 @@ use bit_vec::BitVec; use rspirv::dr; use std::cell::RefCell; use std::collections::{BTreeMap, HashMap, HashSet}; -use std::{borrow::Cow, fmt, mem}; +use std::{borrow::Cow, fmt, iter, mem}; use rspirv::binary::Assemble; @@ -11,7 +11,7 @@ use rspirv::binary::Assemble; enum SpirvType { Base(ast::ScalarType), Extended(ast::ExtendedScalarType), - Pointer(ast::ScalarType, spirv::StorageClass), + Pointer(ast::Type, spirv::StorageClass), } impl From for SpirvType { @@ -77,9 +77,12 @@ impl TypeWordMap { fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> spirv::Word { match t { SpirvType::Base(scalar) => self.get_or_add_scalar(b, scalar), - SpirvType::Extended(t) => self.get_or_add_extended(b, t), - SpirvType::Pointer(scalar, storage) => { - let base = self.get_or_add_scalar(b, scalar); + SpirvType::Extended(ext) => self.get_or_add_extended(b, ext), + SpirvType::Pointer(typ, storage) => { + let base = match typ { + ast::Type::Scalar(scalar) => self.get_or_add_scalar(b, scalar), + ast::Type::ExtendedScalar(ext) => self.get_or_add_extended(b, ext), + }; *self .complex .entry(t) @@ -175,7 +178,40 @@ fn to_ssa<'a>( let ssa_statements = insert_mem_ssa_statements(normalized_statements, &mut id_def); let expanded_statements = expand_arguments(ssa_statements, &mut id_def); let expanded_statements = insert_implicit_conversions(expanded_statements, &mut id_def); - (expanded_statements, id_def.ids_count()) + let labeled_statements = normalize_labels(expanded_statements, &mut id_def); + (labeled_statements, id_def.ids_count()) +} + +fn normalize_labels( + func: Vec>, + id_def: &mut NumericIdResolver, +) -> Vec> { + let mut labels_in_use = HashSet::new(); + for s in func.iter() { + match s { + Statement::Instruction(i) => { + if let Some(target) = i.jump_target() { + labels_in_use.insert(target); + } + } + Statement::Conditional(cond) => { + labels_in_use.insert(cond.if_true); + labels_in_use.insert(cond.if_false); + } + Statement::Variable(_, _, _) + | Statement::LoadVar(_, _) + | Statement::StoreVar(_, _) + | Statement::Converison(_) + | Statement::Constant(_) + | Statement::Label(_) => (), + } + } + iter::once(Statement::Label(id_def.new_id(None))) + .chain(func.into_iter().filter(|s| match s { + Statement::Label(i) => labels_in_use.contains(i), + _ => true, + })) + .collect::>() } fn normalize_predicates( @@ -212,7 +248,9 @@ fn normalize_predicates( result.push(Statement::Instruction(Instruction::from_ast(inst))); } } - ast::Statement::Variable(var) => result.push(Statement::Variable(var.name, var.v_type)), + ast::Statement::Variable(var) => { + result.push(Statement::Variable(var.name, var.v_type, var.space)) + } } } result @@ -225,35 +263,46 @@ fn insert_mem_ssa_statements( let mut result = Vec::with_capacity(func.len()); for s in func { match s { - Statement::Instruction(mut inst) => { - let inst_type = inst.get_type(); - let mut post_statements = Vec::new(); - inst.visit_id_mut(&mut |is_dst, id| { - let inst_type = inst_type.unwrap(); - let generated_id = id_def.new_id(Some(inst_type)); - if !is_dst { - result.push(Statement::LoadVar( - Arg2 { - dst: generated_id, - src: *id, - }, - inst_type, - )); - } else { - post_statements.push(Statement::StoreVar( - Arg2St { - src1: *id, - src2: generated_id, - }, - inst_type, - )); - } - *id = generated_id; - }); - result.push(Statement::Instruction(inst)); - result.append(&mut post_statements); - } - s @ Statement::Variable(_, _) + Statement::Instruction(inst) => match inst { + Instruction::Ld( + ld @ ast::LdData { + state_space: ast::LdStateSpace::Param, + .. + }, + arg, + ) => { + result.push(Statement::Instruction(Instruction::Ld(ld, arg))); + } + mut inst => { + let inst_type = inst.get_type(); + let mut post_statements = Vec::new(); + inst.visit_id_mut(&mut |is_dst, id| { + let inst_type = inst_type.unwrap(); + let generated_id = id_def.new_id(Some(inst_type)); + if !is_dst { + result.push(Statement::LoadVar( + Arg2 { + dst: generated_id, + src: *id, + }, + inst_type, + )); + } else { + post_statements.push(Statement::StoreVar( + Arg2St { + src1: *id, + src2: generated_id, + }, + inst_type, + )); + } + *id = generated_id; + }); + result.push(Statement::Instruction(inst)); + result.append(&mut post_statements); + } + }, + s @ Statement::Variable(_, _, _) | s @ Statement::Label(_) | s @ Statement::Conditional(_) => result.push(s), Statement::LoadVar(_, _) @@ -273,9 +322,10 @@ fn expand_arguments( for s in func { match s { Statement::Instruction(inst) => { - normalize_insert_instruction(&mut result, id_def, inst); + let new_inst = normalize_insert_instruction(&mut result, id_def, inst); + result.push(Statement::Instruction(new_inst)); } - Statement::Variable(id, typ) => result.push(Statement::Variable(id, typ)), + Statement::Variable(id, typ, ss) => result.push(Statement::Variable(id, typ, ss)), Statement::Label(id) => result.push(Statement::Label(id)), Statement::Conditional(bra) => result.push(Statement::Conditional(bra)), Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)), @@ -467,11 +517,11 @@ fn normalize_expand_mov_operand( which is bitcast to a pointer */ fn insert_implicit_conversions( - normalized_ids: Vec, + func: Vec, id_def: &mut NumericIdResolver, ) -> Vec { - let mut result = Vec::with_capacity(normalized_ids.len()); - for s in normalized_ids.into_iter() { + let mut result = Vec::with_capacity(func.len()); + for s in func.into_iter() { match s { Statement::Instruction(inst) => match inst { Instruction::Ld(ld, mut arg) => { @@ -515,11 +565,12 @@ fn insert_implicit_conversions( } inst @ _ => insert_implicit_bitcasts(&mut result, id_def, inst), }, - s @ Statement::Conditional(_) | s @ Statement::Label(_) => result.push(s), - Statement::Constant(_) - | Statement::Variable(_, _) - | Statement::LoadVar(_, _) - | Statement::StoreVar(_, _) => (), + s @ Statement::Conditional(_) + | s @ Statement::Label(_) + | s @ Statement::Constant(_) + | s @ Statement::Variable(_, _, _) + | s @ Statement::LoadVar(_, _) + | s @ Statement::StoreVar(_, _) => result.push(s), Statement::Converison(_) => unreachable!(), } } @@ -591,9 +642,22 @@ fn emit_function_body_ops( ) -> Result<(), dr::Error> { for s in func { match s { - // If block starts with a label it has already been emitted, - // all other labels in the block are unused - Statement::Label(_) => (), + Statement::Label(id) => { + if builder.block.is_some() { + builder.branch(*id)?; + } + builder.begin_block(Some(*id))?; + } + Statement::Variable(id, typ, ss) => { + let type_id = map.get_or_add( + builder, + SpirvType::Pointer(*typ, spirv::StorageClass::Function), + ); + if *ss != ast::StateSpace::Reg { + todo!() + } + builder.variable(type_id, Some(*id), spirv::StorageClass::Function, None); + } Statement::Constant(_) => todo!(), Statement::Converison(cv) => emit_implicit_conversion(builder, map, cv)?, Statement::Conditional(bra) => { @@ -614,7 +678,7 @@ fn emit_function_body_ops( builder.load(result_type, Some(arg.dst), arg.src, None, [])?; } ast::LdStateSpace::Param => { - builder.copy_object(result_type, Some(arg.dst), arg.src)?; + builder.store(arg.dst, arg.src, None, [])?; } _ => todo!(), } @@ -642,7 +706,16 @@ fn emit_function_body_ops( }, _ => todo!(), }, - _ => todo!(), + Statement::LoadVar(arg, typ) => { + let type_id = map.get_or_add( + builder, + SpirvType::from(*typ), + ); + builder.load(type_id, Some(arg.dst), arg.src, None, [])?; + } + Statement::StoreVar(arg, typ) => { + builder.store(arg.src1, arg.src2, None, [])?; + } } } Ok(()) @@ -672,7 +745,10 @@ fn emit_implicit_conversion( ConversionKind::Ptr => { let dst_type = map.get_or_add( builder, - SpirvType::Pointer(to_type, spirv_headers::StorageClass::Generic), + SpirvType::Pointer( + ast::Type::Scalar(to_type), + spirv_headers::StorageClass::Generic, + ), ); builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?; } @@ -856,7 +932,7 @@ impl NumericIdResolver { } enum Statement { - Variable(spirv::Word, ast::Type), + Variable(spirv::Word, ast::Type, ast::StateSpace), LoadVar(Arg2, ast::Type), StoreVar(Arg2St, ast::Type), Label(u32), @@ -870,7 +946,7 @@ enum Statement { impl Statement { fn visit_id_mut(&mut self, f: &mut F) { match self { - Statement::Variable(id, _) => f(true, id), + Statement::Variable(id, _, _) => f(true, id), Statement::LoadVar(a, _) => a.visit_id_mut(f), Statement::StoreVar(a, _) => a.visit_id_mut(f), Statement::Label(id) => f(false, id),