Skip to content

Commit 0400e64

Browse files
author
Donglai Wei
committed
Fix demo config num_cpus error
1 parent f266946 commit 0400e64

3 files changed

Lines changed: 26 additions & 22 deletions

File tree

connectomics/training/lit/utils.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,9 @@ def setup_config(args) -> Config:
201201

202202
# Override config for fast-dev-run mode
203203
if args.fast_dev_run:
204+
fast_dev_num_gpus = 1 if torch.cuda.is_available() else 0
204205
print("🔧 Fast-dev-run mode: Overriding config for debugging")
205-
print(f" - num_gpus: {cfg.system.training.num_gpus}1")
206+
print(f" - num_gpus: {cfg.system.training.num_gpus}{fast_dev_num_gpus}")
206207
print(
207208
f" - num_workers: {cfg.system.training.num_workers} → 0 "
208209
"(avoid multiprocessing in debug mode)"
@@ -212,9 +213,9 @@ def setup_config(args) -> Config:
212213
)
213214
print(" - input patch: 64^3 for lightweight debug")
214215
print(" - MedNeXt size: S for lightweight debug")
215-
cfg.system.training.num_gpus = 1
216+
cfg.system.training.num_gpus = fast_dev_num_gpus
216217
cfg.system.training.num_workers = 0
217-
cfg.system.inference.num_gpus = 1
218+
cfg.system.inference.num_gpus = fast_dev_num_gpus
218219
cfg.system.inference.num_workers = 0
219220
if hasattr(cfg.model, "input_size"):
220221
cfg.model.input_size = [64, 64, 64]
@@ -230,19 +231,6 @@ def setup_config(args) -> Config:
230231
# Resolve -1 sentinels (auto-max resources for current runtime allocation).
231232
cfg = resolve_runtime_resource_sentinels(cfg, print_results=True)
232233

233-
# CPU-only fallback: avoid multiprocessing workers when no CUDA is available
234-
if not torch.cuda.is_available():
235-
if cfg.system.training.num_workers > 0:
236-
print(
237-
"🔧 CUDA not available, setting training num_workers=0 to avoid dataloader crashes"
238-
)
239-
cfg.system.training.num_workers = 0
240-
if cfg.system.inference.num_workers > 0:
241-
print(
242-
"🔧 CUDA not available, setting inference num_workers=0 to avoid dataloader crashes"
243-
)
244-
cfg.system.inference.num_workers = 0
245-
246234
# Apply inference-specific overrides if in test/tune mode
247235
if args.mode in ["test", "tune", "tune-test"]:
248236
if cfg.inference.num_gpus >= 0:
@@ -255,6 +243,25 @@ def setup_config(args) -> Config:
255243
print(f"🔧 Inference override: num_workers={cfg.inference.num_workers}")
256244
cfg.system.inference.num_workers = cfg.inference.num_workers
257245

246+
# CPU-only fallback after all overrides: ensure no CUDA-only settings remain.
247+
if not torch.cuda.is_available():
248+
if cfg.system.training.num_gpus > 0:
249+
print("🔧 CUDA not available, setting training num_gpus=0")
250+
cfg.system.training.num_gpus = 0
251+
if cfg.system.inference.num_gpus > 0:
252+
print("🔧 CUDA not available, setting inference num_gpus=0")
253+
cfg.system.inference.num_gpus = 0
254+
if cfg.system.training.num_workers > 0:
255+
print(
256+
"🔧 CUDA not available, setting training num_workers=0 to avoid dataloader crashes"
257+
)
258+
cfg.system.training.num_workers = 0
259+
if cfg.system.inference.num_workers > 0:
260+
print(
261+
"🔧 CUDA not available, setting inference num_workers=0 to avoid dataloader crashes"
262+
)
263+
cfg.system.inference.num_workers = 0
264+
258265
# Optional convenience toggle to enable nnU-Net preprocessing via CLI
259266
if getattr(args, "nnunet_preprocess", False):
260267
print("🔧 Enabling nnU-Net preprocessing from CLI flag")

scripts/demo.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,11 @@ def create_demo_config():
111111
seed=42,
112112
training=SystemTrainingConfig(
113113
num_gpus=1 if torch.cuda.is_available() else 0,
114-
num_cpus=2,
115114
batch_size=2,
116115
num_workers=0, # 0 for demo to avoid multiprocessing issues
117116
),
118117
inference=SystemInferenceConfig(
119118
num_gpus=1 if torch.cuda.is_available() else 0,
120-
num_cpus=2,
121119
batch_size=2,
122120
num_workers=0,
123121
),
@@ -142,7 +140,8 @@ def create_demo_config():
142140
stride=[16, 32, 32],
143141
iter_num_per_epoch=10, # Just 10 iterations per epoch
144142
use_cache=False,
145-
use_preloaded_cache=False,
143+
use_preloaded_cache_train=False,
144+
use_preloaded_cache_val=False,
146145
pin_memory=False,
147146
persistent_workers=False,
148147
),
@@ -187,7 +186,6 @@ def create_demo_config():
187186
),
188187
inference=InferenceConfig(
189188
num_gpus=-1,
190-
num_cpus=-1,
191189
batch_size=-1,
192190
num_workers=-1,
193191
),
@@ -403,4 +401,3 @@ def run_demo():
403401

404402
if __name__ == "__main__":
405403
run_demo()
406-

scripts/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# Testing mode
1616
python scripts/main.py --config tutorials/mito_lucchi++.yaml --mode test --checkpoint path/to/checkpoint.ckpt
1717
18-
# Fast dev run (1 batch for debugging, auto-sets num_gpus=1, num_cpus=1, num_workers=1)
18+
# Fast dev run (1 batch for debugging, auto-sets num_workers=0 and uses GPU only if CUDA is available)
1919
python scripts/main.py --config tutorials/mito_lucchi++.yaml --fast-dev-run
2020
python scripts/main.py --config tutorials/mito_lucchi++.yaml --fast-dev-run 2 # Run 2 batches
2121

0 commit comments

Comments
 (0)