Skip to content

Commit 3d748d7

Browse files
committed
Improved joined segmentation algorithm to fix splitted cells, and streamlined coding for performance
1 parent a85f345 commit 3d748d7

2 files changed

Lines changed: 58 additions & 256 deletions

File tree

Lines changed: 56 additions & 254 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import numpy as np
2-
from skimage.exposure import rescale_intensity, equalize_adapthist
2+
from skimage.exposure import rescale_intensity
33
from skimage.io import imsave
44
from scipy import ndimage as ndi
5-
from skimage.morphology import erosion
6-
from skimage.segmentation import join_segmentations, watershed, relabel_sequential
5+
from skimage.segmentation import watershed, relabel_sequential
76

87

98
def sharpen(image):
@@ -16,265 +15,74 @@ def adaptive_hist(image):
1615
q10 = np.percentile(image, 10.0)
1716
image[image < q10] = q10
1817
return image
19-
img_adapteq = equalize_adapthist(image)
20-
return img_adapteq
2118

2219

23-
def solve_conflict_watershed(cyto_img, nuc_img):
24-
distance = ndi.distance_transform_edt(cyto_img)
25-
wl = watershed(-distance, np.where(cyto_img > 0, nuc_img, 0), mask=cyto_img)
26-
return wl
20+
def assign_cyto_region(cyto_mask, seg_nuclei, result, fill_only_unfilled=False):
21+
target = cyto_mask if not fill_only_unfilled else (cyto_mask & (result == 0))
22+
if not target.any():
23+
return
2724

25+
seeds = np.where(cyto_mask, seg_nuclei, 0)
26+
nuc_ids = np.unique(seeds)
27+
nuc_ids = nuc_ids[nuc_ids != 0]
2828

29-
def solve_conflict_expand(to_fix_img, expanded_img):
30-
to_fix_img = np.where(to_fix_img != 0, expanded_img, 0)
31-
return to_fix_img
29+
if len(nuc_ids) == 0:
30+
dilated_labeled = ndi.grey_dilation(seg_nuclei, size=3)
31+
seeds = np.where(cyto_mask, dilated_labeled, 0)
32+
nuc_ids = np.unique(seeds)
33+
nuc_ids = nuc_ids[nuc_ids != 0]
34+
if len(nuc_ids) == 0:
35+
return
3236

33-
34-
def merge_segmentations(seg_nuclei, seg_cyto1, seg_cyto2, nuc_diameter):
35-
merged_masks = None
36-
seg_nuc = np.zeros_like(seg_nuclei)
37-
for label in np.unique(seg_nuclei):
38-
if label != 0:
39-
nuc_temp = np.zeros_like(seg_nuclei)
40-
nuc_temp = np.where(seg_nuclei == label, seg_nuclei, 0)
41-
for i in range(int(nuc_diameter / 10)):
42-
nuc_temp = erosion(nuc_temp)
43-
seg_nuc = np.where(nuc_temp != 0, nuc_temp, seg_nuc)
44-
45-
conflicting_seg1_1cytNnuc = {}
46-
conflicting_seg1_nuc_list = set()
47-
merged_cyto1_masks, seg1_ms2n, seg1_ms2c = join_segmentations(seg_nuc, seg_cyto1, return_mapping=True)
48-
49-
seg1_cyt2shape = {}
50-
seg1_nuc2shape = {}
51-
seg1_cyt2nuc = {}
52-
seg1_nuc2cyt = {}
53-
for label in seg1_ms2c.in_values:
54-
if label != 0:
55-
if seg1_ms2c[label] != 0:
56-
if seg1_ms2c[label] not in seg1_cyt2shape:
57-
seg1_cyt2shape[seg1_ms2c[label]] = set()
58-
seg1_cyt2nuc[seg1_ms2c[label]] = set()
59-
if label not in seg1_cyt2shape[seg1_ms2c[label]]: seg1_cyt2shape[seg1_ms2c[label]].add(label)
60-
if seg1_ms2n[label] != 0 and seg1_ms2n[label] not in seg1_cyt2nuc[seg1_ms2c[label]]:
61-
seg1_cyt2nuc[seg1_ms2c[label]].add(seg1_ms2n[label])
62-
else:
63-
merged_cyto1_masks[merged_cyto1_masks == label] = 0
64-
for label in seg1_ms2n.in_values:
65-
if label != 0 and seg1_ms2n[label] != 0:
66-
if seg1_ms2n[label] not in seg1_nuc2shape:
67-
seg1_nuc2shape[seg1_ms2n[label]] = set()
68-
seg1_nuc2cyt[seg1_ms2n[label]] = set()
69-
if label not in seg1_nuc2shape[seg1_ms2n[label]]: seg1_nuc2shape[seg1_ms2n[label]].add(label)
70-
if seg1_ms2c[label] != 0 and seg1_ms2c[label] not in seg1_nuc2cyt[seg1_ms2n[label]]:
71-
seg1_nuc2cyt[seg1_ms2n[label]].add(seg1_ms2c[label])
72-
73-
for cyt in seg1_cyt2nuc:
74-
if len(seg1_cyt2nuc[cyt]) == 0:
75-
for shape_label in seg1_cyt2shape[cyt]:
76-
merged_cyto1_masks[merged_cyto1_masks == shape_label] = 0
77-
del seg1_cyt2shape[cyt]
37+
if len(nuc_ids) == 1:
38+
result[target] = nuc_ids[0]
39+
else:
40+
dist = ndi.distance_transform_edt(cyto_mask)
41+
wl = watershed(-dist, seeds, mask=cyto_mask)
42+
if fill_only_unfilled:
43+
result[target] = wl[target]
7844
else:
79-
if len(seg1_cyt2nuc[cyt]) > 1:
80-
conflicting_seg1_1cytNnuc[cyt] = seg1_cyt2nuc[cyt]
81-
for curr_conf_nuc in seg1_cyt2nuc[cyt]:
82-
conflicting_seg1_nuc_list.add(curr_conf_nuc)
83-
84-
current_shape = seg1_cyt2shape[cyt].copy().pop()
85-
for shape_label in seg1_cyt2shape[cyt]:
86-
merged_cyto1_masks[merged_cyto1_masks == shape_label] = current_shape
87-
seg1_cyt2shape[cyt] = current_shape
88-
89-
if seg_cyto2 is not None:
90-
conflicting_seg2_1cytNnuc = {}
91-
conflicting_seg2_nuc_list = set()
92-
merged_cyto2_masks, seg2_ms2n, seg2_ms2c = join_segmentations(seg_nuc, seg_cyto2, return_mapping=True)
45+
result[cyto_mask] = wl[cyto_mask]
9346

