Skip to content

Commit 124317b

Browse files
Donglai Weiclaude
andcommitted
Add patch-first local TTA, ABISS merge function arg, and prediction cache improvements
- Add patch_first_local TTA mode: slide once over volume, apply TTA augmentations locally inside each ROI batch to reduce redundant sliding-window passes - Add --ws-merge-function arg to run_abiss_single.py for edge scoring (max, mean, p75, p90) in region graph construction - Improve prediction cache lookup: try final _x{N}_prediction.h5 before intermediate _tta_x{N}_prediction.h5 files - Skip redundant crop_pad/affinity crop on cached final predictions - Use in-place activations (sigmoid_, tanh_) for contiguous channel slices - Include TTA pass tag in evaluation metrics filename - Update neuron_snemi.yaml with patch_first_local and p75 merge function Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 91ebc2d commit 124317b

10 files changed

Lines changed: 1099 additions & 205 deletions

File tree

.claude/reference/waterz.md

Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
# waterz - Watershed and Region Agglomeration Library
2+
3+
**Location:** `/projects/weilab/weidf/lib/waterz/`
4+
**Version:** 0.8
5+
**License:** MIT
6+
**Origin:** Fork of [funkey/waterz](https://github.com/funkey/waterz) by donglaiw, with CREMI scoring functions from Mala_v2.zip
7+
**Language:** Python + Cython + C++11 (Boost multi_array)
8+
**Dependencies:** cython, numpy, scipy, mahotas, boost (C++ headers)
9+
10+
## Purpose
11+
12+
Waterz ("water-zed") performs watershed segmentation and hierarchical region agglomeration on 3D affinity graphs. It is the core post-processing library for converting voxel-level affinity predictions (from neural networks) into instance segmentations of neurons/organelles in connectomics EM volumes.
13+
14+
## Installation
15+
16+
```bash
17+
conda create -n zw python==3.8 cython numpy
18+
pip install --editable /projects/weilab/weidf/lib/waterz
19+
# or: cd /projects/weilab/weidf/lib/waterz && python setup.py build_ext --inplace
20+
```
21+
22+
## Architecture Overview
23+
24+
```
25+
waterz/
26+
__init__.py # Public API: agglomerate(), waterz(), watershed(), etc.
27+
seg_waterz.py # High-level waterz() and getRegionGraph() wrappers
28+
seg_watershed.py # Python 2D watershed (mahotas-based, slice-by-slice)
29+
seg_region_graph.py # Soma-aware BFS merging, branch IoU utilities
30+
seg_util.py # Helpers: scoring function string builder, HDF5 I/O, border masks
31+
agglomerate.pyx # Cython bridge: agglomerate() -> C++ initialize/mergeUntil/free
32+
region_graph.pyx # Cython bridge: merge_id() variants -> C++ union-find merging
33+
evaluate.pyx # Cython bridge: Rand/VOI metrics -> C++ compare_volumes
34+
frontend_agglomerate.h/cpp # C++ agglomeration pipeline (WaterzContext state machine)
35+
frontend_region_graph.h/cpp # C++ union-find merge with optional affinity/count filtering
36+
frontend_evaluate.h/cpp # C++ evaluation (compare_arrays, chunked statistics)
37+
frontend_basic.h # Type definitions: SegID=uint32, AffValue=uint8, ScoreValue=uint8
38+
backend/ # C++ template library (header-only)
39+
types.hpp # boost::multi_array typedefs, watershed_traits
40+
basic_watershed.hpp # C++ watershed on affinity graph (BFS plateau division)
41+
RegionGraph.hpp # Region Adjacency Graph with node/edge maps
42+
region_graph.hpp # Extract RAG from segmentation + affinities
43+
IterativeRegionMerging.hpp # Priority-queue agglomeration engine
44+
PriorityQueue.hpp # Min-heap priority queue wrapper
45+
BinQueue.hpp # Discretized bin queue (approximate priority queue)
46+
StatisticsProvider.hpp # Base class with merge/edge callbacks
47+
MergeProviders.hpp # Template meta-programming to combine providers
48+
CompoundProvider.hpp # Multiple inheritance provider combiner
49+
MergeFunctions.hpp # Scoring functions: MinSize, MaxSize, MinAffinity, MeanAffinity, etc.
50+
Operators.hpp # Composable operators: OneMinus, One255Minus, Multiply, Add, etc.
51+
MeanAffinityProvider.hpp # Running mean of edge affinities
52+
MinAffinityProvider.hpp # Min affinity per edge
53+
MaxAffinityProvider.hpp # Max affinity per edge
54+
HistogramQuantileProvider.hpp # Histogram-based quantile (approximate, 256 bins)
55+
VectorQuantileProvider.hpp # Exact quantile via nth_element
56+
MaxKAffinityProvider.hpp # Top-K affinities per edge
57+
RegionSizeProvider.hpp # Voxel count per region (node statistic)
58+
ContactAreaProvider.hpp # Contact area per edge (edge statistic)
59+
RandomNumberProvider.hpp # Random scoring (baseline)
60+
ConstantProvider.hpp # Constant scoring
61+
Histogram.hpp # Fixed-bin histogram data structure
62+
MaxKValues.hpp # Sorted top-K value tracker
63+
discretize.hpp # [0,1] <-> integer bin conversion
64+
evaluate.hpp # Rand index and VOI computation
65+
```
66+
67+
## Key Data Types
68+
69+
| Type | C++ | Notes |
70+
|------|-----|-------|
71+
| **Affinities** | `uint8_t[3][Z][Y][X]` | 3-channel (z/y/x neighbor) affinity predictions, range [0, 255] |
72+
| **Segmentation** | `uint32_t[Z][Y][X]` | Fragment/segment IDs, 0 = background |
73+
| **Ground truth** | `uint32_t[Z][Y][X]` | For evaluation |
74+
| **Score** | `uint8_t` | Edge merge score (lower = more similar) |
75+
| **Region graph** | `(uint32_t u, uint32_t v, uint8_t score)[]` | Weighted edge list |
76+
77+
**Important:** This fork operates on **uint8 affinities** (0-255 range), not float32. The Python `waterz()` wrapper and scoring functions are designed for this integer representation.
78+
79+
## Public Python API
80+
81+
### `waterz.waterz(affs, thresholds, ...)` - Main entry point
82+
83+
```python
84+
import waterz
85+
seg_list = waterz.waterz(
86+
affs, # [3,Z,Y,X] uint8 or float32 affinities
87+
thresholds=[0.1, 0.3, 0.6], # agglomeration thresholds
88+
merge_function='aff50_his256', # scoring function (see below)
89+
aff_threshold=[1, 254], # low/high for initial watershed
90+
gt=None, # optional ground truth for metrics
91+
gt_border=25/4.0, # border mask distance for GT
92+
fragments=None, # pre-computed fragments (skip watershed)
93+
fragments_opt=0, # 0: use C++ watershed; !=0: use mahotas watershed
94+
return_rg=False, # also return region graph
95+
return_seg=True, # return segmentation arrays
96+
)
97+
# Returns: list of uint32 segmentation arrays (one per threshold)
98+
```
99+
100+
### `waterz.agglomerate(affs, thresholds, ...)` - Low-level generator
101+
102+
Returns a generator yielding `(segmentation, [metrics], [merge_history])` tuples. The segmentation array is modified in-place between yields (copy if needed).
103+
104+
### `waterz.watershed(affs, ...)` - 2D slice-by-slice watershed
105+
106+
```python
107+
fragments = waterz.watershed(
108+
affs, # [3,Z,Y,X] affinities
109+
seed_method='maxima_distance', # 'grid', 'minima', 'maxima_distance', 'maxima_distance2'
110+
label_nb=np.ones([5,5]), # structuring element for seed labeling
111+
bg_thres=1, # background threshold (<1 to assign background)
112+
)
113+
```
114+
115+
Uses mahotas `cwatershed` per 2D slice. Converts affinities to boundary map: `boundary = 1 - 0.5*(aff_y + aff_x) / 255`.
116+
117+
### `waterz.getRegionGraph(affs, fragments, ...)` - Extract region graph
118+
119+
```python
120+
rg_ids, rg_scores = waterz.getRegionGraph(
121+
affs, fragments,
122+
rg_opt=1, # 1: all slices, 2: skip first, 3: z-border only
123+
merge_function='aff50_his256',
124+
)
125+
# rg_ids: [N,2] uint32 - edge endpoints
126+
# rg_scores: [N] uint8 - edge scores (sorted ascending)
127+
```
128+
129+
### `waterz.merge_id(id1, id2, ...)` - Union-find merging
130+
131+
```python
132+
mapping = waterz.merge_id(
133+
id1, id2, # [N] uint32 edge endpoint arrays
134+
score=None, # [N] uint8 affinity scores (optional)
135+
count=None, # [M] uint32 segment sizes (optional)
136+
id_thres=0, # relabel threshold
137+
aff_thres=1, # affinity threshold for filtering
138+
count_thres=50, # size threshold (don't merge if both sides >= this)
139+
dust_thres=50, # remove segments smaller than this
140+
)
141+
# Returns: [M] uint32 mapping array (old_id -> new_id)
142+
```
143+
144+
Four merge modes based on which optional args are provided:
145+
1. `score=None, count=None`: merge by ID only
146+
2. `score!=None, count=None`: merge by ID + affinity threshold
147+
3. `score=None, count!=None`: merge by ID + size constraint
148+
4. `score!=None, count!=None`: merge by ID + affinity + size
149+
150+
### `waterz.evaluate_total_volume(seg, gt)` - Evaluation metrics
151+
152+
```python
153+
metrics = waterz.evaluate_total_volume(seg_uint64, gt_uint64)
154+
# Returns dict with: V_Rand_split, V_Rand_merge, V_Info_split, V_Info_merge
155+
```
156+
157+
### Chunked evaluation
158+
159+
```python
160+
stat = waterz.initialize_stats()
161+
for chunk_seg, chunk_gt in chunks:
162+
stat = waterz.update_statistics_using_volume(stat, seg_uint16, gt_uint16)
163+
metrics = waterz.compute_final_metrics(stat)
164+
```
165+
166+
## Scoring Functions (Merge Functions)
167+
168+
Scoring functions are specified as C++ template type strings. The `getScoreFunc()` helper in `seg_util.py` translates shorthand notation:
169+
170+
### Shorthand notation
171+
172+
Format: `aff{Q}_his{B}[_ran255]` or `max{K}[_ran255]`
173+
174+
| Shorthand | C++ Type | Description |
175+
|-----------|----------|-------------|
176+
| `aff50_his256` | `OneMinus<HistogramQuantileAffinity<RG, 50, SV, 256>>` | Median affinity via 256-bin histogram |
177+
| `aff50_his0` | `OneMinus<QuantileAffinity<RG, 50, SV>>` | Exact median affinity (vector-based) |
178+
| `aff85_his256` | `OneMinus<HistogramQuantileAffinity<RG, 85, SV, 256>>` | 85th percentile via histogram |
179+
| `aff50_his256_ran255` | `One255Minus<HistogramQuantileAffinity<RG, 50, SV, 256>>` | Same but score = 255 - quantile |
180+
| `max10` | `OneMinus<MeanMaxKAffinity<RG, 10, SV>>` | Mean of top-10 affinities |
181+
182+
### Available C++ scoring primitives
183+
184+
**Edge statistics (from providers):**
185+
- `MinAffinity` - minimum affinity across edge voxels
186+
- `MaxAffinity` - maximum affinity
187+
- `MeanAffinity` - running mean affinity
188+
- `HistogramQuantileAffinity<RG, Q, Prec, Bins>` - Q-th percentile via histogram
189+
- `QuantileAffinity<RG, Q, Prec>` - exact Q-th percentile
190+
- `MeanMaxKAffinity<RG, K, Prec>` - mean of top-K affinities
191+
- `ContactArea` - number of adjacent voxel pairs
192+
193+
**Node statistics:**
194+
- `MinSize` / `MaxSize` - min/max region size of edge endpoints
195+
196+
**Operators (composable):**
197+
- `OneMinus<F>` - `1 - f(e)` (converts affinity to distance)
198+
- `One255Minus<F>` - `255 - f(e)` (for uint8 range)
199+
- `Multiply<F1, F2>` - `f1(e) * f2(e)`
200+
- `Add<F1, F2>` / `Subtract<F1, F2>`
201+
- `Divide<F1, F2>` (safe division)
202+
- `Invert<F>` - `1/f(e)`
203+
- `Square<F>` - `f(e)^2`
204+
- `Step<F1, F2>` - `f1(e) < f2(e) ? 0 : 1`
205+
206+
**Special:**
207+
- `Random` - random score (baseline)
208+
- `Constant<C>` - constant integer score
209+
210+
## Agglomeration Pipeline
211+
212+
1. **Watershed** (C++ `basic_watershed.hpp` or Python `seg_watershed.py`):
213+
- Finds local maxima in the affinity graph
214+
- BFS to divide plateaus
215+
- Assigns fragment IDs to each basin
216+
- Background (affinity below `aff_threshold_low`) gets ID 0
217+
218+
2. **Region Graph Extraction** (`region_graph.hpp`):
219+
- Scans all voxel pairs across the 3 affinity channels
220+
- Collects affinities between adjacent fragments
221+
- Builds RAG with edge statistics (via StatisticsProvider callbacks)
222+
223+
3. **Iterative Region Merging** (`IterativeRegionMerging.hpp`):
224+
- Scores all edges using the scoring function
225+
- Pushes edges into priority queue (min-score first)
226+
- Pops cheapest edge, merges regions if score < threshold
227+
- Updates incident edges (marks stale for lazy re-scoring)
228+
- Uses union-find with path compression for fast root lookup
229+
- Supports incremental merging across multiple thresholds
230+
231+
4. **Evaluation** (optional, `evaluate.hpp`):
232+
- Computes Rand index (split/merge) and VOI (split/merge)
233+
- Uses co-occurrence matrix between segmentation and ground truth
234+
235+
## Queue Types
236+
237+
- **PriorityQueue** (default): Standard min-heap, exact ordering
238+
- **BinQueue**: Discretized into N bins, approximate but faster for large graphs
239+
240+
Selected at compile time via `discretize_queue` parameter (0 = PriorityQueue, N>0 = BinQueue with N bins).
241+
242+
## JIT Compilation
243+
244+
The `agglomerate()` function uses **just-in-time Cython compilation**. The scoring function is specified as a C++ type string, which gets written into a generated header file (`ScoringFunction.h`). The module is compiled once and cached in `~/.cython/inline/` with a hash-based filename for reuse.
245+
246+
## Connectomics-Specific Utilities
247+
248+
### `seg_region_graph.py`
249+
250+
- **`somaBFS(edges, soma_ids)`**: BFS-based soma-aware merging that prevents false merges between different soma bodies. Iteratively removes edges that would merge two somas into one segment.
251+
- **`branchIoU(bbs, sid, ...)`**: Tracks neurite branches across z-slices using IoU overlap between consecutive 2D segmentations. Used for reconstructing neuron morphology from 2D segments.
252+
- **`branchIoUBFS(bbs, sid, ...)`**: BFS extension of branchIoU that recursively finds all branches belonging to a neuron by following IoU connections.
253+
254+
### `seg_util.py`
255+
256+
- **`create_border_mask(gt, max_dist, bg_label)`**: Creates border masks on ground truth by labeling boundary pixels (within `max_dist` of a label boundary) as background. Used for fair evaluation by ignoring ambiguous boundary regions.
257+
- **`mappingToList(mapping)`**: Converts a dense mapping array to a sparse `[N,2]` list of `(old_id, new_id)` pairs for efficient I/O.
258+
259+
## Usage in PyTorch Connectomics
260+
261+
Waterz is used in the decoding/post-processing pipeline of PyTorch Connectomics for:
262+
1. Converting predicted affinity maps to initial over-segmentation (watershed)
263+
2. Agglomerating fragments into neuron instances at various thresholds
264+
3. Extracting region graphs for downstream processing
265+
4. Evaluating segmentation quality (Rand/VOI metrics)
266+
267+
Typical workflow:
268+
```python
269+
import waterz
270+
import numpy as np
271+
272+
# affs: [3, Z, Y, X] uint8 affinity predictions from model
273+
seg = waterz.waterz(
274+
affs,
275+
thresholds=[0.3],
276+
merge_function='aff50_his256',
277+
aff_threshold=[1, 254],
278+
)[0] # single threshold -> single segmentation
279+
```

connectomics/config/schema/inference.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class TestTimeAugmentationConfig:
5757
flip_combinations: Optional[List[List[int]]] = None # explicit list of axis subsets
5858
rotation90_axes: Any = None # "all" | None | [[int, int], ...] spatial plane pairs
5959
rotate90_k: Optional[List[int]] = None # subset of quarter-turns, defaults to [0,1,2,3]
60+
patch_first_local: bool = False # slide once, apply local TTA inside each ROI batch
6061
apply_mask: bool = True
6162
transforms: Optional[List[Dict[str, Any]]] = None # advanced explicit transforms
6263

0 commit comments

Comments
 (0)