Skip to content

Commit 12c6dab

Browse files
author
Donglai Wei
committed
Finalize raw prediction artifact stage
1 parent 145a695 commit 12c6dab

12 files changed

Lines changed: 372 additions & 143 deletions

File tree

connectomics/data/processing/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
# Core processing functions
22
from .bbox import bbox_ND, crop_ND, replace_ND
33
from .bbox_processor import BBoxInstanceProcessor, BBoxProcessorConfig
4-
from .iou import seg_to_iou, segs_to_iou
54

65
# Pipeline builder (primary entry point for label transforms)
76
from .build import create_label_transform_pipeline
7+
from .iou import seg_to_iou, segs_to_iou
88

99
# Utility functions used by decoding
1010
from .misc import get_seg_type
1111

1212
# MONAI-native transforms and composition
13-
from .nnunet_preprocess import NNUNetPreprocessd
13+
from .nnunet_preprocess import NNUNetPreprocessd, restore_prediction_to_input_space
1414
from .transforms import (
1515
ComputeBinaryRatioWeightd,
1616
ComputeUNet3DWeightd,
@@ -61,6 +61,7 @@
6161
"SegSelectiond",
6262
"MultiTaskLabelTransformd",
6363
"NNUNetPreprocessd",
64+
"restore_prediction_to_input_space",
6465
# Pipelines
6566
"create_label_transform_pipeline",
6667
# IoU

connectomics/data/processing/nnunet_preprocess.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,135 @@
2222
logger = logging.getLogger(__name__)
2323

2424

25+
def _infer_spatial_dims_from_array(array: np.ndarray) -> int:
26+
if array.ndim <= 2:
27+
return array.ndim
28+
if array.ndim == 3:
29+
return 3
30+
return array.ndim - 1
31+
32+
33+
def _spatial_shape(array: np.ndarray, spatial_dims: int) -> tuple:
34+
if array.ndim == spatial_dims:
35+
return tuple(int(v) for v in array.shape)
36+
return tuple(int(v) for v in array.shape[-spatial_dims:])
37+
38+
39+
def _resample_array_to_shape(
40+
array: np.ndarray,
41+
target_shape: Sequence[int],
42+
spatial_dims: int,
43+
order: int,
44+
) -> np.ndarray:
45+
from scipy.ndimage import zoom
46+
47+
target = tuple(int(v) for v in target_shape)
48+
if _spatial_shape(array, spatial_dims) == target:
49+
return array
50+
51+
def _zoom_single(vol: np.ndarray) -> np.ndarray:
52+
factors = np.array(target, dtype=np.float32) / np.maximum(
53+
np.array(vol.shape, dtype=np.float32), 1.0
54+
)
55+
return zoom(
56+
vol.astype(np.float32, copy=False),
57+
zoom=factors,
58+
order=order,
59+
mode="nearest",
60+
prefilter=order > 1,
61+
)
62+
63+
if array.ndim == spatial_dims + 1:
64+
channels = [_zoom_single(array[c])[None] for c in range(array.shape[0])]
65+
return np.vstack(channels).astype(array.dtype, copy=False)
66+
if array.ndim == spatial_dims:
67+
return _zoom_single(array).astype(array.dtype, copy=False)
68+
return array
69+
70+
71+
def _fit_array_to_shape(
72+
array: np.ndarray, target_shape: Sequence[int], spatial_dims: int
73+
) -> np.ndarray:
74+
target = tuple(int(v) for v in target_shape)
75+
if _spatial_shape(array, spatial_dims) == target:
76+
return array
77+
78+
if array.ndim == spatial_dims + 1:
79+
out = np.zeros((array.shape[0], *target), dtype=array.dtype)
80+
in_shape = array.shape[1:]
81+
write_shape = tuple(min(int(in_shape[d]), target[d]) for d in range(spatial_dims))
82+
out_slices = (slice(None),) + tuple(slice(0, w) for w in write_shape)
83+
in_slices = (slice(None),) + tuple(slice(0, w) for w in write_shape)
84+
out[out_slices] = array[in_slices]
85+
return out
86+
87+
if array.ndim == spatial_dims:
88+
out = np.zeros(target, dtype=array.dtype)
89+
in_shape = array.shape
90+
write_shape = tuple(min(int(in_shape[d]), target[d]) for d in range(spatial_dims))
91+
out_slices = tuple(slice(0, w) for w in write_shape)
92+
in_slices = tuple(slice(0, w) for w in write_shape)
93+
out[out_slices] = array[in_slices]
94+
return out
95+
96+
return array
97+
98+
99+
def restore_prediction_to_input_space(sample: np.ndarray, meta: Dict[str, Any]) -> np.ndarray:
100+
"""Invert nnU-Net preprocessing metadata for one prediction sample."""
101+
preprocess_meta = meta.get("nnunet_preprocess")
102+
if not isinstance(preprocess_meta, dict) or not preprocess_meta.get("enabled", False):
103+
return sample
104+
105+
array = sample
106+
spatial_dims = int(preprocess_meta.get("spatial_dims", _infer_spatial_dims_from_array(array)))
107+
is_integer = np.issubdtype(array.dtype, np.integer)
108+
interp_order = 0 if is_integer else 1
109+
110+
if preprocess_meta.get("applied_resample", False):
111+
cropped_shape = preprocess_meta.get("cropped_spatial_shape")
112+
if isinstance(cropped_shape, (list, tuple)) and len(cropped_shape) == spatial_dims:
113+
array = _resample_array_to_shape(
114+
array,
115+
target_shape=cropped_shape,
116+
spatial_dims=spatial_dims,
117+
order=interp_order,
118+
)
119+
120+
if preprocess_meta.get("applied_crop", False):
121+
bbox = preprocess_meta.get("crop_bbox")
122+
original_shape = preprocess_meta.get("original_spatial_shape")
123+
if (
124+
isinstance(bbox, (list, tuple))
125+
and isinstance(original_shape, (list, tuple))
126+
and len(bbox) == spatial_dims
127+
and len(original_shape) == spatial_dims
128+
):
129+
crop_target_shape = tuple(int(b[1]) - int(b[0]) for b in bbox)
130+
array = _fit_array_to_shape(array, crop_target_shape, spatial_dims=spatial_dims)
131+
132+
if array.ndim == spatial_dims + 1:
133+
restored = np.zeros((array.shape[0], *original_shape), dtype=array.dtype)
134+
slices = tuple(slice(int(b[0]), int(b[1])) for b in bbox)
135+
restored[(slice(None), *slices)] = array
136+
else:
137+
restored = np.zeros(tuple(int(v) for v in original_shape), dtype=array.dtype)
138+
slices = tuple(slice(int(b[0]), int(b[1])) for b in bbox)
139+
restored[slices] = array
140+
array = restored
141+
142+
transpose_axes = preprocess_meta.get("transpose_axes")
143+
if isinstance(transpose_axes, (list, tuple)) and len(transpose_axes) == spatial_dims:
144+
inverse_axes = np.argsort(np.asarray(transpose_axes))
145+
if array.ndim == spatial_dims + 1:
146+
perm = [0] + [int(i) + 1 for i in inverse_axes]
147+
array = np.transpose(array, perm)
148+
elif array.ndim == spatial_dims:
149+
array = np.transpose(array, tuple(int(i) for i in inverse_axes))
150+
151+
return array
152+
153+
25154
class NNUNetPreprocessd(MapTransform):
26155
"""nnU-Net style preprocessing transform."""
27156

connectomics/inference/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from .artifact import (
44
PredictionArtifactMetadata,
5+
build_prediction_artifact_metadata,
56
read_prediction_artifact,
67
write_prediction_artifact,
78
write_prediction_artifact_attrs,
@@ -29,6 +30,7 @@
2930
__all__ = [
3031
"InferenceManager",
3132
"PredictionArtifactMetadata",
33+
"build_prediction_artifact_metadata",
3234
"read_prediction_artifact",
3335
"write_prediction_artifact",
3436
"write_prediction_artifact_attrs",

connectomics/inference/artifact.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import json
66
from dataclasses import asdict, dataclass, field
77
from pathlib import Path
8-
from typing import Any, Callable, Mapping
8+
from typing import Any, Callable, Mapping, Sequence
99

1010
import numpy as np
1111

@@ -22,6 +22,10 @@ class PredictionArtifactMetadata:
2222
input_shape: tuple[int, ...] | None = None
2323
final_shape: tuple[int, ...] | None = None
2424
crop_pad: tuple[tuple[int, int], ...] | None = None
25+
transpose: tuple[int, ...] | None = None
26+
model_architecture: str | None = None
27+
model_output_identity: str | None = None
28+
decode_after_inference: bool | None = None
2529
chunk_shape: tuple[int, ...] | None = None
2630
halo: tuple[int, ...] | None = None
2731
channel_order: tuple[str, ...] | None = None
@@ -31,6 +35,86 @@ class PredictionArtifactMetadata:
3135
extra: Mapping[str, Any] = field(default_factory=dict)
3236

3337

38+
def _cfg_get(obj: Any, path: str, default: Any = None) -> Any:
39+
node = obj
40+
for part in path.split("."):
41+
if node is None:
42+
return default
43+
if isinstance(node, Mapping):
44+
node = node.get(part, default)
45+
else:
46+
node = getattr(node, part, default)
47+
return node
48+
49+
50+
def _tuple_or_none(value: Sequence[Any] | None) -> tuple[int, ...] | None:
51+
if value in (None, [], ()):
52+
return None
53+
return tuple(int(v) for v in value)
54+
55+
56+
def _model_output_identity(cfg: Any, output_head: str | None) -> str | None:
57+
parts: list[str] = []
58+
if output_head:
59+
parts.append(f"head={output_head}")
60+
else:
61+
primary_head = _cfg_get(cfg, "model.primary_head")
62+
if primary_head:
63+
parts.append(f"primary_head={primary_head}")
64+
65+
select_channel = _cfg_get(cfg, "inference.select_channel")
66+
if select_channel is not None:
67+
parts.append(f"select_channel={select_channel}")
68+
69+
return ";".join(parts) if parts else None
70+
71+
72+
def build_prediction_artifact_metadata(
73+
cfg: Any,
74+
*,
75+
image_path: str | None = None,
76+
checkpoint_path: str | None = None,
77+
output_head: str | None = None,
78+
input_shape: Sequence[int] | None = None,
79+
final_shape: Sequence[int] | None = None,
80+
crop_pad: Sequence[Sequence[int]] | None = None,
81+
chunk_shape: Sequence[int] | None = None,
82+
halo: Sequence[int] | None = None,
83+
intensity_scale: float | None = None,
84+
intensity_dtype: str | None = None,
85+
extra: Mapping[str, Any] | None = None,
86+
) -> PredictionArtifactMetadata:
87+
"""Build standard metadata for raw prediction artifacts."""
88+
transform_cfg = _cfg_get(cfg, "inference.prediction_transform")
89+
transform_enabled = bool(getattr(transform_cfg, "enabled", False))
90+
if intensity_scale is None and transform_enabled:
91+
intensity_scale = float(getattr(transform_cfg, "intensity_scale", -1.0))
92+
if intensity_dtype is None and transform_enabled:
93+
intensity_dtype = getattr(transform_cfg, "intensity_dtype", None)
94+
95+
return PredictionArtifactMetadata(
96+
image_path=image_path,
97+
checkpoint_path=str(checkpoint_path) if checkpoint_path is not None else None,
98+
output_head=output_head,
99+
input_shape=_tuple_or_none(input_shape),
100+
final_shape=_tuple_or_none(final_shape),
101+
crop_pad=(
102+
tuple((int(pair[0]), int(pair[1])) for pair in crop_pad)
103+
if crop_pad is not None
104+
else None
105+
),
106+
transpose=_tuple_or_none(_cfg_get(cfg, "data.data_transform.val_transpose")),
107+
model_architecture=_cfg_get(cfg, "model.arch.type"),
108+
model_output_identity=_model_output_identity(cfg, output_head),
109+
decode_after_inference=bool(_cfg_get(cfg, "inference.decode_after_inference", True)),
110+
chunk_shape=_tuple_or_none(chunk_shape),
111+
halo=_tuple_or_none(halo),
112+
intensity_scale=intensity_scale,
113+
intensity_dtype=intensity_dtype,
114+
extra=extra or {},
115+
)
116+
117+
34118
def _json_attr(value: Any) -> Any:
35119
if value is None or isinstance(value, (str, int, float, bool)):
36120
return value
@@ -140,6 +224,7 @@ def read_prediction_artifact(
140224

141225
__all__ = [
142226
"PredictionArtifactMetadata",
227+
"build_prediction_artifact_metadata",
143228
"read_prediction_artifact",
144229
"write_prediction_artifact",
145230
"write_prediction_artifact_attrs",

connectomics/inference/chunked.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
)
2020
from ..utils.channel_slices import resolve_channel_indices
2121
from .artifact import (
22-
PredictionArtifactMetadata,
22+
build_prediction_artifact_metadata,
2323
write_prediction_artifact,
2424
)
2525
from .lazy import get_lazy_image_reference_shape, lazy_predict_region
@@ -278,7 +278,8 @@ def _run_chunked_prediction_per_rank(
278278
write_prediction_artifact(
279279
chunk_path,
280280
core_pred,
281-
metadata=PredictionArtifactMetadata(
281+
metadata=build_prediction_artifact_metadata(
282+
cfg,
282283
image_path=str(image_path),
283284
checkpoint_path=str(checkpoint_path) if checkpoint_path is not None else None,
284285
output_head=requested_head,
@@ -479,7 +480,8 @@ def write_chunks(dataset) -> None:
479480

480481
write_prediction_artifact(
481482
output_path,
482-
metadata=PredictionArtifactMetadata(
483+
metadata=build_prediction_artifact_metadata(
484+
cfg,
483485
image_path=str(image_path),
484486
checkpoint_path=str(checkpoint_path) if checkpoint_path is not None else None,
485487
output_head=requested_head,

0 commit comments

Comments
 (0)