Make the rate limiter more thread safe

This commit is contained in:
Slendy 2023-02-18 04:33:09 -06:00
commit cf1adbe640
No known key found for this signature in database
GPG key ID: 7288D68361B91428
2 changed files with 16 additions and 9 deletions

View file

@ -80,8 +80,8 @@ public class WebsiteStartup
app.UseForwardedHeaders(); app.UseForwardedHeaders();
app.UseMiddleware<RequestLogMiddleware>(); app.UseMiddleware<RequestLogMiddleware>();
app.UseMiddleware<UserRequiredRedirectMiddleware>();
app.UseMiddleware<RateLimitMiddleware>(); app.UseMiddleware<RateLimitMiddleware>();
app.UseMiddleware<UserRequiredRedirectMiddleware>();
app.UseRouting(); app.UseRouting();

View file

@ -19,7 +19,7 @@ public class RateLimitMiddleware : Middleware
{ {
// (ipAddress, requestData) // (ipAddress, requestData)
private static readonly ConcurrentDictionary<IPAddress, List<LighthouseRequest?>> recentRequests = new(); private static readonly ConcurrentDictionary<IPAddress, ConcurrentQueue<LighthouseRequest?>> recentRequests = new();
public RateLimitMiddleware(RequestDelegate next) : base(next) public RateLimitMiddleware(RequestDelegate next) : base(next)
{ } { }
@ -54,7 +54,8 @@ public class RateLimitMiddleware : Middleware
if (GetNumRequestsForPath(address, path, options) >= GetMaxNumRequests(options)) if (GetNumRequestsForPath(address, path, options) >= GetMaxNumRequests(options))
{ {
Logger.Info($"Request limit reached for {address} ({ctx.Request.Path})", LogArea.RateLimit); Logger.Info($"Request limit reached for {address} ({ctx.Request.Path})", LogArea.RateLimit);
long nextExpiration = recentRequests[address][0]?.Expiration ?? TimeHelper.TimestampMillis; recentRequests[address].TryPeek(out LighthouseRequest? request);
long nextExpiration = request?.Expiration ?? TimeHelper.TimestampMillis;
ctx.Response.Headers.TryAdd("Retry-After", "" + Math.Ceiling((nextExpiration - TimeHelper.TimestampMillis) / 1000f)); ctx.Response.Headers.TryAdd("Retry-After", "" + Math.Ceiling((nextExpiration - TimeHelper.TimestampMillis) / 1000f));
ctx.Response.StatusCode = 429; ctx.Response.StatusCode = 429;
await ctx.Response.WriteAsync( await ctx.Response.WriteAsync(
@ -103,21 +104,27 @@ public class RateLimitMiddleware : Middleware
private static void LogRequest(IPAddress address, PathString path, RateLimitOptions? options) private static void LogRequest(IPAddress address, PathString path, RateLimitOptions? options)
{ {
recentRequests.GetOrAdd(address, new List<LighthouseRequest?>()).Add(LighthouseRequest.Create(path, GetRequestInterval(options) * 1000 + TimeHelper.TimestampMillis, options)); LighthouseRequest request = LighthouseRequest.Create(path, GetRequestInterval(options) * 1000 + TimeHelper.TimestampMillis, options);
recentRequests.GetOrAdd(address, new ConcurrentQueue<LighthouseRequest?>()).Enqueue(request);
} }
private static void RemoveExpiredEntries() private static void RemoveExpiredEntries()
{ {
for (int i = recentRequests.Count - 1; i >= 0; i--) foreach((IPAddress address, ConcurrentQueue<LighthouseRequest?> list) in recentRequests)
{ {
IPAddress address = recentRequests.ElementAt(i).Key; if (list.IsEmpty)
bool exists = recentRequests.TryGetValue(address, out List<LighthouseRequest?>? requests);
if (!exists || requests == null || recentRequests[address].Count == 0)
{ {
recentRequests.TryRemove(address, out _); recentRequests.TryRemove(address, out _);
continue; continue;
} }
requests.RemoveAll(r => TimeHelper.TimestampMillis >= (r?.Expiration ?? TimeHelper.TimestampMillis));
while (list.TryPeek(out LighthouseRequest? request))
{
if (TimeHelper.TimestampMillis < (request?.Expiration ?? TimeHelper.TimestampMillis))
break;
list.TryDequeue(out _);
}
} }
} }