Skip to content

Commit ceeaa5d

Browse files
author
Donglai Wei
committed
Add v3 refactor guardrails
1 parent ba0f482 commit ceeaa5d

2 files changed

Lines changed: 199 additions & 0 deletions

File tree

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
"""Smoke benchmark for chunked raw-prediction writes.
2+
3+
The assertion is intentionally broad: this is a regression tripwire for the V3
4+
chunked artifact refactor, not a statistically rigorous benchmark.
5+
"""
6+
7+
from __future__ import annotations
8+
9+
import time
10+
11+
import h5py
12+
import numpy as np
13+
import torch
14+
15+
from connectomics.config import Config
16+
from connectomics.data.io import write_hdf5
17+
from connectomics.inference.chunked import run_chunked_prediction_inference
18+
19+
20+
def _identity_forward(x: torch.Tensor) -> torch.Tensor:
21+
return x
22+
23+
24+
def test_chunked_prediction_inference_smoke_throughput(tmp_path):
25+
cfg = Config()
26+
cfg.data.image_transform.normalize = "none"
27+
cfg.data.dataloader.patch_size = [3, 3, 3]
28+
cfg.data.dataloader.batch_size = 1
29+
cfg.model.output_size = [3, 3, 3]
30+
cfg.inference.strategy = "chunked"
31+
cfg.inference.sliding_window.window_size = [3, 3, 3]
32+
cfg.inference.sliding_window.overlap = 0.0
33+
cfg.inference.sliding_window.blending = "constant"
34+
cfg.inference.sliding_window.snap_to_edge = True
35+
cfg.inference.chunking.enabled = True
36+
cfg.inference.chunking.output_mode = "raw_prediction"
37+
cfg.inference.chunking.chunk_size = [2, 4, 4]
38+
cfg.inference.chunking.halo = [0, 0, 0]
39+
40+
image_path = tmp_path / "input.h5"
41+
output_path = tmp_path / "prediction.h5"
42+
volume = np.arange(4 * 5 * 6, dtype=np.float32).reshape(4, 5, 6)
43+
write_hdf5(str(image_path), volume, dataset="main")
44+
45+
started = time.perf_counter()
46+
run_chunked_prediction_inference(
47+
cfg,
48+
_identity_forward,
49+
str(image_path),
50+
output_path=output_path,
51+
device="cpu",
52+
)
53+
elapsed_s = time.perf_counter() - started
54+
55+
assert elapsed_s < 10.0
56+
with h5py.File(output_path, "r") as handle:
57+
assert handle["main"].shape == (1, 4, 5, 6)

