Skip to content

Commit 145a695

Browse files
author
Donglai Wei
committed
Split decoding schema and migrate tutorials
1 parent 83425ad commit 145a695

58 files changed

Lines changed: 716 additions & 713 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

connectomics/config/pipeline/profile_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -608,8 +608,8 @@ def _build_reference_profile_specs() -> List[Tuple[str, List[str]]]:
608608
(
609609
"decoding_templates",
610610
(_STAGE_ROOT, _STAGE_DEFAULT, _STAGE_TUNE, _STAGE_TEST),
611-
"decoding",
612-
"decoding",
611+
"decoding.steps",
612+
"steps",
613613
),
614614
]
615615

connectomics/config/pipeline/stage_resolver.py

Lines changed: 7 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
from ..schema.root import MergeContext
1717
from .dict_utils import as_plain_dict
1818

19-
_NO_PATCH = object()
20-
2119

2220
def _collect_explicit_paths(yaml_node: Any, path: str = "") -> set[str]:
2321
"""Collect all explicit key paths present in YAML."""
@@ -149,18 +147,9 @@ def _walk(node: Any, path: str) -> Any:
149147
return extracted if isinstance(extracted, dict) else {}
150148

151149

152-
def _extract_explicit_value(
153-
section_obj: Any,
154-
section_path: str,
155-
has_explicit_path: Callable[[str], bool],
156-
) -> Any:
157-
"""Extract an explicitly provided non-mapping section value."""
158-
return section_obj if has_explicit_path(section_path) else _NO_PATCH
159-
160-
161150
def _has_section_patch(value: Any) -> bool:
162151
"""Return whether a section override should participate in runtime merging."""
163-
return value is not _NO_PATCH and value != {}
152+
return value != {}
164153

165154

