Skip to content

Commit 4ee2a55

Browse files
committed
Ruff format
1 parent afed20a commit 4ee2a55

31 files changed

Lines changed: 538 additions & 293 deletions

benchmarks/_common.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import os
88
import statistics
99
import time
10-
from typing import Callable
10+
from collections.abc import Callable
1111

1212
try:
1313
import psutil
@@ -152,9 +152,7 @@ def verify_manifest(data_dir: str) -> dict:
152152
for chunk in iter(lambda: f.read(65536), b""):
153153
h.update(chunk)
154154
if h.hexdigest() != entry["sha256"]:
155-
raise ValueError(
156-
f"Checksum mismatch for {entry['name']} — re-run gen_data.py"
157-
)
155+
raise ValueError(f"Checksum mismatch for {entry['name']} — re-run gen_data.py")
158156
return manifest
159157

160158

benchmarks/baselines/manual_sharded.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010
from __future__ import annotations
1111

12+
from collections.abc import Iterator
1213
from glob import glob
13-
from typing import Iterator
1414

1515
import pyarrow.parquet as pq
1616
import torch.utils.data as tud

benchmarks/baselines/naive_iterable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
from __future__ import annotations
1010

11+
from collections.abc import Iterator
1112
from glob import glob
12-
from typing import Iterator
1313

1414
import pyarrow.parquet as pq
1515
import torch.utils.data as tud

benchmarks/gen_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def generate(out_dir: str, dataset: str, fmt: str = "parquet") -> dict:
9595
table = make_table(n_rows, row_id_offset=total)
9696

9797
if cfg.get("sorted_by_label"):
98-
import pyarrow.compute as pc
98+
9999
table = table.sort_by("label")
100100

101101
if fmt == "parquet":

benchmarks/report.py

Lines changed: 309 additions & 126 deletions
Large diffs are not rendered by default.

benchmarks/run.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,28 +38,44 @@
3838
BASELINE_PATH = os.path.join(RESULTS_DIR, "baseline.json")
3939

4040
CI_DATASETS = {
41-
"S1": "tiny", "S2": "unequal", "S3": "single_large",
42-
"S4": "tiny", "S5": "tiny", "S6": "root",
43-
"S7": "tiny", "S8": "tiny",
41+
"S1": "tiny",
42+
"S2": "unequal",
43+
"S3": "single_large",
44+
"S4": "tiny",
45+
"S5": "tiny",
46+
"S6": "root",
47+
"S7": "tiny",
48+
"S8": "tiny",
4449
}
4550
DEFAULT_DATASETS = {
46-
"S1": "large", "S2": "unequal", "S3": "single_large",
47-
"S4": "large", "S5": "large", "S6": "root",
48-
"S7": "small", "S8": "small",
51+
"S1": "large",
52+
"S2": "unequal",
53+
"S3": "single_large",
54+
"S4": "large",
55+
"S5": "large",
56+
"S6": "root",
57+
"S7": "small",
58+
"S8": "small",
4959
}
5060

5161

