Refactor code for reading the body of a request

This commit is contained in:
Slendy 2023-04-06 22:23:49 -05:00
parent 2372dfbb9e
commit 4150e44f80
No known key found for this signature in database
GPG key ID: 7288D68361B91428
6 changed files with 45 additions and 69 deletions

View file

@ -30,9 +30,7 @@ public class LoginController : ControllerBase
[HttpPost] [HttpPost]
public async Task<IActionResult> Login() public async Task<IActionResult> Login()
{ {
MemoryStream ms = new(); byte[] loginData = await this.Request.BodyReader.ReadAllAsync();
await this.Request.Body.CopyToAsync(ms);
byte[] loginData = ms.ToArray();
NPTicket? npTicket; NPTicket? npTicket;
try try

View file

@ -1,5 +1,6 @@
using System.IO.Compression; using System.IO.Compression;
using LBPUnion.ProjectLighthouse.Configuration; using LBPUnion.ProjectLighthouse.Configuration;
using LBPUnion.ProjectLighthouse.Extensions;
using LBPUnion.ProjectLighthouse.Helpers; using LBPUnion.ProjectLighthouse.Helpers;
using LBPUnion.ProjectLighthouse.Middlewares; using LBPUnion.ProjectLighthouse.Middlewares;
using Microsoft.Extensions.Primitives; using Microsoft.Extensions.Primitives;
@ -39,7 +40,7 @@ public class DigestMiddleware : Middleware
const string url = "/LITTLEBIGPLANETPS3_XML"; const string url = "/LITTLEBIGPLANETPS3_XML";
string strippedPath = digestPath.Contains(url) ? digestPath[url.Length..] : ""; string strippedPath = digestPath.Contains(url) ? digestPath[url.Length..] : "";
#endif #endif
Stream body = context.Request.Body; byte[] bodyBytes = await context.Request.BodyReader.ReadAllAsync();
bool usedAlternateDigestKey = false; bool usedAlternateDigestKey = false;
@ -54,8 +55,8 @@ public class DigestMiddleware : Middleware
excludeBodyFromDigest = true; excludeBodyFromDigest = true;
} }
string clientRequestDigest = await CryptoHelper.ComputeDigest string clientRequestDigest = CryptoHelper.ComputeDigest
(digestPath, authCookie, body, ServerConfiguration.Instance.DigestKey.PrimaryDigestKey, excludeBodyFromDigest); (digestPath, authCookie, bodyBytes, ServerConfiguration.Instance.DigestKey.PrimaryDigestKey, excludeBodyFromDigest);
// Check the digest we've just calculated against the digest header if the game set the header. They should match. // Check the digest we've just calculated against the digest header if the game set the header. They should match.
if (context.Request.Headers.TryGetValue(digestHeaderKey, out StringValues sentDigest)) if (context.Request.Headers.TryGetValue(digestHeaderKey, out StringValues sentDigest))
@ -65,11 +66,8 @@ public class DigestMiddleware : Middleware
// If we got here, the normal ServerDigestKey failed to validate. Lets try again with the alternate digest key. // If we got here, the normal ServerDigestKey failed to validate. Lets try again with the alternate digest key.
usedAlternateDigestKey = true; usedAlternateDigestKey = true;
// Reset the body stream clientRequestDigest = CryptoHelper.ComputeDigest
body.Position = 0; (digestPath, authCookie, bodyBytes, ServerConfiguration.Instance.DigestKey.AlternateDigestKey, excludeBodyFromDigest);
clientRequestDigest = await CryptoHelper.ComputeDigest
(digestPath, authCookie, body, ServerConfiguration.Instance.DigestKey.AlternateDigestKey, excludeBodyFromDigest);
if (clientRequestDigest != sentDigest) if (clientRequestDigest != sentDigest)
{ {
#if DEBUG #if DEBUG
@ -116,7 +114,7 @@ public class DigestMiddleware : Middleware
: ServerConfiguration.Instance.DigestKey.PrimaryDigestKey; : ServerConfiguration.Instance.DigestKey.PrimaryDigestKey;
// Compute the digest for the response. // Compute the digest for the response.
string serverDigest = await CryptoHelper.ComputeDigest(context.Request.Path, authCookie, responseBuffer, digestKey); string serverDigest = CryptoHelper.ComputeDigest(context.Request.Path, authCookie, responseBuffer.ToArray(), digestKey);
context.Response.Headers.Add("X-Digest-A", serverDigest); context.Response.Headers.Add("X-Digest-A", serverDigest);
} }

