diff --git a/ProjectLighthouse.Servers.Website/Startup/WebsiteStartup.cs b/ProjectLighthouse.Servers.Website/Startup/WebsiteStartup.cs index db629433..b17da16d 100644 --- a/ProjectLighthouse.Servers.Website/Startup/WebsiteStartup.cs +++ b/ProjectLighthouse.Servers.Website/Startup/WebsiteStartup.cs @@ -80,8 +80,8 @@ public class WebsiteStartup app.UseForwardedHeaders(); app.UseMiddleware(); - app.UseMiddleware(); app.UseMiddleware(); + app.UseMiddleware(); app.UseRouting(); diff --git a/ProjectLighthouse/Middlewares/RateLimitMiddleware.cs b/ProjectLighthouse/Middlewares/RateLimitMiddleware.cs index 334207b4..04c1ffb1 100644 --- a/ProjectLighthouse/Middlewares/RateLimitMiddleware.cs +++ b/ProjectLighthouse/Middlewares/RateLimitMiddleware.cs @@ -19,7 +19,7 @@ public class RateLimitMiddleware : Middleware { // (ipAddress, requestData) - private static readonly ConcurrentDictionary> recentRequests = new(); + private static readonly ConcurrentDictionary> recentRequests = new(); public RateLimitMiddleware(RequestDelegate next) : base(next) { } @@ -54,7 +54,8 @@ public class RateLimitMiddleware : Middleware if (GetNumRequestsForPath(address, path, options) >= GetMaxNumRequests(options)) { Logger.Info($"Request limit reached for {address} ({ctx.Request.Path})", LogArea.RateLimit); - long nextExpiration = recentRequests[address][0]?.Expiration ?? TimeHelper.TimestampMillis; + recentRequests[address].TryPeek(out LighthouseRequest? request); + long nextExpiration = request?.Expiration ?? TimeHelper.TimestampMillis; ctx.Response.Headers.TryAdd("Retry-After", "" + Math.Ceiling((nextExpiration - TimeHelper.TimestampMillis) / 1000f)); ctx.Response.StatusCode = 429; await ctx.Response.WriteAsync( @@ -103,21 +104,27 @@ public class RateLimitMiddleware : Middleware private static void LogRequest(IPAddress address, PathString path, RateLimitOptions? options) { - recentRequests.GetOrAdd(address, new List()).Add(LighthouseRequest.Create(path, GetRequestInterval(options) * 1000 + TimeHelper.TimestampMillis, options)); + LighthouseRequest request = LighthouseRequest.Create(path, GetRequestInterval(options) * 1000 + TimeHelper.TimestampMillis, options); + recentRequests.GetOrAdd(address, new ConcurrentQueue()).Enqueue(request); } private static void RemoveExpiredEntries() { - for (int i = recentRequests.Count - 1; i >= 0; i--) + foreach((IPAddress address, ConcurrentQueue list) in recentRequests) { - IPAddress address = recentRequests.ElementAt(i).Key; - bool exists = recentRequests.TryGetValue(address, out List? requests); - if (!exists || requests == null || recentRequests[address].Count == 0) + if (list.IsEmpty) { recentRequests.TryRemove(address, out _); continue; } - requests.RemoveAll(r => TimeHelper.TimestampMillis >= (r?.Expiration ?? TimeHelper.TimestampMillis)); + + while (list.TryPeek(out LighthouseRequest? request)) + { + if (TimeHelper.TimestampMillis < (request?.Expiration ?? TimeHelper.TimestampMillis)) + break; + + list.TryDequeue(out _); + } } }