Skip to content

Commit 03d4b0d

Browse files
committed
Fix rate limiter wait
1 parent d3304e1 commit 03d4b0d

2 files changed

Lines changed: 197 additions & 25 deletions

File tree

src/pytest_load_testing/token_bucket_rate_limiter.py

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -137,24 +137,6 @@ def hourly_rate(self) -> int:
137137
rate = self._hourly_rate() if callable(self._hourly_rate) else self._hourly_rate
138138
return rate.calls_per_hour
139139

140-
def _get_or_initialize_state(self) -> Dict[str, Any]:
141-
"""Get the current state, initializing if necessary."""
142-
with self.shared_state.locked_dict() as state:
143-
if not state:
144-
current_time = time.time()
145-
state.update(
146-
{
147-
"start_time": current_time,
148-
"last_refill_time": current_time,
149-
"tokens": self.burst_capacity, # Start with full bucket
150-
"call_count": 0,
151-
"exceptions": 0,
152-
}
153-
)
154-
155-
# Return a copy to avoid modifications outside the lock
156-
return dict(state)
157-
158140
def _check_rate(self, state: Dict[str, Any]) -> None:
159141
"""Check if the current rate is within acceptable limits."""
160142
current_time = time.time()
@@ -228,16 +210,20 @@ def _calculate_wait_time_and_update(self) -> float:
228210
# Update tokens (can't exceed burst capacity)
229211
tokens = min(state["tokens"] + new_tokens, self.burst_capacity)
230212

231-
# If we have at least 1 token, we can proceed immediately
213+
# Always consume 1 token immediately, even if it makes tokens negative
214+
# This ensures proper serialization across multiple threads/processes
215+
# Negative tokens represent a "debt" that must be paid back with wait time
216+
state["tokens"] = tokens - 1
217+
state["last_refill_time"] = current_time
218+
219+
# If we had at least 1 token, we can proceed immediately
232220
if tokens >= 1:
233-
# Consume 1 token and update the state
234-
state["tokens"] = tokens - 1
235-
state["last_refill_time"] = current_time
236221
return 0
237222

238-
# Calculate wait time until we have 1 token
239-
wait_time = (1 - tokens) / tokens_per_second
240-
return max(0, wait_time)
223+
# Calculate wait time to pay back the token debt
224+
# We need to wait until tokens would have refilled to 0
225+
wait_time = abs(state["tokens"]) / tokens_per_second
226+
return wait_time
241227

