Implement the HFMA instruction, and some misc. fixes

This commit is contained in:
gdkchan 2019-04-10 17:51:31 -03:00
parent c5faac2c00
commit 5b106a51ed
11 changed files with 385 additions and 198 deletions

View file

@ -116,6 +116,14 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions
//and those never needs to be surrounded in parenthesis.
if (!(node is AstOperation operation))
{
//This is sort of a special case, if this is a negative constant,
//and it is consumed by a unary operation, we need to put on the parenthesis,
//as in GLSL a sequence like --2 or ~-1 is not valid.
if (IsNegativeConst(node) && pInfo.Type == InstType.OpUnary)
{
return true;
}
return false;
}
@ -141,12 +149,22 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions
return false;
}
if (pInst == operation.Inst && (info.Type & InstType.Comutative) != 0)
if (pInst == operation.Inst && info.Type == InstType.OpBinaryCom)
{
return false;
}
return true;
}
private static bool IsNegativeConst(IAstNode node)
{
if (!(node is AstOperand operand))
{
return false;
}
return operand.Type == OperandType.Constant && operand.Value < 0;
}
}
}

View file

@ -34,5 +34,20 @@ namespace Ryujinx.Graphics.Shader.Decoders
return null;
}
public void UpdateSsyOpCodes()
{
SsyOpCodes.Clear();
for (int index = 0; index < OpCodes.Count; index++)
{
if (!(OpCodes[index] is OpCodeSsy op))
{
continue;
}
SsyOpCodes.Add(op);
}
}
}
}

View file

@ -98,6 +98,9 @@ namespace Ryujinx.Graphics.Shader.Decoders
current.OpCodes.Count - smaller.OpCodes.Count,
smaller.OpCodes.Count);
current.UpdateSsyOpCodes();
smaller.UpdateSsyOpCodes();
visitedEnd[smaller.EndAddress] = smaller;
}
@ -289,15 +292,12 @@ namespace Ryujinx.Graphics.Shader.Decoders
OpCode op = MakeOpCode(opCodeType, emitter, opAddress, opCode);
block.OpCodes.Add(op);
if (op.Emitter == InstEmit.Ssy)
{
block.SsyOpCodes.Add((OpCodeSsy)op);
}
}
while (!IsBranch(block.GetLastOp()));
block.EndAddress = address;
block.UpdateSsyOpCodes();
}
private static bool IsUnconditionalBranch(OpCode opCode)

View file

@ -16,15 +16,15 @@ namespace Ryujinx.Graphics.Shader.Decoders
if (negateH0)
{
immH0 |= 1 << 10;
immH0 |= 1 << 9;
}
if (negateH1)
{
immH1 |= 1 << 10;
immH1 |= 1 << 9;
}
Immediate = immH1 << 16 | immH0;
Immediate = immH1 << 22 | immH0 << 6;
}
}
}

View file

@ -68,6 +68,11 @@ namespace Ryujinx.Graphics.Shader.Decoders
Set("0111101x0xxxxx", InstEmit.Hadd2, typeof(OpCodeAluImm2x10));
Set("0010110xxxxxxx", InstEmit.Hadd2, typeof(OpCodeAluImm32));
Set("0101110100010x", InstEmit.Hadd2, typeof(OpCodeAluReg));
Set("01110xxx1xxxxx", InstEmit.Hfma2, typeof(OpCodeAluCbuf));
Set("01110xxx0xxxxx", InstEmit.Hfma2, typeof(OpCodeAluImm2x10));
Set("0010100xxxxxxx", InstEmit.Hfma2, typeof(OpCodeAluImm32));
Set("0101110100000x", InstEmit.Hfma2, typeof(OpCodeAluReg));
Set("01100xxx1xxxxx", InstEmit.Hfma2, typeof(OpCodeAluRegCbuf));
Set("0111100x1xxxxx", InstEmit.Hmul2, typeof(OpCodeAluCbuf));
Set("0111100x0xxxxx", InstEmit.Hmul2, typeof(OpCodeAluImm2x10));
Set("0010101xxxxxxx", InstEmit.Hmul2, typeof(OpCodeAluImm32));

