shader recompiler: auto cast between u32 and u64 during ssa pass

This commit is contained in:
Vinicius Rangel 2024-07-24 13:05:34 -03:00
parent 21ce67e8a0
commit 09946f15a2
No known key found for this signature in database
GPG key ID: A5B154D904B761D9
3 changed files with 46 additions and 12 deletions

View file

@ -145,7 +145,7 @@ void IREmitter::SetThreadBitScalarReg(IR::ScalarReg reg, const U1& value) {
template <>
U32 IREmitter::GetScalarReg(IR::ScalarReg reg) {
return Inst<U32>(Opcode::GetScalarRegister, reg);
return Inst<U32>(Opcode::GetScalarRegister, reg, Imm32(32));
}
template <>
@ -155,7 +155,7 @@ F32 IREmitter::GetScalarReg(IR::ScalarReg reg) {
template <>
U64 IREmitter::GetScalarReg(IR::ScalarReg reg) {
return Inst<U64>(Opcode::GetScalarRegister, reg);
return Inst<U64>(Opcode::GetScalarRegister, reg, Imm32(64));
}
template <>
@ -165,7 +165,7 @@ F64 IREmitter::GetScalarReg(IR::ScalarReg reg) {
template <>
U32 IREmitter::GetVectorReg(IR::VectorReg reg) {
return Inst<U32>(Opcode::GetVectorRegister, reg);
return Inst<U32>(Opcode::GetVectorRegister, reg, Imm32(32));
}
template <>
@ -175,7 +175,7 @@ F32 IREmitter::GetVectorReg(IR::VectorReg reg) {
template <>
U64 IREmitter::GetVectorReg(IR::VectorReg reg) {
return Inst<U64>(Opcode::GetVectorRegister, reg);
return Inst<U64>(Opcode::GetVectorRegister, reg, Imm32(64));
}
template <>
@ -1278,6 +1278,13 @@ U16U32U64 IREmitter::UConvert(size_t result_bitsize, const U16U32U64& value) {
default:
ThrowInvalidType(value.Type());
}
case 32:
switch (value.Type()) {
case Type::U64:
return Inst<U32>(Opcode::ConvertU32U64, value);
default:
ThrowInvalidType(value.Type());
}
case 64:
switch (value.Type()) {
case Type::U32:

View file

@ -43,9 +43,9 @@ OPCODE(WriteSharedU128, Void, U32,
OPCODE(GetUserData, U32, ScalarReg, )
OPCODE(GetThreadBitScalarReg, U1, ScalarReg, )
OPCODE(SetThreadBitScalarReg, Void, ScalarReg, U1, )
OPCODE(GetScalarRegister, U32, ScalarReg, )
OPCODE(GetScalarRegister, U32, ScalarReg, U32, )
OPCODE(SetScalarRegister, Void, ScalarReg, U32, )
OPCODE(GetVectorRegister, U32, VectorReg, )
OPCODE(GetVectorRegister, U32, VectorReg, U32, )
OPCODE(SetVectorRegister, Void, VectorReg, U32, )
OPCODE(GetGotoVariable, U1, U32, )
OPCODE(SetGotoVariable, Void, U32, U1, )
@ -291,6 +291,7 @@ OPCODE(ConvertF64U32, F64, U32,
OPCODE(ConvertF32U16, F32, U16, )
OPCODE(ConvertU16U32, U16, U32, )
OPCODE(ConvertU64U32, U64, U32, )
OPCODE(ConvertU32U64, U32, U64, )
OPCODE(ConvertU64F32, U64, F32, )
// Image operations

View file

@ -310,7 +310,8 @@ private:
DefTable current_def;
};
void VisitInst(Pass& pass, IR::Block* block, IR::Inst& inst) {
void VisitInst(Pass& pass, IR::Block* block, const IR::Block::iterator& iter) {
auto& inst{*iter};
const IR::Opcode opcode{inst.GetOpcode()};
switch (opcode) {
case IR::Opcode::SetThreadBitScalarReg:
@ -348,13 +349,37 @@ void VisitInst(Pass& pass, IR::Block* block, IR::Inst& inst) {
case IR::Opcode::GetThreadBitScalarReg:
case IR::Opcode::GetScalarRegister: {
const IR::ScalarReg reg{inst.Arg(0).ScalarReg()};
inst.ReplaceUsesWith(
pass.ReadVariable(reg, block, opcode == IR::Opcode::GetThreadBitScalarReg));
bool thread_bit = opcode == IR::Opcode::GetThreadBitScalarReg;
IR::Value value = pass.ReadVariable(reg, block, thread_bit);
if (!thread_bit) {
size_t bit_size{inst.Arg(1).U32()};
if (bit_size == 32 && value.Type() == IR::Type::U64) {
auto it{block->PrependNewInst(iter, IR::Opcode::ConvertU32U64, {value})};
value = IR::U32{IR::Value{&*it}};
} else if (bit_size == 64 && value.Type() == IR::Type::U32) {
auto it{block->PrependNewInst(iter, IR::Opcode::ConvertU64U32, {value})};
value = IR::U64{IR::Value{&*it}};
}
}
inst.ReplaceUsesWith(value);
break;
}
case IR::Opcode::GetVectorRegister: {
const IR::VectorReg reg{inst.Arg(0).VectorReg()};
inst.ReplaceUsesWith(pass.ReadVariable(reg, block));
IR::Value value = pass.ReadVariable(reg, block);
size_t bit_size{inst.Arg(1).U32()};
if (bit_size == 32 && value.Type() == IR::Type::U64) {
auto it{block->PrependNewInst(iter, IR::Opcode::ConvertU32U64, {value})};
value = IR::U32{IR::Value{&*it}};
} else if (bit_size == 64 && value.Type() == IR::Type::U32) {
auto it{block->PrependNewInst(iter, IR::Opcode::ConvertU64U32, {value})};
value = IR::U64{IR::Value{&*it}};
}
inst.ReplaceUsesWith(value);
break;
}
case IR::Opcode::GetGotoVariable:
@ -384,8 +409,9 @@ void VisitInst(Pass& pass, IR::Block* block, IR::Inst& inst) {
}
void VisitBlock(Pass& pass, IR::Block* block) {
for (IR::Inst& inst : block->Instructions()) {
VisitInst(pass, block, inst);
const auto end{block->end()};
for (auto iter = block->begin(); iter != end; ++iter) {
VisitInst(pass, block, iter);
}
pass.SealBlock(block);
}