From cb971781d8e9b0617a1f5519fd37301d4d948ebb Mon Sep 17 00:00:00 2001 From: Isaac Marovitz Date: Tue, 15 Aug 2023 23:06:07 +0100 Subject: [PATCH] Fix UserDefined IO Vars --- .../CodeGen/Msl/Instructions/InstGenMemory.cs | 14 +++++- .../CodeGen/Msl/Instructions/IoMap.cs | 47 ++++++++++++++++++- .../CodeGen/Msl/OperandManager.cs | 17 ++++--- 3 files changed, 70 insertions(+), 8 deletions(-) diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenMemory.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenMemory.cs index 1773c17ab7..2f0434c964 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenMemory.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenMemory.cs @@ -74,6 +74,9 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions IoVariable ioVariable = (IoVariable)varId.Value; bool isOutput = storageKind.IsOutput(); + bool isPerPatch = storageKind.IsPerPatch(); + int location = -1; + int component = 0; if (context.Definitions.HasPerLocationInputOrOutput(ioVariable, isOutput)) { @@ -82,16 +85,25 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions throw new InvalidOperationException($"Second input of {operation.Inst} with {storageKind} storage must be a constant operand."); } + location = vecIndex.Value; + if (operation.SourcesCount > srcIndex && operation.GetSource(srcIndex) is AstOperand elemIndex && elemIndex.Type == OperandType.Constant && context.Definitions.HasPerLocationInputOrOutputComponent(ioVariable, vecIndex.Value, elemIndex.Value, isOutput)) { + component = elemIndex.Value; srcIndex++; } } - (varName, varType) = IoMap.GetMslBuiltIn(ioVariable); + (varName, varType) = IoMap.GetMslBuiltIn( + context.Definitions, + ioVariable, + location, + component, + isOutput, + isPerPatch); break; default: diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/IoMap.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/IoMap.cs index b89e8b0201..67e78c9f69 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/IoMap.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/IoMap.cs @@ -1,11 +1,18 @@ using Ryujinx.Graphics.Shader.IntermediateRepresentation; using Ryujinx.Graphics.Shader.Translation; +using System.Globalization; namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions { static class IoMap { - public static (string, AggregateType) GetMslBuiltIn(IoVariable ioVariable) + public static (string, AggregateType) GetMslBuiltIn( + ShaderDefinitions definitions, + IoVariable ioVariable, + int location, + int component, + bool isOutput, + bool isPerPatch) { return ioVariable switch { @@ -20,10 +27,48 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions IoVariable.PointSize => ("point_size", AggregateType.FP32), IoVariable.Position => ("position", AggregateType.Vector4 | AggregateType.FP32), IoVariable.PrimitiveId => ("primitive_id", AggregateType.S32), + IoVariable.UserDefined => GetUserDefinedVariableName(definitions, location, component, isOutput, isPerPatch), IoVariable.VertexId => ("vertex_id", AggregateType.S32), IoVariable.ViewportIndex => ("viewport_array_index", AggregateType.S32), _ => (null, AggregateType.Invalid), }; } + + private static (string, AggregateType) GetUserDefinedVariableName(ShaderDefinitions definitions, int location, int component, bool isOutput, bool isPerPatch) + { + string name = isPerPatch + ? DefaultNames.PerPatchAttributePrefix + : (isOutput ? DefaultNames.OAttributePrefix : DefaultNames.IAttributePrefix); + + if (location < 0) + { + return (name, definitions.GetUserDefinedType(0, isOutput)); + } + + name += location.ToString(CultureInfo.InvariantCulture); + + if (definitions.HasPerLocationInputOrOutputComponent(IoVariable.UserDefined, location, component, isOutput)) + { + name += "_" + "xyzw"[component & 3]; + } + + string prefix = ""; + switch (definitions.Stage) + { + case ShaderStage.Vertex: + prefix = "Vertex"; + break; + case ShaderStage.Fragment: + prefix = "Fragment"; + break; + case ShaderStage.Compute: + prefix = "Compute"; + break; + } + + prefix += isOutput ? "Out" : "In"; + + return (prefix + "." + name, definitions.GetUserDefinedType(location, isOutput)); + } } } diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/OperandManager.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/OperandManager.cs index 7ec653fa88..6d211b7e8b 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/OperandManager.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/OperandManager.cs @@ -46,9 +46,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl public static AggregateType GetNodeDestType(CodeGenContext context, IAstNode node) { - // TODO: Get rid of that function entirely and return the type from the operation generation - // functions directly, like SPIR-V does. - if (node is AstOperation operation) { if (operation.Inst == Instruction.Load || operation.Inst.IsAtomic()) @@ -99,6 +96,8 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl IoVariable ioVariable = (IoVariable)varId.Value; bool isOutput = operation.StorageKind == StorageKind.Output || operation.StorageKind == StorageKind.OutputPerPatch; bool isPerPatch = operation.StorageKind == StorageKind.InputPerPatch || operation.StorageKind == StorageKind.OutputPerPatch; + int location = 0; + int component = 0; if (context.Definitions.HasPerLocationInputOrOutput(ioVariable, isOutput)) { @@ -107,18 +106,24 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl throw new InvalidOperationException($"Second input of {operation.Inst} with {operation.StorageKind} storage must be a constant operand."); } - int location = vecIndex.Value; + location = vecIndex.Value; if (operation.SourcesCount > 2 && operation.GetSource(2) is AstOperand elemIndex && elemIndex.Type == OperandType.Constant && context.Definitions.HasPerLocationInputOrOutputComponent(ioVariable, location, elemIndex.Value, isOutput)) { - int component = elemIndex.Value; + component = elemIndex.Value; } } - (_, AggregateType varType) = IoMap.GetMslBuiltIn(ioVariable); + (_, AggregateType varType) = IoMap.GetMslBuiltIn( + context.Definitions, + ioVariable, + location, + component, + isOutput, + isPerPatch); return varType & AggregateType.ElementTypeMask; }