Refactor code for reading the body of a request

This commit is contained in:
Slendy 2023-04-06 22:23:49 -05:00
parent 2372dfbb9e
commit 4150e44f80
No known key found for this signature in database
GPG key ID: 7288D68361B91428
6 changed files with 45 additions and 69 deletions

View file

@ -30,9 +30,7 @@ public class LoginController : ControllerBase
[HttpPost]
public async Task<IActionResult> Login()
{
MemoryStream ms = new();
await this.Request.Body.CopyToAsync(ms);
byte[] loginData = ms.ToArray();
byte[] loginData = await this.Request.BodyReader.ReadAllAsync();
NPTicket? npTicket;
try

View file

@ -1,5 +1,6 @@
using System.IO.Compression;
using LBPUnion.ProjectLighthouse.Configuration;
using LBPUnion.ProjectLighthouse.Extensions;
using LBPUnion.ProjectLighthouse.Helpers;
using LBPUnion.ProjectLighthouse.Middlewares;
using Microsoft.Extensions.Primitives;
@ -39,7 +40,7 @@ public class DigestMiddleware : Middleware
const string url = "/LITTLEBIGPLANETPS3_XML";
string strippedPath = digestPath.Contains(url) ? digestPath[url.Length..] : "";
#endif
Stream body = context.Request.Body;
byte[] bodyBytes = await context.Request.BodyReader.ReadAllAsync();
bool usedAlternateDigestKey = false;
@ -54,8 +55,8 @@ public class DigestMiddleware : Middleware
excludeBodyFromDigest = true;
}
string clientRequestDigest = await CryptoHelper.ComputeDigest
(digestPath, authCookie, body, ServerConfiguration.Instance.DigestKey.PrimaryDigestKey, excludeBodyFromDigest);
string clientRequestDigest = CryptoHelper.ComputeDigest
(digestPath, authCookie, bodyBytes, ServerConfiguration.Instance.DigestKey.PrimaryDigestKey, excludeBodyFromDigest);
// Check the digest we've just calculated against the digest header if the game set the header. They should match.
if (context.Request.Headers.TryGetValue(digestHeaderKey, out StringValues sentDigest))
@ -65,11 +66,8 @@ public class DigestMiddleware : Middleware
// If we got here, the normal ServerDigestKey failed to validate. Lets try again with the alternate digest key.
usedAlternateDigestKey = true;
// Reset the body stream
body.Position = 0;
clientRequestDigest = await CryptoHelper.ComputeDigest
(digestPath, authCookie, body, ServerConfiguration.Instance.DigestKey.AlternateDigestKey, excludeBodyFromDigest);
clientRequestDigest = CryptoHelper.ComputeDigest
(digestPath, authCookie, bodyBytes, ServerConfiguration.Instance.DigestKey.AlternateDigestKey, excludeBodyFromDigest);
if (clientRequestDigest != sentDigest)
{
#if DEBUG
@ -116,7 +114,7 @@ public class DigestMiddleware : Middleware
: ServerConfiguration.Instance.DigestKey.PrimaryDigestKey;
// Compute the digest for the response.
string serverDigest = await CryptoHelper.ComputeDigest(context.Request.Path, authCookie, responseBuffer, digestKey);
string serverDigest = CryptoHelper.ComputeDigest(context.Request.Path, authCookie, responseBuffer.ToArray(), digestKey);
context.Response.Headers.Add("X-Digest-A", serverDigest);
}

View file

@ -1,8 +1,6 @@
#nullable enable
using System;
using System.Buffers;
using System.IO;
using System.IO.Pipelines;
using System.Linq;
using System.Text;
using System.Text.RegularExpressions;
@ -28,47 +26,16 @@ public static partial class ControllerExtensions
return token;
}
private static void AddStringToBuilder(StringBuilder builder, in ReadOnlySequence<byte> readOnlySequence)
{
// Separate method because Span/ReadOnlySpan cannot be used in async methods
ReadOnlySpan<byte> span = readOnlySequence.IsSingleSegment
? readOnlySequence.First.Span
: readOnlySequence.ToArray();
builder.Append(Encoding.UTF8.GetString(span));
}
public static async Task<string> ReadBodyAsync(this ControllerBase controller)
{
controller.Request.Body.Position = 0;
StringBuilder builder = new();
while (true)
{
ReadResult readResult = await controller.Request.BodyReader.ReadAsync();
ReadOnlySequence<byte> buffer = readResult.Buffer;
if (buffer.Length > 0)
{
AddStringToBuilder(builder, buffer);
}
controller.Request.BodyReader.AdvanceTo(buffer.End);
if (readResult.IsCompleted)
{
break;
}
}
string finalString = builder.ToString();
if (finalString.Length != controller.Request.ContentLength)
byte[] bodyBytes = await controller.Request.BodyReader.ReadAllAsync();
if (bodyBytes.Length != controller.Request.ContentLength)
{
Logger.Warn($"Failed to read entire body, contentType={controller.Request.ContentType}, " +
$"contentLen={controller.Request.ContentLength}, readLen={finalString.Length}",
$"contentLen={controller.Request.ContentLength}, readLen={bodyBytes.Length}",
LogArea.HTTP);
}
return builder.ToString();
return Encoding.ASCII.GetString(bodyBytes);
}
[GeneratedRegex("&(?!(amp|apos|quot|lt|gt);)")]

