Remove redundant branches, add expression propagation and other improvements on the code

This commit is contained in:
gdkchan 2019-04-04 17:37:23 -03:00
parent c561235925
commit baecbfb31a
24 changed files with 833 additions and 371 deletions

View file

@ -67,6 +67,29 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
}
}
public static void DeclareLocals(CodeGenContext context, StructuredProgramInfo prgInfo)
{
foreach (AstOperand decl in prgInfo.Locals)
{
string name = context.DeclareLocal(decl);
context.AppendLine(GetVarTypeName(decl.VarType) + " " + name + ";");
}
}
private static string GetVarTypeName(VariableType type)
{
switch (type)
{
case VariableType.Bool: return "bool";
case VariableType.F32: return "float";
case VariableType.S32: return "int";
case VariableType.U32: return "uint";
}
throw new ArgumentException($"Invalid variable type \"{type}\".");
}
private static void DeclareUniforms(CodeGenContext context, StructuredProgramInfo prgInfo)
{
foreach (int cbufSlot in prgInfo.ConstantBuffers.OrderBy(x => x))

View file

@ -9,47 +9,73 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
{
class GlslGenerator
{
public string Generate(StructuredProgramInfo prgInfo, GalShaderType shaderType)
public string Generate(StructuredProgramInfo info, GalShaderType shaderType)
{
CodeGenContext cgContext = new CodeGenContext(prgInfo, shaderType);
CodeGenContext context = new CodeGenContext(info, shaderType);
Declarations.Declare(cgContext, prgInfo);
Declarations.Declare(context, info);
PrintBlock(cgContext, prgInfo.MainBlock);
PrintMainBlock(context, info);
return cgContext.GetCode();
return context.GetCode();
}
private void PrintMainBlock(CodeGenContext context, StructuredProgramInfo info)
{
context.AppendLine("void main()");
context.EnterScope();
Declarations.DeclareLocals(context, info);
PrintBlock(context, info.MainBlock);
context.LeaveScope();
}
private void PrintBlock(CodeGenContext context, AstBlock block)
{
switch (block.Type)
AstBlockVisitor visitor = new AstBlockVisitor(block);
visitor.BlockEntered += (sender, e) =>
{
case AstBlockType.DoWhile:
context.AppendLine("do");
break;
case AstBlockType.Else:
context.AppendLine("else");
break;
case AstBlockType.If:
context.AppendLine($"if ({GetCondExpr(context, block.Condition)})");
break;
case AstBlockType.Main:
context.AppendLine("void main()");
break;
}
context.EnterScope();
foreach (IAstNode node in block)
{
if (node is AstBlock subBlock)
switch (e.Block.Type)
{
PrintBlock(context, subBlock);
case AstBlockType.DoWhile:
context.AppendLine("do");
break;
case AstBlockType.Else:
context.AppendLine("else");
break;
case AstBlockType.ElseIf:
context.AppendLine($"else if ({GetCondExpr(context, e.Block.Condition)})");
break;
case AstBlockType.If:
context.AppendLine($"if ({GetCondExpr(context, e.Block.Condition)})");
break;
default: throw new InvalidOperationException($"Found unexpected block type \"{e.Block.Type}\".");
}
else if (node is AstOperation operation)
context.EnterScope();
};
visitor.BlockLeft += (sender, e) =>
{
context.LeaveScope();
if (e.Block.Type == AstBlockType.DoWhile)
{
context.AppendLine($"while ({GetCondExpr(context, e.Block.Condition)});");
}
};
foreach (IAstNode node in visitor.Visit())
{
if (node is AstOperation operation)
{
if (operation.Inst == Instruction.Return)
{
@ -58,60 +84,38 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
context.AppendLine(Instructions.GetExpression(context, operation) + ";");
}
else if (node is AstAssignment asg)
else if (node is AstAssignment assignment)
{
VariableType srcType = OperandManager.GetNodeDestType(asg.Source);
VariableType dstType = OperandManager.GetNodeDestType(asg.Destination);
VariableType srcType = OperandManager.GetNodeDestType(assignment.Source);
VariableType dstType = OperandManager.GetNodeDestType(assignment.Destination);
string dest;
if (asg.Destination is AstOperand operand && operand.Type == OperandType.Attribute)
if (assignment.Destination is AstOperand operand && operand.Type == OperandType.Attribute)
{
dest = OperandManager.GetOutAttributeName(context, operand);
}
else
{
dest = Instructions.GetExpression(context, asg.Destination);
dest = Instructions.GetExpression(context, assignment.Destination);
}
string src = ReinterpretCast(context, asg.Source, srcType, dstType);
string src = ReinterpretCast(context, assignment.Source, srcType, dstType);
context.AppendLine(dest + " = " + src + ";");
}
else if (node is AstDeclaration decl && decl.Operand.Type != OperandType.Undefined)
else
{
string name = context.DeclareLocal(decl.Operand);
context.AppendLine(GetVarTypeName(decl.Operand.VarType) + " " + name + ";");
throw new InvalidOperationException($"Found unexpected node type \"{node?.GetType().Name ?? "null"}\".");
}
}
context.LeaveScope();
if (block.Type == AstBlockType.DoWhile)
{
context.AppendLine($"while ({GetCondExpr(context, block.Condition)});");
}
}
private static string GetCondExpr(CodeGenContext context, IAstNode cond)
{
VariableType srcType = OperandManager.GetNodeDestType(cond);
return ReinterpretCast(Instructions.GetExpression(context, cond), srcType, VariableType.Bool);
}
private string GetVarTypeName(VariableType type)
{
switch (type)
{
case VariableType.Bool: return "bool";
case VariableType.F32: return "float";
case VariableType.S32: return "int";
case VariableType.U32: return "uint";
}
throw new ArgumentException($"Invalid variable type \"{type}\".");
return ReinterpretCast(context, cond, srcType, VariableType.Bool);
}
private static void PrepareForReturn(CodeGenContext context)

View file

@ -9,6 +9,112 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
{
static class Instructions
{
[Flags]
private enum InstFlags
{
OpNullary = 0,
OpUnary = 1,
OpBinary = 2,
OpTernary = 3,
CallNullary = Call | 0,
CallUnary = Call | 1,
CallBinary = Call | 2,
CallTernary = Call | 3,
CallQuaternary = Call | 4,
Call = 1 << 8,
ArityMask = 0xff
}
private struct InstInfo
{
public InstFlags Flags { get; }
public string OpName { get; }
public int Precedence { get; }
public InstInfo(InstFlags flags, string opName, int precedence)
{
Flags = flags;
OpName = opName;
Precedence = precedence;
}
}
private static InstInfo[] _infoTbl;
static Instructions()
{
_infoTbl = new InstInfo[(int)Instruction.Count];
Add(Instruction.Absolute, InstFlags.CallUnary, "abs");
Add(Instruction.Add, InstFlags.OpBinary, "+", 2);
Add(Instruction.BitfieldExtractS32, InstFlags.CallTernary, "bitfieldExtract");
Add(Instruction.BitfieldExtractU32, InstFlags.CallTernary, "bitfieldExtract");
Add(Instruction.BitfieldInsert, InstFlags.CallQuaternary, "bitfieldInsert");
Add(Instruction.BitfieldReverse, InstFlags.CallUnary, "bitfieldReverse");
Add(Instruction.BitwiseAnd, InstFlags.OpBinary, "&", 6);
Add(Instruction.BitwiseExclusiveOr, InstFlags.OpBinary, "^", 7);
Add(Instruction.BitwiseNot, InstFlags.OpUnary, "~", 0);
Add(Instruction.BitwiseOr, InstFlags.OpBinary, "|", 8);
Add(Instruction.Ceiling, InstFlags.CallUnary, "ceil");
Add(Instruction.Clamp, InstFlags.CallTernary, "clamp");
Add(Instruction.ClampU32, InstFlags.CallTernary, "clamp");
Add(Instruction.CompareEqual, InstFlags.OpBinary, "==", 5);
Add(Instruction.CompareGreater, InstFlags.OpBinary, ">", 4);
Add(Instruction.CompareGreaterOrEqual, InstFlags.OpBinary, ">=", 4);
Add(Instruction.CompareGreaterOrEqualU32, InstFlags.OpBinary, ">=", 4);
Add(Instruction.CompareGreaterU32, InstFlags.OpBinary, ">", 4);
Add(Instruction.CompareLess, InstFlags.OpBinary, "<", 4);
Add(Instruction.CompareLessOrEqual, InstFlags.OpBinary, "<=", 4);
Add(Instruction.CompareLessOrEqualU32, InstFlags.OpBinary, "<=", 4);
Add(Instruction.CompareLessU32, InstFlags.OpBinary, "<", 4);
Add(Instruction.CompareNotEqual, InstFlags.OpBinary, "!=", 5);
Add(Instruction.ConditionalSelect, InstFlags.OpTernary, "?:", 12);
Add(Instruction.ConvertFPToS32, InstFlags.CallUnary, "int");
Add(Instruction.ConvertS32ToFP, InstFlags.CallUnary, "float");
Add(Instruction.ConvertU32ToFP, InstFlags.CallUnary, "float");
Add(Instruction.Cosine, InstFlags.CallUnary, "cos");
Add(Instruction.Discard, InstFlags.OpNullary, "discard");
Add(Instruction.Divide, InstFlags.OpBinary, "/", 1);
Add(Instruction.EmitVertex, InstFlags.CallNullary, "EmitVertex");
Add(Instruction.EndPrimitive, InstFlags.CallNullary, "EndPrimitive");
Add(Instruction.ExponentB2, InstFlags.CallUnary, "exp2");
Add(Instruction.Floor, InstFlags.CallUnary, "floor");
Add(Instruction.FusedMultiplyAdd, InstFlags.CallTernary, "fma");
Add(Instruction.IsNan, InstFlags.CallUnary, "isnan");
Add(Instruction.LoadConstant, InstFlags.Call);
Add(Instruction.LogarithmB2, InstFlags.CallUnary, "log2");
Add(Instruction.LogicalAnd, InstFlags.OpBinary, "&&", 9);
Add(Instruction.LogicalExclusiveOr, InstFlags.OpBinary, "^^", 10);
Add(Instruction.LogicalNot, InstFlags.OpUnary, "!", 0);
Add(Instruction.LogicalOr, InstFlags.OpBinary, "||", 11);
Add(Instruction.ShiftLeft, InstFlags.OpBinary, "<<", 3);
Add(Instruction.ShiftRightS32, InstFlags.OpBinary, ">>", 3);
Add(Instruction.ShiftRightU32, InstFlags.OpBinary, ">>", 3);
Add(Instruction.Maximum, InstFlags.CallBinary, "max");
Add(Instruction.MaximumU32, InstFlags.CallBinary, "max");
Add(Instruction.Minimum, InstFlags.CallBinary, "min");
Add(Instruction.MinimumU32, InstFlags.CallBinary, "min");
Add(Instruction.Multiply, InstFlags.OpBinary, "*", 1);
Add(Instruction.Negate, InstFlags.OpUnary, "-", 0);
Add(Instruction.ReciprocalSquareRoot, InstFlags.CallUnary, "inversesqrt");
Add(Instruction.Return, InstFlags.OpNullary, "return");
Add(Instruction.Sine, InstFlags.CallUnary, "sin");
Add(Instruction.SquareRoot, InstFlags.CallUnary, "sqrt");
Add(Instruction.Subtract, InstFlags.OpBinary, "-", 2);
Add(Instruction.TextureSample, InstFlags.Call);
Add(Instruction.Truncate, InstFlags.CallUnary, "trunc");
}
private static void Add(Instruction inst, InstFlags flags, string opName = null, int precedence = 0)
{
_infoTbl[(int)inst] = new InstInfo(flags, opName, precedence);
}
public static string GetExpression(CodeGenContext context, IAstNode node)
{
if (node is AstOperation operation)
@ -43,233 +149,114 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
private static string GetExpression(CodeGenContext context, AstOperation operation)
{
switch (operation.Inst & Instruction.Mask)
Instruction inst = operation.Inst & Instruction.Mask;
switch (inst)
{
case Instruction.Absolute:
return GetUnaryCallExpr(context, operation, "abs");
case Instruction.Add:
return GetBinaryExpr(context, operation, "+");
case Instruction.BitfieldExtractS32:
case Instruction.BitfieldExtractU32:
return GetTernaryCallExpr(context, operation, "bitfieldExtract");
case Instruction.BitfieldInsert:
return GetQuaternaryCallExpr(context, operation, "bitfieldInsert");
case Instruction.BitfieldReverse:
return GetUnaryCallExpr(context, operation, "bitfieldReverse");
case Instruction.BitwiseAnd:
return GetBinaryExpr(context, operation, "&");
case Instruction.BitwiseExclusiveOr:
return GetBinaryExpr(context, operation, "^");
case Instruction.BitwiseNot:
return GetUnaryExpr(context, operation, "~");
case Instruction.BitwiseOr:
return GetBinaryExpr(context, operation, "|");
case Instruction.Ceiling:
return GetUnaryCallExpr(context, operation, "ceil");
case Instruction.CompareEqual:
return GetBinaryExpr(context, operation, "==");
case Instruction.CompareGreater:
case Instruction.CompareGreaterU32:
return GetBinaryExpr(context, operation, ">");
case Instruction.CompareGreaterOrEqual:
case Instruction.CompareGreaterOrEqualU32:
return GetBinaryExpr(context, operation, ">=");
case Instruction.CompareLess:
case Instruction.CompareLessU32:
return GetBinaryExpr(context, operation, "<");
case Instruction.CompareLessOrEqual:
case Instruction.CompareLessOrEqualU32:
return GetBinaryExpr(context, operation, "<=");
case Instruction.CompareNotEqual:
return GetBinaryExpr(context, operation, "!=");
case Instruction.ConditionalSelect:
return GetConditionalSelectExpr(context, operation);
case Instruction.Cosine:
return GetUnaryCallExpr(context, operation, "cos");
case Instruction.Clamp:
case Instruction.ClampU32:
return GetTernaryCallExpr(context, operation, "clamp");
case Instruction.ConvertFPToS32:
return GetUnaryCallExpr(context, operation, "int");
case Instruction.ConvertS32ToFP:
case Instruction.ConvertU32ToFP:
return GetUnaryCallExpr(context, operation, "float");
case Instruction.Discard:
return "discard";
case Instruction.Divide:
return GetBinaryExpr(context, operation, "/");
case Instruction.EmitVertex:
return "EmitVertex()";
case Instruction.EndPrimitive:
return "EndPrimitive()";
case Instruction.ExponentB2:
return GetUnaryCallExpr(context, operation, "exp2");
case Instruction.Floor:
return GetUnaryCallExpr(context, operation, "floor");
case Instruction.FusedMultiplyAdd:
return GetTernaryCallExpr(context, operation, "fma");
case Instruction.IsNan:
return GetUnaryCallExpr(context, operation, "isnan");
case Instruction.LoadConstant:
return GetLoadConstantExpr(context, operation);
case Instruction.LogarithmB2:
return GetUnaryCallExpr(context, operation, "log2");
case Instruction.LogicalAnd:
return GetBinaryExpr(context, operation, "&&");
case Instruction.LogicalExclusiveOr:
return GetBinaryExpr(context, operation, "^^");
case Instruction.LogicalNot:
return GetUnaryExpr(context, operation, "!");
case Instruction.LogicalOr:
return GetBinaryExpr(context, operation, "||");
case Instruction.LoopBreak:
return "break";
case Instruction.LoopContinue:
return "continue";
case Instruction.Maximum:
case Instruction.MaximumU32:
return GetBinaryCallExpr(context, operation, "max");
case Instruction.Minimum:
case Instruction.MinimumU32:
return GetBinaryCallExpr(context, operation, "min");
case Instruction.Multiply:
return GetBinaryExpr(context, operation, "*");
case Instruction.Negate:
return GetUnaryExpr(context, operation, "-");
case Instruction.ReciprocalSquareRoot:
return GetUnaryCallExpr(context, operation, "inversesqrt");
case Instruction.Return:
return "return";
case Instruction.ShiftLeft:
return GetBinaryExpr(context, operation, "<<");
case Instruction.ShiftRightS32:
case Instruction.ShiftRightU32:
return GetBinaryExpr(context, operation, ">>");
case Instruction.Sine:
return GetUnaryCallExpr(context, operation, "sin");
case Instruction.SquareRoot:
return GetUnaryCallExpr(context, operation, "sqrt");
case Instruction.Subtract:
return GetBinaryExpr(context, operation, "-");
case Instruction.TextureSample:
return GetTextureSampleExpr(context, operation);
case Instruction.Truncate:
return GetUnaryCallExpr(context, operation, "trunc");
}
throw new ArgumentException($"Operation has invalid instruction \"{operation.Inst}\".");
}
InstInfo info = _infoTbl[(int)inst];
private static string GetUnaryCallExpr(CodeGenContext context, AstOperation operation, string funcName)
{
return funcName + "(" + GetSoureExpr(context, operation.Sources[0], GetSrcVarType(operation.Inst, 0)) + ")";
}
if ((info.Flags & InstFlags.Call) != 0)
{
int arity = (int)(info.Flags & InstFlags.ArityMask);
private static string GetBinaryCallExpr(CodeGenContext context, AstOperation operation, string funcName)
{
return funcName + "(" +
GetSoureExpr(context, operation.Sources[0], GetSrcVarType(operation.Inst, 0)) + ", " +
GetSoureExpr(context, operation.Sources[1], GetSrcVarType(operation.Inst, 1)) + ")";
}
string args = string.Empty;
private static string GetTernaryCallExpr(CodeGenContext context, AstOperation operation, string funcName)
{
return funcName + "(" +
GetSoureExpr(context, operation.Sources[0], GetSrcVarType(operation.Inst, 0)) + ", " +
GetSoureExpr(context, operation.Sources[1], GetSrcVarType(operation.Inst, 1)) + ", " +
GetSoureExpr(context, operation.Sources[2], GetSrcVarType(operation.Inst, 2)) + ")";
}
for (int argIndex = 0; argIndex < arity; argIndex++)
{
if (argIndex != 0)
{
args += ", ";
}
private static string GetQuaternaryCallExpr(CodeGenContext context, AstOperation operation, string funcName)
{
return funcName + "(" +
GetSoureExpr(context, operation.Sources[0], GetSrcVarType(operation.Inst, 0)) + ", " +
GetSoureExpr(context, operation.Sources[1], GetSrcVarType(operation.Inst, 1)) + ", " +
GetSoureExpr(context, operation.Sources[2], GetSrcVarType(operation.Inst, 2)) + ", " +
GetSoureExpr(context, operation.Sources[3], GetSrcVarType(operation.Inst, 3)) + ")";
}
VariableType dstType = GetSrcVarType(operation.Inst, argIndex);
private static string GetUnaryExpr(CodeGenContext context, AstOperation operation, string op)
{
return op + GetSoureExpr(context, operation.Sources[0], GetSrcVarType(operation.Inst, 0));
}
args += GetSoureExpr(context, operation.GetSource(argIndex), dstType);
}
private static string GetBinaryExpr(CodeGenContext context, AstOperation operation, string op)
{
return GetSoureExpr(context, operation.Sources[0], GetSrcVarType(operation.Inst, 0)) + " " + op + " " +
GetSoureExpr(context, operation.Sources[1], GetSrcVarType(operation.Inst, 1));
}
return info.OpName + "(" + args + ")";
}
else
{
if (info.Flags == InstFlags.OpNullary)
{
return info.OpName;
}
else if (info.Flags == InstFlags.OpUnary)
{
IAstNode src = operation.GetSource(0);
private static string GetConditionalSelectExpr(CodeGenContext context, AstOperation operation)
{
return "((" +
GetSoureExpr(context, operation.Sources[0], GetSrcVarType(operation.Inst, 0)) + ") ? (" +
GetSoureExpr(context, operation.Sources[1], GetSrcVarType(operation.Inst, 1)) + ") : (" +
GetSoureExpr(context, operation.Sources[2], GetSrcVarType(operation.Inst, 2)) + "))";
string expr = GetSoureExpr(context, src, GetSrcVarType(operation.Inst, 0));
return info.OpName + Enclose(expr, src, info);
}
else if (info.Flags == InstFlags.OpBinary)
{
IAstNode src0 = operation.GetSource(0);
IAstNode src1 = operation.GetSource(1);
string expr0 = GetSoureExpr(context, src0, GetSrcVarType(operation.Inst, 0));
string expr1 = GetSoureExpr(context, src1, GetSrcVarType(operation.Inst, 1));
expr0 = Enclose(expr0, src0, info, isLhs: true);
expr1 = Enclose(expr1, src1, info, isLhs: false);
return expr0 + " " + info.OpName + " " + expr1;
}
else if (info.Flags == InstFlags.OpTernary)
{
IAstNode src0 = operation.GetSource(0);
IAstNode src1 = operation.GetSource(1);
IAstNode src2 = operation.GetSource(2);
string expr0 = GetSoureExpr(context, src0, GetSrcVarType(operation.Inst, 0));
string expr1 = GetSoureExpr(context, src1, GetSrcVarType(operation.Inst, 1));
string expr2 = GetSoureExpr(context, src2, GetSrcVarType(operation.Inst, 2));
expr0 = Enclose(expr0, src0, info);
expr1 = Enclose(expr1, src1, info);
expr2 = Enclose(expr2, src2, info);
char op0 = info.OpName[0];
char op1 = info.OpName[1];
return expr0 + " " + op0 + " " + expr1 + " " + op1 + " " + expr2;
}
else
{
throw new InvalidOperationException($"Unexpected instruction flags \"{info.Flags}\".");
}
}
}
private static string GetLoadConstantExpr(CodeGenContext context, AstOperation operation)
{
string offsetExpr = GetSoureExpr(context, operation.Sources[1], GetSrcVarType(operation.Inst, 1));
IAstNode src1 = operation.GetSource(1);
return OperandManager.GetConstantBufferName(context, operation.Sources[0], offsetExpr);
string offsetExpr = GetSoureExpr(context, src1, GetSrcVarType(operation.Inst, 1));
offsetExpr = Enclose(offsetExpr, src1, Instruction.ShiftRightS32, isLhs: true);
return OperandManager.GetConstantBufferName(context, operation.GetSource(0), offsetExpr);
}
private static string GetTextureSampleExpr(CodeGenContext context, AstOperation operation)
{
AstTextureOperation texOp = (AstTextureOperation)operation;
bool isGather = (texOp.Flags & TextureFlags.Gather) != 0;
bool isShadow = (texOp.Type & TextureType.Shadow) != 0;
bool isGather = (texOp.Flags & TextureFlags.Gather) != 0;
bool hasLodBias = (texOp.Flags & TextureFlags.LodBias) != 0;
bool hasLodLevel = (texOp.Flags & TextureFlags.LodLevel) != 0;
bool hasOffset = (texOp.Flags & TextureFlags.Offset) != 0;
bool hasOffsets = (texOp.Flags & TextureFlags.Offsets) != 0;
bool isArray = (texOp.Type & TextureType.Array) != 0;
bool isShadow = (texOp.Type & TextureType.Shadow) != 0;
string samplerName = OperandManager.GetSamplerName(context.ShaderType, texOp.TextureHandle);
@ -280,16 +267,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
texCall += "Gather";
}
if ((texOp.Flags & TextureFlags.LodLevel) != 0)
if (hasLodLevel)
{
texCall += "Lod";
}
if ((texOp.Flags & TextureFlags.Offset) != 0)
if (hasOffset)
{
texCall += "Offset";
}
else if ((texOp.Flags & TextureFlags.Offsets) != 0)
else if (hasOffsets)
{
texCall += "Offsets";
}
@ -314,7 +301,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
int arrayIndexElem = -1;
if ((texOp.Type & TextureType.Array) != 0)
if (isArray)
{
arrayIndexElem = pCount++;
}
@ -338,7 +325,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
string Src(VariableType type)
{
return GetSoureExpr(context, texOp.Sources[srcIndex++], type);
return GetSoureExpr(context, texOp.GetSource(srcIndex++), type);
}
string AssembleVector(int count, VariableType type, bool isP = false)
@ -376,16 +363,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
texCall += ", " + Src(VariableType.F32);
}
if ((texOp.Flags & TextureFlags.LodLevel) != 0)
if (hasLodLevel)
{
texCall += ", " + Src(VariableType.F32);
}
if ((texOp.Flags & TextureFlags.Offset) != 0)
if (hasOffset)
{
texCall += ", " + AssembleVector(elemsCount, VariableType.S32);
}
else if ((texOp.Flags & TextureFlags.Offsets) != 0)
else if (hasOffsets)
{
const int gatherTexelsCount = 4;
@ -404,7 +391,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
texCall += ")";
}
if ((texOp.Flags & TextureFlags.LodBias) != 0)
if (hasLodBias && !hasLodLevel)
{
texCall += ", " + Src(VariableType.F32);
}
@ -434,5 +421,56 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
{
return ReinterpretCast(context, node, OperandManager.GetNodeDestType(node), dstType);
}
public static string Enclose(string expr, IAstNode node, Instruction inst, bool isLhs)
{
InstInfo info = _infoTbl[(int)(inst & Instruction.Mask)];
return Enclose(expr, node, info, isLhs);
}
private static string Enclose(string expr, IAstNode node, InstInfo pInfo, bool isLhs = false)
{
if (NeedsParenthesis(node, pInfo, isLhs))
{
expr = "(" + expr + ")";
}
return expr;
}
private static bool NeedsParenthesis(IAstNode node, InstInfo pInfo, bool isLhs)
{
//If the node isn't a operation, then it can only be a operand,
//and those never needs to be surrounded in parenthesis.
if (!(node is AstOperation operation))
{
return false;
}
if ((pInfo.Flags & InstFlags.Call) != 0)
{
return false;
}
InstInfo info = _infoTbl[(int)(operation.Inst & Instruction.Mask)];
if ((info.Flags & InstFlags.Call) != 0)
{
return false;
}
if (info.Precedence < pInfo.Precedence)
{
return false;
}
if (info.Precedence == pInfo.Precedence && isLhs)
{
return false;
}
return true;
}
}
}

View file

@ -22,10 +22,10 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
string expr = Instructions.GetExpression(context, node);
return ReinterpretCast(expr, srcType, dstType);
return ReinterpretCast(expr, node, srcType, dstType);
}
public static string ReinterpretCast(string expr, VariableType srcType, VariableType dstType)
private static string ReinterpretCast(string expr, IAstNode node, VariableType srcType, VariableType dstType)
{
if (srcType == dstType)
{
@ -55,7 +55,9 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
}
else if (dstType == VariableType.Bool)
{
return $"(({expr}) != 0)";
expr = Instructions.Enclose(expr, node, Instruction.CompareNotEqual, isLhs: true);
return $"({expr} != 0)";
}
else if (dstType == VariableType.S32)
{

View file

@ -24,6 +24,8 @@ namespace Ryujinx.Graphics.Shader.IntermediateRepresentation
set => _branch = AddSuccessor(_branch, value);
}
public bool HasBranch => _branch != null;
public List<BasicBlock> Predecessors { get; }
public HashSet<BasicBlock> DominanceFrontiers { get; }
@ -47,7 +49,7 @@ namespace Ryujinx.Graphics.Shader.IntermediateRepresentation
private BasicBlock AddSuccessor(BasicBlock oldBlock, BasicBlock newBlock)
{
oldBlock?.Predecessors.Remove(this);
newBlock.Predecessors.Add(this);
newBlock?.Predecessors.Add(this);
return newBlock;
}

View file

@ -2,6 +2,8 @@ namespace Ryujinx.Graphics.Shader.IntermediateRepresentation
{
interface INode
{
BasicBlock Parent { get; set; }
Operand Dest { get; set; }
int SourcesCount { get; }

View file

@ -2,6 +2,8 @@ namespace Ryujinx.Graphics.Shader.IntermediateRepresentation
{
class Operation : INode
{
public BasicBlock Parent { get; set; }
public Instruction Inst { get; private set; }
private Operand _dest;

View file

@ -4,6 +4,8 @@ namespace Ryujinx.Graphics.Shader.IntermediateRepresentation
{
class PhiNode : INode
{
public BasicBlock Parent { get; set; }
private Operand _dest;
public Operand Dest

View file

@ -1,14 +1,35 @@
using static Ryujinx.Graphics.Shader.StructuredIr.AstHelper;
namespace Ryujinx.Graphics.Shader.StructuredIr
{
class AstAssignment : AstNode
{
public IAstNode Destination { get; }
public IAstNode Source { get; }
private IAstNode _source;
public IAstNode Source
{
get
{
return _source;
}
set
{
RemoveUse(_source, this);
AddUse(value, this);
_source = value;
}
}
public AstAssignment(IAstNode destination, IAstNode source)
{
Destination = destination;
Source = source;
AddDef(destination, this);
}
}
}

View file

@ -1,20 +1,39 @@
using System;
using Ryujinx.Graphics.Shader.IntermediateRepresentation;
using System;
using System.Collections;
using System.Collections.Generic;
using static Ryujinx.Graphics.Shader.StructuredIr.AstHelper;
namespace Ryujinx.Graphics.Shader.StructuredIr
{
class AstBlock : AstNode, IEnumerable<IAstNode>
{
public AstBlockType Type { get; }
public AstBlockType Type { get; private set; }
public IAstNode Condition { get; private set; }
private IAstNode _condition;
public IAstNode Condition
{
get
{
return _condition;
}
set
{
RemoveUse(_condition, this);
AddUse(value, this);
_condition = value;
}
}
private LinkedList<IAstNode> _nodes;
public IAstNode First => _nodes.First?.Value;
public IAstNode Last => _nodes.Last?.Value;
public int Count => _nodes.Count;
public AstBlock(AstBlockType type, IAstNode condition = null)
{
@ -34,14 +53,14 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
Add(node, _nodes.AddFirst(node));
}
public void AddBefore(IAstNode oldNode, IAstNode node)
public void AddBefore(IAstNode next, IAstNode node)
{
Add(node, _nodes.AddBefore(oldNode.LLNode, node));
Add(node, _nodes.AddBefore(next.LLNode, node));
}
public void AddAfter(IAstNode oldNode, IAstNode node)
public void AddAfter(IAstNode prev, IAstNode node)
{
Add(node, _nodes.AddAfter(oldNode.LLNode, node));
Add(node, _nodes.AddAfter(prev.LLNode, node));
}
private void Add(IAstNode node, LinkedListNode<IAstNode> newNode)
@ -72,6 +91,17 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
{
Condition = new AstOperation(Instruction.LogicalOr, Condition, cond);
}
public void TurnIntoIf(IAstNode cond)
{
Condition = cond;
Type = AstBlockType.If;
}
public void TurnIntoElseIf()
{
Type = AstBlockType.ElseIf;
}
public IEnumerator<IAstNode> GetEnumerator()
{

View file

@ -5,6 +5,7 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
DoWhile,
If,
Else,
ElseIf,
Main,
While
}

View file

@ -0,0 +1,64 @@
using System;
using System.Collections.Generic;
using static Ryujinx.Graphics.Shader.StructuredIr.AstHelper;
namespace Ryujinx.Graphics.Shader.StructuredIr
{
class AstBlockVisitor
{
public AstBlock Block { get; private set; }
public class BlockVisitationEventArgs : EventArgs
{
public AstBlock Block { get; }
public BlockVisitationEventArgs(AstBlock block)
{
Block = block;
}
}
public event EventHandler<BlockVisitationEventArgs> BlockEntered;
public event EventHandler<BlockVisitationEventArgs> BlockLeft;
public AstBlockVisitor(AstBlock mainBlock)
{
Block = mainBlock;
}
public IEnumerable<IAstNode> Visit()
{
IAstNode node = Block.First;
while (node != null)
{
//We reached a child block, visit the nodes inside.
while (node is AstBlock childBlock)
{
Block = childBlock;
node = childBlock.First;
BlockEntered?.Invoke(this, new BlockVisitationEventArgs(Block));
}
IAstNode next = Next(node);
yield return node;
node = next;
//We reached the end of the list, go up on tree to the parent blocks.
while (node == null && Block.Type != AstBlockType.Main)
{
BlockLeft?.Invoke(this, new BlockVisitationEventArgs(Block));
node = Next(Block);
Block = Block.Parent;
}
}
}
}
}

View file

@ -1,12 +0,0 @@
namespace Ryujinx.Graphics.Shader.StructuredIr
{
class AstDeclaration : AstNode
{
public AstOperand Operand { get; }
public AstDeclaration(AstOperand operand)
{
Operand = operand;
}
}
}

View file

@ -4,6 +4,38 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
{
static class AstHelper
{
public static void AddUse(IAstNode node, IAstNode parent)
{
if (node is AstOperand operand && operand.Type == OperandType.LocalVariable)
{
operand.Uses.Add(parent);
}
}
public static void AddDef(IAstNode node, IAstNode parent)
{
if (node is AstOperand operand && operand.Type == OperandType.LocalVariable)
{
operand.Defs.Add(parent);
}
}
public static void RemoveUse(IAstNode node, IAstNode parent)
{
if (node is AstOperand operand && operand.Type == OperandType.LocalVariable)
{
operand.Uses.Remove(parent);
}
}
public static void RemoveDef(IAstNode node, IAstNode parent)
{
if (node is AstOperand operand && operand.Type == OperandType.LocalVariable)
{
operand.Defs.Remove(parent);
}
}
public static AstAssignment Assign(IAstNode destination, IAstNode source)
{
return new AstAssignment(destination, source);
@ -22,5 +54,20 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
return local;
}
public static IAstNode InverseCond(IAstNode cond)
{
return new AstOperation(Instruction.LogicalNot, cond);
}
public static IAstNode Next(IAstNode node)
{
return node.LLNode.Next?.Value;
}
public static IAstNode Previous(IAstNode node)
{
return node.LLNode.Previous?.Value;
}
}
}

View file

@ -1,9 +1,13 @@
using Ryujinx.Graphics.Shader.IntermediateRepresentation;
using System.Collections.Generic;
namespace Ryujinx.Graphics.Shader.StructuredIr
{
class AstOperand : AstNode
{
public HashSet<IAstNode> Defs { get; }
public HashSet<IAstNode> Uses { get; }
public OperandType Type { get; }
public VariableType VarType { get; set; }
@ -15,6 +19,9 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
private AstOperand()
{
Defs = new HashSet<IAstNode>();
Uses = new HashSet<IAstNode>();
VarType = VariableType.S32;
}

View file

@ -1,17 +1,40 @@
using Ryujinx.Graphics.Shader.IntermediateRepresentation;
using static Ryujinx.Graphics.Shader.StructuredIr.AstHelper;
namespace Ryujinx.Graphics.Shader.StructuredIr
{
class AstOperation : AstNode
{
public Instruction Inst { get; }
public IAstNode[] Sources { get; }
private IAstNode[] _sources;
public int SourcesCount => _sources.Length;
public AstOperation(Instruction inst, params IAstNode[] sources)
{
Inst = inst;
Sources = sources;
Inst = inst;
_sources = sources;
foreach (IAstNode source in sources)
{
AddUse(source, this);
}
}
public IAstNode GetSource(int index)
{
return _sources[index];
}
public void SetSource(int index, IAstNode source)
{
RemoveUse(_sources[index], this);
AddUse(source, this);
_sources[index] = source;
}
}
}

View file

@ -0,0 +1,149 @@
using Ryujinx.Graphics.Shader.IntermediateRepresentation;
using System.Collections.Generic;
using System.Linq;
using static Ryujinx.Graphics.Shader.StructuredIr.AstHelper;
namespace Ryujinx.Graphics.Shader.StructuredIr
{
static class AstOptimizer
{
public static void Optimize(StructuredProgramInfo info)
{
AstBlock mainBlock = info.MainBlock;
AstBlockVisitor visitor = new AstBlockVisitor(mainBlock);
foreach (IAstNode node in visitor.Visit())
{
if (node is AstAssignment assignment && assignment.Destination is AstOperand propVar)
{
bool isWorthPropagating = propVar.Uses.Count == 1 || IsWorthPropagating(assignment.Source);
if (propVar.Defs.Count == 1 && isWorthPropagating)
{
PropagateExpression(propVar, assignment.Source);
}
if (propVar.Type == OperandType.LocalVariable && propVar.Uses.Count == 0)
{
visitor.Block.Remove(assignment);
info.Locals.Remove(propVar);
}
}
}
RemoveEmptyBlocks(mainBlock);
}
private static bool IsWorthPropagating(IAstNode source)
{
if (!(source is AstOperation srcOp))
{
return false;
}
if (!InstructionInfo.IsUnary(srcOp.Inst))
{
return false;
}
return srcOp.GetSource(0) is AstOperand || srcOp.Inst == Instruction.Copy;
}
private static void PropagateExpression(AstOperand propVar, IAstNode source)
{
IAstNode[] uses = propVar.Uses.ToArray();
foreach (IAstNode useNode in uses)
{
if (useNode is AstBlock useBlock)
{
useBlock.Condition = source;
}
else if (useNode is AstOperation useOperation)
{
for (int srcIndex = 0; srcIndex < useOperation.SourcesCount; srcIndex++)
{
if (useOperation.GetSource(srcIndex) == propVar)
{
useOperation.SetSource(srcIndex, source);
}
}
}
else if (useNode is AstAssignment useAssignment)
{
useAssignment.Source = source;
}
}
}
private static void RemoveEmptyBlocks(AstBlock mainBlock)
{
Queue<AstBlock> pending = new Queue<AstBlock>();
pending.Enqueue(mainBlock);
while (pending.TryDequeue(out AstBlock block))
{
foreach (IAstNode node in block)
{
if (node is AstBlock childBlock)
{
pending.Enqueue(childBlock);
}
}
AstBlock parent = block.Parent;
if (parent == null)
{
continue;
}
AstBlock nextBlock = Next(block) as AstBlock;
bool hasElse = nextBlock != null && nextBlock.Type == AstBlockType.Else;
bool isIf = block.Type == AstBlockType.If;
if (block.Count == 0)
{
if (isIf)
{
if (hasElse)
{
nextBlock.TurnIntoIf(InverseCond(block.Condition));
}
parent.Remove(block);
}
else if (block.Type == AstBlockType.Else)
{
parent.Remove(block);
}
}
else if (isIf && parent.Type == AstBlockType.Else && parent.Count == (hasElse ? 2 : 1))
{
AstBlock parentOfParent = parent.Parent;
parent.Remove(block);
parentOfParent.AddAfter(parent, block);
if (hasElse)
{
parent.Remove(nextBlock);
parentOfParent.AddAfter(block, nextBlock);
}
parentOfParent.Remove(parent);
block.TurnIntoElseIf();
}
}
}
}
}

View file

@ -2,6 +2,8 @@ using Ryujinx.Graphics.Shader.IntermediateRepresentation;
using System;
using System.Collections.Generic;
using static Ryujinx.Graphics.Shader.StructuredIr.AstHelper;
namespace Ryujinx.Graphics.Shader.StructuredIr
{
static class GotoElimination
@ -290,14 +292,14 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
private static bool ContainsCondComb(IAstNode node, Instruction inst, IAstNode newCond)
{
while (node is AstOperation operation && operation.Sources.Length == 2)
while (node is AstOperation operation && operation.SourcesCount == 2)
{
if (operation.Inst == inst && IsSameCond(operation.Sources[1], newCond))
if (operation.Inst == inst && IsSameCond(operation.GetSource(1), newCond))
{
return true;
}
node = operation.Sources[0];
node = operation.GetSource(0);
}
return false;
@ -396,8 +398,8 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
return false;
}
lCond = lCondOp.Sources[0];
rCond = rCondOp.Sources[0];
lCond = lCondOp.GetSource(0);
rCond = rCondOp.GetSource(0);
}
return lCond == rCond;
@ -447,20 +449,5 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
return level;
}
private static IAstNode InverseCond(IAstNode cond)
{
return new AstOperation(Instruction.LogicalNot, cond);
}
private static IAstNode Next(IAstNode node)
{
return node.LLNode.Next?.Value;
}
private static IAstNode Previous(IAstNode node)
{
return node.LLNode.Previous?.Value;
}
}
}

View file

@ -77,6 +77,7 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
Add(Instruction.ReciprocalSquareRoot, VariableType.Scalar, VariableType.Scalar);
Add(Instruction.Sine, VariableType.Scalar, VariableType.Scalar);
Add(Instruction.SquareRoot, VariableType.Scalar, VariableType.Scalar);
Add(Instruction.Subtract, VariableType.Scalar, VariableType.Scalar, VariableType.Scalar);
Add(Instruction.Truncate, VariableType.F32, VariableType.F32);
}
@ -129,5 +130,20 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
{
return inst == Instruction.TextureSample;
}
public static bool IsUnary(Instruction inst)
{
if (inst == Instruction.TextureSample)
{
return false;
}
if (inst == Instruction.Copy)
{
return true;
}
return _infoTbl[(int)(inst & Instruction.Mask)].SrcTypes.Length == 1;
}
}
}

View file

@ -20,15 +20,22 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
foreach (INode node in block.Operations)
{
AddOperation(context, (Operation)node);
}
Operation operation = (Operation)node;
context.LeaveBlock(block);
if (IsBranchInst(operation.Inst))
{
context.LeaveBlock(block, operation);
}
else
{
AddOperation(context, operation);
}
}
}
GotoElimination.Eliminate(context.GetGotos());
context.PrependLocalDeclarations();
AstOptimizer.Optimize(context.Info);
return context.Info;
}
@ -44,13 +51,20 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
sources[index] = context.GetOperandUse(operation.GetSource(index));
}
if (operation.Dest != null && !IsBranchInst(inst))
if (operation.Dest != null)
{
AstOperand dest = context.GetOperandDef(operation.Dest);
if (inst == Instruction.LoadConstant)
{
context.Info.ConstantBuffers.Add((sources[0] as AstOperand).Value);
Operand ldcSource = operation.GetSource(0);
if (ldcSource.Type != OperandType.Constant)
{
throw new InvalidOperationException("Found LDC with non-constant constant buffer slot.");
}
context.Info.ConstantBuffers.Add(ldcSource.Value);
}
AstAssignment assignment;

View file

@ -14,8 +14,6 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
private Dictionary<Operand, AstOperand> _localsMap;
private List<AstOperand> _locals;
private Dictionary<int, AstAssignment> _gotoTempAsgs;
private List<GotoStatement> _gotos;
@ -34,8 +32,6 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
_localsMap = new Dictionary<Operand, AstOperand>();
_locals = new List<AstOperand>();
_gotoTempAsgs = new Dictionary<int, AstAssignment>();
_gotos = new List<GotoStatement>();
@ -62,9 +58,9 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
LookForDoWhileStatements(block);
}
public void LeaveBlock(BasicBlock block)
public void LeaveBlock(BasicBlock block, Operation branchOp)
{
LookForIfStatements(block);
LookForIfStatements(block, branchOp);
}
private void LookForDoWhileStatements(BasicBlock block)
@ -99,38 +95,30 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
}
}
private void LookForIfStatements(BasicBlock block)
private void LookForIfStatements(BasicBlock block, Operation branchOp)
{
if (block.Branch == null)
{
return;
}
Operation branchOp = (Operation)block.GetLastOp();
bool isLoop = block.Branch.Index <= block.Index;
AstOperation branch = _currBlock.Last as AstOperation;
if (block.Branch.Index <= _currEndIndex && !isLoop)
{
_currBlock.Remove(branch);
NewBlock(AstBlockType.If, branchOp, block.Branch.Index);
}
else if (_loopTails.Contains(block))
{
//Loop handled by "LookForDoWhileStatements".
//We can safely remove the branch as it was already taken care of.
_currBlock.Remove(branch);
}
else
else if (!_loopTails.Contains(block))
{
AstAssignment gotoTempAsg = GetGotoTempAsg(block.Branch.Index);
IAstNode cond = GetBranchCond(AstBlockType.DoWhile, branchOp);
_currBlock.AddBefore(branch, Assign(gotoTempAsg.Destination, cond));
AddNode(Assign(gotoTempAsg.Destination, cond));
AstOperation branch = new AstOperation(branchOp.Inst);
AddNode(branch);
GotoStatement gotoStmt = new GotoStatement(branch, gotoTempAsg, isLoop);
@ -214,29 +202,6 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
_currBlock.Add(node);
}
public void PrependLocalDeclarations()
{
AstBlock mainBlock = Info.MainBlock;
AstDeclaration decl = null;
foreach (AstOperand operand in _locals)
{
AstDeclaration oldDecl = decl;
decl = new AstDeclaration(operand);
if (oldDecl == null)
{
mainBlock.AddFirst(decl);
}
else
{
mainBlock.AddAfter(oldDecl, decl);
}
}
}
public GotoStatement[] GetGotos()
{
return _gotos.ToArray();
@ -246,7 +211,7 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
{
AstOperand newTemp = Local(type);
_locals.Add(newTemp);
Info.Locals.Add(newTemp);
return newTemp;
}
@ -293,7 +258,7 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
_localsMap.Add(operand, astOperand);
_locals.Add(astOperand);
Info.Locals.Add(astOperand);
}
return astOperand;

View file

@ -7,6 +7,8 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
{
public AstBlock MainBlock { get; }
public HashSet<AstOperand> Locals { get; }
public HashSet<int> ConstantBuffers { get; }
public HashSet<int> IAttributes { get; }
@ -18,6 +20,8 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
{
MainBlock = mainBlock;
Locals = new HashSet<AstOperand>();
ConstantBuffers = new HashSet<int>();
IAttributes = new HashSet<int>();

View file

@ -0,0 +1,64 @@
using Ryujinx.Graphics.Shader.IntermediateRepresentation;
using System;
namespace Ryujinx.Graphics.Shader.Translation.Optimizations
{
static class BranchElimination
{
public static bool Eliminate(BasicBlock block)
{
if (block.HasBranch && IsRedundantBranch((Operation)block.GetLastOp(), Next(block)))
{
block.Branch = null;
return true;
}
return false;
}
private static bool IsRedundantBranch(Operation current, BasicBlock nextBlock)
{
//Here we check that:
//- The current block ends with a branch.
//- The next block only contains a branch.
//- The branch on the next block is unconditional.
//- Both branches are jumping to the same location.
//In this case, the branch on the current block can be removed,
//as the next block is going to jump to the same place anyway.
if (nextBlock == null)
{
return false;
}
if (!(nextBlock.Operations.First?.Value is Operation next))
{
return false;
}
if (next.Inst != Instruction.Branch)
{
return false;
}
return current.Dest == next.Dest;
}
private static BasicBlock Next(BasicBlock block)
{
block = block.Next;
while (block != null && block.Operations.Count == 0)
{
if (block.HasBranch)
{
throw new InvalidOperationException("Found a bogus empty block that \"ends with a branch\".");
}
block = block.Next;
}
return block;
}
}
}

View file

@ -54,6 +54,13 @@ namespace Ryujinx.Graphics.Shader.Translation.Optimizations
node = nextNode;
}
if (BranchElimination.Eliminate(block))
{
RemoveNode(block, block.Operations.Last);
modified = true;
}
}
}
while (modified);