@@ -203,19 +203,22 @@ def _per_z_scan(pred, z_stride: int) -> dict:
203203 sel = [(i , z ) for i , z in enumerate (z_idx ) if z0 <= z < z1 ]
204204 if not sel :
205205 continue
206- # Slice over the trailing Z axis regardless of leading dims.
207- block = np .asarray (pred [..., z0 :z1 ]).astype (np .float32 )
206+ # Read native dtype (typically float16) — keep the slab compact and
207+ # only widen one z-plane at a time below. Halves peak RAM on large
208+ # (C, X, Y, block_z) reads.
209+ block = np .asarray (pred [..., z0 :z1 ])
208210 nan_count += int (np .isnan (block ).sum ())
209211 inf_count += int (np .isinf (block ).sum ())
210212 for i , z in sel :
211- sl = block [..., z - z0 ].reshape (C , - 1 )
213+ sl = block [..., z - z0 ].astype ( np . float32 , copy = False ). reshape (C , - 1 )
212214 means [i ] = sl .mean (axis = 1 )
213215 stds [i ] = sl .std (axis = 1 )
214- g_sum += sl .sum (axis = 1 )
215- g_sq += ( sl . astype ( np .float64 ) ** 2 ).sum (axis = 1 )
216+ g_sum += sl .sum (axis = 1 , dtype = np . float64 )
217+ g_sq += np . square ( sl , dtype = np .float64 ).sum (axis = 1 )
216218 g_min = np .minimum (g_min , sl .min (axis = 1 ))
217219 g_max = np .maximum (g_max , sl .max (axis = 1 ))
218220 g_n += sl .shape [1 ]
221+ del block
219222 return {
220223 "z_idx" : z_idx , "means" : means , "stds" : stds ,
221224 "g_sum" : g_sum , "g_sq" : g_sq , "g_min" : g_min , "g_max" : g_max ,
@@ -233,29 +236,30 @@ def _refine_z_cuts(pred, interior_mean: np.ndarray,
233236 low_z = head_end
234237 head_rows = []
235238 if head_end > 0 :
236- block = np .asarray (pred [..., 0 :head_end ]).astype (np .float32 )
239+ # Read each Z-plane individually; refine_window is small (~30) so the
240+ # extra h5 calls are negligible vs holding (C, X, Y, refine_window)
241+ # widened to float32 in RAM.
237242 for z in range (head_end ):
238- m = block [..., z ].reshape (C , - 1 ).mean (axis = 1 )
243+ m = np .asarray (pred [..., z ]).astype (np .float32 , copy = False ) \
244+ .reshape (C , - 1 ).mean (axis = 1 )
239245 ok = bool ((m >= cutoff ).all ())
240246 head_rows .append ((z , m .copy (), ok ))
241247 if ok and low_z == head_end :
242248 low_z = z
243- del block
244249
245250 tail_start = max (0 , Z - refine_window )
246251 high_z = tail_start
247252 tail_rows = []
248253 if tail_start < Z :
249- block = np .asarray (pred [..., tail_start :Z ]).astype (np .float32 )
250254 last_ok = - 1
251255 for z in range (tail_start , Z ):
252- m = block [..., z - tail_start ].reshape (C , - 1 ).mean (axis = 1 )
256+ m = np .asarray (pred [..., z ]).astype (np .float32 , copy = False ) \
257+ .reshape (C , - 1 ).mean (axis = 1 )
253258 ok = bool ((m >= cutoff ).all ())
254259 tail_rows .append ((z , m .copy (), ok ))
255260 if ok :
256261 last_ok = z
257262 high_z = last_ok + 1 if last_ok >= 0 else tail_start
258- del block
259263
260264 return low_z , high_z , head_rows , tail_rows
261265
0 commit comments