diff --git a/src/Ryujinx.Graphics.Metal/Auto.cs b/src/Ryujinx.Graphics.Metal/Auto.cs index 7e79ecbc37..2ca3a37e4a 100644 --- a/src/Ryujinx.Graphics.Metal/Auto.cs +++ b/src/Ryujinx.Graphics.Metal/Auto.cs @@ -19,6 +19,12 @@ namespace Ryujinx.Graphics.Metal void AddCommandBufferDependencies(CommandBufferScoped cbs); } + interface IMirrorable where T : IDisposable + { + Auto GetMirrorable(CommandBufferScoped cbs, ref int offset, int size, out bool mirrored); + void ClearMirrors(CommandBufferScoped cbs, int offset, int size); + } + [SupportedOSPlatform("macos")] class Auto : IAutoPrivate, IDisposable where T : IDisposable { @@ -27,6 +33,7 @@ namespace Ryujinx.Graphics.Metal private readonly BitMap _cbOwnership; private readonly MultiFenceHolder _waitable; + private readonly IMirrorable _mirrorable; private bool _disposed; private bool _destroyed; @@ -38,13 +45,22 @@ namespace Ryujinx.Graphics.Metal _cbOwnership = new BitMap(CommandBufferPool.MaxCommandBuffers); } - public Auto(T value, MultiFenceHolder waitable) : this(value) + public Auto(T value, IMirrorable mirrorable, MultiFenceHolder waitable) : this(value) { + _mirrorable = mirrorable; _waitable = waitable; } + public T GetMirrorable(CommandBufferScoped cbs, ref int offset, int size, out bool mirrored) + { + var mirror = _mirrorable.GetMirrorable(cbs, ref offset, size, out mirrored); + mirror._waitable?.AddBufferUse(cbs.CommandBufferIndex, offset, size, false); + return mirror.Get(cbs); + } + public T Get(CommandBufferScoped cbs, int offset, int size, bool write = false) { + _mirrorable?.ClearMirrors(cbs, offset, size); _waitable?.AddBufferUse(cbs.CommandBufferIndex, offset, size, write); return Get(cbs); } diff --git a/src/Ryujinx.Graphics.Metal/BitMapStruct.cs b/src/Ryujinx.Graphics.Metal/BitMapStruct.cs new file mode 100644 index 0000000000..9bc95af47d --- /dev/null +++ b/src/Ryujinx.Graphics.Metal/BitMapStruct.cs @@ -0,0 +1,263 @@ +using Ryujinx.Common.Memory; +using System; +using System.Numerics; + +namespace Ryujinx.Graphics.Metal +{ + interface IBitMapListener + { + void BitMapSignal(int index, int count); + } + + struct BitMapStruct where T : IArray + { + public const int IntSize = 64; + + private const int IntShift = 6; + private const int IntMask = IntSize - 1; + + private T _masks; + + public BitMapStruct() + { + _masks = default; + } + + public bool BecomesUnsetFrom(in BitMapStruct from, ref BitMapStruct into) + { + bool result = false; + + int masks = _masks.Length; + for (int i = 0; i < masks; i++) + { + long fromMask = from._masks[i]; + long unsetMask = (~fromMask) & (fromMask ^ _masks[i]); + into._masks[i] = unsetMask; + + result |= unsetMask != 0; + } + + return result; + } + + public void SetAndSignalUnset(in BitMapStruct from, ref T2 listener) where T2 : struct, IBitMapListener + { + BitMapStruct result = new(); + + if (BecomesUnsetFrom(from, ref result)) + { + // Iterate the set bits in the result, and signal them. + + int offset = 0; + int masks = _masks.Length; + ref T resultMasks = ref result._masks; + for (int i = 0; i < masks; i++) + { + long value = resultMasks[i]; + while (value != 0) + { + int bit = BitOperations.TrailingZeroCount((ulong)value); + + listener.BitMapSignal(offset + bit, 1); + + value &= ~(1L << bit); + } + + offset += IntSize; + } + } + + _masks = from._masks; + } + + public void SignalSet(Action action) + { + // Iterate the set bits in the result, and signal them. + + int offset = 0; + int masks = _masks.Length; + for (int i = 0; i < masks; i++) + { + long value = _masks[i]; + while (value != 0) + { + int bit = BitOperations.TrailingZeroCount((ulong)value); + + action(offset + bit, 1); + + value &= ~(1L << bit); + } + + offset += IntSize; + } + } + + public bool AnySet() + { + for (int i = 0; i < _masks.Length; i++) + { + if (_masks[i] != 0) + { + return true; + } + } + + return false; + } + + public bool IsSet(int bit) + { + int wordIndex = bit >> IntShift; + int wordBit = bit & IntMask; + + long wordMask = 1L << wordBit; + + return (_masks[wordIndex] & wordMask) != 0; + } + + public bool IsSet(int start, int end) + { + if (start == end) + { + return IsSet(start); + } + + int startIndex = start >> IntShift; + int startBit = start & IntMask; + long startMask = -1L << startBit; + + int endIndex = end >> IntShift; + int endBit = end & IntMask; + long endMask = (long)(ulong.MaxValue >> (IntMask - endBit)); + + if (startIndex == endIndex) + { + return (_masks[startIndex] & startMask & endMask) != 0; + } + + if ((_masks[startIndex] & startMask) != 0) + { + return true; + } + + for (int i = startIndex + 1; i < endIndex; i++) + { + if (_masks[i] != 0) + { + return true; + } + } + + if ((_masks[endIndex] & endMask) != 0) + { + return true; + } + + return false; + } + + public bool Set(int bit) + { + int wordIndex = bit >> IntShift; + int wordBit = bit & IntMask; + + long wordMask = 1L << wordBit; + + if ((_masks[wordIndex] & wordMask) != 0) + { + return false; + } + + _masks[wordIndex] |= wordMask; + + return true; + } + + public void Set(int bit, bool value) + { + if (value) + { + Set(bit); + } + else + { + Clear(bit); + } + } + + public void SetRange(int start, int end) + { + if (start == end) + { + Set(start); + return; + } + + int startIndex = start >> IntShift; + int startBit = start & IntMask; + long startMask = -1L << startBit; + + int endIndex = end >> IntShift; + int endBit = end & IntMask; + long endMask = (long)(ulong.MaxValue >> (IntMask - endBit)); + + if (startIndex == endIndex) + { + _masks[startIndex] |= startMask & endMask; + } + else + { + _masks[startIndex] |= startMask; + + for (int i = startIndex + 1; i < endIndex; i++) + { + _masks[i] |= -1L; + } + + _masks[endIndex] |= endMask; + } + } + + public BitMapStruct Union(BitMapStruct other) + { + var result = new BitMapStruct(); + + ref var masks = ref _masks; + ref var otherMasks = ref other._masks; + ref var newMasks = ref result._masks; + + for (int i = 0; i < masks.Length; i++) + { + newMasks[i] = masks[i] | otherMasks[i]; + } + + return result; + } + + public void Clear(int bit) + { + int wordIndex = bit >> IntShift; + int wordBit = bit & IntMask; + + long wordMask = 1L << wordBit; + + _masks[wordIndex] &= ~wordMask; + } + + public void Clear() + { + for (int i = 0; i < _masks.Length; i++) + { + _masks[i] = 0; + } + } + + public void ClearInt(int start, int end) + { + for (int i = start; i <= end; i++) + { + _masks[i] = 0; + } + } + } +} diff --git a/src/Ryujinx.Graphics.Metal/BufferHolder.cs b/src/Ryujinx.Graphics.Metal/BufferHolder.cs index f07143a430..f81c9d7675 100644 --- a/src/Ryujinx.Graphics.Metal/BufferHolder.cs +++ b/src/Ryujinx.Graphics.Metal/BufferHolder.cs @@ -1,6 +1,7 @@ using Ryujinx.Graphics.GAL; using SharpMetal.Metal; using System; +using System.Collections.Generic; using System.Runtime.InteropServices; using System.Runtime.Versioning; using System.Threading; @@ -8,7 +9,7 @@ using System.Threading; namespace Ryujinx.Graphics.Metal { [SupportedOSPlatform("macos")] - class BufferHolder : IDisposable + class BufferHolder : IDisposable, IMirrorable { private CacheByRange _cachedConvertedBuffers; @@ -25,19 +26,99 @@ namespace Ryujinx.Graphics.Metal private FenceHolder _flushFence; private int _flushWaiting; + private byte[] _pendingData; + private BufferMirrorRangeList _pendingDataRanges; + private Dictionary _mirrors; + public BufferHolder(MetalRenderer renderer, Pipeline pipeline, MTLBuffer buffer, int size) { _renderer = renderer; _pipeline = pipeline; _map = buffer.Contents; _waitable = new MultiFenceHolder(size); - _buffer = new Auto(new(buffer), _waitable); + _buffer = new Auto(new(buffer), this, _waitable); _flushLock = new ReaderWriterLockSlim(); Size = size; } + private static ulong ToMirrorKey(int offset, int size) + { + return ((ulong)offset << 32) | (uint)size; + } + + private static (int offset, int size) FromMirrorKey(ulong key) + { + return ((int)(key >> 32), (int)key); + } + + private unsafe bool TryGetMirror(CommandBufferScoped cbs, ref int offset, int size, out Auto buffer) + { + size = Math.Min(size, Size - offset); + + // Does this binding need to be mirrored? + + if (!_pendingDataRanges.OverlapsWith(offset, size)) + { + buffer = null; + return false; + } + + var key = ToMirrorKey(offset, size); + + if (_mirrors.TryGetValue(key, out StagingBufferReserved reserved)) + { + buffer = reserved.Buffer.GetBuffer(); + offset = reserved.Offset; + + return true; + } + + // Is this mirror allowed to exist? Can't be used for write in any in-flight write. + if (_waitable.IsBufferRangeInUse(offset, size, true)) + { + // Some of the data is not mirrorable, so upload the whole range. + ClearMirrors(cbs, offset, size); + + buffer = null; + return false; + } + + // Build data for the new mirror. + + var baseData = new Span((void*)(_map + offset), size); + var modData = _pendingData.AsSpan(offset, size); + + StagingBufferReserved? newMirror = _renderer.BufferManager.StagingBuffer.TryReserveData(cbs, size); + + if (newMirror != null) + { + var mirror = newMirror.Value; + _pendingDataRanges.FillData(baseData, modData, offset, new Span((void*)(mirror.Buffer._map + mirror.Offset), size)); + + if (_mirrors.Count == 0) + { + _pipeline.RegisterActiveMirror(this); + } + + _mirrors.Add(key, mirror); + + buffer = mirror.Buffer.GetBuffer(); + offset = mirror.Offset; + + return true; + } + else + { + // Data could not be placed on the mirror, likely out of space. Force the data to flush. + ClearMirrors(cbs, offset, size); + + buffer = null; + return false; + } + } + public Auto GetBuffer() { return _buffer; @@ -63,6 +144,74 @@ namespace Ryujinx.Graphics.Metal return _buffer; } + public Auto GetMirrorable(CommandBufferScoped cbs, ref int offset, int size, out bool mirrored) + { + if (_pendingData != null && TryGetMirror(cbs, ref offset, size, out Auto result)) + { + mirrored = true; + return result; + } + + mirrored = false; + return _buffer; + } + + public void ClearMirrors() + { + // Clear mirrors without forcing a flush. This happens when the command buffer is switched, + // as all reserved areas on the staging buffer are released. + + if (_pendingData != null) + { + _mirrors.Clear(); + } + } + + public void ClearMirrors(CommandBufferScoped cbs, int offset, int size) + { + // Clear mirrors in the given range, and submit overlapping pending data. + + if (_pendingData != null) + { + bool hadMirrors = _mirrors.Count > 0 && RemoveOverlappingMirrors(offset, size); + + if (_pendingDataRanges.Count() != 0) + { + UploadPendingData(cbs, offset, size); + } + + if (hadMirrors) + { + _pipeline.Rebind(_buffer, offset, size); + } + } + } + + private void UploadPendingData(CommandBufferScoped cbs, int offset, int size) + { + var ranges = _pendingDataRanges.FindOverlaps(offset, size); + + if (ranges != null) + { + _pendingDataRanges.Remove(offset, size); + + foreach (var range in ranges) + { + int rangeOffset = Math.Max(offset, range.Offset); + int rangeSize = Math.Min(offset + size, range.End) - rangeOffset; + + if (_pipeline.Cbs.CommandBuffer == cbs.CommandBuffer) + { + SetData(rangeOffset, _pendingData.AsSpan(rangeOffset, rangeSize), cbs, _pipeline.EndRenderPassDelegate, false); + } + else + { + SetData(rangeOffset, _pendingData.AsSpan(rangeOffset, rangeSize), cbs, null, false); + } + } + } + } + public void SignalWrite(int offset, int size) { if (offset == 0 && size == Size) @@ -162,6 +311,33 @@ namespace Ryujinx.Graphics.Metal throw new InvalidOperationException("The buffer is not mapped."); } + public bool RemoveOverlappingMirrors(int offset, int size) + { + List toRemove = null; + foreach (var key in _mirrors.Keys) + { + (int keyOffset, int keySize) = FromMirrorKey(key); + if (!(offset + size <= keyOffset || offset >= keyOffset + keySize)) + { + toRemove ??= new List(); + + toRemove.Add(key); + } + } + + if (toRemove != null) + { + foreach (var key in toRemove) + { + _mirrors.Remove(key); + } + + return true; + } + + return false; + } + public unsafe void SetData(int offset, ReadOnlySpan data, CommandBufferScoped? cbs = null, Action endRenderPass = null, bool allowCbsWait = true) { int dataSize = Math.Min(data.Length, Size - offset); @@ -170,6 +346,8 @@ namespace Ryujinx.Graphics.Metal return; } + bool allowMirror = allowCbsWait && cbs != null; + if (_map != IntPtr.Zero) { // If persistently mapped, set the data directly if the buffer is not currently in use. @@ -190,6 +368,32 @@ namespace Ryujinx.Graphics.Metal } } + // If the buffer does not have an in-flight write (including an inline update), then upload data to a pendingCopy. + if (allowMirror && !_waitable.IsBufferRangeInUse(offset, dataSize, true)) + { + if (_pendingData == null) + { + _pendingData = new byte[Size]; + _mirrors = new Dictionary(); + } + + data[..dataSize].CopyTo(_pendingData.AsSpan(offset, dataSize)); + _pendingDataRanges.Add(offset, dataSize); + + // Remove any overlapping mirrors. + RemoveOverlappingMirrors(offset, dataSize); + + // Tell the graphics device to rebind any constant buffer that overlaps the newly modified range, as it should access a mirror. + _pipeline.Rebind(_buffer, offset, dataSize); + + return; + } + + if (_pendingData != null) + { + _pendingDataRanges.Remove(offset, dataSize); + } + if (cbs != null && _pipeline.RenderPassActive && !(_buffer.HasCommandBufferDependency(cbs.Value) && diff --git a/src/Ryujinx.Graphics.Metal/BufferMirrorRangeList.cs b/src/Ryujinx.Graphics.Metal/BufferMirrorRangeList.cs new file mode 100644 index 0000000000..86e1f3426c --- /dev/null +++ b/src/Ryujinx.Graphics.Metal/BufferMirrorRangeList.cs @@ -0,0 +1,305 @@ +using System; +using System.Collections.Generic; + +namespace Ryujinx.Graphics.Metal +{ +/// + /// A structure tracking pending upload ranges for buffers. + /// Where a range is present, pending data exists that can either be used to build mirrors + /// or upload directly to the buffer. + /// + struct BufferMirrorRangeList + { + internal readonly struct Range + { + public int Offset { get; } + public int Size { get; } + + public int End => Offset + Size; + + public Range(int offset, int size) + { + Offset = offset; + Size = size; + } + + public bool OverlapsWith(int offset, int size) + { + return Offset < offset + size && offset < Offset + Size; + } + } + + private List _ranges; + + public readonly IEnumerable All() + { + return _ranges; + } + + public readonly bool Remove(int offset, int size) + { + var list = _ranges; + bool removedAny = false; + if (list != null) + { + int overlapIndex = BinarySearch(list, offset, size); + + if (overlapIndex >= 0) + { + // Overlaps with a range. Search back to find the first one it doesn't overlap with. + + while (overlapIndex > 0 && list[overlapIndex - 1].OverlapsWith(offset, size)) + { + overlapIndex--; + } + + int endOffset = offset + size; + int startIndex = overlapIndex; + + var currentOverlap = list[overlapIndex]; + + // Orphan the start of the overlap. + if (currentOverlap.Offset < offset) + { + list[overlapIndex] = new Range(currentOverlap.Offset, offset - currentOverlap.Offset); + currentOverlap = new Range(offset, currentOverlap.End - offset); + list.Insert(++overlapIndex, currentOverlap); + startIndex++; + + removedAny = true; + } + + // Remove any middle overlaps. + while (currentOverlap.Offset < endOffset) + { + if (currentOverlap.End > endOffset) + { + // Update the end overlap instead of removing it, if it spans beyond the removed range. + list[overlapIndex] = new Range(endOffset, currentOverlap.End - endOffset); + + removedAny = true; + break; + } + + if (++overlapIndex >= list.Count) + { + break; + } + + currentOverlap = list[overlapIndex]; + } + + int count = overlapIndex - startIndex; + + list.RemoveRange(startIndex, count); + + removedAny |= count > 0; + } + } + + return removedAny; + } + + public void Add(int offset, int size) + { + var list = _ranges; + if (list != null) + { + int overlapIndex = BinarySearch(list, offset, size); + if (overlapIndex >= 0) + { + while (overlapIndex > 0 && list[overlapIndex - 1].OverlapsWith(offset, size)) + { + overlapIndex--; + } + + int endOffset = offset + size; + int startIndex = overlapIndex; + + while (overlapIndex < list.Count && list[overlapIndex].OverlapsWith(offset, size)) + { + var currentOverlap = list[overlapIndex]; + var currentOverlapEndOffset = currentOverlap.Offset + currentOverlap.Size; + + if (offset > currentOverlap.Offset) + { + offset = currentOverlap.Offset; + } + + if (endOffset < currentOverlapEndOffset) + { + endOffset = currentOverlapEndOffset; + } + + overlapIndex++; + size = endOffset - offset; + } + + int count = overlapIndex - startIndex; + + list.RemoveRange(startIndex, count); + + overlapIndex = startIndex; + } + else + { + overlapIndex = ~overlapIndex; + } + + list.Insert(overlapIndex, new Range(offset, size)); + } + else + { + _ranges = new List + { + new Range(offset, size) + }; + } + } + + public readonly bool OverlapsWith(int offset, int size) + { + var list = _ranges; + if (list == null) + { + return false; + } + + return BinarySearch(list, offset, size) >= 0; + } + + public readonly List FindOverlaps(int offset, int size) + { + var list = _ranges; + if (list == null) + { + return null; + } + + List result = null; + + int index = BinarySearch(list, offset, size); + + if (index >= 0) + { + while (index > 0 && list[index - 1].OverlapsWith(offset, size)) + { + index--; + } + + do + { + (result ??= new List()).Add(list[index++]); + } + while (index < list.Count && list[index].OverlapsWith(offset, size)); + } + + return result; + } + + private static int BinarySearch(List list, int offset, int size) + { + int left = 0; + int right = list.Count - 1; + + while (left <= right) + { + int range = right - left; + + int middle = left + (range >> 1); + + var item = list[middle]; + + if (item.OverlapsWith(offset, size)) + { + return middle; + } + + if (offset < item.Offset) + { + right = middle - 1; + } + else + { + left = middle + 1; + } + } + + return ~left; + } + + public readonly void FillData(Span baseData, Span modData, int offset, Span result) + { + int size = baseData.Length; + int endOffset = offset + size; + + var list = _ranges; + if (list == null) + { + baseData.CopyTo(result); + } + + int srcOffset = offset; + int dstOffset = 0; + bool activeRange = false; + + for (int i = 0; i < list.Count; i++) + { + var range = list[i]; + + int rangeEnd = range.Offset + range.Size; + + if (activeRange) + { + if (range.Offset >= endOffset) + { + break; + } + } + else + { + if (rangeEnd <= offset) + { + continue; + } + + activeRange = true; + } + + int baseSize = range.Offset - srcOffset; + + if (baseSize > 0) + { + baseData.Slice(dstOffset, baseSize).CopyTo(result.Slice(dstOffset, baseSize)); + srcOffset += baseSize; + dstOffset += baseSize; + } + + int modSize = Math.Min(rangeEnd - srcOffset, endOffset - srcOffset); + if (modSize != 0) + { + modData.Slice(dstOffset, modSize).CopyTo(result.Slice(dstOffset, modSize)); + srcOffset += modSize; + dstOffset += modSize; + } + } + + int baseSizeEnd = endOffset - srcOffset; + + if (baseSizeEnd > 0) + { + baseData.Slice(dstOffset, baseSizeEnd).CopyTo(result.Slice(dstOffset, baseSizeEnd)); + } + } + + public readonly int Count() + { + return _ranges?.Count ?? 0; + } + + public void Clear() + { + _ranges = null; + } + } +} diff --git a/src/Ryujinx.Graphics.Metal/EncoderState.cs b/src/Ryujinx.Graphics.Metal/EncoderState.cs index 1ba7e26201..04db6090b7 100644 --- a/src/Ryujinx.Graphics.Metal/EncoderState.cs +++ b/src/Ryujinx.Graphics.Metal/EncoderState.cs @@ -1,3 +1,4 @@ +using Ryujinx.Common.Memory; using Ryujinx.Graphics.GAL; using SharpMetal.Metal; using System; @@ -67,6 +68,11 @@ namespace Ryujinx.Graphics.Metal public BufferRef[] UniformBuffers = new BufferRef[Constants.MaxUniformBuffersPerStage]; public BufferRef[] StorageBuffers = new BufferRef[Constants.MaxStorageBuffersPerStage]; + public BitMapStruct> UniformSet; + public BitMapStruct> StorageSet; + public BitMapStruct> UniformMirrored; + public BitMapStruct> StorageMirrored; + public Auto IndexBuffer = default; public MTLIndexType IndexType = MTLIndexType.UInt16; public ulong IndexBufferOffset = 0; diff --git a/src/Ryujinx.Graphics.Metal/EncoderStateManager.cs b/src/Ryujinx.Graphics.Metal/EncoderStateManager.cs index 218e378b07..9de3caf9c0 100644 --- a/src/Ryujinx.Graphics.Metal/EncoderStateManager.cs +++ b/src/Ryujinx.Graphics.Metal/EncoderStateManager.cs @@ -81,6 +81,8 @@ namespace Ryujinx.Graphics.Metal // Mark the other state as dirty _currentState.Dirty |= DirtyFlags.All; + _currentState.UniformSet.Clear(); + _currentState.StorageSet.Clear(); } else { @@ -368,6 +370,73 @@ namespace Ryujinx.Graphics.Metal computeCommandEncoder.SetComputePipelineState(pipelineState); } + private static bool BindingOverlaps(BufferRange info, int offset, int size) + { + return offset < info.Offset + info.Size && (offset + size) > info.Offset; + } + + public void Rebind(Auto buffer, int offset, int size) + { + if (_pipeline.CurrentEncoder == null) + { + return; + } + + // Check stage bindings + + var currentState = _currentState; + + _currentState.UniformMirrored.Union(_currentState.UniformSet).SignalSet((int binding, int count) => + { + for (int i = 0; i < count; i++) + { + ref BufferRef bufferRef = ref currentState.UniformBuffers[binding]; + if (bufferRef.Buffer == buffer) + { + if (bufferRef.Range != null) + { + if (!BindingOverlaps(bufferRef.Range.Value, offset, size)) + { + binding++; + continue; + } + } + + currentState.UniformSet.Clear(binding); + currentState.Dirty |= DirtyFlags.Buffers; + } + + binding++; + } + }); + + _currentState.StorageMirrored.Union(_currentState.StorageSet).SignalSet((int binding, int count) => + { + for (int i = 0; i < count; i++) + { + ref BufferRef bufferRef = ref currentState.StorageBuffers[binding]; + if (bufferRef.Buffer == buffer) + { + if (bufferRef.Range != null) + { + if (!BindingOverlaps(bufferRef.Range.Value, offset, size)) + { + binding++; + continue; + } + } + + currentState.StorageSet.Clear(binding); + currentState.Dirty |= DirtyFlags.Buffers; + } + + binding++; + } + }); + + _currentState = currentState; + } + public void UpdateIndexBuffer(BufferRange buffer, IndexType type) { if (buffer.Handle != BufferHandle.Null) @@ -702,7 +771,16 @@ namespace Ryujinx.Graphics.Metal ? null : _bufferManager.GetBuffer(buffer.Handle, buffer.Write); - _currentState.UniformBuffers[index] = new BufferRef(mtlBuffer, ref buffer); + ref BufferRef currentBufferRef = ref _currentState.UniformBuffers[index]; + + BufferRef newRef = new(mtlBuffer, ref buffer); + + if (!currentBufferRef.Equals(newRef)) + { + _currentState.UniformSet.Clear(index); + + currentBufferRef = newRef; + } } _currentState.Dirty |= DirtyFlags.Buffers; @@ -719,7 +797,16 @@ namespace Ryujinx.Graphics.Metal ? null : _bufferManager.GetBuffer(buffer.Handle, buffer.Write); - _currentState.StorageBuffers[index] = new BufferRef(mtlBuffer, ref buffer); + ref BufferRef currentBufferRef = ref _currentState.StorageBuffers[index]; + + BufferRef newRef = new(mtlBuffer, ref buffer); + + if (!currentBufferRef.Equals(newRef)) + { + _currentState.StorageSet.Clear(index); + + currentBufferRef = newRef; + } } _currentState.Dirty |= DirtyFlags.Buffers; @@ -986,7 +1073,7 @@ namespace Ryujinx.Graphics.Metal renderCommandEncoder.SetVertexBuffer(zeroMtlBuffer, 0, Constants.ZeroBufferIndex); } - private readonly void SetRenderBuffers(MTLRenderCommandEncoder renderCommandEncoder, BufferRef[] uniformBuffers, BufferRef[] storageBuffers) + private void SetRenderBuffers(MTLRenderCommandEncoder renderCommandEncoder, BufferRef[] uniformBuffers, BufferRef[] storageBuffers) { var uniformArgBufferRange = CreateArgumentBufferForRenderEncoder(renderCommandEncoder, uniformBuffers, true); var uniformArgBuffer = _bufferManager.GetBuffer(uniformArgBufferRange.Handle, false).Get(_pipeline.Cbs).Value; @@ -1001,7 +1088,7 @@ namespace Ryujinx.Graphics.Metal renderCommandEncoder.SetFragmentBuffer(storageArgBuffer, (ulong)storageArgBufferRange.Offset, Constants.StorageBuffersIndex); } - private readonly void SetComputeBuffers(MTLComputeCommandEncoder computeCommandEncoder, BufferRef[] uniformBuffers, BufferRef[] storageBuffers) + private void SetComputeBuffers(MTLComputeCommandEncoder computeCommandEncoder, BufferRef[] uniformBuffers, BufferRef[] storageBuffers) { var uniformArgBufferRange = CreateArgumentBufferForComputeEncoder(computeCommandEncoder, uniformBuffers, true); var uniformArgBuffer = _bufferManager.GetBuffer(uniformArgBufferRange.Handle, false).Get(_pipeline.Cbs).Value; @@ -1015,7 +1102,7 @@ namespace Ryujinx.Graphics.Metal computeCommandEncoder.SetBuffer(storageArgBuffer, (ulong)storageArgBufferRange.Offset, Constants.StorageBuffersIndex); } - private readonly BufferRange CreateArgumentBufferForRenderEncoder(MTLRenderCommandEncoder renderCommandEncoder, BufferRef[] buffers, bool constant) + private BufferRange CreateArgumentBufferForRenderEncoder(MTLRenderCommandEncoder renderCommandEncoder, BufferRef[] buffers, bool constant) { var usage = constant ? MTLResourceUsage.Read : MTLResourceUsage.Write; @@ -1037,8 +1124,16 @@ namespace Ryujinx.Graphics.Metal if (range.HasValue) { offset = range.Value.Offset; - mtlBuffer = autoBuffer.Get(_pipeline.Cbs, offset, range.Value.Size, range.Value.Write).Value; + mtlBuffer = autoBuffer.GetMirrorable(_pipeline.Cbs, ref offset, range.Value.Size, out bool mirrored).Value; + if (constant) + { + _currentState.UniformMirrored.Set(i, mirrored); + } + else + { + _currentState.StorageMirrored.Set(i, mirrored); + } } else { @@ -1057,7 +1152,7 @@ namespace Ryujinx.Graphics.Metal return argBuffer.Range; } - private readonly BufferRange CreateArgumentBufferForComputeEncoder(MTLComputeCommandEncoder computeCommandEncoder, BufferRef[] buffers, bool constant) + private BufferRange CreateArgumentBufferForComputeEncoder(MTLComputeCommandEncoder computeCommandEncoder, BufferRef[] buffers, bool constant) { var usage = constant ? MTLResourceUsage.Read : MTLResourceUsage.Write; @@ -1079,8 +1174,16 @@ namespace Ryujinx.Graphics.Metal if (range.HasValue) { offset = range.Value.Offset; - mtlBuffer = autoBuffer.Get(_pipeline.Cbs, offset, range.Value.Size, range.Value.Write).Value; + mtlBuffer = autoBuffer.GetMirrorable(_pipeline.Cbs, ref offset, range.Value.Size, out bool mirrored).Value; + if (constant) + { + _currentState.UniformMirrored.Set(i, mirrored); + } + else + { + _currentState.StorageMirrored.Set(i, mirrored); + } } else { diff --git a/src/Ryujinx.Graphics.Metal/Pipeline.cs b/src/Ryujinx.Graphics.Metal/Pipeline.cs index 93064e60a5..1768c20382 100644 --- a/src/Ryujinx.Graphics.Metal/Pipeline.cs +++ b/src/Ryujinx.Graphics.Metal/Pipeline.cs @@ -5,6 +5,7 @@ using SharpMetal.Foundation; using SharpMetal.Metal; using SharpMetal.QuartzCore; using System; +using System.Collections.Generic; using System.Runtime.Versioning; namespace Ryujinx.Graphics.Metal @@ -24,6 +25,8 @@ namespace Ryujinx.Graphics.Metal private readonly MTLDevice _device; private readonly MetalRenderer _renderer; + private readonly List _activeBufferMirrors; + private EncoderStateManager _encoderStateManager; private ulong _byteWeight; @@ -40,6 +43,7 @@ namespace Ryujinx.Graphics.Metal { _device = device; _renderer = renderer; + _activeBufferMirrors = new(); EndRenderPassDelegate = EndCurrentPass; @@ -201,6 +205,13 @@ namespace Ryujinx.Graphics.Metal CommandBuffer = (Cbs = _renderer.CommandBufferPool.ReturnAndRent(Cbs)).CommandBuffer; + // Restore per-command buffer state. + foreach (BufferHolder buffer in _activeBufferMirrors) + { + buffer.ClearMirrors(); + } + _activeBufferMirrors.Clear(); + // TODO: Auto flush counting _renderer.SyncManager.GetAndResetWaitTicks(); @@ -208,6 +219,16 @@ namespace Ryujinx.Graphics.Metal dst.Dispose(); } + public void RegisterActiveMirror(BufferHolder buffer) + { + _activeBufferMirrors.Add(buffer); + } + + public void Rebind(Auto buffer, int offset, int size) + { + _encoderStateManager.Rebind(buffer, offset, size); + } + public void FlushCommandsIfWeightExceeding(IAuto disposedResource, ulong byteWeight) { bool usedByCurrentCb = disposedResource.HasCommandBufferDependency(Cbs);