Skip to content

Commit 815041e

Browse files
authored
Support for RateLimiter in Beam Remote Model Handler (#37218)
* Support for EnvoyRateLimiter in Apache Beam * fix format issues * fix test formatting * Fix test and syntax * fix lint * Add dependency based on python version * revert setup to separete pr * fix lint * fix formatting * resolve comments * Support Ratelimiter through RemoteModelHandler * fix lint * fix lint * fix comments * Add custom RateLimited Exception * fix doc * fix test * fix lint
1 parent 139724d commit 815041e

4 files changed

Lines changed: 206 additions & 6 deletions

File tree

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
"""A simple example demonstrating usage of the EnvoyRateLimiter with Vertex AI.
19+
"""
20+
21+
import argparse
22+
import logging
23+
24+
import apache_beam as beam
25+
from apache_beam.io.components.rate_limiter import EnvoyRateLimiter
26+
from apache_beam.ml.inference.base import RunInference
27+
from apache_beam.ml.inference.vertex_ai_inference import VertexAIModelHandlerJSON
28+
from apache_beam.options.pipeline_options import PipelineOptions
29+
from apache_beam.options.pipeline_options import SetupOptions
30+
31+
32+
def run(argv=None):
33+
parser = argparse.ArgumentParser()
34+
parser.add_argument(
35+
'--project',
36+
dest='project',
37+
help='The Google Cloud project ID for Vertex AI.')
38+
parser.add_argument(
39+
'--location',
40+
dest='location',
41+
help='The Google Cloud location (e.g. us-central1) for Vertex AI.')
42+
parser.add_argument(
43+
'--endpoint_id',
44+
dest='endpoint_id',
45+
help='The ID of the Vertex AI endpoint.')
46+
parser.add_argument(
47+
'--rls_address',
48+
dest='rls_address',
49+
help='The address of the Envoy Rate Limit Service (e.g. localhost:8081).')
50+
51+
known_args, pipeline_args = parser.parse_known_args(argv)
52+
pipeline_options = PipelineOptions(pipeline_args)
53+
pipeline_options.view_as(SetupOptions).save_main_session = True
54+
55+
# Initialize the EnvoyRateLimiter
56+
rate_limiter = EnvoyRateLimiter(
57+
service_address=known_args.rls_address,
58+
domain="mongo_cps",
59+
descriptors=[{
60+
"database": "users"
61+
}],
62+
namespace='example_pipeline')
63+
64+
# Initialize the VertexAIModelHandler with the rate limiter
65+
model_handler = VertexAIModelHandlerJSON(
66+
endpoint_id=known_args.endpoint_id,
67+
project=known_args.project,
68+
location=known_args.location,
69+
rate_limiter=rate_limiter)
70+
71+
# Input features for the model
72+
features = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0],
73+
[10.0, 11.0, 12.0], [13.0, 14.0, 15.0]]
74+
75+
with beam.Pipeline(options=pipeline_options) as p:
76+
_ = (
77+
p
78+
| 'CreateInputs' >> beam.Create(features)
79+
| 'RunInference' >> RunInference(model_handler)
80+
| 'PrintPredictions' >> beam.Map(logging.info))
81+
82+
83+
if __name__ == '__main__':
84+
logging.getLogger().setLevel(logging.INFO)
85+
run()

sdks/python/apache_beam/io/components/rate_limiter.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,12 @@ def __init__(self, namespace: str = ""):
6262

6363
@abc.abstractmethod
6464
def throttle(self, **kwargs) -> bool:
65-
"""Check if request should be throttled.
65+
"""Applies rate limiting to the request.
66+
67+
This method checks if the request is permitted by the rate limiting policy.
68+
Depending on the implementation and configuration, it may block (sleep)
69+
until the request is allowed, or return false if the rate limit retry is
70+
exceeded.
6671
6772
Args:
6873
**kwargs: Keyword arguments specific to the RateLimiter implementation.
@@ -78,8 +83,12 @@ def throttle(self, **kwargs) -> bool:
7883

