Skip to content

Commit 534553d

Browse files
Donglai Weiclaude
andcommitted
decoding: decoder graph (DAG) generalizing the linear steps pipeline
Add a decoder-graph execution model to connectomics/decoding/: named nodes with typed inputs (raw prediction channels and/or upstream node outputs), multi-parent combine nodes, and one declared saved output. The linear decoding.steps chain is lowered to a chain graph and run by the same run_decode_graph executor (single code path; byte-parity by construction). Adds channel_gate and combine_split (background- preserving coarsest common refinement) ops and a strict decoding.graph schema with graph-xor-steps validation. Built via CCC (Claude plan / Codex code, Claude review); 35 tests pass. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
1 parent 7e460de commit 534553d

9 files changed

Lines changed: 838 additions & 63 deletions

File tree

connectomics/config/schema/decoding.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,25 @@ class DecodeModeConfig:
5656
kwargs: Dict[str, Any] = field(default_factory=dict)
5757

5858

59+
@dataclass
60+
class GraphNodeConfig:
61+
"""Single node in a decoder graph."""
62+
63+
enabled: bool = True
64+
name: str = ""
65+
op: str = ""
66+
inputs: List[str] = field(default_factory=list)
67+
kwargs: Dict[str, Any] = field(default_factory=dict)
68+
69+
70+
@dataclass
71+
class DecodeGraphConfig:
72+
"""Decoder graph configuration."""
73+
74+
nodes: List[GraphNodeConfig] = field(default_factory=list)
75+
output: str = ""
76+
77+
5978
@dataclass
6079
class TuningParameterConfig:
6180
"""Single tunable decoding/postprocessing parameter."""
@@ -154,6 +173,7 @@ class DecodingConfig:
154173
save_path: str = ""
155174
# Optional user-controlled filename suffix appended to decoded outputs.
156175
save_suffix: str = ""
176+
graph: Optional[DecodeGraphConfig] = None
157177
steps: List[DecodeModeConfig] = field(default_factory=list)
158178
postprocessing: PostprocessingConfig = field(default_factory=PostprocessingConfig)
159179
# Optional explicit raw-prediction file (.h5). If set, pipeline loads
@@ -162,3 +182,7 @@ class DecodingConfig:
162182
affinity_mask_path: str = ""
163183
affinity_qc: AffinityQCConfig = field(default_factory=AffinityQCConfig)
164184
tuning: Optional[DecodingTuningConfig] = None
185+
186+
def __post_init__(self) -> None:
187+
if self.graph is not None and self.steps:
188+
raise ValueError("decoding.graph cannot be set together with non-empty decoding.steps.")

connectomics/decoding/base.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
from dataclasses import dataclass, field
6-
from typing import Any, Dict, Protocol
6+
from typing import Any, Dict, List, Protocol
77

88
import numpy as np
99

@@ -14,10 +14,35 @@ class DecodeFunction(Protocol):
1414
def __call__(self, predictions: np.ndarray, **kwargs: Any) -> np.ndarray: ...
1515

1616

17+
class GraphOp(Protocol):
18+
"""Callable decoder-graph operation signature used by the registry."""
19+
20+
def __call__(self, inputs: List[np.ndarray], **kwargs: Any) -> np.ndarray: ...
21+
22+
1723
@dataclass
1824
class DecodeStep:
1925
"""Single step in a decoding pipeline."""
2026

