Add constant folding and algebraic simplification optimizations, nits

This commit is contained in:
gdkchan 2019-03-29 21:01:37 -03:00
parent 4a7dfadb0f
commit af82b1214e
13 changed files with 614 additions and 21 deletions

View file

@ -1,7 +1,6 @@
using Ryujinx.Graphics.Shader.IntermediateRepresentation;
using Ryujinx.Graphics.Shader.StructuredIr;
using System;
using System.Globalization;
using static Ryujinx.Graphics.Shader.CodeGen.Glsl.TypeConversion;
using static Ryujinx.Graphics.Shader.StructuredIr.InstructionInfo;
@ -24,7 +23,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
return OperandManager.GetAttributeName(context, operand);
case OperandType.Constant:
return "0x" + operand.Value.ToString("X8", CultureInfo.InvariantCulture);
return NumberFormatter.FormatInt(operand.Value);
case OperandType.ConstantBuffer:
return OperandManager.GetConstantBufferName(context.ShaderType, operand);
@ -415,7 +414,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
private static string GetSoureExpr(CodeGenContext context, IAstNode node, VariableType dstType)
{
return ReinterpretCast(GetExpression(context, node), OperandManager.GetNodeDestType(node), dstType);
return ReinterpretCast(context, node, OperandManager.GetNodeDestType(node), dstType);
}
}
}

View file

@ -0,0 +1,87 @@
using Ryujinx.Graphics.Shader.StructuredIr;
using System;
using System.Globalization;
namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
{
static class NumberFormatter
{
private const int MaxDecimal = 256;
public static bool TryFormat(int value, VariableType dstType, out string formatted)
{
if (dstType == VariableType.F32)
{
return TryFormatFloat(BitConverter.Int32BitsToSingle(value), out formatted);
}
else if (dstType == VariableType.S32)
{
formatted = FormatInt(value);
}
else if (dstType == VariableType.U32)
{
formatted = FormatUint((uint)value);
}
else if (dstType == VariableType.Bool)
{
formatted = value != 0 ? "true" : "false";
}
else
{
throw new ArgumentException($"Invalid variable type \"{dstType}\".");
}
return true;
}
public static bool TryFormatFloat(float value, out string formatted)
{
if (float.IsNaN(value) || float.IsInfinity(value))
{
formatted = null;
return false;
}
formatted = value.ToString("G9", CultureInfo.InvariantCulture);
return true;
}
public static string FormatInt(int value, VariableType dstType)
{
if (dstType == VariableType.S32)
{
return FormatInt(value);
}
else if (dstType == VariableType.U32)
{
return FormatUint((uint)value);
}
else
{
throw new ArgumentException($"Invalid variable type \"{dstType}\".");
}
}
public static string FormatInt(int value)
{
if (value <= MaxDecimal && value >= -MaxDecimal)
{
return value.ToString(CultureInfo.InvariantCulture);
}
return "0x" + value.ToString("X", CultureInfo.InvariantCulture);
}
public static string FormatUint(uint value)
{
if (value <= MaxDecimal && value >= 0)
{
return value.ToString(CultureInfo.InvariantCulture) + "u";
}
return "0x" + value.ToString("X", CultureInfo.InvariantCulture) + "u";
}
}
}

View file

