diff --git a/ProjectLighthouse/Middlewares/RateLimitMiddleware.cs b/ProjectLighthouse/Middlewares/RateLimitMiddleware.cs
index 66f60fd5..c668f434 100644
--- a/ProjectLighthouse/Middlewares/RateLimitMiddleware.cs
+++ b/ProjectLighthouse/Middlewares/RateLimitMiddleware.cs
@@ -49,12 +49,16 @@ public class RateLimitMiddleware : Middleware
RemoveExpiredEntries();
- if (GetNumRequestsForPath(address, path) >= GetMaxNumRequests(options))
+ if (GetNumRequestsForPath(address, path, options) >= GetMaxNumRequests(options))
{
Logger.Info($"Request limit reached for {address} ({ctx.Request.Path})", LogArea.RateLimit);
long nextExpiration = recentRequests[address][0]?.Expiration ?? TimeHelper.TimestampMillis;
ctx.Response.Headers.TryAdd("Retry-After", "" + Math.Ceiling((nextExpiration - TimeHelper.TimestampMillis) / 1000f));
ctx.Response.StatusCode = 429;
+ await ctx.Response.WriteAsync(
+ "
Rate limit reached" +
+ "You have reached the rate limit
" +
+ $"Try again in {ctx.Response.Headers.RetryAfter} seconds");
return;
}
@@ -97,7 +101,7 @@ public class RateLimitMiddleware : Middleware
private static void LogRequest(IPAddress address, PathString path, RateLimitOptions? options)
{
- recentRequests.GetOrAdd(address, new List()).Add(LighthouseRequest.Create(path, GetRequestInterval(options) * 1000 + TimeHelper.TimestampMillis));
+ recentRequests.GetOrAdd(address, new List()).Add(LighthouseRequest.Create(path, GetRequestInterval(options) * 1000 + TimeHelper.TimestampMillis, options));
}
private static void RemoveExpiredEntries()
@@ -117,22 +121,31 @@ public class RateLimitMiddleware : Middleware
private static string RemoveTrailingSlash(string s) => s.TrimEnd('/').TrimEnd('\\');
- private static int GetNumRequestsForPath(IPAddress address, PathString path)
+ private static int GetNumRequestsForPath(IPAddress address, PathString path, RateLimitOptions? options)
{
- return !recentRequests.ContainsKey(address) ? 0 : recentRequests[address].Count(r => (r?.Path ?? "") == path);
+ if (!recentRequests.ContainsKey(address)) return 0;
+ int? optionsHash = options?.GetHashCode();
+ // If there are no custom options then count requests based on exact url matches, otherwise use regex matching
+ return options switch
+ {
+ null => recentRequests[address].Count(r => (r?.Path ?? "") == path),
+ _ => recentRequests[address].Count(r => r?.OptionsHash == optionsHash),
+ };
}
private class LighthouseRequest
{
public PathString Path { get; private init; } = "";
+ public int? OptionsHash { get; private init; }
public long Expiration { get; private init; }
- public static LighthouseRequest Create(PathString path, long expiration)
+ public static LighthouseRequest Create(PathString path, long expiration, RateLimitOptions? options = null)
{
LighthouseRequest request = new()
{
Path = path,
Expiration = expiration,
+ OptionsHash = options?.GetHashCode(),
};
return request;
}