Skip to content

Commit c3b62b7

Browse files
Donglai Weiclaude
andcommitted
Add skip_loss for decode-only mode: no model build, no loss, no checkpoint
ConnectomicsModule(cfg, model=Identity(), skip_loss=True) skips all loss/optimizer/weighter construction. main.py uses this when inference.saved_prediction_path is set. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent d4f5118 commit c3b62b7

2 files changed

Lines changed: 19 additions & 1 deletion

File tree

connectomics/training/lightning/model.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def __init__(
9090
self,
9191
cfg: Union[Config, DictConfig],
9292
model: Optional[nn.Module] = None,
93+
skip_loss: bool = False,
9394
):
9495
super().__init__()
9596
self.cfg = cfg
@@ -98,6 +99,17 @@ def __init__(
9899
# Build model
99100
self.model = model if model is not None else self._build_model(cfg)
100101

102+
# Skip loss/optimizer setup for decode-only mode
103+
if skip_loss:
104+
self.loss_functions = []
105+
self.loss_weights = []
106+
self.loss_metadata = []
107+
self.loss_weighter = None
108+
self.enable_nan_detection = False
109+
self.debug_on_nan = False
110+
self.loss_orchestrator = None
111+
return
112+
101113
# Build loss functions
102114
self.loss_functions = self._build_losses(cfg)
103115
self.loss_weights = self._extract_loss_weights(cfg)

scripts/main.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -904,7 +904,13 @@ def main():
904904
)
905905

906906
# Create model
907-
if tta_cached:
907+
if has_saved_prediction:
908+
print(f" Decode-only mode: loading predictions from {_saved_pred}")
909+
print(f" Skipping model build entirely.")
910+
model = ConnectomicsModule(cfg, model=torch.nn.Identity(), skip_loss=True)
911+
model._skip_inference = True
912+
ckpt_path = None
913+
elif tta_cached:
908914
print(
909915
f" Cached intermediate predictions found; "
910916
f"creating lightweight module (skipping {cfg.model.arch.type} build)."

0 commit comments

Comments
 (0)