94-
seg2_cyt2shape = {}
95-
seg2_nuc2shape = {}
96-
seg2_cyt2nuc = {}
97-
seg2_nuc2cyt = {}
98-
for label in seg2_ms2c.in_values:
99-
if label != 0:
100-
if seg2_ms2c[label] != 0:
101-
if seg2_ms2c[label] not in seg2_cyt2shape:
102-
seg2_cyt2shape[seg2_ms2c[label]] = set()
103-
seg2_cyt2nuc[seg2_ms2c[label]] = set()
104-
if label not in seg2_cyt2shape[seg2_ms2c[label]]: seg2_cyt2shape[seg2_ms2c[label]].add(label)
105-
if seg2_ms2n[label] != 0 and seg2_ms2n[label] not in seg2_cyt2nuc[seg2_ms2c[label]]:
106-
seg2_cyt2nuc[seg2_ms2c[label]].add(seg2_ms2n[label])
107-
else:
108-
merged_cyto2_masks[merged_cyto2_masks == label] = 0
109-
for label in seg2_ms2n.in_values:
110-
if label != 0 and seg2_ms2n[label] != 0:
111-
if seg2_ms2n[label] not in seg2_nuc2shape:
112-
seg2_nuc2shape[seg2_ms2n[label]] = set()
113-
seg2_nuc2cyt[seg2_ms2n[label]] = set()
114-
if label not in seg2_nuc2shape[seg2_ms2n[label]]: seg2_nuc2shape[seg2_ms2n[label]].add(label)
115-
if seg2_ms2c[label] != 0 and seg2_ms2c[label] not in seg2_nuc2cyt[seg2_ms2n[label]]:
116-
seg2_nuc2cyt[seg2_ms2n[label]].add(seg2_ms2c[label])
11747

