From 4d224d6bce2a754541a19afc34572568ff348170 Mon Sep 17 00:00:00 2001 From: Vinicius Rangel Date: Thu, 25 Jul 2024 18:49:44 -0300 Subject: [PATCH] shader recompiler: more 64-bit work - removed bit_size parameter from Get[Scalar/Vector]Register - add BitwiseOr64 - add SetDst64 as a replacement for SetScalarReg64 & SetVectorReg64 - add GetSrc64 for 64-bit value --- .../backend/spirv/emit_spirv_instructions.h | 1 + .../backend/spirv/emit_spirv_integer.cpp | 7 + .../frontend/translate/translate.cpp | 127 +++++++++++++----- .../frontend/translate/translate.h | 1 + .../frontend/translate/vector_alu.cpp | 6 +- src/shader_recompiler/ir/ir_emitter.cpp | 49 ++----- src/shader_recompiler/ir/ir_emitter.h | 4 +- src/shader_recompiler/ir/opcodes.inc | 5 +- .../ir/passes/ssa_rewrite_pass.cpp | 22 --- 9 files changed, 122 insertions(+), 100 deletions(-) diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_instructions.h b/src/shader_recompiler/backend/spirv/emit_spirv_instructions.h index 6c7a551ff..bb533e9cf 100644 --- a/src/shader_recompiler/backend/spirv/emit_spirv_instructions.h +++ b/src/shader_recompiler/backend/spirv/emit_spirv_instructions.h @@ -272,6 +272,7 @@ Id EmitShiftRightArithmetic32(EmitContext& ctx, Id base, Id shift); Id EmitShiftRightArithmetic64(EmitContext& ctx, Id base, Id shift); Id EmitBitwiseAnd32(EmitContext& ctx, IR::Inst* inst, Id a, Id b); Id EmitBitwiseOr32(EmitContext& ctx, IR::Inst* inst, Id a, Id b); +Id EmitBitwiseOr64(EmitContext& ctx, IR::Inst* inst, Id a, Id b); Id EmitBitwiseXor32(EmitContext& ctx, IR::Inst* inst, Id a, Id b); Id EmitBitFieldInsert(EmitContext& ctx, Id base, Id insert, Id offset, Id count); Id EmitBitFieldSExtract(EmitContext& ctx, IR::Inst* inst, Id base, Id offset, Id count); diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_integer.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_integer.cpp index e2c73286f..e238a693e 100644 --- a/src/shader_recompiler/backend/spirv/emit_spirv_integer.cpp +++ b/src/shader_recompiler/backend/spirv/emit_spirv_integer.cpp @@ -146,6 +146,13 @@ Id EmitBitwiseOr32(EmitContext& ctx, IR::Inst* inst, Id a, Id b) { return result; } +Id EmitBitwiseOr64(EmitContext& ctx, IR::Inst* inst, Id a, Id b) { + const Id result{ctx.OpBitwiseOr(ctx.U64, a, b)}; + SetZeroFlag(ctx, inst, result); + SetSignFlag(ctx, inst, result); + return result; +} + Id EmitBitwiseXor32(EmitContext& ctx, IR::Inst* inst, Id a, Id b) { const Id result{ctx.OpBitwiseXor(ctx.U32[1], a, b)}; SetZeroFlag(ctx, inst, result); diff --git a/src/shader_recompiler/frontend/translate/translate.cpp b/src/shader_recompiler/frontend/translate/translate.cpp index 691b93d6f..b3bc871c4 100644 --- a/src/shader_recompiler/frontend/translate/translate.cpp +++ b/src/shader_recompiler/frontend/translate/translate.cpp @@ -186,105 +186,126 @@ IR::F32 Translator::GetSrc(const InstOperand& operand, bool) { template <> IR::U64F64 Translator::GetSrc64(const InstOperand& operand, bool force_flt) { - IR::U64F64 value{}; + IR::Value value_hi{}; + IR::Value value_lo{}; + bool immediate = false; bool is_float = operand.type == ScalarType::Float64 || force_flt; switch (operand.field) { case OperandField::ScalarGPR: if (is_float) { - value = ir.GetScalarReg(IR::ScalarReg(operand.code)); + value_lo = ir.GetScalarReg(IR::ScalarReg(operand.code)); + value_hi = ir.GetScalarReg(IR::ScalarReg(operand.code + 1)); } else if (operand.type == ScalarType::Uint64 || operand.type == ScalarType::Sint64) { - value = ir.GetScalarReg(IR::ScalarReg(operand.code)); + value_lo = ir.GetScalarReg(IR::ScalarReg(operand.code)); + value_hi = ir.GetScalarReg(IR::ScalarReg(operand.code + 1)); } else { UNREACHABLE(); } break; case OperandField::VectorGPR: if (is_float) { - value = ir.GetVectorReg(IR::VectorReg(operand.code)); + value_lo = ir.GetVectorReg(IR::VectorReg(operand.code)); + value_hi = ir.GetVectorReg(IR::VectorReg(operand.code + 1)); } else if (operand.type == ScalarType::Uint64 || operand.type == ScalarType::Sint64) { - value = ir.GetVectorReg(IR::VectorReg(operand.code)); + value_lo = ir.GetVectorReg(IR::VectorReg(operand.code)); + value_hi = ir.GetVectorReg(IR::VectorReg(operand.code + 1)); } else { UNREACHABLE(); } break; case OperandField::ConstZero: + immediate = true; if (force_flt) { - value = ir.Imm64(0.0); + value_lo = ir.Imm64(0.0); } else { - value = ir.Imm64(u64(0U)); + value_lo = ir.Imm64(u64(0U)); } break; case OperandField::SignedConstIntPos: ASSERT(!force_flt); - value = ir.Imm64(s64(operand.code) - SignedConstIntPosMin + 1); + immediate = true; + value_lo = ir.Imm64(s64(operand.code) - SignedConstIntPosMin + 1); break; case OperandField::SignedConstIntNeg: ASSERT(!force_flt); - value = ir.Imm64(-s64(operand.code) + SignedConstIntNegMin - 1); + immediate = true; + value_lo = ir.Imm64(-s64(operand.code) + SignedConstIntNegMin - 1); break; case OperandField::LiteralConst: + immediate = true; if (force_flt) { UNREACHABLE(); // There is a literal double? } else { - value = ir.Imm64(u64(operand.code)); + value_lo = ir.Imm64(u64(operand.code)); } break; case OperandField::ConstFloatPos_1_0: + immediate = true; if (force_flt) { - value = ir.Imm64(1.0); + value_lo = ir.Imm64(1.0); } else { - value = ir.Imm64(std::bit_cast(double(1.0))); + value_lo = ir.Imm64(std::bit_cast(f64(1.0))); } break; case OperandField::ConstFloatPos_0_5: - value = ir.Imm64(0.5); + immediate = true; + value_lo = ir.Imm64(0.5); break; case OperandField::ConstFloatPos_2_0: - value = ir.Imm64(2.0); + immediate = true; + value_lo = ir.Imm64(2.0); break; case OperandField::ConstFloatPos_4_0: - value = ir.Imm64(4.0); + immediate = true; + value_lo = ir.Imm64(4.0); break; case OperandField::ConstFloatNeg_0_5: - value = ir.Imm64(-0.5); + immediate = true; + value_lo = ir.Imm64(-0.5); break; case OperandField::ConstFloatNeg_1_0: - value = ir.Imm64(-1.0); + immediate = true; + value_lo = ir.Imm64(-1.0); break; case OperandField::ConstFloatNeg_2_0: - value = ir.Imm64(-2.0); + immediate = true; + value_lo = ir.Imm64(-2.0); break; case OperandField::ConstFloatNeg_4_0: - value = ir.Imm64(-4.0); - break; - case OperandField::VccLo: - if (force_flt) { - value = ir.BitCast(IR::U64(ir.UConvert(64, ir.GetVccLo()))); - } else { - value = ir.UConvert(64, ir.GetVccLo()); - } + immediate = true; + value_lo = ir.Imm64(-4.0); break; + case OperandField::VccLo: { + value_lo = ir.GetVccLo(); + value_hi = ir.GetVccHi(); + } break; case OperandField::VccHi: - if (force_flt) { - value = ir.BitCast(IR::U64(ir.UConvert(64, ir.GetVccHi()))); - } else { - value = ir.UConvert(64, ir.GetVccHi()); - } - break; + UNREACHABLE(); default: UNREACHABLE(); } + IR::Value value; + + if (immediate) { + value = value_lo; + } else if (is_float) { + throw NotImplementedException("required OpPackDouble2x32 implementation"); + } else { + IR::Value packed = ir.CompositeConstruct(value_lo, value_hi); + value = ir.PackUint2x32(packed); + } + if (is_float) { if (operand.input_modifier.abs) { - value = ir.FPAbs(value); + value = ir.FPAbs(IR::F32F64(value)); } if (operand.input_modifier.neg) { - value = ir.FPNeg(value); + value = ir.FPNeg(IR::F32F64(value)); } } - return value; + return IR::U64F64(value); } template <> @@ -320,6 +341,42 @@ void Translator::SetDst(const InstOperand& operand, const IR::U32F32& value) { } } +void Translator::SetDst64(const InstOperand& operand, const IR::U64F64& value_raw) { + IR::U64F64 value_untyped = value_raw; + + bool is_float = value_raw.Type() == IR::Type::F64 || value_raw.Type() == IR::Type::F32; + if (is_float) { + if (operand.output_modifier.multiplier != 0.f) { + value_untyped = + ir.FPMul(value_untyped, ir.Imm64(f64(operand.output_modifier.multiplier))); + } + if (operand.output_modifier.clamp) { + value_untyped = ir.FPSaturate(value_raw); + } + } + IR::U64 value = is_float ? ir.BitCast(IR::F64{value_untyped}) : IR::U64{value_untyped}; + + IR::Value unpacked{ir.UnpackUint2x32(value)}; + IR::U32 lo{ir.CompositeExtract(unpacked, 0U)}; + IR::U32 hi{ir.CompositeExtract(unpacked, 1U)}; + switch (operand.field) { + case OperandField::ScalarGPR: + ir.SetScalarReg(IR::ScalarReg(operand.code + 1), hi); + return ir.SetScalarReg(IR::ScalarReg(operand.code), lo); + case OperandField::VectorGPR: + ir.SetVectorReg(IR::VectorReg(operand.code + 1), hi); + return ir.SetVectorReg(IR::VectorReg(operand.code), lo); + case OperandField::VccLo: + UNREACHABLE(); + case OperandField::VccHi: + UNREACHABLE(); + case OperandField::M0: + break; + default: + UNREACHABLE(); + } +} + void Translator::EmitFetch(const GcnInst& inst) { // Read the pointer to the fetch shader assembly. const u32 sgpr_base = inst.src[0].code; diff --git a/src/shader_recompiler/frontend/translate/translate.h b/src/shader_recompiler/frontend/translate/translate.h index 5feec8f71..3203ad730 100644 --- a/src/shader_recompiler/frontend/translate/translate.h +++ b/src/shader_recompiler/frontend/translate/translate.h @@ -193,6 +193,7 @@ private: template [[nodiscard]] T GetSrc64(const InstOperand& operand, bool flt_zero = false); void SetDst(const InstOperand& operand, const IR::U32F32& value); + void SetDst64(const InstOperand& operand, const IR::U64F64& value_raw); private: IR::IREmitter ir; diff --git a/src/shader_recompiler/frontend/translate/vector_alu.cpp b/src/shader_recompiler/frontend/translate/vector_alu.cpp index d7ca64796..6e73c980f 100644 --- a/src/shader_recompiler/frontend/translate/vector_alu.cpp +++ b/src/shader_recompiler/frontend/translate/vector_alu.cpp @@ -67,7 +67,8 @@ void Translator::V_OR_B32(bool is_xor, const GcnInst& inst) { const IR::U32 src0{GetSrc(inst.src[0])}; const IR::U32 src1{ir.GetVectorReg(IR::VectorReg(inst.src[1].code))}; const IR::VectorReg dst_reg{inst.dst[0].code}; - ir.SetVectorReg(dst_reg, is_xor ? ir.BitwiseXor(src0, src1) : ir.BitwiseOr(src0, src1)); + ir.SetVectorReg(dst_reg, + is_xor ? ir.BitwiseXor(src0, src1) : IR::U32(ir.BitwiseOr(src0, src1))); } void Translator::V_AND_B32(const GcnInst& inst) { @@ -328,8 +329,7 @@ void Translator::V_MAD_U64_U32(const GcnInst& inst) { result = ir.IMul(src0, src1); result = ir.IAdd(ir.UConvert(64, result), src2); - const IR::VectorReg dst_reg{inst.dst[0].code}; - ir.SetVectorReg64(dst_reg, result); + SetDst64(inst.dst[0], result); } void Translator::V_CMP_U32(ConditionOp op, bool is_signed, bool set_exec, const GcnInst& inst) { diff --git a/src/shader_recompiler/ir/ir_emitter.cpp b/src/shader_recompiler/ir/ir_emitter.cpp index 5e64c7827..d203bd616 100644 --- a/src/shader_recompiler/ir/ir_emitter.cpp +++ b/src/shader_recompiler/ir/ir_emitter.cpp @@ -145,7 +145,7 @@ void IREmitter::SetThreadBitScalarReg(IR::ScalarReg reg, const U1& value) { template <> U32 IREmitter::GetScalarReg(IR::ScalarReg reg) { - return Inst(Opcode::GetScalarRegister, reg, Imm32(32)); + return Inst(Opcode::GetScalarRegister, reg); } template <> @@ -153,19 +153,9 @@ F32 IREmitter::GetScalarReg(IR::ScalarReg reg) { return BitCast(GetScalarReg(reg)); } -template <> -U64 IREmitter::GetScalarReg(IR::ScalarReg reg) { - return Inst(Opcode::GetScalarRegister, reg, Imm32(64)); -} - -template <> -F64 IREmitter::GetScalarReg(IR::ScalarReg reg) { - return BitCast(GetScalarReg(reg)); -} - template <> U32 IREmitter::GetVectorReg(IR::VectorReg reg) { - return Inst(Opcode::GetVectorRegister, reg, Imm32(32)); + return Inst(Opcode::GetVectorRegister, reg); } template <> @@ -173,36 +163,16 @@ F32 IREmitter::GetVectorReg(IR::VectorReg reg) { return BitCast(GetVectorReg(reg)); } -template <> -U64 IREmitter::GetVectorReg(IR::VectorReg reg) { - return Inst(Opcode::GetVectorRegister, reg, Imm32(64)); -} - -template <> -F64 IREmitter::GetVectorReg(IR::VectorReg reg) { - return BitCast(GetVectorReg(reg)); -} - void IREmitter::SetScalarReg(IR::ScalarReg reg, const U32F32& value) { const U32 value_typed = value.Type() == Type::F32 ? BitCast(F32{value}) : U32{value}; Inst(Opcode::SetScalarRegister, reg, value_typed); } -void IREmitter::SetScalarReg64(IR::ScalarReg reg, const U64F64& value) { - const U64 value_typed = value.Type() == Type::F64 ? BitCast(F64{value}) : U64{value}; - Inst(Opcode::SetScalarRegister, reg, value_typed); -} - void IREmitter::SetVectorReg(IR::VectorReg reg, const U32F32& value) { const U32 value_typed = value.Type() == Type::F32 ? BitCast(F32{value}) : U32{value}; Inst(Opcode::SetVectorRegister, reg, value_typed); } -void IREmitter::SetVectorReg64(IR::VectorReg reg, const U64F64& value) { - const U64 value_typed = value.Type() == Type::F64 ? BitCast(F64{value}) : U64{value}; - Inst(Opcode::SetVectorRegister, reg, value_typed); -} - U1 IREmitter::GetGotoVariable(u32 id) { return Inst(Opcode::GetGotoVariable, id); } @@ -243,7 +213,6 @@ U1 IREmitter::GetExec() { } U1 IREmitter::GetVcc() { - // FIXME Should it be a thread bit? return Inst(Opcode::GetVcc); } @@ -1065,8 +1034,18 @@ U32 IREmitter::BitwiseAnd(const U32& a, const U32& b) { return Inst(Opcode::BitwiseAnd32, a, b); } -U32 IREmitter::BitwiseOr(const U32& a, const U32& b) { - return Inst(Opcode::BitwiseOr32, a, b); +U32U64 IREmitter::BitwiseOr(const U32U64& a, const U32U64& b) { + if (a.Type() != b.Type()) { + UNREACHABLE_MSG("Mismatching types {} and {}", a.Type(), b.Type()); + } + switch (a.Type()) { + case Type::U32: + return Inst(Opcode::BitwiseOr32, a, b); + case Type::U64: + return Inst(Opcode::BitwiseOr64, a, b); + default: + ThrowInvalidType(a.Type()); + } } U32 IREmitter::BitwiseXor(const U32& a, const U32& b) { diff --git a/src/shader_recompiler/ir/ir_emitter.h b/src/shader_recompiler/ir/ir_emitter.h index 4ce973ae1..423f7d59d 100644 --- a/src/shader_recompiler/ir/ir_emitter.h +++ b/src/shader_recompiler/ir/ir_emitter.h @@ -57,9 +57,7 @@ public: template [[nodiscard]] T GetVectorReg(IR::VectorReg reg); void SetScalarReg(IR::ScalarReg reg, const U32F32& value); - void SetScalarReg64(IR::ScalarReg reg, const U64F64& value); void SetVectorReg(IR::VectorReg reg, const U32F32& value); - void SetVectorReg64(IR::VectorReg reg, const U64F64& value); [[nodiscard]] U1 GetGotoVariable(u32 id); void SetGotoVariable(u32 id, const U1& value); @@ -169,7 +167,7 @@ public: [[nodiscard]] U32U64 ShiftRightLogical(const U32U64& base, const U32& shift); [[nodiscard]] U32U64 ShiftRightArithmetic(const U32U64& base, const U32& shift); [[nodiscard]] U32 BitwiseAnd(const U32& a, const U32& b); - [[nodiscard]] U32 BitwiseOr(const U32& a, const U32& b); + [[nodiscard]] U32U64 BitwiseOr(const U32U64& a, const U32U64& b); [[nodiscard]] U32 BitwiseXor(const U32& a, const U32& b); [[nodiscard]] U32 BitFieldInsert(const U32& base, const U32& insert, const U32& offset, const U32& count); diff --git a/src/shader_recompiler/ir/opcodes.inc b/src/shader_recompiler/ir/opcodes.inc index 11d146f72..eadb6e7b5 100644 --- a/src/shader_recompiler/ir/opcodes.inc +++ b/src/shader_recompiler/ir/opcodes.inc @@ -43,9 +43,9 @@ OPCODE(WriteSharedU128, Void, U32, OPCODE(GetUserData, U32, ScalarReg, ) OPCODE(GetThreadBitScalarReg, U1, ScalarReg, ) OPCODE(SetThreadBitScalarReg, Void, ScalarReg, U1, ) -OPCODE(GetScalarRegister, U32, ScalarReg, U32, ) +OPCODE(GetScalarRegister, U32, ScalarReg, ) OPCODE(SetScalarRegister, Void, ScalarReg, U32, ) -OPCODE(GetVectorRegister, U32, VectorReg, U32, ) +OPCODE(GetVectorRegister, U32, VectorReg, ) OPCODE(SetVectorRegister, Void, VectorReg, U32, ) OPCODE(GetGotoVariable, U1, U32, ) OPCODE(SetGotoVariable, Void, U32, U1, ) @@ -243,6 +243,7 @@ OPCODE(ShiftRightArithmetic32, U32, U32, OPCODE(ShiftRightArithmetic64, U64, U64, U32, ) OPCODE(BitwiseAnd32, U32, U32, U32, ) OPCODE(BitwiseOr32, U32, U32, U32, ) +OPCODE(BitwiseOr64, U64, U64, U64, ) OPCODE(BitwiseXor32, U32, U32, U32, ) OPCODE(BitFieldInsert, U32, U32, U32, U32, U32, ) OPCODE(BitFieldSExtract, U32, U32, U32, U32, ) diff --git a/src/shader_recompiler/ir/passes/ssa_rewrite_pass.cpp b/src/shader_recompiler/ir/passes/ssa_rewrite_pass.cpp index 6a686fadb..4cf73c75d 100644 --- a/src/shader_recompiler/ir/passes/ssa_rewrite_pass.cpp +++ b/src/shader_recompiler/ir/passes/ssa_rewrite_pass.cpp @@ -351,34 +351,12 @@ void VisitInst(Pass& pass, IR::Block* block, const IR::Block::iterator& iter) { const IR::ScalarReg reg{inst.Arg(0).ScalarReg()}; bool thread_bit = opcode == IR::Opcode::GetThreadBitScalarReg; IR::Value value = pass.ReadVariable(reg, block, thread_bit); - - if (!thread_bit) { - size_t bit_size{inst.Arg(1).U32()}; - if (bit_size == 32 && value.Type() == IR::Type::U64) { - auto it{block->PrependNewInst(iter, IR::Opcode::ConvertU32U64, {value})}; - value = IR::U32{IR::Value{&*it}}; - } else if (bit_size == 64 && value.Type() == IR::Type::U32) { - auto it{block->PrependNewInst(iter, IR::Opcode::ConvertU64U32, {value})}; - value = IR::U64{IR::Value{&*it}}; - } - } - inst.ReplaceUsesWith(value); break; } case IR::Opcode::GetVectorRegister: { const IR::VectorReg reg{inst.Arg(0).VectorReg()}; IR::Value value = pass.ReadVariable(reg, block); - - size_t bit_size{inst.Arg(1).U32()}; - if (bit_size == 32 && value.Type() == IR::Type::U64) { - auto it{block->PrependNewInst(iter, IR::Opcode::ConvertU32U64, {value})}; - value = IR::U32{IR::Value{&*it}}; - } else if (bit_size == 64 && value.Type() == IR::Type::U32) { - auto it{block->PrependNewInst(iter, IR::Opcode::ConvertU64U32, {value})}; - value = IR::U64{IR::Value{&*it}}; - } - inst.ReplaceUsesWith(value); break; }