Add better null handling to RateLimitMiddleware

This commit is contained in:
Slendy 2022-11-02 15:45:13 -05:00
commit f6f0f04548
No known key found for this signature in database
GPG key ID: 7288D68361B91428

View file

@ -17,7 +17,7 @@ public class RateLimitMiddleware : MiddlewareDBContext
{
// (userId, requestData)
private static readonly ConcurrentDictionary<IPAddress, List<LighthouseRequest>> recentRequests = new();
private static readonly ConcurrentDictionary<IPAddress, List<LighthouseRequest?>> recentRequests = new();
public RateLimitMiddleware(RequestDelegate next) : base(next)
{ }
@ -52,7 +52,8 @@ public class RateLimitMiddleware : MiddlewareDBContext
if (GetNumRequestsForPath(address, path) >= GetMaxNumRequests(options))
{
Logger.Info($"Request limit reached for {address.ToString()} ({ctx.Request.Path})", LogArea.RateLimit);
ctx.Response.Headers.Add("Retry-After", "" + Math.Ceiling((recentRequests[address][0].Expiration - TimeHelper.TimestampMillis) / 1000f));
long nextExpiration = recentRequests[address][0]?.Expiration ?? TimeHelper.TimestampMillis;
ctx.Response.Headers.Add("Retry-After", "" + Math.Ceiling((nextExpiration - TimeHelper.TimestampMillis) / 1000f));
ctx.Response.StatusCode = 429;
return;
}
@ -96,7 +97,7 @@ public class RateLimitMiddleware : MiddlewareDBContext
private static void LogRequest(IPAddress address, PathString path, RateLimitOptions? options)
{
recentRequests.GetOrAdd(address, new List<LighthouseRequest>()).Add(LighthouseRequest.Create(path, GetRequestInterval(options) * 1000 + TimeHelper.TimestampMillis));
recentRequests.GetOrAdd(address, new List<LighthouseRequest?>()).Add(LighthouseRequest.Create(path, GetRequestInterval(options) * 1000 + TimeHelper.TimestampMillis));
}
private static void RemoveExpiredEntries()
@ -104,13 +105,13 @@ public class RateLimitMiddleware : MiddlewareDBContext
for (int i = recentRequests.Count - 1; i >= 0; i--)
{
IPAddress address = recentRequests.ElementAt(i).Key;
bool exists = recentRequests.TryGetValue(address, out List<LighthouseRequest>? requests);
if (!exists || recentRequests[address].Count == 0)
bool exists = recentRequests.TryGetValue(address, out List<LighthouseRequest?>? requests);
if (!exists || requests == null || recentRequests[address].Count == 0)
{
recentRequests.TryRemove(address, out _);
continue;
}
requests?.RemoveAll(r => TimeHelper.TimestampMillis > r.Expiration);
requests.RemoveAll(r => TimeHelper.TimestampMillis >= (r?.Expiration ?? TimeHelper.TimestampMillis));
}
}
@ -118,10 +119,7 @@ public class RateLimitMiddleware : MiddlewareDBContext
private static int GetNumRequestsForPath(IPAddress address, PathString path)
{
if (!recentRequests.ContainsKey(address)) return 0;
List<LighthouseRequest> requests = recentRequests[address];
return requests.Count(r => r.Path == path);
return !recentRequests.ContainsKey(address) ? 0 : recentRequests[address].Count(r => (r?.Path ?? "") == path);
}
private class LighthouseRequest