7984

8085
class EnvoyRateLimiter(RateLimiter):
81-
"""
82-
Rate limiter implementation that uses an external Envoy Rate Limit Service.
86+
"""Rate limiter implementation that uses an external Envoy Rate Limit Service.
87+
88+
This limiter connects to a gRPC Envoy Rate Limit Service (RLS) to determine
89+
whether a request should be allowed. It supports defining a domain and a
90+
list of descriptors that correspond to the rate limit configuration in the
91+
RLS.
8392
"""
8493
def __init__(
8594
self,
@@ -89,7 +98,7 @@ def __init__(
8998
timeout: float = 5.0,
9099
block_until_allowed: bool = True,
91100
retries: int = 3,
92-
namespace: str = ""):
101+
namespace: str = ''):
93102
"""
94103
Args:
95104
service_address: Address of the Envoy RLS (e.g., 'localhost:8081').
@@ -140,7 +149,15 @@ def init_connection(self):
140149
self._stub = EnvoyRateLimiter.RateLimitServiceStub(channel)
141150

142151
def throttle(self, hits_added: int = 1) -> bool:
143-
"""Calls the Envoy RLS to check for rate limits.
152+
"""Calls the Envoy RLS to apply rate limits.
153+
154+
Sends a rate limit request to the configured Envoy Rate Limit Service.
155+
If 'block_until_allowed' is True, this method will sleep and retry
156+
if the limit is exceeded, effectively blocking until the request is
157+
permitted.
158+
159+
If 'block_until_allowed' is False, it will return False after the retry
160+
limit is exceeded.
144161
145162
Args:
146163
hits_added: Number of hits to add to the rate limit.
@@ -224,3 +241,16 @@ def throttle(self, hits_added: int = 1) -> bool:
224241
response.overall_code)
225242
break
226243
return throttled
244+
245+
def __getstate__(self):
246+
state = self.__dict__.copy()
247+
if '_lock' in state:
248+
del state['_lock']
249+
if '_stub' in state:
250+
del state['_stub']
251+
return state
252+
253+
def __setstate__(self, state):
254+
self.__dict__.update(state)
255+
self._lock = threading.Lock()
256+
self._stub = None

sdks/python/apache_beam/ml/inference/base.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656

5757
import apache_beam as beam
5858
from apache_beam.io.components.adaptive_throttler import ReactiveThrottler
59+
from apache_beam.io.components.rate_limiter import RateLimiter
5960
from apache_beam.utils import multi_process_shared
6061
from apache_beam.utils import retry
6162
from apache_beam.utils import shared
@@ -102,6 +103,11 @@ def __new__(cls, example, inference, model_id=None):
102103
PredictionResult.model_id.__doc__ = """Model ID used to run the prediction."""
103104

104105