View file

@ -0,0 +1,26 @@
using System.Buffers;
using System.IO.Pipelines;
using System.Threading.Tasks;
namespace LBPUnion.ProjectLighthouse.Extensions;
public static class PipeExtensions
{
public static async Task<byte[]> ReadAllAsync(this PipeReader reader)
{
do
{
ReadResult readResult = await reader.ReadAsync();
if (readResult.IsCompleted || readResult.IsCanceled)
{
return readResult.Buffer.ToArray();
}
// consume nothing, keep reading from the pipe reader until all data is there
reader.AdvanceTo(readResult.Buffer.Start, readResult.Buffer.End);
}
while (true);
}
}

View file

@ -1,20 +1,15 @@
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Linq;
using System.Security.Cryptography;
using System.Text;
using System.Threading.Tasks;
using LBPUnion.ProjectLighthouse.Extensions;
namespace LBPUnion.ProjectLighthouse.Helpers;
[SuppressMessage("ReSharper", "UnusedMember.Global")]
public static class CryptoHelper
{
// private static readonly SHA1 sha1 = SHA1.Create();
private static readonly SHA256 sha256 = SHA256.Create();
/// <summary>
@ -28,23 +23,18 @@ public static class CryptoHelper
return BCryptHash(Sha256Hash(bytes));
}
public static async Task<string> ComputeDigest(string path, string authCookie, Stream body, string digestKey, bool excludeBody = false)
public static string ComputeDigest(string path, string authCookie, byte[] body, string digestKey, bool excludeBody = false)
{
MemoryStream memoryStream = new();
byte[] pathBytes = Encoding.UTF8.GetBytes(path);
byte[] cookieBytes = string.IsNullOrEmpty(authCookie) ? Array.Empty<byte>() : Encoding.UTF8.GetBytes(authCookie);
byte[] keyBytes = Encoding.UTF8.GetBytes(digestKey);
await body.CopyToAsync(memoryStream);
byte[] bodyBytes = memoryStream.ToArray();
using IncrementalHash sha1 = IncrementalHash.CreateHash(HashAlgorithmName.SHA1);
// LBP games will sometimes opt to calculate the digest without the body
// (one example is resource upload requests)
if (!excludeBody)
sha1.AppendData(bodyBytes);
sha1.AppendData(body);
if (cookieBytes.Length > 0) sha1.AppendData(cookieBytes);
sha1.AppendData(pathBytes);
sha1.AppendData(keyBytes);
@ -162,13 +152,9 @@ public static class CryptoHelper
public static string Sha256Hash(byte[] bytes) => BitConverter.ToString(sha256.ComputeHash(bytes)).Replace("-", "").ToLower();
public static string Sha1Hash(string str) => Sha1Hash(Encoding.UTF8.GetBytes(str));
public static string Sha1Hash(byte[] bytes) => BitConverter.ToString(SHA1.HashData(bytes)).Replace("-", "");
public static string BCryptHash(string str) => BCrypt.Net.BCrypt.HashPassword(str);
public static string BCryptHash(byte[] bytes) => BCrypt.Net.BCrypt.HashPassword(Encoding.UTF8.GetString(bytes));
#endregion
}

View file

@ -1,6 +1,7 @@
using System.Diagnostics;
using System.IO;
using System.Text;
using System.Threading.Tasks;
using LBPUnion.ProjectLighthouse.Extensions;
using LBPUnion.ProjectLighthouse.Logging;
using LBPUnion.ProjectLighthouse.Types.Logging;
using Microsoft.AspNetCore.Http;
@ -23,7 +24,7 @@ public class RequestLogMiddleware : Middleware
ctx.Request.EnableBuffering(); // Allows us to reset the position of Request.Body for later logging
// Log all headers.
// foreach (KeyValuePair<string, StringValues> header in context.Request.Headers) Logger.Log($"{header.Key}: {header.Value}");
// foreach (KeyValuePair<string, StringValues> header in ctx.Request.Headers) Logger.Debug($"{header.Key}: {header.Value}", LogArea.HTTP);
await this.next(ctx); // Handle the request so we can get the status code from it
@ -39,8 +40,8 @@ public class RequestLogMiddleware : Middleware
// Log post body
if (ctx.Request.Method == "POST")
{
ctx.Request.Body.Position = 0;
Logger.Debug(await new StreamReader(ctx.Request.Body).ReadToEndAsync(), LogArea.HTTP);
string body = Encoding.ASCII.GetString(await ctx.Request.BodyReader.ReadAllAsync());
Logger.Debug(body, LogArea.HTTP);
}
#endif
}