166155
def _resolve_stage_key(mode: str) -> str:
@@ -229,7 +218,7 @@ def _collect_default_overrides(
229218
"default.inference",
230219
has_explicit_path,
231220
),
232-
"decoding": _extract_explicit_value(
221+
"decoding": _extract_explicit_patch(
233222
getattr(default_stage, "decoding", None),
234223
"default.decoding",
235224
has_explicit_path,
@@ -258,18 +247,11 @@ def _collect_stage_overrides(
258247
for section_name in _MODE_SECTIONS[stage_key]:
259248
section_obj = getattr(stage_cfg, section_name, None)
260249
section_path = f"{stage_key}.{section_name}"
261-
if section_name == "decoding":
262-
stage_overrides[section_name] = _extract_explicit_value(
263-
section_obj,
264-
section_path,
265-
has_explicit_path,
266-
)
267-
else:
268-
stage_overrides[section_name] = _extract_explicit_patch(
269-
section_obj,
270-
section_path,
271-
has_explicit_path,
272-
)
250+
stage_overrides[section_name] = _extract_explicit_patch(
251+
section_obj,
252+
section_path,
253+
has_explicit_path,
254+
)
273255

274256
return stage_overrides
275257

@@ -289,13 +271,6 @@ def _merge_runtime_sections(
289271
if not _has_section_patch(default_section) and not _has_section_patch(stage_section):
290272
continue
291273

292-
if section_name == "decoding":
293-
if _has_section_patch(stage_section):
294-
cfg.decoding = stage_section
295-
elif _has_section_patch(default_section):
296-
cfg.decoding = default_section
297-
continue
298-
299274
target_section = getattr(cfg, section_name)
300275
if not _has_section_patch(default_section):
301276
default_section = {}

connectomics/config/profiles/pipeline_profiles.yaml

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ pipeline_profiles:
1717
label_transform:
1818
profile: label_bcd
1919
decoding:
20-
- template: decoding_bcd
20+
steps:
21+
- template: decoding_bcd
2122
inference:
2223
test_time_augmentation:
2324
ensemble_mode: mean
@@ -31,7 +32,8 @@ pipeline_profiles:
3132
label_transform:
3233
profile: label_aff9
3334
decoding:
34-
- template: decoding_waterz
35+
steps:
36+
- template: decoding_waterz
3537
inference:
3638
test_time_augmentation:
3739
ensemble_mode: min
@@ -45,7 +47,8 @@ pipeline_profiles:
4547
label_transform:
4648
profile: label_aff12
4749
decoding:
48-
- template: decoding_waterz
50+
steps:
51+
- template: decoding_waterz
4952
inference:
5053
select_channel: [0, 1, 2]
5154
test_time_augmentation:
@@ -60,7 +63,8 @@ pipeline_profiles:
6063
label_transform:
6164
profile: label_aff9_sdt
6265
decoding:
63-
- template: decoding_waterz
66+
steps:
67+
- template: decoding_waterz
6468
inference:
6569
test_time_augmentation:
6670
ensemble_mode: [["0:9", min], ["9:", mean]]

connectomics/config/schema/__init__.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,21 @@
3232
SliceShiftZConfig,
3333
StripeConfig,
3434
)
35-
from .inference import (
35+
from .decoding import (
3636
BinaryPostprocessingConfig,
37-
ChunkingConfig,
38-
ChunkStitchingConfig,
3937
ConnectedComponentsConfig,
4038
DecodeBinaryContourDistanceWatershedConfig,
4139
DecodeModeConfig,
42-
EvaluationConfig,
40+
DecodingConfig,
41+
PostprocessingConfig,
42+
)
43+
from .evaluation import EvaluationConfig
44+
from .inference import (
45+
ChunkingConfig,
46+
ChunkStitchingConfig,
4347
InferenceConfig,
4448
InferenceDataConfig,
4549
InferenceMemoryCleanupConfig,
46-
PostprocessingConfig,
4750
PredictionTransformConfig,
4851
SavePredictionConfig,
4952
SlidingWindowConfig,
@@ -133,6 +136,7 @@
133136
"PredictionTransformConfig",
134137
"SavePredictionConfig",
135138
"InferenceMemoryCleanupConfig",
139+
"DecodingConfig",
136140
"PostprocessingConfig",
137141
"BinaryPostprocessingConfig",
138142
"ConnectedComponentsConfig",
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass, field
4+
from typing import Any, Dict, List, Optional, Tuple
5+
6+
7+
@dataclass
8+
class ConnectedComponentsConfig:
9+
"""Connected components filtering configuration."""
10+
11+
enabled: bool = False
12+
top_k: Optional[int] = None
13+
min_size: int = 0
14+
connectivity: int = 1
15+
16+
17+
@dataclass
18+
class BinaryPostprocessingConfig:
19+
"""Binary postprocessing pipeline configuration."""
20+
21+
enabled: bool = False
22+
median_filter_size: Optional[Tuple[int, ...]] = None
23+
opening_iterations: int = 0
24+
closing_iterations: int = 0
25+
connected_components: Optional[ConnectedComponentsConfig] = None
26+
27+
28+
@dataclass
29+
class PostprocessingConfig:
30+
"""Postprocessing configuration for decoded outputs."""
31+
32+
enabled: bool = False
33+
binary: Optional[Dict[str, Any]] = field(default_factory=dict)
34+
instance_cc3d: Optional[Dict[str, Any]] = None
35+
output_transpose: List[int] = field(default_factory=list)
36+
37+
38+
@dataclass
39+
class DecodeBinaryContourDistanceWatershedConfig:
40+
"""Parameters for binary+contour+distance watershed decoding."""
41+
42+
binary_threshold: Tuple[float, float] = (0.9, 0.8)
43+
contour_threshold: Tuple[float, float] = (0.5, 0.5)
44+
distance_threshold: Tuple[float, float] = (0.5, -0.5)
45+
min_instance_size: int = 10
46+
min_seed_size: int = 10
47+
seed_distance_scale: float = 1.0
48+
49+
50+
@dataclass
51+
class DecodeModeConfig:
52+
"""Single decode mode configuration."""
53+
54+
enabled: bool = True
55+
name: str = "decode_semantic"
56+
kwargs: Dict[str, Any] = field(default_factory=dict)
57+
58+
59+
@dataclass
60+
class DecodingConfig:
61+
"""Decoded-output orchestration configuration."""
62+
63+
steps: List[DecodeModeConfig] = field(default_factory=list)
64+
postprocessing: PostprocessingConfig = field(default_factory=PostprocessingConfig)
65+
output_path: str = ""
66+
input_prediction_path: str = ""
67+
tuning: Optional[Dict[str, Any]] = None
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from typing import Any, List, Optional
5+
6+
7+
@dataclass
8+
class EvaluationConfig:
9+
"""Evaluation configuration."""
10+
11+
enabled: bool = False
12+
metrics: Optional[List[str]] = None
13+
prediction_threshold: float = 0.5
14+
instance_iou_threshold: float = 0.5
15+
nerl_graph: Any = None
16+
nerl_mask: Any = None
17+
nerl_resolution: Optional[List[float]] = None
18+
nerl_merge_threshold: int = 1
19+
nerl_chunk_num: int = 1
20+
nerl_skeleton_id_attribute: str = "id"
21+
nerl_skeleton_position_attribute: str = "index_position"
22+
nerl_skeleton_edge_length_attribute: str = "edge_length"
23+
nerl_skeleton_position_order: str = "xyz"
24+
nerl_prediction_position_order: Optional[str] = None

connectomics/config/schema/inference.py

Lines changed: 1 addition & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from dataclasses import dataclass, field
4-
from typing import Any, Dict, List, Optional, Tuple
4+
from typing import Any, Dict, List, Optional
55

66
from .system import SystemConfig
77

@@ -157,102 +157,6 @@ class InferenceMemoryCleanupConfig:
157157
release_model_after_inference: bool = False
158158

159159

160-
@dataclass
161-
class DecodeBinaryContourDistanceWatershedConfig:
162-
"""Parameters for binary+contour+distance watershed decoding."""
163-
164-
binary_threshold: Tuple[float, float] = (0.9, 0.8)
165-
contour_threshold: Tuple[float, float] = (0.5, 0.5)
166-
distance_threshold: Tuple[float, float] = (0.5, -0.5)
167-
min_instance_size: int = 10
168-
min_seed_size: int = 10
169-
seed_distance_scale: float = 1.0
170-
171-
172-
@dataclass
173-
class DecodeModeConfig:
174-
"""Single decode mode configuration."""
175-
176-
enabled: bool = True
177-
name: str = "decode_semantic"
178-
kwargs: Dict[str, Any] = field(default_factory=dict)
179-
180-
181-
@dataclass
182-
class BinaryPostprocessingConfig:
183-
"""Binary postprocessing pipeline configuration."""
184-
185-
enabled: bool = False # Enable binary postprocessing pipeline
186-
median_filter_size: Optional[Tuple[int, ...]] = (
187-
None # Median filter kernel size (e.g., (3, 3) for 2D)
188-
)
189-
opening_iterations: int = 0 # Number of morphological opening iterations
190-
closing_iterations: int = 0 # Number of morphological closing iterations
191-
connected_components: Optional[ConnectedComponentsConfig] = None # CC filtering config
192-
193-
194-
@dataclass
195-
class ConnectedComponentsConfig:
196-
"""Connected components filtering configuration."""
197-
198-
enabled: bool = False # Enable connected components filtering
199-
top_k: Optional[int] = None # Keep only top-k largest components (None = keep all)
200-
min_size: int = 0 # Minimum component size in voxels
201-
connectivity: int = 1 # Connectivity for CC (1=face, 2=face+edge, 3=face+edge+corner)
202-
203-
204-
@dataclass
205-
class PostprocessingConfig:
206-
"""Postprocessing configuration for inference output.
207-
208-
Controls how decoded outputs are transformed after decoding:
209-
- Binary refinement: Morphological operations and connected components filtering
210-
- Transpose: Reorder axes (e.g., [2,1,0] for zyx->xyz)
211-
212-
Note: raw prediction scaling/dtype changes that should affect decoding are
213-
handled by PredictionTransformConfig. Save-only encoding is handled by
214-
SavePredictionConfig.
215-
"""
216-
217-
enabled: bool = False # Enable postprocessing pipeline
218-
219-
# Binary segmentation refinement (morphological ops, connected components)
220-
binary: Optional[Dict[str, Any]] = field(
221-
default_factory=dict
222-
) # Binary postprocessing config (e.g., {'opening_iterations': 2})
223-
224-
# Instance cc3d relabeling: split disconnected components and remove small ones
225-
instance_cc3d: Optional[Dict[str, Any]] = None
226-
# Example: {connectivity: 6, min_size: 100, remove_boundary: false}
227-
228-
# Axis permutation
229-
output_transpose: List[int] = field(
230-
default_factory=list
231-
) # Axis permutation for output (e.g., [2,1,0] for zyx->xyz)
232-
233-
234-
@dataclass
235-
class EvaluationConfig:
236-
"""Evaluation configuration."""
237-
238-
enabled: bool = False # Auto-enabled when evaluation keys are provided in YAML
239-
metrics: Optional[List[str]] = None # e.g., ['dice', 'jaccard', 'accuracy']
240-
prediction_threshold: float = 0.5 # Probability/logit threshold for binary metrics
241-
instance_iou_threshold: float = 0.5 # IoU threshold for instance matching metrics
242-
# Neurite ERL evaluation via lib/em_erl. nerl_graph accepts an ERLGraph
243-
# .npz or a BANIS/NISB-style NetworkX skeleton.pkl.
244-
nerl_graph: Any = None
245-
nerl_mask: Any = None
246-
nerl_resolution: Optional[List[float]] = None
247-
nerl_merge_threshold: int = 1
248-
nerl_chunk_num: int = 1
249-
nerl_skeleton_id_attribute: str = "id"
250-
nerl_skeleton_position_attribute: str = "index_position"
251-
nerl_skeleton_edge_length_attribute: str = "edge_length"
252-
nerl_skeleton_position_order: str = "xyz"
253-
nerl_prediction_position_order: Optional[str] = None
254-
255-
256160
@dataclass
257161
class InferenceConfig:
258162
"""Inference configuration.
@@ -289,16 +193,10 @@ class InferenceConfig:
289193
# Optional explicit intermediate prediction file (.h5). If set in test
290194
# mode, pipeline loads this file directly and proceeds to top-level decoding.
291195
tta_result_path: str = ""
292-
# Path to pre-computed affinity prediction HDF5 (dataset "main").
293-
# When set, skips model inference — loads and decodes directly.
294-
saved_prediction_path: str = ""
295-
# Path to save decoded instance segmentation (separate from raw prediction).
296-
decoding_path: str = ""
297196
prediction_transform: PredictionTransformConfig = field(
298197
default_factory=PredictionTransformConfig
299198
)
300199
save_prediction: SavePredictionConfig = field(default_factory=SavePredictionConfig)
301-
postprocessing: PostprocessingConfig = field(default_factory=PostprocessingConfig)
302200
memory_cleanup: InferenceMemoryCleanupConfig = field(
303201
default_factory=InferenceMemoryCleanupConfig
304202
)

0 commit comments

Comments
 (0)