Skip to content

Commit 915af4c

Browse files
author
Donglai Wei
committed
fix just train lucchi++ monai_unet --fast-dev-run
1 parent 67b6624 commit 915af4c

4 files changed

Lines changed: 13 additions & 8 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,4 @@ lib/
156156
.claude/
157157
.codex/
158158
.CLAUDE.md
159+
tmp/

connectomics/config/hydra_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -915,6 +915,7 @@ class TestTimeAugmentationConfig:
915915
- rotation90_axes: Uses spatial-only indices (e.g., [1, 2] for H-W plane where 0=D, 1=H, 2=W)
916916
"""
917917

918+
enabled: bool = True # Enable flip/rotation TTA (preprocessing still applies)
918919
flip_axes: Any = None # TTA flip strategy: "all" (8 flips), null (no aug),
919920
# or list like [[2], [3]] (full tensor indices)
920921
rotation90_axes: Any = None # TTA rotation90 strategy: "all" (3 planes × 4 rotations),

connectomics/inference/tta.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -198,14 +198,16 @@ def predict(self, images: torch.Tensor, mask: Optional[torch.Tensor] = None) ->
198198
if getattr(self.cfg.data, "do_2d", False) and images.size(2) == 1:
199199
images = images.squeeze(2)
200200

201-
# Get TTA configuration
201+
# Get TTA configuration (respect enabled flag for augmentations)
202202
if hasattr(self.cfg, "inference") and hasattr(self.cfg.inference, "test_time_augmentation"):
203-
tta_flip_axes_config = getattr(
204-
self.cfg.inference.test_time_augmentation, "flip_axes", None
205-
)
206-
tta_rotation90_axes_config = getattr(
207-
self.cfg.inference.test_time_augmentation, "rotation90_axes", None
208-
)
203+
tta_cfg = self.cfg.inference.test_time_augmentation
204+
tta_enabled = getattr(tta_cfg, "enabled", True)
205+
if tta_enabled:
206+
tta_flip_axes_config = getattr(tta_cfg, "flip_axes", None)
207+
tta_rotation90_axes_config = getattr(tta_cfg, "rotation90_axes", None)
208+
else:
209+
tta_flip_axes_config = None
210+
tta_rotation90_axes_config = None
209211
else:
210212
tta_flip_axes_config = None
211213
tta_rotation90_axes_config = None

tutorials/lucchi++.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,12 +172,13 @@ optimization:
172172

173173
# Scheduler - ReduceLROnPlateau for adaptive learning
174174
scheduler:
175-
name: ReduceLROnPlateau # Reduce LR when validation loss plateaus
175+
name: ReduceLROnPlateau # Reduce LR when training loss plateaus
176176
mode: min # Monitor minimum loss
177177
factor: 0.5 # Reduce LR by 50%
178178
patience: 50 # Wait 50 epochs before reducing
179179
threshold: 1.0e-4 # Minimum change to qualify as improvement
180180
min_lr: 1.0e-6 # Don't go below 1e-6
181+
monitor: train_loss_total_epoch # No validation; monitor training loss
181182

182183
monitor:
183184
# Loss monitoring and validation frequency

0 commit comments

Comments
 (0)