diff --git a/Ryujinx.HLE/HOS/Kernel/Common/KSynchronizationObject.cs b/Ryujinx.HLE/HOS/Kernel/Common/KSynchronizationObject.cs index 87e5531210..77d8fbff8f 100644 --- a/Ryujinx.HLE/HOS/Kernel/Common/KSynchronizationObject.cs +++ b/Ryujinx.HLE/HOS/Kernel/Common/KSynchronizationObject.cs @@ -5,7 +5,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Common { class KSynchronizationObject : KAutoObject { - public LinkedList WaitingThreads; + public LinkedList WaitingThreads { get; } public KSynchronizationObject(Horizon system) : base(system) { diff --git a/Ryujinx.HLE/HOS/Kernel/Ipc/KClientSession.cs b/Ryujinx.HLE/HOS/Kernel/Ipc/KClientSession.cs index 31c1e43e9b..f139d3d45d 100644 --- a/Ryujinx.HLE/HOS/Kernel/Ipc/KClientSession.cs +++ b/Ryujinx.HLE/HOS/Kernel/Ipc/KClientSession.cs @@ -22,6 +22,10 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc _parent = parent; State = ChannelState.Open; + + CreatorProcess = system.Scheduler.GetCurrentProcess(); + + CreatorProcess.IncrementReferenceCount(); } public KernelResult SendSyncRequest(ulong customCmdBuffAddr = 0, ulong customCmdBuffSize = 0) @@ -30,8 +34,6 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc KSessionRequest request = new KSessionRequest(currentThread, customCmdBuffAddr, customCmdBuffSize); - currentThread.IncrementReferenceCount(); - System.CriticalSection.Enter(); currentThread.SignaledObj = null; diff --git a/Ryujinx.HLE/HOS/Kernel/Ipc/KServerSession.cs b/Ryujinx.HLE/HOS/Kernel/Ipc/KServerSession.cs index d477403ad2..5a45ff4a99 100644 --- a/Ryujinx.HLE/HOS/Kernel/Ipc/KServerSession.cs +++ b/Ryujinx.HLE/HOS/Kernel/Ipc/KServerSession.cs @@ -226,7 +226,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc return KernelResult.PortRemoteClosed; } - if (_activeRequest != null || !PickRequest(out KSessionRequest request)) + if (_activeRequest != null || !DequeueRequest(out KSessionRequest request)) { System.CriticalSection.Leave(); @@ -278,7 +278,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc System.CriticalSection.Leave(); - WakeClient(request, clientResult); + WakeClientThread(request, clientResult); } if (clientHeader.ReceiveListType < 2 && @@ -871,7 +871,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc return serverResult; } - WakeClient(request, clientResult); + WakeClientThread(request, clientResult); return serverResult; } @@ -1044,9 +1044,9 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc recvListBufferAddress = packed & 0x7fffffffff; - uint transferSize = (uint)(packed >> 48); + uint size = (uint)(packed >> 48); - if (recvListBufferAddress == 0 || transferSize == 0 || transferSize < descriptor.BufferSize) + if (recvListBufferAddress == 0 || size == 0 || size < descriptor.BufferSize) { return KernelResult.OutOfResource; } @@ -1102,12 +1102,42 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc { _parent.DisconnectServer(); - CancelAllRequests(KernelResult.PortRemoteClosed); + CancelAllRequestsServerDisconnected(); _parent.DecrementReferenceCount(); } - private void CancelAllRequests(KernelResult result) + private void CancelAllRequestsServerDisconnected() + { + foreach (KSessionRequest request in IterateWithRemovalOfAllRequests()) + { + CancelRequest(request, KernelResult.PortRemoteClosed); + } + } + + public void CancelAllRequestsClientDisconnected() + { + foreach (KSessionRequest request in IterateWithRemovalOfAllRequests()) + { + if (request.ClientThread.ShallBeTerminated || + request.ClientThread.SchedFlags == ThreadSchedState.TerminationPending) + { + continue; + } + + //Client sessions can only be disconnected on async requests (because + //the client would be otherwise blocked waiting for the response), so + //we only need to handle the async case here. + if (request.AsyncEvent != null) + { + SendResultToAsyncRequestClient(request, KernelResult.PortRemoteClosed); + } + } + + WakeServerThreads(KernelResult.PortRemoteClosed); + } + + private IEnumerable IterateWithRemovalOfAllRequests() { System.CriticalSection.Enter(); @@ -1117,22 +1147,22 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc _activeRequest = null; - CancelRequest(request, result); - System.CriticalSection.Leave(); + + yield return request; } else { System.CriticalSection.Leave(); } - while (PickRequest(out KSessionRequest request)) + while (DequeueRequest(out KSessionRequest request)) { - CancelRequest(request, result); + yield return request; } } - private bool PickRequest(out KSessionRequest request) + private bool DequeueRequest(out KSessionRequest request) { request = null; @@ -1169,41 +1199,64 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc request.BufferDescriptorTable.RestoreClientBuffers(clientProcess.MemoryManager); } - WakeClient(request, result); + WakeClientThread(request, result); } - private void WakeClient(KSessionRequest request, KernelResult result) + private void WakeClientThread(KSessionRequest request, KernelResult result) { - KThread clientThread = request.ClientThread; - KProcess clientProcess = clientThread.Owner; - + //Wait client thread waiting for a response for the given request. if (request.AsyncEvent != null) { - ulong address = clientProcess.MemoryManager.GetDramAddressFromVa(request.CustomCmdBuffAddr); - - System.Device.Memory.WriteInt64((long)address + 0, 0); - System.Device.Memory.WriteInt32((long)address + 8, (int)result); - - clientProcess.MemoryManager.UnborrowIpcBuffer( - request.CustomCmdBuffAddr, - request.CustomCmdBuffSize); - - request.AsyncEvent.Signal(); + SendResultToAsyncRequestClient(request, result); } else { System.CriticalSection.Enter(); - if ((clientThread.SchedFlags & ThreadSchedState.LowMask) == ThreadSchedState.Paused) - { - clientThread.SignaledObj = null; - clientThread.ObjSyncResult = result; - - clientThread.Reschedule(ThreadSchedState.Running); - } + WakeAndSetResult(request.ClientThread, result); System.CriticalSection.Leave(); } } + + private void SendResultToAsyncRequestClient(KSessionRequest request, KernelResult result) + { + KProcess clientProcess = request.ClientThread.Owner; + + ulong address = clientProcess.MemoryManager.GetDramAddressFromVa(request.CustomCmdBuffAddr); + + System.Device.Memory.WriteInt64((long)address + 0, 0); + System.Device.Memory.WriteInt32((long)address + 8, (int)result); + + clientProcess.MemoryManager.UnborrowIpcBuffer( + request.CustomCmdBuffAddr, + request.CustomCmdBuffSize); + + request.AsyncEvent.Signal(); + } + + private void WakeServerThreads(KernelResult result) + { + //Wake all server threads waiting for requests. + System.CriticalSection.Enter(); + + foreach (KThread thread in WaitingThreads) + { + WakeAndSetResult(thread, result); + } + + System.CriticalSection.Leave(); + } + + private void WakeAndSetResult(KThread thread, KernelResult result) + { + if ((thread.SchedFlags & ThreadSchedState.LowMask) == ThreadSchedState.Paused) + { + thread.SignaledObj = null; + thread.ObjSyncResult = result; + + thread.Reschedule(ThreadSchedState.Running); + } + } } } \ No newline at end of file diff --git a/Ryujinx.HLE/HOS/Kernel/Ipc/KSession.cs b/Ryujinx.HLE/HOS/Kernel/Ipc/KSession.cs index 0759aafa71..cbf689a57c 100644 --- a/Ryujinx.HLE/HOS/Kernel/Ipc/KSession.cs +++ b/Ryujinx.HLE/HOS/Kernel/Ipc/KSession.cs @@ -15,6 +15,8 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc { ServerSession = new KServerSession(system, this); ClientSession = new KClientSession(system, this); + + _hasBeenInitialized = true; } public void DisconnectClient() @@ -23,7 +25,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc { ClientSession.State = ChannelState.ClientDisconnected; - //TODO: Wake up client, etc. + ServerSession.CancelAllRequestsClientDisconnected(); } } @@ -55,6 +57,8 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc KProcess creatorProcess = ClientSession.CreatorProcess; creatorProcess.ResourceLimit?.Release(LimitableResource.Session, 1); + + creatorProcess.DecrementReferenceCount(); } } } diff --git a/Ryujinx.HLE/HOS/Kernel/SupervisorCall/SvcIpc.cs b/Ryujinx.HLE/HOS/Kernel/SupervisorCall/SvcIpc.cs index aa463ca0e1..d8c8d3d52f 100644 --- a/Ryujinx.HLE/HOS/Kernel/SupervisorCall/SvcIpc.cs +++ b/Ryujinx.HLE/HOS/Kernel/SupervisorCall/SvcIpc.cs @@ -65,7 +65,7 @@ namespace Ryujinx.HLE.HOS.Kernel.SupervisorCall return result; } - result = clientPort.Connect(out KClientSession session); + result = clientPort.Connect(out KClientSession clientSession); if (result != KernelResult.Success) { @@ -74,7 +74,9 @@ namespace Ryujinx.HLE.HOS.Kernel.SupervisorCall return result; } - currentProcess.HandleTable.SetReservedHandleObj(handle, session); + currentProcess.HandleTable.SetReservedHandleObj(handle, clientSession); + + clientSession.DecrementReferenceCount(); return result; } diff --git a/Ryujinx.HLE/HOS/Kernel/Threading/KScheduler.cs b/Ryujinx.HLE/HOS/Kernel/Threading/KScheduler.cs index 60e15efa08..c9686df363 100644 --- a/Ryujinx.HLE/HOS/Kernel/Threading/KScheduler.cs +++ b/Ryujinx.HLE/HOS/Kernel/Threading/KScheduler.cs @@ -210,9 +210,29 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading } } + return GetDummyThread(); + throw new InvalidOperationException("Current thread is not scheduled!"); } + private KThread _dummyThread; + + private KThread GetDummyThread() + { + if (_dummyThread != null) + { + return _dummyThread; + } + + KProcess dummyProcess = new KProcess(_system); + + KThread dummyThread = new KThread(_system); + + dummyThread.Initialize(0, 0, 0, 44, 0, dummyProcess, ThreadType.Dummy); + + return _dummyThread = dummyThread; + } + public KProcess GetCurrentProcess() { return GetCurrentThread().Owner;