118-
for cyt in seg2_cyt2nuc:
119-
if len(seg2_cyt2nuc[cyt]) == 0:
120-
for shape_label in seg2_cyt2shape[cyt]:
121-
merged_cyto2_masks[merged_cyto2_masks == shape_label] = 0
122-
del seg2_cyt2shape[cyt]
123-
else:
124-
if len(seg2_cyt2nuc[cyt]) > 1:
125-
conflicting_seg2_1cytNnuc[cyt] = seg2_cyt2nuc[cyt]
126-
for curr_conf_nuc in seg2_cyt2nuc[cyt]:
127-
conflicting_seg2_nuc_list.add(curr_conf_nuc)
128-
current_shape = seg2_cyt2shape[cyt].copy().pop()
129-
for shape_label in seg2_cyt2shape[cyt]:
130-
merged_cyto2_masks[merged_cyto2_masks == shape_label] = current_shape
131-
seg2_cyt2shape[cyt] = current_shape
132-
133-
merged_masks, segm_ms2seg1, segm_ms2seg2 = join_segmentations(merged_cyto1_masks, merged_cyto2_masks, return_mapping=True)
134-
135-
merged_shape2final = {}
136-
merged_shape_conflicting2cyt1 = {}
137-
merged_shape_conflicting2cyt2 = {}
138-
139-
for label in segm_ms2seg1.in_values:
140-
if segm_ms2seg1[label] != 0 and seg1_ms2c[segm_ms2seg1[label]] != 0:
141-
if segm_ms2seg2[label] != 0 and seg2_ms2c[segm_ms2seg2[label]] != 0:
142-
seg1_shape = segm_ms2seg1[label]
143-
seg2_shape = segm_ms2seg2[label]
144-
seg_cyt1_final = 0
145-
seg_cyt2_final = 0
146-
for seg1_cyt in seg1_cyt2shape:
147-
if seg1_cyt2shape[seg1_cyt] == seg1_shape:
148-
if seg1_cyt in conflicting_seg1_1cytNnuc:
149-
merged_shape_conflicting2cyt1[label] = seg1_cyt
150-
else:
151-
seg_cyt1_final = seg1_cyt
152-
break
153-
for seg2_cyt in seg2_cyt2shape:
154-
if seg2_cyt2shape[seg2_cyt] == seg2_shape:
155-
if seg2_cyt in conflicting_seg2_1cytNnuc:
156-
merged_shape_conflicting2cyt2[label] = seg2_cyt
157-
else:
158-
seg_cyt2_final = seg2_cyt
159-
break
160-
if seg_cyt1_final != 0:
161-
merged_shape2final[label] = seg1_cyt2nuc[seg_cyt1_final].copy().pop() * 1000
162-
if label in merged_shape_conflicting2cyt2:
163-
del merged_shape_conflicting2cyt2[label]
164-
elif seg_cyt2_final != 0:
165-
merged_shape2final[label] = seg2_cyt2nuc[seg_cyt2_final].copy().pop() * 1000
166-
if label in merged_shape_conflicting2cyt1:
167-
del merged_shape_conflicting2cyt1[label]
168-
else:
169-
seg1_shape = segm_ms2seg1[label]
170-
for seg1_cyt in seg1_cyt2shape:
171-
if seg1_cyt2shape[seg1_cyt] == seg1_shape:
172-
if seg1_cyt not in conflicting_seg1_1cytNnuc:
173-
merged_shape2final[label] = seg1_cyt2nuc[seg1_cyt].copy().pop() * 1000
174-
else:
175-
merged_shape_conflicting2cyt1[label] = seg1_cyt
176-
elif segm_ms2seg2[label] != 0 and seg2_ms2c[segm_ms2seg2[label]] != 0:
177-
seg2_shape = segm_ms2seg2[label]
178-
for seg2_cyt in seg2_cyt2shape:
179-
if seg2_cyt2shape[seg2_cyt] == seg2_shape:
180-
if seg2_cyt not in conflicting_seg2_1cytNnuc:
181-
merged_shape2final[label] = seg2_cyt2nuc[seg2_cyt].copy().pop() * 1000
182-
else:
183-
merged_shape_conflicting2cyt2[label] = seg2_cyt
184-
185-
merged_masks_simple = np.zeros_like(merged_masks)
186-
for shape in merged_shape2final:
187-
merged_masks_simple = np.where(merged_masks == shape, merged_shape2final[shape], merged_masks_simple)
48+
def merge_segmentations(seg_nuclei, seg_cyto1, seg_cyto2, nuc_diameter):
49+
height, width = seg_nuclei.shape
50+
result = np.zeros((height, width), dtype=np.int32)
51+
unique_nucs = sorted(n for n in np.unique(seg_nuclei) if n != 0)
52+
53+
if seg_cyto1 is not None:
54+
nuc_primary_cyto1 = {}
55+
for nuc in unique_nucs:
56+
nuc_pixels = seg_cyto1[seg_nuclei == nuc]
57+
vals, cnts = np.unique(nuc_pixels, return_counts=True)
58+
nz = vals != 0
59+
if nz.any():
60+
nuc_primary_cyto1[nuc] = vals[nz][np.argmax(cnts[nz])]
61+
62+
for c1 in set(nuc_primary_cyto1.values()):
63+
assign_cyto_region(seg_cyto1 == c1, seg_nuclei, result, fill_only_unfilled=False)
18864