View file

@ -1,8 +1,6 @@
#nullable enable #nullable enable
using System; using System;
using System.Buffers;
using System.IO; using System.IO;
using System.IO.Pipelines;
using System.Linq; using System.Linq;
using System.Text; using System.Text;
using System.Text.RegularExpressions; using System.Text.RegularExpressions;
@ -28,47 +26,16 @@ public static partial class ControllerExtensions
return token; return token;
} }
private static void AddStringToBuilder(StringBuilder builder, in ReadOnlySequence<byte> readOnlySequence)
{
// Separate method because Span/ReadOnlySpan cannot be used in async methods
ReadOnlySpan<byte> span = readOnlySequence.IsSingleSegment
? readOnlySequence.First.Span
: readOnlySequence.ToArray();
builder.Append(Encoding.UTF8.GetString(span));
}
public static async Task<string> ReadBodyAsync(this ControllerBase controller) public static async Task<string> ReadBodyAsync(this ControllerBase controller)
{ {
controller.Request.Body.Position = 0; byte[] bodyBytes = await controller.Request.BodyReader.ReadAllAsync();
StringBuilder builder = new(); if (bodyBytes.Length != controller.Request.ContentLength)
while (true)
{
ReadResult readResult = await controller.Request.BodyReader.ReadAsync();
ReadOnlySequence<byte> buffer = readResult.Buffer;
if (buffer.Length > 0)
{
AddStringToBuilder(builder, buffer);
}
controller.Request.BodyReader.AdvanceTo(buffer.End);
if (readResult.IsCompleted)
{
break;
}
}
string finalString = builder.ToString();
if (finalString.Length != controller.Request.ContentLength)
{ {
Logger.Warn($"Failed to read entire body, contentType={controller.Request.ContentType}, " + Logger.Warn($"Failed to read entire body, contentType={controller.Request.ContentType}, " +
$"contentLen={controller.Request.ContentLength}, readLen={finalString.Length}", $"contentLen={controller.Request.ContentLength}, readLen={bodyBytes.Length}",
LogArea.HTTP); LogArea.HTTP);
} }
return Encoding.ASCII.GetString(bodyBytes);
return builder.ToString();
} }
[GeneratedRegex("&(?!(amp|apos|quot|lt|gt);)")] [GeneratedRegex("&(?!(amp|apos|quot|lt|gt);)")]

View file

@ -0,0 +1,26 @@
using System.Buffers;
using System.IO.Pipelines;
using System.Threading.Tasks;
namespace LBPUnion.ProjectLighthouse.Extensions;
public static class PipeExtensions
{
public static async Task<byte[]> ReadAllAsync(this PipeReader reader)
{
do
{
ReadResult readResult = await reader.ReadAsync();
if (readResult.IsCompleted || readResult.IsCanceled)
{
return readResult.Buffer.ToArray();
}
// consume nothing, keep reading from the pipe reader until all data is there
reader.AdvanceTo(readResult.Buffer.Start, readResult.Buffer.End);
}
while (true);
}
}

View file

