Skip to content

Commit d3304e1

Browse files
committed
Add a rate limiter
1 parent ea59c11 commit d3304e1

10 files changed

Lines changed: 930 additions & 59 deletions

File tree

examples/conftest.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,3 @@
44

55
# Re-export the fixture so it's available in examples
66
__all__ = ["shared_json_fixture_factory"]
7-
8-
# Made with Bob

src/pytest_load_testing/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22

33
from .api import stop_load_testing, weight
44
from .concurrent_fixtures import SharedJson, shared_json_fixture_factory
5+
from .token_bucket_rate_limiter import RateLimit, TokenBucketRateLimiter
56

67
__version__ = "0.1.0"
78
__all__ = [
89
"weight",
910
"stop_load_testing",
1011
"shared_json_fixture_factory",
1112
"SharedJson",
13+
"RateLimit",
14+
"TokenBucketRateLimiter",
1215
]

src/pytest_load_testing/concurrent_fixtures.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
# Constants
2020
INIT_LOCK_TIMEOUT = 30 # seconds for initialization lock acquisition
21+
SHARED_FILE_PREFIX = "pytest_shared_" # prefix for shared fixture files
2122

2223

2324
class SharedJson:
@@ -51,6 +52,19 @@ def __init__(self, data_file: Path, lock_file: Path, timeout: float = -1):
5152
self.timeout = timeout
5253
self._lock = FileLock(str(lock_file), timeout=timeout)
5354

