ir: Improve read lane folding pass

This commit is contained in:
IndecisiveTurtle 2024-12-05 01:59:07 +02:00
parent a86b855b38
commit 6244d00795
2 changed files with 43 additions and 7 deletions

View file

@ -233,15 +233,51 @@ void FoldCmpClass(IR::Block& block, IR::Inst& inst) {
}
}
void FoldReadLane(IR::Inst& inst) {
void FoldReadLane(IR::Block& block, IR::Inst& inst) {
const u32 lane = inst.Arg(1).U32();
IR::Inst* prod = inst.Arg(0).InstRecursive();
while (prod->GetOpcode() == IR::Opcode::WriteLane) {
if (prod->Arg(2).U32() == lane) {
inst.ReplaceUsesWithAndRemove(prod->Arg(1));
const auto search_chain = [lane](const IR::Inst* prod) -> IR::Value {
while (prod->GetOpcode() == IR::Opcode::WriteLane) {
if (prod->Arg(2).U32() == lane) {
return prod->Arg(1);
}
prod = prod->Arg(0).InstRecursive();
}
return {};
};
if (prod->GetOpcode() == IR::Opcode::WriteLane) {
if (const IR::Value value = search_chain(prod); !value.IsEmpty()) {
inst.ReplaceUsesWith(value);
}
return;
}
if (prod->GetOpcode() == IR::Opcode::Phi) {
boost::container::small_vector<IR::Value, 2> phi_args;
for (size_t arg_index = 0; arg_index < prod->NumArgs(); ++arg_index) {
const IR::Inst* arg{prod->Arg(arg_index).InstRecursive()};
if (arg->GetOpcode() != IR::Opcode::WriteLane) {
return;
}
const IR::Value value = search_chain(arg);
if (value.IsEmpty()) {
continue;
}
phi_args.emplace_back(value);
}
if (std::ranges::all_of(phi_args, [&](IR::Value value) { return value == phi_args[0]; })) {
inst.ReplaceUsesWith(phi_args[0]);
return;
}
prod = prod->Arg(0).InstRecursive();
const auto insert_point = IR::Block::InstructionList::s_iterator_to(*prod);
IR::Inst* const new_phi{&*block.PrependNewInst(insert_point, IR::Opcode::Phi)};
new_phi->SetFlags(IR::Type::U32);
for (size_t arg_index = 0; arg_index < phi_args.size(); arg_index++) {
new_phi->AddPhiOperand(prod->PhiBlock(arg_index), phi_args[arg_index]);
}
inst.ReplaceUsesWith(IR::Value{new_phi});
}
}
@ -291,7 +327,7 @@ void ConstantPropagation(IR::Block& block, IR::Inst& inst) {
case IR::Opcode::SelectF64:
return FoldSelect(inst);
case IR::Opcode::ReadLane:
return FoldReadLane(inst);
return FoldReadLane(block, inst);
case IR::Opcode::FPNeg32:
FoldWhenAllImmediates(inst, [](f32 a) { return -a; });
return;

View file

@ -129,7 +129,7 @@ public:
Inst& operator=(Inst&&) = delete;
Inst(Inst&&) = delete;
IR::Block* GetParent() {
IR::Block* GetParent() const {
ASSERT(parent);
return parent;
}