Make the rate limiter more thread safe

This commit is contained in:
Slendy 2023-02-18 04:33:09 -06:00
commit cf1adbe640
No known key found for this signature in database
GPG key ID: 7288D68361B91428
2 changed files with 16 additions and 9 deletions

View file

@ -80,8 +80,8 @@ public class WebsiteStartup
app.UseForwardedHeaders();
app.UseMiddleware<RequestLogMiddleware>();
app.UseMiddleware<UserRequiredRedirectMiddleware>();
app.UseMiddleware<RateLimitMiddleware>();
app.UseMiddleware<UserRequiredRedirectMiddleware>();
app.UseRouting();

View file

@ -19,7 +19,7 @@ public class RateLimitMiddleware : Middleware
{
// (ipAddress, requestData)
private static readonly ConcurrentDictionary<IPAddress, List<LighthouseRequest?>> recentRequests = new();
private static readonly ConcurrentDictionary<IPAddress, ConcurrentQueue<LighthouseRequest?>> recentRequests = new();
public RateLimitMiddleware(RequestDelegate next) : base(next)
{ }
@ -54,7 +54,8 @@ public class RateLimitMiddleware : Middleware
if (GetNumRequestsForPath(address, path, options) >= GetMaxNumRequests(options))
{
Logger.Info($"Request limit reached for {address} ({ctx.Request.Path})", LogArea.RateLimit);
long nextExpiration = recentRequests[address][0]?.Expiration ?? TimeHelper.TimestampMillis;
recentRequests[address].TryPeek(out LighthouseRequest? request);
long nextExpiration = request?.Expiration ?? TimeHelper.TimestampMillis;
ctx.Response.Headers.TryAdd("Retry-After", "" + Math.Ceiling((nextExpiration - TimeHelper.TimestampMillis) / 1000f));
ctx.Response.StatusCode = 429;
await ctx.Response.WriteAsync(
@ -103,21 +104,27 @@ public class RateLimitMiddleware : Middleware
private static void LogRequest(IPAddress address, PathString path, RateLimitOptions? options)
{
recentRequests.GetOrAdd(address, new List<LighthouseRequest?>()).Add(LighthouseRequest.Create(path, GetRequestInterval(options) * 1000 + TimeHelper.TimestampMillis, options));
LighthouseRequest request = LighthouseRequest.Create(path, GetRequestInterval(options) * 1000 + TimeHelper.TimestampMillis, options);
recentRequests.GetOrAdd(address, new ConcurrentQueue<LighthouseRequest?>()).Enqueue(request);
}
private static void RemoveExpiredEntries()
{
for (int i = recentRequests.Count - 1; i >= 0; i--)
foreach((IPAddress address, ConcurrentQueue<LighthouseRequest?> list) in recentRequests)
{
IPAddress address = recentRequests.ElementAt(i).Key;
bool exists = recentRequests.TryGetValue(address, out List<LighthouseRequest?>? requests);
if (!exists || requests == null || recentRequests[address].Count == 0)
if (list.IsEmpty)
{
recentRequests.TryRemove(address, out _);
continue;
}
requests.RemoveAll(r => TimeHelper.TimestampMillis >= (r?.Expiration ?? TimeHelper.TimestampMillis));
while (list.TryPeek(out LighthouseRequest? request))
{
if (TimeHelper.TimestampMillis < (request?.Expiration ?? TimeHelper.TimestampMillis))
break;
list.TryDequeue(out _);
}
}
}