View file

@ -45,11 +45,9 @@ namespace Ryujinx.Graphics.Shader.Instructions
srcB = context.FPSaturate(srcB, op.Saturate);
Operand dest = GetDest(context);
WriteFP(context, dstType, srcB);
context.Copy(dest, srcB);
SetZnFlags(context, dest, op.SetCondCode);
//TODO: CC.
}
public static void F2I(EmitterContext context)
@ -105,18 +103,18 @@ namespace Ryujinx.Graphics.Shader.Instructions
context.Copy(dest, srcB);
SetZnFlags(context, dest, op.SetCondCode);
//TODO: CC.
}
public static void I2F(EmitterContext context)
{
OpCodeAlu op = (OpCodeAlu)context.CurrOp;
FPType floatType = (FPType)op.RawOpCode.Extract(8, 2);
FPType dstType = (FPType)op.RawOpCode.Extract(8, 2);
IntegerType intType = (IntegerType)op.RawOpCode.Extract(10, 2);
IntegerType srcType = (IntegerType)op.RawOpCode.Extract(10, 2);
bool isSmallInt = intType <= IntegerType.U16;
bool isSmallInt = srcType <= IntegerType.U16;
bool isSignedInt = op.RawOpCode.Extract(13);
bool negateB = op.RawOpCode.Extract(45);
@ -126,23 +124,20 @@ namespace Ryujinx.Graphics.Shader.Instructions
if (isSmallInt)
{
int size = intType == IntegerType.U16 ? 16 : 8;
int size = srcType == IntegerType.U16 ? 16 : 8;
srcB = isSignedInt
? context.BitfieldExtractS32(srcB, Const(op.ByteSelection * 8), Const(size))
: context.BitfieldExtractU32(srcB, Const(op.ByteSelection * 8), Const(size));
}
Operand dest = GetDest(context);
srcB = isSignedInt
? context.IConvertS32ToFP(srcB)
: context.IConvertU32ToFP(srcB);
if (isSignedInt)
{
context.Copy(dest, context.IConvertS32ToFP(srcB));
}
else
{
context.Copy(dest, context.IConvertU32ToFP(srcB));
}
WriteFP(context, dstType, srcB);
//TODO: CC.
}
public static void I2I(EmitterContext context)
@ -196,5 +191,23 @@ namespace Ryujinx.Graphics.Shader.Instructions
//TODO: CC.
}
private static void WriteFP(EmitterContext context, FPType type, Operand srcB)
{
Operand dest = GetDest(context);
if (type == FPType.FP32)
{
context.Copy(dest, srcB);
}
else if (type == FPType.FP16)
{
context.Copy(dest, context.PackHalf2x16(srcB, ConstF(0)));
}
else
{
//TODO.
}
}
}
}

View file

