Skip to content

Commit 2606508

Browse files
TeranisCopilot
andcommitted
feat: opt finding most common label on proj, tooltips
Co-authored-by: Copilot <copilot@github.com>
1 parent 91ac6d7 commit 2606508

5 files changed

Lines changed: 206 additions & 12 deletions

File tree

cellacdc/docs/source/tooltips.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ Edit tools: Segmentation and tracking
447447
* Note: right-click on a background ROI to remove it.
448448
* HELP: Use this function if you need to set the background level specific for each object. Cell-ACDC will save the metrics `amount`, `concentration` and `corrected_mean` where the background correction will be performed by subtracting the mean of the signal in the background ROI (for each object).
449449
* **Delete everything outside segmented areas (** |delObjsOutSegmMaskAction| **):** Select a segmentation file and delete everything outside segmented area.
450-
* **Hull contour (** |hullContToolButton| **"K"):** Right-click on a cell to replace it with its hull contour. Use it to fill cracks and holes. Shift right click fill holes of the entire 3D object.
450+
* **Hull contour (** |hullContToolButton| **"K"):** Right-click on a cell to replace it with its hull contour. Use it to fill cracks and holes. Shift right click to apply hull to the entire 3D object.
451451
* **Fill holes (** |fillHolesToolButton| **"F"):** Right-click on a cell to fill holes. Shift right click fill holes of the entire 3D object.
452452
* **Move object mask (** |moveLabelToolButton| **"P"):** Right-click drag and drop a labels to move it around.
453453
* **Expand/Shrink object mask (** |expandLabelToolButton| **"E"):** Leave mouse cursor on the label you want to expand/shrink and press arrow up/down on the keyboard to expand/shrink the mask.
Binary file not shown.

cellacdc/precompiled_functions.pyx

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,115 @@ def find_all_objects_3D(np.uint32_t[:, :, :] label_img):
128128
idx += 1
129129
return out_labels, out_bboxes
130130

