Rewrite DigestMiddleware to use an opt-in instead of opt-out for endpoints

This commit is contained in:
Slendy 2024-02-29 15:27:37 -06:00
commit fd210b3125
No known key found for this signature in database
GPG key ID: 7288D68361B91428
3 changed files with 159 additions and 172 deletions

View file

@ -2,6 +2,7 @@ using LBPUnion.ProjectLighthouse.Configuration;
using LBPUnion.ProjectLighthouse.Extensions;
using LBPUnion.ProjectLighthouse.Helpers;
using LBPUnion.ProjectLighthouse.Middlewares;
using LBPUnion.ProjectLighthouse.Servers.GameServer.Types;
using Microsoft.Extensions.Primitives;
using Org.BouncyCastle.Utilities.Zlib;
@ -16,101 +17,17 @@ public class DigestMiddleware : Middleware
this.computeDigests = computeDigests;
}
#if !DEBUG
private static readonly HashSet<string> exemptPathList = new()
private readonly List<string> digestKeys;
public DigestMiddleware(RequestDelegate next, List<string> digestKeys) : base(next)
{
"/login",
"/eula",
"/announce",
"/status",
"/farc_hashes",
"/t_conf",
"/network_settings.nws",
"/ChallengeConfig.xml",
};
#endif
this.digestKeys = digestKeys;
}
public override async Task InvokeAsync(HttpContext context)
private static async Task HandleResponseCompression(HttpContext context, MemoryStream responseBuffer)
{
// Client digest check.
if (!context.Request.Cookies.TryGetValue("MM_AUTH", out string? authCookie)) authCookie = string.Empty;
string digestPath = context.Request.Path;
#if !DEBUG
const string url = "/LITTLEBIGPLANETPS3_XML";
string strippedPath = digestPath.Contains(url) ? digestPath[url.Length..] : "";
#endif
byte[] bodyBytes = await context.Request.BodyReader.ReadAllAsync();
bool usedAlternateDigestKey = false;
if (this.computeDigests && digestPath.StartsWith("/LITTLEBIGPLANETPS3_XML"))
{
// The game sets X-Digest-B on a resource upload instead of X-Digest-A
string digestHeaderKey = "X-Digest-A";
bool excludeBodyFromDigest = false;
if (digestPath.Contains("/upload/"))
{
digestHeaderKey = "X-Digest-B";
excludeBodyFromDigest = true;
}
string clientRequestDigest = CryptoHelper.ComputeDigest(digestPath,
authCookie,
bodyBytes,
ServerConfiguration.Instance.DigestKey.PrimaryDigestKey,
excludeBodyFromDigest);
// Check the digest we've just calculated against the digest header if the game set the header. They should match.
if (context.Request.Headers.TryGetValue(digestHeaderKey, out StringValues sentDigest))
{
if (clientRequestDigest != sentDigest)
{
// If we got here, the normal ServerDigestKey failed to validate. Lets try again with the alternate digest key.
usedAlternateDigestKey = true;
clientRequestDigest = CryptoHelper.ComputeDigest(digestPath,
authCookie,
bodyBytes,
ServerConfiguration.Instance.DigestKey.AlternateDigestKey,
excludeBodyFromDigest);
if (clientRequestDigest != sentDigest)
{
#if DEBUG
Console.WriteLine("Digest failed");
Console.WriteLine("digestKey: " + ServerConfiguration.Instance.DigestKey.PrimaryDigestKey);
Console.WriteLine("altDigestKey: " + ServerConfiguration.Instance.DigestKey.AlternateDigestKey);
Console.WriteLine("computed digest: " + clientRequestDigest);
#endif
// We still failed to validate. Abort the request.
context.Response.StatusCode = 403;
return;
}
}
}
#if !DEBUG
// The game doesn't start sending digests until after the announcement so if it's not one of those requests
// and it doesn't include a digest we need to reject the request
else if (!exemptPathList.Contains(strippedPath))
{
context.Response.StatusCode = 403;
return;
}
#endif
context.Response.Headers.Append("X-Digest-B", clientRequestDigest);
context.Request.Body.Position = 0;
}
// This does the same as above, but for the response stream.
await using MemoryStream responseBuffer = new();
Stream oldResponseStream = context.Response.Body;
context.Response.Body = responseBuffer;
await this.next(context); // Handle the request so we can get the server digest hash
responseBuffer.Position = 0;
if (responseBuffer.Length > 1000 &&
const int minCompressionLen = 1000;
if (responseBuffer.Length > minCompressionLen &&
context.Request.Headers.AcceptEncoding.Contains("deflate") &&
(context.Response.ContentType ?? string.Empty).Contains("text/xml"))
{
@ -130,30 +47,94 @@ public class DigestMiddleware : Middleware
}
else
{
string headerName = !context.Response.Headers.ContentLength.HasValue
? "Content-Length"
: "X-Original-Content-Length";
string headerName = !context.Response.Headers.ContentLength.HasValue ? "Content-Length" : "X-Original-Content-Length";
context.Response.Headers.Append(headerName, responseBuffer.Length.ToString());
}
}
// Compute the server digest hash.
if (this.computeDigests)
public override async Task InvokeAsync(HttpContext context)
{
UseDigestAttribute? digestAttribute = context.GetEndpoint()?.Metadata.OfType<UseDigestAttribute>().FirstOrDefault();
if (digestAttribute == null)
{
responseBuffer.Position = 0;
string digestKey = usedAlternateDigestKey
? ServerConfiguration.Instance.DigestKey.AlternateDigestKey
: ServerConfiguration.Instance.DigestKey.PrimaryDigestKey;
// Compute the digest for the response.
string serverDigest =
CryptoHelper.ComputeDigest(context.Request.Path, authCookie, responseBuffer.ToArray(), digestKey);
context.Response.Headers.Append("X-Digest-A", serverDigest);
await this.next(context);
return;
}
// Copy the buffered response to the actual response stream.
if (!context.Request.Cookies.TryGetValue("MM_AUTH", out string? authCookie))
{
context.Response.StatusCode = 403;
return;
}
string digestPath = context.Request.Path;
byte[] bodyBytes = await context.Request.BodyReader.ReadAllAsync();
if (!context.Request.Headers.TryGetValue(digestAttribute.DigestHeaderName, out StringValues digestHeaders) ||
digestHeaders.Count != 1 && digestAttribute.EnforceDigest)
{
context.Response.StatusCode = 403;
return;
}
string? clientDigest = digestHeaders[0];
string? matchingDigestKey = null;
string? calculatedRequestDigest = null;
foreach (string digestKey in this.digestKeys)
{
string calculatedDigest = CryptoHelper.ComputeDigest(digestPath,
authCookie,
bodyBytes,
digestKey,
digestAttribute.ExcludeBodyFromDigest);
if (calculatedDigest != clientDigest) continue;
matchingDigestKey = digestKey;
calculatedRequestDigest = calculatedDigest;
}
matchingDigestKey ??= this.digestKeys.First();
switch (matchingDigestKey)
{
case null when digestAttribute.EnforceDigest:
context.Response.StatusCode = 403;
return;
case null:
calculatedRequestDigest = CryptoHelper.ComputeDigest(digestPath,
authCookie,
bodyBytes,
matchingDigestKey,
digestAttribute.ExcludeBodyFromDigest);
break;
}
context.Response.Headers.Append("X-Digest-B", calculatedRequestDigest);
// context.Request.Body.Position = 0;
// Let endpoint generate response so we can calculate the digest for it
Stream originalBody = context.Response.Body;
await using MemoryStream responseBuffer = new();
context.Response.Body = responseBuffer;
await this.next(context);
await HandleResponseCompression(context, responseBuffer);
string responseDigest = CryptoHelper.ComputeDigest(digestPath,
authCookie,
responseBuffer.ToArray(),
matchingDigestKey,
digestAttribute.ExcludeBodyFromDigest);
context.Response.Headers.Append("X-Digest-A", responseDigest);
responseBuffer.Position = 0;
await responseBuffer.CopyToAsync(oldResponseStream);
context.Response.Body = oldResponseStream;
await responseBuffer.CopyToAsync(originalBody);
context.Response.Body = originalBody;
}
}

