diff --git a/.github/codecov.yml b/.github/codecov.yml index 0936f392ccef..5d0eaccf22da 100644 --- a/.github/codecov.yml +++ b/.github/codecov.yml @@ -73,6 +73,7 @@ ignore: - "**/*_microbenchmark.py" - "sdks/go/pkg/beam/register/register.go" - "sdks/python/apache_beam/testing/benchmarks/nexmark/**" + - "**/*_benchmark.py" - "sdks/python/apache_beam/examples/**" # See https://docs.codecov.com/docs/flags for options. diff --git a/sdks/python/apache_beam/runners/interactive/README.md b/sdks/python/apache_beam/runners/interactive/README.md index ff6c57a94e61..f95b2765c3fa 100644 --- a/sdks/python/apache_beam/runners/interactive/README.md +++ b/sdks/python/apache_beam/runners/interactive/README.md @@ -244,23 +244,10 @@ a quick reference). For a more general and complete getting started guide, see jupyter kernelspec list ``` -* Extend JupyterLab through labextension. **Note**: labextension is different from nbextension - from pre-lab jupyter notebooks. - - All jupyter labextensions need nodejs - - ```bash - # Homebrew users do - brew install node - # Or Conda users do - conda install -c conda-forge nodejs - ``` - - Enable ipywidgets +* Install ipywidgets (includes the JupyterLab widget manager as a prebuilt extension): ```bash pip install ipywidgets - jupyter labextension install @jupyter-widgets/jupyterlab-manager ``` ### Start the notebook diff --git a/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/README.md b/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/README.md index 83fddf491f68..4c0baf3b2d53 100644 --- a/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/README.md +++ b/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/README.md @@ -31,41 +31,22 @@ Includes two different side panels: ## Installation -There are two ways to install the extension: - -### 1. Via pip (recommended) - -The extension is now available as a Python package on PyPI. You can install it with: +This extension is distributed as a prebuilt Python package. Install it with pip: ```bash pip install apache-beam-jupyterlab-sidepanel ``` -After installation, rebuild JupyterLab to activate the extension: - -```bash -jupyter lab clean -jupyter lab build -``` - -Then restart JupyterLab. The side panels will be available automatically. +Then restart JupyterLab. The side panels will be available automatically — no +`jupyter lab build` step is needed. - -### 2. Via JupyterLab Extension Manager (legacy, will be deprecated soon) +You can verify the extension is installed: ```bash -jupyter labextension install apache-beam-jupyterlab-sidepanel +jupyter labextension list ``` -This installs the extension using JupyterLab's legacy extension system. - ---- - -## Notes - -- Pip installation is now the preferred method as it handles Python packaging and JupyterLab extension registration seamlessly. -- After any upgrade or reinstallation, always rebuild JupyterLab to ensure the extension is activated. -- For detailed usage and development, refer to the source code and issues on [GitHub](https://github.com/apache/beam). +The extension should appear under the **prebuilt extensions** section. --- @@ -90,15 +71,12 @@ The `jlpm` command is JupyterLab's pinned version of # Install dependencies jlpm -# Build Typescript source -jlpm build -# Link your development version of the extension with JupyterLab -jupyter labextension link . -# Rebuild Typescript source after making changes -jlpm build -# Rebuild JupyterLab after making any changes -jupyter lab build +# Install the extension in editable mode (runs an initial JS build) +pip install -e . + +# Verify installation +jupyter labextension list ``` You can watch the source directory and run JupyterLab in watch mode to watch for changes in the extension's source and automatically rebuild the extension and application. @@ -110,7 +88,7 @@ jlpm watch jupyter lab --watch ``` -Now every change will be built locally and bundled into JupyterLab. Be sure to refresh your browser page after saving file changes to reload the extension (note: you'll need to wait for webpack to finish, which can take 10s+ at times). +Now every change will be built locally and bundled into JupyterLab. Be sure to refresh your browser page after saving file changes to reload the extension (note: you'll need to wait for the build to finish, which can take 10s+ at times). ### Test @@ -214,9 +192,5 @@ $PREFIX/share/jupyter/labextensions/apache-beam-jupyterlab-sidepanel/ ### Uninstall ```bash -jupyter labextension uninstall apache-beam-jupyterlab-sidepanel -``` -or -```bash -pip uninstall apache-beam-jupyterlab-sidepanel +pip uninstall apache_beam_jupyterlab_sidepanel ``` diff --git a/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/install.json b/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/install.json new file mode 100644 index 000000000000..3ef6567c6a81 --- /dev/null +++ b/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/install.json @@ -0,0 +1,5 @@ +{ + "packageManager": "python", + "packageName": "apache_beam_jupyterlab_sidepanel", + "uninstallInstructions": "Use your Python package manager (pip, conda, etc.) to uninstall the package apache_beam_jupyterlab_sidepanel" +} diff --git a/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/package.json b/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/package.json index eef3fcaa80f4..6bca80350ff7 100644 --- a/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/package.json +++ b/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/package.json @@ -15,7 +15,7 @@ "author": "apache-beam", "files": [ "lib/**/*.{d.ts,eot,gif,html,jpg,js,js.map,json,png,svg,woff2,ttf}", - "style/**/*.{css,eot,gif,html,jpg,json,png,svg,woff2,ttf}" + "style/**/*.{css,js,eot,gif,html,jpg,json,png,svg,woff2,ttf}" ], "main": "lib/index.js", "types": "lib/index.d.ts", @@ -100,6 +100,7 @@ "style/*.css", "style/index.js" ], + "styleModule": "style/index.js", "jupyterlab": { "extension": true, "outputDir": "apache_beam_jupyterlab_sidepanel/labextension" diff --git a/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/pyproject.toml b/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/pyproject.toml index 6831535a2c1e..a28fd40b2ca6 100644 --- a/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/pyproject.toml +++ b/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/pyproject.toml @@ -33,6 +33,8 @@ classifiers = [ "Framework :: Jupyter", "Framework :: Jupyter :: JupyterLab", "Framework :: Jupyter :: JupyterLab :: 4", + "Framework :: Jupyter :: JupyterLab :: Extensions", + "Framework :: Jupyter :: JupyterLab :: Extensions :: Prebuilt", "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3", ] diff --git a/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/style/index.js b/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/style/index.js new file mode 100644 index 000000000000..b533d5a9c6d5 --- /dev/null +++ b/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/style/index.js @@ -0,0 +1,13 @@ +// Licensed under the Apache License, Version 2.0 (the 'License'); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an 'AS IS' BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +import './index.css'; diff --git a/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/tsconfig.json b/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/tsconfig.json index c684cabf44a3..058bf17e1861 100644 --- a/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/tsconfig.json +++ b/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/tsconfig.json @@ -29,6 +29,7 @@ "src/common/*", "src/kernel/*", "src/inspector/*", + "src/yaml/*", "src/__tests__/**/*" ] } diff --git a/sdks/python/apache_beam/testing/benchmarks/sort_and_batch_benchmark.py b/sdks/python/apache_beam/testing/benchmarks/sort_and_batch_benchmark.py new file mode 100644 index 000000000000..695c4c2c995c --- /dev/null +++ b/sdks/python/apache_beam/testing/benchmarks/sort_and_batch_benchmark.py @@ -0,0 +1,655 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Benchmark: BatchElements vs SortAndBatchElements on real Beam pipelines. + +Compares two batching strategies for variable-length inference workloads by +running the actual Beam transforms under DirectRunner: + +- Baseline (BatchElements): fixed-count batching by setting + ``min_batch_size == max_batch_size``. +- Stateless (SortAndBatchElements): sorts elements by size within each runner + bundle, then splits batches using ``max_batch_weight``. + +The benchmark materializes per-batch summaries through a temporary Beam sink and +analyzes them after the pipeline completes. This keeps the benchmark on the +normal Beam execution path rather than relying on InteractiveRunner-specific +result materialization or local side effects. + +Bundle boundaries are runner-defined. As a result, these measurements are meant +to compare the actual DirectRunner behavior of the two transforms rather than a +synthetic, user-configurable bundle model. + +Padding ratio:: + + padding_ratio = sum(max_len_in_batch * batch_size) / sum(actual_lengths) + Lower is better. 1.0 = no padding waste. + +Methodology: + +- N=20 independent trials per condition (3 warmup trials excluded). +- Same input corpus (seed=42) for A/B comparison. +- DirectRunner with in-memory execution and one worker for reproducibility. +- Percentile method: linear interpolation between adjacent ranks + (equivalent to numpy.percentile with method='linear'). + For N=20 trials: P50 interpolates ranks 10-11 (0-indexed 9-10), + P95 interpolates ranks 19-20 (0-indexed 18-19), + P99 interpolates near rank 20 (0-indexed 18.81). +- Reports median [IQR] and P95 for each metric. +- Inference model: latency = batch_size * (max_seq_len / 50)^1.5 ms + (simulates downstream transformer-like scaling). + +Run:: + + python3 -m apache_beam.testing.benchmarks.sort_and_batch_benchmark +""" + +import glob +import json +import math +import os +import random +import statistics +import tempfile +import time +from collections.abc import Sequence +from typing import Any + +import apache_beam as beam +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.transforms import util + +# --------------------------------------------------------------------------- +# Data generators +# --------------------------------------------------------------------------- + + +def generate_highly_skewed_data( + num_elements: int, + min_length: int = 1, + max_length: int = 500, + seed: int = 42) -> list[str]: + """Pareto(alpha=1.2) -- most short, few very long.""" + random.seed(seed) + data = [] + for _ in range(num_elements): + length = int(random.paretovariate(1.2) * min_length) + length = min(max(length, min_length), max_length) + data.append('x' * length) + return data + + +def generate_lognormal_data( + num_elements: int, + mean_length: int = 50, + std_factor: float = 0.8, + min_length: int = 1, + max_length: int = 500, + seed: int = 42) -> list[str]: + """Log-normal -- moderate skew, typical NLP.""" + random.seed(seed) + mu = math.log(mean_length) + sigma = std_factor + data = [] + for _ in range(num_elements): + length = int(random.lognormvariate(mu, sigma)) + length = min(max(length, min_length), max_length) + data.append('x' * length) + return data + + +def generate_bimodal_data( + num_elements: int, + mode1_mean: int = 20, + mode2_mean: int = 200, + mode1_ratio: float = 0.7, + min_length: int = 1, + max_length: int = 500, + seed: int = 42) -> list[str]: + """Bimodal -- two distinct length groups.""" + random.seed(seed) + data = [] + for _ in range(num_elements): + if random.random() < mode1_ratio: + length = int(random.gauss(mode1_mean, mode1_mean * 0.3)) + else: + length = int(random.gauss(mode2_mean, mode2_mean * 0.3)) + length = min(max(length, min_length), max_length) + data.append('x' * length) + return data + + +def generate_low_variance_data( + num_elements: int, + mean_length: int = 100, + cv: float = 0.1, + min_length: int = 1, + max_length: int = 500, + seed: int = 42) -> list[str]: + """Low-variance control (CV=10%).""" + random.seed(seed) + std = mean_length * cv + data = [] + for _ in range(num_elements): + length = int(random.gauss(mean_length, std)) + length = min(max(length, min_length), max_length) + data.append('x' * length) + return data + + +# --------------------------------------------------------------------------- +# Real Beam batching +# --------------------------------------------------------------------------- + + +def _direct_runner_options() -> PipelineOptions: + return PipelineOptions([ + '--runner=DirectRunner', + '--direct_running_mode=in_memory', + '--direct_num_workers=1', + ]) + + +def _batch_to_json(batch: list[str]) -> str: + lengths = [len(element) for element in batch] + return json.dumps({ + 'batch_size': len(batch), + 'actual_total_length': sum(lengths), + 'max_len': max(lengths) if lengths else 0, + }) + + +def _read_batch_summaries(output_prefix: str) -> list[dict[str, int]]: + summaries = [] + for path in sorted(glob.glob(f'{output_prefix}*')): + if path.endswith('.crc'): + continue + with open(path, encoding='utf-8') as handle: + for line in handle: + line = line.strip() + if line: + summaries.append(json.loads(line)) + return summaries + + +def _run_batching_pipeline( + strategy: str, data: list[str], max_batch_size: int, + max_batch_weight: int) -> tuple[list[dict[str, int]], float]: + """Runs one Beam pipeline and returns batch summaries plus runtime.""" + with tempfile.TemporaryDirectory(prefix='beam_batch_benchmark_') as temp_dir: + output_prefix = os.path.join(temp_dir, strategy) + pipeline = beam.Pipeline(options=_direct_runner_options()) + batched = pipeline | 'CreateInput' >> beam.Create(data, reshuffle=False) + + if strategy == 'baseline': + batched = batched | 'BatchElements' >> util.BatchElements( + min_batch_size=max_batch_size, max_batch_size=max_batch_size) + elif strategy == 'stateless': + batched = batched | 'SortAndBatchElements' >> util.SortAndBatchElements( + min_batch_size=1, + max_batch_size=max_batch_size, + max_batch_weight=max_batch_weight) + else: + raise ValueError(f'Unknown strategy: {strategy}') + + _ = ( + batched + | 'SerializeBatchSummary' >> beam.Map(_batch_to_json) + | 'WriteBatchSummary' >> beam.io.WriteToText(output_prefix)) + + start = time.perf_counter() + result = pipeline.run() + result.wait_until_finish() + runtime_ms = (time.perf_counter() - start) * 1000 + + return _read_batch_summaries(output_prefix), runtime_ms + + +# --------------------------------------------------------------------------- +# Simulated inference +# --------------------------------------------------------------------------- + + +def simulate_inference_latency( + batch_size: int, max_len: int, base_latency_ms: float = 1.0) -> float: + """Simulate downstream inference: O(batch_size * seq_len^1.5).""" + if not batch_size or not max_len: + return 0.0 + return base_latency_ms * batch_size * (max_len / 50)**1.5 + + +# --------------------------------------------------------------------------- +# Stats helpers +# --------------------------------------------------------------------------- + + +def percentile(data: Sequence[float], p: float) -> float: + """Percentile via linear interpolation between adjacent ranks. + + Equivalent to numpy.percentile(data, p, method='linear'). + For N=20: P50 interpolates ranks 10-11, P95 ranks 19-20, + P99 near rank 20 (fractional index 18.81). + """ + if not data: + return 0.0 + s = sorted(data) + k = (len(s) - 1) * p / 100 + f = int(k) + c = min(f + 1, len(s) - 1) + return s[f] + (k - f) * (s[c] - s[f]) + + +def compute_padding_stats( + batch_summaries: list[dict[str, int]]) -> dict[str, Any]: + """Padding-efficiency statistics for materialized batch summaries.""" + total_actual = sum(s['actual_total_length'] for s in batch_summaries) + total_padded = sum(s['max_len'] * s['batch_size'] for s in batch_summaries) + batch_sizes = [s['batch_size'] for s in batch_summaries if s['batch_size']] + max_lengths = [s['max_len'] for s in batch_summaries if s['batch_size']] + + efficiency = total_actual / total_padded if total_padded else 0.0 + padding_ratio = total_padded / total_actual if total_actual else float('inf') + + return { + 'efficiency': efficiency, + 'padding_ratio': padding_ratio, + 'num_batches': len(batch_summaries), + 'avg_batch_size': statistics.mean(batch_sizes) if batch_sizes else 0, + 'total_actual_length': total_actual, + 'total_padded_length': total_padded, + 'padding_overhead': total_padded - total_actual, + 'batch_size_p50': percentile(batch_sizes, 50) if batch_sizes else 0, + 'batch_size_p95': percentile(batch_sizes, 95) if batch_sizes else 0, + 'batch_size_max': max(batch_sizes) if batch_sizes else 0, + 'max_len_p50': percentile(max_lengths, 50) if max_lengths else 0, + 'max_len_p95': percentile(max_lengths, 95) if max_lengths else 0, + } + + +# --------------------------------------------------------------------------- +# Invariant validation +# --------------------------------------------------------------------------- + + +def validate_invariants( + data: list[str], + baseline_summaries: list[dict[str, int]], + stateless_summaries: list[dict[str, int]]) -> dict[str, Any]: + """Validate element/token counts and batch-size equality.""" + n = len(data) + b_n = sum(s['batch_size'] for s in baseline_summaries) + s_n = sum(s['batch_size'] for s in stateless_summaries) + tok = sum(len(s) for s in data) + b_tok = sum(s['actual_total_length'] for s in baseline_summaries) + s_tok = sum(s['actual_total_length'] for s in stateless_summaries) + + return { + 'input_elements': n, + 'baseline_elements': b_n, + 'stateless_elements': s_n, + 'elements_match': n == b_n == s_n, + 'input_tokens': tok, + 'baseline_tokens': b_tok, + 'stateless_tokens': s_tok, + 'tokens_match': tok == b_tok == s_tok, + 'baseline_num_batches': len(baseline_summaries), + 'stateless_num_batches': len(stateless_summaries), + } + + +# --------------------------------------------------------------------------- +# Performance benchmark (N=20 trials) +# --------------------------------------------------------------------------- + + +def run_performance_benchmark( + data: list[str], + max_batch_size: int, + max_batch_weight: int, + num_trials: int = 20, + warmup_trials: int = 3 +) -> tuple[ + dict[str, Any], + dict[str, Any], + list[dict[str, int]], + list[dict[str, int]], +]: + """Run N=20 trials for baseline and stateless.""" + total_tokens = sum(len(s) for s in data) + + baseline_trials = [] + stateless_trials = [] + baseline_sample_summaries = [] + stateless_sample_summaries = [] + + for trial_idx in range(warmup_trials + num_trials): + is_warmup = trial_idx < warmup_trials + trial_results = {} + + if trial_idx % 2 == 0: + trial_order = ('baseline', 'stateless') + else: + trial_order = ('stateless', 'baseline') + + for strategy in trial_order: + summaries, runtime_ms = _run_batching_pipeline( + strategy, data, max_batch_size, max_batch_weight) + batch_latencies = [ + simulate_inference_latency(s['batch_size'], s['max_len']) + for s in summaries + ] + trial_results[strategy] = { + 'runtime_ms': runtime_ms, + 'inference_ms': sum(batch_latencies), + 'e2e_ms': runtime_ms + sum(batch_latencies), + 'batch_latencies': batch_latencies, + 'num_batches': len(summaries), + 'summaries': summaries, + } + + if not is_warmup: + baseline_trials.append(trial_results['baseline']) + stateless_trials.append(trial_results['stateless']) + if not baseline_sample_summaries: + baseline_sample_summaries = trial_results['baseline']['summaries'] + if not stateless_sample_summaries: + stateless_sample_summaries = trial_results['stateless']['summaries'] + + def _stats(trials): + e2e = [t['e2e_ms'] for t in trials] + tput = [total_tokens / (t['e2e_ms'] / 1000) for t in trials] + runtime = [t['runtime_ms'] for t in trials] + all_lat = [l for t in trials for l in t['batch_latencies']] + return { + 'e2e_median': percentile(e2e, 50), + 'e2e_p25': percentile(e2e, 25), + 'e2e_p75': percentile(e2e, 75), + 'e2e_p95': percentile(e2e, 95), + 'tput_median': percentile(tput, 50), + 'tput_p25': percentile(tput, 25), + 'tput_p75': percentile(tput, 75), + 'tput_p95': percentile(tput, 95), + 'runtime_median': percentile(runtime, 50), + 'runtime_p25': percentile(runtime, 25), + 'runtime_p75': percentile(runtime, 75), + 'runtime_p95': percentile(runtime, 95), + 'batch_lat_p50': percentile(all_lat, 50), + 'batch_lat_p95': percentile(all_lat, 95), + 'batch_lat_p99': percentile(all_lat, 99), + 'inf_p95': percentile(all_lat, 95), + 'num_trials': len(trials), + 'num_batches': trials[0]['num_batches'] if trials else 0, + } + + return ( + _stats(baseline_trials), + _stats(stateless_trials), + baseline_sample_summaries, + stateless_sample_summaries, + ) + + +# --------------------------------------------------------------------------- +# Single benchmark run +# --------------------------------------------------------------------------- + + +def run_benchmark( + num_elements: int = 10000, + min_length: int = 1, + max_length: int = 500, + max_batch_size: int = 32, + max_batch_weight: int = 2000, + distribution: str = 'pareto', + seed: int = 42) -> dict[str, Any]: + """Run baseline vs stateless comparison.""" + generators = { + 'pareto': lambda: generate_highly_skewed_data( + num_elements, min_length, max_length, seed), + 'lognormal': lambda: generate_lognormal_data( + num_elements, 50, 0.8, min_length, max_length, seed), + 'bimodal': lambda: generate_bimodal_data( + num_elements, 20, 200, 0.7, min_length, max_length, seed), + 'low_variance': lambda: generate_low_variance_data( + num_elements, 100, 0.1, min_length, max_length, seed), + } + if distribution not in generators: + raise ValueError(f"Unknown distribution: {distribution}") + + data = generators[distribution]() + lengths = [len(s) for s in data] + + baseline_perf, stateless_perf, baseline_summaries, stateless_summaries = ( + run_performance_benchmark(data, max_batch_size, max_batch_weight)) + baseline_pad = compute_padding_stats(baseline_summaries) + stateless_pad = compute_padding_stats(stateless_summaries) + baseline_pad.update(baseline_perf) + stateless_pad.update(stateless_perf) + + validation = validate_invariants( + data, baseline_summaries, stateless_summaries) + + return { + 'config': { + 'num_elements': num_elements, + 'max_batch_size': max_batch_size, + 'max_batch_weight': max_batch_weight, + 'distribution': distribution, + 'runner': 'DirectRunner', + }, + 'data_stats': { + 'min': min(lengths), + 'max': max(lengths), + 'mean': statistics.mean(lengths), + 'median': statistics.median(lengths), + 'std': statistics.stdev(lengths), + }, + 'baseline': baseline_pad, + 'stateless': stateless_pad, + 'validation': validation, + } + + +# --------------------------------------------------------------------------- +# Printing +# --------------------------------------------------------------------------- + + +def _fmt_iqr(median, p25, p75, unit=''): + return f"{median:.1f} [{p25:.1f}-{p75:.1f}]{unit}" + + +def print_results(results: dict[str, Any]) -> None: + cfg = results['config'] + ds = results['data_stats'] + bl = results['baseline'] + st = results['stateless'] + val = results['validation'] + + print("=" * 80) + print( + f"Distribution: {cfg['distribution']} | " + f"N={cfg['num_elements']} | " + f"runner={cfg['runner']} | " + f"max_batch_size={cfg['max_batch_size']} | " + f"max_batch_weight={cfg['max_batch_weight']}") + print( + f"Input lengths: min={ds['min']} max={ds['max']} " + f"mean={ds['mean']:.1f} median={ds['median']:.0f} std={ds['std']:.1f}") + print("-" * 80) + + def _arm(label, s): + print(f"\n {label}:") + print(f" Num batches: {s['num_batches']}") + print(f" Padding ratio: {s['padding_ratio']:.2f}x") + print(" ") + print(" Throughput (Ktok/s):") + med = s['tput_median'] / 1000 + p25 = s['tput_p25'] / 1000 + p75 = s['tput_p75'] / 1000 + print(f" Median [IQR]: {med:.1f}" + f" [{p25:.1f}-{p75:.1f}]") + print(f" P95: {s['tput_p95']/1000:.1f}") + print(" ") + print(" E2E latency (ms):") + print( + f" Median [IQR]: {s['e2e_median']:.1f}" + f" [{s['e2e_p25']:.1f}-{s['e2e_p75']:.1f}]") + print(f" P95: {s['e2e_p95']:.1f}") + print(" ") + print(" Pipeline runtime (ms):") + print( + f" Median [IQR]:" + f" {s['runtime_median']:.2f}" + f" [{s['runtime_p25']:.2f}" + f"-{s['runtime_p75']:.2f}]") + print(f" P95: {s['runtime_p95']:.2f}") + print(" ") + print(" Batch latency (ms):") + print(f" P50: {s['batch_lat_p50']:.1f}") + print(f" P95: {s['batch_lat_p95']:.1f}") + print(f" P99: {s['batch_lat_p99']:.1f}") + + _arm("Baseline (BatchElements)", bl) + _arm("Stateless (SortAndBatchElements w/ weight-based splitting)", st) + + # Explicit arrows so direction is unambiguous. + # down arrow = value decreased (good for latency/padding) + # up arrow = value increased (good for throughput) + def _delta_lower(base, new): + """For metrics where lower is better (latency, padding).""" + if base == 0: + return 'N/A' + pct = (base - new) / base * 100 + arrow = '\u2193' if pct > 0 else '\u2191' + return f"{arrow}{abs(pct):.1f}%" + + def _delta_higher(base, new): + """For metrics where higher is better (throughput).""" + if base == 0: + return 'N/A' + pct = (new - base) / base * 100 + arrow = '\u2191' if pct > 0 else '\u2193' + return f"{arrow}{abs(pct):.1f}%" + + print(f"\n {'_' * 76}") + print(" DELTA (Baseline -> Stateless):") + + def _line(label, bv, sv, delta_fn, fmt='.1f', unit=''): + d = delta_fn(bv, sv) + print(f" {label}: {bv:{fmt}}{unit}" + f" -> {sv:{fmt}}{unit} ({d})") + + bl_tmed = bl['tput_median'] / 1000 + st_tmed = st['tput_median'] / 1000 + bl_tp95 = bl['tput_p95'] / 1000 + st_tp95 = st['tput_p95'] / 1000 + + _line( + 'Padding ratio ', + bl['padding_ratio'], + st['padding_ratio'], + _delta_lower, + fmt='.2f', + unit='x') + _line('Throughput median', bl_tmed, st_tmed, _delta_higher, unit=' Ktok/s') + _line('Throughput p95 ', bl_tp95, st_tp95, _delta_higher, unit=' Ktok/s') + _line( + 'E2E latency med ', + bl['e2e_median'], + st['e2e_median'], + _delta_lower, + unit=' ms') + _line( + 'E2E latency p95 ', + bl['e2e_p95'], + st['e2e_p95'], + _delta_lower, + unit=' ms') + _line( + 'Pipeline runtime ', + bl['runtime_median'], + st['runtime_median'], + _delta_lower, + unit=' ms') + _line( + 'Batch lat p95 ', + bl['batch_lat_p95'], + st['batch_lat_p95'], + _delta_lower, + unit=' ms') + _line( + 'Batch lat p99 ', + bl['batch_lat_p99'], + st['batch_lat_p99'], + _delta_lower, + unit=' ms') + + # Invariants + e_ok = "Y" if val['elements_match'] else "X" + t_ok = "Y" if val['tokens_match'] else "X" + b_nb = val['baseline_num_batches'] + s_nb = val['stateless_num_batches'] + print( + f"\n Invariants: elements {e_ok} tokens {t_ok}" + f" (baseline {b_nb} -> stateless {s_nb}" + f" batches)") + print("=" * 80) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main(): + print("=" * 80) + print("BASELINE (BatchElements) vs STATELESS (SortAndBatchElements)") + print("=" * 80) + print() + print("Experiment design:") + print(" A = Baseline : BatchElements with min=max=32") + print(" B = Stateless : SortAndBatchElements with max_batch_weight=2000") + print(" (sort within runner bundle, then split by weight)") + print() + print("Implementation notes:") + print(" - Runs beam.Create(...) pipelines on DirectRunner") + print(" - Materializes per-batch summaries through a temporary text sink") + print(" - Uses runner-defined bundle boundaries rather than a synthetic") + print(" bundle_size knob") + print() + print("Methodology:") + print(" - N=20 trials, 3 warmup excluded") + print(" - DirectRunner, in_memory mode, single worker") + print(" - Percentiles: linear interpolation (= numpy default)") + print(" - Same seed=42 for both arms") + print(" - Inference model: latency = batch_size * (max_seq_len/50)^1.5 ms") + print() + + dist = 'pareto' + print(f"\nRunning: {dist}...") + r = run_benchmark( + num_elements=10000, + max_batch_size=32, + max_batch_weight=2000, + distribution=dist, + seed=42) + print_results(r) + + +if __name__ == '__main__': + main() diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index fbaab6b4ebbb..0e2693be7fcc 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -90,6 +90,8 @@ if TYPE_CHECKING: from apache_beam.runners.pipeline_context import PipelineContext +_LOGGER = logging.getLogger(__name__) + __all__ = [ 'BatchElements', 'CoGroupByKey', @@ -104,6 +106,7 @@ 'RemoveDuplicates', 'Reshuffle', 'Secret', + 'SortAndBatchElements', 'ToString', 'Tee', 'Values', @@ -1319,6 +1322,285 @@ def expand(self, pcoll): self._batch_size_estimator, self._element_size_fn)) +class _SortAndBatchElementsDoFn(DoFn): + """DoFn that buffers, sorts by element size, and batches elements. + + This DoFn is used internally by ``SortAndBatchElements`` for + PCollections with the default (global) window. It accumulates all + elements in the current bundle, sorts them by size in ascending order, + and emits optimally-sized batches on ``finish_bundle``. + + Args: + min_batch_size: The minimum number of elements per batch. Must be >= 1. + max_batch_size: The maximum number of elements per batch. + Must be >= ``min_batch_size``. + max_batch_weight: The maximum total weight of elements in a batch, + where weight is computed by ``element_size_fn``. Must be >= 1. + element_size_fn: An optional callable mapping an element to its integer + size/weight. + """ + def __init__( + self, + min_batch_size: int, + max_batch_size: int, + max_batch_weight: int, + element_size_fn: Optional[Callable[[Any], int]]): + self._min_batch_size = min_batch_size + self._max_batch_size = max_batch_size + self._max_batch_weight = max_batch_weight + self._element_size_fn = element_size_fn or self._default_element_size + self._has_warned_type_error = False + self._buffer = [] + + def _default_element_size(self, element): + try: + return len(element) + except TypeError: + if not self._has_warned_type_error: + _LOGGER.warning( + 'Element of type %s does not support len(). Falling back to ' + 'size 1. Consider providing a custom element_size_fn to ' + 'SortAndBatchElements for meaningful size-based batching.', + type(element).__name__) + self._has_warned_type_error = True + return 1 + + def start_bundle(self): + self._buffer = [] + + def process(self, element): + self._buffer.append(element) + + def finish_bundle(self): + if not self._buffer: + return + + # Sort elements by size (ascending) for optimal batching + # Elements of similar sizes will be grouped together + sorted_elements = sorted(self._buffer, key=self._element_size_fn) + + batch = [] + batch_weight = 0 + + for element in sorted_elements: + element_size = self._element_size_fn(element) + + # Check if adding this element would exceed limits + would_exceed_count = len(batch) >= self._max_batch_size + would_exceed_weight = ( + batch_weight + element_size >= self._max_batch_weight and batch) + + if would_exceed_count or would_exceed_weight: + # Emit current batch + yield window.GlobalWindows.windowed_value_at_end_of_window(batch) + batch = [] + batch_weight = 0 + + batch.append(element) + batch_weight += element_size + + # Emit remaining elements + if batch: + yield window.GlobalWindows.windowed_value_at_end_of_window(batch) + + self._buffer = None + + +class _WindowAwareSortAndBatchElementsDoFn(DoFn): + """DoFn that buffers, sorts by element size, and batches elements per window. + + This DoFn is used internally by ``SortAndBatchElements`` for + PCollections with non-default (e.g. fixed, sliding, or session) windows. + Elements are buffered per window and each window is flushed independently. + To prevent a single bundle from retaining too many per-window buffers at + once, when the number of live windows exceeds ``_MAX_LIVE_WINDOWS`` the + largest window buffer is flushed early. This DoFn reuses + ``_WindowAwareBatchingDoFn._MAX_LIVE_WINDOWS`` so it follows the same + existing window-aware batching behavior already used in this module. + + Args: + min_batch_size: The minimum number of elements per batch. Must be >= 1. + max_batch_size: The maximum number of elements per batch. + Must be >= ``min_batch_size``. + max_batch_weight: The maximum total weight of elements in a batch, + where weight is computed by ``element_size_fn``. Must be >= 1. + element_size_fn: An optional callable mapping an element to its integer + size/weight. + """ + + _MAX_LIVE_WINDOWS = _WindowAwareBatchingDoFn._MAX_LIVE_WINDOWS + + def __init__( + self, + min_batch_size: int, + max_batch_size: int, + max_batch_weight: int, + element_size_fn: Optional[Callable[[Any], int]]): + self._min_batch_size = min_batch_size + self._max_batch_size = max_batch_size + self._max_batch_weight = max_batch_weight + self._element_size_fn = element_size_fn or self._default_element_size + self._has_warned_type_error = False + self._buffers = collections.defaultdict(list) + + def _default_element_size(self, element): + try: + return len(element) + except TypeError: + if not self._has_warned_type_error: + _LOGGER.warning( + 'Element of type %s does not support len(). Falling back to ' + 'size 1. Consider providing a custom element_size_fn to ' + 'SortAndBatchElements for meaningful size-based batching.', + type(element).__name__) + self._has_warned_type_error = True + return 1 + + def start_bundle(self): + self._buffers = collections.defaultdict(list) + + def process(self, element, window=DoFn.WindowParam): + self._buffers[window].append(element) + + # If we have too many live windows, flush the largest one + if len(self._buffers) > self._MAX_LIVE_WINDOWS: + largest_window = max( + self._buffers.keys(), key=lambda w: len(self._buffers[w])) + yield from self._flush_window(largest_window) + + def _flush_window(self, win): + """Flush all elements for a given window.""" + buffer = self._buffers.pop(win, []) + if not buffer: + return + + # Sort elements by size (ascending) + sorted_elements = sorted(buffer, key=self._element_size_fn) + + batch = [] + batch_weight = 0 + + for element in sorted_elements: + element_size = self._element_size_fn(element) + + would_exceed_count = len(batch) >= self._max_batch_size + would_exceed_weight = ( + batch_weight + element_size >= self._max_batch_weight and batch) + + if would_exceed_count or would_exceed_weight: + yield windowed_value.WindowedValue(batch, win.max_timestamp(), (win, )) + batch = [] + batch_weight = 0 + + batch.append(element) + batch_weight += element_size + + if batch: + yield windowed_value.WindowedValue(batch, win.max_timestamp(), (win, )) + + def finish_bundle(self): + for win in list(self._buffers.keys()): + yield from self._flush_window(win) + self._buffers = None + + +@typehints.with_input_types(T) +@typehints.with_output_types(list[T]) +class SortAndBatchElements(PTransform): + """A Transform that sorts elements by size before batching. + + This transform is designed to optimize batch processing by grouping elements + of similar sizes together. This is particularly useful for ML inference + workloads where input sequences of varying lengths need to be padded to the + maximum length in the batch - by sorting elements by size before batching, + padding overhead is minimized. + + The transform consumes a PCollection of element type T and produces a + PCollection of element type list[T], where elements within each batch are + sorted by their size (as determined by element_size_fn). + + Elements are batched per-window and batches emitted in the window + corresponding to its contents. Each batch is emitted with a timestamp at + the end of their window. + + Unlike BatchElements which emits batches as soon as size limits are reached, + SortAndBatchElements buffers all elements in a bundle, sorts them by size, + and then creates optimally-sized batches. This trade-off of increased memory + usage for better batch homogeneity can significantly reduce padding overhead. + + Args: + min_batch_size: The minimum number of elements in a batch. Must be >= 1. + max_batch_size: The maximum number of elements in a batch. + Must be >= min_batch_size. + max_batch_weight: The maximum total weight of elements in a batch, + where weight is computed by element_size_fn. Must be >= 1. + element_size_fn: (optional) A function mapping an element to its + size/weight. + If not provided, defaults to trying len(element) and falling back to 1 + if the element doesn't support len(). This default allows sorting to + work for common types like strings, lists, and arrays. + + Example usage:: + + # Batch strings by total character count + strings = ['a', 'bb', 'ccc', 'dddd', 'eeeee'] + batched = strings | SortAndBatchElements( + min_batch_size=1, + max_batch_size=3, + max_batch_weight=10) + # Possible output: [['a', 'bb', 'ccc'], ['dddd', 'eeeee']] + # Elements are sorted by length and batched optimally + + # Batch with custom size function + data = [{'text': 'short'}, {'text': 'medium text'}, + {'text': 'long text here'}] + batched = data | SortAndBatchElements( + min_batch_size=1, + max_batch_size=10, + max_batch_weight=100, + element_size_fn=lambda x: len(x['text'])) + """ + def __init__( + self, + min_batch_size: int, + max_batch_size: int, + max_batch_weight: int, + element_size_fn: Optional[Callable[[Any], int]] = None): + if min_batch_size < 1: + raise ValueError(f'min_batch_size must be >= 1, got {min_batch_size}') + if max_batch_size < min_batch_size: + raise ValueError( + f'max_batch_size ({max_batch_size}) must be >= ' + f'min_batch_size ({min_batch_size})') + if max_batch_weight < 1: + raise ValueError(f'max_batch_weight must be >= 1, got {max_batch_weight}') + if element_size_fn is not None and not callable(element_size_fn): + raise TypeError('element_size_fn must be callable') + + self._min_batch_size = min_batch_size + self._max_batch_size = max_batch_size + self._max_batch_weight = max_batch_weight + + # None means the DoFn will use its own _default_element_size method, + # which tries len() and warns once on TypeError before falling back to 1. + self._element_size_fn = element_size_fn + + def expand(self, pcoll): + if pcoll.windowing.is_default(): + return pcoll | ParDo( + _SortAndBatchElementsDoFn( + self._min_batch_size, + self._max_batch_size, + self._max_batch_weight, + self._element_size_fn)) + return pcoll | ParDo( + _WindowAwareSortAndBatchElementsDoFn( + self._min_batch_size, + self._max_batch_size, + self._max_batch_weight, + self._element_size_fn)) + + class _IdentityWindowFn(NonMergingWindowFn): """Windowing function that preserves existing windows. diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py index 7389568691cd..30f34b59d3f5 100644 --- a/sdks/python/apache_beam/transforms/util_test.py +++ b/sdks/python/apache_beam/transforms/util_test.py @@ -1026,6 +1026,341 @@ def test_stateful_grows_to_max_batch(self): assert_that(res, equal_to([1, 1, 2, 4, 8, 16, 32, 50, 50])) +class SortAndBatchElementsTest(unittest.TestCase): + """Tests for SortAndBatchElements transform.""" + def test_elements_are_sorted_by_size(self): + """Test that elements are sorted by size within batches.""" + with TestPipeline() as p: + # Create elements with varying sizes + data = ['aaaaa', 'bb', 'cccc', 'a', 'ddd'] + expected = [['a', 'bb', 'ddd', 'cccc', 'aaaaa']] + res = ( + p + | beam.Create(data, reshuffle=False) + | util.SortAndBatchElements( + min_batch_size=1, max_batch_size=5, max_batch_weight=100)) + # All elements fit in one batch, so the expected order is explicit. + assert_that(res, equal_to(expected)) + + def test_batch_respects_max_batch_size(self): + """Test that batches do not exceed max_batch_size.""" + with TestPipeline() as p: + res = ( + p + | beam.Create(['a'] * 10, reshuffle=False) + | util.SortAndBatchElements( + min_batch_size=1, max_batch_size=3, max_batch_weight=100) + | beam.Map(len)) + assert_that(res, equal_to([3, 3, 3, 1])) + + def test_batch_respects_max_batch_weight(self): + """Test that batches do not exceed max_batch_weight.""" + with TestPipeline() as p: + # Each element has size 5, max_batch_weight is 12 + # So we can fit at most 2 elements per batch + data = ['aaaaa', 'bbbbb', 'ccccc', 'ddddd'] + res = ( + p + | beam.Create(data, reshuffle=False) + | util.SortAndBatchElements( + min_batch_size=1, max_batch_size=10, max_batch_weight=12) + | beam.Map(len)) + assert_that(res, equal_to([2, 2])) + + def test_default_element_size_fn_with_strings(self): + """Test default element_size_fn works with strings.""" + with TestPipeline() as p: + data = ['a', 'bbb', 'cc'] + res = ( + p + | beam.Create(data, reshuffle=False) + | util.SortAndBatchElements( + min_batch_size=1, max_batch_size=3, max_batch_weight=100) + | beam.FlatMap(lambda batch: [len(s) for s in batch])) + # Elements should be sorted by length: 'a'(1), 'cc'(2), 'bbb'(3) + assert_that(res, equal_to([1, 2, 3])) + + def test_default_element_size_fn_with_integers(self): + """Test default element_size_fn falls back to 1 for integers.""" + with TestPipeline() as p: + data = [10, 20, 30, 40, 50] + res = ( + p + | beam.Create(data, reshuffle=False) + | util.SortAndBatchElements( + min_batch_size=1, max_batch_size=3, max_batch_weight=100) + | beam.Map(len)) + # With size=1 for all, should batch by max_batch_size + assert_that(res, equal_to([3, 2])) + + def test_custom_element_size_fn(self): + """Test using a custom element_size_fn.""" + with TestPipeline() as p: + data = [{'text': 'a'}, {'text': 'bbb'}, {'text': 'cc'}] + res = ( + p + | beam.Create(data, reshuffle=False) + | util.SortAndBatchElements( + min_batch_size=1, + max_batch_size=3, + max_batch_weight=100, + element_size_fn=lambda x: len(x['text'])) + | beam.FlatMap(lambda batch: [len(e['text']) for e in batch])) + # Should be sorted by text length + assert_that(res, equal_to([1, 2, 3])) + + def test_empty_input(self): + """Test with empty input produces no output.""" + with TestPipeline() as p: + res = ( + p + | beam.Create([], reshuffle=False) + | util.SortAndBatchElements( + min_batch_size=1, max_batch_size=10, max_batch_weight=100) + | beam.Map(len)) + assert_that(res, equal_to([])) + + def test_single_element(self): + """Test with a single element.""" + with TestPipeline() as p: + res = ( + p + | beam.Create(['hello'], reshuffle=False) + | util.SortAndBatchElements( + min_batch_size=1, max_batch_size=10, max_batch_weight=100)) + assert_that(res, equal_to([['hello']])) + + def test_windowed_batches(self): + """Test that windowed elements are batched per window.""" + with TestPipeline('FnApiRunner') as p: + res = ( + p + | beam.Create(range(1, 8), reshuffle=False) + | beam.Map(lambda t: window.TimestampedValue('a' * t, t)) + | beam.WindowInto(window.FixedWindows(3)) + | util.SortAndBatchElements( + min_batch_size=1, max_batch_size=10, max_batch_weight=100) + | beam.Map(lambda batch: ''.join(batch))) + # FixedWindows(3) with default offset 0 produces: + # Window [0, 3): elements at t=1,2 with sizes 1,2 + # Window [3, 6): elements at t=3,4,5 with sizes 3,4,5 + # Window [6, 9): elements at t=6,7 with sizes 6,7 + assert_that( + res, + equal_to([ + 'a' * (1 + 2), # Window [0, 3) + 'a' * (3 + 4 + 5), # Window [3, 6) + 'a' * (6 + 7), # Window [6, 9) + ])) + + def test_validation_min_batch_size(self): + """Test that min_batch_size validation raises ValueError.""" + with self.assertRaises(ValueError) as cm: + util.SortAndBatchElements( + min_batch_size=0, max_batch_size=10, max_batch_weight=100) + self.assertIn('min_batch_size must be >= 1', str(cm.exception)) + + def test_validation_max_batch_size(self): + """Test that max_batch_size < min_batch_size raises ValueError.""" + with self.assertRaises(ValueError) as cm: + util.SortAndBatchElements( + min_batch_size=10, max_batch_size=5, max_batch_weight=100) + self.assertIn('max_batch_size', str(cm.exception)) + self.assertIn('min_batch_size', str(cm.exception)) + + def test_validation_max_batch_weight(self): + """Test that max_batch_weight validation raises ValueError.""" + with self.assertRaises(ValueError) as cm: + util.SortAndBatchElements( + min_batch_size=1, max_batch_size=10, max_batch_weight=0) + self.assertIn('max_batch_weight must be >= 1', str(cm.exception)) + + def test_validation_element_size_fn_callable(self): + """Test that a non-callable element_size_fn raises TypeError.""" + with self.assertRaises(TypeError) as cm: + util.SortAndBatchElements( + min_batch_size=1, + max_batch_size=10, + max_batch_weight=100, + element_size_fn=123) + self.assertIn('element_size_fn must be callable', str(cm.exception)) + + def test_batch_timestamps(self): + """Test that batches have correct timestamps.""" + with TestPipeline('FnApiRunner') as p: + res = ( + p + | beam.Create(['a', 'bb', 'ccc'], reshuffle=False) + | util.SortAndBatchElements( + min_batch_size=1, max_batch_size=10, max_batch_weight=100) + | + beam.Map(lambda batch, ts=beam.DoFn.TimestampParam: (len(batch), ts))) + # The single global-window batch is emitted at end-of-window. + expected = [(3, GlobalWindow().max_timestamp())] + assert_that(res, equal_to(expected)) + + +class SortAndBatchElementsDoFnDirectTest(unittest.TestCase): + """Direct unit tests for DoFn internals to ensure coverage. + + Beam's FnApiRunner executes DoFns in a separate SDK harness process, + so coverage tools in the main process cannot capture DoFn code paths. + These tests exercise the DoFn methods directly in-process. + """ + def test_default_element_size_len(self): + from apache_beam.transforms.util import _SortAndBatchElementsDoFn + dofn = _SortAndBatchElementsDoFn( + min_batch_size=1, + max_batch_size=10, + max_batch_weight=100, + element_size_fn=None) + self.assertEqual(dofn._element_size_fn('abc'), 3) + self.assertEqual(dofn._element_size_fn([1, 2]), 2) + + def test_default_element_size_fallback_warns_once(self): + from apache_beam.transforms.util import _SortAndBatchElementsDoFn + dofn = _SortAndBatchElementsDoFn( + min_batch_size=1, + max_batch_size=10, + max_batch_weight=100, + element_size_fn=None) + with self.assertLogs('apache_beam.transforms.util', level='WARNING') as cm: + self.assertEqual(dofn._element_size_fn(42), 1) + self.assertIn('does not support len()', cm.output[0]) + # Second call should not warn again + self.assertEqual(dofn._element_size_fn(3.14), 1) + self.assertTrue(dofn._has_warned_type_error) + + def test_global_dofn_sort_and_batch(self): + """Test _SortAndBatchElementsDoFn directly.""" + from apache_beam.transforms.util import _SortAndBatchElementsDoFn + dofn = _SortAndBatchElementsDoFn( + min_batch_size=1, + max_batch_size=3, + max_batch_weight=100, + element_size_fn=len) + dofn.start_bundle() + for elem in ['ccccc', 'bb', 'dddd', 'a', 'eee']: + dofn.process(elem) + batches = [wv.value for wv in dofn.finish_bundle()] + # All elements emitted + self.assertEqual(sum(len(b) for b in batches), 5) + # Each batch respects max_batch_size=3 + for batch in batches: + self.assertLessEqual(len(batch), 3) + # Elements within each batch are sorted by size + for batch in batches: + lengths = [len(s) for s in batch] + self.assertEqual(lengths, sorted(lengths)) + + def test_global_dofn_empty_bundle(self): + """Test finish_bundle with no elements returns nothing.""" + from apache_beam.transforms.util import _SortAndBatchElementsDoFn + dofn = _SortAndBatchElementsDoFn( + min_batch_size=1, + max_batch_size=10, + max_batch_weight=100, + element_size_fn=len) + dofn.start_bundle() + result = list(dofn.finish_bundle() or []) + self.assertEqual(result, []) + + def test_global_dofn_weight_splitting(self): + """Test weight-based splitting in the global DoFn.""" + from apache_beam.transforms.util import _SortAndBatchElementsDoFn + + # Each element has size 5, max_batch_weight=12 -> 2 per batch + dofn = _SortAndBatchElementsDoFn( + min_batch_size=1, + max_batch_size=100, + max_batch_weight=12, + element_size_fn=len) + dofn.start_bundle() + for elem in ['aaaaa', 'bbbbb', 'ccccc', 'ddddd']: + dofn.process(elem) + batches = [wv.value for wv in dofn.finish_bundle()] + self.assertEqual(len(batches), 2) + for batch in batches: + self.assertEqual(len(batch), 2) + + def test_windowed_dofn_flush_and_finish(self): + """Test _WindowAwareSortAndBatchElementsDoFn directly.""" + from apache_beam.transforms.util import _WindowAwareSortAndBatchElementsDoFn + + dofn = _WindowAwareSortAndBatchElementsDoFn( + min_batch_size=1, + max_batch_size=10, + max_batch_weight=100, + element_size_fn=len) + dofn.start_bundle() + win1 = IntervalWindow(0, 3) + win2 = IntervalWindow(3, 6) + # Manually add to buffers (bypass process() to avoid DoFn.WindowParam) + dofn._buffers[win1].extend(['aa', 'b', 'ccc']) + dofn._buffers[win2].extend(['dddd', 'ee']) + batches = list(dofn.finish_bundle()) + # All elements across both windows emitted + total_elements = sum(len(wv.value) for wv in batches) + self.assertEqual(total_elements, 5) + # Each batch has the correct window + for wv in batches: + self.assertIn(wv.windows[0], (win1, win2)) + + def test_windowed_dofn_overflow_flush(self): + """Test that exceeding _MAX_LIVE_WINDOWS triggers early flush.""" + from apache_beam.transforms.util import _WindowAwareSortAndBatchElementsDoFn + + dofn = _WindowAwareSortAndBatchElementsDoFn( + min_batch_size=1, + max_batch_size=10, + max_batch_weight=100, + element_size_fn=len) + dofn.start_bundle() + # Fill up to _MAX_LIVE_WINDOWS + for i in range(dofn._MAX_LIVE_WINDOWS): + win = IntervalWindow(i * 10, (i + 1) * 10) + dofn._buffers[win].append('x' * (i + 1)) + self.assertEqual(len(dofn._buffers), dofn._MAX_LIVE_WINDOWS) + # Adding one more window should trigger overflow flush + overflow_win = IntervalWindow(100, 110) + results = list(dofn.process('overflow', overflow_win)) + # One window was flushed, so buffer count stays at _MAX_LIVE_WINDOWS + self.assertLessEqual(len(dofn._buffers), dofn._MAX_LIVE_WINDOWS) + # The flushed window produced output + self.assertGreater(len(results), 0) + + def test_windowed_dofn_flush_empty_window(self): + """Test _flush_window with a non-existent window returns nothing.""" + from apache_beam.transforms.util import _WindowAwareSortAndBatchElementsDoFn + + dofn = _WindowAwareSortAndBatchElementsDoFn( + min_batch_size=1, + max_batch_size=10, + max_batch_weight=100, + element_size_fn=len) + dofn.start_bundle() + result = list(dofn._flush_window(IntervalWindow(0, 10))) + self.assertEqual(result, []) + + def test_windowed_dofn_weight_splitting(self): + """Test weight-based splitting in the windowed DoFn.""" + from apache_beam.transforms.util import _WindowAwareSortAndBatchElementsDoFn + + dofn = _WindowAwareSortAndBatchElementsDoFn( + min_batch_size=1, + max_batch_size=100, + max_batch_weight=12, + element_size_fn=len) + dofn.start_bundle() + win = IntervalWindow(0, 10) + dofn._buffers[win].extend(['aaaaa', 'bbbbb', 'ccccc', 'ddddd']) + batches = list(dofn._flush_window(win)) + self.assertEqual(len(batches), 2) + for wv in batches: + self.assertEqual(len(wv.value), 2) + self.assertEqual(wv.windows[0], win) + + class IdentityWindowTest(unittest.TestCase): def test_window_preserved(self): expected_timestamp = timestamp.Timestamp(5)