Fix send static buffer copy

This commit is contained in:
gdkchan 2018-12-30 23:52:48 -03:00
commit 4f89f148b4

View file

@ -44,6 +44,11 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc
Size = 0x100; Size = 0x100;
} }
} }
public Message(KSessionRequest request) : this(
request.ClientThread,
request.CustomCmdBuffAddr,
request.CustomCmdBuffSize) { }
} }
private struct MessageHeader private struct MessageHeader
@ -244,14 +249,11 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc
request.ServerProcess = serverProcess; request.ServerProcess = serverProcess;
Message clientMsg = new Message( Message clientMsg = new Message(request);
clientThread,
request.CustomCmdBuffAddr,
request.CustomCmdBuffSize);
Message serverMsg = new Message(serverThread, customCmdBuffAddr, customCmdBuffSize); Message serverMsg = new Message(serverThread, customCmdBuffAddr, customCmdBuffSize);
MessageHeader header = GetClientMessageHeader(clientMsg); MessageHeader clientHeader = GetClientMessageHeader(clientMsg);
MessageHeader serverHeader = GetServerMessageHeader(serverMsg);
KernelResult serverResult = KernelResult.NotFound; KernelResult serverResult = KernelResult.NotFound;
KernelResult clientResult = KernelResult.Success; KernelResult clientResult = KernelResult.Success;
@ -263,7 +265,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc
request.BufferDescriptorTable.RestoreClientBuffers(clientProcess.MemoryManager); request.BufferDescriptorTable.RestoreClientBuffers(clientProcess.MemoryManager);
} }
CloseAllHandles(serverMsg, header, serverProcess); CloseAllHandles(serverMsg, clientHeader, serverProcess);
System.CriticalSection.Enter(); System.CriticalSection.Enter();
@ -279,71 +281,74 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc
WakeClient(request, clientResult); WakeClient(request, clientResult);
} }
if (header.ReceiveListType < 2 && if (clientHeader.ReceiveListType < 2 &&
header.ReceiveListOffset > clientMsg.Size) clientHeader.ReceiveListOffset > clientMsg.Size)
{ {
CleanUpForError(); CleanUpForError();
return KernelResult.InvalidCombination; return KernelResult.InvalidCombination;
} }
else if (header.ReceiveListType == 2 && else if (clientHeader.ReceiveListType == 2 &&
header.ReceiveListOffset + 8 > clientMsg.Size) clientHeader.ReceiveListOffset + 8 > clientMsg.Size)
{ {
CleanUpForError(); CleanUpForError();
return KernelResult.InvalidCombination; return KernelResult.InvalidCombination;
} }
else if (header.ReceiveListType > 2 && else if (clientHeader.ReceiveListType > 2 &&
header.ReceiveListType * 8 - 0x10 + header.ReceiveListOffset > clientMsg.Size) clientHeader.ReceiveListType * 8 - 0x10 + clientHeader.ReceiveListOffset > clientMsg.Size)
{ {
CleanUpForError(); CleanUpForError();
return KernelResult.InvalidCombination; return KernelResult.InvalidCombination;
} }
if (header.ReceiveListOffsetInWords < header.MessageSizeInWords) if (clientHeader.ReceiveListOffsetInWords < clientHeader.MessageSizeInWords)
{ {
CleanUpForError(); CleanUpForError();
return KernelResult.InvalidCombination; return KernelResult.InvalidCombination;
} }
if (header.MessageSizeInWords * 4 > clientMsg.Size) if (clientHeader.MessageSizeInWords * 4 > clientMsg.Size)
{ {
CleanUpForError(); CleanUpForError();
return KernelResult.CmdBufferTooSmall; return KernelResult.CmdBufferTooSmall;
} }
ulong[] receiveList = GetReceiveList(clientMsg, header.ReceiveListType, header.ReceiveListOffset); ulong[] receiveList = GetReceiveList(
serverMsg,
serverHeader.ReceiveListType,
serverHeader.ReceiveListOffset);
serverProcess.CpuMemory.WriteUInt32((long)serverMsg.Address + 0, header.Word0); serverProcess.CpuMemory.WriteUInt32((long)serverMsg.Address + 0, clientHeader.Word0);
serverProcess.CpuMemory.WriteUInt32((long)serverMsg.Address + 4, header.Word1); serverProcess.CpuMemory.WriteUInt32((long)serverMsg.Address + 4, clientHeader.Word1);
uint offset; uint offset;
//Copy handles. //Copy handles.
if (header.HasHandles) if (clientHeader.HasHandles)
{ {
if (header.MoveHandlesCount != 0) if (clientHeader.MoveHandlesCount != 0)
{ {
CleanUpForError(); CleanUpForError();
return KernelResult.InvalidCombination; return KernelResult.InvalidCombination;
} }
serverProcess.CpuMemory.WriteUInt32((long)serverMsg.Address + 8, header.Word2); serverProcess.CpuMemory.WriteUInt32((long)serverMsg.Address + 8, clientHeader.Word2);
offset = 3; offset = 3;
if (header.HasPid) if (clientHeader.HasPid)
{ {
serverProcess.CpuMemory.WriteInt64((long)serverMsg.Address + offset * 4, clientProcess.Pid); serverProcess.CpuMemory.WriteInt64((long)serverMsg.Address + offset * 4, clientProcess.Pid);
offset += 2; offset += 2;
} }
for (int index = 0; index < header.CopyHandlesCount; index++) for (int index = 0; index < clientHeader.CopyHandlesCount; index++)
{ {
int newHandle = 0; int newHandle = 0;
@ -359,7 +364,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc
offset++; offset++;
} }
for (int index = 0; index < header.MoveHandlesCount; index++) for (int index = 0; index < clientHeader.MoveHandlesCount; index++)
{ {
int newHandle = 0; int newHandle = 0;
@ -395,7 +400,9 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc
} }
//Copy pointer/receive list buffers. //Copy pointer/receive list buffers.
for (int index = 0; index < header.PointerBuffersCount; index++) uint recvListDstOffset = 0;
for (int index = 0; index < clientHeader.PointerBuffersCount; index++)
{ {
ulong pointerDesc = System.Device.Memory.ReadUInt64((long)clientMsg.DramAddress + offset * 4); ulong pointerDesc = System.Device.Memory.ReadUInt64((long)clientMsg.DramAddress + offset * 4);
@ -406,9 +413,10 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc
clientResult = GetReceiveListAddress( clientResult = GetReceiveListAddress(
descriptor, descriptor,
serverMsg, serverMsg,
header.ReceiveListType, serverHeader.ReceiveListType,
header.MessageSizeInWords, clientHeader.MessageSizeInWords,
receiveList, receiveList,
ref recvListDstOffset,
out ulong recvListBufferAddress); out ulong recvListBufferAddress);
if (clientResult != KernelResult.Success) if (clientResult != KernelResult.Success)
@ -449,9 +457,9 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc
//Copy send, receive and exchange buffers. //Copy send, receive and exchange buffers.
uint totalBuffersCount = uint totalBuffersCount =
header.SendBuffersCount + clientHeader.SendBuffersCount +
header.ReceiveBuffersCount + clientHeader.ReceiveBuffersCount +
header.ExchangeBuffersCount; clientHeader.ExchangeBuffersCount;
for (int index = 0; index < totalBuffersCount; index++) for (int index = 0; index < totalBuffersCount; index++)
{ {
@ -461,13 +469,13 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc
uint descWord1 = System.Device.Memory.ReadUInt32(clientDescAddress + 4); uint descWord1 = System.Device.Memory.ReadUInt32(clientDescAddress + 4);
uint descWord2 = System.Device.Memory.ReadUInt32(clientDescAddress + 8); uint descWord2 = System.Device.Memory.ReadUInt32(clientDescAddress + 8);
bool isSendDesc = index < header.SendBuffersCount; bool isSendDesc = index < clientHeader.SendBuffersCount;
bool isExchangeDesc = index >= header.SendBuffersCount + header.ReceiveBuffersCount; bool isExchangeDesc = index >= clientHeader.SendBuffersCount + clientHeader.ReceiveBuffersCount;
bool notReceiveDesc = isSendDesc || isExchangeDesc; bool notReceiveDesc = isSendDesc || isExchangeDesc;
bool isReceiveDesc = !notReceiveDesc; bool isReceiveDesc = !notReceiveDesc;
MemoryPermission permission = index >= header.SendBuffersCount MemoryPermission permission = index >= clientHeader.SendBuffersCount
? MemoryPermission.ReadAndWrite ? MemoryPermission.ReadAndWrite
: MemoryPermission.Read; : MemoryPermission.Read;
@ -544,12 +552,12 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc
} }
//Copy raw data. //Copy raw data.
if (header.RawDataSizeInWords != 0) if (clientHeader.RawDataSizeInWords != 0)
{ {
ulong copySrc = clientMsg.Address + offset * 4; ulong copySrc = clientMsg.Address + offset * 4;
ulong copyDst = serverMsg.Address + offset * 4; ulong copyDst = serverMsg.Address + offset * 4;
ulong copySize = header.RawDataSizeInWords * 4; ulong copySize = clientHeader.RawDataSizeInWords * 4;
if (serverMsg.IsCustom || clientMsg.IsCustom) if (serverMsg.IsCustom || clientMsg.IsCustom)
{ {
@ -614,27 +622,18 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc
KThread clientThread = request.ClientThread; KThread clientThread = request.ClientThread;
KProcess clientProcess = clientThread.Owner; KProcess clientProcess = clientThread.Owner;
Message clientMsg = new Message( Message clientMsg = new Message(request);
clientThread,
request.CustomCmdBuffAddr,
request.CustomCmdBuffSize);
Message serverMsg = new Message(serverThread, customCmdBuffAddr, customCmdBuffSize); Message serverMsg = new Message(serverThread, customCmdBuffAddr, customCmdBuffSize);
uint word0 = serverProcess.CpuMemory.ReadUInt32((long)serverMsg.Address + 0);
uint word1 = serverProcess.CpuMemory.ReadUInt32((long)serverMsg.Address + 4);
uint word2 = serverProcess.CpuMemory.ReadUInt32((long)serverMsg.Address + 8);
MessageHeader header = new MessageHeader(word0, word1, word2);
MessageHeader clientHeader = GetClientMessageHeader(clientMsg); MessageHeader clientHeader = GetClientMessageHeader(clientMsg);
MessageHeader serverHeader = GetServerMessageHeader(serverMsg);
KernelResult clientResult = KernelResult.Success; KernelResult clientResult = KernelResult.Success;
KernelResult serverResult = KernelResult.Success; KernelResult serverResult = KernelResult.Success;
void CleanUpForError() void CleanUpForError()
{ {
CloseAllHandles(clientMsg, header, clientProcess); CloseAllHandles(clientMsg, serverHeader, clientProcess);
CancelRequest(request, clientResult); CancelRequest(request, clientResult);
} }
@ -668,16 +667,16 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc
return KernelResult.InvalidCombination; return KernelResult.InvalidCombination;
} }
if (header.MessageSizeInWords * 4 > clientMsg.Size) if (serverHeader.MessageSizeInWords * 4 > clientMsg.Size)
{ {
CleanUpForError(); CleanUpForError();
return KernelResult.CmdBufferTooSmall; return KernelResult.CmdBufferTooSmall;
} }
if (header.SendBuffersCount != 0 || if (serverHeader.SendBuffersCount != 0 ||
header.ReceiveBuffersCount != 0 || serverHeader.ReceiveBuffersCount != 0 ||
header.ExchangeBuffersCount != 0) serverHeader.ExchangeBuffersCount != 0)
{ {
CleanUpForError(); CleanUpForError();
@ -701,26 +700,26 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc
} }
//Copy header. //Copy header.
System.Device.Memory.WriteUInt32((long)clientMsg.DramAddress + 0, word0); System.Device.Memory.WriteUInt32((long)clientMsg.DramAddress + 0, serverHeader.Word0);
System.Device.Memory.WriteUInt32((long)clientMsg.DramAddress + 4, word1); System.Device.Memory.WriteUInt32((long)clientMsg.DramAddress + 4, serverHeader.Word1);
//Copy handles. //Copy handles.
uint offset; uint offset;
if (header.HasHandles) if (serverHeader.HasHandles)
{ {
offset = 3; offset = 3;
System.Device.Memory.WriteUInt32((long)clientMsg.DramAddress + 8, word2); System.Device.Memory.WriteUInt32((long)clientMsg.DramAddress + 8, serverHeader.Word2);
if (header.HasPid) if (serverHeader.HasPid)
{ {
System.Device.Memory.WriteInt64((long)clientMsg.DramAddress + offset * 4, serverProcess.Pid); System.Device.Memory.WriteInt64((long)clientMsg.DramAddress + offset * 4, serverProcess.Pid);
offset += 2; offset += 2;
} }
for (int index = 0; index < header.CopyHandlesCount; index++) for (int index = 0; index < serverHeader.CopyHandlesCount; index++)
{ {
int newHandle = 0; int newHandle = 0;
@ -736,7 +735,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc
offset++; offset++;
} }
for (int index = 0; index < header.MoveHandlesCount; index++) for (int index = 0; index < serverHeader.MoveHandlesCount; index++)
{ {
int newHandle = 0; int newHandle = 0;
@ -765,7 +764,9 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc
} }
//Copy pointer/receive list buffers. //Copy pointer/receive list buffers.
for (int index = 0; index < header.PointerBuffersCount; index++) uint recvListDstOffset = 0;
for (int index = 0; index < serverHeader.PointerBuffersCount; index++)
{ {
ulong pointerDesc = serverProcess.CpuMemory.ReadUInt64((long)serverMsg.Address + offset * 4); ulong pointerDesc = serverProcess.CpuMemory.ReadUInt64((long)serverMsg.Address + offset * 4);
@ -777,8 +778,9 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc
descriptor, descriptor,
clientMsg, clientMsg,
clientHeader.ReceiveListType, clientHeader.ReceiveListType,
header.MessageSizeInWords, serverHeader.MessageSizeInWords,
receiveList, receiveList,
ref recvListDstOffset,
out ulong recvListBufferAddress); out ulong recvListBufferAddress);
if (clientResult != KernelResult.Success) if (clientResult != KernelResult.Success)
@ -811,9 +813,9 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc
//Set send, receive and exchange buffer descriptors to zero. //Set send, receive and exchange buffer descriptors to zero.
uint totalBuffersCount = uint totalBuffersCount =
header.SendBuffersCount + serverHeader.SendBuffersCount +
header.ReceiveBuffersCount + serverHeader.ReceiveBuffersCount +
header.ExchangeBuffersCount; serverHeader.ExchangeBuffersCount;
for (int index = 0; index < totalBuffersCount; index++) for (int index = 0; index < totalBuffersCount; index++)
{ {
@ -827,12 +829,12 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc
} }
//Copy raw data. //Copy raw data.
if (header.RawDataSizeInWords != 0) if (serverHeader.RawDataSizeInWords != 0)
{ {
ulong copyDst = clientMsg.Address + offset * 4; ulong copyDst = clientMsg.Address + offset * 4;
ulong copySrc = serverMsg.Address + offset * 4; ulong copySrc = serverMsg.Address + offset * 4;
ulong copySize = header.RawDataSizeInWords * 4; ulong copySize = serverHeader.RawDataSizeInWords * 4;
if (serverMsg.IsCustom || clientMsg.IsCustom) if (serverMsg.IsCustom || clientMsg.IsCustom)
{ {
@ -883,6 +885,17 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc
return new MessageHeader(word0, word1, word2); return new MessageHeader(word0, word1, word2);
} }
private MessageHeader GetServerMessageHeader(Message serverMsg)
{
KProcess currentProcess = System.Scheduler.GetCurrentProcess();
uint word0 = currentProcess.CpuMemory.ReadUInt32((long)serverMsg.Address + 0);
uint word1 = currentProcess.CpuMemory.ReadUInt32((long)serverMsg.Address + 4);
uint word2 = currentProcess.CpuMemory.ReadUInt32((long)serverMsg.Address + 8);
return new MessageHeader(word0, word1, word2);
}
private KernelResult GetCopyObjectHandle( private KernelResult GetCopyObjectHandle(
KThread srcThread, KThread srcThread,
KProcess dstProcess, KProcess dstProcess,
@ -973,6 +986,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc
uint recvListType, uint recvListType,
uint messageSizeInWords, uint messageSizeInWords,
ulong[] receiveList, ulong[] receiveList,
ref uint dstOffset,
out ulong address) out ulong address)
{ {
ulong recvListBufferAddress = address = 0; ulong recvListBufferAddress = address = 0;
@ -1007,7 +1021,11 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc
recvListEndAddr = recvListBaseAddr + size; recvListEndAddr = recvListBaseAddr + size;
} }
recvListBufferAddress = BitUtils.AlignUp(recvListBaseAddr + descriptor.ReceiveIndex, 0x10); recvListBufferAddress = BitUtils.AlignUp(recvListBaseAddr + dstOffset, 0x10);
ulong endAddress = recvListBufferAddress + descriptor.BufferSize;
dstOffset = (uint)endAddress - (uint)recvListBaseAddr;
if (recvListBufferAddress + descriptor.BufferSize <= recvListBufferAddress || if (recvListBufferAddress + descriptor.BufferSize <= recvListBufferAddress ||
recvListBufferAddress + descriptor.BufferSize > recvListEndAddr) recvListBufferAddress + descriptor.BufferSize > recvListEndAddr)