106+
class RateLimitExceeded(RuntimeError):
107+
"""RateLimit Exceeded to process a batch of requests."""
108+
pass
109+
110+
105111
class ModelMetadata(NamedTuple):
106112
model_id: str
107113
model_name: str
@@ -349,7 +355,8 @@ def __init__(
349355
*,
350356
window_ms: int = 1 * _MILLISECOND_TO_SECOND,
351357
bucket_ms: int = 1 * _MILLISECOND_TO_SECOND,
352-
overload_ratio: float = 2):
358+
overload_ratio: float = 2,
359+
rate_limiter: Optional[RateLimiter] = None):
353360
"""Initializes a ReactiveThrottler class for enabling
354361
client-side throttling for remote calls to an inference service. Also wraps
355362
provided calls to the service with retry logic.
@@ -372,6 +379,7 @@ def __init__(
372379
overload_ratio: the target ratio between requests sent and successful
373380
requests. This is "K" in the formula in
374381
https://landing.google.com/sre/book/chapters/handling-overload.html.
382+
rate_limiter: A RateLimiter object for setting a global rate limit.
375383
"""
376384
# Configure ReactiveThrottler for client-side throttling behavior.
377385
self.throttler = ReactiveThrottler(
@@ -383,6 +391,9 @@ def __init__(
383391
self.logger = logging.getLogger(namespace)
384392
self.num_retries = num_retries
385393
self.retry_filter = retry_filter
394+
self._rate_limiter = rate_limiter
395+
self._shared_rate_limiter = None
396+
self._shared_handle = shared.Shared()
386397

387398
def __init_subclass__(cls):
388399
if cls.load_model is not RemoteModelHandler.load_model:
@@ -431,6 +442,19 @@ def run_inference(
431442
Returns:
432443
An Iterable of Predictions.
433444
"""
445+
if self._rate_limiter:
446+
if self._shared_rate_limiter is None:
447+
448+
def init_limiter():
449+
return self._rate_limiter
450+
451+
self._shared_rate_limiter = self._shared_handle.acquire(init_limiter)
452+
453+
if not self._shared_rate_limiter.throttle(hits_added=len(batch)):
454+
raise RateLimitExceeded(
455+
"Rate Limit Exceeded, "
456+
"Could not process this batch.")
457+
434458
self.throttler.throttle()
435459

436460
try:

sdks/python/apache_beam/ml/inference/base_test.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2071,6 +2071,67 @@ def run_inference(self,
20712071
responses.append(model.predict(example))
20722072
return responses
20732073

2074+
def test_run_inference_with_rate_limiter(self):
2075+
class FakeRateLimiter(base.RateLimiter):
2076+
def __init__(self):
2077+
super().__init__(namespace='test_namespace')
2078+
2079+
def throttle(self, hits_added=1):
2080+
self.requests_counter.inc()
2081+
return True
2082+
2083+
limiter = FakeRateLimiter()
2084+
2085+
with TestPipeline() as pipeline:
2086+
examples = [1, 5]
2087+
2088+
class ConcreteRemoteModelHandler(base.RemoteModelHandler):
2089+
def create_client(self):
2090+
return FakeModel()
2091+
2092+
def request(self, batch, model, inference_args=None):
2093+
return [model.predict(example) for example in batch]
2094+
2095+
model_handler = ConcreteRemoteModelHandler(
2096+
rate_limiter=limiter, namespace='test_namespace')
2097+
2098+
pcoll = pipeline | 'start' >> beam.Create(examples)
2099+
actual = pcoll | base.RunInference(model_handler)
2100+
2101+
expected = [2, 6]
2102+
assert_that(actual, equal_to(expected))
2103+
2104+
result = pipeline.run()
2105+
result.wait_until_finish()
2106+
2107+
metrics_filter = MetricsFilter().with_name(
2108+
'RatelimitRequestsTotal').with_namespace('test_namespace')
2109+
metrics = result.metrics().query(metrics_filter)
2110+
self.assertGreaterEqual(metrics['counters'][0].committed, 0)
2111+
2112+
def test_run_inference_with_rate_limiter_exceeded(self):
2113+
class FakeRateLimiter(base.RateLimiter):
2114+
def __init__(self):
2115+
super().__init__(namespace='test_namespace')
2116+
2117+
def throttle(self, hits_added=1):
2118+
return False
2119+
2120+
class ConcreteRemoteModelHandler(base.RemoteModelHandler):
2121+
def create_client(self):
2122+
return FakeModel()
2123+
2124+
def request(self, batch, model, inference_args=None):
2125+
return [model.predict(example) for example in batch]
2126+
2127+
model_handler = ConcreteRemoteModelHandler(
2128+
rate_limiter=FakeRateLimiter(),
2129+
namespace='test_namespace',
2130+
num_retries=0)
2131+
2132+
with self.assertRaises(base.RateLimitExceeded):
2133+
model_handler.run_inference([1], FakeModel())
2134+
20742135

20752136
if __name__ == '__main__':
20762137
unittest.main()

0 commit comments

Comments
 (0)