Skip to content

Commit 2c954e6

Browse files
author
Donglai Wei
committed
Extract public chunk-grid helpers
1 parent fae7e09 commit 2c954e6

4 files changed

Lines changed: 179 additions & 155 deletions

File tree

connectomics/decoding/streamed_chunked.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010
import numpy as np
1111
import torch
1212

13-
from ..inference.chunked import (
14-
_build_chunk_grid,
15-
_resolve_chunk_output_mode,
16-
_resolve_chunk_shape,
17-
_resolve_global_prediction_crop,
18-
_resolve_h5_spatial_chunks,
19-
_validate_chunked_output_contract,
13+
from ..inference.chunk_grid import (
14+
build_chunk_grid,
15+
resolve_chunk_output_mode,
16+
resolve_chunk_shape,
17+
resolve_global_prediction_crop,
18+
resolve_h5_spatial_chunks,
19+
validate_chunked_output_format,
2020
)
2121
from ..inference.lazy import get_lazy_image_reference_shape, lazy_predict_region
2222
from ..inference.output import apply_prediction_transform
@@ -163,8 +163,8 @@ def run_chunked_affinity_cc_inference(
163163
requested_head: str | None = None,
164164
) -> Path:
165165
"""Run chunked lazy inference, decode chunks, stitch boundaries, and write HDF5."""
166-
_validate_chunked_output_contract(cfg)
167-
output_mode = _resolve_chunk_output_mode(cfg)
166+
validate_chunked_output_format(cfg)
167+
output_mode = resolve_chunk_output_mode(cfg)
168168
if output_mode != "decoded":
169169
raise ValueError(
170170
"run_chunked_affinity_cc_inference requires "
@@ -192,7 +192,7 @@ def run_chunked_affinity_cc_inference(
192192

193193
reference_shape = get_lazy_image_reference_shape(cfg, image_path, mode="test")
194194
input_shape = tuple(int(v) for v in reference_shape[-3:])
195-
crop_pad = _resolve_global_prediction_crop(cfg)
195+
crop_pad = resolve_global_prediction_crop(cfg)
196196
crop_before = tuple(int(crop_pad[axis][0]) for axis in range(3))
197197
crop_after = tuple(int(crop_pad[axis][1]) for axis in range(3))
198198
final_shape = tuple(
@@ -203,9 +203,9 @@ def run_chunked_affinity_cc_inference(
203203
f"Chunked inference crop {crop_pad} is too large for input shape {input_shape}."
204204
)
205205

206-
chunk_shape = _resolve_chunk_shape(cfg, final_shape)
206+
chunk_shape = resolve_chunk_shape(cfg, final_shape)
207207
halo = tuple(int(v) for v in getattr(chunking_cfg, "halo", [0, 0, 0]))
208-
chunks = _build_chunk_grid(final_shape, chunk_shape)
208+
chunks = build_chunk_grid(final_shape, chunk_shape)
209209
output_path = Path(output_path)
210210
temp_root = (
211211
Path(chunking_cfg.temp_dir)
@@ -329,7 +329,7 @@ def run_chunked_affinity_cc_inference(
329329
output_path.parent.mkdir(parents=True, exist_ok=True)
330330
compression = getattr(getattr(cfg.inference, "save_prediction", None), "compression", "gzip")
331331
compression = None if compression in (None, "", "none") else compression
332-
h5_chunks = _resolve_h5_spatial_chunks(final_shape)
332+
h5_chunks = resolve_h5_spatial_chunks(final_shape)
333333
with h5py.File(output_path, "w") as handle:
334334
dataset = handle.create_dataset(
335335
"main",
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
"""Public chunk-grid helpers shared by chunked inference and streamed decoding."""
2+
3+
from __future__ import annotations
4+
5+
from dataclasses import dataclass
6+
from itertools import product
7+
from typing import Any, Sequence
8+
9+
from ..data.processing.affinity import (
10+
compute_affinity_crop_pad,
11+
resolve_affinity_channel_groups_from_cfg,
12+
resolve_affinity_mode_from_cfg,
13+
)
14+
from ..utils.channel_slices import resolve_channel_indices
15+
16+
17+
@dataclass(frozen=True)
18+
class ChunkRef:
19+
index: tuple[int, int, int]
20+
start: tuple[int, int, int]
21+
stop: tuple[int, int, int]
22+
23+
@property
24+
def key(self) -> str:
25+
z, y, x = self.index
26+
return f"z{z}_y{y}_x{x}"
27+
28+
@property
29+
def slices(self) -> tuple[slice, slice, slice]:
30+
return tuple(slice(self.start[axis], self.stop[axis]) for axis in range(3))
31+
32+
33+
def build_chunk_grid(volume_shape: Sequence[int], chunk_shape: Sequence[int]) -> list[ChunkRef]:
34+
volume = tuple(int(v) for v in volume_shape)
35+
chunk = tuple(int(v) for v in chunk_shape)
36+
counts = tuple((volume[axis] + chunk[axis] - 1) // chunk[axis] for axis in range(3))
37+
result: list[ChunkRef] = []
38+
for index in product(*(range(count) for count in counts)):
39+
start = tuple(index[axis] * chunk[axis] for axis in range(3))
40+
stop = tuple(min(start[axis] + chunk[axis], volume[axis]) for axis in range(3))
41+
result.append(ChunkRef(index=tuple(int(v) for v in index), start=start, stop=stop))
42+
return result
43+
44+
45+
def normalize_crop_pad(value: Any) -> tuple[tuple[int, int], tuple[int, int], tuple[int, int]]:
46+
if value in (None, [], ()):
47+
return ((0, 0), (0, 0), (0, 0))
48+
values = [int(v) for v in value]
49+
if len(values) == 3:
50+
return tuple((v, v) for v in values) # type: ignore[return-value]
51+
if len(values) == 6:
52+
return ((values[0], values[1]), (values[2], values[3]), (values[4], values[5]))
53+
raise ValueError(f"inference.crop_pad must have length 3 or 6, got {value!r}")
54+
55+
56+
def resolve_selected_affinity_offsets(cfg: Any) -> list[tuple[int, int, int]]:
57+
groups = resolve_affinity_channel_groups_from_cfg(cfg)
58+
if not groups:
59+
return []
60+
61+
label_channels = max(end for (start, end), _offsets in groups)
62+
channel_offsets: list[tuple[int, int, int] | None] = [None] * label_channels
63+
for (start, end), offsets in groups:
64+
for channel, offset in zip(range(start, end), offsets):
65+
channel_offsets[channel] = offset
66+
67+
select_channel = getattr(getattr(cfg, "inference", None), "select_channel", None)
68+
if select_channel is not None:
69+
selected = resolve_channel_indices(
70+
select_channel,
71+
num_channels=len(channel_offsets),
72+
context="inference.select_channel",
73+
)
74+
channel_offsets = [channel_offsets[idx] for idx in selected]
75+
76+
return [offset for offset in channel_offsets if offset is not None]
77+
78+
79+
def resolve_global_prediction_crop(
80+
cfg: Any,
81+
) -> tuple[tuple[int, int], tuple[int, int], tuple[int, int]]:
82+
user_crop = normalize_crop_pad(getattr(getattr(cfg, "inference", None), "crop_pad", None))
83+
affinity_mode = resolve_affinity_mode_from_cfg(cfg)
84+
if affinity_mode is None:
85+
affinity_crop = ((0, 0), (0, 0), (0, 0))
86+
else:
87+
offsets = resolve_selected_affinity_offsets(cfg)
88+
affinity_crop = (
89+
compute_affinity_crop_pad(offsets, affinity_mode=affinity_mode)
90+
if offsets
91+
else ((0, 0), (0, 0), (0, 0))
92+
)
93+
return tuple(
94+
(
95+
int(user_crop[axis][0]) + int(affinity_crop[axis][0]),
96+
int(user_crop[axis][1]) + int(affinity_crop[axis][1]),
97+
)
98+
for axis in range(3)
99+
) # type: ignore[return-value]
100+
101+
102+
def validate_chunked_output_format(cfg: Any) -> None:
103+
save_cfg = getattr(getattr(cfg, "inference", None), "save_prediction", None)
104+
formats = [str(fmt).lower() for fmt in getattr(save_cfg, "output_formats", ["h5"])]
105+
unsupported_formats = [fmt for fmt in formats if fmt not in {"h5", "hdf5"}]
106+
if unsupported_formats:
107+
raise ValueError(
108+
"Chunked inference writes a single streamed HDF5 output only; "
109+
f"unsupported save_prediction.output_formats={unsupported_formats}."
110+
)
111+
112+
113+
def resolve_chunk_shape(cfg: Any, final_shape: Sequence[int]) -> tuple[int, int, int]:
114+
chunking_cfg = cfg.inference.chunking
115+
chunk_size = tuple(int(v) for v in chunking_cfg.chunk_size)
116+
axes = str(getattr(chunking_cfg, "axes", "all")).lower()
117+
if axes == "z":
118+
return (chunk_size[0], int(final_shape[1]), int(final_shape[2]))
119+
if axes != "all":
120+
raise ValueError("inference.chunking.axes must be 'all' or 'z'")
121+
return tuple(min(chunk_size[axis], int(final_shape[axis])) for axis in range(3))
122+
123+
124+
def resolve_h5_spatial_chunks(spatial_shape: Sequence[int]) -> tuple[int, int, int]:
125+
preferred = (64, 64, 64)
126+
return tuple(min(int(spatial_shape[axis]), preferred[axis]) for axis in range(3))
127+
128+
129+
def resolve_chunk_output_mode(cfg: Any) -> str:
130+
chunking_cfg = cfg.inference.chunking
131+
mode = str(getattr(chunking_cfg, "output_mode", "decoded")).lower()
132+
if mode not in {"decoded", "raw_prediction"}:
133+
raise ValueError("inference.chunking.output_mode must be 'decoded' or 'raw_prediction'.")
134+
return mode
135+
136+
137+
__all__ = [
138+
"ChunkRef",
139+
"build_chunk_grid",
140+
"normalize_crop_pad",
141+
"resolve_selected_affinity_offsets",
142+
"resolve_global_prediction_crop",
143+
"validate_chunked_output_format",
144+
"resolve_chunk_shape",
145+
"resolve_h5_spatial_chunks",
146+
"resolve_chunk_output_mode",
147+
]

0 commit comments

Comments
 (0)