189-
for label in segm_ms2seg1.in_values:
190-
if label in merged_shape_conflicting2cyt1:
191-
if label in merged_shape_conflicting2cyt2:
192-
# 3 body problem
193-
print("3 body problem")
194-
else:
195-
seg1_cyt = merged_shape_conflicting2cyt1[label]
196-
affected_shapes = set()
197-
affected_nuclei = set()
198-
for seg1_nuc in conflicting_seg1_1cytNnuc[seg1_cyt]:
199-
affected_nuclei.add(seg1_nuc)
200-
for curr_shape in segm_ms2seg1.in_values:
201-
if segm_ms2seg1[curr_shape] == seg1_cyt2shape[seg1_cyt]:
202-
affected_shapes.add(curr_shape)
203-
if len(affected_shapes) > 0:
204-
cyto_img = np.zeros_like(merged_masks)
205-
for shape_label in affected_shapes:
206-
cyto_img = np.where(merged_masks == shape_label, shape_label, cyto_img)
207-
if shape_label in merged_shape_conflicting2cyt1:
208-
del merged_shape_conflicting2cyt1[shape_label]
209-
nuc_img = np.zeros_like(merged_masks)
210-
for nuc_label in affected_nuclei:
211-
nuc_img = np.where(seg_nuc == nuc_label, seg_nuc, nuc_img)
212-
wl = solve_conflict_watershed(cyto_img, nuc_img)
213-
for curr_shape in affected_shapes:
214-
if curr_shape in merged_shape2final:
215-
wl = np.where(merged_masks_simple == merged_shape2final[curr_shape], 0, wl)
216-
merged_masks_simple = np.where(wl != 0, wl * 1000, merged_masks_simple)
217-
elif label in merged_shape_conflicting2cyt2:
218-
seg2_cyt = merged_shape_conflicting2cyt2[label]
219-
affected_shapes = set()
220-
affected_nuclei = set()
221-
for seg2_nuc in conflicting_seg2_1cytNnuc[seg2_cyt]:
222-
affected_nuclei.add(seg2_nuc)
223-
for curr_shape in segm_ms2seg2.in_values:
224-
if segm_ms2seg2[curr_shape] == seg2_cyt2shape[seg2_cyt]:
225-
affected_shapes.add(curr_shape)
226-
if len(affected_shapes) > 0:
227-
cyto_img = np.zeros_like(merged_masks)
228-
for shape_label in affected_shapes:
229-
cyto_img = np.where(merged_masks == shape_label, shape_label, cyto_img)
230-
if shape_label in merged_shape_conflicting2cyt2:
231-
del merged_shape_conflicting2cyt2[shape_label]
232-
nuc_img = np.zeros_like(merged_masks)
233-
for nuc_label in affected_nuclei:
234-
nuc_img = np.where(seg_nuc == nuc_label, seg_nuc, nuc_img)
235-
wl = solve_conflict_watershed(cyto_img, nuc_img)
236-
for curr_shape in affected_shapes:
237-
if curr_shape in merged_shape2final:
238-
wl = np.where(merged_masks_simple == merged_shape2final[curr_shape], 0, wl)
239-
merged_masks_simple = np.where(wl != 0, wl * 1000, merged_masks_simple)
65+
if seg_cyto2 is not None:
66+
nuc_primary_cyto2 = {}
67+
for nuc in unique_nucs:
68+
nuc_pixels = seg_cyto2[seg_nuclei == nuc]
69+
vals, cnts = np.unique(nuc_pixels, return_counts=True)
70+
nz = vals != 0
71+
if nz.any():
72+
nuc_primary_cyto2[nuc] = vals[nz][np.argmax(cnts[nz])]
24073

241-
merged_masks = merged_masks_simple
242-
else:
243-
merged_masks = merged_cyto1_masks
74+
for c2 in set(nuc_primary_cyto2.values()):
75+
assign_cyto_region(seg_cyto2 == c2, seg_nuclei, result, fill_only_unfilled=True)
24476

