Skip to content

Commit 3b066f0

Browse files
author
Donglai Wei
committed
fix pytorch DDP in just slurm
1 parent beae237 commit 3b066f0

32 files changed

Lines changed: 191 additions & 151 deletions

.claude/optuna/optuna_decoding_tuning.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ description: Automated optimization of decoding parameters using Optuna
2222
# ============================================================================
2323
system:
2424
num_gpus: 1
25-
num_cpus: 8
2625
seed: 42
2726

2827
# ============================================================================

connectomics/config/hydra_config.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,11 @@ class SystemTrainingConfig:
5454
5555
Attributes:
5656
num_gpus: Number of GPUs to use for training (0 for CPU-only)
57-
num_cpus: Number of CPU cores available for data loading
5857
num_workers: Number of parallel data loading workers
5958
batch_size: Training batch size (per GPU)
6059
"""
6160

6261
num_gpus: int = 1
63-
num_cpus: int = 4
6462
num_workers: int = 8
6563
batch_size: int = 4
6664

@@ -74,13 +72,11 @@ class SystemInferenceConfig:
7472
7573
Attributes:
7674
num_gpus: Number of GPUs to use for inference (0 for CPU-only)
77-
num_cpus: Number of CPU cores available for data loading
7875
num_workers: Number of parallel data loading workers
7976
batch_size: Inference batch size (usually 1 for large volumes)
8077
"""
8178

8279
num_gpus: int = 1
83-
num_cpus: int = 1
8480
num_workers: int = 1
8581
batch_size: int = 1
8682

@@ -422,7 +418,7 @@ class DataConfig:
422418

423419
# Data properties
424420
patch_size: List[int] = field(default_factory=lambda: [128, 128, 128])
425-
pad_size: List[int] = field(default_factory=lambda: [8, 32, 32])
421+
pad_size: List[int] = field(default_factory=lambda: [0, 0, 0])
426422
pad_mode: str = "reflect" # Padding mode: 'reflect', 'replicate', 'constant', 'edge'
427423
stride: List[int] = field(default_factory=lambda: [1, 1, 1]) # Sampling stride (z, y, x)
428424

@@ -476,6 +472,10 @@ class DataConfig:
476472
use_preloaded_cache: bool = (
477473
True # Preload volumes into memory for fast random cropping (default: True)
478474
)
475+
cached_sampling_max_attempts: int = 10 # Retry attempts for foreground-aware sampling
476+
cached_sampling_foreground_threshold: float = (
477+
0.0 # Minimum (label > 0) fraction required for training crops; 0 disables foreground sampling
478+
)
479479

480480
# Reject sampling configuration (for volumetric patch sampling)
481481
reject_sampling: Optional[Dict[str, Any]] = None # Dict with 'size_thres' and 'p' keys
@@ -572,7 +572,7 @@ class OptimizationConfig:
572572
benchmark: bool = True
573573

574574
# Validation and logging
575-
val_check_interval: Union[int, float] = 1.0
575+
val_check_interval: Union[int, float] = 1.0 # Validate every N epochs (legacy key name)
576576
log_every_n_steps: int = 50
577577
num_sanity_val_steps: int = 0
578578

@@ -1137,7 +1137,6 @@ class InferenceConfig:
11371137
# Inference-specific overrides (override system settings during inference)
11381138
# Use -1 to keep training values, or >= 0 to override
11391139
num_gpus: int = -1 # Override system.training.num_gpus if >= 0
1140-
num_cpus: int = -1 # Override system.training.num_cpus if >= 0
11411140
batch_size: int = -1 # Override system.training.batch_size if >= 0 (typically 1 for inference)
11421141
num_workers: int = -1 # Override system.training.num_workers if >= 0
11431142

connectomics/data/dataset/dataset_volume_cached.py

