shader recompiler: more 64-bit work

- removed bit_size parameter from Get[Scalar/Vector]Register
- add BitwiseOr64
- add SetDst64 as a replacement for SetScalarReg64 & SetVectorReg64
- add GetSrc64 for 64-bit value
This commit is contained in:
Vinicius Rangel 2024-07-25 18:49:44 -03:00
parent 11d9fbd20e
commit 4d224d6bce
No known key found for this signature in database
GPG key ID: A5B154D904B761D9
9 changed files with 122 additions and 100 deletions

View file

@ -272,6 +272,7 @@ Id EmitShiftRightArithmetic32(EmitContext& ctx, Id base, Id shift);
Id EmitShiftRightArithmetic64(EmitContext& ctx, Id base, Id shift);
Id EmitBitwiseAnd32(EmitContext& ctx, IR::Inst* inst, Id a, Id b);
Id EmitBitwiseOr32(EmitContext& ctx, IR::Inst* inst, Id a, Id b);
Id EmitBitwiseOr64(EmitContext& ctx, IR::Inst* inst, Id a, Id b);
Id EmitBitwiseXor32(EmitContext& ctx, IR::Inst* inst, Id a, Id b);
Id EmitBitFieldInsert(EmitContext& ctx, Id base, Id insert, Id offset, Id count);
Id EmitBitFieldSExtract(EmitContext& ctx, IR::Inst* inst, Id base, Id offset, Id count);

View file

