feat(ml): add stateless bundle-local size-aware batching and benchmark#37532
Conversation
Summary of ChangesHello @Eliaaazzz, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances Apache Beam's capabilities for machine learning inference by introducing a novel size-aware batching mechanism. The new Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
Checks are failing. Will not request review until checks are succeeding. If you'd like to override that behavior, comment |
907ddfd to
501bf5c
Compare
|
Assigning reviewers: R: @jrmccluskey for label python. Note: If you would like to opt out of this review, comment Available commands:
The PR bot will only process comments in the main thread (not review comments). |
4948b6e to
142962d
Compare
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #37532 +/- ##
============================================
+ Coverage 56.97% 57.02% +0.04%
Complexity 3404 3404
============================================
Files 1177 1175 -2
Lines 187195 187192 -3
Branches 3581 3581
============================================
+ Hits 106661 106749 +88
+ Misses 77142 77051 -91
Partials 3392 3392
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
142962d to
541fd95
Compare
- Exclude *_benchmark.py from codecov (standalone scripts, not production code) - Remove redundant validation from internal DoFn classes (already validated by PTransform) - Add direct in-process unit tests for DoFn internals to capture coverage (FnApiRunner runs DoFns in separate process, invisible to coverage tools) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
ff83d22 to
c813692
Compare
c813692 to
9987114
Compare
|
Reminder, please take a look at this pr: @jrmccluskey |
|
Assigning new set of reviewers because Pr has gone too long without review. If you would like to opt out of this review, comment R: @shunping for label python. Available commands:
|
|
Reminder, please take a look at this pr: @shunping |
|
Assigning new set of reviewers because Pr has gone too long without review. If you would like to opt out of this review, comment R: @claudevdm for label python. Available commands:
|
|
Reminder, please take a look at this pr: @claudevdm |
|
Assigning new set of reviewers because Pr has gone too long without review. If you would like to opt out of this review, comment R: @tvalentyn for label python. Available commands:
|
|
Reminder, please take a look at this pr: @tvalentyn |
…sses (Part 3/3) Completes the smart bucketing feature (apache#37531) by exposing batch_length_fn and batch_bucket_boundaries parameters across all concrete ModelHandler implementations. This allows users to enable length-aware batching on any supported inference backend (PyTorch, TensorFlow, HuggingFace, sklearn, ONNX, XGBoost, TensorRT, vLLM, Vertex AI, Gemini) by simply passing these parameters to the handler constructor. - Adds batch_length_fn / batch_bucket_boundaries to 16 handler classes - Adds end-to-end test verifying short/long elements are bucketed separately through RunInference with FnApiRunner - Handlers using ModelHandler pass through to super().__init__() - Handlers using RemoteModelHandler (Gemini, Vertex AI) wire params directly into _batching_kwargs Part 1: apache#37532 (Stateless SortAndBatchElements) Part 2: apache#37565 (Stateful length-aware keying in BatchElements) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
Assigning new set of reviewers because Pr has gone too long without review. If you would like to opt out of this review, comment R: @damccorm for label python. Available commands:
|
damccorm
left a comment
There was a problem hiding this comment.
I took a quick look, but didn't do a full review yet - @jrmccluskey could you please take this one?
| by size, then splits batches using max_batch_weight so that each batch | ||
| has a bounded total weight. The improvement comes from *changing batch | ||
| boundaries* (weight-based splitting), NOT from sorting alone -- sorting | ||
| within fixed boundaries yields 0% gain (verified by strict-control). |
There was a problem hiding this comment.
The improvement comes from changing batch
boundaries (weight-based splitting), NOT from sorting alone
This is surprising to me - doesn't sorting the batch functionally change the batch boundaries as well?
There was a problem hiding this comment.
The improvement comes from changing batch
boundaries (weight-based splitting), NOT from sorting aloneThis is surprising to me - doesn't sorting the batch functionally change the batch boundaries as well?
It does. The comment is misleading. What I meant is that sorting alone (with fixed count-based boundaries) yields ~0% gain, which is what the strict-control ablation verified. The actual improvement comes from sorting combined with weight-based splitting, where similar-sized elements cluster together and the weight constraint produces tighter batches. I'll reframe the comment to make that clearer.
| try: | ||
| return len(element) | ||
| except TypeError: | ||
| return 1 |
There was a problem hiding this comment.
We should probably warn in this case since this is almost certainly not what the user desired. To avoid warning on every element, we can make this function a member of the _SortAndBatchElementsDoFn and use a member variable to track if we've already warned. That way we only warn once per DoFn instance
There was a problem hiding this comment.
We should probably warn in this case since this is almost certainly not what the user desired. To avoid warning on every element, we can make this function a member of the
_SortAndBatchElementsDoFnand use a member variable to track if we've already warned. That way we only warn once per DoFn instance
Makes sense, will do. I'll move this into _SortAndBatchElementsDoFn as a method and track with an instance variable so we only warn once per DoFn instance.
Reframe benchmark docstring to clarify that sorting combined with weight-based splitting drives the improvement. Move default element size fallback into DoFn instances with a one-time warning when len() is unsupported, so users know to provide a custom element_size_fn.
Replace deprecated jupyter labextension install/link workflow with pip-installable prebuilt extension for JupyterLab 4+ compatibility. - Add install.json for prebuilt extension discovery metadata - Add style/index.js CSS entry point and styleModule field in package.json - Include js in package.json files glob so style/index.js is published - Add Extensions and Extensions :: Prebuilt classifiers to pyproject.toml - Add missing src/yaml/* to tsconfig.json includes - Remove deprecated labextension install/link/build instructions from READMEs - Replace ipywidgets labextension install with pip install in Interactive README
|
Reminder, please take a look at this pr: @damccorm |
|
R: @jrmccluskey |
|
Stopping reviewer notifications for this pull request: review requested by someone other than the bot, ceding control. If you'd like to restart, comment |
jrmccluskey
left a comment
There was a problem hiding this comment.
Some questions and notes on the benchmark and the testing side, but I like the core implementation
| from typing import Any | ||
| from typing import Callable | ||
| from typing import Dict | ||
| from typing import List | ||
| from typing import Optional | ||
| from typing import Sequence | ||
| from typing import Tuple |
There was a problem hiding this comment.
Swap to the collections.abc equivalent for Callable and Sequence, use the native built-ins for dict, list, and tuple
|
|
||
| def check_sorted(batch): | ||
| lengths = [len(s) for s in batch] | ||
| assert lengths == sorted(lengths), ( |
There was a problem hiding this comment.
I would recommend listing the expected data here just for clarity, even if just in a comment block
| beam.Map(lambda batch, ts=beam.DoFn.TimestampParam: (len(batch), ts))) | ||
| assert_that(res, equal_to([(3, GlobalWindow().max_timestamp())])) | ||
|
|
||
| def test_padding_efficiency_improvement(self): |
There was a problem hiding this comment.
I'm not sold on this test in particular, since I think there's a bit of an incongruity as far as the batching approaches being used. Using BatchElements with a default element weighting of 1 compared to weighting based on length does create a favorable outcome for this test, but putting the approaches on the same weighing function actually produces a better padding overhead in the case of traditional BatchElements (albeit creating five batches of 1, which by the definition here has a padding overhead of 0.)
I don't disagree that sorting and batching has benefits, but I don't think we necessarily need a unit test to prove it.
| size/weight. | ||
| """ | ||
|
|
||
| _MAX_LIVE_WINDOWS = 10 |
There was a problem hiding this comment.
Is 10 an arbitrary number, or were there any experiences while testing that led you to this number? Would it be worth making this configurable as a kwarg in the DoFn?
| # limitations under the License. | ||
| # | ||
|
|
||
| """Benchmark: BatchElements vs SortAndBatchElements (weight-based splitting). |
There was a problem hiding this comment.
I generally like having something like this for proof-of-concept work and helping users pick the best options for their data. I don't love that the benchmark here doesn't actually use Beam or the BatchElements / SortAndBatchElements implementations directly, but considering that those implementations are generally pretty static and don't change often I'm okay including this code in this way.
- Reuse _WindowAwareBatchingDoFn._MAX_LIVE_WINDOWS instead of keeping a separate hard-coded limit in SortAndBatchElements. - Drop the padding-efficiency unit test that compared incongruent batching strategies and keep the transform tests focused on deterministic behavior. - Align benchmark typing with modern Python style by using collections.abc imports and native built-in generics. - Make the sorted-order test clearer by naming the expected batch contents explicitly.
|
@jrmccluskey Thank you for your reviews. I addressed all four points. I removed the separate hard-coded 10 and now reuse _WindowAwareBatchingDoFn._MAX_LIVE_WINDOWS so the two window-aware batching implementations stay aligned. I did not make it a public kwarg, since this is an internal buffering heuristic rather than part of the transform’s user-facing contract. Agreed on the padding-efficiency unit test. I removed it. It was comparing incongruent batching setups and trying to encode a performance claim rather than a deterministic semantic guarantee. The remaining unit tests now focus on deterministic behavior only. I updated the benchmark typing to use collections.abc for Callable / Sequence and native built-in generics for dict / list / tuple. I made the expected sorted batch contents explicit in the test for readability. |
jrmccluskey
left a comment
There was a problem hiding this comment.
Remaining failures are unrelated, LGTM. Thanks!
Updates #37531
Summary
This PR adds an opt-in stateless bundle-local size-aware batching path for variable-length inference workloads in
RunInference.It introduces
SortAndBatchElementsinapache_beam/transforms/util.py, which:StartBundle→FinishBundle)len(x), overridable viaelement_size_fn)max_batch_size,max_batch_weight)Default behavior remains unchanged unless this path is enabled.
Motivation
BatchElementsis count-based. For heavy-tail length distributions, long outliers can inflate padding cost for many short elements in the same batch, increasing tail latency and reducing effective throughput.This PR provides a stateless (bundle-local, no shuffle) way to improve batch composition under variable-length inputs.
Mechanism clarification
A strict-control ablation is included to isolate effects:
max_batch_weight: significant gainIn this workload, gains are primarily consistent with boundary changes under weight constraints after size-aware ordering, rather than intra-batch reordering alone.
Benchmark methodology
Script:
apache_beam/transforms/sort_and_batch_benchmark.pyPareto (heavy-tail) results
Configuration:
max_batch_size=32,max_batch_weight=2000Baseline → Stateless:
Here are the concrete improvements. Lower is better for padding_ratio and E2E latency; higher is better for throughput.
pareto
padding_ratio: 15.0622 -> 1.1743, a 92.20% reduction
E2E median: 12220.214 ms -> 948.606 ms, a 92.24% reduction
E2E p95: 12346.203 ms -> 1092.947 ms, a 91.15% reduction
throughput median: 1548.336 -> 19946.187 tok/s, a 1188.23% increase
pipeline runtime median: 328.974 ms -> 347.343 ms, Beam overhead itself was 5.58% slower
lognormal
padding_ratio: 4.0544 -> 1.0094, a 75.10% reduction
E2E median: 55431.920 ms -> 8632.172 ms, a 84.43% reduction
E2E p95: 55548.536 ms -> 8737.652 ms, a 84.27% reduction
throughput median: 4946.969 -> 31767.210 tok/s, a 542.16% increase
pipeline runtime median: 363.500 ms -> 377.133 ms, Beam overhead itself was 3.75% slower
bimodal
padding_ratio: 4.0008 -> 1.0079, a 74.81% reduction
E2E median: 56058.090 ms -> 10859.886 ms, a 80.63% reduction
E2E p95: 56123.844 ms -> 10915.934 ms, a 80.55% reduction
throughput median: 5138.759 -> 26525.980 tok/s, a 416.19% increase
pipeline runtime median: 383.392 ms -> 417.467 ms, Beam overhead itself was 8.89% slower
low_variance
padding_ratio: 1.2126 -> 1.0019, a 17.37% reduction
E2E median: 15396.794 ms -> 11688.260 ms, a 24.09% reduction
E2E p95: 15466.797 ms -> 11760.909 ms, a 23.96% reduction
throughput median: 25833.235 -> 34029.789 tok/s, a 31.73% increase
pipeline runtime median: 406.039 ms -> 395.115 ms, Beam overhead itself was 2.69% faster
Scope
Included in this PR:
SortAndBatchElements)Not included in this PR:
Files changed
apache_beam/transforms/util.pyapache_beam/transforms/util_test.pyapache_beam/transforms/sort_and_batch_benchmark.pyThank you for your contribution! Follow this checklist to help us incorporate your contribution quickly and easily:
addresses #123), if applicable. This will automatically add a link to the pull request in the issue. If you would like the issue to automatically close on merging the pull request, commentfixes #<ISSUE NUMBER>instead.