View file

@ -0,0 +1,11 @@
namespace LBPUnion.ProjectLighthouse.Servers.GameServer.Types;
[AttributeUsage(AttributeTargets.Class | AttributeTargets.Method, AllowMultiple = true)]
public class UseDigestAttribute : Attribute
{
public bool EnforceDigest { get; set; } = true;
public string DigestHeaderName { get; set; } = "X-Digest-A";
public bool ExcludeBodyFromDigest { get; set; } = false;
}

View file

@ -3,8 +3,8 @@ using System.Collections.Generic;
using System.IO;
using System.Text;
using System.Threading.Tasks;
using LBPUnion.ProjectLighthouse.Configuration;
using LBPUnion.ProjectLighthouse.Servers.GameServer.Middlewares;
using LBPUnion.ProjectLighthouse.Servers.GameServer.Types;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Primitives;
using Xunit;
@ -14,25 +14,44 @@ namespace ProjectLighthouse.Tests.GameApiTests.Unit.Middlewares;
[Trait("Category", "Unit")]
public class DigestMiddlewareTests
{
[Fact]
public async Task DigestMiddleware_ShouldNotComputeDigests_WhenDigestsDisabled()
//TODO: fix remaining unit tests
private static DefaultHttpContext GetHttpContext
(Stream body, string path, string cookie, Dictionary<string, StringValues>? extraHeaders = null)
{
DefaultHttpContext context = new()
{
Request =
{
Body = new MemoryStream(),
Path = "/LITTLEBIGPLANETPS3_XML/notification",
Headers = { KeyValuePair.Create<string, StringValues>("Cookie", "MM_AUTH=unittest"), },
Body = body,
Path = path,
Headers =
{
KeyValuePair.Create<string, StringValues>("Cookie", cookie),
}
},
};
if (extraHeaders == null) return context;
foreach ((string key, StringValues value) in extraHeaders)
{
context.Request.Headers.Append(key, value);
}
return context;
}
[Fact]
public async Task DigestMiddleware_ShouldNotComputeDigests_WithoutDigestAttribute()
{
DefaultHttpContext context = GetHttpContext(new MemoryStream(), "/LITTLEBIGPLANETPS3_XML/notification", "MM_AUTH=unittest");
context.SetEndpoint(new Endpoint(null, new EndpointMetadataCollection(), null));
DigestMiddleware middleware = new(httpContext =>
{
httpContext.Response.StatusCode = 200;
httpContext.Response.WriteAsync("");
return Task.CompletedTask;
}, false);
}, []);
await middleware.InvokeAsync(context);
@ -46,26 +65,15 @@ public class DigestMiddlewareTests
[Fact]
public async Task DigestMiddleware_ShouldReject_WhenDigestHeaderIsMissing()
{
DefaultHttpContext context = new()
{
Request =
{
Body = new MemoryStream(),
Path = "/LITTLEBIGPLANETPS3_XML/notification",
Headers =
{
KeyValuePair.Create<string, StringValues>("Cookie", "MM_AUTH=unittest"),
},
},
};
ServerConfiguration.Instance.DigestKey.PrimaryDigestKey = "bruh";
DefaultHttpContext context = GetHttpContext(new MemoryStream(), "/LITTLEBIGPLANETPS3_XML/notification", "MM_AUTH=unittest");
context.SetEndpoint(new Endpoint(null, new EndpointMetadataCollection(new UseDigestAttribute()), null));
DigestMiddleware middleware = new(httpContext =>
{
httpContext.Response.StatusCode = 200;
httpContext.Response.WriteAsync("");
return Task.CompletedTask;
},
true);
["bruh",]);
await middleware.InvokeAsync(context);
@ -80,28 +88,23 @@ public class DigestMiddlewareTests
[Fact]
public async Task DigestMiddleware_ShouldReject_WhenRequestDigestInvalid()
{
DefaultHttpContext context = new()
{
Request =
DefaultHttpContext context = GetHttpContext(new MemoryStream(),
"/LITTLEBIGPLANETPS3_XML/notification",
"MM_AUTH=unittest",
new Dictionary<string, StringValues>
{
Body = new MemoryStream(),
Path = "/LITTLEBIGPLANETPS3_XML/notification",
Headers =
{
KeyValuePair.Create<string, StringValues>("Cookie", "MM_AUTH=unittest"),
KeyValuePair.Create<string, StringValues>("X-Digest-A", "invalid_digest"),
"X-Digest-A", "invalid_digest"
},
},
};
ServerConfiguration.Instance.DigestKey.PrimaryDigestKey = "bruh";
ServerConfiguration.Instance.DigestKey.AlternateDigestKey = "test";
});
context.SetEndpoint(new Endpoint(null, new EndpointMetadataCollection(new UseDigestAttribute()), null));
DigestMiddleware middleware = new(httpContext =>
{
httpContext.Response.StatusCode = 200;
httpContext.Response.WriteAsync("");
return Task.CompletedTask;
},
true);
["bruh",]);
await middleware.InvokeAsync(context);
@ -115,28 +118,23 @@ public class DigestMiddlewareTests
[Fact]
public async Task DigestMiddleware_ShouldUseAlternateDigest_WhenPrimaryDigestInvalid()
{
DefaultHttpContext context = new()
{
Request =
DefaultHttpContext context = GetHttpContext(new MemoryStream(),
"/LITTLEBIGPLANETPS3_XML/notification",
"MM_AUTH=unittest",
new Dictionary<string, StringValues>
{
Body = new MemoryStream(),
Path = "/LITTLEBIGPLANETPS3_XML/notification",
Headers =
{
KeyValuePair.Create<string, StringValues>("Cookie", "MM_AUTH=unittest"),
KeyValuePair.Create<string, StringValues>("X-Digest-A", "df619790a2579a077eae4a6b6864966ff4768723"),
"X-Digest-A", "df619790a2579a077eae4a6b6864966ff4768723"
},
},
};
ServerConfiguration.Instance.DigestKey.PrimaryDigestKey = "test";
ServerConfiguration.Instance.DigestKey.AlternateDigestKey = "bruh";
});
DigestMiddleware middleware = new(httpContext =>
{
httpContext.Response.StatusCode = 200;
httpContext.Response.WriteAsync("");
return Task.CompletedTask;
},
true);
["test, bruh",]);
await middleware.InvokeAsync(context);
@ -166,14 +164,14 @@ public class DigestMiddlewareTests
},
},
};
ServerConfiguration.Instance.DigestKey.PrimaryDigestKey = "bruh";
DigestMiddleware middleware = new(httpContext =>
{
httpContext.Response.StatusCode = 200;
httpContext.Response.WriteAsync("");
return Task.CompletedTask;
},
true);
["bruh",]);
await middleware.InvokeAsync(context);
@ -203,14 +201,14 @@ public class DigestMiddlewareTests
},
},
};
ServerConfiguration.Instance.DigestKey.PrimaryDigestKey = "bruh";
DigestMiddleware middleware = new(httpContext =>
{
httpContext.Response.StatusCode = 200;
httpContext.Response.WriteAsync("");
return Task.CompletedTask;
},
true);
["bruh",]);
await middleware.InvokeAsync(context);
@ -241,14 +239,14 @@ public class DigestMiddlewareTests
},
},
};
ServerConfiguration.Instance.DigestKey.PrimaryDigestKey = "bruh";
DigestMiddleware middleware = new(httpContext =>
{
httpContext.Response.StatusCode = 200;
httpContext.Response.WriteAsync("");
return Task.CompletedTask;
},
true);
["bruh",]);
await middleware.InvokeAsync(context);
@ -279,14 +277,14 @@ public class DigestMiddlewareTests
},
},
};
ServerConfiguration.Instance.DigestKey.PrimaryDigestKey = "bruh";
DigestMiddleware middleware = new(httpContext =>
{
httpContext.Response.StatusCode = 200;
httpContext.Response.WriteAsync("");
return Task.CompletedTask;
},
true);
["bruh",]);
await middleware.InvokeAsync(context);
@ -317,14 +315,13 @@ public class DigestMiddlewareTests
},
},
};
ServerConfiguration.Instance.DigestKey.PrimaryDigestKey = "bruh";
DigestMiddleware middleware = new(httpContext =>
{
httpContext.Response.StatusCode = 200;
httpContext.Response.WriteAsync("digest test");
return Task.CompletedTask;
},
true);
["bruh",]);
await middleware.InvokeAsync(context);
@ -355,14 +352,13 @@ public class DigestMiddlewareTests
},
},
};
ServerConfiguration.Instance.DigestKey.PrimaryDigestKey = "bruh";
DigestMiddleware middleware = new(httpContext =>
{
httpContext.Response.StatusCode = 200;
httpContext.Response.WriteAsync("");
return Task.CompletedTask;
},
true);
["bruh",]);
await middleware.InvokeAsync(context);
@ -398,7 +394,6 @@ public class DigestMiddlewareTests
},
},
};
ServerConfiguration.Instance.DigestKey.PrimaryDigestKey = "bruh";
DigestMiddleware middleware = new(httpContext =>
{
httpContext.Response.StatusCode = 200;
@ -406,7 +401,7 @@ public class DigestMiddlewareTests
httpContext.Response.Headers.ContentType = "text/xml";
return Task.CompletedTask;
},
true);
["bruh",]);
await middleware.InvokeAsync(context);