From fd210b3125342bc2a384ac7bf46807bb6e33c158 Mon Sep 17 00:00:00 2001 From: Slendy Date: Thu, 29 Feb 2024 15:27:37 -0600 Subject: [PATCH] Rewrite DigestMiddleware to use an opt-in instead of opt-out for endpoints --- .../Middlewares/DigestMiddleware.cs | 201 ++++++++---------- .../Types/UseDigestAttribute.cs | 11 + .../Unit/Middlewares/DigestMiddlewareTests.cs | 119 +++++------ 3 files changed, 159 insertions(+), 172 deletions(-) create mode 100644 ProjectLighthouse.Servers.GameServer/Types/UseDigestAttribute.cs diff --git a/ProjectLighthouse.Servers.GameServer/Middlewares/DigestMiddleware.cs b/ProjectLighthouse.Servers.GameServer/Middlewares/DigestMiddleware.cs index ecb0824c..b4385e72 100644 --- a/ProjectLighthouse.Servers.GameServer/Middlewares/DigestMiddleware.cs +++ b/ProjectLighthouse.Servers.GameServer/Middlewares/DigestMiddleware.cs @@ -2,6 +2,7 @@ using LBPUnion.ProjectLighthouse.Configuration; using LBPUnion.ProjectLighthouse.Extensions; using LBPUnion.ProjectLighthouse.Helpers; using LBPUnion.ProjectLighthouse.Middlewares; +using LBPUnion.ProjectLighthouse.Servers.GameServer.Types; using Microsoft.Extensions.Primitives; using Org.BouncyCastle.Utilities.Zlib; @@ -16,101 +17,17 @@ public class DigestMiddleware : Middleware this.computeDigests = computeDigests; } - #if !DEBUG - private static readonly HashSet exemptPathList = new() + private readonly List digestKeys; + + public DigestMiddleware(RequestDelegate next, List digestKeys) : base(next) { - "/login", - "/eula", - "/announce", - "/status", - "/farc_hashes", - "/t_conf", - "/network_settings.nws", - "/ChallengeConfig.xml", - }; - #endif + this.digestKeys = digestKeys; + } - public override async Task InvokeAsync(HttpContext context) + private static async Task HandleResponseCompression(HttpContext context, MemoryStream responseBuffer) { - // Client digest check. - if (!context.Request.Cookies.TryGetValue("MM_AUTH", out string? authCookie)) authCookie = string.Empty; - string digestPath = context.Request.Path; - #if !DEBUG - const string url = "/LITTLEBIGPLANETPS3_XML"; - string strippedPath = digestPath.Contains(url) ? digestPath[url.Length..] : ""; - #endif - byte[] bodyBytes = await context.Request.BodyReader.ReadAllAsync(); - - bool usedAlternateDigestKey = false; - - if (this.computeDigests && digestPath.StartsWith("/LITTLEBIGPLANETPS3_XML")) - { - // The game sets X-Digest-B on a resource upload instead of X-Digest-A - string digestHeaderKey = "X-Digest-A"; - bool excludeBodyFromDigest = false; - if (digestPath.Contains("/upload/")) - { - digestHeaderKey = "X-Digest-B"; - excludeBodyFromDigest = true; - } - - string clientRequestDigest = CryptoHelper.ComputeDigest(digestPath, - authCookie, - bodyBytes, - ServerConfiguration.Instance.DigestKey.PrimaryDigestKey, - excludeBodyFromDigest); - - // Check the digest we've just calculated against the digest header if the game set the header. They should match. - if (context.Request.Headers.TryGetValue(digestHeaderKey, out StringValues sentDigest)) - { - if (clientRequestDigest != sentDigest) - { - // If we got here, the normal ServerDigestKey failed to validate. Lets try again with the alternate digest key. - usedAlternateDigestKey = true; - - clientRequestDigest = CryptoHelper.ComputeDigest(digestPath, - authCookie, - bodyBytes, - ServerConfiguration.Instance.DigestKey.AlternateDigestKey, - excludeBodyFromDigest); - if (clientRequestDigest != sentDigest) - { - #if DEBUG - Console.WriteLine("Digest failed"); - Console.WriteLine("digestKey: " + ServerConfiguration.Instance.DigestKey.PrimaryDigestKey); - Console.WriteLine("altDigestKey: " + ServerConfiguration.Instance.DigestKey.AlternateDigestKey); - Console.WriteLine("computed digest: " + clientRequestDigest); - #endif - // We still failed to validate. Abort the request. - context.Response.StatusCode = 403; - return; - } - } - } - - #if !DEBUG - // The game doesn't start sending digests until after the announcement so if it's not one of those requests - // and it doesn't include a digest we need to reject the request - else if (!exemptPathList.Contains(strippedPath)) - { - context.Response.StatusCode = 403; - return; - } - #endif - - context.Response.Headers.Append("X-Digest-B", clientRequestDigest); - context.Request.Body.Position = 0; - } - - // This does the same as above, but for the response stream. - await using MemoryStream responseBuffer = new(); - Stream oldResponseStream = context.Response.Body; - context.Response.Body = responseBuffer; - - await this.next(context); // Handle the request so we can get the server digest hash - responseBuffer.Position = 0; - - if (responseBuffer.Length > 1000 && + const int minCompressionLen = 1000; + if (responseBuffer.Length > minCompressionLen && context.Request.Headers.AcceptEncoding.Contains("deflate") && (context.Response.ContentType ?? string.Empty).Contains("text/xml")) { @@ -130,30 +47,94 @@ public class DigestMiddleware : Middleware } else { - string headerName = !context.Response.Headers.ContentLength.HasValue - ? "Content-Length" - : "X-Original-Content-Length"; + string headerName = !context.Response.Headers.ContentLength.HasValue ? "Content-Length" : "X-Original-Content-Length"; context.Response.Headers.Append(headerName, responseBuffer.Length.ToString()); } + } - // Compute the server digest hash. - if (this.computeDigests) + public override async Task InvokeAsync(HttpContext context) + { + UseDigestAttribute? digestAttribute = context.GetEndpoint()?.Metadata.OfType().FirstOrDefault(); + if (digestAttribute == null) { - responseBuffer.Position = 0; - - string digestKey = usedAlternateDigestKey - ? ServerConfiguration.Instance.DigestKey.AlternateDigestKey - : ServerConfiguration.Instance.DigestKey.PrimaryDigestKey; - - // Compute the digest for the response. - string serverDigest = - CryptoHelper.ComputeDigest(context.Request.Path, authCookie, responseBuffer.ToArray(), digestKey); - context.Response.Headers.Append("X-Digest-A", serverDigest); + await this.next(context); + return; } - // Copy the buffered response to the actual response stream. + if (!context.Request.Cookies.TryGetValue("MM_AUTH", out string? authCookie)) + { + context.Response.StatusCode = 403; + return; + } + + string digestPath = context.Request.Path; + + byte[] bodyBytes = await context.Request.BodyReader.ReadAllAsync(); + + if (!context.Request.Headers.TryGetValue(digestAttribute.DigestHeaderName, out StringValues digestHeaders) || + digestHeaders.Count != 1 && digestAttribute.EnforceDigest) + { + context.Response.StatusCode = 403; + return; + } + + string? clientDigest = digestHeaders[0]; + + string? matchingDigestKey = null; + string? calculatedRequestDigest = null; + + foreach (string digestKey in this.digestKeys) + { + string calculatedDigest = CryptoHelper.ComputeDigest(digestPath, + authCookie, + bodyBytes, + digestKey, + digestAttribute.ExcludeBodyFromDigest); + if (calculatedDigest != clientDigest) continue; + + matchingDigestKey = digestKey; + calculatedRequestDigest = calculatedDigest; + } + + matchingDigestKey ??= this.digestKeys.First(); + + switch (matchingDigestKey) + { + case null when digestAttribute.EnforceDigest: + context.Response.StatusCode = 403; + return; + case null: + calculatedRequestDigest = CryptoHelper.ComputeDigest(digestPath, + authCookie, + bodyBytes, + matchingDigestKey, + digestAttribute.ExcludeBodyFromDigest); + break; + } + + context.Response.Headers.Append("X-Digest-B", calculatedRequestDigest); + // context.Request.Body.Position = 0; + + // Let endpoint generate response so we can calculate the digest for it + Stream originalBody = context.Response.Body; + await using MemoryStream responseBuffer = new(); + context.Response.Body = responseBuffer; + + await this.next(context); + + await HandleResponseCompression(context, responseBuffer); + + string responseDigest = CryptoHelper.ComputeDigest(digestPath, + authCookie, + responseBuffer.ToArray(), + matchingDigestKey, + digestAttribute.ExcludeBodyFromDigest); + + context.Response.Headers.Append("X-Digest-A", responseDigest); + responseBuffer.Position = 0; - await responseBuffer.CopyToAsync(oldResponseStream); - context.Response.Body = oldResponseStream; + await responseBuffer.CopyToAsync(originalBody); + context.Response.Body = originalBody; } + } \ No newline at end of file diff --git a/ProjectLighthouse.Servers.GameServer/Types/UseDigestAttribute.cs b/ProjectLighthouse.Servers.GameServer/Types/UseDigestAttribute.cs new file mode 100644 index 00000000..55e03c1f --- /dev/null +++ b/ProjectLighthouse.Servers.GameServer/Types/UseDigestAttribute.cs @@ -0,0 +1,11 @@ +namespace LBPUnion.ProjectLighthouse.Servers.GameServer.Types; + +[AttributeUsage(AttributeTargets.Class | AttributeTargets.Method, AllowMultiple = true)] +public class UseDigestAttribute : Attribute +{ + public bool EnforceDigest { get; set; } = true; + + public string DigestHeaderName { get; set; } = "X-Digest-A"; + + public bool ExcludeBodyFromDigest { get; set; } = false; +} \ No newline at end of file diff --git a/ProjectLighthouse.Tests.GameApiTests/Unit/Middlewares/DigestMiddlewareTests.cs b/ProjectLighthouse.Tests.GameApiTests/Unit/Middlewares/DigestMiddlewareTests.cs index b6b7354a..95a1205c 100644 --- a/ProjectLighthouse.Tests.GameApiTests/Unit/Middlewares/DigestMiddlewareTests.cs +++ b/ProjectLighthouse.Tests.GameApiTests/Unit/Middlewares/DigestMiddlewareTests.cs @@ -3,8 +3,8 @@ using System.Collections.Generic; using System.IO; using System.Text; using System.Threading.Tasks; -using LBPUnion.ProjectLighthouse.Configuration; using LBPUnion.ProjectLighthouse.Servers.GameServer.Middlewares; +using LBPUnion.ProjectLighthouse.Servers.GameServer.Types; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Primitives; using Xunit; @@ -14,25 +14,44 @@ namespace ProjectLighthouse.Tests.GameApiTests.Unit.Middlewares; [Trait("Category", "Unit")] public class DigestMiddlewareTests { - - [Fact] - public async Task DigestMiddleware_ShouldNotComputeDigests_WhenDigestsDisabled() + //TODO: fix remaining unit tests + private static DefaultHttpContext GetHttpContext + (Stream body, string path, string cookie, Dictionary? extraHeaders = null) { + DefaultHttpContext context = new() { Request = { - Body = new MemoryStream(), - Path = "/LITTLEBIGPLANETPS3_XML/notification", - Headers = { KeyValuePair.Create("Cookie", "MM_AUTH=unittest"), }, + Body = body, + Path = path, + Headers = + { + KeyValuePair.Create("Cookie", cookie), + } }, }; + if (extraHeaders == null) return context; + + foreach ((string key, StringValues value) in extraHeaders) + { + context.Request.Headers.Append(key, value); + } + + return context; + } + + [Fact] + public async Task DigestMiddleware_ShouldNotComputeDigests_WithoutDigestAttribute() + { + DefaultHttpContext context = GetHttpContext(new MemoryStream(), "/LITTLEBIGPLANETPS3_XML/notification", "MM_AUTH=unittest"); + context.SetEndpoint(new Endpoint(null, new EndpointMetadataCollection(), null)); DigestMiddleware middleware = new(httpContext => { httpContext.Response.StatusCode = 200; httpContext.Response.WriteAsync(""); return Task.CompletedTask; - }, false); + }, []); await middleware.InvokeAsync(context); @@ -46,26 +65,15 @@ public class DigestMiddlewareTests [Fact] public async Task DigestMiddleware_ShouldReject_WhenDigestHeaderIsMissing() { - DefaultHttpContext context = new() - { - Request = - { - Body = new MemoryStream(), - Path = "/LITTLEBIGPLANETPS3_XML/notification", - Headers = - { - KeyValuePair.Create("Cookie", "MM_AUTH=unittest"), - }, - }, - }; - ServerConfiguration.Instance.DigestKey.PrimaryDigestKey = "bruh"; + DefaultHttpContext context = GetHttpContext(new MemoryStream(), "/LITTLEBIGPLANETPS3_XML/notification", "MM_AUTH=unittest"); + context.SetEndpoint(new Endpoint(null, new EndpointMetadataCollection(new UseDigestAttribute()), null)); DigestMiddleware middleware = new(httpContext => { httpContext.Response.StatusCode = 200; httpContext.Response.WriteAsync(""); return Task.CompletedTask; }, - true); + ["bruh",]); await middleware.InvokeAsync(context); @@ -80,28 +88,23 @@ public class DigestMiddlewareTests [Fact] public async Task DigestMiddleware_ShouldReject_WhenRequestDigestInvalid() { - DefaultHttpContext context = new() - { - Request = + DefaultHttpContext context = GetHttpContext(new MemoryStream(), + "/LITTLEBIGPLANETPS3_XML/notification", + "MM_AUTH=unittest", + new Dictionary { - Body = new MemoryStream(), - Path = "/LITTLEBIGPLANETPS3_XML/notification", - Headers = { - KeyValuePair.Create("Cookie", "MM_AUTH=unittest"), - KeyValuePair.Create("X-Digest-A", "invalid_digest"), + "X-Digest-A", "invalid_digest" }, - }, - }; - ServerConfiguration.Instance.DigestKey.PrimaryDigestKey = "bruh"; - ServerConfiguration.Instance.DigestKey.AlternateDigestKey = "test"; + }); + context.SetEndpoint(new Endpoint(null, new EndpointMetadataCollection(new UseDigestAttribute()), null)); DigestMiddleware middleware = new(httpContext => { httpContext.Response.StatusCode = 200; httpContext.Response.WriteAsync(""); return Task.CompletedTask; }, - true); + ["bruh",]); await middleware.InvokeAsync(context); @@ -115,28 +118,23 @@ public class DigestMiddlewareTests [Fact] public async Task DigestMiddleware_ShouldUseAlternateDigest_WhenPrimaryDigestInvalid() { - DefaultHttpContext context = new() - { - Request = + DefaultHttpContext context = GetHttpContext(new MemoryStream(), + "/LITTLEBIGPLANETPS3_XML/notification", + "MM_AUTH=unittest", + new Dictionary { - Body = new MemoryStream(), - Path = "/LITTLEBIGPLANETPS3_XML/notification", - Headers = { - KeyValuePair.Create("Cookie", "MM_AUTH=unittest"), - KeyValuePair.Create("X-Digest-A", "df619790a2579a077eae4a6b6864966ff4768723"), + "X-Digest-A", "df619790a2579a077eae4a6b6864966ff4768723" }, - }, - }; - ServerConfiguration.Instance.DigestKey.PrimaryDigestKey = "test"; - ServerConfiguration.Instance.DigestKey.AlternateDigestKey = "bruh"; + }); + DigestMiddleware middleware = new(httpContext => { httpContext.Response.StatusCode = 200; httpContext.Response.WriteAsync(""); return Task.CompletedTask; }, - true); + ["test, bruh",]); await middleware.InvokeAsync(context); @@ -166,14 +164,14 @@ public class DigestMiddlewareTests }, }, }; - ServerConfiguration.Instance.DigestKey.PrimaryDigestKey = "bruh"; + DigestMiddleware middleware = new(httpContext => { httpContext.Response.StatusCode = 200; httpContext.Response.WriteAsync(""); return Task.CompletedTask; }, - true); + ["bruh",]); await middleware.InvokeAsync(context); @@ -203,14 +201,14 @@ public class DigestMiddlewareTests }, }, }; - ServerConfiguration.Instance.DigestKey.PrimaryDigestKey = "bruh"; + DigestMiddleware middleware = new(httpContext => { httpContext.Response.StatusCode = 200; httpContext.Response.WriteAsync(""); return Task.CompletedTask; }, - true); + ["bruh",]); await middleware.InvokeAsync(context); @@ -241,14 +239,14 @@ public class DigestMiddlewareTests }, }, }; - ServerConfiguration.Instance.DigestKey.PrimaryDigestKey = "bruh"; + DigestMiddleware middleware = new(httpContext => { httpContext.Response.StatusCode = 200; httpContext.Response.WriteAsync(""); return Task.CompletedTask; }, - true); + ["bruh",]); await middleware.InvokeAsync(context); @@ -279,14 +277,14 @@ public class DigestMiddlewareTests }, }, }; - ServerConfiguration.Instance.DigestKey.PrimaryDigestKey = "bruh"; + DigestMiddleware middleware = new(httpContext => { httpContext.Response.StatusCode = 200; httpContext.Response.WriteAsync(""); return Task.CompletedTask; }, - true); + ["bruh",]); await middleware.InvokeAsync(context); @@ -317,14 +315,13 @@ public class DigestMiddlewareTests }, }, }; - ServerConfiguration.Instance.DigestKey.PrimaryDigestKey = "bruh"; DigestMiddleware middleware = new(httpContext => { httpContext.Response.StatusCode = 200; httpContext.Response.WriteAsync("digest test"); return Task.CompletedTask; }, - true); + ["bruh",]); await middleware.InvokeAsync(context); @@ -355,14 +352,13 @@ public class DigestMiddlewareTests }, }, }; - ServerConfiguration.Instance.DigestKey.PrimaryDigestKey = "bruh"; DigestMiddleware middleware = new(httpContext => { httpContext.Response.StatusCode = 200; httpContext.Response.WriteAsync(""); return Task.CompletedTask; }, - true); + ["bruh",]); await middleware.InvokeAsync(context); @@ -398,7 +394,6 @@ public class DigestMiddlewareTests }, }, }; - ServerConfiguration.Instance.DigestKey.PrimaryDigestKey = "bruh"; DigestMiddleware middleware = new(httpContext => { httpContext.Response.StatusCode = 200; @@ -406,7 +401,7 @@ public class DigestMiddlewareTests httpContext.Response.Headers.ContentType = "text/xml"; return Task.CompletedTask; }, - true); + ["bruh",]); await middleware.InvokeAsync(context);