Skip to content

Commit fbabdcf

Browse files
author
Donglai Wei
committed
hydra-lv decoding
1 parent 1b36294 commit fbabdcf

4 files changed

Lines changed: 18 additions & 25 deletions

File tree

File renamed without changes.

connectomics/config/hydra_config.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,9 @@
3737
"""
3838

3939
from __future__ import annotations
40-
from dataclasses import dataclass, field
40+
from dataclasses import dataclass, field, is_dataclass
4141
from typing import Dict, List, Optional, Tuple, Any, Union
42+
import inspect
4243

4344
# Note: MISSING can be imported from omegaconf if needed for required fields
4445

@@ -1269,26 +1270,14 @@ class Config:
12691270
import torch
12701271

12711272
if hasattr(torch, "serialization") and hasattr(torch.serialization, "add_safe_globals"):
1272-
torch.serialization.add_safe_globals(
1273-
[
1274-
ParameterConfig,
1275-
DecodingParameterSpace,
1276-
PostprocessingParameterSpace,
1277-
ParameterSpaceConfig,
1278-
# Core config dataclasses (for Lightning checkpoints)
1279-
Config,
1280-
SystemConfig,
1281-
SystemTrainingConfig,
1282-
SystemInferenceConfig,
1283-
ModelConfig,
1284-
DataConfig,
1285-
OptimizationConfig,
1286-
MonitorConfig,
1287-
InferenceConfig,
1288-
TestConfig,
1289-
TuneConfig,
1290-
]
1291-
)
1273+
# Register every dataclass defined in this module so Lightning checkpoints
1274+
# can be loaded safely when torch.load defaults to weights_only=True.
1275+
safe_dataclasses = [
1276+
obj
1277+
for obj in globals().values()
1278+
if inspect.isclass(obj) and obj.__module__ == __name__ and is_dataclass(obj)
1279+
]
1280+
torch.serialization.add_safe_globals(safe_dataclasses)
12921281
except Exception:
12931282
# Best-effort registration; ignore if torch not available at import time
12941283
pass

connectomics/training/lit/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ def __init__(
8383
num_tasks = len(cfg.model.multi_task_config) if hasattr(cfg.model, 'multi_task_config') and cfg.model.multi_task_config else 1
8484
self.loss_weighter = build_loss_weighter(cfg, num_tasks, model=self.model)
8585

86+
# Track multi-task configuration state for downstream logic/tests
87+
self.multi_task_config = getattr(cfg.model, 'multi_task_config', None)
88+
self.multi_task_enabled = bool(self.multi_task_config)
89+
8690
# Enable inline NaN detection (can be disabled via config)
8791
self.enable_nan_detection = getattr(cfg.model, 'enable_nan_detection', True)
8892
self.debug_on_nan = getattr(cfg.model, 'debug_on_nan', True)

tutorials/hydra-lv-finetune.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -279,10 +279,10 @@ test:
279279
decoding:
280280
- name: decode_instance_binary_contour_distance
281281
kwargs:
282-
binary_threshold: [0.9, 0.1]
283-
contour_threshold: [0.8, 1.1]
284-
distance_threshold: [0.5, 0.0]
285-
min_seed_size: 8
282+
binary_threshold: [0.5, 0.7]
283+
contour_threshold: [0.8, 13.1]
284+
distance_threshold: [0.5, -0.5]
285+
min_seed_size: 4
286286

287287
evaluation:
288288
enabled: true

0 commit comments

Comments
 (0)