@ -180,6 +180,33 @@ namespace Ryujinx.Graphics.Shader.Instructions
Hadd2Hmul2Impl(context, isAdd: true);
}
public static void Hfma2(EmitterContext context)
{
OpCode op = context.CurrOp;
bool saturate = false;
if (!(op is OpCodeAluImm32))
{
saturate = op.RawOpCode.Extract(op is IOpCodeReg ? 32 : 52);
}
Operand[] srcA = GetHfmaSrcA(context);
Operand[] srcB = GetHfmaSrcB(context);
Operand[] srcC = GetHfmaSrcC(context);
Operand[] res = new Operand[2];
for (int index = 0; index < res.Length; index++)
{
res[index] = context.FPFusedMultiplyAdd(srcA[index], srcB[index], srcC[index]);
res[index] = context.FPSaturate(res[index], saturate);
}
context.Copy(GetDest(context), GetHalfPacked(context, res));
}
public static void Hmul2(EmitterContext context)
{
Hadd2Hmul2Impl(context, isAdd: false);
@ -259,148 +286,6 @@ namespace Ryujinx.Graphics.Shader.Instructions
context.Copy(GetDest(context), context.FPSaturate(res, op.Saturate));
}
private static Operand[] GetHalfSrcA(EmitterContext context)
{
OpCode op = context.CurrOp;
bool absoluteA = false, negateA = false;
if (op is IOpCodeCbuf || op is IOpCodeImm)
{
negateA = op.RawOpCode.Extract(43);
absoluteA = op.RawOpCode.Extract(44);
}
else if (op is IOpCodeReg)
{
absoluteA = op.RawOpCode.Extract(44);
}
else if (op is OpCodeAluImm32 && op.Emitter == Hadd2)
{
negateA = op.RawOpCode.Extract(56);
}
FPHalfSwizzle swizzle = (FPHalfSwizzle)context.CurrOp.RawOpCode.Extract(47, 2);
Operand[] operands = GetHalfSources(context, GetSrcA(context), swizzle);
return FPAbsNeg(context, operands, absoluteA, negateA);
}
private static Operand[] GetHalfSrcB(EmitterContext context)
{
OpCode op = context.CurrOp;
FPHalfSwizzle swizzle = FPHalfSwizzle.FP16;
bool absoluteB = false, negateB = false;
if (op is IOpCodeReg)
{
swizzle = (FPHalfSwizzle)op.RawOpCode.Extract(28, 2);
absoluteB = op.RawOpCode.Extract(30);
negateB = op.RawOpCode.Extract(31);
}
else if (op is IOpCodeCbuf)
{
swizzle = FPHalfSwizzle.FP32;
absoluteB = op.RawOpCode.Extract(54);
}
Operand[] operands = GetHalfSources(context, GetSrcB(context), swizzle);
return FPAbsNeg(context, operands, absoluteB, negateB);
}
private static Operand[] GetHalfSources(EmitterContext context, Operand src, FPHalfSwizzle swizzle)
{
switch (swizzle)
{
case FPHalfSwizzle.FP16:
return new Operand[]
{
context.UnpackHalf2x16Low (src),
context.UnpackHalf2x16High(src)
};
case FPHalfSwizzle.FP32: return new Operand[] { src, src };
case FPHalfSwizzle.DupH0:
return new Operand[]
{
context.UnpackHalf2x16Low(src),
context.UnpackHalf2x16Low(src)
};
case FPHalfSwizzle.DupH1:
return new Operand[]
{
context.UnpackHalf2x16High(src),
context.UnpackHalf2x16High(src)
};
}
throw new ArgumentException($"Invalid swizzle \"{swizzle}\".");
}
private static Operand[] FPAbsNeg(EmitterContext context, Operand[] operands, bool abs, bool neg)
{
for (int index = 0; index < operands.Length; index++)
{
operands[index] = context.FPAbsNeg(operands[index], abs, neg);
}
return operands;
}
private static Operand GetHalfPacked(EmitterContext context, Operand[] results)
{
OpCode op = context.CurrOp;
FPHalfSwizzle swizzle = FPHalfSwizzle.FP16;
if (!(op is OpCodeAluImm32))
{
swizzle = (FPHalfSwizzle)context.CurrOp.RawOpCode.Extract(49, 2);
}
switch (swizzle)
{
case FPHalfSwizzle.FP16: return context.PackHalf2x16(results[0], results[1]);
case FPHalfSwizzle.FP32: return results[0];
case FPHalfSwizzle.DupH0:
{
Operand h1 = GetHalfDest(context, isHigh: true);
return context.PackHalf2x16(results[0], h1);
}
case FPHalfSwizzle.DupH1:
{
Operand h0 = GetHalfDest(context, isHigh: false);
return context.PackHalf2x16(h0, results[1]);
}
}
throw new ArgumentException($"Invalid swizzle \"{swizzle}\".");
}
private static Operand GetHalfDest(EmitterContext context, bool isHigh)
{
if (isHigh)
{
return context.UnpackHalf2x16High(GetDest(context));
}
else
{
return context.UnpackHalf2x16Low(GetDest(context));
}
}
private static Operand GetFPComparison(
EmitterContext context,
Condition cond,
@ -462,5 +347,89 @@ namespace Ryujinx.Graphics.Shader.Instructions
context.Copy(GetNF(context), context.FPCompareLess (dest, ConstF(0)));
}
}
private static Operand[] GetHfmaSrcA(EmitterContext context)
{
OpCode op = context.CurrOp;
FPHalfSwizzle swizzle = (FPHalfSwizzle)op.RawOpCode.Extract(47, 2);
return GetHalfSources(context, GetSrcA(context), swizzle);
}
private static Operand[] GetHfmaSrcB(EmitterContext context)
{
OpCode op = context.CurrOp;
FPHalfSwizzle swizzle = FPHalfSwizzle.FP16;
bool negateB = false;
//Note: OpCodeAluRegCbuf also implements IOpCodeReg.
//Check IOpCodeRegCbuf before checking IOpCodeReg.
if (op is IOpCodeRegCbuf)
{
negateB = op.RawOpCode.Extract(56);
swizzle = (FPHalfSwizzle)op.RawOpCode.Extract(53, 2);
}
else if (op is IOpCodeReg)
{
swizzle = (FPHalfSwizzle)op.RawOpCode.Extract(28, 2);
negateB = op.RawOpCode.Extract(31);
}
else if (op is IOpCodeCbuf)
{
swizzle = FPHalfSwizzle.FP32;
negateB = op.RawOpCode.Extract(56);
}
Operand[] operands = GetHalfSources(context, GetSrcB(context), swizzle);
return FPAbsNeg(context, operands, abs: false, neg: negateB);
}
private static Operand[] GetHfmaSrcC(EmitterContext context)
{
OpCode op = context.CurrOp;
Operand[] operands;
if (op is OpCodeAluImm32)
{
operands = GetHalfSources(context, GetDest(context), FPHalfSwizzle.FP16);
return FPAbsNeg(context, operands, abs: false, neg: op.RawOpCode.Extract(52));
}
FPHalfSwizzle swizzle = FPHalfSwizzle.FP16;
bool negateC = false;
//Note: OpCodeAluRegCbuf also implements IOpCodeReg.
//Check IOpCodeRegCbuf before checking IOpCodeReg.
if (op is IOpCodeRegCbuf)
{
swizzle = FPHalfSwizzle.FP32;
negateC = op.RawOpCode.Extract(51);
}
else if (op is IOpCodeReg)
{
swizzle = (FPHalfSwizzle)op.RawOpCode.Extract(35, 2);
negateC = op.RawOpCode.Extract(30);
}
else
{
swizzle = (FPHalfSwizzle)op.RawOpCode.Extract(53, 2);
}
operands = GetHalfSources(context, GetSrcC(context), swizzle);
return FPAbsNeg(context, operands, abs: false, neg: negateC);
}
}
}