131+
def most_common_projection_3D(np.uint32_t[:, :, :] lab, int axis):
132+
"""Most-common-value projection for a 3-D label image along `axis`.
133+
134+
Tie-break matches np.unique(..., return_counts=True) + np.argmax(counts),
135+
i.e. the smallest label wins when counts are equal.
136+
"""
137+
if axis < 0 or axis > 2:
138+
raise ValueError(f'axis must be 0, 1, or 2. Got {axis}.')
139+
140+
cdef Py_ssize_t z = lab.shape[0]
141+
cdef Py_ssize_t y = lab.shape[1]
142+
cdef Py_ssize_t x = lab.shape[2]
143+
cdef Py_ssize_t i, j, a, b, depth
144+
cdef unsigned int v, vv
145+
cdef unsigned int best_label, best_count, curr_count
146+
cdef bint seen
147+
cdef np.uint32_t[:, :] out_view
148+
149+
if axis == 0:
150+
depth = z
151+
out = np.empty((y, x), dtype=np.uint32)
152+
out_view = out
153+
for i in range(y):
154+
for j in range(x):
155+
best_count = 0
156+
best_label = UINT_MAX
157+
for a in range(depth):
158+
v = lab[a, i, j]
159+
seen = False
160+
for b in range(a):
161+
if lab[b, i, j] == v:
162+
seen = True
163+
break
164+
if seen:
165+
continue
166+
167+
curr_count = 1
168+
for b in range(a + 1, depth):
169+
if lab[b, i, j] == v:
170+
curr_count += 1
171+
172+
if curr_count > best_count or (curr_count == best_count and v < best_label):
173+
best_count = curr_count
174+
best_label = v
175+
176+
out_view[i, j] = best_label
177+
return out
178+
179+
if axis == 1:
180+
depth = y
181+
out = np.empty((z, x), dtype=np.uint32)
182+
out_view = out
183+
for i in range(z):
184+
for j in range(x):
185+
best_count = 0
186+
best_label = UINT_MAX
187+
for a in range(depth):
188+
v = lab[i, a, j]
189+
seen = False
190+
for b in range(a):
191+
if lab[i, b, j] == v:
192+
seen = True
193+
break
194+
if seen:
195+
continue
196+
197+
curr_count = 1
198+
for b in range(a + 1, depth):
199+
if lab[i, b, j] == v:
200+
curr_count += 1
201+
202+
if curr_count > best_count or (curr_count == best_count and v < best_label):
203+
best_count = curr_count
204+
best_label = v
205+
206+
out_view[i, j] = best_label
207+
return out
208+
209+
depth = x
210+
out = np.empty((z, y), dtype=np.uint32)
211+
out_view = out
212+
for i in range(z):
213+
for j in range(y):
214+
best_count = 0
215+
best_label = UINT_MAX
216+
for a in range(depth):
217+
v = lab[i, j, a]
218+
seen = False
219+
for b in range(a):
220+
vv = lab[i, j, b]
221+
if vv == v:
222+
seen = True
223+
break
224+
if seen:
225+
continue
226+
227+
curr_count = 1
228+
for b in range(a + 1, depth):
229+
vv = lab[i, j, b]
230+
if vv == v:
231+
curr_count += 1
232+
233+
if curr_count > best_count or (curr_count == best_count and v < best_label):
234+
best_count = curr_count
235+
best_label = v
236+
237+
out_view[i, j] = best_label
238+
return out
239+
131240
def calc_IoA_matrix_2D(
132241
np.uint32_t[:, :] lab,
133242
np.uint32_t[:, :] prev_lab,

cellacdc/regionprops.py

Lines changed: 66 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
from scipy import ndimage as ndi
3+
from scipy import stats as scipy_stats
34
import skimage.measure
45
import cv2
56
from . import printl, debugutils
@@ -9,10 +10,19 @@
910
import traceback as traceback
1011

1112
try:
12-
from cellacdc.precompiled.precompiled_functions import find_all_objects_2D, find_all_objects_3D
13+
from cellacdc.precompiled.precompiled_functions import (
14+
find_all_objects_2D,
15+
find_all_objects_3D,
16+
)
1317
_CYTHON_FIND_OBJECTS = True
1418
except Exception:
1519
_CYTHON_FIND_OBJECTS = False
20+
21+
try:
22+
from cellacdc.precompiled.precompiled_functions import most_common_projection_3D
23+
_CYTHON_MOST_COMMON_PROJECTION = True
24+
except Exception:
25+
_CYTHON_MOST_COMMON_PROJECTION = False
1626
# WARNING: Developers have already used
1727
# 14 hrs
1828
# to optimize this.
@@ -475,13 +485,7 @@ def _get_lab_projection(self, lab, slicing='z', kind='max'):
475485
return np.max(lab, axis=axis)
476486

477487
if kind == 'most_common':
478-
moved = np.moveaxis(lab, axis, 0)
479-
flat = moved.reshape(moved.shape[0], -1)
480-
projected_flat = np.empty(flat.shape[1], dtype=lab.dtype)
481-
for i in range(flat.shape[1]):
482-
values, counts = np.unique(flat[:, i], return_counts=True)
483-
projected_flat[i] = values[np.argmax(counts)]
484-
return projected_flat.reshape(moved.shape[1:])
488+
return self._compute_most_common_projection(lab, axis=axis)
485489

486490
if kind == 'mean':
487491
projected = np.mean(lab, axis=axis)
@@ -491,6 +495,47 @@ def _get_lab_projection(self, lab, slicing='z', kind='max'):
491495
# Regionprops requires integer labels.
492496
return np.rint(projected).astype(lab.dtype, copy=False)
493497

498+
def _compute_most_common_projection(self, lab, axis):
499+
if _CYTHON_MOST_COMMON_PROJECTION:
500+
lab_uint32 = lab.astype(np.uint32, copy=False)
501+
projected = most_common_projection_3D(lab_uint32, int(axis))
502+
return projected.astype(lab.dtype, copy=False)
503+
504+
moved = np.moveaxis(lab, axis, 0)
505+
projected = scipy_stats.mode(moved, axis=0, keepdims=False).mode
506+
return projected.astype(lab.dtype, copy=False)
507+
508+
def _get_projection_patch_slices(self, slicing, cutout_bbox):
509+
z0, y0, x0, z1, y1, x1 = [int(v) for v in cutout_bbox]
510+
if slicing == 'z':
511+
return (slice(y0, y1), slice(x0, x1))
512+
if slicing == 'y':
513+
return (slice(z0, z1), slice(x0, x1))
514+
return (slice(z0, z1), slice(y0, y1))
515+
516+
def _compute_most_common_projection_patch(self, slicing, cutout_bbox):
517+
z0, y0, x0, z1, y1, x1 = [int(v) for v in cutout_bbox]
518+
if slicing == 'z':
519+
patch_lab = self.lab[:, y0:y1, x0:x1]
520+
elif slicing == 'y':
521+
patch_lab = self.lab[z0:z1, :, x0:x1]
522+
else:
523+
patch_lab = self.lab[z0:z1, y0:y1, :]
524+
525+
axis = self._slice_axis_index(slicing)
526+
return self._compute_most_common_projection(patch_lab, axis=axis)
527+
528+
def _update_cached_most_common_projection_locally(self, slicing, cutout_bbox):
529+
lab_proj = self._get_cached_or_new_lab_projection(slicing, 'most_common')
530+
patch_slices = self._get_projection_patch_slices(slicing, cutout_bbox)
531+
if any(slc.start >= slc.stop for slc in patch_slices):
532+
return lab_proj
533+
534+
patch = self._compute_most_common_projection_patch(slicing, cutout_bbox)
535+
lab_proj[patch_slices] = patch
536+
self._proj_labs[slicing]['most_common'] = lab_proj
537+
return lab_proj
538+
494539
def _get_cached_or_new_lab_projection(self, slicing, kind):
495540
lab_proj = self._proj_labs[slicing].get(kind)
496541
if lab_proj is None:
@@ -581,12 +626,21 @@ def _sync_initialized_slice_rps_via_update(self, specific_IDs_update_centroids=N
581626
specific_IDs_update_centroids=specific_IDs_update_centroids,
582627
)
583628

584-
def _sync_initialized_proj_rps_via_update(self, specific_IDs_update_centroids=None):
629+
def _sync_initialized_proj_rps_via_update(
630+
self,
631+
specific_IDs_update_centroids=None,
632+
cutout_bbox=None,
633+
):
585634
if not self._has_initialized_proj_rps():
586635
return
587636

588637
for slicing, kind, rp in self._iter_initialized_proj_rps():
589-
lab_proj = self._replace_cached_lab_projection(slicing, kind)
638+
if cutout_bbox is not None and kind == 'most_common':
639+
lab_proj = self._update_cached_most_common_projection_locally(
640+
slicing, cutout_bbox
641+
)
642+
else:
643+
lab_proj = self._replace_cached_lab_projection(slicing, kind)
590644
rp.update_regionprops(
591645
lab_proj,
592646
specific_IDs_update_centroids=specific_IDs_update_centroids,
@@ -1083,7 +1137,8 @@ def update_regionprops_via_cutout(
10831137
specific_IDs_update_centroids=target_IDs
10841138
)
10851139
self._sync_initialized_proj_rps_via_update(
1086-
specific_IDs_update_centroids=target_IDs
1140+
specific_IDs_update_centroids=target_IDs,
1141+
cutout_bbox=cutout_bbox,
10871142
)
10881143

10891144
def get_centroid(self, ID, exact=False):

tests/test_regionprops_cutout_update.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,36 @@ def test_projection_regionprops_support_most_common_kind():
153153
assert rp.get_obj_from_proj_rp(3, kind='most common z-projection', warn=False) is not None
154154

155155

156+
def test_most_common_projection_uses_local_cutout_update(monkeypatch):
157+
old_lab = np.zeros((3, 6, 6), dtype=np.uint16)
158+
old_lab[:, 1:4, 1:4] = 1
159+
160+
rp = acdcRegionprops(old_lab)
161+
proj_before = rp.get_proj_rp(kind='most_common', slicing='z')
162+
expected_before = rp._get_lab_projection(old_lab, slicing='z', kind='most_common')
163+
np.testing.assert_array_equal(proj_before.lab, expected_before)
164+
165+
new_lab = old_lab.copy()
166+
new_lab[0:2, 2:5, 2:5] = 2
167+
168+
original_replace_cached = rp._replace_cached_lab_projection
169+
170+
def _replace_cached_should_not_run_for_most_common(slicing, kind):
171+
if kind == 'most_common':
172+
raise AssertionError(
173+
'most_common projection should be updated locally for cutout updates.'
174+
)
175+
return original_replace_cached(slicing, kind)
176+
177+
monkeypatch.setattr(rp, '_replace_cached_lab_projection', _replace_cached_should_not_run_for_most_common)
178+
179+
rp.update_regionprops_via_cutout(new_lab, cutout_bbox=(2, 2, 5, 5))
180+
181+
proj_after = rp.get_proj_rp(kind='most_common', slicing='z')
182+
expected_after = rp._get_lab_projection(new_lab, slicing='z', kind='most_common')
183+
np.testing.assert_array_equal(proj_after.lab, expected_after)
184+
185+
156186
def test_get_obj_from_id_for_stored_slice_and_projection_rps():
157187
lab = np.zeros((4, 8, 8), dtype=np.uint16)
158188
lab[1:3, 2:5, 3:6] = 2

0 commit comments

Comments
 (0)