Skip to content

Commit f599b1d

Browse files
authored
fix: prevent rate limiter bypass via Lua atomic script (#593)
1 parent 0119c74 commit f599b1d

5 files changed

Lines changed: 84 additions & 14 deletions

File tree

.changeset/metal-snails-prove.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"nostream": patch
3+
---
4+
5+
fix: resolve TOCTOU race condition and key collisions in SlidingWindowRateLimiter

src/@types/adapters.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,11 @@ export interface ICacheAdapter {
2525
removeRangeByScoreFromSortedSet(key: string, min: number, max: number): Promise<number>
2626
getRangeFromSortedSet(key: string, start: number, stop: number): Promise<string[]>
2727
setKeyExpiry(key: string, expiry: number): Promise<void>
28+
2829
deleteKey(key: string): Promise<number>
2930
getHKey(key: string, field: string): Promise<string>
3031
setHKey(key: string, fields: Record<string, string>): Promise<boolean>
32+
33+
3134
eval(script: string, keys: string[], args: string[]): Promise<unknown>
3235
}

src/adapters/redis-adapter.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ export class RedisAdapter implements ICacheAdapter {
9696
return this.client.zAdd(key, members)
9797
}
9898

99+
99100
public async deleteKey(key: string): Promise<number> {
100101
await this.connection
101102
logger('delete %s key', key)
@@ -123,4 +124,5 @@ export class RedisAdapter implements ICacheAdapter {
123124
return await this.client.evalSha(this.scriptShas.get(script)!, { keys, arguments: args })
124125
}
125126

127+
126128
}

src/utils/sliding-window-rate-limiter.ts

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,67 @@ import { ICacheAdapter } from '../@types/adapters'
44

55
const logger = createLogger('sliding-window-rate-limiter')
66

7+
const SLIDING_WINDOW_RATE_LIMITER_LUA_SCRIPT = `
8+
local key = KEYS[1]
9+
local timestamp = tonumber(ARGV[1])
10+
local period = tonumber(ARGV[2])
11+
local step = tonumber(ARGV[3])
12+
local max_rate = tonumber(ARGV[4])
13+
14+
local windowStart = timestamp - period
15+
16+
redis.call('ZREMRANGEBYSCORE', key, 0, windowStart)
17+
18+
local entries = redis.call('ZRANGE', key, 0, -1)
19+
local hits = 0
20+
for i=1, #entries do
21+
local step_str = string.match(entries[i], "^[^:]+:([^:]+)")
22+
if step_str then
23+
local entry_step = tonumber(step_str)
24+
if entry_step then
25+
hits = hits + entry_step
26+
end
27+
end
28+
end
29+
30+
if hits + step > max_rate then
31+
return 1
32+
end
33+
34+
local base_member = timestamp .. ':' .. step
35+
local member = base_member
36+
local counter = 0
37+
while redis.call('ZSCORE', key, member) do
38+
counter = counter + 1
39+
member = base_member .. ':' .. counter
40+
end
41+
42+
redis.call('ZADD', key, timestamp, member)
43+
redis.call('PEXPIRE', key, period)
44+
45+
return 0
46+
`
47+
748
export class SlidingWindowRateLimiter implements IRateLimiter {
8-
public constructor(private readonly cache: ICacheAdapter) {}
49+
public constructor(
50+
private readonly cache: ICacheAdapter,
51+
) { }
952

1053
public async hit(key: string, step: number, options: IRateLimiterOptions): Promise<boolean> {
1154
const timestamp = Date.now()
12-
const { period } = options
55+
const { period, rate } = options
1356

14-
const [, , entries] = await Promise.all([
15-
this.cache.removeRangeByScoreFromSortedSet(key, 0, timestamp - period),
16-
this.cache.addToSortedSet(key, { [`${timestamp}:${step}`]: timestamp.toString() }),
17-
this.cache.getRangeFromSortedSet(key, 0, -1),
18-
this.cache.setKeyExpiry(key, period),
57+
const result = await this.cache.eval(SLIDING_WINDOW_RATE_LIMITER_LUA_SCRIPT, [key], [
58+
timestamp.toString(),
59+
period.toString(),
60+
step.toString(),
61+
rate.toString(),
1962
])
2063

21-
const hits = entries.reduce((acc, timestampAndStep) => acc + Number(timestampAndStep.split(':')[1]), 0)
64+
const isRateLimited = result === 1 || result === '1'
2265

23-
logger('hit count on %s bucket: %d', key, hits)
66+
logger('hit on %s bucket: is rate limited? %s', key, isRateLimited)
2467

25-
return hits > options.rate
68+
return isRateLimited
2669
}
2770
}

test/unit/utils/sliding-window-rate-limiter.spec.ts

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ describe('SlidingWindowRateLimiter', () => {
1717
let getKeyStub: Sinon.SinonStub
1818
let hasKeyStub: Sinon.SinonStub
1919
let setKeyStub: Sinon.SinonStub
20+
let evalStub: Sinon.SinonStub
2021

2122
let sandbox: Sinon.SinonSandbox
2223

@@ -30,6 +31,7 @@ describe('SlidingWindowRateLimiter', () => {
3031
getKeyStub = sandbox.stub()
3132
hasKeyStub = sandbox.stub()
3233
setKeyStub = sandbox.stub()
34+
evalStub = sandbox.stub()
3335
cache = {
3436
removeRangeByScoreFromSortedSet: removeRangeByScoreFromSortedSetStub,
3537
addToSortedSet: addToSortedSetStub,
@@ -38,7 +40,10 @@ describe('SlidingWindowRateLimiter', () => {
3840
getKey: getKeyStub,
3941
hasKey: hasKeyStub,
4042
setKey: setKeyStub,
43+
eval: evalStub,
4144
} as unknown as ICacheAdapter
45+
46+
4247
rateLimiter = new SlidingWindowRateLimiter(cache)
4348
})
4449

@@ -48,20 +53,32 @@ describe('SlidingWindowRateLimiter', () => {
4853
})
4954

5055
it('returns true if rate limited', async () => {
51-
const now = Date.now()
52-
getRangeFromSortedSetStub.resolves([`${now}:6`, `${now}:4`, `${now}:1`])
56+
evalStub.resolves(1)
5357

5458
const actualResult = await rateLimiter.hit('key', 1, { period: 60000, rate: 10 })
5559

5660
expect(actualResult).to.be.true
61+
expect(evalStub).to.have.been.calledOnce
62+
const args = evalStub.firstCall.args
63+
expect(args[1]).to.deep.equal(['key'])
64+
expect(args[2][1]).to.equal('60000') // period
65+
expect(args[2][2]).to.equal('1') // step
66+
expect(args[2][3]).to.equal('10') // max_rate
5767
})
5868

5969
it('returns false if not rate limited', async () => {
60-
const now = Date.now()
61-
getRangeFromSortedSetStub.resolves([`${now}:10`])
70+
evalStub.resolves(0)
6271

6372
const actualResult = await rateLimiter.hit('key', 1, { period: 60000, rate: 10 })
6473

6574
expect(actualResult).to.be.false
6675
})
76+
77+
it('robustly handles string return types from Redis', async () => {
78+
evalStub.resolves('1')
79+
80+
const actualResult = await rateLimiter.hit('key', 1, { period: 60000, rate: 10 })
81+
82+
expect(actualResult).to.be.true
83+
})
6784
})

0 commit comments

Comments
 (0)