5262
def _run_metadata() -> dict:
5363
import importlib.metadata
64+
5465
try:
5566
version = importlib.metadata.version("torch-dataloader-utils")
5667
except Exception:
5768
version = "dev"
5869
try:
5970
import subprocess
60-
git_sha = subprocess.check_output(
61-
["git", "rev-parse", "--short", "HEAD"], stderr=subprocess.DEVNULL
62-
).decode().strip()
71+
72+
git_sha = (
73+
subprocess.check_output(
74+
["git", "rev-parse", "--short", "HEAD"], stderr=subprocess.DEVNULL
75+
)
76+
.decode()
77+
.strip()
78+
)
6379
except Exception:
6480
git_sha = "unknown"
6581
return {
@@ -134,7 +150,6 @@ def main() -> int:
134150
module, uses_root = ALL_SCENARIOS[sid]
135151
ds = dataset_map[sid]
136152
d = _dataset_dir(args.data_dir, sid, ds, uses_root)
137-
check_dir = d if uses_root else d
138153
# For root-dir scenarios, verify the subdatasets that exist
139154
if uses_root:
140155
for sub in ["tiny", "small", "medium", "large"]:
@@ -158,11 +173,11 @@ def main() -> int:
158173
module, uses_root = ALL_SCENARIOS[sid]
159174
ds = dataset_map[sid]
160175
d = _dataset_dir(args.data_dir, sid, ds, uses_root)
161-
print(f"\n{'='*60}")
176+
print(f"\n{'=' * 60}")
162177
print(f"Running {sid}: {module.__name__.split('.')[-1]}")
163178
print(f" data_dir : {d}")
164179
print(f" n_runs : {n_runs} n_warmup: {n_warmup}")
165-
print(f"{'='*60}")
180+
print(f"{'=' * 60}")
166181
try:
167182
result = module.run(d, n_warmup=n_warmup, n_runs=n_runs)
168183
all_results["scenarios"][sid] = result

benchmarks/scenarios/__init__.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313
# uses_root_dir=True → scenario receives the parent data dir (contains dataset subdirs)
1414
# uses_root_dir=False → scenario receives the single-dataset subdirectory directly
1515
ALL_SCENARIOS: dict[str, tuple] = {
16-
"S1": (s1_throughput, False),
17-
"S2": (s2_unequal, False),
18-
"S3": (s3_single_large, False),
19-
"S4": (s4_rank_sharding, False),
20-
"S5": (s5_column_projection, False),
16+
"S1": (s1_throughput, False),
17+
"S2": (s2_unequal, False),
18+
"S3": (s3_single_large, False),
19+
"S4": (s4_rank_sharding, False),
20+
"S5": (s5_column_projection, False),
2121
"S6": (s6_predicate_pushdown, True),
22-
"S7": (s7_startup_latency, True),
23-
"S8": (s8_format_comparison, True),
22+
"S7": (s7_startup_latency, True),
23+
"S8": (s8_format_comparison, True),
2424
}

benchmarks/scenarios/s1_throughput.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
from benchmarks._common import load_manifest, measure, parquet_glob, passthrough
1515
from torch_dataloader_utils import StructuredDataset
1616

17-
DESCRIPTION = "Baseline throughput sweep across num_workers on equal-sized files. All three implementations."
17+
DESCRIPTION = (
18+
"Baseline throughput sweep across num_workers on equal-sized files. All three implementations."
19+
)
1820
DATASET = "small"
1921
WORKER_COUNTS = [0, 2, 4, 8]
2022
BATCH_SIZE = 1024

benchmarks/scenarios/s3_single_large.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import statistics
1414
import time
1515

16-
from benchmarks._common import parquet_glob, passthrough, load_manifest, run_epoch
16+
from benchmarks._common import load_manifest, parquet_glob, passthrough, run_epoch
1717
from torch_dataloader_utils import StructuredDataset
1818

1919
DESCRIPTION = (
@@ -60,24 +60,35 @@ def run(data_dir: str, n_warmup: int = 1, n_runs: int = 5) -> dict:
6060
# Warmup: prime OS disk cache
6161
for _ in range(n_warmup):
6262
loader, _ = StructuredDataset.create_dataloader(
63-
path=parquet_glob(data_dir), format="parquet", num_workers=0,
64-
batch_size=BATCH_SIZE, split_bytes=SPLIT_BYTES,
65-
output_format="arrow", collate_fn=passthrough,
63+
path=parquet_glob(data_dir),
64+
format="parquet",
65+
num_workers=0,
66+
batch_size=BATCH_SIZE,
67+
split_bytes=SPLIT_BYTES,
68+
output_format="arrow",
69+
collate_fn=passthrough,
6670
)
6771
run_epoch(loader)
6872

6973
def _lib_worker(w: int):
7074
loader, _ = StructuredDataset.create_dataloader(
71-
path=parquet_glob(data_dir), format="parquet", num_workers=0,
72-
batch_size=BATCH_SIZE, split_bytes=SPLIT_BYTES,
73-
num_ranks=NUM_WORKERS, rank=w,
74-
output_format="arrow", collate_fn=passthrough,
75+
path=parquet_glob(data_dir),
76+
format="parquet",
77+
num_workers=0,
78+
batch_size=BATCH_SIZE,
79+
split_bytes=SPLIT_BYTES,
80+
num_ranks=NUM_WORKERS,
81+
rank=w,
82+
output_format="arrow",
83+
collate_fn=passthrough,
7584
)
7685
return loader
7786

7887
def _manual_worker(w: int):
7988
from glob import glob
89+
8090
import pyarrow.parquet as pq
91+
8192
files = sorted(glob(f"{data_dir}/*.parquet"))
8293
my_files = files[w::NUM_WORKERS]
8394

@@ -88,6 +99,7 @@ def __iter__(self_):
8899
for batch in pf.iter_batches(BATCH_SIZE):
89100
if batch.num_rows > 0:
90101
yield batch
102+
91103
return _W()
92104

93105
lib_stats = _simulate_parallel(_lib_worker, NUM_WORKERS, n_runs)

benchmarks/scenarios/s4_rank_sharding.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from __future__ import annotations
1313

14-
from benchmarks._common import parquet_glob, passthrough, load_manifest, measure
14+
from benchmarks._common import load_manifest, measure, parquet_glob, passthrough
1515
from torch_dataloader_utils import StructuredDataset
1616

1717
DESCRIPTION = (
@@ -34,7 +34,11 @@ def run(data_dir: str, n_warmup: int = 1, n_runs: int = 5) -> dict:
3434
"dataset": DATASET,
3535
"total_rows": total,
3636
"num_workers": NUM_WORKERS,
37-
"config": {"num_workers": NUM_WORKERS, "batch_size": BATCH_SIZE, "rank_counts": RANK_COUNTS},
37+
"config": {
38+
"num_workers": NUM_WORKERS,
39+
"batch_size": BATCH_SIZE,
40+
"rank_counts": RANK_COUNTS,
41+
},
3842
"this_library": {},
3943
"naive_ddp": {},
4044
}
@@ -89,7 +93,7 @@ def _loader(nr=nr):
8993

9094
# naive_ddp at this num_ranks: reads all rows but delivers only 1/nr fraction
9195
naive_elapsed = naive_stats["elapsed_sec"]["median"]
92-
naive_rps = naive_stats["rows_per_sec"]["median"]
96+
_naive_rps = naive_stats["rows_per_sec"]["median"]
9397
results["naive_ddp"][f"num_ranks={nr}"] = {
9498
"rows_received": actual_rows,
9599
"fraction_of_total": round(actual_rows / total, 4),

0 commit comments

Comments
 (0)