242228
def _increment_call_count_and_check_rate(self) -> Tuple[int, Dict[str, Any]]:
243229
"""
@@ -275,6 +261,13 @@ def _check_max_calls(self, call_count: int) -> None:
275261
class RateLimitContext:
276262
"""
277263
Context object yielded by rate_limited_context that provides access to rate limiter metrics.
264+
265+
Properties:
266+
id: Rate limiter identifier
267+
hourly_rate: Configured rate limit in calls per hour
268+
call_count: Total number of calls made
269+
exceptions: Total number of exceptions encountered
270+
start_time: Unix timestamp of when the first call was made
278271
"""
279272

280273
_limiter: "TokenBucketRateLimiter"
@@ -296,6 +289,11 @@ def call_count(self) -> int:
296289
def exceptions(self) -> int:
297290
return self._state["exceptions"]
298291

292+
@property
293+
def start_time(self) -> float:
294+
"""Timestamp of when the first call was made (Unix timestamp)."""
295+
return self._state["start_time"]
296+
299297
@contextlib.contextmanager
300298
def rate_limited_context(self) -> Generator[RateLimitContext, Any, None]:
301299
"""
@@ -305,6 +303,7 @@ def rate_limited_context(self) -> Generator[RateLimitContext, Any, None]:
305303
with rate_limiter.rate_limited_context() as ctx:
306304
print(f"Using rate limiter {ctx.id} with rate {ctx.hourly_rate}/hr")
307305
print(f"Current call count: {ctx.call_count}")
306+
print(f"First call at: {ctx.start_time}")
308307
perform_action()
309308
"""
310309
# Calculate wait time and update tokens atomically
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
"""
2+
Tests for concurrent rate limiting behavior.
3+
4+
These tests verify that the token bucket rate limiter properly enforces
5+
rate limits when multiple threads/processes are competing for tokens.
6+
"""
7+
8+
import time
9+
from concurrent.futures import ThreadPoolExecutor
10+
11+
from pytest_load_testing.concurrent_fixtures import SharedJson
12+
from pytest_load_testing.token_bucket_rate_limiter import RateLimit, TokenBucketRateLimiter
13+
14+
15+
def test_concurrent_workers_respect_rate_limit(tmp_path):
16+
"""
17+
Test that multiple concurrent workers properly respect the rate limit.
18+
19+
This test verifies the fix for the bug where multiple threads could
20+
calculate wait times based on the same token state, leading to rate
21+
limit violations.
22+
23+
With the old buggy code, this test would fail because multiple threads
24+
would see the same token state and all proceed after the same wait time.
25+
26+
With the fix, tokens are consumed immediately (even going negative),
27+
ensuring proper serialization.
28+
"""
29+
# Create a rate limiter with 1 call per second and burst capacity of 1
30+
data_file = tmp_path / "concurrent_test.json"
31+
lock_file = tmp_path / "concurrent_test.lock"
32+
shared_state = SharedJson(
33+
data_file=data_file,
34+
lock_file=lock_file,
35+
)
36+
37+
limiter = TokenBucketRateLimiter(
38+
shared_state=shared_state,
39+
hourly_rate=RateLimit.per_second(1), # 1 call per second
40+
burst_capacity=1, # No burst allowance
41+
max_drift=0.5,
42+
num_calls_between_checks=1000,
43+
seconds_before_first_check=100.0,
44+
)
45+
46+
# Track execution times
47+
execution_times = []
48+
49+
def make_call():
50+
"""Make a rate-limited call and record the time."""
51+
with limiter.rate_limited_context():
52+
execution_times.append(time.time())
53+
54+
# Run 5 calls concurrently with 2 workers
55+
start_time = time.time()
56+
with ThreadPoolExecutor(max_workers=2) as executor:
57+
futures = [executor.submit(make_call) for _ in range(5)]
58+
for future in futures:
59+
future.result()
60+
61+
elapsed = time.time() - start_time
62+
63+
# With 1 call/second rate limit, 5 calls should take at least 4 seconds
64+
# (first call is immediate, then 4 more calls at 1/second)
65+
assert elapsed >= 4.0, f"Expected at least 4 seconds for 5 calls at 1/sec rate, but took only {elapsed:.2f}s"
66+
67+
# Verify calls were properly spaced
68+
# Sort execution times
69+
execution_times.sort()
70+
71+
# Check spacing between consecutive calls
72+
for i in range(1, len(execution_times)):
73+
gap = execution_times[i] - execution_times[i - 1]
74+
# Each gap should be at least 0.9 seconds (allowing small timing variance)
75+
assert gap >= 0.9, f"Gap between call {i - 1} and {i} was only {gap:.2f}s, expected at least 0.9s"
76+
77+
78+
def test_concurrent_workers_with_burst_capacity(tmp_path):
79+
"""
80+
Test that burst capacity allows initial rapid calls, then enforces rate limit.
81+
"""
82+
data_file = tmp_path / "burst_test.json"
83+
lock_file = tmp_path / "burst_test.lock"
84+
shared_state = SharedJson(
85+
data_file=data_file,
86+
lock_file=lock_file,
87+
)
88+
89+
limiter = TokenBucketRateLimiter(
90+
shared_state=shared_state,
91+
hourly_rate=RateLimit.per_second(1), # 1 call per second
92+
burst_capacity=3, # Allow 3 rapid calls
93+
max_drift=0.5,
94+
num_calls_between_checks=1000,
95+
seconds_before_first_check=100.0,
96+
)
97+
98+
execution_times = []
99+
100+
def make_call():
101+
with limiter.rate_limited_context():
102+
execution_times.append(time.time())
103+
104+
# Run 5 calls concurrently
105+
with ThreadPoolExecutor(max_workers=2) as executor:
106+
futures = [executor.submit(make_call) for _ in range(5)]
107+
for future in futures:
108+
future.result()
109+
110+
execution_times.sort()
111+
112+
# First 3 calls should be rapid (using burst capacity)
113+
first_three_duration = execution_times[2] - execution_times[0]
114+
assert first_three_duration < 0.5, f"First 3 calls should be rapid, but took {first_three_duration:.2f}s"
115+
116+
# Calls 4 and 5 should be rate-limited
117+
# They should take at least 1 second each after the burst
118+
gap_3_to_4 = execution_times[3] - execution_times[2]
119+
gap_4_to_5 = execution_times[4] - execution_times[3]
120+
121+
assert gap_3_to_4 >= 0.9, f"Gap from call 3 to 4 was only {gap_3_to_4:.2f}s, expected ~1s"
122+
assert gap_4_to_5 >= 0.9, f"Gap from call 4 to 5 was only {gap_4_to_5:.2f}s, expected ~1s"
123+
124+
125+
def test_negative_tokens_prevent_race_condition(tmp_path):
126+
"""
127+
Test that the fix properly prevents the race condition by allowing negative tokens.
128+
129+
This test specifically targets the bug where multiple threads could see
130+
the same positive token count and all calculate the same wait time.
131+
"""
132+
data_file = tmp_path / "negative_tokens_test.json"
133+
lock_file = tmp_path / "negative_tokens_test.lock"
134+
shared_state = SharedJson(
135+
data_file=data_file,
136+
lock_file=lock_file,
137+
)
138+
139+
limiter = TokenBucketRateLimiter(
140+
shared_state=shared_state,
141+
hourly_rate=RateLimit.per_second(2), # 2 calls per second
142+
burst_capacity=1, # Only 1 token available initially
143+
max_drift=0.5,
144+
num_calls_between_checks=1000,
145+
seconds_before_first_check=100.0,
146+
)
147+
148+
call_count = [0]
149+
150+
def make_call():
151+
with limiter.rate_limited_context():
152+
call_count[0] += 1
153+
154+
# Launch 4 calls simultaneously
155+
start_time = time.time()
156+
with ThreadPoolExecutor(max_workers=4) as executor:
157+
futures = [executor.submit(make_call) for _ in range(4)]
158+
for future in futures:
159+
future.result()
160+
161+
elapsed = time.time() - start_time
162+
163+
# With 2 calls/second and 4 calls:
164+
# - Call 1: immediate (uses burst token)
165+
# - Call 2: waits 0.5s (token debt of -1)
166+
# - Call 3: waits 1.0s (token debt of -2)
167+
# - Call 4: waits 1.5s (token debt of -3)
168+
# Total time should be at least 1.5 seconds
169+
assert elapsed >= 1.4, (
170+
f"Expected at least 1.4 seconds for 4 calls at 2/sec rate with burst=1, but took only {elapsed:.2f}s"
171+
)
172+
173+
assert call_count[0] == 4, f"Expected 4 calls, got {call_count[0]}"

0 commit comments

Comments
 (0)