Skip to content

Commit 0aa8bcb

Browse files
committed
Add length-aware batching to BatchElements and ModelHandler
- Add length_fn and bucket_boundaries parameters to ModelHandler.__init__ to support length-aware bucketed keying for ML inference batching - Add WithLengthBucketKey DoFn to route elements by length buckets - Update BatchElements to support length-aware batching when max_batch_duration_secs is set, reducing padding waste for variable-length sequences (e.g., NLP workloads) - Default bucket boundaries: [16, 32, 64, 128, 256, 512] - Add comprehensive tests validating bucket assignment, mixed-length batching, and padding efficiency improvements (77% vs 68% on bimodal data) - All formatting (yapf) and lint (pylint 10/10) checks passed
1 parent 195cc59 commit 0aa8bcb

4 files changed

Lines changed: 335 additions & 2 deletions

File tree

sdks/python/apache_beam/ml/inference/base.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ def __init__(
178178
max_batch_duration_secs: Optional[int] = None,
179179
max_batch_weight: Optional[int] = None,
180180
element_size_fn: Optional[Callable[[Any], int]] = None,
181+
length_fn: Optional[Callable[[Any], int]] = None,
182+
bucket_boundaries: Optional[list[int]] = None,
181183
large_model: bool = False,
182184
model_copies: Optional[int] = None,
183185
**kwargs):
@@ -190,6 +192,11 @@ def __init__(
190192
before emitting; used in streaming contexts.
191193
max_batch_weight: the maximum weight of a batch. Requires element_size_fn.
192194
element_size_fn: a function that returns the size (weight) of an element.
195+
length_fn: a callable mapping an element to its length. When set with
196+
max_batch_duration_secs, enables length-aware bucketed keying so
197+
elements of similar length are batched together.
198+
bucket_boundaries: sorted list of positive boundary values for length
199+
bucketing. Requires length_fn.
193200
large_model: set to true if your model is large enough to run into
194201
memory pressure if you load multiple copies.
195202
model_copies: The exact number of models that you would like loaded
@@ -209,6 +216,10 @@ def __init__(
209216
self._batching_kwargs['max_batch_weight'] = max_batch_weight
210217
if element_size_fn is not None:
211218
self._batching_kwargs['element_size_fn'] = element_size_fn
219+
if length_fn is not None:
220+
self._batching_kwargs['length_fn'] = length_fn
221+
if bucket_boundaries is not None:
222+
self._batching_kwargs['bucket_boundaries'] = bucket_boundaries
212223
self._large_model = large_model
213224
self._model_copies = model_copies
214225
self._share_across_processes = large_model or (model_copies is not None)

sdks/python/apache_beam/ml/inference/base_test.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2278,6 +2278,43 @@ def test_max_batch_duration_secs_only(self):
22782278

22792279
self.assertEqual(kwargs, {'max_batch_duration_secs': 60})
22802280

2281+
def test_length_fn_and_bucket_boundaries(self):
2282+
"""length_fn and bucket_boundaries are passed through to kwargs."""
2283+
handler = FakeModelHandlerForBatching(
2284+
length_fn=len, bucket_boundaries=[16, 32, 64])
2285+
kwargs = handler.batch_elements_kwargs()
2286+
2287+
self.assertIs(kwargs['length_fn'], len)
2288+
self.assertEqual(kwargs['bucket_boundaries'], [16, 32, 64])
2289+
2290+
def test_length_fn_only(self):
2291+
"""length_fn alone is passed through without bucket_boundaries."""
2292+
handler = FakeModelHandlerForBatching(length_fn=len)
2293+
kwargs = handler.batch_elements_kwargs()
2294+
2295+
self.assertIs(kwargs['length_fn'], len)
2296+
self.assertNotIn('bucket_boundaries', kwargs)
2297+
2298+
def test_bucket_boundaries_without_length_fn(self):
2299+
"""Passing bucket_boundaries without length_fn should fail in BatchElements.
2300+
2301+
Note: ModelHandler.__init__ doesn't validate this; the error is raised
2302+
by BatchElements when batch_elements_kwargs are used."""
2303+
handler = FakeModelHandlerForBatching(bucket_boundaries=[10, 20])
2304+
kwargs = handler.batch_elements_kwargs()
2305+
# The kwargs are stored, but BatchElements will reject them
2306+
self.assertEqual(kwargs['bucket_boundaries'], [10, 20])
2307+
self.assertNotIn('length_fn', kwargs)
2308+
2309+
def test_batching_kwargs_none_values_omitted(self):
2310+
"""None values for length_fn and bucket_boundaries are not in kwargs."""
2311+
handler = FakeModelHandlerForBatching(
2312+
min_batch_size=5, length_fn=None, bucket_boundaries=None)
2313+
kwargs = handler.batch_elements_kwargs()
2314+
self.assertNotIn('length_fn', kwargs)
2315+
self.assertNotIn('bucket_boundaries', kwargs)
2316+
self.assertEqual(kwargs['min_batch_size'], 5)
2317+
22812318

22822319
class SimpleFakeModelHandler(base.ModelHandler[int, int, FakeModel]):
22832320
def load_model(self):

sdks/python/apache_beam/transforms/util.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
# pytype: skip-file
2222

23+
import bisect
2324
import collections
2425
import contextlib
2526
import hashlib
@@ -1209,6 +1210,28 @@ def process(self, element):
12091210
yield (self.key, element)
12101211

12111212

1213+
class WithLengthBucketKey(DoFn):
1214+
"""Keys elements with (worker_uuid, length_bucket) for length-aware
1215+
stateful batching. Elements of similar length are routed to the same
1216+
state partition, reducing padding waste."""
1217+
def __init__(self, length_fn, bucket_boundaries):
1218+
self.shared_handle = shared.Shared()
1219+
self._length_fn = length_fn
1220+
self._bucket_boundaries = bucket_boundaries
1221+
1222+
def setup(self):
1223+
self.key = self.shared_handle.acquire(
1224+
load_shared_key, "WithLengthBucketKey").key
1225+
1226+
def _get_bucket(self, length):
1227+
return bisect.bisect_left(self._bucket_boundaries, length)
1228+
1229+
def process(self, element):
1230+
length = self._length_fn(element)
1231+
bucket = self._get_bucket(length)
1232+
yield ((self.key, bucket), element)
1233+
1234+
12121235
@typehints.with_input_types(T)
12131236
@typehints.with_output_types(list[T])
12141237
class BatchElements(PTransform):
@@ -1268,7 +1291,18 @@ class BatchElements(PTransform):
12681291
donwstream operations (mostly for testing)
12691292
record_metrics: (optional) whether or not to record beam metrics on
12701293
distributions of the batch size. Defaults to True.
1294+
length_fn: (optional) a callable mapping an element to its length (int).
1295+
When set together with max_batch_duration_secs, enables length-aware
1296+
bucketed keying on the stateful path so that elements of similar length
1297+
are routed to the same batch, reducing padding waste.
1298+
bucket_boundaries: (optional) a sorted list of positive boundary values
1299+
for length bucketing. Elements with length < boundaries[i] go to
1300+
bucket i; overflow goes to bucket len(boundaries). Defaults to
1301+
[16, 32, 64, 128, 256, 512] when length_fn is set. Requires
1302+
length_fn.
12711303
"""
1304+
_DEFAULT_BUCKET_BOUNDARIES = [16, 32, 64, 128, 256, 512]
1305+
12721306
def __init__(
12731307
self,
12741308
min_batch_size=1,
@@ -1281,7 +1315,17 @@ def __init__(
12811315
element_size_fn=lambda x: 1,
12821316
variance=0.25,
12831317
clock=time.time,
1284-
record_metrics=True):
1318+
record_metrics=True,
1319+
length_fn=None,
1320+
bucket_boundaries=None):
1321+
if bucket_boundaries is not None and length_fn is None:
1322+
raise ValueError('bucket_boundaries requires length_fn to be set.')
1323+
if bucket_boundaries is not None:
1324+
if (not bucket_boundaries or any(b <= 0 for b in bucket_boundaries) or
1325+
bucket_boundaries != sorted(bucket_boundaries)):
1326+
raise ValueError(
1327+
'bucket_boundaries must be a non-empty sorted list of '
1328+
'positive values.')
12851329
self._batch_size_estimator = _BatchSizeEstimator(
12861330
min_batch_size=min_batch_size,
12871331
max_batch_size=max_batch_size,
@@ -1295,13 +1339,23 @@ def __init__(
12951339
self._element_size_fn = element_size_fn
12961340
self._max_batch_dur = max_batch_duration_secs
12971341
self._clock = clock
1342+
self._length_fn = length_fn
1343+
if length_fn is not None and bucket_boundaries is None:
1344+
self._bucket_boundaries = self._DEFAULT_BUCKET_BOUNDARIES
1345+
else:
1346+
self._bucket_boundaries = bucket_boundaries
12981347

12991348
def expand(self, pcoll):
13001349
if getattr(pcoll.pipeline.runner, 'is_streaming', False):
13011350
raise NotImplementedError("Requires stateful processing (BEAM-2687)")
13021351
elif self._max_batch_dur is not None:
13031352
coder = coders.registry.get_coder(pcoll)
1304-
return pcoll | ParDo(WithSharedKey()) | ParDo(
1353+
if self._length_fn is not None:
1354+
keying_dofn = WithLengthBucketKey(
1355+
self._length_fn, self._bucket_boundaries)
1356+
else:
1357+
keying_dofn = WithSharedKey()
1358+
return pcoll | ParDo(keying_dofn) | ParDo(
13051359
_pardo_stateful_batch_elements(
13061360
coder,
13071361
self._batch_size_estimator,

0 commit comments

Comments
 (0)