Skip to content

Commit a169252

Browse files
author
Donglai Wei
committed
Redesign decoding module
1 parent b66f8cf commit a169252

12 files changed

Lines changed: 894 additions & 793 deletions

File tree

connectomics/decoding/__init__.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,19 @@
1-
"""
2-
Decoding package for PyTorch Connectomics.
1+
"""Decoding package for PyTorch Connectomics."""
32

4-
This package provides post-processing functions to convert model predictions
5-
into final instance segmentation masks for various biological structures.
6-
7-
Modules:
8-
- segmentation: Mitochondria and organelle instance decoding
9-
- synapse: Synaptic polarity instance decoding
10-
- postprocess: General post-processing utilities
11-
- utils: Shared utility functions
12-
- auto_tuning: Hyperparameter optimization for post-processing
13-
14-
Import patterns:
15-
from connectomics.decoding import decode_binary_watershed, decode_binary_contour_watershed
16-
from connectomics.decoding import polarity2instance
17-
from connectomics.decoding import stitch_3d, watershed_split
18-
from connectomics.decoding import optimize_threshold, SkeletonMetrics
19-
"""
3+
from .base import DecodeStep
4+
from .pipeline import (
5+
apply_decode_mode,
6+
apply_decode_pipeline,
7+
normalize_decode_modes,
8+
resolve_decode_modes_from_cfg,
9+
)
10+
from .registry import (
11+
DecoderRegistry,
12+
get_decoder,
13+
list_decoders,
14+
register_builtin_decoders,
15+
register_decoder,
16+
)
2017

2118
from .auto_tuning import (
2219
SkeletonMetrics,
@@ -53,7 +50,19 @@
5350
remove_small_instances,
5451
)
5552

