Skip to content

Commit 18fdef1

Browse files
author
Donglai Wei
committed
Move runtime dispatch out of main script
1 parent e7fc206 commit 18fdef1

3 files changed

Lines changed: 379 additions & 370 deletions

File tree

connectomics/runtime/dispatch.py

Lines changed: 362 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,362 @@
1+
"""Runtime mode dispatch for the command-line entry point."""
2+
3+
from __future__ import annotations
4+
5+
import os
6+
import sys
7+
from pathlib import Path
8+
from typing import Any
9+
10+
import torch
11+
12+
from ..config import Config
13+
from ..training.lightning import (
14+
ConnectomicsModule,
15+
cleanup_run_directory,
16+
create_datamodule,
17+
create_trainer,
18+
modify_checkpoint_state,
19+
setup_seed_everything,
20+
)
21+
from .cache_resolver import (
22+
create_decode_only_datamodule,
23+
handle_test_cache_hit,
24+
has_cached_predictions_in_output_dir,
25+
has_tta_prediction_file,
26+
preflight_test_cache_hit,
27+
try_cache_only_test_execution,
28+
)
29+
from .checkpoint_dispatch import setup_runtime_directories
30+
from .output_naming import resolve_prediction_cache_suffix
31+
from .sharding import (
32+
has_assigned_test_shard,
33+
maybe_enable_independent_test_sharding,
34+
maybe_limit_test_devices,
35+
resolve_test_stage_runtime,
36+
shard_test_datamodule,
37+
)
38+
39+
_RANK_STDOUT_REDIRECT = None
40+
seed_everything = setup_seed_everything()
41+
42+
43+
def suppress_nonzero_rank_stdout() -> None:
44+
"""Silence duplicate stdout from non-zero DDP subprocesses."""
45+
global _RANK_STDOUT_REDIRECT
46+
local_rank = os.environ.get("LOCAL_RANK")
47+
if local_rank is None or local_rank == "0":
48+
return
49+
_RANK_STDOUT_REDIRECT = open(os.devnull, "w")
50+
sys.stdout = _RANK_STDOUT_REDIRECT
51+
52+
53+
def prepare_cli_args(args: Any, repo_root: Path) -> None:
54+
"""Apply CLI-only defaults before config resolution."""
55+
if args.demo:
56+
minimal_config = repo_root / "tutorials" / "minimal.yaml"
57+
if not minimal_config.exists():
58+
print(f"Error: Demo config not found: {minimal_config}")
59+
sys.exit(1)
60+
if not args.config:
61+
args.config = str(minimal_config)
62+
if args.fast_dev_run == 0:
63+
args.fast_dev_run = 1
64+
if args.mode != "train":
65+
args.mode = "train"
66+
print(f"Demo mode: using minimal config {args.config}")
67+
68+
if not args.config:
69+
print("Error: --config is required (or use --demo for a quick test)")
70+
print("\nUsage:")
71+
print(" python scripts/main.py --config tutorials/mito_lucchi++.yaml")
72+
print(" python scripts/main.py --demo")
73+
sys.exit(1)
74+
75+
76+
def configure_matmul_precision(cfg: Config) -> None:
77+
"""Enable Tensor Core matmul precision when supported by available CUDA devices."""
78+
requested_gpus = cfg.system.num_gpus
79+
if requested_gpus <= 0 or not torch.cuda.is_available():
80+
return
81+
82+
try:
83+
visible_gpus = torch.cuda.device_count()
84+
check_gpus = min(requested_gpus, visible_gpus)
85+
86+
has_tensor_cores = False
87+
for idx in range(check_gpus):
88+
major, _minor = torch.cuda.get_device_capability(idx)
89+
if major >= 7:
90+
has_tensor_cores = True
91+
break
92+
93+
if has_tensor_cores:
94+
torch.set_float32_matmul_precision("medium")
95+
print("Enabled float32 matmul precision='medium' (Tensor Cores detected)")
96+
except Exception as exc:
97+
print(f"WARNING: Could not configure float32 matmul precision automatically: {exc}")
98+
99+
100+
def _create_runtime_model(
101+
args: Any,
102+
cfg: Config,
103+
run_dir: Path,
104+
*,
105+
has_saved_prediction: bool,
106+
saved_prediction_path: str,
107+
tta_cached: bool,
108+
) -> tuple[ConnectomicsModule, str | None]:
109+
if has_saved_prediction:
110+
print(f" Decode-only mode: loading predictions from {saved_prediction_path}")
111+
print(" Skipping model build entirely.")
112+
model = ConnectomicsModule(cfg, model=torch.nn.Identity(), skip_loss=True)
113+
model._skip_inference = True
114+
ckpt_path = None
115+
elif tta_cached:
116+
print(
117+
f" Cached intermediate predictions found; "
118+
f"creating lightweight module (skipping {cfg.model.arch.type} build)."
119+
)
120+
model = ConnectomicsModule(cfg, model=torch.nn.Identity())
121+
model._skip_inference = True
122+
ckpt_path = None
123+
elif args.external_prefix:
124+
print(f"Creating model: {cfg.model.arch.type}")
125+
model = ConnectomicsModule(cfg)
126+
print(
127+
" WARNING: External weights loaded - checkpoint path will not "
128+
"be used for training/testing"
129+
)
130+
ckpt_path = None
131+
else:
132+
print(f"Creating model: {cfg.model.arch.type}")
133+
model = ConnectomicsModule(cfg)
134+
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
135+
print(f" Model parameters: {num_params:,}")
136+
ckpt_path = modify_checkpoint_state(
137+
args.checkpoint,
138+
run_dir,
139+
reset_optimizer=args.reset_optimizer,
140+
reset_scheduler=args.reset_scheduler,
141+
reset_epoch=args.reset_epoch,
142+
reset_early_stopping=args.reset_early_stopping,
143+
)
144+
145+
model._prediction_checkpoint_path = args.checkpoint or getattr(
146+
getattr(cfg, "model", None),
147+
"external_weights_path",
148+
None,
149+
)
150+
return model, ckpt_path
151+
152+
153+
def _run_training(
154+
args: Any, cfg: Config, model: ConnectomicsModule, trainer: Any, ckpt_path
155+
) -> None:
156+
datamodule = create_datamodule(cfg, mode=args.mode, fast_dev_run=bool(args.fast_dev_run))
157+
print("\n" + "=" * 60)
158+
print("STARTING TRAINING")
159+
print("=" * 60)
160+
161+
trainer.fit(
162+
model,
163+
datamodule=datamodule,
164+
ckpt_path=ckpt_path,
165+
)
166+
print("\n[OK]Training completed successfully!")
167+
168+
169+
def _run_test(
170+
args: Any,
171+
cfg: Config,
172+
model: ConnectomicsModule,
173+
trainer: Any,
174+
run_dir: Path,
175+
ckpt_path,
176+
*,
177+
has_saved_prediction: bool,
178+
saved_prediction_path: str,
179+
) -> None:
180+
print("\n" + "=" * 60)
181+
print("RUNNING TEST")
182+
print("=" * 60)
183+
184+
cfg = resolve_test_stage_runtime(cfg)
185+
cfg.inference.save_prediction.cache_suffix = resolve_prediction_cache_suffix(
186+
cfg,
187+
args.mode,
188+
checkpoint_path=args.checkpoint,
189+
)
190+
191+
if maybe_enable_independent_test_sharding(args, cfg):
192+
trainer = create_trainer(
193+
cfg,
194+
run_dir=run_dir,
195+
fast_dev_run=args.fast_dev_run,
196+
ckpt_path=ckpt_path,
197+
mode="test",
198+
)
199+
if not has_assigned_test_shard(cfg, args):
200+
return
201+
202+
if has_saved_prediction:
203+
datamodule = create_decode_only_datamodule(cfg, saved_prediction_path)
204+
else:
205+
datamodule = create_datamodule(cfg, mode="test")
206+
207+
if args.shard_id is not None and args.num_shards is not None:
208+
datamodule = shard_test_datamodule(datamodule, args.shard_id, args.num_shards)
209+
210+
if maybe_limit_test_devices(cfg, datamodule):
211+
trainer = create_trainer(
212+
cfg,
213+
run_dir=run_dir,
214+
fast_dev_run=args.fast_dev_run,
215+
ckpt_path=ckpt_path,
216+
mode="test",
217+
)
218+
219+
if args.mode == "tune-test":
220+
from .tune_runner import load_and_apply_best_params
221+
222+
print("\n" + "=" * 80)
223+
print("LOADING BEST PARAMETERS")
224+
print("=" * 80)
225+
226+
cfg = load_and_apply_best_params(cfg, checkpoint_path=args.checkpoint)
227+
cfg.inference.save_prediction.cache_suffix = resolve_prediction_cache_suffix(
228+
cfg,
229+
args.mode,
230+
checkpoint_path=args.checkpoint,
231+
)
232+
233+
test_ckpt_path = ckpt_path
234+
cache_hit, cached_suffix, cache_count = preflight_test_cache_hit(
235+
cfg,
236+
datamodule,
237+
checkpoint_path=args.checkpoint,
238+
)
239+
if cache_hit:
240+
skip_test_loop, test_ckpt_path = handle_test_cache_hit(
241+
args,
242+
cfg,
243+
cached_suffix,
244+
cache_count,
245+
ckpt_path,
246+
)
247+
if skip_test_loop:
248+
return
249+
250+
trainer.test(
251+
model,
252+
datamodule,
253+
ckpt_path=test_ckpt_path,
254+
)
255+
256+
257+
def dispatch_runtime(args: Any, cfg: Config) -> None:
258+
"""Dispatch the configured runtime mode."""
259+
configure_matmul_precision(cfg)
260+
261+
if args.mode in ["test", "tune", "tune-test"]:
262+
cfg.inference.save_prediction.cache_suffix = resolve_prediction_cache_suffix(cfg, args.mode)
263+
264+
if args.mode == "train":
265+
from . import preflight_check, print_preflight_issues
266+
267+
issues = preflight_check(cfg)
268+
if issues:
269+
print_preflight_issues(issues)
270+
271+
run_dir, output_base = setup_runtime_directories(args, cfg)
272+
273+
if cfg.system.seed is not None:
274+
print(f"Random seed set to: {cfg.system.seed}")
275+
seed_everything(cfg.system.seed, workers=True)
276+
277+
if args.mode == "test":
278+
maybe_enable_independent_test_sharding(args, cfg)
279+
if not has_assigned_test_shard(cfg, args):
280+
return
281+
282+
if try_cache_only_test_execution(
283+
cfg,
284+
args.mode,
285+
args.shard_id,
286+
args.num_shards,
287+
checkpoint_path=args.checkpoint,
288+
):
289+
return
290+
291+
saved_prediction_path = getattr(getattr(cfg, "decoding", None), "input_prediction_path", "")
292+
has_saved_prediction = bool(
293+
saved_prediction_path
294+
and isinstance(saved_prediction_path, str)
295+
and saved_prediction_path.strip()
296+
)
297+
tta_cached = args.mode in ("test", "tune", "tune-test") and (
298+
has_saved_prediction
299+
or has_tta_prediction_file(cfg)
300+
or has_cached_predictions_in_output_dir(
301+
cfg,
302+
mode=args.mode,
303+
checkpoint_path=args.checkpoint,
304+
)
305+
)
306+
307+
model, ckpt_path = _create_runtime_model(
308+
args,
309+
cfg,
310+
run_dir,
311+
has_saved_prediction=has_saved_prediction,
312+
saved_prediction_path=saved_prediction_path,
313+
tta_cached=tta_cached,
314+
)
315+
316+
trainer = create_trainer(
317+
cfg,
318+
run_dir=run_dir,
319+
fast_dev_run=args.fast_dev_run,
320+
ckpt_path=ckpt_path,
321+
mode=args.mode,
322+
)
323+
324+
try:
325+
if args.mode == "train":
326+
_run_training(args, cfg, model, trainer, ckpt_path)
327+
328+
if args.mode in ["tune", "tune-test"]:
329+
from .tune_runner import run_tuning
330+
331+
run_tuning(model, trainer, cfg, checkpoint_path=ckpt_path)
332+
333+
if args.mode in ["tune-test", "test"]:
334+
_run_test(
335+
args,
336+
cfg,
337+
model,
338+
trainer,
339+
run_dir,
340+
ckpt_path,
341+
has_saved_prediction=has_saved_prediction,
342+
saved_prediction_path=saved_prediction_path,
343+
)
344+
345+
except Exception as exc:
346+
mode_name = args.mode.capitalize() if args.mode else "Operation"
347+
print(f"\n{mode_name} failed: {exc}")
348+
import traceback
349+
350+
traceback.print_exc()
351+
sys.exit(1)
352+
finally:
353+
if args.mode == "train":
354+
cleanup_run_directory(output_base)
355+
356+
357+
__all__ = [
358+
"configure_matmul_precision",
359+
"dispatch_runtime",
360+
"prepare_cli_args",
361+
"suppress_nonzero_rank_stdout",
362+
]

0 commit comments

Comments
 (0)