Skip to content

Commit aeea82a

Browse files
author
Donglai Wei
committed
Diagnose stalled mito_mitoEM_H train
1 parent a169252 commit aeea82a

6 files changed

Lines changed: 241 additions & 2 deletions

File tree

connectomics/decoding/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Decoding package for PyTorch Connectomics."""
22

33
from .base import DecodeStep
4+
from .abiss import decode_abiss
45
from .pipeline import (
56
apply_decode_mode,
67
apply_decode_pipeline,
@@ -67,6 +68,7 @@
6768
"decode_instance_binary_contour_distance",
6869
"decode_affinity_cc",
6970
"decode_distance_watershed",
71+
"decode_abiss",
7072
# Auto-tuning
7173
"optimize_threshold",
7274
"optimize_parameters",

connectomics/decoding/abiss.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
"""ABISS external wrapper decoder.
2+
3+
This module exposes a decoder that bridges prediction tensors to an external
4+
ABISS (or ABISS-compatible) command-line pipeline.
5+
6+
The wrapper writes predictions to temporary files, runs a user-specified
7+
command, then reads back an instance-label segmentation.
8+
"""
9+
10+
from __future__ import annotations
11+
12+
from pathlib import Path
13+
from tempfile import TemporaryDirectory
14+
from typing import Any, Dict, List, Mapping, Optional, Sequence
15+
import os
16+
import subprocess
17+
18+
import numpy as np
19+
20+
from connectomics.data.io import read_hdf5, write_hdf5
21+
22+
from .utils import cast2dtype
23+
24+
__all__ = ["decode_abiss"]
25+
26+
27+
def _format_command(
28+
command: str | Sequence[str],
29+
mapping: Mapping[str, str],
30+
) -> tuple[str | List[str], bool]:
31+
"""Format command placeholders for shell/list execution."""
32+
if isinstance(command, str):
33+
return command.format(**mapping), True
34+
if isinstance(command, Sequence):
35+
return [str(part).format(**mapping) for part in command], False
36+
raise TypeError(f"`command` must be str or sequence[str], got {type(command).__name__}.")
37+
38+
39+
def _load_output(output_h5: Path, output_npy: Path, output_dataset: str) -> np.ndarray:
40+
"""Load decoded segmentation output from file."""
41+
if output_h5.exists():
42+
seg = read_hdf5(str(output_h5), dataset=output_dataset)
43+
elif output_npy.exists():
44+
seg = np.load(output_npy)
45+
else:
46+
raise FileNotFoundError(
47+
"decode_abiss did not produce output file. "
48+
f"Expected one of: {output_h5}, {output_npy}"
49+
)
50+
51+
seg = np.asarray(seg)
52+
if seg.ndim == 4 and seg.shape[0] == 1:
53+
seg = seg[0]
54+
if seg.ndim != 3:
55+
raise ValueError(
56+
"decode_abiss output must be 3D label volume (Z, Y, X) "
57+
f"or singleton-channel 4D; got shape {seg.shape}."
58+
)
59+
60+
if not np.issubdtype(seg.dtype, np.integer):
61+
seg = np.rint(seg).astype(np.uint64, copy=False)
62+
63+
return cast2dtype(seg)
64+
65+
66+
def decode_abiss(
67+
predictions: np.ndarray,
68+
command: str | Sequence[str],
69+
*,
70+
input_dataset: str = "main",
71+
output_dataset: str = "main",
72+
channels: Optional[Sequence[int]] = None,
73+
workdir: Optional[str] = None,
74+
keep_workspace: bool = False,
75+
timeout_sec: Optional[int] = None,
76+
env: Optional[Dict[str, Any]] = None,
77+
check: bool = True,
78+
) -> np.ndarray:
79+
"""Decode instance segmentation with an external ABISS command.
80+
81+
Args:
82+
predictions: Model output, typically shape ``(C, Z, Y, X)``.
83+
command: External command to execute. Supports placeholders:
84+
- ``{workspace}``: working directory path
85+
- ``{input_h5}``, ``{input_npy}``: prediction file paths
86+
- ``{output_h5}``, ``{output_npy}``: expected output file paths
87+
- ``{input_dataset}``, ``{output_dataset}``: HDF5 dataset names
88+
input_dataset: Dataset name when writing input HDF5.
89+
output_dataset: Dataset name when reading output HDF5.
90+
channels: Optional channel indices to select before saving input.
91+
workdir: Optional fixed workspace directory. If None, uses temp dir.
92+
keep_workspace: Keep temp workspace when using auto temp dir.
93+
timeout_sec: Optional subprocess timeout in seconds.
94+
env: Optional extra environment variables for subprocess.
95+
check: Raise on non-zero return code if True.
96+
97+
Returns:
98+
3D instance label volume ``(Z, Y, X)``.
99+
"""
100+
pred = np.asarray(predictions)
101+
if pred.ndim not in (3, 4):
102+
raise ValueError(
103+
f"decode_abiss expects 3D/4D predictions, got shape {pred.shape}."
104+
)
105+
106+
if channels is not None:
107+
if pred.ndim != 4:
108+
raise ValueError("`channels` can only be used for 4D predictions (C, Z, Y, X).")
109+
pred = pred[np.asarray(channels)]
110+
111+
if workdir is not None:
112+
workspace_path = Path(workdir).resolve()
113+
workspace_path.mkdir(parents=True, exist_ok=True)
114+
temp_ctx = None
115+
else:
116+
temp_ctx = TemporaryDirectory(prefix="decode_abiss_")
117+
workspace_path = Path(temp_ctx.name).resolve()
118+
119+
try:
120+
input_h5 = workspace_path / "predictions.h5"
121+
input_npy = workspace_path / "predictions.npy"
122+
output_h5 = workspace_path / "segmentation.h5"
123+
output_npy = workspace_path / "segmentation.npy"
124+
125+
# Save both formats so external command can choose the easiest input.
126+
write_hdf5(str(input_h5), pred, dataset=input_dataset)
127+
np.save(input_npy, pred)
128+
129+
mapping = {
130+
"workspace": str(workspace_path),
131+
"input_h5": str(input_h5),
132+
"input_npy": str(input_npy),
133+
"output_h5": str(output_h5),
134+
"output_npy": str(output_npy),
135+
"input_dataset": input_dataset,
136+
"output_dataset": output_dataset,
137+
}
138+
cmd, use_shell = _format_command(command, mapping)
139+
140+
proc_env = os.environ.copy()
141+
if env:
142+
proc_env.update({str(k): str(v) for k, v in env.items()})
143+
144+
subprocess.run(
145+
cmd,
146+
shell=use_shell,
147+
env=proc_env,
148+
cwd=str(workspace_path),
149+
check=check,
150+
timeout=timeout_sec,
151+
)
152+
153+
return _load_output(output_h5=output_h5, output_npy=output_npy, output_dataset=output_dataset)
154+
finally:
155+
if temp_ctx is not None and not keep_workspace:
156+
temp_ctx.cleanup()

connectomics/decoding/registry.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def list_decoders() -> List[str]:
6262

6363
def register_builtin_decoders() -> None:
6464
"""Populate registry with built-in decoders."""
65+
from .abiss import decode_abiss
6566
from .segmentation import (
6667
decode_affinity_cc,
6768
decode_distance_watershed,
@@ -76,5 +77,5 @@ def register_builtin_decoders() -> None:
7677
)
7778
register_decoder("decode_affinity_cc", decode_affinity_cc, overwrite=True)
7879
register_decoder("decode_distance_watershed", decode_distance_watershed, overwrite=True)
80+
register_decoder("decode_abiss", decode_abiss, overwrite=True)
7981
register_decoder("polarity2instance", polarity2instance, overwrite=True)
80-
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""Tests for decode_abiss external wrapper."""
2+
3+
from __future__ import annotations
4+
5+
import sys
6+
7+
import numpy as np
8+
import pytest
9+
10+
from connectomics.decoding import decode_abiss
11+
12+
13+
def test_decode_abiss_with_list_command_writes_npy_output():
14+
pred = np.zeros((3, 6, 8, 10), dtype=np.float32)
15+
pred[0, 1:4, 2:6, 3:8] = 0.9
16+
17+
command = [
18+
sys.executable,
19+
"-c",
20+
(
21+
"import h5py, numpy as np; "
22+
"x = h5py.File('{input_h5}', 'r')['{input_dataset}'][:]; "
23+
"y = (x[0] > 0.5).astype(np.uint64); "
24+
"np.save('{output_npy}', y)"
25+
),
26+
]
27+
28+
seg = decode_abiss(pred, command=command)
29+
assert seg.shape == (6, 8, 10)
30+
assert np.issubdtype(seg.dtype, np.integer)
31+
assert seg.max() == 1
32+
assert seg[2, 3, 4] == 1
33+
34+
35+
def test_decode_abiss_with_string_command_writes_h5_output():
36+
pred = np.zeros((3, 5, 7, 9), dtype=np.float32)
37+
pred[1, 1:4, 2:5, 3:7] = 1.0
38+
39+
command = (
40+
f"{sys.executable} -c \""
41+
"import h5py, numpy as np; "
42+
"x = h5py.File('{input_h5}', 'r')['{input_dataset}'][:]; "
43+
"y = (x[1] > 0.5).astype(np.uint64); "
44+
"f = h5py.File('{output_h5}', 'w'); "
45+
"f.create_dataset('{output_dataset}', data=y); "
46+
"f.close()\""
47+
)
48+
49+
seg = decode_abiss(pred, command=command)
50+
assert seg.shape == (5, 7, 9)
51+
assert seg.max() == 1
52+
assert seg[2, 3, 4] == 1
53+
54+
55+
def test_decode_abiss_raises_if_output_missing():
56+
pred = np.zeros((3, 4, 4, 4), dtype=np.float32)
57+
58+
command = [sys.executable, "-c", "print('no output written')"]
59+
60+
with pytest.raises(FileNotFoundError, match="did not produce output file"):
61+
decode_abiss(pred, command=command)
62+