53+
register_builtin_decoders()
54+
5655
__all__ = [
56+
# Registry / pipeline
57+
"DecodeStep",
58+
"DecoderRegistry",
59+
"register_decoder",
60+
"get_decoder",
61+
"list_decoders",
62+
"normalize_decode_modes",
63+
"apply_decode_pipeline",
64+
"resolve_decode_modes_from_cfg",
65+
"apply_decode_mode",
5766
# Segmentation decoding
5867
"decode_instance_binary_contour_distance",
5968
"decode_affinity_cc",

connectomics/decoding/base.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""Core decoding type definitions."""
2+
3+
from __future__ import annotations
4+
5+
from dataclasses import dataclass, field
6+
from typing import Any, Dict, Protocol
7+
8+
import numpy as np
9+
10+
11+
class DecodeFunction(Protocol):
12+
"""Callable decoder signature used by the registry."""
13+
14+
def __call__(self, predictions: np.ndarray, **kwargs: Any) -> np.ndarray: ...
15+
16+
17+
@dataclass
18+
class DecodeStep:
19+
"""Single step in a decoding pipeline."""
20+
21+
name: str
22+
kwargs: Dict[str, Any] = field(default_factory=dict)
23+

connectomics/decoding/pipeline.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
"""Decoding pipeline helpers."""
2+
3+
from __future__ import annotations
4+
5+
from typing import Any, Iterable, List, Sequence, Tuple
6+
7+
import numpy as np
8+
9+
from .base import DecodeStep
10+
from .registry import DecoderRegistry, DEFAULT_DECODER_REGISTRY, register_builtin_decoders
11+
12+
13+
def _coerce_kwargs(kwargs: Any) -> dict:
14+
if kwargs is None:
15+
return {}
16+
if hasattr(kwargs, "items"):
17+
return dict(kwargs)
18+
raise TypeError(
19+
f"Decode kwargs must be a mapping, got {type(kwargs).__name__}"
20+
)
21+
22+
23+
def normalize_decode_modes(decode_modes: Iterable[Any]) -> List[DecodeStep]:
24+
"""Normalize decode configuration entries into DecodeStep objects."""
25+
steps: List[DecodeStep] = []
26+
for mode in decode_modes:
27+
if isinstance(mode, DecodeStep):
28+
steps.append(DecodeStep(name=mode.name, kwargs=_coerce_kwargs(mode.kwargs)))
29+
continue
30+
31+
if hasattr(mode, "name"):
32+
name = mode.name
33+
kwargs = _coerce_kwargs(getattr(mode, "kwargs", {}))
34+
steps.append(DecodeStep(name=name, kwargs=kwargs))
35+
continue
36+
37+
if isinstance(mode, dict):
38+
name = mode.get("name")
39+
kwargs = _coerce_kwargs(mode.get("kwargs", {}))
40+
steps.append(DecodeStep(name=name, kwargs=kwargs))
41+
continue
42+
43+
raise TypeError(f"Unsupported decode mode type: {type(mode).__name__}")
44+
45+
for step in steps:
46+
if not step.name:
47+
raise ValueError("Decode step is missing required field 'name'.")
48+
49+
return steps
50+
51+
52+
def _prepare_batched_input(data: np.ndarray) -> Tuple[np.ndarray, int]:
53+
arr = np.asarray(data)
54+
if arr.ndim == 5:
55+
return arr, arr.shape[0]
56+
if arr.ndim == 4:
57+
return arr[np.newaxis, ...], 1
58+
if arr.ndim == 3:
59+
return arr[np.newaxis, np.newaxis, ...], 1
60+
if arr.ndim == 2:
61+
return arr[np.newaxis, np.newaxis, np.newaxis, ...], 1
62+
raise ValueError(
63+
f"Expected input with 2-5 dimensions, got shape {arr.shape}."
64+
)
65+
66+
67+
def apply_decode_pipeline(
68+
data: np.ndarray,
69+
decode_modes: Sequence[Any] | None,
70+
registry: DecoderRegistry | None = None,
71+
) -> np.ndarray:
72+
"""Apply configured decode steps to prediction data."""
73+
if not decode_modes:
74+
return data
75+
76+
register_builtin_decoders()
77+
registry = registry or DEFAULT_DECODER_REGISTRY
78+
steps = normalize_decode_modes(decode_modes)
79+
batched, batch_size = _prepare_batched_input(data)
80+
81+
results: List[np.ndarray] = []
82+
for batch_idx in range(batch_size):
83+
sample = batched[batch_idx]
84+
for step in steps:
85+
try:
86+
decoder = registry.get(step.name)
87+
except KeyError as exc:
88+
available = ", ".join(registry.available())
89+
raise ValueError(
90+
f"Unknown decode function '{step.name}'. "
91+
f"Available functions: [{available}]."
92+
) from exc
93+
94+
try:
95+
sample = decoder(sample, **step.kwargs)
96+
except Exception as exc:
97+
raise RuntimeError(
98+
f"Error applying decode function '{step.name}': {exc}"
99+
) from exc
100+
101+
results.append(sample)
102+
103+
if len(results) == 1:
104+
return results[0]
105+
return np.stack(results, axis=0)
106+
107+
108+
def resolve_decode_modes_from_cfg(cfg: Any) -> Sequence[Any] | None:
109+
"""Resolve decode mode list from config.
110+
111+
Priority:
112+
1. ``cfg.test.decoding``
113+
2. ``cfg.inference.decoding``
114+
"""
115+
if hasattr(cfg, "test") and cfg.test and hasattr(cfg.test, "decoding"):
116+
return cfg.test.decoding
117+
if hasattr(cfg, "inference") and hasattr(cfg.inference, "decoding"):
118+
return cfg.inference.decoding
119+
return None
120+
121+
122+
def apply_decode_mode(cfg: Any, data: np.ndarray, *, verbose: bool = True) -> np.ndarray:
123+
"""Apply decode pipeline resolved from ``test.decoding`` or ``inference.decoding``."""
124+
decode_modes = resolve_decode_modes_from_cfg(cfg)
125+
if not decode_modes:
126+
if verbose:
127+
print(" No decoding configuration found (test.decoding or inference.decoding)")
128+
return data
129+
130+
if verbose:
131+
source = "test.decoding"
132+
if not (hasattr(cfg, "test") and cfg.test and hasattr(cfg.test, "decoding")):
133+
source = "inference.decoding"
134+
print(f" Using {source}: {decode_modes}")
135+
136+
return apply_decode_pipeline(data, decode_modes)

connectomics/decoding/registry.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
"""Decoder registry for configurable decode pipelines."""
2+
3+
from __future__ import annotations
4+
5+
from typing import Dict, List
6+
7+
from .base import DecodeFunction
8+
9+
10+
class DecoderRegistry:
11+
"""Name -> decoder function registry."""
12+
13+
def __init__(self) -> None:
14+
self._decoders: Dict[str, DecodeFunction] = {}
15+
16+
def register(self, name: str, fn: DecodeFunction, *, overwrite: bool = False) -> None:
17+
"""Register a decoder function."""
18+
if not name or not isinstance(name, str):
19+
raise ValueError("Decoder name must be a non-empty string.")
20+
if not callable(fn):
21+
raise TypeError(f"Decoder '{name}' must be callable.")
22+
23+
existing = self._decoders.get(name)
24+
if existing is not None and not overwrite and existing is not fn:
25+
raise ValueError(
26+
f"Decoder '{name}' already registered. Use overwrite=True to replace it."
27+
)
28+
self._decoders[name] = fn
29+
30+
def get(self, name: str) -> DecodeFunction:
31+
"""Get decoder function by name."""
32+
try:
33+
return self._decoders[name]
34+
except KeyError as exc:
35+
available = ", ".join(sorted(self._decoders))
36+
raise KeyError(
37+
f"Unknown decoder '{name}'. Available decoders: [{available}]"
38+
) from exc
39+
40+
def available(self) -> List[str]:
41+
"""Return registered decoder names."""
42+
return sorted(self._decoders.keys())
43+
44+
45+
DEFAULT_DECODER_REGISTRY = DecoderRegistry()
46+
47+
48+
def register_decoder(name: str, fn: DecodeFunction, *, overwrite: bool = False) -> None:
49+
"""Register decoder in the default registry."""
50+
DEFAULT_DECODER_REGISTRY.register(name, fn, overwrite=overwrite)
51+
52+
53+
def get_decoder(name: str) -> DecodeFunction:
54+
"""Get decoder from the default registry."""
55+
return DEFAULT_DECODER_REGISTRY.get(name)
56+
57+
58+
def list_decoders() -> List[str]:
59+
"""List names in the default registry."""
60+
return DEFAULT_DECODER_REGISTRY.available()
61+
62+
63+
def register_builtin_decoders() -> None:
64+
"""Populate registry with built-in decoders."""
65+
from .segmentation import (
66+
decode_affinity_cc,
67+
decode_distance_watershed,
68+
decode_instance_binary_contour_distance,
69+
)
70+
from .synapse import polarity2instance
71+
72+
register_decoder(
73+
"decode_instance_binary_contour_distance",
74+
decode_instance_binary_contour_distance,
75+
overwrite=True,
76+
)
77+
register_decoder("decode_affinity_cc", decode_affinity_cc, overwrite=True)
78+
register_decoder("decode_distance_watershed", decode_distance_watershed, overwrite=True)
79+
register_decoder("polarity2instance", polarity2instance, overwrite=True)
80+

connectomics/inference/__init__.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,15 @@
11
"""Inference utilities package."""
22

33
from .manager import InferenceManager
4-
from .io import (
5-
apply_save_prediction_transform,
6-
apply_postprocessing,
7-
apply_decode_mode,
8-
resolve_output_filenames,
9-
write_outputs,
10-
)
4+
from .output import resolve_output_filenames, write_outputs
5+
from .postprocessing import apply_save_prediction_transform, apply_postprocessing
116
from .sliding import build_sliding_inferer, resolve_inferer_roi_size, resolve_inferer_overlap
127
from .tta import TTAPredictor
138

149
__all__ = [
1510
"InferenceManager",
1611
"apply_save_prediction_transform",
1712
"apply_postprocessing",
18-
"apply_decode_mode",
1913
"resolve_output_filenames",
2014
"write_outputs",
2115
"build_sliding_inferer",

0 commit comments

Comments
 (0)