-
-
Notifications
You must be signed in to change notification settings - Fork 30
Expand file tree
/
Copy pathRateLimitMiddleware.cs
More file actions
133 lines (114 loc) · 4.88 KB
/
RateLimitMiddleware.cs
File metadata and controls
133 lines (114 loc) · 4.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
using Helldivers.API.Configuration;
using Helldivers.API.Extensions;
using Microsoft.Extensions.Caching.Memory;
using Microsoft.Extensions.Options;
using Microsoft.Net.Http.Headers;
using System.Net;
using System.Security.Claims;
using System.Text;
using System.Text.Json;
using System.Threading.RateLimiting;
namespace Helldivers.API.Middlewares;
/// <summary>
/// Handles applying rate limit logic to the API's requests.
/// </summary>
public sealed partial class RateLimitMiddleware(
ILogger<RateLimitMiddleware> logger,
IOptions<ApiConfiguration> options,
IMemoryCache cache
) : IMiddleware
{
[LoggerMessage(Level = LogLevel.Debug, Message = "Retrieving rate limiter for {Key}")]
private static partial void LogRateLimitKey(ILogger logger, IPAddress key);
[LoggerMessage(Level = LogLevel.Information, Message = "Retrieving rate limit for {Name} ({Limit})")]
private static partial void LogRateLimitForUser(ILogger logger, string name, int limit);
/// <inheritdoc />
public async Task InvokeAsync(HttpContext context, RequestDelegate next)
{
if (IsValidRequest(context) is false)
{
await RejectRequest(context);
return;
}
var limiter = GetRateLimiter(context);
using var lease = await limiter.AcquireAsync(permitCount: 1, context.RequestAborted);
if (limiter.GetStatistics() is { } statistics)
{
context.Response.Headers["X-RateLimit-Limit"] = $"{options.Value.RateLimit}";
context.Response.Headers["X-RateLimit-Remaining"] = $"{statistics.CurrentAvailablePermits}";
if (lease.TryGetMetadata(MetadataName.RetryAfter, out var retryAfter))
context.Response.Headers["Retry-After"] = $"{retryAfter.Seconds}";
}
if (lease.IsAcquired is false)
{
context.Response.StatusCode = StatusCodes.Status429TooManyRequests;
return;
}
await next(context);
}
/// <summary>
/// Checks if the request is "valid" (contains the correct X-Super-* headers).
/// </summary>
private bool IsValidRequest(HttpContext context)
{
if (options.Value.ValidateClients is false || context.Request.Path.StartsWithSegments("/metrics"))
return true;
return HasSuperHeaderOrQuery(context, Constants.CLIENT_HEADER_NAME)
&& HasSuperHeaderOrQuery(context, Constants.CONTACT_HEADER_NAME);
}
private RateLimiter GetRateLimiter(HttpContext http)
{
if (http.User.Identity?.IsAuthenticated ?? false)
return GetRateLimiterForUser(http.User);
var key = http.Connection.RemoteIpAddress ?? IPAddress.Loopback;
LogRateLimitKey(logger, key);
return cache.GetOrCreate(key, entry =>
{
entry.SlidingExpiration = TimeSpan.FromSeconds(options.Value.RateLimitWindow);
return new TokenBucketRateLimiter(new()
{
AutoReplenishment = true,
TokenLimit = options.Value.RateLimit,
TokensPerPeriod = options.Value.RateLimit,
QueueLimit = 0,
ReplenishmentPeriod = TimeSpan.FromSeconds(options.Value.RateLimitWindow)
});
}) ?? throw new InvalidOperationException($"Creating rate limiter failed for {key}");
}
private RateLimiter GetRateLimiterForUser(ClaimsPrincipal user)
{
var name = user.Identity?.Name!;
var limit = user.GetIntClaim("RateLimit");
LogRateLimitForUser(logger, name, limit);
return cache.GetOrCreate(name, entry =>
{
entry.SlidingExpiration = TimeSpan.FromSeconds(options.Value.RateLimitWindow);
return new TokenBucketRateLimiter(new()
{
AutoReplenishment = true,
TokenLimit = limit,
TokensPerPeriod = limit,
QueueLimit = 0,
ReplenishmentPeriod = TimeSpan.FromSeconds(options.Value.RateLimitWindow)
});
}) ?? throw new InvalidOperationException($"Creating rate limiter failed for {name}");
}
private async Task RejectRequest(HttpContext context)
{
context.Response.StatusCode = StatusCodes.Status400BadRequest;
context.Response.Headers.WWWAuthenticate = "X-Super-Client";
context.Response.ContentType = "application/json";
var writer = new Utf8JsonWriter(context.Response.Body);
writer.WriteStartObject();
writer.WritePropertyName("message");
writer.WriteStringValue("The X-Super-Client and X-Super-Contact headers are required");
writer.WriteEndObject();
await writer.FlushAsync(context.RequestAborted);
}
private bool HasSuperHeaderOrQuery(HttpContext context, string name)
{
if (context.Request.Headers.ContainsKey(name))
return true;
return context.Request.Query.ContainsKey(name.ToLowerInvariant());
}
}