From d6fd565492074b6b4b98cb0d670a5ad3a53eaf6c Mon Sep 17 00:00:00 2001 From: Isaac Marovitz Date: Sat, 22 Jun 2024 13:53:39 +0100 Subject: [PATCH] Fix atomic operations --- .../CodeGen/Msl/Declarations.cs | 9 ++++++--- .../CodeGen/Msl/Instructions/InstGen.cs | 7 ++++--- .../CodeGen/Msl/Instructions/InstGenHelper.cs | 12 ++++++------ 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs index 59552b885e..09de9d0a2b 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs @@ -98,15 +98,18 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl } } - public static string GetVarTypeName(CodeGenContext context, AggregateType type) + public static string GetVarTypeName(CodeGenContext context, AggregateType type, bool atomic = false) { + var s32 = atomic ? "atomic_int" : "int"; + var u32 = atomic ? "atomic_uint" : "uint"; + return type switch { AggregateType.Void => "void", AggregateType.Bool => "bool", AggregateType.FP32 => "float", - AggregateType.S32 => "int", - AggregateType.U32 => "uint", + AggregateType.S32 => s32, + AggregateType.U32 => u32, AggregateType.Vector2 | AggregateType.Bool => "bool2", AggregateType.Vector2 | AggregateType.FP32 => "float2", AggregateType.Vector2 | AggregateType.S32 => "int2", diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGen.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGen.cs index 7626f94aae..0bea4d1aa6 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGen.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGen.cs @@ -44,15 +44,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions if (atomic && (operation.StorageKind == StorageKind.StorageBuffer || operation.StorageKind == StorageKind.SharedMemory)) { - builder.Append(GenerateLoadOrStore(context, operation, isStore: false)); - AggregateType dstType = operation.Inst == Instruction.AtomicMaxS32 || operation.Inst == Instruction.AtomicMinS32 ? AggregateType.S32 : AggregateType.U32; + builder.Append($"(device {Declarations.GetVarTypeName(context, dstType, true)}*)&{GenerateLoadOrStore(context, operation, isStore: false)}"); + + for (int argIndex = operation.SourcesCount - arity + 2; argIndex < operation.SourcesCount; argIndex++) { - builder.Append($", {GetSourceExpr(context, operation.GetSource(argIndex), dstType)}"); + builder.Append($", {GetSourceExpr(context, operation.GetSource(argIndex), dstType)}, memory_order_relaxed"); } } else diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenHelper.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenHelper.cs index 983441e59c..d230e2ed49 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenHelper.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenHelper.cs @@ -15,14 +15,14 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions _infoTable = new InstInfo[(int)Instruction.Count]; #pragma warning disable IDE0055 // Disable formatting - Add(Instruction.AtomicAdd, InstType.AtomicBinary, "atomic_add_explicit"); - Add(Instruction.AtomicAnd, InstType.AtomicBinary, "atomic_and_explicit"); + Add(Instruction.AtomicAdd, InstType.AtomicBinary, "atomic_fetch_add_explicit"); + Add(Instruction.AtomicAnd, InstType.AtomicBinary, "atomic_fetch_and_explicit"); Add(Instruction.AtomicCompareAndSwap, InstType.AtomicBinary, "atomic_compare_exchange_weak_explicit"); - Add(Instruction.AtomicMaxU32, InstType.AtomicBinary, "atomic_max_explicit"); - Add(Instruction.AtomicMinU32, InstType.AtomicBinary, "atomic_min_explicit"); - Add(Instruction.AtomicOr, InstType.AtomicBinary, "atomic_or_explicit"); + Add(Instruction.AtomicMaxU32, InstType.AtomicBinary, "atomic_fetch_max_explicit"); + Add(Instruction.AtomicMinU32, InstType.AtomicBinary, "atomic_fetch_min_explicit"); + Add(Instruction.AtomicOr, InstType.AtomicBinary, "atomic_fetch_or_explicit"); Add(Instruction.AtomicSwap, InstType.AtomicBinary, "atomic_exchange_explicit"); - Add(Instruction.AtomicXor, InstType.AtomicBinary, "atomic_xor_explicit"); + Add(Instruction.AtomicXor, InstType.AtomicBinary, "atomic_fetch_xor_explicit"); Add(Instruction.Absolute, InstType.CallUnary, "abs"); Add(Instruction.Add, InstType.OpBinaryCom, "+", 2); Add(Instruction.Ballot, InstType.Special);