From aca61fdcf9031c2aaf7c6f995a7940e4075e9797 Mon Sep 17 00:00:00 2001 From: Nekotekina Date: Wed, 24 Apr 2019 23:53:49 +0300 Subject: [PATCH] LLVM DSL: implement expression matching (preview) Only literal match for binary ops implemented. --- rpcs3/Emu/CPU/CPUTranslator.h | 468 ++++++++++++++++++++++++++++++---- 1 file changed, 413 insertions(+), 55 deletions(-) diff --git a/rpcs3/Emu/CPU/CPUTranslator.h b/rpcs3/Emu/CPU/CPUTranslator.h index 0288c81f1f..ea5476064a 100644 --- a/rpcs3/Emu/CPU/CPUTranslator.h +++ b/rpcs3/Emu/CPU/CPUTranslator.h @@ -60,6 +60,16 @@ struct llvm_value_t return value; } + std::tuple<> match(llvm::Value*& value) const + { + if (value != this->value) + { + value = nullptr; + } + + return {}; + } + llvm::Value* value; // llvm_value_t() = default; @@ -361,6 +371,9 @@ struct is_llvm_expr_of::type, typena template using llvm_common_t = std::enable_if_t<(is_llvm_expr_of::ok && ...), typename is_llvm_expr::type>; +template +using llvm_match_tuple = decltype(std::tuple_cat(std::declval&>().match(std::declval())...)); + template >> struct llvm_match_t { @@ -377,6 +390,38 @@ struct llvm_match_t { return value; } + + std::tuple<> match(llvm::Value*& value) const + { + if (value != this->value) + { + value = nullptr; + } + + return {}; + } +}; + +template >> +struct llvm_placeholder_t +{ + using type = T; + + llvm::Value* eval(llvm::IRBuilder<>* ir) const + { + return nullptr; + } + + std::tuple> match(llvm::Value*& value) const + { + if (value && value->getType() == llvm_value_t::get_type(value->getContext())) + { + return {value}; + } + + value = nullptr; + return {}; + } }; template @@ -394,6 +439,17 @@ struct llvm_const_int return llvm::ConstantInt::get(llvm_value_t::get_type(ir->getContext()), val, ForceSigned || llvm_value_t::is_sint); } + + std::tuple<> match(llvm::Value*& value) const + { + if (value && value == llvm::ConstantInt::get(llvm_value_t::get_type(value->getContext()), val, ForceSigned || llvm_value_t::is_sint)) + { + return {}; + } + + value = nullptr; + return {}; + } }; template @@ -411,6 +467,17 @@ struct llvm_const_float return llvm::ConstantFP::get(llvm_value_t::get_type(ir->getContext()), val); } + + std::tuple<> match(llvm::Value*& value) const + { + if (value && value == llvm::ConstantFP::get(llvm_value_t::get_type(value->getContext()), val)) + { + return {}; + } + + value = nullptr; + return {}; + } }; template @@ -428,6 +495,17 @@ struct llvm_const_vector return llvm::ConstantDataVector::get(ir->getContext(), data); } + + std::tuple<> match(llvm::Value*& value) const + { + if (value && value == llvm::ConstantDataVector::get(value->getContext(), data)) + { + return {}; + } + + value = nullptr; + return {}; + } }; template > @@ -440,20 +518,36 @@ struct llvm_add static_assert(llvm_value_t::is_sint || llvm_value_t::is_uint || llvm_value_t::is_float, "llvm_add<>: invalid type"); + static constexpr auto opc = llvm_value_t::is_float ? llvm::Instruction::FAdd : llvm::Instruction::Add; + llvm::Value* eval(llvm::IRBuilder<>* ir) const { const auto v1 = a1.eval(ir); const auto v2 = a2.eval(ir); + return ir->CreateBinOp(opc, v1, v2); + } - if constexpr (llvm_value_t::is_int) + llvm_match_tuple match(llvm::Value*& value) const + { + llvm::Value* v1 = {}; + llvm::Value* v2 = {}; + + if (auto i = llvm::dyn_cast_or_null(value); i && i->getOpcode() == opc) { - return ir->CreateAdd(v1, v2); + v1 = i->getOperand(0); + v2 = i->getOperand(1); + + if (auto r1 = a1.match(v1); v1) + { + if (auto r2 = a2.match(v2); v2) + { + return std::tuple_cat(r1, r2); + } + } } - if constexpr (llvm_value_t::is_float) - { - return ir->CreateFAdd(v1, v2); - } + value = nullptr; + return {}; } }; @@ -485,11 +579,13 @@ struct llvm_sum const auto v1 = a1.eval(ir); const auto v2 = a2.eval(ir); const auto v3 = a3.eval(ir); + return ir->CreateAdd(ir->CreateAdd(v1, v2), v3); + } - if constexpr (llvm_value_t::is_int) - { - return ir->CreateAdd(ir->CreateAdd(v1, v2), v3); - } + llvm_match_tuple match(llvm::Value*& value) const + { + value = nullptr; + return {}; } }; @@ -506,20 +602,36 @@ struct llvm_sub static_assert(llvm_value_t::is_sint || llvm_value_t::is_uint || llvm_value_t::is_float, "llvm_sub<>: invalid type"); + static constexpr auto opc = llvm_value_t::is_float ? llvm::Instruction::FSub : llvm::Instruction::Sub; + llvm::Value* eval(llvm::IRBuilder<>* ir) const { const auto v1 = a1.eval(ir); const auto v2 = a2.eval(ir); + return ir->CreateBinOp(opc, v1, v2); + } - if constexpr (llvm_value_t::is_int) + llvm_match_tuple match(llvm::Value*& value) const + { + llvm::Value* v1 = {}; + llvm::Value* v2 = {}; + + if (auto i = llvm::dyn_cast_or_null(value); i && i->getOpcode() == opc) { - return ir->CreateSub(v1, v2); + v1 = i->getOperand(0); + v2 = i->getOperand(1); + + if (auto r1 = a1.match(v1); v1) + { + if (auto r2 = a2.match(v2); v2) + { + return std::tuple_cat(r1, r2); + } + } } - if constexpr (llvm_value_t::is_float) - { - return ir->CreateFSub(v1, v2); - } + value = nullptr; + return {}; } }; @@ -551,20 +663,36 @@ struct llvm_mul static_assert(llvm_value_t::is_sint || llvm_value_t::is_uint || llvm_value_t::is_float, "llvm_mul<>: invalid type"); + static constexpr auto opc = llvm_value_t::is_float ? llvm::Instruction::FMul : llvm::Instruction::Mul; + llvm::Value* eval(llvm::IRBuilder<>* ir) const { const auto v1 = a1.eval(ir); const auto v2 = a2.eval(ir); + return ir->CreateBinOp(opc, v1, v2); + } - if constexpr (llvm_value_t::is_int) + llvm_match_tuple match(llvm::Value*& value) const + { + llvm::Value* v1 = {}; + llvm::Value* v2 = {}; + + if (auto i = llvm::dyn_cast_or_null(value); i && i->getOpcode() == opc) { - return ir->CreateMul(v1, v2); + v1 = i->getOperand(0); + v2 = i->getOperand(1); + + if (auto r1 = a1.match(v1); v1) + { + if (auto r2 = a2.match(v2); v2) + { + return std::tuple_cat(r1, r2); + } + } } - if constexpr (llvm_value_t::is_float) - { - return ir->CreateFMul(v1, v2); - } + value = nullptr; + return {}; } }; @@ -584,25 +712,38 @@ struct llvm_div static_assert(llvm_value_t::is_sint || llvm_value_t::is_uint || llvm_value_t::is_float, "llvm_div<>: invalid type"); + static constexpr auto opc = + llvm_value_t::is_float ? llvm::Instruction::FDiv : + llvm_value_t::is_uint ? llvm::Instruction::UDiv : llvm::Instruction::SDiv; + llvm::Value* eval(llvm::IRBuilder<>* ir) const { const auto v1 = a1.eval(ir); const auto v2 = a2.eval(ir); + return ir->CreateBinOp(opc, v1, v2); + } - if constexpr (llvm_value_t::is_sint) + llvm_match_tuple match(llvm::Value*& value) const + { + llvm::Value* v1 = {}; + llvm::Value* v2 = {}; + + if (auto i = llvm::dyn_cast_or_null(value); i && i->getOpcode() == opc) { - return ir->CreateSDiv(v1, v2); + v1 = i->getOperand(0); + v2 = i->getOperand(1); + + if (auto r1 = a1.match(v1); v1) + { + if (auto r2 = a2.match(v2); v2) + { + return std::tuple_cat(r1, r2); + } + } } - if constexpr (llvm_value_t::is_uint) - { - return ir->CreateUDiv(v1, v2); - } - - if constexpr (llvm_value_t::is_float) - { - return ir->CreateFDiv(v1, v2); - } + value = nullptr; + return {}; } }; @@ -621,6 +762,8 @@ struct llvm_neg static_assert(llvm_value_t::is_sint || llvm_value_t::is_uint || llvm_value_t::is_float, "llvm_neg<>: invalid type"); + static constexpr auto opc = llvm_value_t::is_float ? llvm::Instruction::FSub : llvm::Instruction::Sub; + llvm::Value* eval(llvm::IRBuilder<>* ir) const { const auto v1 = a1.eval(ir); @@ -635,6 +778,12 @@ struct llvm_neg return ir->CreateFNeg(v1); } } + + llvm_match_tuple match(llvm::Value*& value) const + { + value = nullptr; + return {}; + } }; template @@ -657,16 +806,30 @@ struct llvm_shl { const auto v1 = a1.eval(ir); const auto v2 = a2.eval(ir); + return ir->CreateShl(v1, v2); + } - if constexpr (llvm_value_t::is_sint) + llvm_match_tuple match(llvm::Value*& value) const + { + llvm::Value* v1 = {}; + llvm::Value* v2 = {}; + + if (auto i = llvm::dyn_cast_or_null(value); i && i->getOpcode() == llvm::Instruction::Shl) { - return ir->CreateShl(v1, v2); + v1 = i->getOperand(0); + v2 = i->getOperand(1); + + if (auto r1 = a1.match(v1); v1) + { + if (auto r2 = a2.match(v2); v2) + { + return std::tuple_cat(r1, r2); + } + } } - if constexpr (llvm_value_t::is_uint) - { - return ir->CreateShl(v1, v2); - } + value = nullptr; + return {}; } }; @@ -692,20 +855,36 @@ struct llvm_shr static_assert(llvm_value_t::is_sint || llvm_value_t::is_uint, "llvm_shr<>: invalid type"); + static constexpr auto opc = llvm_value_t::is_uint ? llvm::Instruction::LShr : llvm::Instruction::AShr; + llvm::Value* eval(llvm::IRBuilder<>* ir) const { const auto v1 = a1.eval(ir); const auto v2 = a2.eval(ir); + return ir->CreateBinOp(opc, v1, v2); + } - if constexpr (llvm_value_t::is_sint) + llvm_match_tuple match(llvm::Value*& value) const + { + llvm::Value* v1 = {}; + llvm::Value* v2 = {}; + + if (auto i = llvm::dyn_cast_or_null(value); i && i->getOpcode() == opc) { - return ir->CreateAShr(v1, v2); + v1 = i->getOperand(0); + v2 = i->getOperand(1); + + if (auto r1 = a1.match(v1); v1) + { + if (auto r2 = a2.match(v2); v2) + { + return std::tuple_cat(r1, r2); + } + } } - if constexpr (llvm_value_t::is_uint) - { - return ir->CreateLShr(v1, v2); - } + value = nullptr; + return {}; } }; @@ -763,6 +942,12 @@ struct llvm_fshl return ir->CreateCall(get_fshl(ir), {v1, v2, v3}); } + + llvm_match_tuple match(llvm::Value*& value) const + { + value = nullptr; + return {}; + } }; template > @@ -807,6 +992,12 @@ struct llvm_fshr return ir->CreateCall(get_fshr(ir), {v1, v2, v3}); } + + llvm_match_tuple match(llvm::Value*& value) const + { + value = nullptr; + return {}; + } }; template > @@ -833,6 +1024,12 @@ struct llvm_rol return ir->CreateCall(llvm_fshl::get_fshl(ir), {v1, v1, v2}); } + + llvm_match_tuple match(llvm::Value*& value) const + { + value = nullptr; + return {}; + } }; template > @@ -849,11 +1046,30 @@ struct llvm_and { const auto v1 = a1.eval(ir); const auto v2 = a2.eval(ir); + return ir->CreateAnd(v1, v2); + } - if constexpr (llvm_value_t::is_int) + llvm_match_tuple match(llvm::Value*& value) const + { + llvm::Value* v1 = {}; + llvm::Value* v2 = {}; + + if (auto i = llvm::dyn_cast_or_null(value); i && i->getOpcode() == llvm::Instruction::And) { - return ir->CreateAnd(v1, v2); + v1 = i->getOperand(0); + v2 = i->getOperand(1); + + if (auto r1 = a1.match(v1); v1) + { + if (auto r2 = a2.match(v2); v2) + { + return std::tuple_cat(r1, r2); + } + } } + + value = nullptr; + return {}; } }; @@ -883,11 +1099,30 @@ struct llvm_or { const auto v1 = a1.eval(ir); const auto v2 = a2.eval(ir); + return ir->CreateOr(v1, v2); + } - if constexpr (llvm_value_t::is_int) + llvm_match_tuple match(llvm::Value*& value) const + { + llvm::Value* v1 = {}; + llvm::Value* v2 = {}; + + if (auto i = llvm::dyn_cast_or_null(value); i && i->getOpcode() == llvm::Instruction::Or) { - return ir->CreateOr(v1, v2); + v1 = i->getOperand(0); + v2 = i->getOperand(1); + + if (auto r1 = a1.match(v1); v1) + { + if (auto r2 = a2.match(v2); v2) + { + return std::tuple_cat(r1, r2); + } + } } + + value = nullptr; + return {}; } }; @@ -917,11 +1152,30 @@ struct llvm_xor { const auto v1 = a1.eval(ir); const auto v2 = a2.eval(ir); + return ir->CreateXor(v1, v2); + } - if constexpr (llvm_value_t::is_int) + llvm_match_tuple match(llvm::Value*& value) const + { + llvm::Value* v1 = {}; + llvm::Value* v2 = {}; + + if (auto i = llvm::dyn_cast_or_null(value); i && i->getOpcode() == llvm::Instruction::Xor) { - return ir->CreateXor(v1, v2); + v1 = i->getOperand(0); + v2 = i->getOperand(1); + + if (auto r1 = a1.match(v1); v1) + { + if (auto r2 = a2.match(v2); v2) + { + return std::tuple_cat(r1, r2); + } + } } + + value = nullptr; + return {}; } }; @@ -970,11 +1224,13 @@ struct llvm_cmp const auto v1 = a1.eval(ir); const auto v2 = a2.eval(ir); + return ir->CreateICmp(pred, v1, v2); + } - if constexpr (llvm_value_t::is_int) - { - return ir->CreateICmp(pred, v1, v2); - } + llvm_match_tuple match(llvm::Value*& value) const + { + value = nullptr; + return {}; } }; @@ -1013,6 +1269,12 @@ struct llvm_ord const auto v2 = cmp.a2.eval(ir); return ir->CreateFCmp(pred, v1, v2); } + + llvm_match_tuple match(llvm::Value*& value) const + { + value = nullptr; + return {}; + } }; template @@ -1043,6 +1305,12 @@ struct llvm_uno const auto v2 = cmp.a2.eval(ir); return ir->CreateFCmp(pred, v1, v2); } + + llvm_match_tuple match(llvm::Value*& value) const + { + value = nullptr; + return {}; + } }; template @@ -1143,6 +1411,12 @@ struct llvm_noncast // No operation required return a1.eval(ir); } + + llvm_match_tuple match(llvm::Value*& value) const + { + value = nullptr; + return {}; + } }; template > @@ -1188,6 +1462,12 @@ struct llvm_bitcast return ir->CreateBitCast(v1, rt); } + + llvm_match_tuple match(llvm::Value*& value) const + { + value = nullptr; + return {}; + } }; template > @@ -1212,6 +1492,12 @@ struct llvm_trunc { return ir->CreateTrunc(a1.eval(ir), llvm_value_t::get_type(ir->getContext())); } + + llvm_match_tuple match(llvm::Value*& value) const + { + value = nullptr; + return {}; + } }; template > @@ -1236,6 +1522,12 @@ struct llvm_sext { return ir->CreateSExt(a1.eval(ir), llvm_value_t::get_type(ir->getContext())); } + + llvm_match_tuple match(llvm::Value*& value) const + { + value = nullptr; + return {}; + } }; template > @@ -1260,6 +1552,12 @@ struct llvm_zext { return ir->CreateZExt(a1.eval(ir), llvm_value_t::get_type(ir->getContext())); } + + llvm_match_tuple match(llvm::Value*& value) const + { + value = nullptr; + return {}; + } }; template , typename U = llvm_common_t> @@ -1282,6 +1580,12 @@ struct llvm_select { return ir->CreateSelect(cond.eval(ir), a2.eval(ir), a3.eval(ir)); } + + llvm_match_tuple match(llvm::Value*& value) const + { + value = nullptr; + return {}; + } }; template > @@ -1311,6 +1615,12 @@ struct llvm_min return ir->CreateSelect(ir->CreateICmpULT(v1, v2), v1, v2); } } + + llvm_match_tuple match(llvm::Value*& value) const + { + value = nullptr; + return {}; + } }; template > @@ -1338,6 +1648,12 @@ struct llvm_max return ir->CreateSelect(ir->CreateICmpULT(v1, v2), v2, v1); } } + + llvm_match_tuple match(llvm::Value*& value) const + { + value = nullptr; + return {}; + } }; template > @@ -1398,6 +1714,12 @@ struct llvm_add_sat return ir->CreateCall(get_uadd_sat(ir), {v1, v2}); } } + + llvm_match_tuple match(llvm::Value*& value) const + { + value = nullptr; + return {}; + } }; template > @@ -1457,6 +1779,12 @@ struct llvm_sub_sat return ir->CreateCall(get_usub_sat(ir), {v1, v2}); } } + + llvm_match_tuple match(llvm::Value*& value) const + { + value = nullptr; + return {}; + } }; template , typename U = llvm_common_t> @@ -1480,6 +1808,12 @@ struct llvm_extract return ir->CreateExtractElement(v1, v2); } + + llvm_match_tuple match(llvm::Value*& value) const + { + value = nullptr; + return {}; + } }; template , typename U = llvm_common_t, typename V = llvm_common_t> @@ -1506,6 +1840,12 @@ struct llvm_insert return ir->CreateInsertElement(v1, v3, v2); } + + llvm_match_tuple match(llvm::Value*& value) const + { + value = nullptr; + return {}; + } }; template > @@ -1530,6 +1870,12 @@ struct llvm_splat return ir->CreateVectorSplat(llvm_value_t::is_vector, v1); } + + llvm_match_tuple match(llvm::Value*& value) const + { + value = nullptr; + return {}; + } }; template > @@ -1550,6 +1896,12 @@ struct llvm_zshuffle return ir->CreateShuffleVector(v1, llvm::ConstantAggregateZero::get(v1->getType()), index_array); } + + llvm_match_tuple match(llvm::Value*& value) const + { + value = nullptr; + return {}; + } }; template > @@ -1572,6 +1924,12 @@ struct llvm_shuffle2 return ir->CreateShuffleVector(v1, v2, index_array); } + + llvm_match_tuple match(llvm::Value*& value) const + { + value = nullptr; + return {}; + } }; class cpu_translator