tests/unit/test_decoding_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def test_builtin_decoders_are_registered():
3838
assert "decode_affinity_cc" in names
3939
assert "decode_distance_watershed" in names
4040
assert "decode_instance_binary_contour_distance" in names
41+
assert "decode_abiss" in names
4142

4243

4344
def test_decode_pipeline_dict_mode_matches_direct_decoder(affinity_with_redundant_channels):
@@ -62,4 +63,3 @@ def test_decode_pipeline_unknown_decoder_raises(affinity_with_redundant_channels
6263
decode_modes = [{"name": "decode_not_exists", "kwargs": {}}]
6364
with pytest.raises(ValueError, match="Unknown decode function"):
6465
apply_decode_pipeline(affinity_with_redundant_channels, decode_modes)
65-

tutorials/neuron_snemi.yaml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,20 @@ test:
195195
- name: decode_affinity_cc
196196
kwargs:
197197
threshold: 0.5
198+
# Example: external ABISS wrapper (replace decode_affinity_cc above)
199+
# - name: decode_abiss
200+
# kwargs:
201+
# command:
202+
# - python
203+
# - scripts/run_abiss_single.py
204+
# - --input
205+
# - "{input_h5}"
206+
# - --input-dataset
207+
# - "{input_dataset}"
208+
# - --output
209+
# - "{output_h5}"
210+
# - --output-dataset
211+
# - "{output_dataset}"
198212
evaluation:
199213
enabled: true
200214
metrics:
@@ -228,6 +242,10 @@ inference:
228242
- name: decode_affinity_cc
229243
kwargs:
230244
threshold: 0.5
245+
# Example (same as test.decoding):
246+
# - name: decode_abiss
247+
# kwargs:
248+
# command: "python scripts/run_abiss_single.py --input {input_h5} --output {output_h5}"
231249
evaluation:
232250
enabled: false
233251
metrics:

0 commit comments

Comments
 (0)