mirror of
https://github.com/LBPUnion/ProjectLighthouse.git
synced 2025-07-26 15:08:39 +00:00
154 lines
No EOL
5.8 KiB
C#
154 lines
No EOL
5.8 KiB
C#
#nullable enable
|
|
using System;
|
|
using System.Collections.Concurrent;
|
|
using System.Collections.Generic;
|
|
using System.Linq;
|
|
using System.Net;
|
|
using System.Text.RegularExpressions;
|
|
using System.Threading.Tasks;
|
|
using LBPUnion.ProjectLighthouse.Configuration;
|
|
using LBPUnion.ProjectLighthouse.Helpers;
|
|
using LBPUnion.ProjectLighthouse.Logging;
|
|
using Microsoft.AspNetCore.Http;
|
|
|
|
namespace LBPUnion.ProjectLighthouse.Middlewares;
|
|
|
|
public class RateLimitMiddleware : Middleware
|
|
{
|
|
|
|
// (ipAddress, requestData)
|
|
private static readonly ConcurrentDictionary<IPAddress, List<LighthouseRequest?>> recentRequests = new();
|
|
|
|
public RateLimitMiddleware(RequestDelegate next) : base(next)
|
|
{ }
|
|
|
|
public override async Task InvokeAsync(HttpContext ctx)
|
|
{
|
|
// We only want to rate limit POST requests
|
|
if (ctx.Request.Method != "POST")
|
|
{
|
|
await this.next(ctx);
|
|
return;
|
|
}
|
|
IPAddress? address = ctx.Connection.RemoteIpAddress;
|
|
if (address == null)
|
|
{
|
|
await this.next(ctx);
|
|
return;
|
|
}
|
|
|
|
PathString path = RemoveTrailingSlash(ctx.Request.Path.ToString());
|
|
|
|
RateLimitOptions? options = GetRateLimitOverride(path);
|
|
|
|
if (!IsRateLimitEnabled(options))
|
|
{
|
|
await this.next(ctx);
|
|
return;
|
|
}
|
|
|
|
RemoveExpiredEntries();
|
|
|
|
if (GetNumRequestsForPath(address, path, options) >= GetMaxNumRequests(options))
|
|
{
|
|
Logger.Info($"Request limit reached for {address} ({ctx.Request.Path})", LogArea.RateLimit);
|
|
long nextExpiration = recentRequests[address][0]?.Expiration ?? TimeHelper.TimestampMillis;
|
|
ctx.Response.Headers.TryAdd("Retry-After", "" + Math.Ceiling((nextExpiration - TimeHelper.TimestampMillis) / 1000f));
|
|
ctx.Response.StatusCode = 429;
|
|
await ctx.Response.WriteAsync(
|
|
"<html><head><title>Rate limit reached</title><style>html{font-family: Tahoma, Verdana, Arial, sans-serif;}</style></head>" +
|
|
"<h1>You have reached the rate limit</h1>" +
|
|
$"<p>Try again in {ctx.Response.Headers.RetryAfter} seconds</html>");
|
|
return;
|
|
}
|
|
|
|
LogRequest(address, path, options);
|
|
|
|
// Handle request as normal
|
|
await this.next(ctx);
|
|
}
|
|
|
|
private static int GetMaxNumRequests(RateLimitOptions? options) => options?.RequestsPerInterval ?? ServerConfiguration.Instance.RateLimitConfiguration.GlobalOptions.RequestsPerInterval;
|
|
|
|
private static bool IsRateLimitEnabled(RateLimitOptions? options) => options?.Enabled ?? ServerConfiguration.Instance.RateLimitConfiguration.GlobalOptions.Enabled;
|
|
|
|
private static long GetRequestInterval(RateLimitOptions? options) => options?.RequestInterval ?? ServerConfiguration.Instance.RateLimitConfiguration.GlobalOptions.RequestInterval;
|
|
|
|
private static RateLimitOptions? GetRateLimitOverride(PathString path)
|
|
{
|
|
Dictionary<string, RateLimitOptions> overrides = ServerConfiguration.Instance.RateLimitConfiguration.OverrideOptions;
|
|
List<string> matchingOptions = overrides.Keys.Where(s => new Regex("^" + s.Replace("/", @"\/").Replace("*", ".*") + "$").Match(path).Success).ToList();
|
|
if (matchingOptions.Count == 0) return null;
|
|
// return 0 for equal, -1 for a, and 1 for b
|
|
matchingOptions.Sort((a, b) =>
|
|
{
|
|
int aWeight = 100;
|
|
int bWeight = 100;
|
|
if (a.Contains('*')) aWeight -= 20;
|
|
if (b.Contains('*')) bWeight -= 20;
|
|
|
|
aWeight += a.Length;
|
|
bWeight += b.Length;
|
|
|
|
if (aWeight > bWeight) return -1;
|
|
|
|
if (bWeight > aWeight) return 1;
|
|
|
|
return 0;
|
|
});
|
|
return overrides[matchingOptions.First()];
|
|
}
|
|
|
|
private static void LogRequest(IPAddress address, PathString path, RateLimitOptions? options)
|
|
{
|
|
recentRequests.GetOrAdd(address, new List<LighthouseRequest?>()).Add(LighthouseRequest.Create(path, GetRequestInterval(options) * 1000 + TimeHelper.TimestampMillis, options));
|
|
}
|
|
|
|
private static void RemoveExpiredEntries()
|
|
{
|
|
for (int i = recentRequests.Count - 1; i >= 0; i--)
|
|
{
|
|
IPAddress address = recentRequests.ElementAt(i).Key;
|
|
bool exists = recentRequests.TryGetValue(address, out List<LighthouseRequest?>? requests);
|
|
if (!exists || requests == null || recentRequests[address].Count == 0)
|
|
{
|
|
recentRequests.TryRemove(address, out _);
|
|
continue;
|
|
}
|
|
requests.RemoveAll(r => TimeHelper.TimestampMillis >= (r?.Expiration ?? TimeHelper.TimestampMillis));
|
|
}
|
|
}
|
|
|
|
private static string RemoveTrailingSlash(string s) => s.TrimEnd('/').TrimEnd('\\');
|
|
|
|
private static int GetNumRequestsForPath(IPAddress address, PathString path, RateLimitOptions? options)
|
|
{
|
|
if (!recentRequests.ContainsKey(address)) return 0;
|
|
int? optionsHash = options?.GetHashCode();
|
|
// If there are no custom options then count requests based on exact url matches, otherwise use regex matching
|
|
return options switch
|
|
{
|
|
null => recentRequests[address].Count(r => (r?.Path ?? "") == path),
|
|
_ => recentRequests[address].Count(r => r?.OptionsHash == optionsHash),
|
|
};
|
|
}
|
|
|
|
private class LighthouseRequest
|
|
{
|
|
public PathString Path { get; private init; } = "";
|
|
public int? OptionsHash { get; private init; }
|
|
public long Expiration { get; private init; }
|
|
|
|
public static LighthouseRequest Create(PathString path, long expiration, RateLimitOptions? options = null)
|
|
{
|
|
LighthouseRequest request = new()
|
|
{
|
|
Path = path,
|
|
Expiration = expiration,
|
|
OptionsHash = options?.GetHashCode(),
|
|
};
|
|
return request;
|
|
}
|
|
}
|
|
|
|
} |