diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs index 5005a24203..f06af235ae 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs @@ -77,7 +77,22 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl { if (inputs.Any()) { - context.AppendLine("struct VertexIn"); + string prefix = ""; + + switch (context.Definitions.Stage) + { + case ShaderStage.Vertex: + prefix = "Vertex"; + break; + case ShaderStage.Fragment: + prefix = "Fragment"; + break; + case ShaderStage.Compute: + prefix = "Compute"; + break; + } + + context.AppendLine($"struct {prefix}In"); context.EnterScope(); foreach (var ioDefinition in inputs.OrderBy(x => x.Location)) diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/MslGenerator.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/MslGenerator.cs index e0ce97abef..b4d2ecad2a 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/MslGenerator.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/MslGenerator.cs @@ -90,10 +90,25 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl funcKeyword = "fragment"; funcName = "fragmentMain"; } + else if (stage == ShaderStage.Compute) + { + // TODO: Compute main + } if (context.AttributeUsage.UsedInputAttributes != 0) { - args = args.Prepend("VertexIn in [[stage_in]]").ToArray(); + if (stage == ShaderStage.Vertex) + { + args = args.Prepend("VertexIn in [[stage_in]]").ToArray(); + } + else if (stage == ShaderStage.Fragment) + { + args = args.Prepend("FragmentIn in [[stage_in]]").ToArray(); + } + else if (stage == ShaderStage.Compute) + { + // TODO: Compute input + } } }