Skip to content

Commit 482a8a4

Browse files
Donglai Weiclaude
andcommitted
Add waterz decoder with dust merge, Optuna batch tuning, and experiment auto-logging
- Add decode_waterz: waterz agglomeration decoder with channel_order (xyz/zyx), dust_merge_size/dust_merge_affinity/dust_remove_size for zwatershed-style size+affinity dust cleanup (C++ implementation in lib/waterz) - Add waterz batch threshold mode to Optuna tuner: sweep all thresholds in one waterz call (watershed computed once, incremental merging), with categorical merge_function tuning - Add decode experiment auto-logger: appends decode params + metrics to decode_experiments.tsv after every test run for systematic tracking - Add tune_waterz profile and update neuron_snemi.yaml for waterz tuning - Add reference docs: waterz (new fork), waterz_dw (old fork), zwatershed_dw, autoresearch Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 9e2bd43 commit 482a8a4

13 files changed

Lines changed: 1406 additions & 235 deletions

File tree

.claude/reference/autoresearch.md

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# autoresearch Reference
2+
3+
**Location:** `lib/autoresearch/`
4+
**Origin:** [karpathy/autoresearch](https://github.com/karpathy/autoresearch)
5+
**License:** MIT
6+
7+
Autonomous AI research framework: an agent modifies code, trains for 5 minutes, evaluates, keeps or discards, and repeats — indefinitely. Designed for LLM pretraining experiments on a single GPU.
8+
9+
## Core Idea
10+
11+
The human writes `program.md` (agent instructions). The agent edits `train.py` (model + optimizer + training loop). `prepare.py` is read-only (data, tokenizer, eval). The metric is `val_bpb` (validation bits per byte) — lower is better. Training always runs for exactly 5 minutes wall clock.
12+
13+
## Files
14+
15+
| File | Role | Who edits |
16+
|------|------|-----------|
17+
| `program.md` | Agent instructions, experiment loop protocol | Human |
18+
| `train.py` | GPT model, Muon+AdamW optimizer, training loop | Agent |
19+
| `prepare.py` | Constants, data download, tokenizer, dataloader, `evaluate_bpb` | Nobody (read-only) |
20+
21+
## Experiment Loop (from `program.md`)
22+
23+
```
24+
LOOP FOREVER:
25+
1. Read git state
26+
2. Edit train.py with experimental idea
27+
3. git commit
28+
4. Run: uv run train.py > run.log 2>&1
29+
5. Read results: grep "^val_bpb:" run.log
30+
6. If crash → read traceback, attempt fix or skip
31+
7. Log to results.tsv
32+
8. If val_bpb improved → keep commit (advance branch)
33+
9. If val_bpb equal or worse → git reset to previous
34+
```
35+
36+
Each experiment runs on a dedicated branch (`autoresearch/<tag>`). The agent never stops to ask — it runs autonomously until interrupted.
37+
38+
## Logging Format (`results.tsv`)
39+
40+
Tab-separated, 5 columns:
41+
42+
```
43+
commit val_bpb memory_gb status description
44+
a1b2c3d 0.997900 44.0 keep baseline
45+
b2c3d4e 0.993200 44.2 keep increase LR to 0.04
46+
c3d4e5f 1.005000 44.0 discard switch to GeLU activation
47+
d4e5f6g 0.000000 0.0 crash double model width (OOM)
48+
```
49+
50+
Status: `keep` (improved), `discard` (equal/worse), `crash` (failed).
51+
52+
## Key Design Principles
53+
54+
1. **Single file to modify** — only `train.py`. Keeps scope manageable and diffs reviewable.
55+
2. **Fixed time budget** — always 5 minutes. Makes experiments directly comparable regardless of what changed (model size, batch size, architecture).
56+
3. **One metric**`val_bpb`. Vocab-size-independent, so architectural changes (vocab size, tokenizer) are fairly compared.
57+
4. **Self-contained** — no external dependencies beyond PyTorch + small packages. One GPU, one file, one metric.
58+
5. **Simplicity criterion** — all else equal, simpler is better. Removing code that doesn't help is a win.
59+
6. **Never stop** — the agent runs indefinitely. ~12 experiments/hour, ~100 overnight.
60+
61+
## Model Architecture (`train.py`)
62+
63+
GPT with modern tricks:
64+
- **RMSNorm** (pre-norm)
65+
- **Rotary Position Embeddings** (RoPE)
66+
- **Flash Attention 3** (via `kernels` package)
67+
- **Sliding window attention** (pattern: SSSL — 3 short + 1 long)
68+
- **Value Embeddings** (ResFormer-style, alternating layers, gated)
69+
- **Residual lambdas + x0 lambdas** (per-layer learnable scalars)
70+
- **Squared ReLU** activation in MLP
71+
- **Logit softcapping** (softcap=15)
72+
73+
Default: 8 layers, 768 dim, 6 heads, ~50M params.
74+
75+
## Optimizer: MuonAdamW
76+
77+
Combined optimizer:
78+
- **Muon** for 2D matrix params (attention, MLP projections) — polar express orthogonalization + NorMuon variance reduction + cautious weight decay
79+
- **AdamW** for everything else (embeddings, scalars, lm_head)
80+
81+
Separate LR groups: embedding (0.6), unembedding (0.004), matrices (0.04), scalars (0.5). All scaled by `1/sqrt(d_model/768)`.
82+
83+
## Relevance to PyTorch Connectomics
84+
85+
The autoresearch loop pattern is directly applicable to decoding parameter tuning:
86+
87+
| autoresearch | PyTC decoding |
88+
|-------------|---------------|
89+
| Edit `train.py` | Edit decode params (threshold, merge_function, aff_threshold, ...) |
90+
| Run 5min training | Run decode + evaluate (~seconds) |
91+
| Metric: val_bpb | Metric: adapted_rand |
92+
| Log to results.tsv | Log to experimental log in crackit/decoding.md |
93+
| keep/discard/crash | keep/discard based on ARE improvement |
94+
| git branch per run | Optuna study per sweep |
95+
96+
The key difference: decoding experiments are much faster (~seconds vs 5 minutes), so the agent can run hundreds of experiments per hour instead of 12.

.claude/reference/waterz.md

Lines changed: 133 additions & 195 deletions
Large diffs are not rendered by default.

.claude/reference/waterz_dw.md

Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
# waterz - Watershed and Region Agglomeration Library
2+
3+
**Location:** `/projects/weilab/weidf/lib/waterz/`
4+
**Version:** 0.8
5+
**License:** MIT
6+
**Origin:** Fork of [funkey/waterz](https://github.com/funkey/waterz) by donglaiw, with CREMI scoring functions from Mala_v2.zip
7+
**Language:** Python + Cython + C++11 (Boost multi_array)
8+
**Dependencies:** cython, numpy, scipy, mahotas, boost (C++ headers)
9+
10+
## Purpose
11+
12+
Waterz ("water-zed") performs watershed segmentation and hierarchical region agglomeration on 3D affinity graphs. It is the core post-processing library for converting voxel-level affinity predictions (from neural networks) into instance segmentations of neurons/organelles in connectomics EM volumes.
13+
14+
## Installation
15+
16+
```bash
17+
conda create -n zw python==3.8 cython numpy
18+
pip install --editable /projects/weilab/weidf/lib/waterz
19+
# or: cd /projects/weilab/weidf/lib/waterz && python setup.py build_ext --inplace
20+
```
21+
22+
## Architecture Overview
23+
24+
```
25+
waterz/
26+
__init__.py # Public API: agglomerate(), waterz(), watershed(), etc.
27+
seg_waterz.py # High-level waterz() and getRegionGraph() wrappers
28+
seg_watershed.py # Python 2D watershed (mahotas-based, slice-by-slice)
29+
seg_region_graph.py # Soma-aware BFS merging, branch IoU utilities
30+
seg_util.py # Helpers: scoring function string builder, HDF5 I/O, border masks
31+
agglomerate.pyx # Cython bridge: agglomerate() -> C++ initialize/mergeUntil/free
32+
region_graph.pyx # Cython bridge: merge_id() variants -> C++ union-find merging
33+
evaluate.pyx # Cython bridge: Rand/VOI metrics -> C++ compare_volumes
34+
frontend_agglomerate.h/cpp # C++ agglomeration pipeline (WaterzContext state machine)
35+
frontend_region_graph.h/cpp # C++ union-find merge with optional affinity/count filtering
36+
frontend_evaluate.h/cpp # C++ evaluation (compare_arrays, chunked statistics)
37+
frontend_basic.h # Type definitions: SegID=uint32, AffValue=uint8, ScoreValue=uint8
38+
backend/ # C++ template library (header-only)
39+
types.hpp # boost::multi_array typedefs, watershed_traits
40+
basic_watershed.hpp # C++ watershed on affinity graph (BFS plateau division)
41+
RegionGraph.hpp # Region Adjacency Graph with node/edge maps
42+
region_graph.hpp # Extract RAG from segmentation + affinities
43+
IterativeRegionMerging.hpp # Priority-queue agglomeration engine
44+
PriorityQueue.hpp # Min-heap priority queue wrapper
45+
BinQueue.hpp # Discretized bin queue (approximate priority queue)
46+
StatisticsProvider.hpp # Base class with merge/edge callbacks
47+
MergeProviders.hpp # Template meta-programming to combine providers
48+
CompoundProvider.hpp # Multiple inheritance provider combiner
49+
MergeFunctions.hpp # Scoring functions: MinSize, MaxSize, MinAffinity, MeanAffinity, etc.
50+
Operators.hpp # Composable operators: OneMinus, One255Minus, Multiply, Add, etc.
51+
MeanAffinityProvider.hpp # Running mean of edge affinities
52+
MinAffinityProvider.hpp # Min affinity per edge
53+
MaxAffinityProvider.hpp # Max affinity per edge
54+
HistogramQuantileProvider.hpp # Histogram-based quantile (approximate, 256 bins)
55+
VectorQuantileProvider.hpp # Exact quantile via nth_element
56+
MaxKAffinityProvider.hpp # Top-K affinities per edge
57+
RegionSizeProvider.hpp # Voxel count per region (node statistic)
58+
ContactAreaProvider.hpp # Contact area per edge (edge statistic)
59+
RandomNumberProvider.hpp # Random scoring (baseline)
60+
ConstantProvider.hpp # Constant scoring
61+
Histogram.hpp # Fixed-bin histogram data structure
62+
MaxKValues.hpp # Sorted top-K value tracker
63+
discretize.hpp # [0,1] <-> integer bin conversion
64+
evaluate.hpp # Rand index and VOI computation
65+
```
66+
67+
## Key Data Types
68+
69+
| Type | C++ | Notes |
70+
|------|-----|-------|
71+
| **Affinities** | `uint8_t[3][Z][Y][X]` | 3-channel (z/y/x neighbor) affinity predictions, range [0, 255] |
72+
| **Segmentation** | `uint32_t[Z][Y][X]` | Fragment/segment IDs, 0 = background |
73+
| **Ground truth** | `uint32_t[Z][Y][X]` | For evaluation |
74+
| **Score** | `uint8_t` | Edge merge score (lower = more similar) |
75+
| **Region graph** | `(uint32_t u, uint32_t v, uint8_t score)[]` | Weighted edge list |
76+
77+
**Important:** This fork operates on **uint8 affinities** (0-255 range), not float32. The Python `waterz()` wrapper and scoring functions are designed for this integer representation.
78+
79+
## Public Python API
80+
81+
### `waterz.waterz(affs, thresholds, ...)` - Main entry point
82+
83+
```python
84+
import waterz
85+
seg_list = waterz.waterz(
86+
affs, # [3,Z,Y,X] uint8 or float32 affinities
87+
thresholds=[0.1, 0.3, 0.6], # agglomeration thresholds
88+
merge_function='aff50_his256', # scoring function (see below)
89+
aff_threshold=[1, 254], # low/high for initial watershed
90+
gt=None, # optional ground truth for metrics
91+
gt_border=25/4.0, # border mask distance for GT
92+
fragments=None, # pre-computed fragments (skip watershed)
93+
fragments_opt=0, # 0: use C++ watershed; !=0: use mahotas watershed
94+
return_rg=False, # also return region graph
95+
return_seg=True, # return segmentation arrays
96+
)
97+
# Returns: list of uint32 segmentation arrays (one per threshold)
98+
```
99+
100+
### `waterz.agglomerate(affs, thresholds, ...)` - Low-level generator
101+
102+
Returns a generator yielding `(segmentation, [metrics], [merge_history])` tuples. The segmentation array is modified in-place between yields (copy if needed).
103+
104+
### `waterz.watershed(affs, ...)` - 2D slice-by-slice watershed
105+
106+
```python
107+
fragments = waterz.watershed(
108+
affs, # [3,Z,Y,X] affinities
109+
seed_method='maxima_distance', # 'grid', 'minima', 'maxima_distance', 'maxima_distance2'
110+
label_nb=np.ones([5,5]), # structuring element for seed labeling
111+
bg_thres=1, # background threshold (<1 to assign background)
112+
)
113+
```
114+
115+
Uses mahotas `cwatershed` per 2D slice. Converts affinities to boundary map: `boundary = 1 - 0.5*(aff_y + aff_x) / 255`.
116+
117+
### `waterz.getRegionGraph(affs, fragments, ...)` - Extract region graph
118+
119+
```python
120+
rg_ids, rg_scores = waterz.getRegionGraph(
121+
affs, fragments,
122+
rg_opt=1, # 1: all slices, 2: skip first, 3: z-border only
123+
merge_function='aff50_his256',
124+
)
125+
# rg_ids: [N,2] uint32 - edge endpoints
126+
# rg_scores: [N] uint8 - edge scores (sorted ascending)
127+
```
128+
129+
### `waterz.merge_id(id1, id2, ...)` - Union-find merging
130+
131+
```python
132+
mapping = waterz.merge_id(
133+
id1, id2, # [N] uint32 edge endpoint arrays
134+
score=None, # [N] uint8 affinity scores (optional)
135+
count=None, # [M] uint32 segment sizes (optional)
136+
id_thres=0, # relabel threshold
137+
aff_thres=1, # affinity threshold for filtering
138+
count_thres=50, # size threshold (don't merge if both sides >= this)
139+
dust_thres=50, # remove segments smaller than this
140+
)
141+
# Returns: [M] uint32 mapping array (old_id -> new_id)
142+
```
143+
144+
Four merge modes based on which optional args are provided:
145+
1. `score=None, count=None`: merge by ID only
146+
2. `score!=None, count=None`: merge by ID + affinity threshold
147+
3. `score=None, count!=None`: merge by ID + size constraint
148+
4. `score!=None, count!=None`: merge by ID + affinity + size
149+
150+
### `waterz.evaluate_total_volume(seg, gt)` - Evaluation metrics
151+
152+
```python
153+
metrics = waterz.evaluate_total_volume(seg_uint64, gt_uint64)
154+
# Returns dict with: V_Rand_split, V_Rand_merge, V_Info_split, V_Info_merge
155+
```
156+
157+
### Chunked evaluation
158+
159+
```python
160+
stat = waterz.initialize_stats()
161+
for chunk_seg, chunk_gt in chunks:
162+
stat = waterz.update_statistics_using_volume(stat, seg_uint16, gt_uint16)
163+
metrics = waterz.compute_final_metrics(stat)
164+
```
165+
166+
## Scoring Functions (Merge Functions)
167+
168+
Scoring functions are specified as C++ template type strings. The `getScoreFunc()` helper in `seg_util.py` translates shorthand notation:
169+
170+
### Shorthand notation
171+
172+
Format: `aff{Q}_his{B}[_ran255]` or `max{K}[_ran255]`
173+
174+
| Shorthand | C++ Type | Description |
175+
|-----------|----------|-------------|
176+
| `aff50_his256` | `OneMinus<HistogramQuantileAffinity<RG, 50, SV, 256>>` | Median affinity via 256-bin histogram |
177+
| `aff50_his0` | `OneMinus<QuantileAffinity<RG, 50, SV>>` | Exact median affinity (vector-based) |
178+
| `aff85_his256` | `OneMinus<HistogramQuantileAffinity<RG, 85, SV, 256>>` | 85th percentile via histogram |
179+
| `aff50_his256_ran255` | `One255Minus<HistogramQuantileAffinity<RG, 50, SV, 256>>` | Same but score = 255 - quantile |
180+
| `max10` | `OneMinus<MeanMaxKAffinity<RG, 10, SV>>` | Mean of top-10 affinities |
181+
182+
### Available C++ scoring primitives
183+
184+
**Edge statistics (from providers):**
185+
- `MinAffinity` - minimum affinity across edge voxels
186+
- `MaxAffinity` - maximum affinity
187+
- `MeanAffinity` - running mean affinity
188+
- `HistogramQuantileAffinity<RG, Q, Prec, Bins>` - Q-th percentile via histogram
189+
- `QuantileAffinity<RG, Q, Prec>` - exact Q-th percentile
190+
- `MeanMaxKAffinity<RG, K, Prec>` - mean of top-K affinities
191+
- `ContactArea` - number of adjacent voxel pairs
192+
193+
**Node statistics:**
194+
- `MinSize` / `MaxSize` - min/max region size of edge endpoints
195+
196+
**Operators (composable):**
197+
- `OneMinus<F>` - `1 - f(e)` (converts affinity to distance)
198+
- `One255Minus<F>` - `255 - f(e)` (for uint8 range)
199+
- `Multiply<F1, F2>` - `f1(e) * f2(e)`
200+
- `Add<F1, F2>` / `Subtract<F1, F2>`
201+
- `Divide<F1, F2>` (safe division)
202+
- `Invert<F>` - `1/f(e)`
203+
- `Square<F>` - `f(e)^2`
204+
- `Step<F1, F2>` - `f1(e) < f2(e) ? 0 : 1`
205+
206+
**Special:**
207+
- `Random` - random score (baseline)
208+
- `Constant<C>` - constant integer score
209+
210+
## Agglomeration Pipeline
211+
212+
1. **Watershed** (C++ `basic_watershed.hpp` or Python `seg_watershed.py`):
213+
- Finds local maxima in the affinity graph
214+
- BFS to divide plateaus
215+
- Assigns fragment IDs to each basin
216+
- Background (affinity below `aff_threshold_low`) gets ID 0
217+
218+
2. **Region Graph Extraction** (`region_graph.hpp`):
219+
- Scans all voxel pairs across the 3 affinity channels
220+
- Collects affinities between adjacent fragments
221+
- Builds RAG with edge statistics (via StatisticsProvider callbacks)
222+
223+
3. **Iterative Region Merging** (`IterativeRegionMerging.hpp`):
224+
- Scores all edges using the scoring function
225+
- Pushes edges into priority queue (min-score first)
226+
- Pops cheapest edge, merges regions if score < threshold
227+
- Updates incident edges (marks stale for lazy re-scoring)
228+
- Uses union-find with path compression for fast root lookup
229+
- Supports incremental merging across multiple thresholds
230+
231+
4. **Evaluation** (optional, `evaluate.hpp`):
232+
- Computes Rand index (split/merge) and VOI (split/merge)
233+
- Uses co-occurrence matrix between segmentation and ground truth
234+
235+
## Queue Types
236+
237+
- **PriorityQueue** (default): Standard min-heap, exact ordering
238+
- **BinQueue**: Discretized into N bins, approximate but faster for large graphs
239+
240+
Selected at compile time via `discretize_queue` parameter (0 = PriorityQueue, N>0 = BinQueue with N bins).
241+
242+
## JIT Compilation
243+
244+
The `agglomerate()` function uses **just-in-time Cython compilation**. The scoring function is specified as a C++ type string, which gets written into a generated header file (`ScoringFunction.h`). The module is compiled once and cached in `~/.cython/inline/` with a hash-based filename for reuse.
245+
246+
## Connectomics-Specific Utilities
247+
248+
### `seg_region_graph.py`
249+
250+
- **`somaBFS(edges, soma_ids)`**: BFS-based soma-aware merging that prevents false merges between different soma bodies. Iteratively removes edges that would merge two somas into one segment.
251+
- **`branchIoU(bbs, sid, ...)`**: Tracks neurite branches across z-slices using IoU overlap between consecutive 2D segmentations. Used for reconstructing neuron morphology from 2D segments.
252+
- **`branchIoUBFS(bbs, sid, ...)`**: BFS extension of branchIoU that recursively finds all branches belonging to a neuron by following IoU connections.
253+
254+
### `seg_util.py`
255+
256+
- **`create_border_mask(gt, max_dist, bg_label)`**: Creates border masks on ground truth by labeling boundary pixels (within `max_dist` of a label boundary) as background. Used for fair evaluation by ignoring ambiguous boundary regions.
257+
- **`mappingToList(mapping)`**: Converts a dense mapping array to a sparse `[N,2]` list of `(old_id, new_id)` pairs for efficient I/O.
258+
259+
## Usage in PyTorch Connectomics
260+
261+
Waterz is used in the decoding/post-processing pipeline of PyTorch Connectomics for:
262+
1. Converting predicted affinity maps to initial over-segmentation (watershed)
263+
2. Agglomerating fragments into neuron instances at various thresholds
264+
3. Extracting region graphs for downstream processing
265+
4. Evaluating segmentation quality (Rand/VOI metrics)
266+
267+
Typical workflow:
268+
```python
269+
import waterz
270+
import numpy as np
271+
272+
# affs: [3, Z, Y, X] uint8 affinity predictions from model
273+
seg = waterz.waterz(
274+
affs,
275+
thresholds=[0.3],
276+
merge_function='aff50_his256',
277+
aff_threshold=[1, 254],
278+
)[0] # single threshold -> single segmentation
279+
```

0 commit comments

Comments
 (0)