Skip to content

Commit 2d68de2

Browse files
author
Donglai Wei
committed
Store precomputed skeleton caches with source backend
1 parent 8076088 commit 2d68de2

3 files changed

Lines changed: 78 additions & 1 deletion

File tree

connectomics/data/io/io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ def save_volume(
388388
filename: Output filename or directory path.
389389
volume: Volume data to save.
390390
dataset: Dataset name for HDF5 format.
391-
file_format: 'h5', 'tiff', 'png', 'nii', 'nii.gz'.
391+
file_format: Optional override. If omitted, inferred from ``filename``.
392392
"""
393393
file_format = file_format or _detect_format(filename)
394394

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from pathlib import Path
2+
3+
import numpy as np
4+
5+
from connectomics.config import Config
6+
from connectomics.data.io import save_volume
7+
from connectomics.training.lightning.data_factory import _maybe_precompute_label_aux
8+
9+
10+
def test_maybe_precompute_label_aux_reuses_existing_zarr_cache(monkeypatch, tmp_path: Path):
11+
label_path = tmp_path / "data.zarr" / "seg"
12+
aux_path = tmp_path / "data.zarr" / "seg_skeleton"
13+
save_volume(str(label_path), np.zeros((2, 2, 2), dtype=np.uint16), file_format="zarr")
14+
save_volume(str(aux_path), np.ones((2, 2, 2), dtype=np.uint16), file_format="zarr")
15+
16+
cfg = Config()
17+
cfg.data.label_transform.targets = [{"name": "skeleton_aware_edt", "kwargs": {}}]
18+
cfg.data.train.label_aux_type = "skeleton"
19+
20+
def fail(*args, **kwargs):
21+
raise AssertionError("existing zarr cache should be reused")
22+
23+
monkeypatch.setattr(
24+
"connectomics.data.processing.distance.precompute_skeleton_volume",
25+
fail,
26+
)
27+
28+
paths = _maybe_precompute_label_aux(
29+
cfg,
30+
cfg.data.train,
31+
[str(label_path)],
32+
split_name="train",
33+
)
34+
35+
assert paths == [str(aux_path)]

tests/unit/test_io_zarr.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from pathlib import Path
2+
3+
import numpy as np
4+
import pytest
5+
6+
from connectomics.data.io import get_vol_shape, read_volume, save_volume, volume_exists
7+
from connectomics.data.processing.distance import sdt_path_for_label
8+
9+
zarr = pytest.importorskip("zarr")
10+
11+
12+
def test_sdt_path_for_label_preserves_zarr_backend():
13+
assert sdt_path_for_label("/tmp/labels.h5", mode="skeleton") == "/tmp/labels_skeleton.h5"
14+
assert (
15+
sdt_path_for_label("/tmp/data.zarr/seg", mode="skeleton") == "/tmp/data.zarr/seg_skeleton"
16+
)
17+
assert (
18+
sdt_path_for_label("/tmp/data.zarr/group/seg", mode="sdt") == "/tmp/data.zarr/group/seg_sdt"
19+
)
20+
assert sdt_path_for_label("/tmp/data.zarr", mode="skeleton") == "/tmp/data_skeleton.zarr"
21+
22+
23+
def test_zarr_subkey_round_trip_and_shape(tmp_path: Path):
24+
path = tmp_path / "data.zarr" / "seg_skeleton"
25+
volume = np.arange(24, dtype=np.uint16).reshape(2, 3, 4)
26+
27+
save_volume(str(path), volume, file_format="zarr")
28+
29+
loaded = read_volume(str(path))
30+
np.testing.assert_array_equal(loaded, volume)
31+
assert get_vol_shape(str(path)) == volume.shape
32+
assert volume_exists(str(path))
33+
34+
35+
def test_detect_format_prefers_real_suffix_inside_zarr_directory(tmp_path: Path):
36+
h5_path = tmp_path / "data.zarr" / "seg_skeleton.h5"
37+
h5_path.parent.mkdir(parents=True, exist_ok=True)
38+
save_volume(str(h5_path), np.ones((2, 2, 2), dtype=np.uint8), file_format="h5")
39+
40+
loaded = read_volume(str(h5_path))
41+
42+
assert loaded.shape == (2, 2, 2)

0 commit comments

Comments
 (0)