Add better null handling to RateLimitMiddleware

This commit is contained in:
Slendy 2022-11-02 15:45:13 -05:00
commit f6f0f04548
No known key found for this signature in database
GPG key ID: 7288D68361B91428

View file

@ -17,7 +17,7 @@ public class RateLimitMiddleware : MiddlewareDBContext
{ {
// (userId, requestData) // (userId, requestData)
private static readonly ConcurrentDictionary<IPAddress, List<LighthouseRequest>> recentRequests = new(); private static readonly ConcurrentDictionary<IPAddress, List<LighthouseRequest?>> recentRequests = new();
public RateLimitMiddleware(RequestDelegate next) : base(next) public RateLimitMiddleware(RequestDelegate next) : base(next)
{ } { }
@ -52,7 +52,8 @@ public class RateLimitMiddleware : MiddlewareDBContext
if (GetNumRequestsForPath(address, path) >= GetMaxNumRequests(options)) if (GetNumRequestsForPath(address, path) >= GetMaxNumRequests(options))
{ {
Logger.Info($"Request limit reached for {address.ToString()} ({ctx.Request.Path})", LogArea.RateLimit); Logger.Info($"Request limit reached for {address.ToString()} ({ctx.Request.Path})", LogArea.RateLimit);
ctx.Response.Headers.Add("Retry-After", "" + Math.Ceiling((recentRequests[address][0].Expiration - TimeHelper.TimestampMillis) / 1000f)); long nextExpiration = recentRequests[address][0]?.Expiration ?? TimeHelper.TimestampMillis;
ctx.Response.Headers.Add("Retry-After", "" + Math.Ceiling((nextExpiration - TimeHelper.TimestampMillis) / 1000f));
ctx.Response.StatusCode = 429; ctx.Response.StatusCode = 429;
return; return;
} }
@ -96,7 +97,7 @@ public class RateLimitMiddleware : MiddlewareDBContext
private static void LogRequest(IPAddress address, PathString path, RateLimitOptions? options) private static void LogRequest(IPAddress address, PathString path, RateLimitOptions? options)
{ {
recentRequests.GetOrAdd(address, new List<LighthouseRequest>()).Add(LighthouseRequest.Create(path, GetRequestInterval(options) * 1000 + TimeHelper.TimestampMillis)); recentRequests.GetOrAdd(address, new List<LighthouseRequest?>()).Add(LighthouseRequest.Create(path, GetRequestInterval(options) * 1000 + TimeHelper.TimestampMillis));
} }
private static void RemoveExpiredEntries() private static void RemoveExpiredEntries()
@ -104,13 +105,13 @@ public class RateLimitMiddleware : MiddlewareDBContext
for (int i = recentRequests.Count - 1; i >= 0; i--) for (int i = recentRequests.Count - 1; i >= 0; i--)
{ {
IPAddress address = recentRequests.ElementAt(i).Key; IPAddress address = recentRequests.ElementAt(i).Key;
bool exists = recentRequests.TryGetValue(address, out List<LighthouseRequest>? requests); bool exists = recentRequests.TryGetValue(address, out List<LighthouseRequest?>? requests);
if (!exists || recentRequests[address].Count == 0) if (!exists || requests == null || recentRequests[address].Count == 0)
{ {
recentRequests.TryRemove(address, out _); recentRequests.TryRemove(address, out _);
continue; continue;
} }
requests?.RemoveAll(r => TimeHelper.TimestampMillis > r.Expiration); requests.RemoveAll(r => TimeHelper.TimestampMillis >= (r?.Expiration ?? TimeHelper.TimestampMillis));
} }
} }
@ -118,10 +119,7 @@ public class RateLimitMiddleware : MiddlewareDBContext
private static int GetNumRequestsForPath(IPAddress address, PathString path) private static int GetNumRequestsForPath(IPAddress address, PathString path)
{ {
if (!recentRequests.ContainsKey(address)) return 0; return !recentRequests.ContainsKey(address) ? 0 : recentRequests[address].Count(r => (r?.Path ?? "") == path);
List<LighthouseRequest> requests = recentRequests[address];
return requests.Count(r => r.Path == path);
} }
private class LighthouseRequest private class LighthouseRequest