@ -1,20 +1,15 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Linq; using System.Linq;
using System.Security.Cryptography; using System.Security.Cryptography;
using System.Text; using System.Text;
using System.Threading.Tasks;
using LBPUnion.ProjectLighthouse.Extensions; using LBPUnion.ProjectLighthouse.Extensions;
namespace LBPUnion.ProjectLighthouse.Helpers; namespace LBPUnion.ProjectLighthouse.Helpers;
[SuppressMessage("ReSharper", "UnusedMember.Global")]
public static class CryptoHelper public static class CryptoHelper
{ {
// private static readonly SHA1 sha1 = SHA1.Create();
private static readonly SHA256 sha256 = SHA256.Create(); private static readonly SHA256 sha256 = SHA256.Create();
/// <summary> /// <summary>
@ -28,23 +23,18 @@ public static class CryptoHelper
return BCryptHash(Sha256Hash(bytes)); return BCryptHash(Sha256Hash(bytes));
} }
public static async Task<string> ComputeDigest(string path, string authCookie, Stream body, string digestKey, bool excludeBody = false) public static string ComputeDigest(string path, string authCookie, byte[] body, string digestKey, bool excludeBody = false)
{ {
MemoryStream memoryStream = new();
byte[] pathBytes = Encoding.UTF8.GetBytes(path); byte[] pathBytes = Encoding.UTF8.GetBytes(path);
byte[] cookieBytes = string.IsNullOrEmpty(authCookie) ? Array.Empty<byte>() : Encoding.UTF8.GetBytes(authCookie); byte[] cookieBytes = string.IsNullOrEmpty(authCookie) ? Array.Empty<byte>() : Encoding.UTF8.GetBytes(authCookie);
byte[] keyBytes = Encoding.UTF8.GetBytes(digestKey); byte[] keyBytes = Encoding.UTF8.GetBytes(digestKey);
await body.CopyToAsync(memoryStream);
byte[] bodyBytes = memoryStream.ToArray();
using IncrementalHash sha1 = IncrementalHash.CreateHash(HashAlgorithmName.SHA1); using IncrementalHash sha1 = IncrementalHash.CreateHash(HashAlgorithmName.SHA1);
// LBP games will sometimes opt to calculate the digest without the body // LBP games will sometimes opt to calculate the digest without the body
// (one example is resource upload requests) // (one example is resource upload requests)
if (!excludeBody) if (!excludeBody)
sha1.AppendData(bodyBytes); sha1.AppendData(body);
if (cookieBytes.Length > 0) sha1.AppendData(cookieBytes); if (cookieBytes.Length > 0) sha1.AppendData(cookieBytes);
sha1.AppendData(pathBytes); sha1.AppendData(pathBytes);
sha1.AppendData(keyBytes); sha1.AppendData(keyBytes);
@ -162,13 +152,9 @@ public static class CryptoHelper
public static string Sha256Hash(byte[] bytes) => BitConverter.ToString(sha256.ComputeHash(bytes)).Replace("-", "").ToLower(); public static string Sha256Hash(byte[] bytes) => BitConverter.ToString(sha256.ComputeHash(bytes)).Replace("-", "").ToLower();
public static string Sha1Hash(string str) => Sha1Hash(Encoding.UTF8.GetBytes(str));
public static string Sha1Hash(byte[] bytes) => BitConverter.ToString(SHA1.HashData(bytes)).Replace("-", ""); public static string Sha1Hash(byte[] bytes) => BitConverter.ToString(SHA1.HashData(bytes)).Replace("-", "");
public static string BCryptHash(string str) => BCrypt.Net.BCrypt.HashPassword(str); public static string BCryptHash(string str) => BCrypt.Net.BCrypt.HashPassword(str);
public static string BCryptHash(byte[] bytes) => BCrypt.Net.BCrypt.HashPassword(Encoding.UTF8.GetString(bytes));
#endregion #endregion
} }

View file

@ -1,6 +1,7 @@
using System.Diagnostics; using System.Diagnostics;
using System.IO; using System.Text;
using System.Threading.Tasks; using System.Threading.Tasks;
using LBPUnion.ProjectLighthouse.Extensions;
using LBPUnion.ProjectLighthouse.Logging; using LBPUnion.ProjectLighthouse.Logging;
using LBPUnion.ProjectLighthouse.Types.Logging; using LBPUnion.ProjectLighthouse.Types.Logging;
using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http;
@ -23,7 +24,7 @@ public class RequestLogMiddleware : Middleware
ctx.Request.EnableBuffering(); // Allows us to reset the position of Request.Body for later logging ctx.Request.EnableBuffering(); // Allows us to reset the position of Request.Body for later logging
// Log all headers. // Log all headers.
// foreach (KeyValuePair<string, StringValues> header in context.Request.Headers) Logger.Log($"{header.Key}: {header.Value}"); // foreach (KeyValuePair<string, StringValues> header in ctx.Request.Headers) Logger.Debug($"{header.Key}: {header.Value}", LogArea.HTTP);
await this.next(ctx); // Handle the request so we can get the status code from it await this.next(ctx); // Handle the request so we can get the status code from it
@ -39,8 +40,8 @@ public class RequestLogMiddleware : Middleware
// Log post body // Log post body
if (ctx.Request.Method == "POST") if (ctx.Request.Method == "POST")
{ {
ctx.Request.Body.Position = 0; string body = Encoding.ASCII.GetString(await ctx.Request.BodyReader.ReadAllAsync());
Logger.Debug(await new StreamReader(ctx.Request.Body).ReadToEndAsync(), LogArea.HTTP); Logger.Debug(body, LogArea.HTTP);
} }
#endif #endif
} }