245-
fixed_nuclei = set()
246-
fixed_shapes = set()
247-
for seg1_conf_nuc in conflicting_seg1_nuc_list:
248-
if seg1_conf_nuc not in fixed_nuclei:
249-
affected_shapes = set()
250-
affected_nuclei = set()
251-
for seg1_conf_cyt in conflicting_seg1_1cytNnuc:
252-
if seg1_conf_nuc in conflicting_seg1_1cytNnuc[seg1_conf_cyt]:
253-
for curr_nuc in conflicting_seg1_1cytNnuc[seg1_conf_cyt]:
254-
affected_nuclei.add(curr_nuc)
255-
for extra_cyt in seg1_nuc2cyt[curr_nuc]:
256-
affected_shapes.add(seg1_cyt2shape[extra_cyt])
257-
for curr_shape in seg1_ms2c.in_values:
258-
if seg1_ms2c[curr_shape] == seg1_cyt2shape[seg1_conf_cyt]:
259-
affected_shapes.add(curr_shape)
260-
cyto_img = np.zeros_like(merged_masks)
261-
for shape_label in affected_shapes:
262-
cyto_img = np.where(merged_masks == shape_label, shape_label, cyto_img)
263-
nuc_img = np.zeros_like(merged_masks)
264-
for nuc_label in affected_nuclei:
265-
nuc_img = np.where(seg_nuc == nuc_label, seg_nuc, nuc_img)
266-
wl = solve_conflict_watershed(cyto_img, nuc_img)
267-
merged_masks = np.where(wl != 0, wl * 1000, merged_masks)
268-
fixed_nuclei.update(affected_nuclei)
269-
fixed_shapes.update(affected_shapes)
77+
for nuc in unique_nucs:
78+
result[seg_nuclei == nuc] = nuc
27079

271-
return relabel_sequential(merged_masks)[0]
80+
return relabel_sequential(result)[0]
27281

27382

274-
# Code to generate the segmentation
275-
def segment(model_nuc, model_cyto, nuclei_img, cyto_img1, cyto_img2, nuc_diameter, cell_diameter, output_folder, output_prefix):
83+
def segment(model_nuc, model_cyto, nuclei_img, cyto_img1, cyto_img2, nuc_diameter, cell_diameter, output_folder,
84+
output_prefix):
27685
channels = [1, 0]
277-
# We segment the nuclei using cellpose default nuclei model
27886
nuclei_masks, flows, styles = model_nuc.eval(
27987
np.stack([nuclei_img, np.zeros_like(nuclei_img)]),
28088
channels=channels,
@@ -286,7 +94,6 @@ def segment(model_nuc, model_cyto, nuclei_img, cyto_img1, cyto_img2, nuc_diamete
28694
if cyto_img1 is not None:
28795
cell_masks = None
28896

289-
# We segment the 1st cytoplasm marker using cellpose default cyto3 model
29097
cyto1_masks, flows, styles = model_cyto.eval(
29198
np.stack([sharpen(adaptive_hist(cyto_img1)), nuclei_masks]),
29299
channels=channels,
@@ -296,10 +103,7 @@ def segment(model_nuc, model_cyto, nuclei_img, cyto_img1, cyto_img2, nuc_diamete
296103
)
297104
imsave(output_folder + "/" + output_prefix + "cyto1_mask.png", cyto1_masks)
298105

299-
300106
if cyto_img2 is not None:
301-
# In case we are using 2 cytoplasm markers
302-
# We segment the 2nd cytoplasm marker using cellpose default cyto3 model
303107
cyto2_masks, flows, styles = model_cyto.eval(
304108
np.stack([sharpen(adaptive_hist(cyto_img2)), nuclei_masks]),
305109
channels=channels,
@@ -309,10 +113,8 @@ def segment(model_nuc, model_cyto, nuclei_img, cyto_img1, cyto_img2, nuc_diamete
309113
)
310114
imsave(output_folder + "/" + output_prefix + "cyto2_mask.png", cyto2_masks)
311115

312-
# We merge the nuclei and 2 cytoplasm segmentations
313116
cell_masks = merge_segmentations(nuclei_masks, cyto1_masks, cyto2_masks, nuc_diameter)
314117
else:
315-
# We merge the nuclei and cytoplasm segmentations
316118
cell_masks = merge_segmentations(nuclei_masks, cyto1_masks, None, nuc_diameter)
317119

318-
imsave(output_folder + "/" + output_prefix + "cell_mask.png", cell_masks)
120+
imsave(output_folder + "/" + output_prefix + "cell_mask.png", cell_masks)

examples/cellpose_segmentation/process.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818

1919
# If you want to use constants with your script, add them here
2020
config["nuclei_only"] = False
21-
config["nuc_diameter"] = 250
22-
config["cyto_diameter"] = 500
21+
config["nuc_diameter"] = 100
22+
config["cyto_diameter"] = 150
2323

2424
# Log the start time and the final configuration so you can keep track of what you did
2525
config["log"].info('Start: ' + datetime.datetime.now().strftime("%Y/%m/%d %H:%M:%S"))

0 commit comments

Comments
 (0)