shader_recompiler: Separate thread bit scalars

* We can assume guest shader never mixes them with normal sgprs. This helps avoid errors where ssa could view an sgpr write dominating a thread bit read, due to how control flow is structurized, even though its not possible in actual control flow
This commit is contained in:
IndecisiveTurtle 2024-09-06 15:43:59 +03:00
parent 122a87dd23
commit fd6611ed54
2 changed files with 35 additions and 12 deletions

View file

@ -147,6 +147,7 @@ public:
/// Intrusively store the value of a register in the block.
std::array<Value, NumScalarRegs> ssa_sreg_values;
std::array<Value, NumScalarRegs> ssa_sbit_values;
std::array<Value, NumVectorRegs> ssa_vreg_values;
bool has_multiple_predecessors{false};

View file

@ -44,8 +44,17 @@ struct GotoVariable : FlagTag {
u32 index;
};
using Variant = std::variant<IR::ScalarReg, IR::VectorReg, GotoVariable, SccFlagTag, ExecFlagTag,
VccFlagTag, VccLoTag, VccHiTag, M0Tag>;
struct ThreadBitScalar : FlagTag {
ThreadBitScalar() = default;
explicit ThreadBitScalar(IR::ScalarReg sgpr_) : sgpr{sgpr_} {}
auto operator<=>(const ThreadBitScalar&) const noexcept = default;
IR::ScalarReg sgpr;
};
using Variant = std::variant<IR::ScalarReg, IR::VectorReg, GotoVariable, ThreadBitScalar,
SccFlagTag, ExecFlagTag, VccFlagTag, VccLoTag, VccHiTag, M0Tag>;
using ValueMap = std::unordered_map<IR::Block*, IR::Value>;
struct DefTable {
@ -70,6 +79,13 @@ struct DefTable {
goto_vars[variable.index].insert_or_assign(block, value);
}
const IR::Value& Def(IR::Block* block, ThreadBitScalar variable) {
return block->ssa_sreg_values[RegIndex(variable.sgpr)];
}
void SetDef(IR::Block* block, ThreadBitScalar variable, const IR::Value& value) {
block->ssa_sreg_values[RegIndex(variable.sgpr)] = value;
}
const IR::Value& Def(IR::Block* block, SccFlagTag) {
return scc_flag[block];
}
@ -173,7 +189,7 @@ public:
}
template <typename Type>
IR::Value ReadVariable(Type variable, IR::Block* root_block, bool is_thread_bit = false) {
IR::Value ReadVariable(Type variable, IR::Block* root_block) {
boost::container::small_vector<ReadState<Type>, 64> stack{
ReadState<Type>(nullptr),
ReadState<Type>(root_block),
@ -201,7 +217,7 @@ public:
} else if (!block->IsSsaSealed()) {
// Incomplete CFG
IR::Inst* phi{&*block->PrependNewInst(block->begin(), IR::Opcode::Phi)};
phi->SetFlags(is_thread_bit ? IR::Type::U1 : IR::TypeOf(UndefOpcode(variable)));
phi->SetFlags(IR::TypeOf(UndefOpcode(variable)));
incomplete_phis[block].insert_or_assign(variable, phi);
stack.back().result = IR::Value{&*phi};
@ -214,7 +230,7 @@ public:
} else {
// Break potential cycles with operandless phi
IR::Inst* const phi{&*block->PrependNewInst(block->begin(), IR::Opcode::Phi)};
phi->SetFlags(is_thread_bit ? IR::Type::U1 : IR::TypeOf(UndefOpcode(variable)));
phi->SetFlags(IR::TypeOf(UndefOpcode(variable)));
WriteVariable(variable, block, IR::Value{phi});
@ -263,9 +279,7 @@ private:
template <typename Type>
IR::Value AddPhiOperands(Type variable, IR::Inst& phi, IR::Block* block) {
for (IR::Block* const imm_pred : block->ImmPredecessors()) {
const bool is_thread_bit =
std::is_same_v<Type, IR::ScalarReg> && phi.Flags<IR::Type>() == IR::Type::U1;
phi.AddPhiOperand(imm_pred, ReadVariable(variable, imm_pred, is_thread_bit));
phi.AddPhiOperand(imm_pred, ReadVariable(variable, imm_pred));
}
return TryRemoveTrivialPhi(phi, block, UndefOpcode(variable));
}
@ -313,7 +327,11 @@ private:
void VisitInst(Pass& pass, IR::Block* block, IR::Inst& inst) {
const IR::Opcode opcode{inst.GetOpcode()};
switch (opcode) {
case IR::Opcode::SetThreadBitScalarReg:
case IR::Opcode::SetThreadBitScalarReg: {
const IR::ScalarReg reg{inst.Arg(0).ScalarReg()};
pass.WriteVariable(ThreadBitScalar{reg}, block, inst.Arg(1));
break;
}
case IR::Opcode::SetScalarRegister: {
const IR::ScalarReg reg{inst.Arg(0).ScalarReg()};
pass.WriteVariable(reg, block, inst.Arg(1));
@ -345,11 +363,15 @@ void VisitInst(Pass& pass, IR::Block* block, IR::Inst& inst) {
case IR::Opcode::SetM0:
pass.WriteVariable(M0Tag{}, block, inst.Arg(0));
break;
case IR::Opcode::GetThreadBitScalarReg:
case IR::Opcode::GetThreadBitScalarReg: {
const IR::ScalarReg reg{inst.Arg(0).ScalarReg()};
const IR::Value value = pass.ReadVariable(ThreadBitScalar{reg}, block);
inst.ReplaceUsesWith(value);
break;
}
case IR::Opcode::GetScalarRegister: {
const IR::ScalarReg reg{inst.Arg(0).ScalarReg()};
const bool thread_bit = opcode == IR::Opcode::GetThreadBitScalarReg;
const IR::Value value = pass.ReadVariable(reg, block, thread_bit);
const IR::Value value = pass.ReadVariable(reg, block);
inst.ReplaceUsesWith(value);
break;
}