2127
enabled: bool = True
2228
name: str = ""
2329
kwargs: Dict[str, Any] = field(default_factory=dict)
30+
31+
32+
@dataclass
33+
class DecodeNode:
34+
"""Single named node in a decoding graph."""
35+
36+
enabled: bool = True
37+
name: str = ""
38+
op: str = ""
39+
inputs: List[str] = field(default_factory=list)
40+
kwargs: Dict[str, Any] = field(default_factory=dict)
41+
42+
43+
@dataclass
44+
class DecodeGraph:
45+
"""Decoder graph with one declared output node."""
46+
47+
nodes: List[DecodeNode] = field(default_factory=list)
48+
output: str = ""
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""Multi-input decoder-graph combine operations."""
2+
3+
from __future__ import annotations
4+
5+
from typing import Sequence
6+
7+
import numpy as np
8+
9+
10+
def _validate_label_input(arr: np.ndarray, *, name: str) -> None:
11+
if not np.issubdtype(arr.dtype, np.integer):
12+
raise TypeError(f"combine_split {name} must be an integer label array.")
13+
if np.issubdtype(arr.dtype, np.signedinteger) and arr.size and int(arr.min()) < 0:
14+
raise ValueError(f"combine_split {name} must not contain negative labels.")
15+
16+
17+
def _check_key_space(max_a: int, max_b: int) -> int:
18+
max_uint64 = int(np.iinfo(np.uint64).max)
19+
if max_b >= max_uint64:
20+
raise OverflowError("combine_split pair-key base exceeds uint64 range.")
21+
base = max_b + 1
22+
if max_a > (max_uint64 - max_b) // base:
23+
raise OverflowError("combine_split pair keys would overflow uint64.")
24+
return base
25+
26+
27+
def combine_split(
28+
inputs: Sequence[np.ndarray],
29+
*,
30+
output_dtype: str | np.dtype = "uint32",
31+
) -> np.ndarray:
32+
"""Return the background-preserving coarsest common refinement of two labels."""
33+
if len(inputs) != 2:
34+
raise ValueError(f"combine_split expects exactly two inputs, got {len(inputs)}.")
35+
36+
a = np.asarray(inputs[0])
37+
b = np.asarray(inputs[1])
38+
if a.shape != b.shape:
39+
raise ValueError(
40+
f"combine_split inputs must have matching shapes, got {a.shape} and {b.shape}."
41+
)
42+
_validate_label_input(a, name="input 0")
43+
_validate_label_input(b, name="input 1")
44+
45+
dtype = np.dtype(output_dtype)
46+
if not np.issubdtype(dtype, np.integer):
47+
raise TypeError(f"combine_split output_dtype must be an integer dtype, got {dtype}.")
48+
49+
out = np.zeros(a.shape, dtype=dtype)
50+
fg = (a != 0) & (b != 0)
51+
if not bool(fg.any()):
52+
return out
53+
54+
a_fg = a[fg]
55+
b_fg = b[fg]
56+
max_a = int(a_fg.max())
57+
max_b = int(b_fg.max())
58+
base = _check_key_space(max_a, max_b)
59+
60+
key = a_fg.astype(np.uint64, copy=False)
61+
if key is a_fg:
62+
key = key.copy()
63+
key *= np.uint64(base)
64+
np.add(key, b_fg.astype(np.uint64, copy=False), out=key)
65+
66+
_, inv = np.unique(key, return_inverse=True)
67+
n_labels = int(inv.max()) + 1 if inv.size else 0
68+
if n_labels >= 2**32:
69+
raise OverflowError("combine_split produced too many labels for uint32 output.")
70+
dtype_info = np.iinfo(dtype)
71+
if n_labels > int(dtype_info.max):
72+
raise OverflowError(
73+
f"combine_split produced {n_labels} labels, exceeding output dtype {dtype}."
74+
)
75+
76+
labels = np.arange(1, n_labels + 1, dtype=dtype)
77+
out[fg] = labels[inv]
78+
return out
79+
80+
81+
__all__ = ["combine_split"]
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""Pure array transforms for decoder graphs."""
2+
3+
from __future__ import annotations
4+
5+
import numpy as np
6+
7+
from ...utils.channel_slices import resolve_channel_indices
8+
9+
10+
def channel_gate(
11+
predictions: np.ndarray,
12+
*,
13+
signal_channels,
14+
gate_channel,
15+
) -> np.ndarray:
16+
"""Multiply selected signal channels by a single gate channel."""
17+
arr = np.asarray(predictions)
18+
if arr.ndim < 1:
19+
raise ValueError("channel_gate expects an array with a channel axis.")
20+
21+
signal_indices = resolve_channel_indices(
22+
signal_channels,
23+
num_channels=int(arr.shape[0]),
24+
context="channel_gate.signal_channels",
25+
)
26+
gate_indices = resolve_channel_indices(
27+
gate_channel,
28+
num_channels=int(arr.shape[0]),
29+
context="channel_gate.gate_channel",
30+
)
31+
if len(gate_indices) != 1:
32+
raise ValueError(
33+
f"channel_gate.gate_channel must resolve to one channel, got {gate_indices}."
34+
)
35+
36+
gated = arr[signal_indices] * arr[gate_indices[0] : gate_indices[0] + 1]
37+
return gated.astype(arr.dtype, copy=False)
38+
39+
40+
__all__ = ["channel_gate"]

0 commit comments

Comments
 (0)