11import numpy as np
2- from skimage .exposure import rescale_intensity , equalize_adapthist
2+ from skimage .exposure import rescale_intensity
33from skimage .io import imsave
44from scipy import ndimage as ndi
5- from skimage .morphology import erosion
6- from skimage .segmentation import join_segmentations , watershed , relabel_sequential
5+ from skimage .segmentation import watershed , relabel_sequential
76
87
98def sharpen (image ):
@@ -16,265 +15,74 @@ def adaptive_hist(image):
1615 q10 = np .percentile (image , 10.0 )
1716 image [image < q10 ] = q10
1817 return image
19- img_adapteq = equalize_adapthist (image )
20- return img_adapteq
2118
2219
23- def solve_conflict_watershed ( cyto_img , nuc_img ):
24- distance = ndi . distance_transform_edt ( cyto_img )
25- wl = watershed ( - distance , np . where ( cyto_img > 0 , nuc_img , 0 ), mask = cyto_img )
26- return wl
20+ def assign_cyto_region ( cyto_mask , seg_nuclei , result , fill_only_unfilled = False ):
21+ target = cyto_mask if not fill_only_unfilled else ( cyto_mask & ( result == 0 ) )
22+ if not target . any ():
23+ return
2724
25+ seeds = np .where (cyto_mask , seg_nuclei , 0 )
26+ nuc_ids = np .unique (seeds )
27+ nuc_ids = nuc_ids [nuc_ids != 0 ]
2828
29- def solve_conflict_expand (to_fix_img , expanded_img ):
30- to_fix_img = np .where (to_fix_img != 0 , expanded_img , 0 )
31- return to_fix_img
29+ if len (nuc_ids ) == 0 :
30+ dilated_labeled = ndi .grey_dilation (seg_nuclei , size = 3 )
31+ seeds = np .where (cyto_mask , dilated_labeled , 0 )
32+ nuc_ids = np .unique (seeds )
33+ nuc_ids = nuc_ids [nuc_ids != 0 ]
34+ if len (nuc_ids ) == 0 :
35+ return
3236
33-
34- def merge_segmentations (seg_nuclei , seg_cyto1 , seg_cyto2 , nuc_diameter ):
35- merged_masks = None
36- seg_nuc = np .zeros_like (seg_nuclei )
37- for label in np .unique (seg_nuclei ):
38- if label != 0 :
39- nuc_temp = np .zeros_like (seg_nuclei )
40- nuc_temp = np .where (seg_nuclei == label , seg_nuclei , 0 )
41- for i in range (int (nuc_diameter / 10 )):
42- nuc_temp = erosion (nuc_temp )
43- seg_nuc = np .where (nuc_temp != 0 , nuc_temp , seg_nuc )
44-
45- conflicting_seg1_1cytNnuc = {}
46- conflicting_seg1_nuc_list = set ()
47- merged_cyto1_masks , seg1_ms2n , seg1_ms2c = join_segmentations (seg_nuc , seg_cyto1 , return_mapping = True )
48-
49- seg1_cyt2shape = {}
50- seg1_nuc2shape = {}
51- seg1_cyt2nuc = {}
52- seg1_nuc2cyt = {}
53- for label in seg1_ms2c .in_values :
54- if label != 0 :
55- if seg1_ms2c [label ] != 0 :
56- if seg1_ms2c [label ] not in seg1_cyt2shape :
57- seg1_cyt2shape [seg1_ms2c [label ]] = set ()
58- seg1_cyt2nuc [seg1_ms2c [label ]] = set ()
59- if label not in seg1_cyt2shape [seg1_ms2c [label ]]: seg1_cyt2shape [seg1_ms2c [label ]].add (label )
60- if seg1_ms2n [label ] != 0 and seg1_ms2n [label ] not in seg1_cyt2nuc [seg1_ms2c [label ]]:
61- seg1_cyt2nuc [seg1_ms2c [label ]].add (seg1_ms2n [label ])
62- else :
63- merged_cyto1_masks [merged_cyto1_masks == label ] = 0
64- for label in seg1_ms2n .in_values :
65- if label != 0 and seg1_ms2n [label ] != 0 :
66- if seg1_ms2n [label ] not in seg1_nuc2shape :
67- seg1_nuc2shape [seg1_ms2n [label ]] = set ()
68- seg1_nuc2cyt [seg1_ms2n [label ]] = set ()
69- if label not in seg1_nuc2shape [seg1_ms2n [label ]]: seg1_nuc2shape [seg1_ms2n [label ]].add (label )
70- if seg1_ms2c [label ] != 0 and seg1_ms2c [label ] not in seg1_nuc2cyt [seg1_ms2n [label ]]:
71- seg1_nuc2cyt [seg1_ms2n [label ]].add (seg1_ms2c [label ])
72-
73- for cyt in seg1_cyt2nuc :
74- if len (seg1_cyt2nuc [cyt ]) == 0 :
75- for shape_label in seg1_cyt2shape [cyt ]:
76- merged_cyto1_masks [merged_cyto1_masks == shape_label ] = 0
77- del seg1_cyt2shape [cyt ]
37+ if len (nuc_ids ) == 1 :
38+ result [target ] = nuc_ids [0 ]
39+ else :
40+ dist = ndi .distance_transform_edt (cyto_mask )
41+ wl = watershed (- dist , seeds , mask = cyto_mask )
42+ if fill_only_unfilled :
43+ result [target ] = wl [target ]
7844 else :
79- if len (seg1_cyt2nuc [cyt ]) > 1 :
80- conflicting_seg1_1cytNnuc [cyt ] = seg1_cyt2nuc [cyt ]
81- for curr_conf_nuc in seg1_cyt2nuc [cyt ]:
82- conflicting_seg1_nuc_list .add (curr_conf_nuc )
83-
84- current_shape = seg1_cyt2shape [cyt ].copy ().pop ()
85- for shape_label in seg1_cyt2shape [cyt ]:
86- merged_cyto1_masks [merged_cyto1_masks == shape_label ] = current_shape
87- seg1_cyt2shape [cyt ] = current_shape
88-
89- if seg_cyto2 is not None :
90- conflicting_seg2_1cytNnuc = {}
91- conflicting_seg2_nuc_list = set ()
92- merged_cyto2_masks , seg2_ms2n , seg2_ms2c = join_segmentations (seg_nuc , seg_cyto2 , return_mapping = True )
45+ result [cyto_mask ] = wl [cyto_mask ]
9346
94- seg2_cyt2shape = {}
95- seg2_nuc2shape = {}
96- seg2_cyt2nuc = {}
97- seg2_nuc2cyt = {}
98- for label in seg2_ms2c .in_values :
99- if label != 0 :
100- if seg2_ms2c [label ] != 0 :
101- if seg2_ms2c [label ] not in seg2_cyt2shape :
102- seg2_cyt2shape [seg2_ms2c [label ]] = set ()
103- seg2_cyt2nuc [seg2_ms2c [label ]] = set ()
104- if label not in seg2_cyt2shape [seg2_ms2c [label ]]: seg2_cyt2shape [seg2_ms2c [label ]].add (label )
105- if seg2_ms2n [label ] != 0 and seg2_ms2n [label ] not in seg2_cyt2nuc [seg2_ms2c [label ]]:
106- seg2_cyt2nuc [seg2_ms2c [label ]].add (seg2_ms2n [label ])
107- else :
108- merged_cyto2_masks [merged_cyto2_masks == label ] = 0
109- for label in seg2_ms2n .in_values :
110- if label != 0 and seg2_ms2n [label ] != 0 :
111- if seg2_ms2n [label ] not in seg2_nuc2shape :
112- seg2_nuc2shape [seg2_ms2n [label ]] = set ()
113- seg2_nuc2cyt [seg2_ms2n [label ]] = set ()
114- if label not in seg2_nuc2shape [seg2_ms2n [label ]]: seg2_nuc2shape [seg2_ms2n [label ]].add (label )
115- if seg2_ms2c [label ] != 0 and seg2_ms2c [label ] not in seg2_nuc2cyt [seg2_ms2n [label ]]:
116- seg2_nuc2cyt [seg2_ms2n [label ]].add (seg2_ms2c [label ])
11747
118- for cyt in seg2_cyt2nuc :
119- if len (seg2_cyt2nuc [cyt ]) == 0 :
120- for shape_label in seg2_cyt2shape [cyt ]:
121- merged_cyto2_masks [merged_cyto2_masks == shape_label ] = 0
122- del seg2_cyt2shape [cyt ]
123- else :
124- if len (seg2_cyt2nuc [cyt ]) > 1 :
125- conflicting_seg2_1cytNnuc [cyt ] = seg2_cyt2nuc [cyt ]
126- for curr_conf_nuc in seg2_cyt2nuc [cyt ]:
127- conflicting_seg2_nuc_list .add (curr_conf_nuc )
128- current_shape = seg2_cyt2shape [cyt ].copy ().pop ()
129- for shape_label in seg2_cyt2shape [cyt ]:
130- merged_cyto2_masks [merged_cyto2_masks == shape_label ] = current_shape
131- seg2_cyt2shape [cyt ] = current_shape
132-
133- merged_masks , segm_ms2seg1 , segm_ms2seg2 = join_segmentations (merged_cyto1_masks , merged_cyto2_masks , return_mapping = True )
134-
135- merged_shape2final = {}
136- merged_shape_conflicting2cyt1 = {}
137- merged_shape_conflicting2cyt2 = {}
138-
139- for label in segm_ms2seg1 .in_values :
140- if segm_ms2seg1 [label ] != 0 and seg1_ms2c [segm_ms2seg1 [label ]] != 0 :
141- if segm_ms2seg2 [label ] != 0 and seg2_ms2c [segm_ms2seg2 [label ]] != 0 :
142- seg1_shape = segm_ms2seg1 [label ]
143- seg2_shape = segm_ms2seg2 [label ]
144- seg_cyt1_final = 0
145- seg_cyt2_final = 0
146- for seg1_cyt in seg1_cyt2shape :
147- if seg1_cyt2shape [seg1_cyt ] == seg1_shape :
148- if seg1_cyt in conflicting_seg1_1cytNnuc :
149- merged_shape_conflicting2cyt1 [label ] = seg1_cyt
150- else :
151- seg_cyt1_final = seg1_cyt
152- break
153- for seg2_cyt in seg2_cyt2shape :
154- if seg2_cyt2shape [seg2_cyt ] == seg2_shape :
155- if seg2_cyt in conflicting_seg2_1cytNnuc :
156- merged_shape_conflicting2cyt2 [label ] = seg2_cyt
157- else :
158- seg_cyt2_final = seg2_cyt
159- break
160- if seg_cyt1_final != 0 :
161- merged_shape2final [label ] = seg1_cyt2nuc [seg_cyt1_final ].copy ().pop () * 1000
162- if label in merged_shape_conflicting2cyt2 :
163- del merged_shape_conflicting2cyt2 [label ]
164- elif seg_cyt2_final != 0 :
165- merged_shape2final [label ] = seg2_cyt2nuc [seg_cyt2_final ].copy ().pop () * 1000
166- if label in merged_shape_conflicting2cyt1 :
167- del merged_shape_conflicting2cyt1 [label ]
168- else :
169- seg1_shape = segm_ms2seg1 [label ]
170- for seg1_cyt in seg1_cyt2shape :
171- if seg1_cyt2shape [seg1_cyt ] == seg1_shape :
172- if seg1_cyt not in conflicting_seg1_1cytNnuc :
173- merged_shape2final [label ] = seg1_cyt2nuc [seg1_cyt ].copy ().pop () * 1000
174- else :
175- merged_shape_conflicting2cyt1 [label ] = seg1_cyt
176- elif segm_ms2seg2 [label ] != 0 and seg2_ms2c [segm_ms2seg2 [label ]] != 0 :
177- seg2_shape = segm_ms2seg2 [label ]
178- for seg2_cyt in seg2_cyt2shape :
179- if seg2_cyt2shape [seg2_cyt ] == seg2_shape :
180- if seg2_cyt not in conflicting_seg2_1cytNnuc :
181- merged_shape2final [label ] = seg2_cyt2nuc [seg2_cyt ].copy ().pop () * 1000
182- else :
183- merged_shape_conflicting2cyt2 [label ] = seg2_cyt
184-
185- merged_masks_simple = np .zeros_like (merged_masks )
186- for shape in merged_shape2final :
187- merged_masks_simple = np .where (merged_masks == shape , merged_shape2final [shape ], merged_masks_simple )
48+ def merge_segmentations (seg_nuclei , seg_cyto1 , seg_cyto2 , nuc_diameter ):
49+ height , width = seg_nuclei .shape
50+ result = np .zeros ((height , width ), dtype = np .int32 )
51+ unique_nucs = sorted (n for n in np .unique (seg_nuclei ) if n != 0 )
52+
53+ if seg_cyto1 is not None :
54+ nuc_primary_cyto1 = {}
55+ for nuc in unique_nucs :
56+ nuc_pixels = seg_cyto1 [seg_nuclei == nuc ]
57+ vals , cnts = np .unique (nuc_pixels , return_counts = True )
58+ nz = vals != 0
59+ if nz .any ():
60+ nuc_primary_cyto1 [nuc ] = vals [nz ][np .argmax (cnts [nz ])]
61+
62+ for c1 in set (nuc_primary_cyto1 .values ()):
63+ assign_cyto_region (seg_cyto1 == c1 , seg_nuclei , result , fill_only_unfilled = False )
18864
189- for label in segm_ms2seg1 .in_values :
190- if label in merged_shape_conflicting2cyt1 :
191- if label in merged_shape_conflicting2cyt2 :
192- # 3 body problem
193- print ("3 body problem" )
194- else :
195- seg1_cyt = merged_shape_conflicting2cyt1 [label ]
196- affected_shapes = set ()
197- affected_nuclei = set ()
198- for seg1_nuc in conflicting_seg1_1cytNnuc [seg1_cyt ]:
199- affected_nuclei .add (seg1_nuc )
200- for curr_shape in segm_ms2seg1 .in_values :
201- if segm_ms2seg1 [curr_shape ] == seg1_cyt2shape [seg1_cyt ]:
202- affected_shapes .add (curr_shape )
203- if len (affected_shapes ) > 0 :
204- cyto_img = np .zeros_like (merged_masks )
205- for shape_label in affected_shapes :
206- cyto_img = np .where (merged_masks == shape_label , shape_label , cyto_img )
207- if shape_label in merged_shape_conflicting2cyt1 :
208- del merged_shape_conflicting2cyt1 [shape_label ]
209- nuc_img = np .zeros_like (merged_masks )
210- for nuc_label in affected_nuclei :
211- nuc_img = np .where (seg_nuc == nuc_label , seg_nuc , nuc_img )
212- wl = solve_conflict_watershed (cyto_img , nuc_img )
213- for curr_shape in affected_shapes :
214- if curr_shape in merged_shape2final :
215- wl = np .where (merged_masks_simple == merged_shape2final [curr_shape ], 0 , wl )
216- merged_masks_simple = np .where (wl != 0 , wl * 1000 , merged_masks_simple )
217- elif label in merged_shape_conflicting2cyt2 :
218- seg2_cyt = merged_shape_conflicting2cyt2 [label ]
219- affected_shapes = set ()
220- affected_nuclei = set ()
221- for seg2_nuc in conflicting_seg2_1cytNnuc [seg2_cyt ]:
222- affected_nuclei .add (seg2_nuc )
223- for curr_shape in segm_ms2seg2 .in_values :
224- if segm_ms2seg2 [curr_shape ] == seg2_cyt2shape [seg2_cyt ]:
225- affected_shapes .add (curr_shape )
226- if len (affected_shapes ) > 0 :
227- cyto_img = np .zeros_like (merged_masks )
228- for shape_label in affected_shapes :
229- cyto_img = np .where (merged_masks == shape_label , shape_label , cyto_img )
230- if shape_label in merged_shape_conflicting2cyt2 :
231- del merged_shape_conflicting2cyt2 [shape_label ]
232- nuc_img = np .zeros_like (merged_masks )
233- for nuc_label in affected_nuclei :
234- nuc_img = np .where (seg_nuc == nuc_label , seg_nuc , nuc_img )
235- wl = solve_conflict_watershed (cyto_img , nuc_img )
236- for curr_shape in affected_shapes :
237- if curr_shape in merged_shape2final :
238- wl = np .where (merged_masks_simple == merged_shape2final [curr_shape ], 0 , wl )
239- merged_masks_simple = np .where (wl != 0 , wl * 1000 , merged_masks_simple )
65+ if seg_cyto2 is not None :
66+ nuc_primary_cyto2 = {}
67+ for nuc in unique_nucs :
68+ nuc_pixels = seg_cyto2 [seg_nuclei == nuc ]
69+ vals , cnts = np .unique (nuc_pixels , return_counts = True )
70+ nz = vals != 0
71+ if nz .any ():
72+ nuc_primary_cyto2 [nuc ] = vals [nz ][np .argmax (cnts [nz ])]
24073
241- merged_masks = merged_masks_simple
242- else :
243- merged_masks = merged_cyto1_masks
74+ for c2 in set (nuc_primary_cyto2 .values ()):
75+ assign_cyto_region (seg_cyto2 == c2 , seg_nuclei , result , fill_only_unfilled = True )
24476
245- fixed_nuclei = set ()
246- fixed_shapes = set ()
247- for seg1_conf_nuc in conflicting_seg1_nuc_list :
248- if seg1_conf_nuc not in fixed_nuclei :
249- affected_shapes = set ()
250- affected_nuclei = set ()
251- for seg1_conf_cyt in conflicting_seg1_1cytNnuc :
252- if seg1_conf_nuc in conflicting_seg1_1cytNnuc [seg1_conf_cyt ]:
253- for curr_nuc in conflicting_seg1_1cytNnuc [seg1_conf_cyt ]:
254- affected_nuclei .add (curr_nuc )
255- for extra_cyt in seg1_nuc2cyt [curr_nuc ]:
256- affected_shapes .add (seg1_cyt2shape [extra_cyt ])
257- for curr_shape in seg1_ms2c .in_values :
258- if seg1_ms2c [curr_shape ] == seg1_cyt2shape [seg1_conf_cyt ]:
259- affected_shapes .add (curr_shape )
260- cyto_img = np .zeros_like (merged_masks )
261- for shape_label in affected_shapes :
262- cyto_img = np .where (merged_masks == shape_label , shape_label , cyto_img )
263- nuc_img = np .zeros_like (merged_masks )
264- for nuc_label in affected_nuclei :
265- nuc_img = np .where (seg_nuc == nuc_label , seg_nuc , nuc_img )
266- wl = solve_conflict_watershed (cyto_img , nuc_img )
267- merged_masks = np .where (wl != 0 , wl * 1000 , merged_masks )
268- fixed_nuclei .update (affected_nuclei )
269- fixed_shapes .update (affected_shapes )
77+ for nuc in unique_nucs :
78+ result [seg_nuclei == nuc ] = nuc
27079
271- return relabel_sequential (merged_masks )[0 ]
80+ return relabel_sequential (result )[0 ]
27281
27382
274- # Code to generate the segmentation
275- def segment ( model_nuc , model_cyto , nuclei_img , cyto_img1 , cyto_img2 , nuc_diameter , cell_diameter , output_folder , output_prefix ):
83+ def segment ( model_nuc , model_cyto , nuclei_img , cyto_img1 , cyto_img2 , nuc_diameter , cell_diameter , output_folder ,
84+ output_prefix ):
27685 channels = [1 , 0 ]
277- # We segment the nuclei using cellpose default nuclei model
27886 nuclei_masks , flows , styles = model_nuc .eval (
27987 np .stack ([nuclei_img , np .zeros_like (nuclei_img )]),
28088 channels = channels ,
@@ -286,7 +94,6 @@ def segment(model_nuc, model_cyto, nuclei_img, cyto_img1, cyto_img2, nuc_diamete
28694 if cyto_img1 is not None :
28795 cell_masks = None
28896
289- # We segment the 1st cytoplasm marker using cellpose default cyto3 model
29097 cyto1_masks , flows , styles = model_cyto .eval (
29198 np .stack ([sharpen (adaptive_hist (cyto_img1 )), nuclei_masks ]),
29299 channels = channels ,
@@ -296,10 +103,7 @@ def segment(model_nuc, model_cyto, nuclei_img, cyto_img1, cyto_img2, nuc_diamete
296103 )
297104 imsave (output_folder + "/" + output_prefix + "cyto1_mask.png" , cyto1_masks )
298105
299-
300106 if cyto_img2 is not None :
301- # In case we are using 2 cytoplasm markers
302- # We segment the 2nd cytoplasm marker using cellpose default cyto3 model
303107 cyto2_masks , flows , styles = model_cyto .eval (
304108 np .stack ([sharpen (adaptive_hist (cyto_img2 )), nuclei_masks ]),
305109 channels = channels ,
@@ -309,10 +113,8 @@ def segment(model_nuc, model_cyto, nuclei_img, cyto_img1, cyto_img2, nuc_diamete
309113 )
310114 imsave (output_folder + "/" + output_prefix + "cyto2_mask.png" , cyto2_masks )
311115
312- # We merge the nuclei and 2 cytoplasm segmentations
313116 cell_masks = merge_segmentations (nuclei_masks , cyto1_masks , cyto2_masks , nuc_diameter )
314117 else :
315- # We merge the nuclei and cytoplasm segmentations
316118 cell_masks = merge_segmentations (nuclei_masks , cyto1_masks , None , nuc_diameter )
317119
318- imsave (output_folder + "/" + output_prefix + "cell_mask.png" , cell_masks )
120+ imsave (output_folder + "/" + output_prefix + "cell_mask.png" , cell_masks )
0 commit comments