Lines changed: 79 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,8 @@ def __init__(
162162
mode: str = "train",
163163
pad_size: Optional[Tuple[int, ...]] = None,
164164
pad_mode: str = "reflect",
165+
max_attempts: int = 10,
166+
foreground_threshold: float = 0.05,
165167
):
166168
self.image_paths = image_paths
167169
self.label_paths = label_paths if label_paths else [None] * len(image_paths)
@@ -194,6 +196,8 @@ def __init__(
194196
self._d2_rejected_patches = 0
195197
self._d2_foreground_fracs = []
196198
self._d2_last_report_step = 0
199+
self.max_attempts = max_attempts
200+
self.foreground_threshold = foreground_threshold
197201

198202
# Load all volumes into memory
199203
print(f" Loading {len(image_paths)} volumes into memory...")
@@ -271,13 +275,16 @@ def __init__(
271275
# Support both 2D and 3D: get last N dimensions matching patch_size
272276
ndim = len(self.patch_size)
273277
self.volume_sizes = [img.shape[-ndim:] for img in self.cached_images] # (Z, Y, X) or (Y, X)
274-
278+
275279
# [D2 DIAGNOSTIC] Print foreground sampling configuration
276280
if self.mode == "train":
277-
print(f" [D2] Foreground sampling ENABLED:")
278-
print(f" - Minimum foreground threshold: 5.0%")
279-
print(f" - Max retry attempts: 10")
280-
print(f" - Will report statistics every 100 batches")
281+
if self.foreground_threshold > 0:
282+
print(" [D2] Foreground sampling ENABLED:")
283+
print(f" - Minimum foreground threshold: {self.foreground_threshold * 100:.1f}%")
284+
print(f" - Max retry attempts: {self.max_attempts}")
285+
print(" - Will report statistics every 100 batches")
286+
else:
287+
print(" [D2] Foreground sampling DISABLED (threshold <= 0)")
281288

282289
def _apply_padding(
283290
self, volume: np.ndarray, mode: Optional[str] = None, constant_values: float = 0
@@ -366,15 +373,15 @@ def __len__(self) -> int:
366373
def set_epoch(self, epoch: int, base_seed: int = 0):
367374
"""
368375
Set current epoch for epoch-based validation reseeding.
369-
376+
370377
This method enables validation to sample different patches each epoch
371378
while maintaining determinism. For training, this has no effect since
372379
training already uses random sampling.
373-
380+
374381
Args:
375382
epoch: Current training epoch
376383
base_seed: Base random seed (typically from cfg.system.seed)
377-
384+
378385
Usage:
379386
Called by ValidationReseedingCallback at the start of each validation epoch:
380387
if hasattr(dataset, 'set_epoch'):
@@ -387,32 +394,36 @@ def set_epoch(self, epoch: int, base_seed: int = 0):
387394
self.current_epoch = epoch
388395
effective_seed = self.base_seed + epoch
389396
random.seed(effective_seed)
390-
397+
391398
# IMPORTANT: Print to verify reseeding is happening
392399
# This should appear in logs at the start of EACH validation epoch
393-
print(f"[Validation] Set epoch={epoch}, base_seed={base_seed}, effective_seed={effective_seed}")
394-
print(f"[Validation] Dataset: {type(self).__name__}@{id(self)}, mode={self.mode}, iter_num={self.iter_num}")
395-
400+
print(
401+
f"[Validation] Set epoch={epoch}, base_seed={base_seed}, effective_seed={effective_seed}"
402+
)
403+
print(
404+
f"[Validation] Dataset: {type(self).__name__}@{id(self)}, mode={self.mode}, iter_num={self.iter_num}"
405+
)
406+
396407
def get_sampling_fingerprint(self, num_samples: int = 5) -> str:
397408
"""
398409
Generate a deterministic fingerprint of validation sampling.
399-
410+
400411
This allows verification that validation patches change across epochs.
401412
The fingerprint is based on the first N random samples that would be
402413
generated with the current RNG state.
403-
414+
404415
Args:
405416
num_samples: Number of random samples to include in fingerprint
406-
417+
407418
Returns:
408419
String representing the sampling fingerprint
409420
"""
410421
if self.mode != "val":
411422
return "N/A (training mode)"
412-
423+
413424
# Save current RNG state
414425
state = random.getstate()
415-
426+
416427
try:
417428
# Generate deterministic samples
418429
samples = []
@@ -422,11 +433,11 @@ def get_sampling_fingerprint(self, num_samples: int = 5) -> str:
422433
# Sample patch position
423434
pos = self._get_random_crop_position(vol_idx)
424435
samples.append((vol_idx, pos))
425-
436+
426437
# Create fingerprint string
427438
fingerprint = ", ".join([f"v{v}@{p}" for v, p in samples])
428439
return fingerprint
429-
440+
430441
finally:
431442
# Restore RNG state
432443
random.setstate(state)
@@ -488,25 +499,46 @@ def __getitem__(self, index: int) -> Dict[str, Any]:
488499
label = self.cached_labels[vol_idx]
489500
mask = self.cached_masks[vol_idx]
490501

491-
# [D2] Foreground-aware patch sampling: ensure patches contain sufficient mitochondria
492-
# This prevents SDT collapse by avoiding background-only patches
493-
max_attempts = 10
494-
foreground_threshold = 0.05 # Require at least 5% foreground (SDT > 0)
495-
496-
# [D2 DIAGNOSTIC] Track sampling attempts
497-
attempts_used = 0
502+
# [D2] Foreground-aware patch sampling: optional retry loop for training.
503+
# Disabled by default when foreground_threshold <= 0.
504+
max_attempts = self.max_attempts
505+
foreground_threshold = self.foreground_threshold
506+
use_foreground_sampling = (
507+
self.mode == "train" and label is not None and foreground_threshold > 0
508+
)
509+
510+
# [D2 DIAGNOSTIC] Track sampling attempts only when foreground sampling is active.
511+
attempts_used = 1
498512
final_foreground_frac = 0.0
499-
500-
for attempt in range(max_attempts):
501-
attempts_used = attempt + 1
502-
503-
# Get crop position
513+
514+
if use_foreground_sampling:
515+
for attempt in range(max_attempts):
516+
attempts_used = attempt + 1
517+
pos = self._get_random_crop_position(vol_idx)
518+
519+
# Crop using fast numpy slicing (like v1)
520+
image_crop = crop_volume(image, self.patch_size, pos)
521+
label_crop = crop_volume(label, self.patch_size, pos)
522+
if mask is not None:
523+
mask_crop = crop_volume(mask, self.patch_size, pos)
524+
else:
525+
mask_crop = np.zeros_like(image_crop)
526+
527+
foreground_frac = (label_crop > 0).sum() / label_crop.size
528+
final_foreground_frac = foreground_frac
529+
530+
if foreground_frac >= foreground_threshold:
531+
break
532+
533+
# [D2 DIAGNOSTIC] Patch rejected, increment counter
534+
self._d2_rejected_patches += 1
535+
else:
536+
# Standard single-crop behavior (no foreground-based retry)
504537
if self.mode == "train":
505538
pos = self._get_random_crop_position(vol_idx)
506539
else:
507540
pos = self._get_center_crop_position(vol_idx)
508541

509-
# Crop using fast numpy slicing (like v1)
510542
image_crop = crop_volume(image, self.patch_size, pos)
511543
if label is not None:
512544
label_crop = crop_volume(label, self.patch_size, pos)
@@ -518,39 +550,30 @@ def __getitem__(self, index: int) -> Dict[str, Any]:
518550
else:
519551
mask_crop = np.zeros_like(image_crop)
520552

521-
# [D2] Check if patch has sufficient foreground (only during training)
522-
if self.mode == "train" and label is not None:
523-
foreground_frac = (label_crop > 0).sum() / label_crop.size
524-
final_foreground_frac = foreground_frac
525-
526-
if foreground_frac >= foreground_threshold:
527-
# [D2 DIAGNOSTIC] Good patch found
528-
break
529-
else:
530-
# [D2 DIAGNOSTIC] Patch rejected, increment counter
531-
self._d2_rejected_patches += 1
532-
else:
533-
# Val/test mode or no label: accept any patch
534-
break
535-
536-
# [D2 DIAGNOSTIC] Record statistics
537-
self._d2_total_samples += 1
538-
self._d2_total_attempts += attempts_used
539-
self._d2_foreground_fracs.append(final_foreground_frac * 100) # Convert to percentage
540-
553+
# [D2 DIAGNOSTIC] Record/report sampling stats only when enabled.
554+
if use_foreground_sampling:
555+
self._d2_total_samples += 1
556+
self._d2_total_attempts += attempts_used
557+
self._d2_foreground_fracs.append(final_foreground_frac * 100) # percentage
558+
541559
# [D2 DIAGNOSTIC] Print report every 100 samples (not too verbose)
542-
if self.mode == "train" and self._d2_total_samples % 100 == 0:
560+
if use_foreground_sampling and self._d2_total_samples % 100 == 0:
543561
avg_attempts = self._d2_total_attempts / self._d2_total_samples
544562
reject_rate = (self._d2_rejected_patches / self._d2_total_attempts) * 100
545563
avg_fg = sum(self._d2_foreground_fracs) / len(self._d2_foreground_fracs)
546564
min_fg = min(self._d2_foreground_fracs)
547565
max_fg = max(self._d2_foreground_fracs)
548-
566+
549567
print(f"[D2 Sampling Stats after {self._d2_total_samples} batches]")
550568
print(f" Avg attempts per patch: {avg_attempts:.2f}/{max_attempts}")
551-
print(f" Patches rejected: {self._d2_rejected_patches}/{self._d2_total_attempts} ({reject_rate:.1f}%)")
569+
print(
570+
f" Patches rejected: {self._d2_rejected_patches}/{self._d2_total_attempts} ({reject_rate:.1f}%)"
571+
)
552572
print(f" Final foreground %: avg={avg_fg:.1f}%, min={min_fg:.1f}%, max={max_fg:.1f}%")
553-
print(f" Threshold: {foreground_threshold*100:.1f}% (5% minimum)")
573+
print(
574+
f" Threshold: {foreground_threshold * 100:.1f}% "
575+
f"({self.foreground_threshold * 100:.1f}% minimum)"
576+
)
554577

555578
# Create data dict
556579
data = {

connectomics/data/process/distance.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,10 @@ def skeleton_aware_distance_transform(
306306
"""
307307
eps = 1e-6
308308

309+
# Fast-path: empty label should produce all background energy.
310+
if np.sum(label > 0) == 0:
311+
return np.full(label.shape, bg_value, dtype=np.float32)
312+
309313
# Configure bbox processor
310314
config = BBoxProcessorConfig(
311315
bg_value=bg_value,

connectomics/training/lit/data_factory.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,8 @@ def create_datamodule(
568568
mode="train",
569569
pad_size=tuple(pad_size) if pad_size else None,
570570
pad_mode=pad_mode,
571+
max_attempts=cfg.data.cached_sampling_max_attempts,
572+
foreground_threshold=cfg.data.cached_sampling_foreground_threshold,
571573
)
572574

573575
# Use fewer workers since we're loading from memory
@@ -623,6 +625,8 @@ def create_datamodule(
623625
mode="val",
624626
pad_size=tuple(pad_size) if pad_size else None,
625627
pad_mode=pad_mode,
628+
max_attempts=cfg.data.cached_sampling_max_attempts,
629+
foreground_threshold=cfg.data.cached_sampling_foreground_threshold,
626630
)
627631

628632
# Create validation dataloader

connectomics/training/lit/trainer.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def create_trainer(
176176
f" EMA: Enabled (decay={ema_cfg.decay}, warmup_steps={ema_cfg.warmup_steps}, "
177177
f"validate_with_ema={ema_cfg.validate_with_ema})"
178178
)
179-
179+
180180
# [FIX 1 - PROPER IMPLEMENTATION] Validation reseeding callback
181181
# This ensures validation datasets are reseeded at the start of EACH validation epoch
182182
# Previous fix in val_dataloader() only ran once during setup
@@ -272,6 +272,27 @@ def create_trainer(
272272
max_steps = -1 # -1 means unlimited steps
273273
training_mode = f"epoch-based ({max_epochs} epochs)"
274274

275+
# Treat optimization.val_check_interval as epoch interval (legacy key name).
276+
# Accept values like 1.0 from existing YAMLs, but reject non-integer floats.
277+
val_check_cfg = cfg.optimization.val_check_interval
278+
if isinstance(val_check_cfg, float):
279+
if not val_check_cfg.is_integer():
280+
raise ValueError(
281+
"optimization.val_check_interval must be an integer number of epochs "
282+
f"(got {val_check_cfg})."
283+
)
284+
check_val_every_n_epoch = int(val_check_cfg)
285+
else:
286+
check_val_every_n_epoch = int(val_check_cfg)
287+
288+
if check_val_every_n_epoch < 1:
289+
raise ValueError(
290+
"optimization.val_check_interval must be >= 1 "
291+
f"(got {check_val_every_n_epoch})."
292+
)
293+
294+
print(f" Validation: every {check_val_every_n_epoch} epoch(s)")
295+
275296
trainer = pl.Trainer(
276297
max_epochs=max_epochs,
277298
max_steps=max_steps,
@@ -282,7 +303,8 @@ def create_trainer(
282303
precision=cfg.optimization.precision,
283304
gradient_clip_val=cfg.optimization.gradient_clip_val,
284305
accumulate_grad_batches=cfg.optimization.accumulate_grad_batches,
285-
val_check_interval=cfg.optimization.val_check_interval,
306+
val_check_interval=1.0,
307+
check_val_every_n_epoch=check_val_every_n_epoch,
286308
log_every_n_steps=cfg.optimization.log_every_n_steps,
287309
callbacks=callbacks,
288310
logger=logger,

0 commit comments

Comments
 (0)