@ -6,6 +6,25 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
{
static class TypeConversion
{
public static string ReinterpretCast(
CodeGenContext context,
IAstNode node,
VariableType srcType,
VariableType dstType)
{
if (node is AstOperand operand && operand.Type == OperandType.Constant)
{
if (NumberFormatter.TryFormat(operand.Value, dstType, out string formatted))
{
return formatted;
}
}
string expr = Instructions.GetExpression(context, node);
return ReinterpretCast(expr, srcType, dstType);
}
public static string ReinterpretCast(string expr, VariableType srcType, VariableType dstType)
{
if (srcType == dstType)
@ -25,13 +44,14 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
{
switch (srcType)
{
case VariableType.S32: return $"intBitsToFloat({expr})";
case VariableType.U32: return $"uintBitsToFloat({expr})";
case VariableType.Bool: return $"intBitsToFloat({ReinterpretBoolToInt(expr, VariableType.S32)})";
case VariableType.S32: return $"intBitsToFloat({expr})";
case VariableType.U32: return $"uintBitsToFloat({expr})";
}
}
else if (srcType == VariableType.Bool)
{
return $"(({expr}) ? {IrConsts.True} : {IrConsts.False})";
return ReinterpretBoolToInt(expr, dstType);
}
else if (dstType == VariableType.Bool)
{
@ -48,5 +68,13 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
throw new ArgumentException($"Invalid reinterpret cast from \"{srcType}\" to \"{dstType}\".");
}
private static string ReinterpretBoolToInt(string expr, VariableType dstType)
{
string trueExpr = NumberFormatter.FormatInt(IrConsts.True, dstType);
string falseExpr = NumberFormatter.FormatInt(IrConsts.False, dstType);
return $"(({expr}) ? {trueExpr} : {falseExpr})";
}
}
}

View file

@ -1,13 +1,14 @@
using Ryujinx.Graphics.Shader.Decoders;
using System;
using System.Collections.Generic;
namespace Ryujinx.Graphics.Shader.IntermediateRepresentation
{
class Operand
{
private const int AttrSlotBits = 5;
private const int AttrSlotLsb = 32 - AttrSlotBits;
private const int AttrSlotMask = (1 << AttrSlotBits) - 1;
private const int CbufSlotBits = 5;
private const int CbufSlotLsb = 32 - CbufSlotBits;
private const int CbufSlotMask = (1 << CbufSlotBits) - 1;
public OperandType Type { get; }
@ -47,7 +48,7 @@ namespace Ryujinx.Graphics.Shader.IntermediateRepresentation
private static int PackCbufInfo(int slot, int offset)
{
return (slot << AttrSlotLsb) | offset;
return (slot << CbufSlotLsb) | offset;
}
private static int PackRegInfo(int index, RegisterType type)
@ -57,17 +58,22 @@ namespace Ryujinx.Graphics.Shader.IntermediateRepresentation
public int GetCbufSlot()
{
return (Value >> AttrSlotLsb) & AttrSlotMask;
return (Value >> CbufSlotLsb) & CbufSlotMask;
}
public int GetCbufOffset()
{
return Value & ~(AttrSlotMask << AttrSlotLsb);
return Value & ~(CbufSlotMask << CbufSlotLsb);
}
public Register GetRegister()
{
return new Register(Value & 0xffffff, (RegisterType)(Value >> 24));
}
public float AsFloat()
{
return BitConverter.Int32BitsToSingle(Value);
}
}
}

View file

@ -2,7 +2,7 @@ namespace Ryujinx.Graphics.Shader.IntermediateRepresentation
{
class Operation : INode
{
public Instruction Inst { get; }
public Instruction Inst { get; private set; }
private Operand _dest;
@ -57,5 +57,22 @@ namespace Ryujinx.Graphics.Shader.IntermediateRepresentation
_sources[index] = operand;
}
public void TurnIntoCopy(Operand source)
{
Inst = Instruction.Copy;
foreach (Operand oldSrc in _sources)
{
if (oldSrc.Type == OperandType.LocalVariable)
{
oldSrc.UseOps.Remove(this);
}
}
source.UseOps.Add(this);
_sources = new Operand[] { source };
}
}
}

View file

