Skip to content

Commit ef2eecf

Browse files
committed
feat: log convert metrics to benchmark local runs
1 parent 563e1f5 commit ef2eecf

3 files changed

Lines changed: 2976 additions & 1361 deletions

File tree

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ dependencies = [
3939
"s3fs>=2024.6.0",
4040
"boto3>=1.34.0",
4141
"pyproj>=3.7.0",
42+
"ipykernel>=6.30.1",
43+
"jupyter>=1.1.1",
44+
"bokeh>=3.7.3",
4245
]
4346

4447
[dependency-groups]

src/eopf_geozarr/cli.py

Lines changed: 262 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
import sys
1010
from pathlib import Path
1111
from typing import Any, Optional
12+
from time import perf_counter
13+
from contextlib import nullcontext
14+
from datetime import datetime
1215

1316
import xarray as xr
1417

@@ -21,7 +24,13 @@
2124
)
2225

2326

24-
def setup_dask_cluster(enable_dask: bool, verbose: bool = False) -> Optional[Any]:
27+
def setup_dask_cluster(
28+
enable_dask: bool,
29+
verbose: bool = False,
30+
mode: str = "threads", # threads | processes | single-threaded
31+
n_workers: int = 4,
32+
threads_per_worker: int = 1,
33+
) -> Optional[Any]:
2534
"""
2635
Set up a dask cluster for parallel processing.
2736
@@ -41,18 +50,34 @@ def setup_dask_cluster(enable_dask: bool, verbose: bool = False) -> Optional[Any
4150
return None
4251

4352
try:
44-
from dask.distributed import Client
45-
46-
# Set up local cluster
47-
client = Client() # set up local cluster
53+
from dask.distributed import Client, LocalCluster
54+
55+
if mode not in {"threads", "processes", "single-threaded"}:
56+
raise ValueError(f"Unsupported --dask-mode: {mode}")
57+
58+
processes = (mode == "processes")
59+
# For single-threaded, use one worker, one thread, processes=False
60+
if mode == "single-threaded":
61+
n_workers = 1
62+
threads_per_worker = 1
63+
processes = False
64+
65+
cluster = LocalCluster(
66+
n_workers=n_workers,
67+
threads_per_worker=threads_per_worker,
68+
processes=processes,
69+
)
70+
client = Client(cluster)
4871

4972
if verbose:
5073
print(f"🚀 Dask cluster started: {client}")
5174
print(f" Dashboard: {client.dashboard_link}")
5275
print(f" Workers: {len(client.scheduler_info()['workers'])}")
5376
else:
54-
print("🚀 Dask cluster started for parallel processing")
55-
77+
print(
78+
f"🚀 Dask cluster started "
79+
f"({mode}; workers={n_workers}, threads/worker={threads_per_worker})"
80+
)
5681
return client
5782

5883
except ImportError:
@@ -65,6 +90,79 @@ def setup_dask_cluster(enable_dask: bool, verbose: bool = False) -> Optional[Any
6590
sys.exit(1)
6691

6792

93+
def _duration_from_startstops(ev: dict, action: str) -> float:
94+
"""Best-effort duration extractor from task_stream 'startstops'."""
95+
total = 0.0
96+
for ss in ev.get("startstops", []) or []:
97+
if ss.get("action") == action:
98+
s, t = ss.get("start"), ss.get("stop")
99+
if s is not None and t is not None:
100+
total += max(0.0, float(t) - float(s))
101+
return total
102+
103+
104+
def _parse_task_events(task_events: Optional[list]) -> dict:
105+
"""Normalize task_stream events across Dask versions into aggregate numbers."""
106+
if not task_events:
107+
return {"tasks_observed": 0, "compute_time_s_sum": 0.0, "transfer_time_s_sum": 0.0}
108+
comp_sum = 0.0
109+
xfer_sum = 0.0
110+
for e in task_events:
111+
# compute time
112+
cd = e.get("compute_duration")
113+
if cd is None:
114+
cs = e.get("compute_start") or e.get("start")
115+
ce = e.get("compute_stop") or e.get("stop")
116+
if cs is not None and ce is not None:
117+
cd = max(0.0, float(ce) - float(cs))
118+
else:
119+
cd = _duration_from_startstops(e, "compute")
120+
comp_sum += float(cd or 0.0)
121+
# transfer time
122+
td = e.get("transfer_duration")
123+
if td is None:
124+
ts = e.get("transfer_start")
125+
te = e.get("transfer_stop")
126+
if ts is not None and te is not None:
127+
td = max(0.0, float(te) - float(ts))
128+
else:
129+
td = _duration_from_startstops(e, "transfer")
130+
xfer_sum += float(td or 0.0)
131+
return {
132+
"tasks_observed": len(task_events),
133+
"compute_time_s_sum": comp_sum,
134+
"transfer_time_s_sum": xfer_sum,
135+
}
136+
137+
138+
def _summarize_dask_metrics(client: Any, task_events: Optional[list], wall_clock_s: float) -> dict:
139+
info = client.scheduler_info()
140+
workers = info.get("workers", {})
141+
total_threads = sum(int(w.get("nthreads", 0) or 0) for w in workers.values())
142+
memory_limit = sum(int(w.get("memory_limit", 0) or 0) for w in workers.values())
143+
memory_used = sum(int(w.get("metrics", {}).get("memory", 0) or 0) for w in workers.values())
144+
spilled = sum(int(w.get("metrics", {}).get("spilled_nbytes", 0) or 0) for w in workers.values())
145+
146+
parsed = _parse_task_events(task_events)
147+
tasks = parsed["tasks_observed"]
148+
base = {
149+
"wall_clock_s": wall_clock_s,
150+
"workers": len(workers),
151+
"threads_total": total_threads,
152+
"tasks_observed": tasks,
153+
"tasks_per_sec": (tasks / wall_clock_s) if tasks and wall_clock_s else 0.0,
154+
"compute_time_s_sum": parsed["compute_time_s_sum"],
155+
"transfer_time_s_sum": parsed["transfer_time_s_sum"],
156+
"memory_used_bytes": memory_used,
157+
"memory_limit_bytes": memory_limit,
158+
}
159+
if spilled:
160+
base["spilled_nbytes"] = spilled
161+
if getattr(client, "dashboard_link", None):
162+
base["dashboard_link"] = client.dashboard_link
163+
return base
164+
165+
68166
def convert_command(args: argparse.Namespace) -> None:
69167
"""
70168
Convert EOPF dataset to GeoZarr compliant format.
@@ -76,9 +174,19 @@ def convert_command(args: argparse.Namespace) -> None:
76174
"""
77175
# Set up dask cluster if requested
78176
dask_client = setup_dask_cluster(
79-
enable_dask=getattr(args, "dask_cluster", False), verbose=args.verbose
177+
enable_dask=getattr(args, "dask_cluster", False),
178+
verbose=args.verbose,
179+
mode=getattr(args, "dask_mode", "threads"),
180+
n_workers=getattr(args, "dask_workers", 4),
181+
threads_per_worker=getattr(args, "dask_threads_per_worker", 1),
80182
)
81183

184+
# Prepare for metrics writing even on failures (local outputs only)
185+
debug_dir: Optional[Path] = None
186+
run_id = datetime.now().strftime("%Y%m%d-%H%M%S")
187+
status = "unknown"
188+
error_msg = None
189+
82190
try:
83191
# Validate input path (handle both local paths and URLs)
84192
input_path_str = args.input_path
@@ -129,6 +237,12 @@ def convert_command(args: argparse.Namespace) -> None:
129237
output_path = Path(output_path_str)
130238
output_path.parent.mkdir(parents=True, exist_ok=True)
131239
output_path = str(output_path)
240+
# Prepare debug dir for metrics
241+
debug_dir = Path(output_path) / "debug"
242+
try:
243+
debug_dir.mkdir(parents=True, exist_ok=True)
244+
except Exception:
245+
debug_dir = None
132246

133247
if args.verbose:
134248
print(f"Loading EOPF dataset from: {input_path}")
@@ -157,19 +271,84 @@ def convert_command(args: argparse.Namespace) -> None:
157271

158272
# Convert to GeoZarr compliant format
159273
print("Converting to GeoZarr compliant format...")
160-
dt_geozarr = create_geozarr_dataset(
161-
dt_input=dt,
162-
groups=args.groups,
163-
output_path=output_path,
164-
spatial_chunk=args.spatial_chunk,
165-
min_dimension=args.min_dimension,
166-
tile_width=args.tile_width,
167-
max_retries=args.max_retries,
168-
crs_groups=args.crs_groups,
169-
)
274+
t0 = perf_counter()
275+
task_events = None
276+
perf_ctx = nullcontext()
277+
if dask_client is not None and getattr(args, "dask_perf_html", None):
278+
from dask.distributed import performance_report
279+
perf_path = args.dask_perf_html
280+
if not (perf_path.startswith(("s3://", "gs://", "http://", "https://"))):
281+
Path(perf_path).parent.mkdir(parents=True, exist_ok=True)
282+
perf_ctx = performance_report(filename=perf_path)
283+
284+
capture_tasks = False
285+
if dask_client is not None:
286+
try:
287+
from dask.distributed import get_task_stream
288+
capture_tasks = True
289+
except Exception:
290+
pass
291+
292+
if capture_tasks:
293+
with perf_ctx:
294+
from dask.distributed import get_task_stream
295+
with get_task_stream(client=dask_client, plot=False) as ts:
296+
dt_geozarr = create_geozarr_dataset(
297+
dt_input=dt,
298+
groups=args.groups,
299+
output_path=output_path,
300+
spatial_chunk=args.spatial_chunk,
301+
min_dimension=args.min_dimension,
302+
tile_width=args.tile_width,
303+
max_retries=args.max_retries,
304+
crs_groups=args.crs_groups,
305+
)
306+
task_events = ts.data
307+
else:
308+
with perf_ctx:
309+
dt_geozarr = create_geozarr_dataset(
310+
dt_input=dt,
311+
groups=args.groups,
312+
output_path=output_path,
313+
spatial_chunk=args.spatial_chunk,
314+
min_dimension=args.min_dimension,
315+
tile_width=args.tile_width,
316+
max_retries=args.max_retries,
317+
crs_groups=args.crs_groups,
318+
)
319+
320+
wall_clock = perf_counter() - t0
321+
if dask_client is not None and getattr(args, "verbose", False) and getattr(args, "dask_perf_html", None):
322+
print(f"📊 Dask performance report: {args.dask_perf_html}")
170323

171324
print("✅ Successfully converted EOPF dataset to GeoZarr format")
172325
print(f"Output saved to: {output_path}")
326+
status = "ok"
327+
328+
# Drop a JSON run summary (local paths only)
329+
try:
330+
if debug_dir is not None:
331+
summary_path = Path(output_path) / "debug" / "dask_run_summary.json"
332+
summary_path.parent.mkdir(parents=True, exist_ok=True)
333+
from json import dumps as _dumps
334+
summary = {
335+
"dask_enabled": bool(dask_client is not None),
336+
"mode": getattr(args, "dask_mode", None),
337+
"workers": getattr(args, "dask_workers", None),
338+
"threads_per_worker": getattr(args, "dask_threads_per_worker", None),
339+
"perf_report": getattr(args, "dask_perf_html", None),
340+
"wall_clock_s": wall_clock if dask_client is not None else None,
341+
"groups": args.groups,
342+
"spatial_chunk": args.spatial_chunk,
343+
"min_dimension": args.min_dimension,
344+
"tile_width": args.tile_width,
345+
}
346+
summary_path.write_text(_dumps(summary, indent=2))
347+
if args.verbose:
348+
print(f"🧾 Wrote run summary: {summary_path}")
349+
except Exception as _exc:
350+
if args.verbose:
351+
print(f"(debug) could not write run summary: {_exc}")
173352

174353
if args.verbose:
175354
# Check if dt_geozarr is a DataTree or Dataset
@@ -183,6 +362,7 @@ def convert_command(args: argparse.Namespace) -> None:
183362

184363
except Exception as e:
185364
print(f"❌ Error during conversion: {e}")
365+
error_msg = str(e)
186366
if args.verbose:
187367
import traceback
188368

@@ -200,6 +380,46 @@ def convert_command(args: argparse.Namespace) -> None:
200380
if args.verbose:
201381
print(f"Warning: Error closing dask cluster: {e}")
202382

383+
# Best-effort metrics write (works for success or failure, local only)
384+
try:
385+
if debug_dir is not None:
386+
attempt = 1
387+
try:
388+
attempt = 1 + len(list(debug_dir.glob("dask_metrics_*.json")))
389+
except Exception:
390+
pass
391+
wall = None
392+
if "t0" in locals():
393+
wall = perf_counter() - t0
394+
metrics_doc = {
395+
"status": status,
396+
"run_id": run_id,
397+
"attempt": attempt,
398+
"dask_enabled": bool(dask_client is not None),
399+
"mode": getattr(args, "dask_mode", None),
400+
}
401+
if dask_client is not None and wall is not None:
402+
try:
403+
metrics_doc.update(_summarize_dask_metrics(dask_client, locals().get("task_events"), wall))
404+
except Exception:
405+
metrics_doc["wall_clock_s"] = wall
406+
elif wall is not None:
407+
metrics_doc["wall_clock_s"] = wall
408+
if error_msg:
409+
metrics_doc["error"] = error_msg
410+
411+
path_ts = debug_dir / f"dask_metrics_{run_id}_attempt{attempt}.json"
412+
path_latest = debug_dir / "dask_metrics.json"
413+
from json import dumps as _dumps
414+
txt = _dumps(metrics_doc, indent=2)
415+
path_ts.write_text(txt)
416+
path_latest.write_text(txt)
417+
if getattr(args, "verbose", False):
418+
print(f"📈 Wrote metrics: {path_ts}")
419+
except Exception as _mexc:
420+
if getattr(args, "verbose", False):
421+
print(f"(debug) could not write metrics snapshot: {_mexc}")
422+
203423

204424
def info_command(args: argparse.Namespace) -> None:
205425
"""
@@ -309,8 +529,6 @@ def format_data_vars(data_vars: Any) -> str:
309529
# Get xarray's HTML representation and extract just the variables section
310530
try:
311531
html_repr = temp_ds._repr_html_()
312-
# Extract the variables section from xarray's HTML
313-
# This gives us the rich, interactive variable display
314532
return f'<div class="xarray-variables">{html_repr}</div>'
315533
except Exception:
316534
# Fallback to simple format if xarray HTML fails
@@ -1095,6 +1313,30 @@ def create_parser() -> argparse.ArgumentParser:
10951313
action="store_true",
10961314
help="Start a local dask cluster for parallel processing of chunks",
10971315
)
1316+
convert_parser.add_argument(
1317+
"--dask-mode",
1318+
type=str,
1319+
default="threads",
1320+
choices=["threads", "processes", "single-threaded"],
1321+
help="Local Dask execution mode (default: threads).",
1322+
)
1323+
convert_parser.add_argument(
1324+
"--dask-workers",
1325+
type=int,
1326+
default=4,
1327+
help="Number of Dask workers (default: 4).",
1328+
)
1329+
convert_parser.add_argument(
1330+
"--dask-threads-per-worker",
1331+
type=int,
1332+
default=1,
1333+
help="Threads per worker (default: 1).",
1334+
)
1335+
convert_parser.add_argument(
1336+
"--dask-perf-html",
1337+
type=str,
1338+
help="Write a Dask performance report HTML to this path (local only).",
1339+
)
10981340
convert_parser.set_defaults(func=convert_command)
10991341

11001342
# Info command

0 commit comments

Comments
 (0)