SPU LLVM: Allow swapped FMA and multiplications args in match context

This commit is contained in:
Elad.Ash 2024-02-07 13:45:29 +02:00 committed by GitHub
parent 96b7e4c67f
commit 30e8c3e951
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 89 additions and 4 deletions

View file

@ -685,6 +685,18 @@ struct llvm_add
return std::tuple_cat(r1, r2);
}
}
v1 = i->getOperand(0);
v2 = i->getOperand(1);
// Argument order does not matter here, try when swapped
if (auto r1 = a1.match(v2, _m); v2)
{
if (auto r2 = a2.match(v1, _m); v1)
{
return std::tuple_cat(r1, r2);
}
}
}
value = nullptr;
@ -858,6 +870,18 @@ struct llvm_mul
return std::tuple_cat(r1, r2);
}
}
v1 = i->getOperand(0);
v2 = i->getOperand(1);
// Argument order does not matter here, try when swapped
if (auto r1 = a1.match(v2, _m); v2)
{
if (auto r2 = a2.match(v1, _m); v1)
{
return std::tuple_cat(r1, r2);
}
}
}
value = nullptr;
@ -2239,6 +2263,18 @@ struct llvm_add_sat
return std::tuple_cat(r1, r2);
}
}
v1 = i->getOperand(0);
v2 = i->getOperand(1);
// Argument order does not matter here, try when swapped
if (auto r1 = a1.match(v2, _m); v2)
{
if (auto r2 = a2.match(v1, _m); v1)
{
return std::tuple_cat(r1, r2);
}
}
}
value = nullptr;
@ -2868,6 +2904,22 @@ struct llvm_fmuladd
}
}
}
v1 = i->getOperand(0);
v2 = i->getOperand(1);
v3 = i->getOperand(2);
// With multiplication args swapped
if (auto r1 = a1.match(v2, _m); v2)
{
if (auto r2 = a2.match(v1, _m); v1)
{
if (auto r3 = a3.match(v3, _m); v3)
{
return std::tuple_cat(r1, r2, r3);
}
}
}
}
value = nullptr;
@ -2884,6 +2936,18 @@ struct llvm_calli
std::tuple<llvm_expr_t<A>...> a;
std::array<usz, sizeof...(A)> order_equality_hint = []()
{
std::array<usz, sizeof...(A)> r{};
for (usz i = 0; i < r.size(); i++)
{
r[i] = i;
}
return r;
}();
llvm::Value*(*c)(llvm::Value**, llvm::IRBuilder<>*){};
llvm::Value* eval(llvm::IRBuilder<>* ir) const
@ -2917,6 +2981,13 @@ struct llvm_calli
return *this;
}
template <typename... Args> requires (sizeof...(Args) == sizeof...(A))
llvm_calli& set_order_equality_hint(Args... args)
{
order_equality_hint = {static_cast<usz>(args)...};
return *this;
}
llvm_match_tuple<A...> match(llvm::Value*& value, llvm::Module* _m) const
{
return match(value, _m, std::make_index_sequence<sizeof...(A)>());
@ -2939,6 +3010,20 @@ struct llvm_calli
{
return std::tuple_cat(std::get<I>(r)...);
}
if constexpr (sizeof...(A) >= 2)
{
if (order_equality_hint[0] == order_equality_hint[1])
{
// Test if it works with the first pair swapped
((v[I <= 1 ? I ^ 1 : I] = i->getOperand(I)), ...);
if (((std::get<I>(r) = std::get<I>(a).match(v[I], _m), v[I]) && ...))
{
return std::tuple_cat(std::get<I>(r)...);
}
}
}
}
}

View file

@ -5582,7 +5582,7 @@ public:
template <typename T, typename U>
static llvm_calli<f32[4], T, U> fm(T&& a, U&& b)
{
return {"spu_fm", {std::forward<T>(a), std::forward<U>(b)}};
return llvm_calli<f32[4], T, U>{"spu_fm", {std::forward<T>(a), std::forward<U>(b)}}.set_order_equality_hint(1, 1);
}
void FM(spu_opcode_t op)
@ -5933,7 +5933,7 @@ public:
template <typename T, typename U, typename V>
static llvm_calli<f32[4], T, U, V> fnms(T&& a, U&& b, V&& c)
{
return {"spu_fnms", {std::forward<T>(a), std::forward<U>(b), std::forward<V>(c)}};
return llvm_calli<f32[4], T, U, V>{"spu_fnms", {std::forward<T>(a), std::forward<U>(b), std::forward<V>(c)}}.set_order_equality_hint(1, 1, 0);
}
void FNMS(spu_opcode_t op)
@ -5968,7 +5968,7 @@ public:
template <typename T, typename U, typename V>
static llvm_calli<f32[4], T, U, V> fma(T&& a, U&& b, V&& c)
{
return {"spu_fma", {std::forward<T>(a), std::forward<U>(b), std::forward<V>(c)}};
return llvm_calli<f32[4], T, U, V>{"spu_fma", {std::forward<T>(a), std::forward<U>(b), std::forward<V>(c)}}.set_order_equality_hint(1, 1, 0);
}
template <typename T>
@ -6217,7 +6217,7 @@ public:
template <typename T, typename U, typename V>
static llvm_calli<f32[4], T, U, V> fms(T&& a, U&& b, V&& c)
{
return {"spu_fms", {std::forward<T>(a), std::forward<U>(b), std::forward<V>(c)}};
return llvm_calli<f32[4], T, U, V>{"spu_fms", {std::forward<T>(a), std::forward<U>(b), std::forward<V>(c)}}.set_order_equality_hint(1, 1, 0);
}
void FMS(spu_opcode_t op)