View file

@ -41,7 +41,22 @@ namespace Ryujinx.Graphics.Shader.Instructions
public static Operand GetSrcB(EmitterContext context, FPType floatType)
{
return GetSrcB(context);
if (floatType == FPType.FP32)
{
return GetSrcB(context);
}
else if (floatType == FPType.FP16)
{
int h = context.CurrOp.RawOpCode.Extract(41, 1);
return GetHalfSources(context, GetSrcB(context), FPHalfSwizzle.FP16)[h];
}
else if (floatType == FPType.FP64)
{
//TODO.
}
throw new ArgumentException($"Invalid floating point type \"{floatType}\".");
}
public static Operand GetSrcB(EmitterContext context)
@ -78,6 +93,148 @@ namespace Ryujinx.Graphics.Shader.Instructions
throw new InvalidOperationException($"Unexpected opcode type \"{context.CurrOp.GetType().Name}\".");
}
public static Operand[] GetHalfSrcA(EmitterContext context)
{
OpCode op = context.CurrOp;
bool absoluteA = false, negateA = false;
if (op is IOpCodeCbuf || op is IOpCodeImm)
{
negateA = op.RawOpCode.Extract(43);
absoluteA = op.RawOpCode.Extract(44);
}
else if (op is IOpCodeReg)
{
absoluteA = op.RawOpCode.Extract(44);
}
else if (op is OpCodeAluImm32 && op.Emitter == InstEmit.Hadd2)
{
negateA = op.RawOpCode.Extract(56);
}
FPHalfSwizzle swizzle = (FPHalfSwizzle)op.RawOpCode.Extract(47, 2);
Operand[] operands = GetHalfSources(context, GetSrcA(context), swizzle);
return FPAbsNeg(context, operands, absoluteA, negateA);
}
public static Operand[] GetHalfSrcB(EmitterContext context)
{
OpCode op = context.CurrOp;
FPHalfSwizzle swizzle = FPHalfSwizzle.FP16;
bool absoluteB = false, negateB = false;
if (op is IOpCodeReg)
{
swizzle = (FPHalfSwizzle)op.RawOpCode.Extract(28, 2);
absoluteB = op.RawOpCode.Extract(30);
negateB = op.RawOpCode.Extract(31);
}
else if (op is IOpCodeCbuf)
{
swizzle = FPHalfSwizzle.FP32;
absoluteB = op.RawOpCode.Extract(54);
}
Operand[] operands = GetHalfSources(context, GetSrcB(context), swizzle);
return FPAbsNeg(context, operands, absoluteB, negateB);
}
public static Operand[] FPAbsNeg(EmitterContext context, Operand[] operands, bool abs, bool neg)
{
for (int index = 0; index < operands.Length; index++)
{
operands[index] = context.FPAbsNeg(operands[index], abs, neg);
}
return operands;
}
public static Operand[] GetHalfSources(EmitterContext context, Operand src, FPHalfSwizzle swizzle)
{
switch (swizzle)
{
case FPHalfSwizzle.FP16:
return new Operand[]
{
context.UnpackHalf2x16Low (src),
context.UnpackHalf2x16High(src)
};
case FPHalfSwizzle.FP32: return new Operand[] { src, src };
case FPHalfSwizzle.DupH0:
return new Operand[]
{
context.UnpackHalf2x16Low(src),
context.UnpackHalf2x16Low(src)
};
case FPHalfSwizzle.DupH1:
return new Operand[]
{
context.UnpackHalf2x16High(src),
context.UnpackHalf2x16High(src)
};
}
throw new ArgumentException($"Invalid swizzle \"{swizzle}\".");
}
public static Operand GetHalfPacked(EmitterContext context, Operand[] results)
{
OpCode op = context.CurrOp;
FPHalfSwizzle swizzle = FPHalfSwizzle.FP16;
if (!(op is OpCodeAluImm32))
{
swizzle = (FPHalfSwizzle)context.CurrOp.RawOpCode.Extract(49, 2);
}
switch (swizzle)
{
case FPHalfSwizzle.FP16: return context.PackHalf2x16(results[0], results[1]);
case FPHalfSwizzle.FP32: return results[0];
case FPHalfSwizzle.DupH0:
{
Operand h1 = GetHalfDest(context, isHigh: true);
return context.PackHalf2x16(results[0], h1);
}
case FPHalfSwizzle.DupH1:
{
Operand h0 = GetHalfDest(context, isHigh: false);
return context.PackHalf2x16(h0, results[1]);
}
}
throw new ArgumentException($"Invalid swizzle \"{swizzle}\".");
}
public static Operand GetHalfDest(EmitterContext context, bool isHigh)
{
if (isHigh)
{
return context.UnpackHalf2x16High(GetDest(context));
}
else
{
return context.UnpackHalf2x16Low(GetDest(context));
}
}
public static Operand GetPredicate39(EmitterContext context)
{
IOpCodeAlu op = (IOpCodeAlu)context.CurrOp;

View file

@ -20,13 +20,15 @@ namespace Ryujinx.Graphics.Shader.IntermediateRepresentation
public Operation(Instruction inst, Operand dest, params Operand[] sources)
{
Inst = inst;
Dest = dest;
_sources = sources;
Inst = inst;
Dest = dest;
for (int index = 0; index < sources.Length; index++)
//The array may be modified externally, so we store a copy.
_sources = (Operand[])sources.Clone();
for (int index = 0; index < _sources.Length; index++)
{
Operand source = sources[index];
Operand source = _sources[index];
if (source.Type == OperandType.LocalVariable)
{
@ -59,14 +61,21 @@ namespace Ryujinx.Graphics.Shader.IntermediateRepresentation
return _sources[index];
}
public void SetSource(int index, Operand operand)
public void SetSource(int index, Operand source)
{
if (operand.Type == OperandType.LocalVariable)
Operand oldSrc = _sources[index];
if (oldSrc != null && oldSrc.Type == OperandType.LocalVariable)
{
operand.UseOps.Add(this);
oldSrc.UseOps.Remove(this);
}
_sources[index] = operand;
if (source.Type == OperandType.LocalVariable)
{
source.UseOps.Add(this);
}
_sources[index] = source;
}
public void TurnIntoCopy(Operand source)
@ -81,7 +90,10 @@ namespace Ryujinx.Graphics.Shader.IntermediateRepresentation
}
}
source.UseOps.Add(this);
if (source.Type == OperandType.LocalVariable)
{
source.UseOps.Add(this);
}
_sources = new Operand[] { source };
}

