From 0da3549bd350e0b1fbe707b447ce89e1535af428 Mon Sep 17 00:00:00 2001 From: measty <20169086+measty@users.noreply.github.com> Date: Fri, 22 May 2026 16:24:01 +0100 Subject: [PATCH 1/2] fix broken margin behaviour --- tests/engines/test_multi_task_segmentor.py | 36 ++++++++++++++++++- .../models/engine/multi_task_segmentor.py | 33 ++++++++++------- 2 files changed, 55 insertions(+), 14 deletions(-) diff --git a/tests/engines/test_multi_task_segmentor.py b/tests/engines/test_multi_task_segmentor.py index f2ddb7f84..20bd5f74b 100644 --- a/tests/engines/test_multi_task_segmentor.py +++ b/tests/engines/test_multi_task_segmentor.py @@ -889,9 +889,10 @@ class FakeVM: ) # --- Call function --- - new_zarr, new_da = _save_multitask_vertical_to_cache( + new_zarr, new_da, zarr_group = _save_multitask_vertical_to_cache( probabilities_zarr=probabilities_zarr, probabilities_da=probabilities_da, + zarr_group=None, probabilities=probabilities, idx=idx, tqdm_loop=tqdm_loop, @@ -905,11 +906,44 @@ class FakeVM: # new_zarr must be a real zarr array assert isinstance(new_zarr[idx], zarr.Array) + assert zarr_group is not None # Data was written correctly assert np.array_equal(new_zarr[idx][:], np.array([[1, 2, 3]])) +def test_multitask_vertical_merge_continues_after_zarr_spill( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """Test multitask vertical merge appends all chunks after spilling to Zarr.""" + + class FakeVM: + """Fake psutil.virtual_memory() with extremely low available memory.""" + + available = 1 + + monkeypatch.setattr(psutil, "virtual_memory", FakeVM) + + values = np.arange(8 * 3, dtype=np.float32).reshape(8, 3, 1) + canvas = [da.from_array(values, chunks=(2, 3, 1))] + count = [da.from_array(np.ones_like(values), chunks=(2, 3, 1))] + output_locs_y = np.array([[0, 2], [2, 4], [4, 6], [6, 8]]) + + result = merge_multitask_vertical_chunkwise( + canvas=canvas, + count=count, + output_locs_y_=output_locs_y, + zarr_group=None, + save_path=tmp_path / "vertical.zarr", + memory_threshold=0, + output_shape=(8, 3), + verbose=False, + ) + + assert result[0].shape == values.shape + assert np.array_equal(result[0].compute(), values) + + def test_qupath_feature_class_dict_lookup_fails() -> None: """Test qupath_feature_class_dict lookup fails.""" qupath_json = DaskDelayedJSONStore.__new__(DaskDelayedJSONStore) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index f183f325c..4a00208d3 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -2608,19 +2608,24 @@ def merge_multitask_vertical_chunkwise( chunk_shape=chunk_shape, probabilities_zarr=probabilities_zarr[idx], probabilities_da=probabilities_da[idx], - zarr_group=zarr_group, + zarr_group=( + zarr_group if probabilities_zarr[idx] is not None else None + ), name=f"probabilities/{idx}", ) - probabilities_zarr, probabilities_da = _save_multitask_vertical_to_cache( - probabilities_zarr=probabilities_zarr, - probabilities_da=probabilities_da, - probabilities=probabilities, - idx=idx, - tqdm_loop=tqdm_loop, - save_path=save_path, - chunk_shape=chunk_shape, - memory_threshold=memory_threshold, + probabilities_zarr, probabilities_da, zarr_group = ( + _save_multitask_vertical_to_cache( + probabilities_zarr=probabilities_zarr, + probabilities_da=probabilities_da, + zarr_group=zarr_group, + probabilities=probabilities, + idx=idx, + tqdm_loop=tqdm_loop, + save_path=save_path, + chunk_shape=chunk_shape, + memory_threshold=memory_threshold, + ) ) if next_chunk is not None: @@ -2647,13 +2652,14 @@ def merge_multitask_vertical_chunkwise( def _save_multitask_vertical_to_cache( probabilities_zarr: list[zarr.Array] | list[None], probabilities_da: list[da.Array] | list[None], + zarr_group: zarr.Group | None, probabilities: np.ndarray, idx: int, tqdm_loop: tqdm, save_path: Path, chunk_shape: tuple, memory_threshold: int = 80, -) -> tuple[list[zarr.Array], list[da.Array] | None]: +) -> tuple[list[zarr.Array], list[da.Array] | None, zarr.Group | None]: """Helper function to save to zarr if vertical merge is out of memory.""" used_percent = 0 if probabilities_da[idx] is not None: @@ -2669,7 +2675,8 @@ def _save_multitask_vertical_to_cache( f"Saving intermediate results to disk." ) update_tqdm_desc(tqdm_loop=tqdm_loop, desc=msg) - zarr_group = zarr.open(str(save_path), mode="a") + if zarr_group is None: + zarr_group = zarr.open(str(save_path), mode="a") probabilities_zarr[idx] = zarr_group.create_array( name=f"probabilities/{idx}", shape=probabilities_da[idx].shape, @@ -2681,7 +2688,7 @@ def _save_multitask_vertical_to_cache( update_tqdm_desc(tqdm_loop=tqdm_loop, desc=desc) probabilities_da[idx] = None - return probabilities_zarr, probabilities_da + return probabilities_zarr, probabilities_da, zarr_group def _clear_zarr( From e22ca77ca72ee1400b1af62b58d6eec82377e3fd Mon Sep 17 00:00:00 2001 From: measty <20169086+measty@users.noreply.github.com> Date: Thu, 28 May 2026 19:39:01 +0100 Subject: [PATCH 2/2] fix small stride bug --- tests/engines/test_semantic_segmentor.py | 26 +++ .../models/engine/semantic_segmentor.py | 179 +++++++++++++----- 2 files changed, 162 insertions(+), 43 deletions(-) diff --git a/tests/engines/test_semantic_segmentor.py b/tests/engines/test_semantic_segmentor.py index 567f55b00..f0d7c03f1 100644 --- a/tests/engines/test_semantic_segmentor.py +++ b/tests/engines/test_semantic_segmentor.py @@ -404,6 +404,32 @@ def test_merge_vertical_chunkwise_memory_threshold_triggered() -> None: assert np.all(zarr_group["probabilities"][:] == data) +def test_merge_vertical_chunkwise_multi_row_overlap() -> None: + """Test vertical merging when one row overlaps multiple following rows.""" + rows = [ + np.ones((4, 2, 1), dtype=np.float32), + np.ones((4, 2, 1), dtype=np.float32) * 2, + np.ones((4, 2, 1), dtype=np.float32) * 4, + ] + data = np.concatenate(rows, axis=0) + canvas = da.from_array(data, chunks=(4, 2, 1)) + count = da.from_array(np.ones_like(data, dtype=np.uint8), chunks=(4, 2, 1)) + output_locs_y_ = np.array([[0, 4], [1, 5], [2, 6]]) + + result = merge_vertical_chunkwise( + canvas=canvas, + count=count, + output_locs_y_=output_locs_y_, + zarr_group=None, + save_path=Path("unused"), + verbose=False, + ) + + expected_rows = np.array([1, 1.5, 7 / 3, 7 / 3, 3, 4], dtype=np.float32) + expected = np.broadcast_to(expected_rows[:, None, None], (6, 2, 1)) + np.testing.assert_allclose(result.compute(), expected) + + def test_raise_value_error_return_labels_wsi( remote_sample: Callable, track_tmp_path: Path, diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index 1d752b4df..aec7fa1cf 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -1395,6 +1395,84 @@ def get_wsi_output_shape(dataset: object) -> tuple[int, int] | None: return int(wsi_shape[1]), int(wsi_shape[0]) +def _get_vertical_chunk_locations( + output_locs_y: np.ndarray, + num_chunks: int, +) -> np.ndarray: + """Return unique vertical chunk locations in processing order.""" + chunk_locs = np.unique(output_locs_y, axis=0) + chunk_locs = chunk_locs[np.argsort(chunk_locs[:, 0], kind="stable")] + if len(chunk_locs) != num_chunks: + msg = ( + "Number of vertical output locations does not match the number " + "of merged canvas chunks." + ) + raise ValueError(msg) + return chunk_locs.astype(np.int64, copy=False) + + +def _aggregate_vertical_segment( + active_chunks: list[tuple[int, int, np.ndarray, np.ndarray]], + start_y: int, + end_y: int, +) -> np.ndarray: + """Average all active chunks covering a finalized vertical segment.""" + if end_y <= start_y: + return np.empty((0, *active_chunks[0][2].shape[1:])) + + segment_shape = (end_y - start_y, *active_chunks[0][2].shape[1:]) + segment = np.zeros(segment_shape, dtype=active_chunks[0][2].dtype) + segment_count = np.zeros( + (end_y - start_y, *active_chunks[0][3].shape[1:]), + dtype=np.uint32, + ) + + for chunk_start_y, chunk_end_y, chunk, chunk_count in active_chunks: + overlap_start = max(start_y, chunk_start_y) + overlap_end = min(end_y, chunk_end_y) + if overlap_end <= overlap_start: + continue + + source_start = overlap_start - chunk_start_y + source_end = overlap_end - chunk_start_y + target_start = overlap_start - start_y + target_end = overlap_end - start_y + + segment[target_start:target_end] += chunk[source_start:source_end] + segment_count[target_start:target_end] += chunk_count[source_start:source_end] + + segment_count = np.where(segment_count == 0, 1, segment_count) + return segment / segment_count.astype(np.float32) + + +def _store_vertical_segment( + probabilities: np.ndarray, + output_shape: tuple[int, int] | None, + written_height: int, + chunk_shape: tuple[int, ...], + probabilities_zarr: zarr.Array | None, + probabilities_da: da.Array | None, + zarr_group: zarr.Group | None, +) -> tuple[zarr.Array | None, da.Array | None, int, bool]: + """Clip and store a finalized vertical probability segment.""" + probabilities, written_height, should_stop = clip_probabilities_to_shape( + probabilities=probabilities, + output_shape=output_shape, + written_height=written_height, + ) + if should_stop or probabilities.shape[0] == 0: + return probabilities_zarr, probabilities_da, written_height, should_stop + + probabilities_zarr, probabilities_da = store_probabilities( + probabilities=probabilities, + chunk_shape=chunk_shape, + probabilities_zarr=probabilities_zarr, + probabilities_da=probabilities_da, + zarr_group=zarr_group, + ) + return probabilities_zarr, probabilities_da, written_height, False + + def merge_vertical_chunkwise( canvas: da.Array, count: da.Array, @@ -1410,8 +1488,8 @@ def merge_vertical_chunkwise( This function processes vertically stacked image blocks (`canvas`) and their associated count arrays to compute normalized probabilities. It handles overlapping - regions between chunks by applying seam folding and trimming halos to ensure smooth - transitions. If a Zarr group is provided, the result is stored incrementally. + regions between chunks by keeping active rows until no later chunk can contribute + to them. If a Zarr group is provided, the result is stored incrementally. Args: canvas (da.Array): @@ -1441,16 +1519,14 @@ def merge_vertical_chunkwise( or constructed in memory. """ - y0s, y1s = np.unique(output_locs_y_[:, 0]), np.unique(output_locs_y_[:, 1]) - overlaps = np.append(y1s[:-1] - y0s[1:], 0) - num_chunks = canvas.numblocks[0] probabilities_zarr, probabilities_da = None, None chunk_shape = tuple(chunk[0] for chunk in canvas.chunks) written_height = 0 + chunk_locs = _get_vertical_chunk_locations(output_locs_y_, num_chunks) tqdm_loop = tqdm( - overlaps, + range(num_chunks), leave=False, desc="Merging rows", disable=not verbose, @@ -1458,37 +1534,46 @@ def merge_vertical_chunkwise( used_percent = 0 - curr_chunk = canvas.blocks[0, 0].compute() - curr_count = count.blocks[0, 0].compute() - next_chunk = canvas.blocks[1, 0].compute() if num_chunks > 1 else None - next_count = count.blocks[1, 0].compute() if num_chunks > 1 else None - + active_chunks: list[tuple[int, int, np.ndarray, np.ndarray]] = [] probabilities = np.empty(0) + current_y = int(chunk_locs[0, 0]) + should_stop = False - for i, overlap in enumerate(tqdm_loop): - if next_chunk is not None and overlap > 0: - curr_chunk[-overlap:] += next_chunk[:overlap] - curr_count[-overlap:] += next_count[:overlap] - - # Normalize - curr_count = np.where(curr_count == 0, 1, curr_count) - probabilities = curr_chunk / curr_count.astype(np.float32) + for chunk_idx in tqdm_loop: + chunk_start_y, chunk_end_y = map(int, chunk_locs[chunk_idx]) - probabilities, written_height, should_stop = clip_probabilities_to_shape( - probabilities=probabilities, - output_shape=output_shape, - written_height=written_height, - ) - if should_stop: - break - - probabilities_zarr, probabilities_da = store_probabilities( - probabilities=probabilities, - chunk_shape=chunk_shape, - probabilities_zarr=probabilities_zarr, - probabilities_da=probabilities_da, - zarr_group=zarr_group, - ) + if active_chunks and chunk_start_y > current_y: + probabilities = _aggregate_vertical_segment( + active_chunks=active_chunks, + start_y=current_y, + end_y=chunk_start_y, + ) + probabilities_zarr, probabilities_da, written_height, should_stop = ( + _store_vertical_segment( + probabilities=probabilities, + output_shape=output_shape, + written_height=written_height, + chunk_shape=chunk_shape, + probabilities_zarr=probabilities_zarr, + probabilities_da=probabilities_da, + zarr_group=zarr_group, + ) + ) + if should_stop: + break + + current_y = chunk_start_y + active_chunks = [ + active_chunk + for active_chunk in active_chunks + if active_chunk[1] > current_y + ] + + chunk = canvas.blocks[chunk_idx, 0].compute() + chunk_count = count.blocks[chunk_idx, 0].compute() + valid_chunk_end_y = min(chunk_end_y, chunk_start_y + chunk.shape[0]) + if valid_chunk_end_y > chunk_start_y: + active_chunks.append((chunk_start_y, valid_chunk_end_y, chunk, chunk_count)) if probabilities_da is not None: vm = psutil.virtual_memory() @@ -1514,16 +1599,24 @@ def merge_vertical_chunkwise( probabilities_da = None update_tqdm_desc(tqdm_loop=tqdm_loop, desc=desc) - if next_chunk is not None: - curr_chunk, curr_count = next_chunk[overlap:], next_count[overlap:] - - if i + 2 < num_chunks: - next_chunk = canvas.blocks[i + 2, 0].compute() - next_count = count.blocks[i + 2, 0].compute() - else: - next_chunk, next_count = None, None + if active_chunks and not should_stop: + final_y = max(active_chunk[1] for active_chunk in active_chunks) + probabilities = _aggregate_vertical_segment( + active_chunks=active_chunks, + start_y=current_y, + end_y=final_y, + ) + probabilities_zarr, probabilities_da, _, _ = _store_vertical_segment( + probabilities=probabilities, + output_shape=output_shape, + written_height=written_height, + chunk_shape=chunk_shape, + probabilities_zarr=probabilities_zarr, + probabilities_da=probabilities_da, + zarr_group=zarr_group, + ) - if probabilities_zarr: + if probabilities_zarr is not None: return _get_probabilities_da_from_zarr( zarr_group=zarr_group, probabilities_zarr=probabilities_zarr,