@ -3,7 +3,7 @@ using System.Collections.Generic;
namespace Ryujinx.Graphics.Shader.StructuredIr
{
static class PhiFunction
static class PhiFunctions
{
public static void Remove(BasicBlock[] blocks)
{

View file

@ -6,7 +6,7 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
{
public static StructuredProgramInfo MakeStructuredProgram(BasicBlock[] blocks)
{
PhiFunction.Remove(blocks);
PhiFunctions.Remove(blocks);
StructuredProgramContext context = new StructuredProgramContext(blocks);

View file

@ -61,11 +61,11 @@ namespace Ryujinx.Graphics.Shader.Translation
entry.ImmediateDominator = entry;
bool changed;
bool modified;
do
{
changed = false;
modified = false;
for (int blkIndex = postOrderBlocks.Count - 2; blkIndex >= 0; blkIndex--)
{
@ -92,11 +92,11 @@ namespace Ryujinx.Graphics.Shader.Translation
{
block.ImmediateDominator = newIDom;
changed = true;
modified = true;
}
}
}
while (changed);
while (modified);
}
public static void FindDominanceFrontiers(BasicBlock[] blocks)

View file

@ -0,0 +1,302 @@
using Ryujinx.Graphics.Shader.Decoders;
using Ryujinx.Graphics.Shader.IntermediateRepresentation;
using System;
using static Ryujinx.Graphics.Shader.IntermediateRepresentation.OperandHelper;
namespace Ryujinx.Graphics.Shader.Translation.Optimizations
{
static class ConstantFolding
{
public static void FoldOperation(Operation operation)
{
if (!AreAllSourcesConstant(operation))
{
return;
}
switch (operation.Inst)
{
case Instruction.Add:
EvaluateBinary(operation, (x, y) => x + y);
break;
case Instruction.BitwiseAnd:
EvaluateBinary(operation, (x, y) => x & y);
break;
case Instruction.BitwiseExclusiveOr:
EvaluateBinary(operation, (x, y) => x ^ y);
break;
case Instruction.BitwiseNot:
EvaluateUnary(operation, (x) => ~x);
break;
case Instruction.BitwiseOr:
EvaluateBinary(operation, (x, y) => x | y);
break;
case Instruction.BitfieldExtractS32:
BitfieldExtractS32(operation);
break;
case Instruction.BitfieldExtractU32:
BitfieldExtractU32(operation);
break;
case Instruction.Clamp:
EvaluateTernary(operation, (x, y, z) => Math.Clamp(x, y, z));
break;
case Instruction.ClampU32:
EvaluateTernary(operation, (x, y, z) => (int)Math.Clamp((uint)x, (uint)y, (uint)z));
break;
case Instruction.CompareEqual:
EvaluateBinary(operation, (x, y) => x == y);
break;
case Instruction.CompareGreater:
EvaluateBinary(operation, (x, y) => x > y);
break;
case Instruction.CompareGreaterOrEqual:
EvaluateBinary(operation, (x, y) => x >= y);
break;
case Instruction.CompareGreaterOrEqualU32:
EvaluateBinary(operation, (x, y) => (uint)x >= (uint)y);
break;
case Instruction.CompareGreaterU32:
EvaluateBinary(operation, (x, y) => (uint)x > (uint)y);
break;
case Instruction.CompareLess:
EvaluateBinary(operation, (x, y) => x < y);
break;
case Instruction.CompareLessOrEqual:
EvaluateBinary(operation, (x, y) => x <= y);
break;
case Instruction.CompareLessOrEqualU32:
EvaluateBinary(operation, (x, y) => (uint)x <= (uint)y);
break;
case Instruction.CompareLessU32:
EvaluateBinary(operation, (x, y) => (uint)x < (uint)y);
break;
case Instruction.Divide:
EvaluateBinary(operation, (x, y) => y != 0 ? x / y : 0);
break;
case Instruction.FP | Instruction.Add:
EvaluateFPBinary(operation, (x, y) => x + y);
break;
case Instruction.FP | Instruction.Clamp:
EvaluateFPTernary(operation, (x, y, z) => Math.Clamp(x, y, z));
break;
case Instruction.FP | Instruction.CompareEqual:
EvaluateFPBinary(operation, (x, y) => x == y);
break;
case Instruction.FP | Instruction.CompareGreater:
EvaluateFPBinary(operation, (x, y) => x > y);
break;
case Instruction.FP | Instruction.CompareGreaterOrEqual:
EvaluateFPBinary(operation, (x, y) => x >= y);
break;
case Instruction.FP | Instruction.CompareLess:
EvaluateFPBinary(operation, (x, y) => x < y);
break;
case Instruction.FP | Instruction.CompareLessOrEqual:
EvaluateFPBinary(operation, (x, y) => x <= y);
break;
case Instruction.FP | Instruction.Divide:
EvaluateFPBinary(operation, (x, y) => x / y);
break;
case Instruction.FP | Instruction.Multiply:
EvaluateFPBinary(operation, (x, y) => x * y);
break;
case Instruction.FP | Instruction.Negate:
EvaluateFPUnary(operation, (x) => -x);
break;
case Instruction.FP | Instruction.Subtract:
EvaluateFPBinary(operation, (x, y) => x - y);
break;
case Instruction.IsNan:
EvaluateFPUnary(operation, (x) => float.IsNaN(x));
break;
case Instruction.Maximum:
EvaluateBinary(operation, (x, y) => Math.Max(x, y));
break;
case Instruction.MaximumU32:
EvaluateBinary(operation, (x, y) => (int)Math.Max((uint)x, (uint)y));
break;
case Instruction.Minimum:
EvaluateBinary(operation, (x, y) => Math.Min(x, y));
break;
case Instruction.MinimumU32:
EvaluateBinary(operation, (x, y) => (int)Math.Min((uint)x, (uint)y));
break;
case Instruction.Multiply:
EvaluateBinary(operation, (x, y) => x * y);
break;
case Instruction.Negate:
EvaluateUnary(operation, (x) => -x);
break;
case Instruction.ShiftLeft:
EvaluateBinary(operation, (x, y) => x << y);
break;
case Instruction.ShiftRightS32:
EvaluateBinary(operation, (x, y) => x >> y);
break;
case Instruction.ShiftRightU32:
EvaluateBinary(operation, (x, y) => (int)((uint)x >> y));
break;
case Instruction.Subtract:
EvaluateBinary(operation, (x, y) => x - y);
break;
}
}
private static bool AreAllSourcesConstant(Operation operation)
{
for (int index = 0; index < operation.SourcesCount; index++)
{
if (operation.GetSource(index).Type != OperandType.Constant)
{
return false;
}
}
return true;
}
private static void BitfieldExtractS32(Operation operation)
{
int value = GetBitfieldExtractValue(operation);
int shift = 32 - operation.GetSource(2).Value;
value = (value << shift) >> shift;
operation.TurnIntoCopy(Const(value));
}
private static void BitfieldExtractU32(Operation operation)
{
operation.TurnIntoCopy(Const(GetBitfieldExtractValue(operation)));
}
private static int GetBitfieldExtractValue(Operation operation)
{
int value = operation.GetSource(0).Value;
int lsb = operation.GetSource(1).Value;
int length = operation.GetSource(2).Value;
return value.Extract(lsb, length);
}
private static void FPNegate(Operation operation)
{
float value = operation.GetSource(0).AsFloat();
operation.TurnIntoCopy(ConstF(-value));
}
private static void EvaluateUnary(Operation operation, Func<int, int> op)
{
int x = operation.GetSource(0).Value;
operation.TurnIntoCopy(Const(op(x)));
}
private static void EvaluateFPUnary(Operation operation, Func<float, float> op)
{
float x = operation.GetSource(0).AsFloat();
operation.TurnIntoCopy(ConstF(op(x)));
}
private static void EvaluateFPUnary(Operation operation, Func<float, bool> op)
{
float x = operation.GetSource(0).AsFloat();
operation.TurnIntoCopy(Const(op(x) ? IrConsts.True : IrConsts.False));
}
private static void EvaluateBinary(Operation operation, Func<int, int, int> op)
{
int x = operation.GetSource(0).Value;
int y = operation.GetSource(1).Value;
operation.TurnIntoCopy(Const(op(x, y)));
}
private static void EvaluateBinary(Operation operation, Func<int, int, bool> op)
{
int x = operation.GetSource(0).Value;
int y = operation.GetSource(1).Value;
operation.TurnIntoCopy(Const(op(x, y) ? IrConsts.True : IrConsts.False));
}
private static void EvaluateFPBinary(Operation operation, Func<float, float, float> op)
{
float x = operation.GetSource(0).AsFloat();
float y = operation.GetSource(1).AsFloat();
operation.TurnIntoCopy(ConstF(op(x, y)));
}
private static void EvaluateFPBinary(Operation operation, Func<float, float, bool> op)
{
float x = operation.GetSource(0).AsFloat();
float y = operation.GetSource(1).AsFloat();
operation.TurnIntoCopy(Const(op(x, y) ? IrConsts.True : IrConsts.False));
}
private static void EvaluateTernary(Operation operation, Func<int, int, int, int> op)
{
int x = operation.GetSource(0).Value;
int y = operation.GetSource(1).Value;
int z = operation.GetSource(2).Value;
operation.TurnIntoCopy(Const(op(x, y, z)));
}
private static void EvaluateFPTernary(Operation operation, Func<float, float, float, float> op)
{
float x = operation.GetSource(0).AsFloat();
float y = operation.GetSource(1).AsFloat();
float z = operation.GetSource(2).AsFloat();
operation.TurnIntoCopy(ConstF(op(x, y, z)));
}
}
}

View file

@ -1,7 +1,7 @@
using Ryujinx.Graphics.Shader.IntermediateRepresentation;
using System.Collections.Generic;
namespace Ryujinx.Graphics.Shader.Translation
namespace Ryujinx.Graphics.Shader.Translation.Optimizations
{
static class Optimizer
{
@ -39,6 +39,10 @@ namespace Ryujinx.Graphics.Shader.Translation
continue;
}
ConstantFolding.FoldOperation(operation);
Simplification.Simplify(operation);
if (operation.Inst == Instruction.Copy && DestIsLocalVar(operation))
{
PropagateCopy(operation);

View file

@ -0,0 +1,147 @@
using Ryujinx.Graphics.Shader.IntermediateRepresentation;
using static Ryujinx.Graphics.Shader.IntermediateRepresentation.OperandHelper;
namespace Ryujinx.Graphics.Shader.Translation.Optimizations
{
static class Simplification
{
private const int AllOnes = ~0;
public static void Simplify(Operation operation)
{
switch (operation.Inst)
{
case Instruction.Add:
case Instruction.BitwiseExclusiveOr:
TryEliminateBinaryOpComutative(operation, 0);
break;
case Instruction.BitwiseAnd:
TryEliminateBitwiseAnd(operation);
break;
case Instruction.BitwiseOr:
TryEliminateBitwiseOr(operation);
break;
case Instruction.ConditionalSelect:
TryEliminateConditionalSelect(operation);
break;
case Instruction.Divide:
TryEliminateBinaryOpY(operation, 1);
break;
case Instruction.Multiply:
TryEliminateBinaryOpComutative(operation, 1);
break;
case Instruction.ShiftLeft:
case Instruction.ShiftRightS32:
case Instruction.ShiftRightU32:
case Instruction.Subtract:
TryEliminateBinaryOpY(operation, 0);
break;
}
}
private static void TryEliminateBitwiseAnd(Operation operation)
{
//Try to recognize and optimize those 3 patterns (in order):
//x & 0xFFFFFFFF == x, 0xFFFFFFFF & y == y,
//x & 0x00000000 == 0x00000000, 0x00000000 & y == 0x00000000
Operand x = operation.GetSource(0);
Operand y = operation.GetSource(1);
if (IsConstEqual(x, AllOnes))
{
operation.TurnIntoCopy(y);
}
else if (IsConstEqual(y, AllOnes))
{
operation.TurnIntoCopy(x);
}
else if (IsConstEqual(x, 0) || IsConstEqual(y, 0))
{
operation.TurnIntoCopy(Const(0));
}
}
private static void TryEliminateBitwiseOr(Operation operation)
{
//Try to recognize and optimize those 3 patterns (in order):
//x | 0x00000000 == x, 0x00000000 | y == y,
//x | 0xFFFFFFFF == 0xFFFFFFFF, 0xFFFFFFFF | y == 0xFFFFFFFF
Operand x = operation.GetSource(0);
Operand y = operation.GetSource(1);
if (IsConstEqual(x, 0))
{
operation.TurnIntoCopy(y);
}
else if (IsConstEqual(y, 0))
{
operation.TurnIntoCopy(x);
}
else if (IsConstEqual(x, AllOnes) || IsConstEqual(y, AllOnes))
{
operation.TurnIntoCopy(Const(AllOnes));
}
}
private static void TryEliminateBinaryOpY(Operation operation, int comparand)
{
Operand x = operation.GetSource(0);
Operand y = operation.GetSource(1);
if (IsConstEqual(y, comparand))
{
operation.TurnIntoCopy(x);
}
}
private static void TryEliminateBinaryOpComutative(Operation operation, int comparand)
{
Operand x = operation.GetSource(0);
Operand y = operation.GetSource(1);
if (IsConstEqual(x, comparand))
{
operation.TurnIntoCopy(y);
}
else if (IsConstEqual(y, comparand))
{
operation.TurnIntoCopy(x);
}
}
private static void TryEliminateConditionalSelect(Operation operation)
{
Operand cond = operation.GetSource(0);
if (cond.Type != OperandType.Constant)
{
return;
}
//The condition is constant, we can turn it into a copy, and select
//the source based on the condition value.
int srcIndex = cond.Value != 0 ? 1 : 2;
Operand source = operation.GetSource(srcIndex);
operation.TurnIntoCopy(source);
}
private static bool IsConstEqual(Operand operand, int comparand)
{
if (operand.Type != OperandType.Constant)
{
return false;
}
return operand.Value == comparand;
}
}
}

View file

@ -4,6 +4,7 @@ using Ryujinx.Graphics.Shader.Decoders;
using Ryujinx.Graphics.Shader.Instructions;
using Ryujinx.Graphics.Shader.IntermediateRepresentation;
using Ryujinx.Graphics.Shader.StructuredIr;
using Ryujinx.Graphics.Shader.Translation.Optimizations;
using static Ryujinx.Graphics.Shader.IntermediateRepresentation.OperandHelper;

View file

@ -30,7 +30,9 @@ namespace Ryujinx.ShaderTools
{
Memory Mem = new Memory(FS);
Translator.Translate(Mem, 0, ShaderType);
string code = Translator.Translate(Mem, 0, ShaderType);
Console.WriteLine(code);
//GlslProgram Program = Decompiler.Decompile(Mem, 0, ShaderType);