Make rate limiter match zone rather than directly comparing url (#656)

Make rate limiter match regex rather than directly comparing
This commit is contained in:
Josh 2023-02-05 11:53:58 -06:00 committed by GitHub
commit 886771ec3c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -49,12 +49,16 @@ public class RateLimitMiddleware : Middleware
RemoveExpiredEntries(); RemoveExpiredEntries();
if (GetNumRequestsForPath(address, path) >= GetMaxNumRequests(options)) if (GetNumRequestsForPath(address, path, options) >= GetMaxNumRequests(options))
{ {
Logger.Info($"Request limit reached for {address} ({ctx.Request.Path})", LogArea.RateLimit); Logger.Info($"Request limit reached for {address} ({ctx.Request.Path})", LogArea.RateLimit);
long nextExpiration = recentRequests[address][0]?.Expiration ?? TimeHelper.TimestampMillis; long nextExpiration = recentRequests[address][0]?.Expiration ?? TimeHelper.TimestampMillis;
ctx.Response.Headers.TryAdd("Retry-After", "" + Math.Ceiling((nextExpiration - TimeHelper.TimestampMillis) / 1000f)); ctx.Response.Headers.TryAdd("Retry-After", "" + Math.Ceiling((nextExpiration - TimeHelper.TimestampMillis) / 1000f));
ctx.Response.StatusCode = 429; ctx.Response.StatusCode = 429;
await ctx.Response.WriteAsync(
"<html><head><title>Rate limit reached</title><style>html{font-family: Tahoma, Verdana, Arial, sans-serif;}</style></head>" +
"<h1>You have reached the rate limit</h1>" +
$"<p>Try again in {ctx.Response.Headers.RetryAfter} seconds</html>");
return; return;
} }
@ -97,7 +101,7 @@ public class RateLimitMiddleware : Middleware
private static void LogRequest(IPAddress address, PathString path, RateLimitOptions? options) private static void LogRequest(IPAddress address, PathString path, RateLimitOptions? options)
{ {
recentRequests.GetOrAdd(address, new List<LighthouseRequest?>()).Add(LighthouseRequest.Create(path, GetRequestInterval(options) * 1000 + TimeHelper.TimestampMillis)); recentRequests.GetOrAdd(address, new List<LighthouseRequest?>()).Add(LighthouseRequest.Create(path, GetRequestInterval(options) * 1000 + TimeHelper.TimestampMillis, options));
} }
private static void RemoveExpiredEntries() private static void RemoveExpiredEntries()
@ -117,22 +121,31 @@ public class RateLimitMiddleware : Middleware
private static string RemoveTrailingSlash(string s) => s.TrimEnd('/').TrimEnd('\\'); private static string RemoveTrailingSlash(string s) => s.TrimEnd('/').TrimEnd('\\');
private static int GetNumRequestsForPath(IPAddress address, PathString path) private static int GetNumRequestsForPath(IPAddress address, PathString path, RateLimitOptions? options)
{ {
return !recentRequests.ContainsKey(address) ? 0 : recentRequests[address].Count(r => (r?.Path ?? "") == path); if (!recentRequests.ContainsKey(address)) return 0;
int? optionsHash = options?.GetHashCode();
// If there are no custom options then count requests based on exact url matches, otherwise use regex matching
return options switch
{
null => recentRequests[address].Count(r => (r?.Path ?? "") == path),
_ => recentRequests[address].Count(r => r?.OptionsHash == optionsHash),
};
} }
private class LighthouseRequest private class LighthouseRequest
{ {
public PathString Path { get; private init; } = ""; public PathString Path { get; private init; } = "";
public int? OptionsHash { get; private init; }
public long Expiration { get; private init; } public long Expiration { get; private init; }
public static LighthouseRequest Create(PathString path, long expiration) public static LighthouseRequest Create(PathString path, long expiration, RateLimitOptions? options = null)
{ {
LighthouseRequest request = new() LighthouseRequest request = new()
{ {
Path = path, Path = path,
Expiration = expiration, Expiration = expiration,
OptionsHash = options?.GetHashCode(),
}; };
return request; return request;
} }