tests/unit/test_v3_guardrails.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
"""V3 refactor guardrails.
2+
3+
These tests document the intended package boundaries before the implementation
4+
stages move code. Known current violations are marked strict xfail so later
5+
stages must either fix and un-xfail them or keep the debt visible.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import ast
11+
from pathlib import Path
12+
13+
import pytest
14+
15+
REPO_ROOT = Path(__file__).resolve().parents[2]
16+
17+
18+
def _module_name_for_path(path: Path) -> str:
19+
relative = path.relative_to(REPO_ROOT).with_suffix("")
20+
parts = list(relative.parts)
21+
if parts[-1] == "__init__":
22+
parts = parts[:-1]
23+
return ".".join(parts)
24+
25+
26+
def _resolve_import_from(module_name: str, node: ast.ImportFrom) -> str:
27+
if node.level == 0:
28+
return node.module or ""
29+
30+
package_parts = module_name.split(".")
31+
if path_name := package_parts[-1]:
32+
if path_name != "__init__":
33+
package_parts = package_parts[:-1]
34+
base_parts = package_parts[: max(0, len(package_parts) - node.level + 1)]
35+
if node.module:
36+
base_parts.extend(node.module.split("."))
37+
return ".".join(base_parts)
38+
39+
40+
def _forbidden_imports(root: Path, forbidden_prefixes: tuple[str, ...]) -> list[str]:
41+
violations: list[str] = []
42+
for path in sorted(root.rglob("*.py")):
43+
module_name = _module_name_for_path(path)
44+
tree = ast.parse(path.read_text(), filename=str(path))
45+
for node in ast.walk(tree):
46+
imported_modules: list[str] = []
47+
if isinstance(node, ast.Import):
48+
imported_modules = [alias.name for alias in node.names]
49+
elif isinstance(node, ast.ImportFrom):
50+
imported_modules = [_resolve_import_from(module_name, node)]
51+
52+
for imported_module in imported_modules:
53+
if imported_module.startswith(forbidden_prefixes):
54+
rel = path.relative_to(REPO_ROOT)
55+
violations.append(f"{rel}:{node.lineno}: {imported_module}")
56+
return violations
57+
58+
59+
@pytest.mark.xfail(strict=True, reason="V3 PR 2/5 removes decoding -> training imports")
60+
def test_decoding_static_imports_do_not_reference_training():
61+
violations = _forbidden_imports(
62+
REPO_ROOT / "connectomics" / "decoding",
63+
("connectomics.training",),
64+
)
65+
assert violations == []
66+
67+
68+
@pytest.mark.xfail(strict=True, reason="V3 PR 6 moves streamed chunk decoding out of inference")
69+
def test_inference_static_imports_do_not_reference_decoding():
70+
violations = _forbidden_imports(
71+
REPO_ROOT / "connectomics" / "inference",
72+
("connectomics.decoding",),
73+
)
74+
assert violations == []
75+
76+
77+
@pytest.mark.xfail(strict=True, reason="V3 PR 7 moves data-aware validation out of config")
78+
def test_config_static_imports_do_not_reference_data_execution():
79+
violations = _forbidden_imports(
80+
REPO_ROOT / "connectomics" / "config",
81+
("connectomics.data",),
82+
)
83+
assert violations == []
84+
85+
86+
@pytest.mark.xfail(strict=True, reason="V3 PR 3 makes unknown top-level keys hard errors")
87+
def test_config_load_raises_on_unknown_top_level_key(tmp_path):
88+
from connectomics.config import load_config
89+
90+
config_yaml = tmp_path / "unknown_key.yaml"
91+
config_yaml.write_text("unknown_section: {}\n")
92+
93+
with pytest.raises(ValueError, match="unknown_section"):
94+
load_config(config_yaml)
95+
96+
97+
def test_connectomics_config_public_api_snapshot():
98+
import connectomics.config as config
99+
100+
assert set(config.__all__) == {
101+
"Config",
102+
"load_config",
103+
"save_config",
104+
"merge_configs",
105+
"update_from_cli",
106+
"to_dict",
107+
"from_dict",
108+
"print_config",
109+
"validate_config",
110+
"get_config_hash",
111+
"create_experiment_name",
112+
"resolve_data_paths",
113+
"resolve_default_profiles",
114+
"to_plain",
115+
"as_plain_dict",
116+
"cfg_get",
117+
}
118+
119+
120+
def test_connectomics_inference_public_api_snapshot():
121+
import connectomics.inference as inference
122+
123+
assert set(inference.__all__) == {
124+
"InferenceManager",
125+
"PredictionArtifactMetadata",
126+
"read_prediction_artifact",
127+
"write_prediction_artifact",
128+
"write_prediction_artifact_attrs",
129+
"run_prediction_inference",
130+
"is_chunked_inference_enabled",
131+
"run_chunked_affinity_cc_inference",
132+
"run_chunked_prediction_inference",
133+
"apply_prediction_transform",
134+
"apply_storage_dtype_transform",
135+
"resolve_output_filenames",
136+
"write_outputs",
137+
"build_sliding_inferer",
138+
"resolve_inferer_roi_size",
139+
"resolve_inferer_overlap",
140+
"is_2d_inference_mode",
141+
"TTAPredictor",
142+
}

0 commit comments

Comments
 (0)