Skip to content

Commit 35a622e

Browse files
committed
Refine length bucketing docs and fix boundary inclusivity
Expands parameter documentation for clarity and replaces bisect_left with bisect_right to ensure bucket boundaries are inclusive on the lower bound. Updates util_test.py assertions accordingly.
1 parent 53454f3 commit 35a622e

4 files changed

Lines changed: 46 additions & 34 deletions

File tree

sdks/python/apache_beam/ml/inference/base.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,8 @@ def __init__(
178178
max_batch_duration_secs: Optional[int] = None,
179179
max_batch_weight: Optional[int] = None,
180180
element_size_fn: Optional[Callable[[Any], int]] = None,
181-
length_fn: Optional[Callable[[Any], int]] = None,
182-
bucket_boundaries: Optional[list[int]] = None,
181+
batch_length_fn: Optional[Callable[[Any], int]] = None,
182+
batch_bucket_boundaries: Optional[list[int]] = None,
183183
large_model: bool = False,
184184
model_copies: Optional[int] = None,
185185
**kwargs):
@@ -192,11 +192,17 @@ def __init__(
192192
before emitting; used in streaming contexts.
193193
max_batch_weight: the maximum weight of a batch. Requires element_size_fn.
194194
element_size_fn: a function that returns the size (weight) of an element.
195-
length_fn: a callable mapping an element to its length. When set with
196-
max_batch_duration_secs, enables length-aware bucketed keying so
197-
elements of similar length are batched together.
198-
bucket_boundaries: sorted list of positive boundary values for length
199-
bucketing. Requires length_fn.
195+
batch_length_fn: a callable mapping an element to its length (int). When
196+
set together with max_batch_duration_secs, enables length-aware bucketed
197+
keying so that elements of similar length are batched together, reducing
198+
padding waste for variable-length inputs. Bucket assignment uses
199+
bisect_right so boundaries are lower-inclusive: e.g., for boundaries
200+
[10, 50], buckets are (-inf, 10), [10, 50), [50, inf).
201+
batch_bucket_boundaries: a sorted list of positive boundary values for
202+
length bucketing. Boundaries are lower-inclusive (bisect_right
203+
semantics): bucket i covers lengths in [boundaries[i-1], boundaries[i]).
204+
Requires batch_length_fn. Defaults to [16, 32, 64, 128, 256, 512] when
205+
batch_length_fn is set.
200206
large_model: set to true if your model is large enough to run into
201207
memory pressure if you load multiple copies.
202208
model_copies: The exact number of models that you would like loaded
@@ -216,10 +222,10 @@ def __init__(
216222
self._batching_kwargs['max_batch_weight'] = max_batch_weight
217223
if element_size_fn is not None:
218224
self._batching_kwargs['element_size_fn'] = element_size_fn
219-
if length_fn is not None:
220-
self._batching_kwargs['length_fn'] = length_fn
221-
if bucket_boundaries is not None:
222-
self._batching_kwargs['bucket_boundaries'] = bucket_boundaries
225+
if batch_length_fn is not None:
226+
self._batching_kwargs['length_fn'] = batch_length_fn
227+
if batch_bucket_boundaries is not None:
228+
self._batching_kwargs['bucket_boundaries'] = batch_bucket_boundaries
223229
self._large_model = large_model
224230
self._model_copies = model_copies
225231
self._share_across_processes = large_model or (model_copies is not None)

sdks/python/apache_beam/ml/inference/base_test.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2279,38 +2279,40 @@ def test_max_batch_duration_secs_only(self):
22792279

22802280
self.assertEqual(kwargs, {'max_batch_duration_secs': 60})
22812281

2282-
def test_length_fn_and_bucket_boundaries(self):
2283-
"""length_fn and bucket_boundaries are passed through to kwargs."""
2282+
def test_batch_length_fn_and_batch_bucket_boundaries(self):
2283+
"""batch_length_fn and batch_bucket_boundaries passed through to kwargs."""
22842284
handler = FakeModelHandlerForBatching(
2285-
length_fn=len, bucket_boundaries=[16, 32, 64])
2285+
batch_length_fn=len, batch_bucket_boundaries=[16, 32, 64])
22862286
kwargs = handler.batch_elements_kwargs()
22872287

22882288
self.assertIs(kwargs['length_fn'], len)
22892289
self.assertEqual(kwargs['bucket_boundaries'], [16, 32, 64])
22902290

2291-
def test_length_fn_only(self):
2292-
"""length_fn alone is passed through without bucket_boundaries."""
2293-
handler = FakeModelHandlerForBatching(length_fn=len)
2291+
def test_batch_length_fn_only(self):
2292+
"""batch_length_fn alone is passed through without bucket_boundaries."""
2293+
handler = FakeModelHandlerForBatching(batch_length_fn=len)
22942294
kwargs = handler.batch_elements_kwargs()
22952295

22962296
self.assertIs(kwargs['length_fn'], len)
22972297
self.assertNotIn('bucket_boundaries', kwargs)
22982298

2299-
def test_bucket_boundaries_without_length_fn(self):
2300-
"""Passing bucket_boundaries without length_fn should fail in BatchElements.
2299+
def test_batch_bucket_boundaries_without_batch_length_fn(self):
2300+
"""Passing batch_bucket_boundaries without batch_length_fn should fail in
2301+
BatchElements.
23012302
23022303
Note: ModelHandler.__init__ doesn't validate this; the error is raised
23032304
by BatchElements when batch_elements_kwargs are used."""
2304-
handler = FakeModelHandlerForBatching(bucket_boundaries=[10, 20])
2305+
handler = FakeModelHandlerForBatching(batch_bucket_boundaries=[10, 20])
23052306
kwargs = handler.batch_elements_kwargs()
23062307
# The kwargs are stored, but BatchElements will reject them
23072308
self.assertEqual(kwargs['bucket_boundaries'], [10, 20])
23082309
self.assertNotIn('length_fn', kwargs)
23092310

23102311
def test_batching_kwargs_none_values_omitted(self):
2311-
"""None values for length_fn and bucket_boundaries are not in kwargs."""
2312+
"""None values for batch_length_fn and batch_bucket_boundaries are not in
2313+
kwargs."""
23122314
handler = FakeModelHandlerForBatching(
2313-
min_batch_size=5, length_fn=None, bucket_boundaries=None)
2315+
min_batch_size=5, batch_length_fn=None, batch_bucket_boundaries=None)
23142316
kwargs = handler.batch_elements_kwargs()
23152317
self.assertNotIn('length_fn', kwargs)
23162318
self.assertNotIn('bucket_boundaries', kwargs)

sdks/python/apache_beam/transforms/util.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1223,7 +1223,9 @@ def setup(self):
12231223
load_shared_key, "WithLengthBucketKey").key
12241224

12251225
def _get_bucket(self, length):
1226-
return bisect.bisect_left(self._bucket_boundaries, length)
1226+
# bisect_right: boundaries are lower-inclusive.
1227+
# e.g., for boundaries [10, 50], buckets are (-inf, 10), [10, 50), [50, inf)
1228+
return bisect.bisect_right(self._bucket_boundaries, length)
12271229

12281230
def process(self, element):
12291231
length = self._length_fn(element)
@@ -1291,14 +1293,14 @@ class BatchElements(PTransform):
12911293
record_metrics: (optional) whether or not to record beam metrics on
12921294
distributions of the batch size. Defaults to True.
12931295
length_fn: (optional) a callable mapping an element to its length (int).
1294-
When set together with max_batch_duration_secs, enables length-aware
1295-
bucketed keying on the stateful path so that elements of similar length
1296-
are routed to the same batch, reducing padding waste.
1296+
When set together with bucket_boundaries, enables length-aware bucketed
1297+
keying on the stateful path so that elements of similar length are
1298+
routed to the same batch, reducing padding waste.
12971299
bucket_boundaries: (optional) a sorted list of positive boundary values
1298-
for length bucketing. Elements with length < boundaries[i] go to
1299-
bucket i; overflow goes to bucket len(boundaries). Defaults to
1300-
[16, 32, 64, 128, 256, 512] when length_fn is set. Requires
1301-
length_fn.
1300+
for length bucketing. Boundaries are lower-inclusive (bisect_right
1301+
semantics): e.g., for boundaries [10, 50], buckets are (-inf, 10),
1302+
[10, 50), [50, inf). Defaults to [16, 32, 64, 128, 256, 512] when
1303+
length_fn is set. Requires length_fn.
13021304
"""
13031305
_DEFAULT_BUCKET_BOUNDARIES = [16, 32, 64, 128, 256, 512]
13041306

sdks/python/apache_beam/transforms/util_test.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,13 +1030,15 @@ def test_length_bucket_assignment(self):
10301030
"""WithLengthBucketKey assigns correct bucket indices."""
10311031
boundaries = [10, 50, 100]
10321032
dofn = util.WithLengthBucketKey(length_fn=len, bucket_boundaries=boundaries)
1033-
# bisect_left: length < 10 -> bucket 0, 10 <= length < 50 -> bucket 1, etc.
1033+
# bisect_right: boundaries are lower-inclusive.
1034+
# e.g., for boundaries [10, 50, 100], buckets are:
1035+
# (-inf, 10), [10, 50), [50, 100), [100, inf)
10341036
self.assertEqual(dofn._get_bucket(5), 0)
1035-
self.assertEqual(dofn._get_bucket(10), 0)
1037+
self.assertEqual(dofn._get_bucket(10), 1)
10361038
self.assertEqual(dofn._get_bucket(11), 1)
1037-
self.assertEqual(dofn._get_bucket(50), 1)
1039+
self.assertEqual(dofn._get_bucket(50), 2)
10381040
self.assertEqual(dofn._get_bucket(51), 2)
1039-
self.assertEqual(dofn._get_bucket(100), 2)
1041+
self.assertEqual(dofn._get_bucket(100), 3)
10401042
self.assertEqual(dofn._get_bucket(101), 3)
10411043
self.assertEqual(dofn._get_bucket(999), 3)
10421044

0 commit comments

Comments
 (0)