From 886771ec3ca3b9d131c9f21e9ef5ced663dc594a Mon Sep 17 00:00:00 2001 From: Josh Date: Sun, 5 Feb 2023 11:53:58 -0600 Subject: [PATCH] Make rate limiter match zone rather than directly comparing url (#656) Make rate limiter match regex rather than directly comparing --- .../Middlewares/RateLimitMiddleware.cs | 23 +++++++++++++++---- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/ProjectLighthouse/Middlewares/RateLimitMiddleware.cs b/ProjectLighthouse/Middlewares/RateLimitMiddleware.cs index 66f60fd5..c668f434 100644 --- a/ProjectLighthouse/Middlewares/RateLimitMiddleware.cs +++ b/ProjectLighthouse/Middlewares/RateLimitMiddleware.cs @@ -49,12 +49,16 @@ public class RateLimitMiddleware : Middleware RemoveExpiredEntries(); - if (GetNumRequestsForPath(address, path) >= GetMaxNumRequests(options)) + if (GetNumRequestsForPath(address, path, options) >= GetMaxNumRequests(options)) { Logger.Info($"Request limit reached for {address} ({ctx.Request.Path})", LogArea.RateLimit); long nextExpiration = recentRequests[address][0]?.Expiration ?? TimeHelper.TimestampMillis; ctx.Response.Headers.TryAdd("Retry-After", "" + Math.Ceiling((nextExpiration - TimeHelper.TimestampMillis) / 1000f)); ctx.Response.StatusCode = 429; + await ctx.Response.WriteAsync( + "Rate limit reached" + + "

You have reached the rate limit

" + + $"

Try again in {ctx.Response.Headers.RetryAfter} seconds"); return; } @@ -97,7 +101,7 @@ public class RateLimitMiddleware : Middleware private static void LogRequest(IPAddress address, PathString path, RateLimitOptions? options) { - recentRequests.GetOrAdd(address, new List()).Add(LighthouseRequest.Create(path, GetRequestInterval(options) * 1000 + TimeHelper.TimestampMillis)); + recentRequests.GetOrAdd(address, new List()).Add(LighthouseRequest.Create(path, GetRequestInterval(options) * 1000 + TimeHelper.TimestampMillis, options)); } private static void RemoveExpiredEntries() @@ -117,22 +121,31 @@ public class RateLimitMiddleware : Middleware private static string RemoveTrailingSlash(string s) => s.TrimEnd('/').TrimEnd('\\'); - private static int GetNumRequestsForPath(IPAddress address, PathString path) + private static int GetNumRequestsForPath(IPAddress address, PathString path, RateLimitOptions? options) { - return !recentRequests.ContainsKey(address) ? 0 : recentRequests[address].Count(r => (r?.Path ?? "") == path); + if (!recentRequests.ContainsKey(address)) return 0; + int? optionsHash = options?.GetHashCode(); + // If there are no custom options then count requests based on exact url matches, otherwise use regex matching + return options switch + { + null => recentRequests[address].Count(r => (r?.Path ?? "") == path), + _ => recentRequests[address].Count(r => r?.OptionsHash == optionsHash), + }; } private class LighthouseRequest { public PathString Path { get; private init; } = ""; + public int? OptionsHash { get; private init; } public long Expiration { get; private init; } - public static LighthouseRequest Create(PathString path, long expiration) + public static LighthouseRequest Create(PathString path, long expiration, RateLimitOptions? options = null) { LighthouseRequest request = new() { Path = path, Expiration = expiration, + OptionsHash = options?.GetHashCode(), }; return request; }