Skip to content

Commit 16719ac

Browse files
committed
feat: add read throughput micro-benchmark for ArrowScan configurations
1 parent 13feb8d commit 16719ac

File tree

1 file changed

+159
-0
lines changed

1 file changed

+159
-0
lines changed
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Read throughput micro-benchmark for ArrowScan configurations.
18+
19+
Measures records/sec and peak Arrow memory across streaming, concurrent_files,
20+
and batch_size configurations introduced for issue #3036.
21+
22+
Memory is measured using pa.total_allocated_bytes() which tracks PyArrow's C++
23+
memory pool (Arrow buffers, Parquet decompression), not Python heap allocations.
24+
25+
Run with: uv run pytest tests/benchmark/test_read_benchmark.py -v -s -m benchmark
26+
"""
27+
28+
import gc
29+
import statistics
30+
import timeit
31+
from datetime import datetime, timezone
32+
33+
import pyarrow as pa
34+
import pyarrow.parquet as pq
35+
import pytest
36+
37+
from pyiceberg.catalog.sql import SqlCatalog
38+
from pyiceberg.table import Table
39+
40+
NUM_FILES = 32
41+
ROWS_PER_FILE = 500_000
42+
TOTAL_ROWS = NUM_FILES * ROWS_PER_FILE
43+
NUM_RUNS = 3
44+
45+
46+
def _generate_parquet_file(path: str, num_rows: int, seed: int) -> pa.Schema:
47+
"""Write a synthetic Parquet file and return its schema."""
48+
table = pa.table(
49+
{
50+
"id": pa.array(range(seed, seed + num_rows), type=pa.int64()),
51+
"value": pa.array([float(i) * 0.1 for i in range(num_rows)], type=pa.float64()),
52+
"label": pa.array([f"row_{i}" for i in range(num_rows)], type=pa.string()),
53+
"flag": pa.array([i % 2 == 0 for i in range(num_rows)], type=pa.bool_()),
54+
"ts": pa.array([datetime.now(timezone.utc)] * num_rows, type=pa.timestamp("us", tz="UTC")),
55+
}
56+
)
57+
pq.write_table(table, path)
58+
return table.schema
59+
60+
61+
@pytest.fixture(scope="session")
62+
def benchmark_table(tmp_path_factory: pytest.TempPathFactory) -> Table:
63+
"""Create a catalog and table with synthetic Parquet files for benchmarking."""
64+
warehouse_path = str(tmp_path_factory.mktemp("benchmark_warehouse"))
65+
catalog = SqlCatalog(
66+
"benchmark",
67+
uri=f"sqlite:///{warehouse_path}/pyiceberg_catalog.db",
68+
warehouse=f"file://{warehouse_path}",
69+
)
70+
catalog.create_namespace("default")
71+
72+
# Generate files and append to table
73+
table = None
74+
for i in range(NUM_FILES):
75+
file_path = f"{warehouse_path}/data_{i}.parquet"
76+
_generate_parquet_file(file_path, ROWS_PER_FILE, seed=i * ROWS_PER_FILE)
77+
78+
file_table = pq.read_table(file_path)
79+
if table is None:
80+
table = catalog.create_table("default.benchmark_read", schema=file_table.schema)
81+
table.append(file_table)
82+
83+
return table
84+
85+
86+
@pytest.mark.benchmark
87+
@pytest.mark.parametrize(
88+
"streaming,concurrent_files,batch_size",
89+
[
90+
pytest.param(False, 1, None, id="default"),
91+
pytest.param(True, 1, None, id="streaming-cf1"),
92+
pytest.param(True, 2, None, id="streaming-cf2"),
93+
pytest.param(True, 4, None, id="streaming-cf4"),
94+
pytest.param(True, 8, None, id="streaming-cf8"),
95+
pytest.param(True, 16, None, id="streaming-cf16"),
96+
],
97+
)
98+
def test_read_throughput(
99+
benchmark_table: Table,
100+
streaming: bool,
101+
concurrent_files: int,
102+
batch_size: int | None,
103+
) -> None:
104+
"""Measure records/sec and peak Arrow memory for a scan configuration."""
105+
effective_batch_size = batch_size or 131_072 # PyArrow default
106+
if streaming:
107+
config_str = f"streaming=True, concurrent_files={concurrent_files}, batch_size={effective_batch_size}"
108+
else:
109+
config_str = f"streaming=False (executor.map, all files parallel), batch_size={effective_batch_size}"
110+
print(f"\n--- ArrowScan Read Throughput Benchmark ---")
111+
print(f"Config: {config_str}")
112+
print(f" Files: {NUM_FILES}, Rows per file: {ROWS_PER_FILE}, Total rows: {TOTAL_ROWS}")
113+
114+
elapsed_times: list[float] = []
115+
throughputs: list[float] = []
116+
peak_memories: list[int] = []
117+
118+
for run in range(NUM_RUNS):
119+
# Measure throughput
120+
gc.collect()
121+
pa.default_memory_pool().release_unused()
122+
baseline_mem = pa.total_allocated_bytes()
123+
peak_mem = baseline_mem
124+
125+
start = timeit.default_timer()
126+
total_rows = 0
127+
for batch in benchmark_table.scan().to_arrow_batch_reader(
128+
batch_size=batch_size,
129+
streaming=streaming,
130+
concurrent_files=concurrent_files,
131+
):
132+
total_rows += len(batch)
133+
current_mem = pa.total_allocated_bytes()
134+
if current_mem > peak_mem:
135+
peak_mem = current_mem
136+
elapsed = timeit.default_timer() - start
137+
138+
peak_above_baseline = peak_mem - baseline_mem
139+
rows_per_sec = total_rows / elapsed if elapsed > 0 else 0
140+
elapsed_times.append(elapsed)
141+
throughputs.append(rows_per_sec)
142+
peak_memories.append(peak_above_baseline)
143+
144+
print(
145+
f" Run {run + 1}: {elapsed:.2f}s, {rows_per_sec:,.0f} rows/s, "
146+
f"peak arrow mem: {peak_above_baseline / (1024 * 1024):.1f} MB"
147+
)
148+
149+
assert total_rows == TOTAL_ROWS, f"Expected {TOTAL_ROWS} rows, got {total_rows}"
150+
151+
mean_elapsed = statistics.mean(elapsed_times)
152+
stdev_elapsed = statistics.stdev(elapsed_times) if len(elapsed_times) > 1 else 0.0
153+
mean_throughput = statistics.mean(throughputs)
154+
mean_peak_mem = statistics.mean(peak_memories)
155+
156+
print(
157+
f" Mean: {mean_elapsed:.2f}s ± {stdev_elapsed:.2f}s, {mean_throughput:,.0f} rows/s, "
158+
f"peak arrow mem: {mean_peak_mem / (1024 * 1024):.1f} MB"
159+
)

0 commit comments

Comments
 (0)