Skip to content

Commit 5126e14

Browse files
0x46616c6bclaude
andcommitted
🐛 Count rejected requests in Redis sliding window
Mirror the in-memory fix from #79: the Lua script now ZADDs the current request before checking limits and uses '>' instead of '>=' for the comparison. This way a sender who keeps hammering the policy server while over-limit extends their own blocking window instead of getting a free reset once the oldest timestamps fall out. Update the hourly/daily limit tests to expect count=4 on the rejected 4th attempt, and replace the old "rejected does not increment" test with the equivalent of the in-memory test that asserts rejected attempts keep accumulating. Co-Authored-By: Claude <claude@anthropic.com>
1 parent 33fc124 commit 5126e14

2 files changed

Lines changed: 38 additions & 20 deletions

File tree

ratelimit.go

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,13 @@ const rateLimitKeyPrefix = "userli:ratelimit:sender:"
2020

2121
// rateLimitScript implements the sliding-window check atomically.
2222
// ARGV: hourLimit, dayLimit, now (unix nano), hourAgo, dayAgo, member suffix, TTL seconds.
23-
// Returns {allowed (1/0), hourCount, dayCount}. The new entry is only added when allowed.
23+
// Returns {allowed (1/0), hourCount, dayCount}.
24+
//
25+
// The current request is always recorded — even when it gets rejected — so a
26+
// sender who keeps trying while over-limit extends their own blocking window
27+
// instead of getting a free reset once old timestamps expire. The reported
28+
// counts therefore include the just-recorded attempt; the limit comparison
29+
// uses '>' rather than '>='.
2430
const rateLimitScript = `
2531
local key = KEYS[1]
2632
local hour_limit = tonumber(ARGV[1])
@@ -33,20 +39,20 @@ local ttl = tonumber(ARGV[7])
3339
3440
redis.call("ZREMRANGEBYSCORE", key, "-inf", "(" .. day_ago)
3541
42+
redis.call("ZADD", key, now, now .. ":" .. suffix)
43+
redis.call("EXPIRE", key, ttl)
44+
3645
local day_count = tonumber(redis.call("ZCARD", key))
3746
local hour_count = tonumber(redis.call("ZCOUNT", key, "(" .. hour_ago, "+inf"))
3847
39-
if hour_limit > 0 and hour_count >= hour_limit then
48+
if hour_limit > 0 and hour_count > hour_limit then
4049
return {0, hour_count, day_count}
4150
end
42-
if day_limit > 0 and day_count >= day_limit then
51+
if day_limit > 0 and day_count > day_limit then
4352
return {0, hour_count, day_count}
4453
end
4554
46-
redis.call("ZADD", key, now, now .. ":" .. suffix)
47-
redis.call("EXPIRE", key, ttl)
48-
49-
return {1, hour_count + 1, day_count + 1}
55+
return {1, hour_count, day_count}
5056
`
5157

5258
// RateLimiter enforces per-sender sliding-window quotas using Redis as backing store.

ratelimit_test.go

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -122,12 +122,13 @@ func TestRateLimiter_CheckAndIncrement_HourlyLimit(t *testing.T) {
122122
}
123123
}
124124

125+
// 4th message should be rejected but still counted
125126
allowed, hourCount, _ := rl.CheckAndIncrement(ctx, sender, quota)
126127
if allowed {
127128
t.Error("4th message should be rejected due to hourly limit")
128129
}
129-
if hourCount != 3 {
130-
t.Errorf("Expected hourCount to be 3, got %d", hourCount)
130+
if hourCount != 4 {
131+
t.Errorf("Expected hourCount to be 4, got %d", hourCount)
131132
}
132133
}
133134

@@ -145,12 +146,13 @@ func TestRateLimiter_CheckAndIncrement_DailyLimit(t *testing.T) {
145146
}
146147
}
147148

149+
// 4th message should be rejected but still counted
148150
allowed, _, dayCount := rl.CheckAndIncrement(ctx, sender, quota)
149151
if allowed {
150152
t.Error("4th message should be rejected due to daily limit")
151153
}
152-
if dayCount != 3 {
153-
t.Errorf("Expected dayCount to be 3, got %d", dayCount)
154+
if dayCount != 4 {
155+
t.Errorf("Expected dayCount to be 4, got %d", dayCount)
154156
}
155157
}
156158

@@ -228,29 +230,39 @@ func TestRateLimiter_KeyTTL(t *testing.T) {
228230
}
229231
}
230232

231-
// TestRateLimiter_RejectedDoesNotIncrement guards the conditional-ZADD invariant
232-
// in the Lua script: a rejected message must not push the counter above the limit.
233-
func TestRateLimiter_RejectedDoesNotIncrement(t *testing.T) {
233+
// TestRateLimiter_CheckAndIncrement_RejectedRequestsExtendWindow ensures
234+
// rejected attempts are still recorded so a sender who keeps trying while
235+
// over-limit extends their own blocking window instead of getting a free
236+
// reset once old timestamps expire.
237+
func TestRateLimiter_CheckAndIncrement_RejectedRequestsExtendWindow(t *testing.T) {
234238
rl, _ := newTestRateLimiter(t)
235239
ctx := context.Background()
236240

237241
quota := &Quota{PerHour: 2, PerDay: 100}
238-
sender := "reject@example.org"
242+
sender := "spammer@example.org"
239243

244+
// Use up the hourly quota.
240245
for i := 0; i < 2; i++ {
241-
rl.CheckAndIncrement(ctx, sender, quota)
246+
allowed, _, _ := rl.CheckAndIncrement(ctx, sender, quota)
247+
if !allowed {
248+
t.Errorf("Message %d should be allowed", i+1)
249+
}
242250
}
243251

252+
// Continued spam attempts must be rejected and still counted.
244253
for i := 0; i < 5; i++ {
245254
allowed, _, _ := rl.CheckAndIncrement(ctx, sender, quota)
246255
if allowed {
247-
t.Errorf("Attempt %d above limit should be rejected", i+1)
256+
t.Errorf("Spam attempt %d should be rejected", i+1)
248257
}
249258
}
250259

251-
hourCount, _ := rl.GetCounts(ctx, sender)
252-
if hourCount != 2 {
253-
t.Errorf("Expected hourCount to stay at 2 after rejected attempts, got %d", hourCount)
260+
hourCount, dayCount := rl.GetCounts(ctx, sender)
261+
if hourCount != 7 {
262+
t.Errorf("Expected hourCount to include rejected attempts (7), got %d", hourCount)
263+
}
264+
if dayCount != 7 {
265+
t.Errorf("Expected dayCount to include rejected attempts (7), got %d", dayCount)
254266
}
255267
}
256268

0 commit comments

Comments
 (0)