From baecbfb31a06c6ccd0c5d5c42850584754441975 Mon Sep 17 00:00:00 2001 From: gdkchan Date: Thu, 4 Apr 2019 17:37:23 -0300 Subject: [PATCH] Remove redundant branches, add expression propagation and other improvements on the code --- .../Shader/CodeGen/Glsl/Declarations.cs | 23 + .../Shader/CodeGen/Glsl/GlslGenerator.cs | 126 ++--- .../Shader/CodeGen/Glsl/Instructions.cs | 458 ++++++++++-------- .../Shader/CodeGen/Glsl/TypeConversion.cs | 8 +- .../IntermediateRepresentation/BasicBlock.cs | 4 +- .../IntermediateRepresentation/INode.cs | 2 + .../IntermediateRepresentation/Operation.cs | 2 + .../IntermediateRepresentation/PhiNode.cs | 2 + .../Shader/StructuredIr/AstAssignment.cs | 23 +- .../Shader/StructuredIr/AstBlock.cs | 46 +- .../Shader/StructuredIr/AstBlockType.cs | 1 + .../Shader/StructuredIr/AstBlockVisitor.cs | 64 +++ .../Shader/StructuredIr/AstDeclaration.cs | 12 - .../Shader/StructuredIr/AstHelper.cs | 47 ++ .../Shader/StructuredIr/AstOperand.cs | 7 + .../Shader/StructuredIr/AstOperation.cs | 29 +- .../Shader/StructuredIr/AstOptimizer.cs | 149 ++++++ .../Shader/StructuredIr/GotoElimination.cs | 27 +- .../Shader/StructuredIr/InstructionInfo.cs | 16 + .../Shader/StructuredIr/StructuredProgram.cs | 26 +- .../StructuredIr/StructuredProgramContext.cs | 57 +-- .../StructuredIr/StructuredProgramInfo.cs | 4 + .../Optimizations/BranchElimination.cs | 64 +++ .../Translation/Optimizations/Optimizer.cs | 7 + 24 files changed, 833 insertions(+), 371 deletions(-) create mode 100644 Ryujinx.Graphics/Shader/StructuredIr/AstBlockVisitor.cs delete mode 100644 Ryujinx.Graphics/Shader/StructuredIr/AstDeclaration.cs create mode 100644 Ryujinx.Graphics/Shader/StructuredIr/AstOptimizer.cs create mode 100644 Ryujinx.Graphics/Shader/Translation/Optimizations/BranchElimination.cs diff --git a/Ryujinx.Graphics/Shader/CodeGen/Glsl/Declarations.cs b/Ryujinx.Graphics/Shader/CodeGen/Glsl/Declarations.cs index 21d4a285fa..de589cc1f5 100644 --- a/Ryujinx.Graphics/Shader/CodeGen/Glsl/Declarations.cs +++ b/Ryujinx.Graphics/Shader/CodeGen/Glsl/Declarations.cs @@ -67,6 +67,29 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl } } + public static void DeclareLocals(CodeGenContext context, StructuredProgramInfo prgInfo) + { + foreach (AstOperand decl in prgInfo.Locals) + { + string name = context.DeclareLocal(decl); + + context.AppendLine(GetVarTypeName(decl.VarType) + " " + name + ";"); + } + } + + private static string GetVarTypeName(VariableType type) + { + switch (type) + { + case VariableType.Bool: return "bool"; + case VariableType.F32: return "float"; + case VariableType.S32: return "int"; + case VariableType.U32: return "uint"; + } + + throw new ArgumentException($"Invalid variable type \"{type}\"."); + } + private static void DeclareUniforms(CodeGenContext context, StructuredProgramInfo prgInfo) { foreach (int cbufSlot in prgInfo.ConstantBuffers.OrderBy(x => x)) diff --git a/Ryujinx.Graphics/Shader/CodeGen/Glsl/GlslGenerator.cs b/Ryujinx.Graphics/Shader/CodeGen/Glsl/GlslGenerator.cs index 607d6786e9..801224a2c7 100644 --- a/Ryujinx.Graphics/Shader/CodeGen/Glsl/GlslGenerator.cs +++ b/Ryujinx.Graphics/Shader/CodeGen/Glsl/GlslGenerator.cs @@ -9,47 +9,73 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl { class GlslGenerator { - public string Generate(StructuredProgramInfo prgInfo, GalShaderType shaderType) + public string Generate(StructuredProgramInfo info, GalShaderType shaderType) { - CodeGenContext cgContext = new CodeGenContext(prgInfo, shaderType); + CodeGenContext context = new CodeGenContext(info, shaderType); - Declarations.Declare(cgContext, prgInfo); + Declarations.Declare(context, info); - PrintBlock(cgContext, prgInfo.MainBlock); + PrintMainBlock(context, info); - return cgContext.GetCode(); + return context.GetCode(); + } + + private void PrintMainBlock(CodeGenContext context, StructuredProgramInfo info) + { + context.AppendLine("void main()"); + + context.EnterScope(); + + Declarations.DeclareLocals(context, info); + + PrintBlock(context, info.MainBlock); + + context.LeaveScope(); } private void PrintBlock(CodeGenContext context, AstBlock block) { - switch (block.Type) + AstBlockVisitor visitor = new AstBlockVisitor(block); + + visitor.BlockEntered += (sender, e) => { - case AstBlockType.DoWhile: - context.AppendLine("do"); - break; - - case AstBlockType.Else: - context.AppendLine("else"); - break; - - case AstBlockType.If: - context.AppendLine($"if ({GetCondExpr(context, block.Condition)})"); - break; - - case AstBlockType.Main: - context.AppendLine("void main()"); - break; - } - - context.EnterScope(); - - foreach (IAstNode node in block) - { - if (node is AstBlock subBlock) + switch (e.Block.Type) { - PrintBlock(context, subBlock); + case AstBlockType.DoWhile: + context.AppendLine("do"); + break; + + case AstBlockType.Else: + context.AppendLine("else"); + break; + + case AstBlockType.ElseIf: + context.AppendLine($"else if ({GetCondExpr(context, e.Block.Condition)})"); + break; + + case AstBlockType.If: + context.AppendLine($"if ({GetCondExpr(context, e.Block.Condition)})"); + break; + + default: throw new InvalidOperationException($"Found unexpected block type \"{e.Block.Type}\"."); } - else if (node is AstOperation operation) + + context.EnterScope(); + }; + + visitor.BlockLeft += (sender, e) => + { + context.LeaveScope(); + + if (e.Block.Type == AstBlockType.DoWhile) + { + context.AppendLine($"while ({GetCondExpr(context, e.Block.Condition)});"); + } + }; + + foreach (IAstNode node in visitor.Visit()) + { + if (node is AstOperation operation) { if (operation.Inst == Instruction.Return) { @@ -58,60 +84,38 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl context.AppendLine(Instructions.GetExpression(context, operation) + ";"); } - else if (node is AstAssignment asg) + else if (node is AstAssignment assignment) { - VariableType srcType = OperandManager.GetNodeDestType(asg.Source); - VariableType dstType = OperandManager.GetNodeDestType(asg.Destination); + VariableType srcType = OperandManager.GetNodeDestType(assignment.Source); + VariableType dstType = OperandManager.GetNodeDestType(assignment.Destination); string dest; - if (asg.Destination is AstOperand operand && operand.Type == OperandType.Attribute) + if (assignment.Destination is AstOperand operand && operand.Type == OperandType.Attribute) { dest = OperandManager.GetOutAttributeName(context, operand); } else { - dest = Instructions.GetExpression(context, asg.Destination); + dest = Instructions.GetExpression(context, assignment.Destination); } - string src = ReinterpretCast(context, asg.Source, srcType, dstType); + string src = ReinterpretCast(context, assignment.Source, srcType, dstType); context.AppendLine(dest + " = " + src + ";"); } - else if (node is AstDeclaration decl && decl.Operand.Type != OperandType.Undefined) + else { - string name = context.DeclareLocal(decl.Operand); - - context.AppendLine(GetVarTypeName(decl.Operand.VarType) + " " + name + ";"); + throw new InvalidOperationException($"Found unexpected node type \"{node?.GetType().Name ?? "null"}\"."); } } - - context.LeaveScope(); - - if (block.Type == AstBlockType.DoWhile) - { - context.AppendLine($"while ({GetCondExpr(context, block.Condition)});"); - } } private static string GetCondExpr(CodeGenContext context, IAstNode cond) { VariableType srcType = OperandManager.GetNodeDestType(cond); - return ReinterpretCast(Instructions.GetExpression(context, cond), srcType, VariableType.Bool); - } - - private string GetVarTypeName(VariableType type) - { - switch (type) - { - case VariableType.Bool: return "bool"; - case VariableType.F32: return "float"; - case VariableType.S32: return "int"; - case VariableType.U32: return "uint"; - } - - throw new ArgumentException($"Invalid variable type \"{type}\"."); + return ReinterpretCast(context, cond, srcType, VariableType.Bool); } private static void PrepareForReturn(CodeGenContext context) diff --git a/Ryujinx.Graphics/Shader/CodeGen/Glsl/Instructions.cs b/Ryujinx.Graphics/Shader/CodeGen/Glsl/Instructions.cs index 896904bbfd..fe6d2e0ce0 100644 --- a/Ryujinx.Graphics/Shader/CodeGen/Glsl/Instructions.cs +++ b/Ryujinx.Graphics/Shader/CodeGen/Glsl/Instructions.cs @@ -9,6 +9,112 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl { static class Instructions { + [Flags] + private enum InstFlags + { + OpNullary = 0, + OpUnary = 1, + OpBinary = 2, + OpTernary = 3, + + CallNullary = Call | 0, + CallUnary = Call | 1, + CallBinary = Call | 2, + CallTernary = Call | 3, + CallQuaternary = Call | 4, + + Call = 1 << 8, + + ArityMask = 0xff + } + + private struct InstInfo + { + public InstFlags Flags { get; } + + public string OpName { get; } + + public int Precedence { get; } + + public InstInfo(InstFlags flags, string opName, int precedence) + { + Flags = flags; + OpName = opName; + Precedence = precedence; + } + } + + private static InstInfo[] _infoTbl; + + static Instructions() + { + _infoTbl = new InstInfo[(int)Instruction.Count]; + + Add(Instruction.Absolute, InstFlags.CallUnary, "abs"); + Add(Instruction.Add, InstFlags.OpBinary, "+", 2); + Add(Instruction.BitfieldExtractS32, InstFlags.CallTernary, "bitfieldExtract"); + Add(Instruction.BitfieldExtractU32, InstFlags.CallTernary, "bitfieldExtract"); + Add(Instruction.BitfieldInsert, InstFlags.CallQuaternary, "bitfieldInsert"); + Add(Instruction.BitfieldReverse, InstFlags.CallUnary, "bitfieldReverse"); + Add(Instruction.BitwiseAnd, InstFlags.OpBinary, "&", 6); + Add(Instruction.BitwiseExclusiveOr, InstFlags.OpBinary, "^", 7); + Add(Instruction.BitwiseNot, InstFlags.OpUnary, "~", 0); + Add(Instruction.BitwiseOr, InstFlags.OpBinary, "|", 8); + Add(Instruction.Ceiling, InstFlags.CallUnary, "ceil"); + Add(Instruction.Clamp, InstFlags.CallTernary, "clamp"); + Add(Instruction.ClampU32, InstFlags.CallTernary, "clamp"); + Add(Instruction.CompareEqual, InstFlags.OpBinary, "==", 5); + Add(Instruction.CompareGreater, InstFlags.OpBinary, ">", 4); + Add(Instruction.CompareGreaterOrEqual, InstFlags.OpBinary, ">=", 4); + Add(Instruction.CompareGreaterOrEqualU32, InstFlags.OpBinary, ">=", 4); + Add(Instruction.CompareGreaterU32, InstFlags.OpBinary, ">", 4); + Add(Instruction.CompareLess, InstFlags.OpBinary, "<", 4); + Add(Instruction.CompareLessOrEqual, InstFlags.OpBinary, "<=", 4); + Add(Instruction.CompareLessOrEqualU32, InstFlags.OpBinary, "<=", 4); + Add(Instruction.CompareLessU32, InstFlags.OpBinary, "<", 4); + Add(Instruction.CompareNotEqual, InstFlags.OpBinary, "!=", 5); + Add(Instruction.ConditionalSelect, InstFlags.OpTernary, "?:", 12); + Add(Instruction.ConvertFPToS32, InstFlags.CallUnary, "int"); + Add(Instruction.ConvertS32ToFP, InstFlags.CallUnary, "float"); + Add(Instruction.ConvertU32ToFP, InstFlags.CallUnary, "float"); + Add(Instruction.Cosine, InstFlags.CallUnary, "cos"); + Add(Instruction.Discard, InstFlags.OpNullary, "discard"); + Add(Instruction.Divide, InstFlags.OpBinary, "/", 1); + Add(Instruction.EmitVertex, InstFlags.CallNullary, "EmitVertex"); + Add(Instruction.EndPrimitive, InstFlags.CallNullary, "EndPrimitive"); + Add(Instruction.ExponentB2, InstFlags.CallUnary, "exp2"); + Add(Instruction.Floor, InstFlags.CallUnary, "floor"); + Add(Instruction.FusedMultiplyAdd, InstFlags.CallTernary, "fma"); + Add(Instruction.IsNan, InstFlags.CallUnary, "isnan"); + Add(Instruction.LoadConstant, InstFlags.Call); + Add(Instruction.LogarithmB2, InstFlags.CallUnary, "log2"); + Add(Instruction.LogicalAnd, InstFlags.OpBinary, "&&", 9); + Add(Instruction.LogicalExclusiveOr, InstFlags.OpBinary, "^^", 10); + Add(Instruction.LogicalNot, InstFlags.OpUnary, "!", 0); + Add(Instruction.LogicalOr, InstFlags.OpBinary, "||", 11); + Add(Instruction.ShiftLeft, InstFlags.OpBinary, "<<", 3); + Add(Instruction.ShiftRightS32, InstFlags.OpBinary, ">>", 3); + Add(Instruction.ShiftRightU32, InstFlags.OpBinary, ">>", 3); + Add(Instruction.Maximum, InstFlags.CallBinary, "max"); + Add(Instruction.MaximumU32, InstFlags.CallBinary, "max"); + Add(Instruction.Minimum, InstFlags.CallBinary, "min"); + Add(Instruction.MinimumU32, InstFlags.CallBinary, "min"); + Add(Instruction.Multiply, InstFlags.OpBinary, "*", 1); + Add(Instruction.Negate, InstFlags.OpUnary, "-", 0); + Add(Instruction.ReciprocalSquareRoot, InstFlags.CallUnary, "inversesqrt"); + Add(Instruction.Return, InstFlags.OpNullary, "return"); + Add(Instruction.Sine, InstFlags.CallUnary, "sin"); + Add(Instruction.SquareRoot, InstFlags.CallUnary, "sqrt"); + Add(Instruction.Subtract, InstFlags.OpBinary, "-", 2); + Add(Instruction.TextureSample, InstFlags.Call); + Add(Instruction.Truncate, InstFlags.CallUnary, "trunc"); + } + + private static void Add(Instruction inst, InstFlags flags, string opName = null, int precedence = 0) + { + _infoTbl[(int)inst] = new InstInfo(flags, opName, precedence); + } + public static string GetExpression(CodeGenContext context, IAstNode node) { if (node is AstOperation operation) @@ -43,233 +149,114 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl private static string GetExpression(CodeGenContext context, AstOperation operation) { - switch (operation.Inst & Instruction.Mask) + Instruction inst = operation.Inst & Instruction.Mask; + + switch (inst) { - case Instruction.Absolute: - return GetUnaryCallExpr(context, operation, "abs"); - - case Instruction.Add: - return GetBinaryExpr(context, operation, "+"); - - case Instruction.BitfieldExtractS32: - case Instruction.BitfieldExtractU32: - return GetTernaryCallExpr(context, operation, "bitfieldExtract"); - - case Instruction.BitfieldInsert: - return GetQuaternaryCallExpr(context, operation, "bitfieldInsert"); - - case Instruction.BitfieldReverse: - return GetUnaryCallExpr(context, operation, "bitfieldReverse"); - - case Instruction.BitwiseAnd: - return GetBinaryExpr(context, operation, "&"); - - case Instruction.BitwiseExclusiveOr: - return GetBinaryExpr(context, operation, "^"); - - case Instruction.BitwiseNot: - return GetUnaryExpr(context, operation, "~"); - - case Instruction.BitwiseOr: - return GetBinaryExpr(context, operation, "|"); - - case Instruction.Ceiling: - return GetUnaryCallExpr(context, operation, "ceil"); - - case Instruction.CompareEqual: - return GetBinaryExpr(context, operation, "=="); - - case Instruction.CompareGreater: - case Instruction.CompareGreaterU32: - return GetBinaryExpr(context, operation, ">"); - - case Instruction.CompareGreaterOrEqual: - case Instruction.CompareGreaterOrEqualU32: - return GetBinaryExpr(context, operation, ">="); - - case Instruction.CompareLess: - case Instruction.CompareLessU32: - return GetBinaryExpr(context, operation, "<"); - - case Instruction.CompareLessOrEqual: - case Instruction.CompareLessOrEqualU32: - return GetBinaryExpr(context, operation, "<="); - - case Instruction.CompareNotEqual: - return GetBinaryExpr(context, operation, "!="); - - case Instruction.ConditionalSelect: - return GetConditionalSelectExpr(context, operation); - - case Instruction.Cosine: - return GetUnaryCallExpr(context, operation, "cos"); - - case Instruction.Clamp: - case Instruction.ClampU32: - return GetTernaryCallExpr(context, operation, "clamp"); - - case Instruction.ConvertFPToS32: - return GetUnaryCallExpr(context, operation, "int"); - - case Instruction.ConvertS32ToFP: - case Instruction.ConvertU32ToFP: - return GetUnaryCallExpr(context, operation, "float"); - - case Instruction.Discard: - return "discard"; - - case Instruction.Divide: - return GetBinaryExpr(context, operation, "/"); - - case Instruction.EmitVertex: - return "EmitVertex()"; - - case Instruction.EndPrimitive: - return "EndPrimitive()"; - - case Instruction.ExponentB2: - return GetUnaryCallExpr(context, operation, "exp2"); - - case Instruction.Floor: - return GetUnaryCallExpr(context, operation, "floor"); - - case Instruction.FusedMultiplyAdd: - return GetTernaryCallExpr(context, operation, "fma"); - - case Instruction.IsNan: - return GetUnaryCallExpr(context, operation, "isnan"); - case Instruction.LoadConstant: return GetLoadConstantExpr(context, operation); - case Instruction.LogarithmB2: - return GetUnaryCallExpr(context, operation, "log2"); - - case Instruction.LogicalAnd: - return GetBinaryExpr(context, operation, "&&"); - - case Instruction.LogicalExclusiveOr: - return GetBinaryExpr(context, operation, "^^"); - - case Instruction.LogicalNot: - return GetUnaryExpr(context, operation, "!"); - - case Instruction.LogicalOr: - return GetBinaryExpr(context, operation, "||"); - - case Instruction.LoopBreak: - return "break"; - - case Instruction.LoopContinue: - return "continue"; - - case Instruction.Maximum: - case Instruction.MaximumU32: - return GetBinaryCallExpr(context, operation, "max"); - - case Instruction.Minimum: - case Instruction.MinimumU32: - return GetBinaryCallExpr(context, operation, "min"); - - case Instruction.Multiply: - return GetBinaryExpr(context, operation, "*"); - - case Instruction.Negate: - return GetUnaryExpr(context, operation, "-"); - - case Instruction.ReciprocalSquareRoot: - return GetUnaryCallExpr(context, operation, "inversesqrt"); - - case Instruction.Return: - return "return"; - - case Instruction.ShiftLeft: - return GetBinaryExpr(context, operation, "<<"); - - case Instruction.ShiftRightS32: - case Instruction.ShiftRightU32: - return GetBinaryExpr(context, operation, ">>"); - - case Instruction.Sine: - return GetUnaryCallExpr(context, operation, "sin"); - - case Instruction.SquareRoot: - return GetUnaryCallExpr(context, operation, "sqrt"); - - case Instruction.Subtract: - return GetBinaryExpr(context, operation, "-"); - case Instruction.TextureSample: return GetTextureSampleExpr(context, operation); - - case Instruction.Truncate: - return GetUnaryCallExpr(context, operation, "trunc"); } - throw new ArgumentException($"Operation has invalid instruction \"{operation.Inst}\"."); - } + InstInfo info = _infoTbl[(int)inst]; - private static string GetUnaryCallExpr(CodeGenContext context, AstOperation operation, string funcName) - { - return funcName + "(" + GetSoureExpr(context, operation.Sources[0], GetSrcVarType(operation.Inst, 0)) + ")"; - } + if ((info.Flags & InstFlags.Call) != 0) + { + int arity = (int)(info.Flags & InstFlags.ArityMask); - private static string GetBinaryCallExpr(CodeGenContext context, AstOperation operation, string funcName) - { - return funcName + "(" + - GetSoureExpr(context, operation.Sources[0], GetSrcVarType(operation.Inst, 0)) + ", " + - GetSoureExpr(context, operation.Sources[1], GetSrcVarType(operation.Inst, 1)) + ")"; - } + string args = string.Empty; - private static string GetTernaryCallExpr(CodeGenContext context, AstOperation operation, string funcName) - { - return funcName + "(" + - GetSoureExpr(context, operation.Sources[0], GetSrcVarType(operation.Inst, 0)) + ", " + - GetSoureExpr(context, operation.Sources[1], GetSrcVarType(operation.Inst, 1)) + ", " + - GetSoureExpr(context, operation.Sources[2], GetSrcVarType(operation.Inst, 2)) + ")"; - } + for (int argIndex = 0; argIndex < arity; argIndex++) + { + if (argIndex != 0) + { + args += ", "; + } - private static string GetQuaternaryCallExpr(CodeGenContext context, AstOperation operation, string funcName) - { - return funcName + "(" + - GetSoureExpr(context, operation.Sources[0], GetSrcVarType(operation.Inst, 0)) + ", " + - GetSoureExpr(context, operation.Sources[1], GetSrcVarType(operation.Inst, 1)) + ", " + - GetSoureExpr(context, operation.Sources[2], GetSrcVarType(operation.Inst, 2)) + ", " + - GetSoureExpr(context, operation.Sources[3], GetSrcVarType(operation.Inst, 3)) + ")"; - } + VariableType dstType = GetSrcVarType(operation.Inst, argIndex); - private static string GetUnaryExpr(CodeGenContext context, AstOperation operation, string op) - { - return op + GetSoureExpr(context, operation.Sources[0], GetSrcVarType(operation.Inst, 0)); - } + args += GetSoureExpr(context, operation.GetSource(argIndex), dstType); + } - private static string GetBinaryExpr(CodeGenContext context, AstOperation operation, string op) - { - return GetSoureExpr(context, operation.Sources[0], GetSrcVarType(operation.Inst, 0)) + " " + op + " " + - GetSoureExpr(context, operation.Sources[1], GetSrcVarType(operation.Inst, 1)); - } + return info.OpName + "(" + args + ")"; + } + else + { + if (info.Flags == InstFlags.OpNullary) + { + return info.OpName; + } + else if (info.Flags == InstFlags.OpUnary) + { + IAstNode src = operation.GetSource(0); - private static string GetConditionalSelectExpr(CodeGenContext context, AstOperation operation) - { - return "((" + - GetSoureExpr(context, operation.Sources[0], GetSrcVarType(operation.Inst, 0)) + ") ? (" + - GetSoureExpr(context, operation.Sources[1], GetSrcVarType(operation.Inst, 1)) + ") : (" + - GetSoureExpr(context, operation.Sources[2], GetSrcVarType(operation.Inst, 2)) + "))"; + string expr = GetSoureExpr(context, src, GetSrcVarType(operation.Inst, 0)); + + return info.OpName + Enclose(expr, src, info); + } + else if (info.Flags == InstFlags.OpBinary) + { + IAstNode src0 = operation.GetSource(0); + IAstNode src1 = operation.GetSource(1); + + string expr0 = GetSoureExpr(context, src0, GetSrcVarType(operation.Inst, 0)); + string expr1 = GetSoureExpr(context, src1, GetSrcVarType(operation.Inst, 1)); + + expr0 = Enclose(expr0, src0, info, isLhs: true); + expr1 = Enclose(expr1, src1, info, isLhs: false); + + return expr0 + " " + info.OpName + " " + expr1; + } + else if (info.Flags == InstFlags.OpTernary) + { + IAstNode src0 = operation.GetSource(0); + IAstNode src1 = operation.GetSource(1); + IAstNode src2 = operation.GetSource(2); + + string expr0 = GetSoureExpr(context, src0, GetSrcVarType(operation.Inst, 0)); + string expr1 = GetSoureExpr(context, src1, GetSrcVarType(operation.Inst, 1)); + string expr2 = GetSoureExpr(context, src2, GetSrcVarType(operation.Inst, 2)); + + expr0 = Enclose(expr0, src0, info); + expr1 = Enclose(expr1, src1, info); + expr2 = Enclose(expr2, src2, info); + + char op0 = info.OpName[0]; + char op1 = info.OpName[1]; + + return expr0 + " " + op0 + " " + expr1 + " " + op1 + " " + expr2; + } + else + { + throw new InvalidOperationException($"Unexpected instruction flags \"{info.Flags}\"."); + } + } } private static string GetLoadConstantExpr(CodeGenContext context, AstOperation operation) { - string offsetExpr = GetSoureExpr(context, operation.Sources[1], GetSrcVarType(operation.Inst, 1)); + IAstNode src1 = operation.GetSource(1); - return OperandManager.GetConstantBufferName(context, operation.Sources[0], offsetExpr); + string offsetExpr = GetSoureExpr(context, src1, GetSrcVarType(operation.Inst, 1)); + + offsetExpr = Enclose(offsetExpr, src1, Instruction.ShiftRightS32, isLhs: true); + + return OperandManager.GetConstantBufferName(context, operation.GetSource(0), offsetExpr); } private static string GetTextureSampleExpr(CodeGenContext context, AstOperation operation) { AstTextureOperation texOp = (AstTextureOperation)operation; - bool isGather = (texOp.Flags & TextureFlags.Gather) != 0; - bool isShadow = (texOp.Type & TextureType.Shadow) != 0; + bool isGather = (texOp.Flags & TextureFlags.Gather) != 0; + bool hasLodBias = (texOp.Flags & TextureFlags.LodBias) != 0; + bool hasLodLevel = (texOp.Flags & TextureFlags.LodLevel) != 0; + bool hasOffset = (texOp.Flags & TextureFlags.Offset) != 0; + bool hasOffsets = (texOp.Flags & TextureFlags.Offsets) != 0; + bool isArray = (texOp.Type & TextureType.Array) != 0; + bool isShadow = (texOp.Type & TextureType.Shadow) != 0; string samplerName = OperandManager.GetSamplerName(context.ShaderType, texOp.TextureHandle); @@ -280,16 +267,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl texCall += "Gather"; } - if ((texOp.Flags & TextureFlags.LodLevel) != 0) + if (hasLodLevel) { texCall += "Lod"; } - if ((texOp.Flags & TextureFlags.Offset) != 0) + if (hasOffset) { texCall += "Offset"; } - else if ((texOp.Flags & TextureFlags.Offsets) != 0) + else if (hasOffsets) { texCall += "Offsets"; } @@ -314,7 +301,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl int arrayIndexElem = -1; - if ((texOp.Type & TextureType.Array) != 0) + if (isArray) { arrayIndexElem = pCount++; } @@ -338,7 +325,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl string Src(VariableType type) { - return GetSoureExpr(context, texOp.Sources[srcIndex++], type); + return GetSoureExpr(context, texOp.GetSource(srcIndex++), type); } string AssembleVector(int count, VariableType type, bool isP = false) @@ -376,16 +363,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl texCall += ", " + Src(VariableType.F32); } - if ((texOp.Flags & TextureFlags.LodLevel) != 0) + if (hasLodLevel) { texCall += ", " + Src(VariableType.F32); } - if ((texOp.Flags & TextureFlags.Offset) != 0) + if (hasOffset) { texCall += ", " + AssembleVector(elemsCount, VariableType.S32); } - else if ((texOp.Flags & TextureFlags.Offsets) != 0) + else if (hasOffsets) { const int gatherTexelsCount = 4; @@ -404,7 +391,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl texCall += ")"; } - if ((texOp.Flags & TextureFlags.LodBias) != 0) + if (hasLodBias && !hasLodLevel) { texCall += ", " + Src(VariableType.F32); } @@ -434,5 +421,56 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl { return ReinterpretCast(context, node, OperandManager.GetNodeDestType(node), dstType); } + + public static string Enclose(string expr, IAstNode node, Instruction inst, bool isLhs) + { + InstInfo info = _infoTbl[(int)(inst & Instruction.Mask)]; + + return Enclose(expr, node, info, isLhs); + } + + private static string Enclose(string expr, IAstNode node, InstInfo pInfo, bool isLhs = false) + { + if (NeedsParenthesis(node, pInfo, isLhs)) + { + expr = "(" + expr + ")"; + } + + return expr; + } + + private static bool NeedsParenthesis(IAstNode node, InstInfo pInfo, bool isLhs) + { + //If the node isn't a operation, then it can only be a operand, + //and those never needs to be surrounded in parenthesis. + if (!(node is AstOperation operation)) + { + return false; + } + + if ((pInfo.Flags & InstFlags.Call) != 0) + { + return false; + } + + InstInfo info = _infoTbl[(int)(operation.Inst & Instruction.Mask)]; + + if ((info.Flags & InstFlags.Call) != 0) + { + return false; + } + + if (info.Precedence < pInfo.Precedence) + { + return false; + } + + if (info.Precedence == pInfo.Precedence && isLhs) + { + return false; + } + + return true; + } } } \ No newline at end of file diff --git a/Ryujinx.Graphics/Shader/CodeGen/Glsl/TypeConversion.cs b/Ryujinx.Graphics/Shader/CodeGen/Glsl/TypeConversion.cs index 801545e4f7..7642869f3b 100644 --- a/Ryujinx.Graphics/Shader/CodeGen/Glsl/TypeConversion.cs +++ b/Ryujinx.Graphics/Shader/CodeGen/Glsl/TypeConversion.cs @@ -22,10 +22,10 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl string expr = Instructions.GetExpression(context, node); - return ReinterpretCast(expr, srcType, dstType); + return ReinterpretCast(expr, node, srcType, dstType); } - public static string ReinterpretCast(string expr, VariableType srcType, VariableType dstType) + private static string ReinterpretCast(string expr, IAstNode node, VariableType srcType, VariableType dstType) { if (srcType == dstType) { @@ -55,7 +55,9 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl } else if (dstType == VariableType.Bool) { - return $"(({expr}) != 0)"; + expr = Instructions.Enclose(expr, node, Instruction.CompareNotEqual, isLhs: true); + + return $"({expr} != 0)"; } else if (dstType == VariableType.S32) { diff --git a/Ryujinx.Graphics/Shader/IntermediateRepresentation/BasicBlock.cs b/Ryujinx.Graphics/Shader/IntermediateRepresentation/BasicBlock.cs index d0aaa28a4a..cbd4d64bca 100644 --- a/Ryujinx.Graphics/Shader/IntermediateRepresentation/BasicBlock.cs +++ b/Ryujinx.Graphics/Shader/IntermediateRepresentation/BasicBlock.cs @@ -24,6 +24,8 @@ namespace Ryujinx.Graphics.Shader.IntermediateRepresentation set => _branch = AddSuccessor(_branch, value); } + public bool HasBranch => _branch != null; + public List Predecessors { get; } public HashSet DominanceFrontiers { get; } @@ -47,7 +49,7 @@ namespace Ryujinx.Graphics.Shader.IntermediateRepresentation private BasicBlock AddSuccessor(BasicBlock oldBlock, BasicBlock newBlock) { oldBlock?.Predecessors.Remove(this); - newBlock.Predecessors.Add(this); + newBlock?.Predecessors.Add(this); return newBlock; } diff --git a/Ryujinx.Graphics/Shader/IntermediateRepresentation/INode.cs b/Ryujinx.Graphics/Shader/IntermediateRepresentation/INode.cs index 48dda24b1e..e1a312cc7d 100644 --- a/Ryujinx.Graphics/Shader/IntermediateRepresentation/INode.cs +++ b/Ryujinx.Graphics/Shader/IntermediateRepresentation/INode.cs @@ -2,6 +2,8 @@ namespace Ryujinx.Graphics.Shader.IntermediateRepresentation { interface INode { + BasicBlock Parent { get; set; } + Operand Dest { get; set; } int SourcesCount { get; } diff --git a/Ryujinx.Graphics/Shader/IntermediateRepresentation/Operation.cs b/Ryujinx.Graphics/Shader/IntermediateRepresentation/Operation.cs index 96c210a084..f3c44f73ac 100644 --- a/Ryujinx.Graphics/Shader/IntermediateRepresentation/Operation.cs +++ b/Ryujinx.Graphics/Shader/IntermediateRepresentation/Operation.cs @@ -2,6 +2,8 @@ namespace Ryujinx.Graphics.Shader.IntermediateRepresentation { class Operation : INode { + public BasicBlock Parent { get; set; } + public Instruction Inst { get; private set; } private Operand _dest; diff --git a/Ryujinx.Graphics/Shader/IntermediateRepresentation/PhiNode.cs b/Ryujinx.Graphics/Shader/IntermediateRepresentation/PhiNode.cs index 81609ea9f1..ea1cfa1850 100644 --- a/Ryujinx.Graphics/Shader/IntermediateRepresentation/PhiNode.cs +++ b/Ryujinx.Graphics/Shader/IntermediateRepresentation/PhiNode.cs @@ -4,6 +4,8 @@ namespace Ryujinx.Graphics.Shader.IntermediateRepresentation { class PhiNode : INode { + public BasicBlock Parent { get; set; } + private Operand _dest; public Operand Dest diff --git a/Ryujinx.Graphics/Shader/StructuredIr/AstAssignment.cs b/Ryujinx.Graphics/Shader/StructuredIr/AstAssignment.cs index 5a4acb88a3..bb3fe7af4b 100644 --- a/Ryujinx.Graphics/Shader/StructuredIr/AstAssignment.cs +++ b/Ryujinx.Graphics/Shader/StructuredIr/AstAssignment.cs @@ -1,14 +1,35 @@ +using static Ryujinx.Graphics.Shader.StructuredIr.AstHelper; + namespace Ryujinx.Graphics.Shader.StructuredIr { class AstAssignment : AstNode { public IAstNode Destination { get; } - public IAstNode Source { get; } + + private IAstNode _source; + + public IAstNode Source + { + get + { + return _source; + } + set + { + RemoveUse(_source, this); + + AddUse(value, this); + + _source = value; + } + } public AstAssignment(IAstNode destination, IAstNode source) { Destination = destination; Source = source; + + AddDef(destination, this); } } } \ No newline at end of file diff --git a/Ryujinx.Graphics/Shader/StructuredIr/AstBlock.cs b/Ryujinx.Graphics/Shader/StructuredIr/AstBlock.cs index 888ff38e8f..fdef87de56 100644 --- a/Ryujinx.Graphics/Shader/StructuredIr/AstBlock.cs +++ b/Ryujinx.Graphics/Shader/StructuredIr/AstBlock.cs @@ -1,20 +1,39 @@ -using System; using Ryujinx.Graphics.Shader.IntermediateRepresentation; +using System; using System.Collections; using System.Collections.Generic; +using static Ryujinx.Graphics.Shader.StructuredIr.AstHelper; + namespace Ryujinx.Graphics.Shader.StructuredIr { class AstBlock : AstNode, IEnumerable { - public AstBlockType Type { get; } + public AstBlockType Type { get; private set; } - public IAstNode Condition { get; private set; } + private IAstNode _condition; + + public IAstNode Condition + { + get + { + return _condition; + } + set + { + RemoveUse(_condition, this); + + AddUse(value, this); + + _condition = value; + } + } private LinkedList _nodes; public IAstNode First => _nodes.First?.Value; - public IAstNode Last => _nodes.Last?.Value; + + public int Count => _nodes.Count; public AstBlock(AstBlockType type, IAstNode condition = null) { @@ -34,14 +53,14 @@ namespace Ryujinx.Graphics.Shader.StructuredIr Add(node, _nodes.AddFirst(node)); } - public void AddBefore(IAstNode oldNode, IAstNode node) + public void AddBefore(IAstNode next, IAstNode node) { - Add(node, _nodes.AddBefore(oldNode.LLNode, node)); + Add(node, _nodes.AddBefore(next.LLNode, node)); } - public void AddAfter(IAstNode oldNode, IAstNode node) + public void AddAfter(IAstNode prev, IAstNode node) { - Add(node, _nodes.AddAfter(oldNode.LLNode, node)); + Add(node, _nodes.AddAfter(prev.LLNode, node)); } private void Add(IAstNode node, LinkedListNode newNode) @@ -72,6 +91,17 @@ namespace Ryujinx.Graphics.Shader.StructuredIr { Condition = new AstOperation(Instruction.LogicalOr, Condition, cond); } + public void TurnIntoIf(IAstNode cond) + { + Condition = cond; + + Type = AstBlockType.If; + } + + public void TurnIntoElseIf() + { + Type = AstBlockType.ElseIf; + } public IEnumerator GetEnumerator() { diff --git a/Ryujinx.Graphics/Shader/StructuredIr/AstBlockType.cs b/Ryujinx.Graphics/Shader/StructuredIr/AstBlockType.cs index b607a3b21c..c12efda909 100644 --- a/Ryujinx.Graphics/Shader/StructuredIr/AstBlockType.cs +++ b/Ryujinx.Graphics/Shader/StructuredIr/AstBlockType.cs @@ -5,6 +5,7 @@ namespace Ryujinx.Graphics.Shader.StructuredIr DoWhile, If, Else, + ElseIf, Main, While } diff --git a/Ryujinx.Graphics/Shader/StructuredIr/AstBlockVisitor.cs b/Ryujinx.Graphics/Shader/StructuredIr/AstBlockVisitor.cs new file mode 100644 index 0000000000..1ee630bc71 --- /dev/null +++ b/Ryujinx.Graphics/Shader/StructuredIr/AstBlockVisitor.cs @@ -0,0 +1,64 @@ +using System; +using System.Collections.Generic; + +using static Ryujinx.Graphics.Shader.StructuredIr.AstHelper; + +namespace Ryujinx.Graphics.Shader.StructuredIr +{ + class AstBlockVisitor + { + public AstBlock Block { get; private set; } + + public class BlockVisitationEventArgs : EventArgs + { + public AstBlock Block { get; } + + public BlockVisitationEventArgs(AstBlock block) + { + Block = block; + } + } + + public event EventHandler BlockEntered; + public event EventHandler BlockLeft; + + public AstBlockVisitor(AstBlock mainBlock) + { + Block = mainBlock; + } + + public IEnumerable Visit() + { + IAstNode node = Block.First; + + while (node != null) + { + //We reached a child block, visit the nodes inside. + while (node is AstBlock childBlock) + { + Block = childBlock; + + node = childBlock.First; + + BlockEntered?.Invoke(this, new BlockVisitationEventArgs(Block)); + } + + IAstNode next = Next(node); + + yield return node; + + node = next; + + //We reached the end of the list, go up on tree to the parent blocks. + while (node == null && Block.Type != AstBlockType.Main) + { + BlockLeft?.Invoke(this, new BlockVisitationEventArgs(Block)); + + node = Next(Block); + + Block = Block.Parent; + } + } + } + } +} \ No newline at end of file diff --git a/Ryujinx.Graphics/Shader/StructuredIr/AstDeclaration.cs b/Ryujinx.Graphics/Shader/StructuredIr/AstDeclaration.cs deleted file mode 100644 index 4d5e65c01f..0000000000 --- a/Ryujinx.Graphics/Shader/StructuredIr/AstDeclaration.cs +++ /dev/null @@ -1,12 +0,0 @@ -namespace Ryujinx.Graphics.Shader.StructuredIr -{ - class AstDeclaration : AstNode - { - public AstOperand Operand { get; } - - public AstDeclaration(AstOperand operand) - { - Operand = operand; - } - } -} \ No newline at end of file diff --git a/Ryujinx.Graphics/Shader/StructuredIr/AstHelper.cs b/Ryujinx.Graphics/Shader/StructuredIr/AstHelper.cs index 2dcb188e21..9d3148e1bb 100644 --- a/Ryujinx.Graphics/Shader/StructuredIr/AstHelper.cs +++ b/Ryujinx.Graphics/Shader/StructuredIr/AstHelper.cs @@ -4,6 +4,38 @@ namespace Ryujinx.Graphics.Shader.StructuredIr { static class AstHelper { + public static void AddUse(IAstNode node, IAstNode parent) + { + if (node is AstOperand operand && operand.Type == OperandType.LocalVariable) + { + operand.Uses.Add(parent); + } + } + + public static void AddDef(IAstNode node, IAstNode parent) + { + if (node is AstOperand operand && operand.Type == OperandType.LocalVariable) + { + operand.Defs.Add(parent); + } + } + + public static void RemoveUse(IAstNode node, IAstNode parent) + { + if (node is AstOperand operand && operand.Type == OperandType.LocalVariable) + { + operand.Uses.Remove(parent); + } + } + + public static void RemoveDef(IAstNode node, IAstNode parent) + { + if (node is AstOperand operand && operand.Type == OperandType.LocalVariable) + { + operand.Defs.Remove(parent); + } + } + public static AstAssignment Assign(IAstNode destination, IAstNode source) { return new AstAssignment(destination, source); @@ -22,5 +54,20 @@ namespace Ryujinx.Graphics.Shader.StructuredIr return local; } + + public static IAstNode InverseCond(IAstNode cond) + { + return new AstOperation(Instruction.LogicalNot, cond); + } + + public static IAstNode Next(IAstNode node) + { + return node.LLNode.Next?.Value; + } + + public static IAstNode Previous(IAstNode node) + { + return node.LLNode.Previous?.Value; + } } } \ No newline at end of file diff --git a/Ryujinx.Graphics/Shader/StructuredIr/AstOperand.cs b/Ryujinx.Graphics/Shader/StructuredIr/AstOperand.cs index 5a5f3b7ee5..97ff3ca97c 100644 --- a/Ryujinx.Graphics/Shader/StructuredIr/AstOperand.cs +++ b/Ryujinx.Graphics/Shader/StructuredIr/AstOperand.cs @@ -1,9 +1,13 @@ using Ryujinx.Graphics.Shader.IntermediateRepresentation; +using System.Collections.Generic; namespace Ryujinx.Graphics.Shader.StructuredIr { class AstOperand : AstNode { + public HashSet Defs { get; } + public HashSet Uses { get; } + public OperandType Type { get; } public VariableType VarType { get; set; } @@ -15,6 +19,9 @@ namespace Ryujinx.Graphics.Shader.StructuredIr private AstOperand() { + Defs = new HashSet(); + Uses = new HashSet(); + VarType = VariableType.S32; } diff --git a/Ryujinx.Graphics/Shader/StructuredIr/AstOperation.cs b/Ryujinx.Graphics/Shader/StructuredIr/AstOperation.cs index e8b699d4f9..8594fe6d64 100644 --- a/Ryujinx.Graphics/Shader/StructuredIr/AstOperation.cs +++ b/Ryujinx.Graphics/Shader/StructuredIr/AstOperation.cs @@ -1,17 +1,40 @@ using Ryujinx.Graphics.Shader.IntermediateRepresentation; +using static Ryujinx.Graphics.Shader.StructuredIr.AstHelper; + namespace Ryujinx.Graphics.Shader.StructuredIr { class AstOperation : AstNode { public Instruction Inst { get; } - public IAstNode[] Sources { get; } + private IAstNode[] _sources; + + public int SourcesCount => _sources.Length; public AstOperation(Instruction inst, params IAstNode[] sources) { - Inst = inst; - Sources = sources; + Inst = inst; + _sources = sources; + + foreach (IAstNode source in sources) + { + AddUse(source, this); + } + } + + public IAstNode GetSource(int index) + { + return _sources[index]; + } + + public void SetSource(int index, IAstNode source) + { + RemoveUse(_sources[index], this); + + AddUse(source, this); + + _sources[index] = source; } } } \ No newline at end of file diff --git a/Ryujinx.Graphics/Shader/StructuredIr/AstOptimizer.cs b/Ryujinx.Graphics/Shader/StructuredIr/AstOptimizer.cs new file mode 100644 index 0000000000..0f5392b7d6 --- /dev/null +++ b/Ryujinx.Graphics/Shader/StructuredIr/AstOptimizer.cs @@ -0,0 +1,149 @@ +using Ryujinx.Graphics.Shader.IntermediateRepresentation; +using System.Collections.Generic; +using System.Linq; + +using static Ryujinx.Graphics.Shader.StructuredIr.AstHelper; + +namespace Ryujinx.Graphics.Shader.StructuredIr +{ + static class AstOptimizer + { + public static void Optimize(StructuredProgramInfo info) + { + AstBlock mainBlock = info.MainBlock; + + AstBlockVisitor visitor = new AstBlockVisitor(mainBlock); + + foreach (IAstNode node in visitor.Visit()) + { + if (node is AstAssignment assignment && assignment.Destination is AstOperand propVar) + { + bool isWorthPropagating = propVar.Uses.Count == 1 || IsWorthPropagating(assignment.Source); + + if (propVar.Defs.Count == 1 && isWorthPropagating) + { + PropagateExpression(propVar, assignment.Source); + } + + if (propVar.Type == OperandType.LocalVariable && propVar.Uses.Count == 0) + { + visitor.Block.Remove(assignment); + + info.Locals.Remove(propVar); + } + } + } + + RemoveEmptyBlocks(mainBlock); + } + + private static bool IsWorthPropagating(IAstNode source) + { + if (!(source is AstOperation srcOp)) + { + return false; + } + + if (!InstructionInfo.IsUnary(srcOp.Inst)) + { + return false; + } + + return srcOp.GetSource(0) is AstOperand || srcOp.Inst == Instruction.Copy; + } + + private static void PropagateExpression(AstOperand propVar, IAstNode source) + { + IAstNode[] uses = propVar.Uses.ToArray(); + + foreach (IAstNode useNode in uses) + { + if (useNode is AstBlock useBlock) + { + useBlock.Condition = source; + } + else if (useNode is AstOperation useOperation) + { + for (int srcIndex = 0; srcIndex < useOperation.SourcesCount; srcIndex++) + { + if (useOperation.GetSource(srcIndex) == propVar) + { + useOperation.SetSource(srcIndex, source); + } + } + } + else if (useNode is AstAssignment useAssignment) + { + useAssignment.Source = source; + } + } + } + + private static void RemoveEmptyBlocks(AstBlock mainBlock) + { + Queue pending = new Queue(); + + pending.Enqueue(mainBlock); + + while (pending.TryDequeue(out AstBlock block)) + { + foreach (IAstNode node in block) + { + if (node is AstBlock childBlock) + { + pending.Enqueue(childBlock); + } + } + + AstBlock parent = block.Parent; + + if (parent == null) + { + continue; + } + + AstBlock nextBlock = Next(block) as AstBlock; + + bool hasElse = nextBlock != null && nextBlock.Type == AstBlockType.Else; + + bool isIf = block.Type == AstBlockType.If; + + if (block.Count == 0) + { + if (isIf) + { + if (hasElse) + { + nextBlock.TurnIntoIf(InverseCond(block.Condition)); + } + + parent.Remove(block); + } + else if (block.Type == AstBlockType.Else) + { + parent.Remove(block); + } + } + else if (isIf && parent.Type == AstBlockType.Else && parent.Count == (hasElse ? 2 : 1)) + { + AstBlock parentOfParent = parent.Parent; + + parent.Remove(block); + + parentOfParent.AddAfter(parent, block); + + if (hasElse) + { + parent.Remove(nextBlock); + + parentOfParent.AddAfter(block, nextBlock); + } + + parentOfParent.Remove(parent); + + block.TurnIntoElseIf(); + } + } + } + } +} \ No newline at end of file diff --git a/Ryujinx.Graphics/Shader/StructuredIr/GotoElimination.cs b/Ryujinx.Graphics/Shader/StructuredIr/GotoElimination.cs index b9c9c8ff88..c2fd76055e 100644 --- a/Ryujinx.Graphics/Shader/StructuredIr/GotoElimination.cs +++ b/Ryujinx.Graphics/Shader/StructuredIr/GotoElimination.cs @@ -2,6 +2,8 @@ using Ryujinx.Graphics.Shader.IntermediateRepresentation; using System; using System.Collections.Generic; +using static Ryujinx.Graphics.Shader.StructuredIr.AstHelper; + namespace Ryujinx.Graphics.Shader.StructuredIr { static class GotoElimination @@ -290,14 +292,14 @@ namespace Ryujinx.Graphics.Shader.StructuredIr private static bool ContainsCondComb(IAstNode node, Instruction inst, IAstNode newCond) { - while (node is AstOperation operation && operation.Sources.Length == 2) + while (node is AstOperation operation && operation.SourcesCount == 2) { - if (operation.Inst == inst && IsSameCond(operation.Sources[1], newCond)) + if (operation.Inst == inst && IsSameCond(operation.GetSource(1), newCond)) { return true; } - node = operation.Sources[0]; + node = operation.GetSource(0); } return false; @@ -396,8 +398,8 @@ namespace Ryujinx.Graphics.Shader.StructuredIr return false; } - lCond = lCondOp.Sources[0]; - rCond = rCondOp.Sources[0]; + lCond = lCondOp.GetSource(0); + rCond = rCondOp.GetSource(0); } return lCond == rCond; @@ -447,20 +449,5 @@ namespace Ryujinx.Graphics.Shader.StructuredIr return level; } - - private static IAstNode InverseCond(IAstNode cond) - { - return new AstOperation(Instruction.LogicalNot, cond); - } - - private static IAstNode Next(IAstNode node) - { - return node.LLNode.Next?.Value; - } - - private static IAstNode Previous(IAstNode node) - { - return node.LLNode.Previous?.Value; - } } } \ No newline at end of file diff --git a/Ryujinx.Graphics/Shader/StructuredIr/InstructionInfo.cs b/Ryujinx.Graphics/Shader/StructuredIr/InstructionInfo.cs index 2d7fa1f281..3130200eeb 100644 --- a/Ryujinx.Graphics/Shader/StructuredIr/InstructionInfo.cs +++ b/Ryujinx.Graphics/Shader/StructuredIr/InstructionInfo.cs @@ -77,6 +77,7 @@ namespace Ryujinx.Graphics.Shader.StructuredIr Add(Instruction.ReciprocalSquareRoot, VariableType.Scalar, VariableType.Scalar); Add(Instruction.Sine, VariableType.Scalar, VariableType.Scalar); Add(Instruction.SquareRoot, VariableType.Scalar, VariableType.Scalar); + Add(Instruction.Subtract, VariableType.Scalar, VariableType.Scalar, VariableType.Scalar); Add(Instruction.Truncate, VariableType.F32, VariableType.F32); } @@ -129,5 +130,20 @@ namespace Ryujinx.Graphics.Shader.StructuredIr { return inst == Instruction.TextureSample; } + + public static bool IsUnary(Instruction inst) + { + if (inst == Instruction.TextureSample) + { + return false; + } + + if (inst == Instruction.Copy) + { + return true; + } + + return _infoTbl[(int)(inst & Instruction.Mask)].SrcTypes.Length == 1; + } } } \ No newline at end of file diff --git a/Ryujinx.Graphics/Shader/StructuredIr/StructuredProgram.cs b/Ryujinx.Graphics/Shader/StructuredIr/StructuredProgram.cs index 6684768819..888bdcc853 100644 --- a/Ryujinx.Graphics/Shader/StructuredIr/StructuredProgram.cs +++ b/Ryujinx.Graphics/Shader/StructuredIr/StructuredProgram.cs @@ -20,15 +20,22 @@ namespace Ryujinx.Graphics.Shader.StructuredIr foreach (INode node in block.Operations) { - AddOperation(context, (Operation)node); - } + Operation operation = (Operation)node; - context.LeaveBlock(block); + if (IsBranchInst(operation.Inst)) + { + context.LeaveBlock(block, operation); + } + else + { + AddOperation(context, operation); + } + } } GotoElimination.Eliminate(context.GetGotos()); - context.PrependLocalDeclarations(); + AstOptimizer.Optimize(context.Info); return context.Info; } @@ -44,13 +51,20 @@ namespace Ryujinx.Graphics.Shader.StructuredIr sources[index] = context.GetOperandUse(operation.GetSource(index)); } - if (operation.Dest != null && !IsBranchInst(inst)) + if (operation.Dest != null) { AstOperand dest = context.GetOperandDef(operation.Dest); if (inst == Instruction.LoadConstant) { - context.Info.ConstantBuffers.Add((sources[0] as AstOperand).Value); + Operand ldcSource = operation.GetSource(0); + + if (ldcSource.Type != OperandType.Constant) + { + throw new InvalidOperationException("Found LDC with non-constant constant buffer slot."); + } + + context.Info.ConstantBuffers.Add(ldcSource.Value); } AstAssignment assignment; diff --git a/Ryujinx.Graphics/Shader/StructuredIr/StructuredProgramContext.cs b/Ryujinx.Graphics/Shader/StructuredIr/StructuredProgramContext.cs index ebfffd6101..0cf438c396 100644 --- a/Ryujinx.Graphics/Shader/StructuredIr/StructuredProgramContext.cs +++ b/Ryujinx.Graphics/Shader/StructuredIr/StructuredProgramContext.cs @@ -14,8 +14,6 @@ namespace Ryujinx.Graphics.Shader.StructuredIr private Dictionary _localsMap; - private List _locals; - private Dictionary _gotoTempAsgs; private List _gotos; @@ -34,8 +32,6 @@ namespace Ryujinx.Graphics.Shader.StructuredIr _localsMap = new Dictionary(); - _locals = new List(); - _gotoTempAsgs = new Dictionary(); _gotos = new List(); @@ -62,9 +58,9 @@ namespace Ryujinx.Graphics.Shader.StructuredIr LookForDoWhileStatements(block); } - public void LeaveBlock(BasicBlock block) + public void LeaveBlock(BasicBlock block, Operation branchOp) { - LookForIfStatements(block); + LookForIfStatements(block, branchOp); } private void LookForDoWhileStatements(BasicBlock block) @@ -99,38 +95,30 @@ namespace Ryujinx.Graphics.Shader.StructuredIr } } - private void LookForIfStatements(BasicBlock block) + private void LookForIfStatements(BasicBlock block, Operation branchOp) { if (block.Branch == null) { return; } - Operation branchOp = (Operation)block.GetLastOp(); - bool isLoop = block.Branch.Index <= block.Index; - AstOperation branch = _currBlock.Last as AstOperation; - if (block.Branch.Index <= _currEndIndex && !isLoop) { - _currBlock.Remove(branch); - NewBlock(AstBlockType.If, branchOp, block.Branch.Index); } - else if (_loopTails.Contains(block)) - { - //Loop handled by "LookForDoWhileStatements". - //We can safely remove the branch as it was already taken care of. - _currBlock.Remove(branch); - } - else + else if (!_loopTails.Contains(block)) { AstAssignment gotoTempAsg = GetGotoTempAsg(block.Branch.Index); IAstNode cond = GetBranchCond(AstBlockType.DoWhile, branchOp); - _currBlock.AddBefore(branch, Assign(gotoTempAsg.Destination, cond)); + AddNode(Assign(gotoTempAsg.Destination, cond)); + + AstOperation branch = new AstOperation(branchOp.Inst); + + AddNode(branch); GotoStatement gotoStmt = new GotoStatement(branch, gotoTempAsg, isLoop); @@ -214,29 +202,6 @@ namespace Ryujinx.Graphics.Shader.StructuredIr _currBlock.Add(node); } - public void PrependLocalDeclarations() - { - AstBlock mainBlock = Info.MainBlock; - - AstDeclaration decl = null; - - foreach (AstOperand operand in _locals) - { - AstDeclaration oldDecl = decl; - - decl = new AstDeclaration(operand); - - if (oldDecl == null) - { - mainBlock.AddFirst(decl); - } - else - { - mainBlock.AddAfter(oldDecl, decl); - } - } - } - public GotoStatement[] GetGotos() { return _gotos.ToArray(); @@ -246,7 +211,7 @@ namespace Ryujinx.Graphics.Shader.StructuredIr { AstOperand newTemp = Local(type); - _locals.Add(newTemp); + Info.Locals.Add(newTemp); return newTemp; } @@ -293,7 +258,7 @@ namespace Ryujinx.Graphics.Shader.StructuredIr _localsMap.Add(operand, astOperand); - _locals.Add(astOperand); + Info.Locals.Add(astOperand); } return astOperand; diff --git a/Ryujinx.Graphics/Shader/StructuredIr/StructuredProgramInfo.cs b/Ryujinx.Graphics/Shader/StructuredIr/StructuredProgramInfo.cs index e7dc799d95..d8f4377692 100644 --- a/Ryujinx.Graphics/Shader/StructuredIr/StructuredProgramInfo.cs +++ b/Ryujinx.Graphics/Shader/StructuredIr/StructuredProgramInfo.cs @@ -7,6 +7,8 @@ namespace Ryujinx.Graphics.Shader.StructuredIr { public AstBlock MainBlock { get; } + public HashSet Locals { get; } + public HashSet ConstantBuffers { get; } public HashSet IAttributes { get; } @@ -18,6 +20,8 @@ namespace Ryujinx.Graphics.Shader.StructuredIr { MainBlock = mainBlock; + Locals = new HashSet(); + ConstantBuffers = new HashSet(); IAttributes = new HashSet(); diff --git a/Ryujinx.Graphics/Shader/Translation/Optimizations/BranchElimination.cs b/Ryujinx.Graphics/Shader/Translation/Optimizations/BranchElimination.cs new file mode 100644 index 0000000000..2b0f19052b --- /dev/null +++ b/Ryujinx.Graphics/Shader/Translation/Optimizations/BranchElimination.cs @@ -0,0 +1,64 @@ +using Ryujinx.Graphics.Shader.IntermediateRepresentation; +using System; + +namespace Ryujinx.Graphics.Shader.Translation.Optimizations +{ + static class BranchElimination + { + public static bool Eliminate(BasicBlock block) + { + if (block.HasBranch && IsRedundantBranch((Operation)block.GetLastOp(), Next(block))) + { + block.Branch = null; + + return true; + } + + return false; + } + + private static bool IsRedundantBranch(Operation current, BasicBlock nextBlock) + { + //Here we check that: + //- The current block ends with a branch. + //- The next block only contains a branch. + //- The branch on the next block is unconditional. + //- Both branches are jumping to the same location. + //In this case, the branch on the current block can be removed, + //as the next block is going to jump to the same place anyway. + if (nextBlock == null) + { + return false; + } + + if (!(nextBlock.Operations.First?.Value is Operation next)) + { + return false; + } + + if (next.Inst != Instruction.Branch) + { + return false; + } + + return current.Dest == next.Dest; + } + + private static BasicBlock Next(BasicBlock block) + { + block = block.Next; + + while (block != null && block.Operations.Count == 0) + { + if (block.HasBranch) + { + throw new InvalidOperationException("Found a bogus empty block that \"ends with a branch\"."); + } + + block = block.Next; + } + + return block; + } + } +} \ No newline at end of file diff --git a/Ryujinx.Graphics/Shader/Translation/Optimizations/Optimizer.cs b/Ryujinx.Graphics/Shader/Translation/Optimizations/Optimizer.cs index 0af5d260b1..e28361a0e2 100644 --- a/Ryujinx.Graphics/Shader/Translation/Optimizations/Optimizer.cs +++ b/Ryujinx.Graphics/Shader/Translation/Optimizations/Optimizer.cs @@ -54,6 +54,13 @@ namespace Ryujinx.Graphics.Shader.Translation.Optimizations node = nextNode; } + + if (BranchElimination.Eliminate(block)) + { + RemoveNode(block, block.Operations.Last); + + modified = true; + } } } while (modified);