diff --git a/ProjectLighthouse/Middlewares/RateLimitMiddleware.cs b/ProjectLighthouse/Middlewares/RateLimitMiddleware.cs index c4a6132b..46b4d1b9 100644 --- a/ProjectLighthouse/Middlewares/RateLimitMiddleware.cs +++ b/ProjectLighthouse/Middlewares/RateLimitMiddleware.cs @@ -17,7 +17,7 @@ public class RateLimitMiddleware : MiddlewareDBContext { // (userId, requestData) - private static readonly ConcurrentDictionary> recentRequests = new(); + private static readonly ConcurrentDictionary> recentRequests = new(); public RateLimitMiddleware(RequestDelegate next) : base(next) { } @@ -52,7 +52,8 @@ public class RateLimitMiddleware : MiddlewareDBContext if (GetNumRequestsForPath(address, path) >= GetMaxNumRequests(options)) { Logger.Info($"Request limit reached for {address.ToString()} ({ctx.Request.Path})", LogArea.RateLimit); - ctx.Response.Headers.Add("Retry-After", "" + Math.Ceiling((recentRequests[address][0].Expiration - TimeHelper.TimestampMillis) / 1000f)); + long nextExpiration = recentRequests[address][0]?.Expiration ?? TimeHelper.TimestampMillis; + ctx.Response.Headers.Add("Retry-After", "" + Math.Ceiling((nextExpiration - TimeHelper.TimestampMillis) / 1000f)); ctx.Response.StatusCode = 429; return; } @@ -96,7 +97,7 @@ public class RateLimitMiddleware : MiddlewareDBContext private static void LogRequest(IPAddress address, PathString path, RateLimitOptions? options) { - recentRequests.GetOrAdd(address, new List()).Add(LighthouseRequest.Create(path, GetRequestInterval(options) * 1000 + TimeHelper.TimestampMillis)); + recentRequests.GetOrAdd(address, new List()).Add(LighthouseRequest.Create(path, GetRequestInterval(options) * 1000 + TimeHelper.TimestampMillis)); } private static void RemoveExpiredEntries() @@ -104,13 +105,13 @@ public class RateLimitMiddleware : MiddlewareDBContext for (int i = recentRequests.Count - 1; i >= 0; i--) { IPAddress address = recentRequests.ElementAt(i).Key; - bool exists = recentRequests.TryGetValue(address, out List? requests); - if (!exists || recentRequests[address].Count == 0) + bool exists = recentRequests.TryGetValue(address, out List? requests); + if (!exists || requests == null || recentRequests[address].Count == 0) { recentRequests.TryRemove(address, out _); continue; } - requests?.RemoveAll(r => TimeHelper.TimestampMillis > r.Expiration); + requests.RemoveAll(r => TimeHelper.TimestampMillis >= (r?.Expiration ?? TimeHelper.TimestampMillis)); } } @@ -118,10 +119,7 @@ public class RateLimitMiddleware : MiddlewareDBContext private static int GetNumRequestsForPath(IPAddress address, PathString path) { - if (!recentRequests.ContainsKey(address)) return 0; - - List requests = recentRequests[address]; - return requests.Count(r => r.Path == path); + return !recentRequests.ContainsKey(address) ? 0 : recentRequests[address].Count(r => (r?.Path ?? "") == path); } private class LighthouseRequest