Skip to content

Commit 333b3e2

Browse files
author
Donglai Wei
committed
Merge branch 'worktree-optuna-fail-orphan-running'
2 parents a4f3456 + ecb8d53 commit 333b3e2

1 file changed

Lines changed: 15 additions & 11 deletions

File tree

connectomics/decoding/qc/affinity.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -203,19 +203,22 @@ def _per_z_scan(pred, z_stride: int) -> dict:
203203
sel = [(i, z) for i, z in enumerate(z_idx) if z0 <= z < z1]
204204
if not sel:
205205
continue
206-
# Slice over the trailing Z axis regardless of leading dims.
207-
block = np.asarray(pred[..., z0:z1]).astype(np.float32)
206+
# Read native dtype (typically float16) — keep the slab compact and
207+
# only widen one z-plane at a time below. Halves peak RAM on large
208+
# (C, X, Y, block_z) reads.
209+
block = np.asarray(pred[..., z0:z1])
208210
nan_count += int(np.isnan(block).sum())
209211
inf_count += int(np.isinf(block).sum())
210212
for i, z in sel:
211-
sl = block[..., z - z0].reshape(C, -1)
213+
sl = block[..., z - z0].astype(np.float32, copy=False).reshape(C, -1)
212214
means[i] = sl.mean(axis=1)
213215
stds[i] = sl.std(axis=1)
214-
g_sum += sl.sum(axis=1)
215-
g_sq += (sl.astype(np.float64) ** 2).sum(axis=1)
216+
g_sum += sl.sum(axis=1, dtype=np.float64)
217+
g_sq += np.square(sl, dtype=np.float64).sum(axis=1)
216218
g_min = np.minimum(g_min, sl.min(axis=1))
217219
g_max = np.maximum(g_max, sl.max(axis=1))
218220
g_n += sl.shape[1]
221+
del block
219222
return {
220223
"z_idx": z_idx, "means": means, "stds": stds,
221224
"g_sum": g_sum, "g_sq": g_sq, "g_min": g_min, "g_max": g_max,
@@ -233,29 +236,30 @@ def _refine_z_cuts(pred, interior_mean: np.ndarray,
233236
low_z = head_end
234237
head_rows = []
235238
if head_end > 0:
236-
block = np.asarray(pred[..., 0:head_end]).astype(np.float32)
239+
# Read each Z-plane individually; refine_window is small (~30) so the
240+
# extra h5 calls are negligible vs holding (C, X, Y, refine_window)
241+
# widened to float32 in RAM.
237242
for z in range(head_end):
238-
m = block[..., z].reshape(C, -1).mean(axis=1)
243+
m = np.asarray(pred[..., z]).astype(np.float32, copy=False) \
244+
.reshape(C, -1).mean(axis=1)
239245
ok = bool((m >= cutoff).all())
240246
head_rows.append((z, m.copy(), ok))
241247
if ok and low_z == head_end:
242248
low_z = z
243-
del block
244249

245250
tail_start = max(0, Z - refine_window)
246251
high_z = tail_start
247252
tail_rows = []
248253
if tail_start < Z:
249-
block = np.asarray(pred[..., tail_start:Z]).astype(np.float32)
250254
last_ok = -1
251255
for z in range(tail_start, Z):
252-
m = block[..., z - tail_start].reshape(C, -1).mean(axis=1)
256+
m = np.asarray(pred[..., z]).astype(np.float32, copy=False) \
257+
.reshape(C, -1).mean(axis=1)
253258
ok = bool((m >= cutoff).all())
254259
tail_rows.append((z, m.copy(), ok))
255260
if ok:
256261
last_ok = z
257262
high_z = last_ok + 1 if last_ok >= 0 else tail_start
258-
del block
259263

260264
return low_z, high_z, head_rows, tail_rows
261265

0 commit comments

Comments
 (0)