Skip to content

Commit 8eed96b

Browse files
author
Donglai Wei
committed
Remove orphan v3 refactor code
1 parent ceeaa5d commit 8eed96b

10 files changed

Lines changed: 91 additions & 755 deletions

File tree

connectomics/data/datasets/sampling.py

Lines changed: 0 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -89,80 +89,7 @@ def compute_total_samples(
8989
return total_samples, samples_per_volume
9090

9191

92-
def calculate_inference_grid(
93-
volume_shape: Tuple[int, int, int],
94-
patch_size: Tuple[int, int, int],
95-
stride: Tuple[int, int, int],
96-
) -> Tuple[np.ndarray, Tuple[int, int, int]]:
97-
"""
98-
Calculate grid of patch positions for sliding-window inference.
99-
100-
This function generates all patch positions needed to cover a volume
101-
with overlapping patches, using the specified stride.
102-
103-
Args:
104-
volume_shape: Shape of the input volume (D, H, W)
105-
patch_size: Size of each patch (D, H, W)
106-
stride: Stride between patch centers (D, H, W)
107-
108-
Returns:
109-
positions: Array of shape (N, 3) containing (z, y, x) start positions
110-
grid_shape: Tuple (num_z, num_y, num_x) indicating grid dimensions
111-
112-
Examples:
113-
>>> volume_shape = (256, 256, 256)
114-
>>> patch_size = (128, 128, 128)
115-
>>> stride = (64, 64, 64)
116-
>>> positions, grid = calculate_inference_grid(volume_shape, patch_size, stride)
117-
>>> print(f"Grid shape: {grid}")
118-
>>> # Grid shape: (3, 3, 3)
119-
>>> print(f"Total patches: {len(positions)}")
120-
>>> # Total patches: 27
121-
122-
Note:
123-
The last patch in each dimension is "tucked in" to ensure it fits
124-
within the volume boundaries, matching the legacy v1 behavior.
125-
"""
126-
volume_shape = np.array(volume_shape)
127-
patch_size = np.array(patch_size)
128-
stride = np.array(stride)
129-
130-
# Calculate grid dimensions
131-
grid_shape = count_volume(volume_shape, patch_size, stride)
132-
grid_shape = tuple(grid_shape)
133-
134-
positions = []
135-
136-
# Generate all grid positions
137-
for z_idx in range(grid_shape[0]):
138-
for y_idx in range(grid_shape[1]):
139-
for x_idx in range(grid_shape[2]):
140-
# Calculate position with boundary handling
141-
# Normal case: multiply by stride
142-
# Boundary case: tuck in to ensure patch fits
143-
z = (
144-
z_idx * stride[0]
145-
if z_idx < grid_shape[0] - 1
146-
else volume_shape[0] - patch_size[0]
147-
)
148-
y = (
149-
y_idx * stride[1]
150-
if y_idx < grid_shape[1] - 1
151-
else volume_shape[1] - patch_size[1]
152-
)
153-
x = (
154-
x_idx * stride[2]
155-
if x_idx < grid_shape[2] - 1
156-
else volume_shape[2] - patch_size[2]
157-
)
158-
159-
positions.append([z, y, x])
160-
161-
return np.array(positions, dtype=np.int32), grid_shape
162-
163-
16492
__all__ = [
16593
"count_volume",
16694
"compute_total_samples",
167-
"calculate_inference_grid",
16895
]

connectomics/data/processing/bbox.py

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,15 @@
11
from __future__ import annotations
22

33
import itertools
4-
from collections import OrderedDict
54
from typing import Optional, Tuple, Union
65

76
import numpy as np
87

98
__all__ = [
109
"bbox_ND",
1110
"bbox_relax",
12-
"adjust_bbox",
13-
"index2bbox",
1411
"crop_ND",
1512
"replace_ND",
16-
"rand_window",
1713
"compute_bbox_all",
1814
]
1915

@@ -53,27 +49,6 @@ def bbox_relax(coord: Union[tuple, list], shape: tuple, relax: int = 0) -> tuple
5349
return tuple(coord)
5450

5551

56-
def adjust_bbox(low, high, sz):
57-
if high < low:
58-
raise ValueError(f"high ({high}) must be >= low ({low})")
59-
bbox_sz = high - low
60-
diff = abs(sz - bbox_sz) // 2
61-
if bbox_sz >= sz:
62-
return low + diff, low + diff + sz
63-
64-
return low - diff, low - diff + sz
65-
66-
67-
def index2bbox(seg: np.ndarray, indices: list, relax: int = 0, iterative: bool = True) -> dict:
68-
"""Calculate the bounding boxes associated with the given mask indices."""
69-
bbox_dict = OrderedDict()
70-
for idx in indices:
71-
temp = seg == idx
72-
bbox = bbox_ND(temp, relax=relax)
73-
bbox_dict[idx] = bbox
74-
return bbox_dict
75-
76-
7752
def _coord2slice(coord: Tuple[int], ndim: int, end_included: bool = False):
7853
if len(coord) != ndim * 2:
7954
raise ValueError(f"Expected {ndim * 2} coordinates for {ndim}D array, got {len(coord)}")
@@ -116,28 +91,6 @@ def replace_ND(
11691
return img.copy()
11792

11893

119-
def rand_window(w0, w1, sz, rand_shift: int = 0):
120-
if w1 < w0:
121-
raise ValueError(f"w1 ({w1}) must be >= w0 ({w0})")
122-
diff = np.abs((w1 - w0) - sz)
123-
if (w1 - w0) <= sz:
124-
if rand_shift > 0: # random shift augmentation
125-
start_l = max(w0 - diff // 2 - rand_shift, w1 - sz)
126-
start_r = min(w0, w0 - diff // 2 + rand_shift)
127-
low = np.random.randint(start_l, start_r)
128-
else:
129-
low = w0 - diff // 2
130-
else:
131-
if rand_shift > 0: # random shift augmentation
132-
start_l = max(w0, w0 + diff // 2 - rand_shift)
133-
start_r = min(w0 + diff // 2 + rand_shift, w1 - sz)
134-
low = np.random.randint(start_l, start_r)
135-
else:
136-
low = w0 + diff // 2
137-
high = low + sz
138-
return low, high
139-
140-
14194
def compute_bbox_all(
14295
seg: np.ndarray, do_count: bool = False, uid: Optional[np.ndarray] = None
14396
) -> Optional[np.ndarray]:

connectomics/data/processing/blend.py

Lines changed: 0 additions & 61 deletions
This file was deleted.

connectomics/decoding/__init__.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,30 +6,16 @@
66
tuning/ - Hyperparameter tuning for decoding parameters
77
"""
88

9+
from importlib import import_module
10+
911
# --- Framework / Infrastructure ---
1012
from .base import DecodeStep
11-
12-
# --- Segmentation Decoders ---
13-
from .decoders import (
14-
branch_merge,
15-
decode_abiss,
16-
decode_affinity_cc,
17-
decode_distance_watershed,
18-
decode_instance_binary_contour_distance,
19-
decode_waterz,
20-
polarity2instance,
21-
)
2213
from .pipeline import (
2314
apply_decode_mode,
2415
apply_decode_pipeline,
2516
normalize_decode_modes,
2617
resolve_decode_modes_from_cfg,
2718
)
28-
from .stage import (
29-
DecodingStageResult,
30-
apply_decoding_postprocessing,
31-
run_decoding_stage,
32-
)
3319

3420
# --- Post-processing & Utilities ---
3521
from .postprocessing import (
@@ -49,15 +35,38 @@
4935
register_builtin_decoders,
5036
register_decoder,
5137
)
52-
38+
from .stage import (
39+
DecodingStageResult,
40+
apply_decoding_postprocessing,
41+
run_decoding_stage,
42+
)
5343
from .utils import (
5444
cast2dtype,
5545
merge_small_objects,
5646
remove_large_instances,
5747
remove_small_instances,
5848
)
5949

60-
register_builtin_decoders()
50+
_LAZY_DECODER_EXPORTS = {
51+
"branch_merge": "connectomics.decoding.decoders.branch_merge",
52+
"decode_abiss": "connectomics.decoding.decoders.abiss",
53+
"decode_affinity_cc": "connectomics.decoding.decoders.segmentation",
54+
"decode_distance_watershed": "connectomics.decoding.decoders.segmentation",
55+
"decode_instance_binary_contour_distance": "connectomics.decoding.decoders.segmentation",
56+
"decode_waterz": "connectomics.decoding.decoders.waterz",
57+
"polarity2instance": "connectomics.decoding.decoders.synapse",
58+
}
59+
60+
61+
def __getattr__(name: str):
62+
module_name = _LAZY_DECODER_EXPORTS.get(name)
63+
if module_name is None:
64+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
65+
66+
value = getattr(import_module(module_name), name)
67+
globals()[name] = value
68+
return value
69+
6170

6271
__all__ = [
6372
# Registry / pipeline

connectomics/decoding/decoders/__init__.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,27 @@
11
"""Segmentation decoder implementations."""
22

3-
from .abiss import decode_abiss
4-
from .branch_merge import branch_merge
5-
from .segmentation import (
6-
decode_affinity_cc,
7-
decode_distance_watershed,
8-
decode_instance_binary_contour_distance,
9-
)
10-
from .synapse import polarity2instance
11-
from .waterz import decode_waterz
3+
from importlib import import_module
4+
5+
_LAZY_DECODERS = {
6+
"branch_merge": "connectomics.decoding.decoders.branch_merge",
7+
"decode_abiss": "connectomics.decoding.decoders.abiss",
8+
"decode_affinity_cc": "connectomics.decoding.decoders.segmentation",
9+
"decode_distance_watershed": "connectomics.decoding.decoders.segmentation",
10+
"decode_instance_binary_contour_distance": "connectomics.decoding.decoders.segmentation",
11+
"decode_waterz": "connectomics.decoding.decoders.waterz",
12+
"polarity2instance": "connectomics.decoding.decoders.synapse",
13+
}
14+
15+
16+
def __getattr__(name: str):
17+
module_name = _LAZY_DECODERS.get(name)
18+
if module_name is None:
19+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
20+
21+
value = getattr(import_module(module_name), name)
22+
globals()[name] = value
23+
return value
24+
1225

1326
__all__ = [
1427
"decode_instance_binary_contour_distance",

connectomics/decoding/pipeline.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from ..utils.channel_slices import resolve_channel_indices
1111
from .base import DecodeStep
12-
from .registry import DEFAULT_DECODER_REGISTRY, DecoderRegistry
12+
from .registry import DEFAULT_DECODER_REGISTRY, DecoderRegistry, ensure_builtin_decoders_registered
1313

1414
logger = logging.getLogger(__name__)
1515

@@ -87,7 +87,9 @@ def apply_decode_pipeline(
8787
if not decode_modes:
8888
return data
8989

90-
registry = registry or DEFAULT_DECODER_REGISTRY
90+
if registry is None:
91+
ensure_builtin_decoders_registered()
92+
registry = DEFAULT_DECODER_REGISTRY
9193
steps = [s for s in normalize_decode_modes(decode_modes) if s.enabled]
9294
if not steps:
9395
return data

connectomics/decoding/registry.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def available(self) -> List[str]:
4141

4242

4343
DEFAULT_DECODER_REGISTRY = DecoderRegistry()
44+
_BUILTINS_REGISTERED = False
4445

4546

4647
def register_decoder(name: str, fn: DecodeFunction, *, overwrite: bool = False) -> None:
@@ -50,16 +51,27 @@ def register_decoder(name: str, fn: DecodeFunction, *, overwrite: bool = False)
5051

5152
def get_decoder(name: str) -> DecodeFunction:
5253
"""Get decoder from the default registry."""
54+
ensure_builtin_decoders_registered()
5355
return DEFAULT_DECODER_REGISTRY.get(name)
5456

5557

5658
def list_decoders() -> List[str]:
5759
"""List names in the default registry."""
60+
ensure_builtin_decoders_registered()
5861
return DEFAULT_DECODER_REGISTRY.available()
5962

6063

64+
def ensure_builtin_decoders_registered() -> None:
65+
"""Register built-in decoders on first use."""
66+
register_builtin_decoders()
67+
68+
6169
def register_builtin_decoders() -> None:
6270
"""Populate registry with built-in decoders."""
71+
global _BUILTINS_REGISTERED
72+
if _BUILTINS_REGISTERED:
73+
return
74+
6375
from .decoders.abiss import decode_abiss
6476
from .decoders.segmentation import (
6577
decode_affinity_cc,
@@ -79,3 +91,4 @@ def register_builtin_decoders() -> None:
7991
register_decoder("decode_waterz", decode_waterz, overwrite=True)
8092
register_decoder("decode_abiss", decode_abiss, overwrite=True)
8193
register_decoder("polarity2instance", polarity2instance, overwrite=True)
94+
_BUILTINS_REGISTERED = True

0 commit comments

Comments
 (0)