99import sys
1010from pathlib import Path
1111from typing import Any , Optional
12+ from time import perf_counter
13+ from contextlib import nullcontext
14+ from datetime import datetime
1215
1316import xarray as xr
1417
2124)
2225
2326
24- def setup_dask_cluster (enable_dask : bool , verbose : bool = False ) -> Optional [Any ]:
27+ def setup_dask_cluster (
28+ enable_dask : bool ,
29+ verbose : bool = False ,
30+ mode : str = "threads" , # threads | processes | single-threaded
31+ n_workers : int = 4 ,
32+ threads_per_worker : int = 1 ,
33+ ) -> Optional [Any ]:
2534 """
2635 Set up a dask cluster for parallel processing.
2736
@@ -41,18 +50,34 @@ def setup_dask_cluster(enable_dask: bool, verbose: bool = False) -> Optional[Any
4150 return None
4251
4352 try :
44- from dask .distributed import Client
45-
46- # Set up local cluster
47- client = Client () # set up local cluster
53+ from dask .distributed import Client , LocalCluster
54+
55+ if mode not in {"threads" , "processes" , "single-threaded" }:
56+ raise ValueError (f"Unsupported --dask-mode: { mode } " )
57+
58+ processes = (mode == "processes" )
59+ # For single-threaded, use one worker, one thread, processes=False
60+ if mode == "single-threaded" :
61+ n_workers = 1
62+ threads_per_worker = 1
63+ processes = False
64+
65+ cluster = LocalCluster (
66+ n_workers = n_workers ,
67+ threads_per_worker = threads_per_worker ,
68+ processes = processes ,
69+ )
70+ client = Client (cluster )
4871
4972 if verbose :
5073 print (f"🚀 Dask cluster started: { client } " )
5174 print (f" Dashboard: { client .dashboard_link } " )
5275 print (f" Workers: { len (client .scheduler_info ()['workers' ])} " )
5376 else :
54- print ("🚀 Dask cluster started for parallel processing" )
55-
77+ print (
78+ f"🚀 Dask cluster started "
79+ f"({ mode } ; workers={ n_workers } , threads/worker={ threads_per_worker } )"
80+ )
5681 return client
5782
5883 except ImportError :
@@ -65,6 +90,79 @@ def setup_dask_cluster(enable_dask: bool, verbose: bool = False) -> Optional[Any
6590 sys .exit (1 )
6691
6792
93+ def _duration_from_startstops (ev : dict , action : str ) -> float :
94+ """Best-effort duration extractor from task_stream 'startstops'."""
95+ total = 0.0
96+ for ss in ev .get ("startstops" , []) or []:
97+ if ss .get ("action" ) == action :
98+ s , t = ss .get ("start" ), ss .get ("stop" )
99+ if s is not None and t is not None :
100+ total += max (0.0 , float (t ) - float (s ))
101+ return total
102+
103+
104+ def _parse_task_events (task_events : Optional [list ]) -> dict :
105+ """Normalize task_stream events across Dask versions into aggregate numbers."""
106+ if not task_events :
107+ return {"tasks_observed" : 0 , "compute_time_s_sum" : 0.0 , "transfer_time_s_sum" : 0.0 }
108+ comp_sum = 0.0
109+ xfer_sum = 0.0
110+ for e in task_events :
111+ # compute time
112+ cd = e .get ("compute_duration" )
113+ if cd is None :
114+ cs = e .get ("compute_start" ) or e .get ("start" )
115+ ce = e .get ("compute_stop" ) or e .get ("stop" )
116+ if cs is not None and ce is not None :
117+ cd = max (0.0 , float (ce ) - float (cs ))
118+ else :
119+ cd = _duration_from_startstops (e , "compute" )
120+ comp_sum += float (cd or 0.0 )
121+ # transfer time
122+ td = e .get ("transfer_duration" )
123+ if td is None :
124+ ts = e .get ("transfer_start" )
125+ te = e .get ("transfer_stop" )
126+ if ts is not None and te is not None :
127+ td = max (0.0 , float (te ) - float (ts ))
128+ else :
129+ td = _duration_from_startstops (e , "transfer" )
130+ xfer_sum += float (td or 0.0 )
131+ return {
132+ "tasks_observed" : len (task_events ),
133+ "compute_time_s_sum" : comp_sum ,
134+ "transfer_time_s_sum" : xfer_sum ,
135+ }
136+
137+
138+ def _summarize_dask_metrics (client : Any , task_events : Optional [list ], wall_clock_s : float ) -> dict :
139+ info = client .scheduler_info ()
140+ workers = info .get ("workers" , {})
141+ total_threads = sum (int (w .get ("nthreads" , 0 ) or 0 ) for w in workers .values ())
142+ memory_limit = sum (int (w .get ("memory_limit" , 0 ) or 0 ) for w in workers .values ())
143+ memory_used = sum (int (w .get ("metrics" , {}).get ("memory" , 0 ) or 0 ) for w in workers .values ())
144+ spilled = sum (int (w .get ("metrics" , {}).get ("spilled_nbytes" , 0 ) or 0 ) for w in workers .values ())
145+
146+ parsed = _parse_task_events (task_events )
147+ tasks = parsed ["tasks_observed" ]
148+ base = {
149+ "wall_clock_s" : wall_clock_s ,
150+ "workers" : len (workers ),
151+ "threads_total" : total_threads ,
152+ "tasks_observed" : tasks ,
153+ "tasks_per_sec" : (tasks / wall_clock_s ) if tasks and wall_clock_s else 0.0 ,
154+ "compute_time_s_sum" : parsed ["compute_time_s_sum" ],
155+ "transfer_time_s_sum" : parsed ["transfer_time_s_sum" ],
156+ "memory_used_bytes" : memory_used ,
157+ "memory_limit_bytes" : memory_limit ,
158+ }
159+ if spilled :
160+ base ["spilled_nbytes" ] = spilled
161+ if getattr (client , "dashboard_link" , None ):
162+ base ["dashboard_link" ] = client .dashboard_link
163+ return base
164+
165+
68166def convert_command (args : argparse .Namespace ) -> None :
69167 """
70168 Convert EOPF dataset to GeoZarr compliant format.
@@ -76,9 +174,19 @@ def convert_command(args: argparse.Namespace) -> None:
76174 """
77175 # Set up dask cluster if requested
78176 dask_client = setup_dask_cluster (
79- enable_dask = getattr (args , "dask_cluster" , False ), verbose = args .verbose
177+ enable_dask = getattr (args , "dask_cluster" , False ),
178+ verbose = args .verbose ,
179+ mode = getattr (args , "dask_mode" , "threads" ),
180+ n_workers = getattr (args , "dask_workers" , 4 ),
181+ threads_per_worker = getattr (args , "dask_threads_per_worker" , 1 ),
80182 )
81183
184+ # Prepare for metrics writing even on failures (local outputs only)
185+ debug_dir : Optional [Path ] = None
186+ run_id = datetime .now ().strftime ("%Y%m%d-%H%M%S" )
187+ status = "unknown"
188+ error_msg = None
189+
82190 try :
83191 # Validate input path (handle both local paths and URLs)
84192 input_path_str = args .input_path
@@ -129,6 +237,12 @@ def convert_command(args: argparse.Namespace) -> None:
129237 output_path = Path (output_path_str )
130238 output_path .parent .mkdir (parents = True , exist_ok = True )
131239 output_path = str (output_path )
240+ # Prepare debug dir for metrics
241+ debug_dir = Path (output_path ) / "debug"
242+ try :
243+ debug_dir .mkdir (parents = True , exist_ok = True )
244+ except Exception :
245+ debug_dir = None
132246
133247 if args .verbose :
134248 print (f"Loading EOPF dataset from: { input_path } " )
@@ -157,19 +271,84 @@ def convert_command(args: argparse.Namespace) -> None:
157271
158272 # Convert to GeoZarr compliant format
159273 print ("Converting to GeoZarr compliant format..." )
160- dt_geozarr = create_geozarr_dataset (
161- dt_input = dt ,
162- groups = args .groups ,
163- output_path = output_path ,
164- spatial_chunk = args .spatial_chunk ,
165- min_dimension = args .min_dimension ,
166- tile_width = args .tile_width ,
167- max_retries = args .max_retries ,
168- crs_groups = args .crs_groups ,
169- )
274+ t0 = perf_counter ()
275+ task_events = None
276+ perf_ctx = nullcontext ()
277+ if dask_client is not None and getattr (args , "dask_perf_html" , None ):
278+ from dask .distributed import performance_report
279+ perf_path = args .dask_perf_html
280+ if not (perf_path .startswith (("s3://" , "gs://" , "http://" , "https://" ))):
281+ Path (perf_path ).parent .mkdir (parents = True , exist_ok = True )
282+ perf_ctx = performance_report (filename = perf_path )
283+
284+ capture_tasks = False
285+ if dask_client is not None :
286+ try :
287+ from dask .distributed import get_task_stream
288+ capture_tasks = True
289+ except Exception :
290+ pass
291+
292+ if capture_tasks :
293+ with perf_ctx :
294+ from dask .distributed import get_task_stream
295+ with get_task_stream (client = dask_client , plot = False ) as ts :
296+ dt_geozarr = create_geozarr_dataset (
297+ dt_input = dt ,
298+ groups = args .groups ,
299+ output_path = output_path ,
300+ spatial_chunk = args .spatial_chunk ,
301+ min_dimension = args .min_dimension ,
302+ tile_width = args .tile_width ,
303+ max_retries = args .max_retries ,
304+ crs_groups = args .crs_groups ,
305+ )
306+ task_events = ts .data
307+ else :
308+ with perf_ctx :
309+ dt_geozarr = create_geozarr_dataset (
310+ dt_input = dt ,
311+ groups = args .groups ,
312+ output_path = output_path ,
313+ spatial_chunk = args .spatial_chunk ,
314+ min_dimension = args .min_dimension ,
315+ tile_width = args .tile_width ,
316+ max_retries = args .max_retries ,
317+ crs_groups = args .crs_groups ,
318+ )
319+
320+ wall_clock = perf_counter () - t0
321+ if dask_client is not None and getattr (args , "verbose" , False ) and getattr (args , "dask_perf_html" , None ):
322+ print (f"📊 Dask performance report: { args .dask_perf_html } " )
170323
171324 print ("✅ Successfully converted EOPF dataset to GeoZarr format" )
172325 print (f"Output saved to: { output_path } " )
326+ status = "ok"
327+
328+ # Drop a JSON run summary (local paths only)
329+ try :
330+ if debug_dir is not None :
331+ summary_path = Path (output_path ) / "debug" / "dask_run_summary.json"
332+ summary_path .parent .mkdir (parents = True , exist_ok = True )
333+ from json import dumps as _dumps
334+ summary = {
335+ "dask_enabled" : bool (dask_client is not None ),
336+ "mode" : getattr (args , "dask_mode" , None ),
337+ "workers" : getattr (args , "dask_workers" , None ),
338+ "threads_per_worker" : getattr (args , "dask_threads_per_worker" , None ),
339+ "perf_report" : getattr (args , "dask_perf_html" , None ),
340+ "wall_clock_s" : wall_clock if dask_client is not None else None ,
341+ "groups" : args .groups ,
342+ "spatial_chunk" : args .spatial_chunk ,
343+ "min_dimension" : args .min_dimension ,
344+ "tile_width" : args .tile_width ,
345+ }
346+ summary_path .write_text (_dumps (summary , indent = 2 ))
347+ if args .verbose :
348+ print (f"🧾 Wrote run summary: { summary_path } " )
349+ except Exception as _exc :
350+ if args .verbose :
351+ print (f"(debug) could not write run summary: { _exc } " )
173352
174353 if args .verbose :
175354 # Check if dt_geozarr is a DataTree or Dataset
@@ -183,6 +362,7 @@ def convert_command(args: argparse.Namespace) -> None:
183362
184363 except Exception as e :
185364 print (f"❌ Error during conversion: { e } " )
365+ error_msg = str (e )
186366 if args .verbose :
187367 import traceback
188368
@@ -200,6 +380,46 @@ def convert_command(args: argparse.Namespace) -> None:
200380 if args .verbose :
201381 print (f"Warning: Error closing dask cluster: { e } " )
202382
383+ # Best-effort metrics write (works for success or failure, local only)
384+ try :
385+ if debug_dir is not None :
386+ attempt = 1
387+ try :
388+ attempt = 1 + len (list (debug_dir .glob ("dask_metrics_*.json" )))
389+ except Exception :
390+ pass
391+ wall = None
392+ if "t0" in locals ():
393+ wall = perf_counter () - t0
394+ metrics_doc = {
395+ "status" : status ,
396+ "run_id" : run_id ,
397+ "attempt" : attempt ,
398+ "dask_enabled" : bool (dask_client is not None ),
399+ "mode" : getattr (args , "dask_mode" , None ),
400+ }
401+ if dask_client is not None and wall is not None :
402+ try :
403+ metrics_doc .update (_summarize_dask_metrics (dask_client , locals ().get ("task_events" ), wall ))
404+ except Exception :
405+ metrics_doc ["wall_clock_s" ] = wall
406+ elif wall is not None :
407+ metrics_doc ["wall_clock_s" ] = wall
408+ if error_msg :
409+ metrics_doc ["error" ] = error_msg
410+
411+ path_ts = debug_dir / f"dask_metrics_{ run_id } _attempt{ attempt } .json"
412+ path_latest = debug_dir / "dask_metrics.json"
413+ from json import dumps as _dumps
414+ txt = _dumps (metrics_doc , indent = 2 )
415+ path_ts .write_text (txt )
416+ path_latest .write_text (txt )
417+ if getattr (args , "verbose" , False ):
418+ print (f"📈 Wrote metrics: { path_ts } " )
419+ except Exception as _mexc :
420+ if getattr (args , "verbose" , False ):
421+ print (f"(debug) could not write metrics snapshot: { _mexc } " )
422+
203423
204424def info_command (args : argparse .Namespace ) -> None :
205425 """
@@ -309,8 +529,6 @@ def format_data_vars(data_vars: Any) -> str:
309529 # Get xarray's HTML representation and extract just the variables section
310530 try :
311531 html_repr = temp_ds ._repr_html_ ()
312- # Extract the variables section from xarray's HTML
313- # This gives us the rich, interactive variable display
314532 return f'<div class="xarray-variables">{ html_repr } </div>'
315533 except Exception :
316534 # Fallback to simple format if xarray HTML fails
@@ -1095,6 +1313,30 @@ def create_parser() -> argparse.ArgumentParser:
10951313 action = "store_true" ,
10961314 help = "Start a local dask cluster for parallel processing of chunks" ,
10971315 )
1316+ convert_parser .add_argument (
1317+ "--dask-mode" ,
1318+ type = str ,
1319+ default = "threads" ,
1320+ choices = ["threads" , "processes" , "single-threaded" ],
1321+ help = "Local Dask execution mode (default: threads)." ,
1322+ )
1323+ convert_parser .add_argument (
1324+ "--dask-workers" ,
1325+ type = int ,
1326+ default = 4 ,
1327+ help = "Number of Dask workers (default: 4)." ,
1328+ )
1329+ convert_parser .add_argument (
1330+ "--dask-threads-per-worker" ,
1331+ type = int ,
1332+ default = 1 ,
1333+ help = "Threads per worker (default: 1)." ,
1334+ )
1335+ convert_parser .add_argument (
1336+ "--dask-perf-html" ,
1337+ type = str ,
1338+ help = "Write a Dask performance report HTML to this path (local only)." ,
1339+ )
10981340 convert_parser .set_defaults (func = convert_command )
10991341
11001342 # Info command
0 commit comments