Skip to content

Commit fae7e09

Browse files
author
Donglai Wei
committed
Type decoding tuning schema
1 parent 5aba493 commit fae7e09

3 files changed

Lines changed: 91 additions & 1 deletion

File tree

connectomics/config/schema/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,11 @@
3636
DecodeBinaryContourDistanceWatershedConfig,
3737
DecodeModeConfig,
3838
DecodingConfig,
39+
DecodingTuningConfig,
40+
DecodingTuningParameterSpaceConfig,
3941
PostprocessingConfig,
42+
TuningFunctionSpaceConfig,
43+
TuningParameterConfig,
4044
)
4145
from .evaluation import EvaluationConfig
4246
from .inference import (
@@ -133,12 +137,16 @@
133137
"SavePredictionConfig",
134138
"InferenceMemoryCleanupConfig",
135139
"DecodingConfig",
140+
"DecodingTuningConfig",
141+
"DecodingTuningParameterSpaceConfig",
136142
"PostprocessingConfig",
137143
"BinaryPostprocessingConfig",
138144
"ConnectedComponentsConfig",
139145
"EvaluationConfig",
140146
"DecodeModeConfig",
141147
"DecodeBinaryContourDistanceWatershedConfig",
148+
"TuningFunctionSpaceConfig",
149+
"TuningParameterConfig",
142150
# Test configuration
143151
"TestConfig",
144152
# Tuning configuration

connectomics/config/schema/decoding.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,48 @@ class DecodeModeConfig:
5656
kwargs: Dict[str, Any] = field(default_factory=dict)
5757

5858

59+
@dataclass
60+
class TuningParameterConfig:
61+
"""Single tunable decoding/postprocessing parameter."""
62+
63+
type: str = "float"
64+
range: List[Any] = field(default_factory=list)
65+
choices: List[Any] = field(default_factory=list)
66+
step: Optional[float] = None
67+
log: bool = False
68+
param_group: Optional[str] = None
69+
tuple_index: Optional[int] = None
70+
description: Optional[str] = None
71+
72+
73+
@dataclass
74+
class TuningFunctionSpaceConfig:
75+
"""Search space for one decoding or postprocessing function."""
76+
77+
enabled: bool = False
78+
function_name: str = ""
79+
defaults: Dict[str, Any] = field(default_factory=dict)
80+
parameters: Dict[str, TuningParameterConfig] = field(default_factory=dict)
81+
82+
83+
@dataclass
84+
class DecodingTuningParameterSpaceConfig:
85+
"""Parameter spaces attached to a decoding pipeline."""
86+
87+
decoding: TuningFunctionSpaceConfig = field(default_factory=TuningFunctionSpaceConfig)
88+
postprocessing: TuningFunctionSpaceConfig = field(default_factory=TuningFunctionSpaceConfig)
89+
90+
91+
@dataclass
92+
class DecodingTuningConfig:
93+
"""Structured tuning metadata for decoded-output workflows."""
94+
95+
enabled: bool = False
96+
parameter_space: DecodingTuningParameterSpaceConfig = field(
97+
default_factory=DecodingTuningParameterSpaceConfig
98+
)
99+
100+
59101
@dataclass
60102
class DecodingConfig:
61103
"""Decoded-output orchestration configuration."""
@@ -64,4 +106,4 @@ class DecodingConfig:
64106
postprocessing: PostprocessingConfig = field(default_factory=PostprocessingConfig)
65107
output_path: str = ""
66108
input_prediction_path: str = ""
67-
tuning: Optional[Dict[str, Any]] = None
109+
tuning: Optional[DecodingTuningConfig] = None

tests/unit/test_hydra_config.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -875,6 +875,46 @@ def test_top_level_decoding_template_list_ref(tmp_path):
875875
assert cfg.decoding.steps[0].kwargs["min_instance_size"] == 7
876876

877877

878+
def test_decoding_tuning_uses_structured_schema(tmp_path):
879+
config_yaml = tmp_path / "decoding_tuning.yaml"
880+
config_yaml.write_text("""
881+
decoding:
882+
tuning:
883+
enabled: true
884+
parameter_space:
885+
decoding:
886+
enabled: true
887+
function_name: decode_waterz
888+
parameters:
889+
thresholds:
890+
range: [0.1, 0.9]
891+
step: 0.1
892+
""".strip())
893+
894+
cfg = load_config(config_yaml)
895+
896+
assert cfg.decoding.tuning is not None
897+
assert cfg.decoding.tuning.enabled is True
898+
decoding_space = cfg.decoding.tuning.parameter_space.decoding
899+
assert decoding_space.function_name == "decode_waterz"
900+
threshold_param = decoding_space.parameters["thresholds"]
901+
assert threshold_param.range == [0.1, 0.9]
902+
assert threshold_param.step == 0.1
903+
904+
905+
def test_decoding_tuning_rejects_unknown_nested_keys(tmp_path):
906+
config_yaml = tmp_path / "bad_decoding_tuning.yaml"
907+
config_yaml.write_text("""
908+
decoding:
909+
tuning:
910+
enabled: true
911+
unknown_key: true
912+
""".strip())
913+
914+
with pytest.raises(Exception, match="unknown_key"):
915+
load_config(config_yaml)
916+
917+
878918
def test_loss_profile_positional_overrides(tmp_path):
879919
"""Loss profile + overrides dict patches individual list entries by index."""
880920
base_yaml = tmp_path / "base.yaml"

0 commit comments

Comments
 (0)