55+
@property
56+
def name(self) -> str:
57+
"""Get the name derived from the data file path.
58+
59+
Returns:
60+
str: The stem (filename without extension) of the data file,
61+
with the pytest_shared_ prefix removed if present
62+
"""
63+
stem = self.data_file.stem
64+
if stem.startswith(SHARED_FILE_PREFIX):
65+
return stem[len(SHARED_FILE_PREFIX) :]
66+
return stem
67+
5468
@contextmanager
5569
def locked_dict(self):
5670
"""Context manager for atomic read-modify-write operations.
@@ -209,7 +223,7 @@ def factory(
209223
Raises:
210224
filelock.Timeout: If lock cannot be acquired within timeout period
211225
"""
212-
base_path = shared_temp / f"pytest_shared_{name}"
226+
base_path = shared_temp / f"{SHARED_FILE_PREFIX}{name}"
213227
data_file = base_path.with_suffix(".json")
214228
init_marker = base_path.with_name(f"{name}_init.marker")
215229
data_lock_file = base_path.with_name(f"{name}_data.lock")
Lines changed: 329 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,329 @@
1+
"""
2+
Token Bucket Rate Limiter
3+
4+
This module provides a token bucket rate limiter implementation that can be used
5+
to control the rate of operations across multiple processes.
6+
"""
7+
8+
import contextlib
9+
import logging
10+
import time
11+
from dataclasses import dataclass
12+
from typing import Any, Callable, Dict, Generator, Optional, Tuple, Union
13+
14+
from pytest_load_testing.concurrent_fixtures import SharedJson
15+
16+
logger = logging.getLogger(__name__)
17+
logger.setLevel(logging.INFO)
18+
19+
20+
class RateLimit:
21+
"""
22+
Represents a rate limit with convenient factory methods for different time units.
23+
24+
Examples:
25+
>>> rate = RateLimit.per_second(10) # 10 calls per second
26+
>>> rate = RateLimit.per_minute(600) # 600 calls per minute
27+
>>> rate = RateLimit.per_hour(3600) # 3600 calls per hour
28+
>>> rate = RateLimit.per_day(86400) # 86400 calls per day
29+
"""
30+
31+
def __init__(self, calls_per_hour: int):
32+
if calls_per_hour <= 0:
33+
raise ValueError("calls_per_hour must be positive")
34+
self._calls_per_hour = calls_per_hour
35+
36+
@property
37+
def calls_per_hour(self) -> int:
38+
return self._calls_per_hour
39+
40+
@classmethod
41+
def per_second(cls, calls: Union[int, float]) -> "RateLimit":
42+
return cls(int(calls * 3600))
43+
44+
@classmethod
45+
def per_minute(cls, calls: Union[int, float]) -> "RateLimit":
46+
return cls(int(calls * 60))
47+
48+
@classmethod
49+
def per_hour(cls, calls: int) -> "RateLimit":
50+
return cls(calls)
51+
52+
@classmethod
53+
def per_day(cls, calls: Union[int, float]) -> "RateLimit":
54+
return cls(int(calls / 24))
55+
56+
def __repr__(self) -> str:
57+
return f"RateLimit({self._calls_per_hour} calls/hour)"
58+
59+
60+
class TokenBucketRateLimiter:
61+
"""
62+
A token bucket rate limiter that tracks and limits the rate of operations.
63+
64+
This class implements the token bucket algorithm, a classical rate limiting
65+
algorithm that allows for controlled bursts of activity. It is designed to be
66+
used with pytest-xdist to coordinate rate limiting across multiple worker processes.
67+
"""
68+
69+
def __init__(
70+
self,
71+
shared_state: SharedJson,
72+
hourly_rate: Union[RateLimit, Callable[[], RateLimit]],
73+
max_drift: float = 0.1,
74+
on_drift_callback: Optional[Callable[[str, float, float, float], None]] = None,
75+
num_calls_between_checks: int = 10,
76+
seconds_before_first_check: float = 60.0,
77+
burst_capacity: Optional[int] = None,
78+
max_calls: int = -1,
79+
max_call_callback: Optional[Callable[[str, int], None]] = None,
80+
):
81+
"""
82+
Initialize a token bucket rate limiter.
83+
84+
Args:
85+
shared_state: SharedJson instance for state management across workers
86+
hourly_rate: Rate limit specification. Can be:
87+
- RateLimit: rate limit object (e.g., RateLimit.per_second(10))
88+
- Callable: function returning RateLimit
89+
max_drift: Maximum allowed drift from the expected rate (as a fraction)
90+
on_drift_callback: Callback function to execute when drift exceeds max_drift
91+
Function signature: (id: str, current_rate: float,
92+
target_rate: float, drift: float) -> None
93+
num_calls_between_checks: Number of calls between rate drift checks (default: 10)
94+
seconds_before_first_check: Minimum elapsed time (seconds) before rate checking begins
95+
(default: 60.0 seconds)
96+
burst_capacity: Maximum number of tokens that can be stored in the bucket
97+
(defaults to 10% of hourly rate or 1, whichever is larger)
98+
max_calls: Maximum number of calls allowed (-1 for unlimited)
99+
max_call_callback: Callback function to execute when max_calls is reached
100+
Function signature: (id: str, call_count: int) -> None
101+
"""
102+
# Validate input parameters
103+
if not 0 <= max_drift <= 1:
104+
raise ValueError(f"max_drift must be between 0 and 1, got {max_drift}")
105+
if num_calls_between_checks < 1:
106+
raise ValueError(f"num_calls_between_checks must be positive, got {num_calls_between_checks}")
107+
if seconds_before_first_check < 0:
108+
raise ValueError(f"seconds_before_first_check must be non-negative, got {seconds_before_first_check}")
109+
if burst_capacity is not None and burst_capacity < 1:
110+
raise ValueError(f"burst_capacity must be positive, got {burst_capacity}")
111+
112+
self.shared_state = shared_state
113+
self._hourly_rate = hourly_rate
114+
self.max_drift = max_drift
115+
self.on_drift_callback = on_drift_callback
116+
self.num_calls_between_checks = num_calls_between_checks
117+
self.seconds_before_first_check = seconds_before_first_check
118+
self.burst_capacity = (
119+
burst_capacity if burst_capacity is not None else self._calculate_default_burst_capacity(self.hourly_rate)
120+
)
121+
self.max_calls = max_calls
122+
self.max_call_callback = max_call_callback
123+
124+
@staticmethod
125+
def _calculate_default_burst_capacity(hourly_rate: int) -> int:
126+
"""Calculate default burst capacity as 10% of hourly rate, minimum 1."""
127+
return max(1, int(hourly_rate * 0.1))
128+
129+
@property
130+
def id(self) -> str:
131+
"""Get the identifier from the shared state name."""
132+
return self.shared_state.name
133+
134+
@property
135+
def hourly_rate(self) -> int:
136+
"""Get the current hourly rate."""
137+
rate = self._hourly_rate() if callable(self._hourly_rate) else self._hourly_rate
138+
return rate.calls_per_hour
139+
140+
def _get_or_initialize_state(self) -> Dict[str, Any]:
141+
"""Get the current state, initializing if necessary."""
142+
with self.shared_state.locked_dict() as state:
143+
if not state:
144+
current_time = time.time()
145+
state.update(
146+
{
147+
"start_time": current_time,
148+
"last_refill_time": current_time,
149+
"tokens": self.burst_capacity, # Start with full bucket
150+
"call_count": 0,
151+
"exceptions": 0,
152+
}
153+
)
154+
155+
# Return a copy to avoid modifications outside the lock
156+
return dict(state)
157+
158+
def _check_rate(self, state: Dict[str, Any]) -> None:
159+
"""Check if the current rate is within acceptable limits."""
160+
current_time = time.time()
161+
start_time = state["start_time"]
162+
elapsed_time = current_time - start_time
163+
164+
# Only check if we have enough data
165+
if elapsed_time < self.seconds_before_first_check:
166+
return
167+
168+
current_rate = (state["call_count"] / elapsed_time) * 3600
169+
target_rate = self.hourly_rate
170+
171+
# Calculate drift as a fraction of the target rate
172+
if target_rate > 0:
173+
drift = abs(current_rate - target_rate) / target_rate
174+
else:
175+
drift = 0 if current_rate == 0 else float("inf")
176+
177+
logger.info(
178+
f"Rate check for {self.shared_state.name}: current={current_rate:.2f}/hr, "
179+
f"target={target_rate}/hr, drift={drift:.2%}. "
180+
f"Total calls: {state['call_count']}. Exceptions: {state['exceptions']}"
181+
)
182+
183+
if drift > self.max_drift:
184+
message = (
185+
f"Rate drift for {self.shared_state.name} exceeds maximum allowed: "
186+
f"current={current_rate:.2f}/hr, target={target_rate}/hr, "
187+
f"drift={drift:.2%} (max allowed: {self.max_drift:.2%})"
188+
)
189+
logger.error(message)
190+
191+
if self.on_drift_callback:
192+
self.on_drift_callback(self.shared_state.name, current_rate, target_rate, drift)
193+
194+
def _calculate_wait_time_and_update(self) -> float:
195+
"""
196+
Calculate how long to wait before allowing the next call using token bucket algorithm.
197+
Updates state atomically within the lock.
198+
199+
The token bucket algorithm works by:
200+
1. Adding tokens to the bucket at a constant rate (the refill rate)
201+
2. When a request arrives, it takes a token from the bucket if one is available
202+
3. If no tokens are available, the request must wait until a token becomes available
203+
4. The bucket has a maximum capacity to limit bursts
204+
205+
Returns:
206+
float: Wait time in seconds (0 if can proceed immediately)
207+
"""
208+
current_time = time.time()
209+
210+
with self.shared_state.locked_dict() as state:
211+
# Initialize if needed
212+
if not state:
213+
state.update(
214+
{
215+
"start_time": current_time,
216+
"last_refill_time": current_time,
217+
"tokens": self.burst_capacity,
218+
"call_count": 0,
219+
"exceptions": 0,
220+
}
221+
)
222+
223+
# Calculate tokens to add based on time elapsed since last refill
224+
tokens_per_second = self.hourly_rate / 3600
225+
elapsed_seconds = current_time - state["last_refill_time"]
226+
new_tokens = elapsed_seconds * tokens_per_second
227+
228+
# Update tokens (can't exceed burst capacity)
229+
tokens = min(state["tokens"] + new_tokens, self.burst_capacity)
230+
231+
# If we have at least 1 token, we can proceed immediately
232+
if tokens >= 1:
233+
# Consume 1 token and update the state
234+
state["tokens"] = tokens - 1
235+
state["last_refill_time"] = current_time
236+
return 0
237+
238+
# Calculate wait time until we have 1 token
239+
wait_time = (1 - tokens) / tokens_per_second
240+
return max(0, wait_time)
241+
242+
def _increment_call_count_and_check_rate(self) -> Tuple[int, Dict[str, Any]]:
243+
"""
244+
Increment call count and check rate if needed.
245+
246+
Returns:
247+
tuple: (call_count, state_snapshot)
248+
"""
249+
with self.shared_state.locked_dict() as state:
250+
call_count = state["call_count"] + 1
251+
state["call_count"] = call_count
252+
253+
# Check rate periodically
254+
if call_count % self.num_calls_between_checks == 0:
255+
self._check_rate(state)
256+
257+
# Create snapshot for context
258+
state_snapshot = dict(state)
259+
260+
return call_count, state_snapshot
261+
262+
def _track_exception(self) -> None:
263+
"""Track that an exception occurred during rate-limited execution."""
264+
with self.shared_state.locked_dict() as state:
265+
state["exceptions"] = state.get("exceptions", 0) + 1
266+
267+
def _check_max_calls(self, call_count: int) -> None:
268+
"""Check if max_calls limit has been reached and invoke callback if configured."""
269+
if self.max_calls > 0 and call_count >= self.max_calls:
270+
logger.info(f"Rate limiter {self.shared_state.name} reached max_calls limit of {self.max_calls}")
271+
if self.max_call_callback:
272+
self.max_call_callback(self.shared_state.name, call_count)
273+
274+
@dataclass
275+
class RateLimitContext:
276+
"""
277+
Context object yielded by rate_limited_context that provides access to rate limiter metrics.
278+
"""
279+
280+
_limiter: "TokenBucketRateLimiter"
281+
_state: dict
282+
283+
@property
284+
def id(self) -> str:
285+
return self._limiter.shared_state.name
286+
287+
@property
288+
def hourly_rate(self) -> int:
289+
return self._limiter.hourly_rate
290+
291+
@property
292+
def call_count(self) -> int:
293+
return self._state["call_count"]
294+
295+
@property
296+
def exceptions(self) -> int:
297+
return self._state["exceptions"]
298+
299+
@contextlib.contextmanager
300+
def rate_limited_context(self) -> Generator[RateLimitContext, Any, None]:
301+
"""
302+
Context manager that rate-limits the enclosed code using token bucket algorithm.
303+
304+
Example:
305+
with rate_limiter.rate_limited_context() as ctx:
306+
print(f"Using rate limiter {ctx.id} with rate {ctx.hourly_rate}/hr")
307+
print(f"Current call count: {ctx.call_count}")
308+
perform_action()
309+
"""
310+
# Calculate wait time and update tokens atomically
311+
wait_time = self._calculate_wait_time_and_update()
312+
313+
if wait_time > 0:
314+
logger.debug(f"Token bucket rate limiter {self.id} waiting for {wait_time:.2f} seconds")
315+
time.sleep(wait_time)
316+
else:
317+
logger.debug(f"Token bucket rate limiter {self.id} can proceed immediately")
318+
319+
# Update call count and check rate
320+
call_count, state_snapshot = self._increment_call_count_and_check_rate()
321+
context = self.RateLimitContext(self, state_snapshot)
322+
323+
try:
324+
yield context
325+
except Exception:
326+
self._track_exception()
327+
raise
328+
finally:
329+
self._check_max_calls(call_count)

0 commit comments

Comments
 (0)