@ -146,6 +146,13 @@ Id EmitBitwiseOr32(EmitContext& ctx, IR::Inst* inst, Id a, Id b) {
return result;
}
Id EmitBitwiseOr64(EmitContext& ctx, IR::Inst* inst, Id a, Id b) {
const Id result{ctx.OpBitwiseOr(ctx.U64, a, b)};
SetZeroFlag(ctx, inst, result);
SetSignFlag(ctx, inst, result);
return result;
}
Id EmitBitwiseXor32(EmitContext& ctx, IR::Inst* inst, Id a, Id b) {
const Id result{ctx.OpBitwiseXor(ctx.U32[1], a, b)};
SetZeroFlag(ctx, inst, result);

View file

@ -186,105 +186,126 @@ IR::F32 Translator::GetSrc(const InstOperand& operand, bool) {
template <>
IR::U64F64 Translator::GetSrc64(const InstOperand& operand, bool force_flt) {
IR::U64F64 value{};
IR::Value value_hi{};
IR::Value value_lo{};
bool immediate = false;
bool is_float = operand.type == ScalarType::Float64 || force_flt;
switch (operand.field) {
case OperandField::ScalarGPR:
if (is_float) {
value = ir.GetScalarReg<IR::F64>(IR::ScalarReg(operand.code));
value_lo = ir.GetScalarReg<IR::F32>(IR::ScalarReg(operand.code));
value_hi = ir.GetScalarReg<IR::F32>(IR::ScalarReg(operand.code + 1));
} else if (operand.type == ScalarType::Uint64 || operand.type == ScalarType::Sint64) {
value = ir.GetScalarReg<IR::U64>(IR::ScalarReg(operand.code));
value_lo = ir.GetScalarReg<IR::U32>(IR::ScalarReg(operand.code));
value_hi = ir.GetScalarReg<IR::U32>(IR::ScalarReg(operand.code + 1));
} else {
UNREACHABLE();
}
break;
case OperandField::VectorGPR:
if (is_float) {
value = ir.GetVectorReg<IR::F64>(IR::VectorReg(operand.code));
value_lo = ir.GetVectorReg<IR::F32>(IR::VectorReg(operand.code));
value_hi = ir.GetVectorReg<IR::F32>(IR::VectorReg(operand.code + 1));
} else if (operand.type == ScalarType::Uint64 || operand.type == ScalarType::Sint64) {
value = ir.GetVectorReg<IR::U64>(IR::VectorReg(operand.code));
value_lo = ir.GetVectorReg<IR::U32>(IR::VectorReg(operand.code));
value_hi = ir.GetVectorReg<IR::U32>(IR::VectorReg(operand.code + 1));
} else {
UNREACHABLE();
}
break;
case OperandField::ConstZero:
immediate = true;
if (force_flt) {
value = ir.Imm64(0.0);
value_lo = ir.Imm64(0.0);
} else {
value = ir.Imm64(u64(0U));
value_lo = ir.Imm64(u64(0U));
}
break;
case OperandField::SignedConstIntPos:
ASSERT(!force_flt);
value = ir.Imm64(s64(operand.code) - SignedConstIntPosMin + 1);
immediate = true;
value_lo = ir.Imm64(s64(operand.code) - SignedConstIntPosMin + 1);
break;
case OperandField::SignedConstIntNeg:
ASSERT(!force_flt);
value = ir.Imm64(-s64(operand.code) + SignedConstIntNegMin - 1);
immediate = true;
value_lo = ir.Imm64(-s64(operand.code) + SignedConstIntNegMin - 1);
break;
case OperandField::LiteralConst:
immediate = true;
if (force_flt) {
UNREACHABLE(); // There is a literal double?
} else {
value = ir.Imm64(u64(operand.code));
value_lo = ir.Imm64(u64(operand.code));
}
break;
case OperandField::ConstFloatPos_1_0:
immediate = true;
if (force_flt) {
value = ir.Imm64(1.0);
value_lo = ir.Imm64(1.0);
} else {
value = ir.Imm64(std::bit_cast<u64>(double(1.0)));
value_lo = ir.Imm64(std::bit_cast<u64>(f64(1.0)));
}
break;
case OperandField::ConstFloatPos_0_5:
value = ir.Imm64(0.5);
immediate = true;
value_lo = ir.Imm64(0.5);
break;
case OperandField::ConstFloatPos_2_0:
value = ir.Imm64(2.0);
immediate = true;
value_lo = ir.Imm64(2.0);
break;
case OperandField::ConstFloatPos_4_0:
value = ir.Imm64(4.0);
immediate = true;
value_lo = ir.Imm64(4.0);
break;
case OperandField::ConstFloatNeg_0_5:
value = ir.Imm64(-0.5);
immediate = true;
value_lo = ir.Imm64(-0.5);
break;
case OperandField::ConstFloatNeg_1_0:
value = ir.Imm64(-1.0);
immediate = true;
value_lo = ir.Imm64(-1.0);
break;
case OperandField::ConstFloatNeg_2_0:
value = ir.Imm64(-2.0);
immediate = true;
value_lo = ir.Imm64(-2.0);
break;
case OperandField::ConstFloatNeg_4_0:
value = ir.Imm64(-4.0);
break;
case OperandField::VccLo:
if (force_flt) {
value = ir.BitCast<IR::F64>(IR::U64(ir.UConvert(64, ir.GetVccLo())));
} else {
value = ir.UConvert(64, ir.GetVccLo());
}
immediate = true;
value_lo = ir.Imm64(-4.0);
break;
case OperandField::VccLo: {
value_lo = ir.GetVccLo();
value_hi = ir.GetVccHi();
} break;
case OperandField::VccHi:
if (force_flt) {
value = ir.BitCast<IR::F64>(IR::U64(ir.UConvert(64, ir.GetVccHi())));
} else {
value = ir.UConvert(64, ir.GetVccHi());
}
break;
UNREACHABLE();
default:
UNREACHABLE();
}
IR::Value value;
if (immediate) {
value = value_lo;
} else if (is_float) {
throw NotImplementedException("required OpPackDouble2x32 implementation");
} else {
IR::Value packed = ir.CompositeConstruct(value_lo, value_hi);
value = ir.PackUint2x32(packed);
}
if (is_float) {
if (operand.input_modifier.abs) {
value = ir.FPAbs(value);
value = ir.FPAbs(IR::F32F64(value));
}
if (operand.input_modifier.neg) {
value = ir.FPNeg(value);
value = ir.FPNeg(IR::F32F64(value));
}
}
return value;
return IR::U64F64(value);
}
template <>
@ -320,6 +341,42 @@ void Translator::SetDst(const InstOperand& operand, const IR::U32F32& value) {
}
}
void Translator::SetDst64(const InstOperand& operand, const IR::U64F64& value_raw) {
IR::U64F64 value_untyped = value_raw;
bool is_float = value_raw.Type() == IR::Type::F64 || value_raw.Type() == IR::Type::F32;
if (is_float) {
if (operand.output_modifier.multiplier != 0.f) {
value_untyped =
ir.FPMul(value_untyped, ir.Imm64(f64(operand.output_modifier.multiplier)));
}
if (operand.output_modifier.clamp) {
value_untyped = ir.FPSaturate(value_raw);
}
}
IR::U64 value = is_float ? ir.BitCast<IR::U64>(IR::F64{value_untyped}) : IR::U64{value_untyped};
IR::Value unpacked{ir.UnpackUint2x32(value)};
IR::U32 lo{ir.CompositeExtract(unpacked, 0U)};
IR::U32 hi{ir.CompositeExtract(unpacked, 1U)};
switch (operand.field) {
case OperandField::ScalarGPR:
ir.SetScalarReg(IR::ScalarReg(operand.code + 1), hi);
return ir.SetScalarReg(IR::ScalarReg(operand.code), lo);
case OperandField::VectorGPR:
ir.SetVectorReg(IR::VectorReg(operand.code + 1), hi);
return ir.SetVectorReg(IR::VectorReg(operand.code), lo);
case OperandField::VccLo:
UNREACHABLE();
case OperandField::VccHi:
UNREACHABLE();
case OperandField::M0:
break;
default:
UNREACHABLE();
}
}
void Translator::EmitFetch(const GcnInst& inst) {
// Read the pointer to the fetch shader assembly.
const u32 sgpr_base = inst.src[0].code;

View file

@ -193,6 +193,7 @@ private:
template <typename T = IR::U64F64>
[[nodiscard]] T GetSrc64(const InstOperand& operand, bool flt_zero = false);
void SetDst(const InstOperand& operand, const IR::U32F32& value);
void SetDst64(const InstOperand& operand, const IR::U64F64& value_raw);
private:
IR::IREmitter ir;

View file

@ -67,7 +67,8 @@ void Translator::V_OR_B32(bool is_xor, const GcnInst& inst) {
const IR::U32 src0{GetSrc(inst.src[0])};
const IR::U32 src1{ir.GetVectorReg(IR::VectorReg(inst.src[1].code))};
const IR::VectorReg dst_reg{inst.dst[0].code};
ir.SetVectorReg(dst_reg, is_xor ? ir.BitwiseXor(src0, src1) : ir.BitwiseOr(src0, src1));
ir.SetVectorReg(dst_reg,
is_xor ? ir.BitwiseXor(src0, src1) : IR::U32(ir.BitwiseOr(src0, src1)));
}
void Translator::V_AND_B32(const GcnInst& inst) {
@ -328,8 +329,7 @@ void Translator::V_MAD_U64_U32(const GcnInst& inst) {
result = ir.IMul(src0, src1);
result = ir.IAdd(ir.UConvert(64, result), src2);
const IR::VectorReg dst_reg{inst.dst[0].code};
ir.SetVectorReg64(dst_reg, result);
SetDst64(inst.dst[0], result);
}
void Translator::V_CMP_U32(ConditionOp op, bool is_signed, bool set_exec, const GcnInst& inst) {

View file

@ -145,7 +145,7 @@ void IREmitter::SetThreadBitScalarReg(IR::ScalarReg reg, const U1& value) {
template <>
U32 IREmitter::GetScalarReg(IR::ScalarReg reg) {
return Inst<U32>(Opcode::GetScalarRegister, reg, Imm32(32));
return Inst<U32>(Opcode::GetScalarRegister, reg);
}
template <>
@ -153,19 +153,9 @@ F32 IREmitter::GetScalarReg(IR::ScalarReg reg) {
return BitCast<F32>(GetScalarReg<U32>(reg));
}
template <>
U64 IREmitter::GetScalarReg(IR::ScalarReg reg) {
return Inst<U64>(Opcode::GetScalarRegister, reg, Imm32(64));
}
template <>
F64 IREmitter::GetScalarReg(IR::ScalarReg reg) {
return BitCast<F64>(GetScalarReg<U64>(reg));
}
template <>
U32 IREmitter::GetVectorReg(IR::VectorReg reg) {
return Inst<U32>(Opcode::GetVectorRegister, reg, Imm32(32));
return Inst<U32>(Opcode::GetVectorRegister, reg);
}
template <>
@ -173,36 +163,16 @@ F32 IREmitter::GetVectorReg(IR::VectorReg reg) {
return BitCast<F32>(GetVectorReg<U32>(reg));
}
template <>
U64 IREmitter::GetVectorReg(IR::VectorReg reg) {
return Inst<U64>(Opcode::GetVectorRegister, reg, Imm32(64));
}
template <>
F64 IREmitter::GetVectorReg(IR::VectorReg reg) {
return BitCast<F64>(GetVectorReg<U64>(reg));
}
void IREmitter::SetScalarReg(IR::ScalarReg reg, const U32F32& value) {
const U32 value_typed = value.Type() == Type::F32 ? BitCast<U32>(F32{value}) : U32{value};
Inst(Opcode::SetScalarRegister, reg, value_typed);
}
void IREmitter::SetScalarReg64(IR::ScalarReg reg, const U64F64& value) {
const U64 value_typed = value.Type() == Type::F64 ? BitCast<U64>(F64{value}) : U64{value};
Inst(Opcode::SetScalarRegister, reg, value_typed);
}
void IREmitter::SetVectorReg(IR::VectorReg reg, const U32F32& value) {
const U32 value_typed = value.Type() == Type::F32 ? BitCast<U32>(F32{value}) : U32{value};
Inst(Opcode::SetVectorRegister, reg, value_typed);
}
void IREmitter::SetVectorReg64(IR::VectorReg reg, const U64F64& value) {
const U64 value_typed = value.Type() == Type::F64 ? BitCast<U64>(F64{value}) : U64{value};
Inst(Opcode::SetVectorRegister, reg, value_typed);
}
U1 IREmitter::GetGotoVariable(u32 id) {
return Inst<U1>(Opcode::GetGotoVariable, id);
}
@ -243,7 +213,6 @@ U1 IREmitter::GetExec() {
}
U1 IREmitter::GetVcc() {
// FIXME Should it be a thread bit?
return Inst<U1>(Opcode::GetVcc);
}
@ -1065,8 +1034,18 @@ U32 IREmitter::BitwiseAnd(const U32& a, const U32& b) {
return Inst<U32>(Opcode::BitwiseAnd32, a, b);
}
U32 IREmitter::BitwiseOr(const U32& a, const U32& b) {
return Inst<U32>(Opcode::BitwiseOr32, a, b);
U32U64 IREmitter::BitwiseOr(const U32U64& a, const U32U64& b) {
if (a.Type() != b.Type()) {
UNREACHABLE_MSG("Mismatching types {} and {}", a.Type(), b.Type());
}
switch (a.Type()) {
case Type::U32:
return Inst<U32>(Opcode::BitwiseOr32, a, b);
case Type::U64:
return Inst<U64>(Opcode::BitwiseOr64, a, b);
default:
ThrowInvalidType(a.Type());
}
}
U32 IREmitter::BitwiseXor(const U32& a, const U32& b) {

View file

@ -57,9 +57,7 @@ public:
template <typename T = U32>
[[nodiscard]] T GetVectorReg(IR::VectorReg reg);
void SetScalarReg(IR::ScalarReg reg, const U32F32& value);
void SetScalarReg64(IR::ScalarReg reg, const U64F64& value);
void SetVectorReg(IR::VectorReg reg, const U32F32& value);
void SetVectorReg64(IR::VectorReg reg, const U64F64& value);
[[nodiscard]] U1 GetGotoVariable(u32 id);
void SetGotoVariable(u32 id, const U1& value);
@ -169,7 +167,7 @@ public:
[[nodiscard]] U32U64 ShiftRightLogical(const U32U64& base, const U32& shift);
[[nodiscard]] U32U64 ShiftRightArithmetic(const U32U64& base, const U32& shift);
[[nodiscard]] U32 BitwiseAnd(const U32& a, const U32& b);
[[nodiscard]] U32 BitwiseOr(const U32& a, const U32& b);
[[nodiscard]] U32U64 BitwiseOr(const U32U64& a, const U32U64& b);
[[nodiscard]] U32 BitwiseXor(const U32& a, const U32& b);
[[nodiscard]] U32 BitFieldInsert(const U32& base, const U32& insert, const U32& offset,
const U32& count);

View file

@ -43,9 +43,9 @@ OPCODE(WriteSharedU128, Void, U32,
OPCODE(GetUserData, U32, ScalarReg, )
OPCODE(GetThreadBitScalarReg, U1, ScalarReg, )
OPCODE(SetThreadBitScalarReg, Void, ScalarReg, U1, )
OPCODE(GetScalarRegister, U32, ScalarReg, U32, )
OPCODE(GetScalarRegister, U32, ScalarReg, )
OPCODE(SetScalarRegister, Void, ScalarReg, U32, )
OPCODE(GetVectorRegister, U32, VectorReg, U32, )
OPCODE(GetVectorRegister, U32, VectorReg, )
OPCODE(SetVectorRegister, Void, VectorReg, U32, )
OPCODE(GetGotoVariable, U1, U32, )
OPCODE(SetGotoVariable, Void, U32, U1, )
@ -243,6 +243,7 @@ OPCODE(ShiftRightArithmetic32, U32, U32,
OPCODE(ShiftRightArithmetic64, U64, U64, U32, )
OPCODE(BitwiseAnd32, U32, U32, U32, )
OPCODE(BitwiseOr32, U32, U32, U32, )
OPCODE(BitwiseOr64, U64, U64, U64, )
OPCODE(BitwiseXor32, U32, U32, U32, )
OPCODE(BitFieldInsert, U32, U32, U32, U32, U32, )
OPCODE(BitFieldSExtract, U32, U32, U32, U32, )

View file

@ -351,34 +351,12 @@ void VisitInst(Pass& pass, IR::Block* block, const IR::Block::iterator& iter) {
const IR::ScalarReg reg{inst.Arg(0).ScalarReg()};
bool thread_bit = opcode == IR::Opcode::GetThreadBitScalarReg;
IR::Value value = pass.ReadVariable(reg, block, thread_bit);
if (!thread_bit) {
size_t bit_size{inst.Arg(1).U32()};
if (bit_size == 32 && value.Type() == IR::Type::U64) {
auto it{block->PrependNewInst(iter, IR::Opcode::ConvertU32U64, {value})};
value = IR::U32{IR::Value{&*it}};
} else if (bit_size == 64 && value.Type() == IR::Type::U32) {
auto it{block->PrependNewInst(iter, IR::Opcode::ConvertU64U32, {value})};
value = IR::U64{IR::Value{&*it}};
}
}
inst.ReplaceUsesWith(value);
break;
}
case IR::Opcode::GetVectorRegister: {
const IR::VectorReg reg{inst.Arg(0).VectorReg()};
IR::Value value = pass.ReadVariable(reg, block);
size_t bit_size{inst.Arg(1).U32()};
if (bit_size == 32 && value.Type() == IR::Type::U64) {
auto it{block->PrependNewInst(iter, IR::Opcode::ConvertU32U64, {value})};
value = IR::U32{IR::Value{&*it}};
} else if (bit_size == 64 && value.Type() == IR::Type::U32) {
auto it{block->PrependNewInst(iter, IR::Opcode::ConvertU64U32, {value})};
value = IR::U64{IR::Value{&*it}};
}
inst.ReplaceUsesWith(value);
break;
}