View file

@ -74,14 +74,21 @@ namespace Ryujinx.Graphics.Shader.IntermediateRepresentation
return _sources[index].Block;
}
public void SetSource(int index, Operand operand)
public void SetSource(int index, Operand source)
{
if (operand.Type == OperandType.LocalVariable)
Operand oldSrc = _sources[index].Operand;
if (oldSrc != null && oldSrc.Type == OperandType.LocalVariable)
{
operand.UseOps.Add(this);
oldSrc.UseOps.Remove(this);
}
_sources[index].Operand = operand;
if (source.Type == OperandType.LocalVariable)
{
source.UseOps.Add(this);
}
_sources[index].Operand = source;
}
}
}

View file

@ -86,7 +86,9 @@ namespace Ryujinx.Graphics.Shader.Translation.Optimizations
Operand dest = copyOp.Dest;
Operand src = copyOp.GetSource(0);
foreach (INode useNode in dest.UseOps)
INode[] uses = dest.UseOps.ToArray();
foreach (INode useNode in uses)
{
for (int index = 0; index < useNode.SourcesCount; index++)
{
@ -112,25 +114,14 @@ namespace Ryujinx.Graphics.Shader.Translation.Optimizations
foreach (INode useNode in uses)
{
if (!(useNode is Operation operation))
{
continue;
}
Operand src;
if (operation.Inst == Instruction.UnpackHalf2x16)
{
src = operation.ComponentIndex == 1 ? src1 : src0;
}
else
if (!(useNode is Operation operation) || operation.Inst != Instruction.UnpackHalf2x16)
{
continue;
}
if (operation.GetSource(0) == dest)
{
operation.TurnIntoCopy(src);
operation.TurnIntoCopy(operation.ComponentIndex == 1 ? src1 : src0);
modified = true;
}