Skip to content

Commit 56c8da0

Browse files
committed
[feat] Use CoM for FM fiducial refinement
1 parent 66819f8 commit 56c8da0

6 files changed

Lines changed: 187 additions & 80 deletions

File tree

src/odemis/gui/cont/acquisition/cryo_z_localization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def _on_target_focus_pos(self, target_coordinates: List[float]) -> None:
243243
"""
244244
# Set the target Z ctrl with the focus position
245245
self._panel.ctrl_target_z.SetValue(target_coordinates[2])
246-
save_project(self._tab_data_model.main)
246+
save_project(self._tab_data.main)
247247

248248
def _on_ctrl_target_z_change(self) -> List[float]:
249249
"""

src/odemis/gui/cont/multi_point_correlation.py

Lines changed: 97 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
# This is not related to any particular wxPython version and is most likely permanent.
3636

3737
from odemis import model, util
38-
from odemis.acq.align.tdct import get_optimized_z_gauss, _convert_das_to_numpy_stack, run_tdct_correlation
38+
from odemis.acq.align.tdct import _convert_das_to_numpy_stack, run_tdct_correlation
3939
from odemis.acq.feature import FIBFMCorrelationData, Target, TargetType
4040
from odemis.acq.stream import StaticFluoStream, StaticSEMStream, StaticFIBStream, FluoStream
4141
from odemis.gui import conf
@@ -44,7 +44,7 @@
4444
from odemis.gui.util import call_in_wx_main
4545
from odemis.model import ListVA
4646
from odemis.util.dataio import data_to_static_streams
47-
from odemis.util.interpolation import interpolate_z_stack
47+
from odemis.util.img import get_brightest_channel, compute_center_of_mass
4848
from odemis.util.units import readable_str
4949

5050
# create an enum with column labels and position
@@ -62,16 +62,17 @@ class GridColumns(Enum):
6262
FIDUCIAL_PATTERN = r"^[^-]+-"
6363
RIM_COR_DEFAULT = 0.495 # See MD_RIM_COR. This value works fine for 50x objectives, which are common
6464

65-
# Both functions getPixel3DCoordinates(args*, kwargs*) and getPhysical3DCoordinates(args*, kwargs*) need special
65+
# Both functions get_pixel_3d_coordinates(args*, kwargs*) and get_physical_3d_coordinates(args*, kwargs*) need special
6666
# conditions to convert between physical and pixel coordinate systems in order for multipoint correlation to operate.
6767
# For coordinate conversions, we assume the pixels in 3D are isosymmetric
6868
# i.e. size in pixel[0]=pixel[1]=pixel[2].
69+
COM_ROI_PADDING = 12 # Padding (pixels) for center of mass ROI extraction
70+
# TODO: Could adapt padding based on pixel spacing for more flexibility if needed
6971

70-
def getPixel3DCoordinates(stream: FluoStream, p_pos: Tuple[float, float, float], check_bbox: bool = False) \
72+
def get_pixel_3d_coordinates(stream: FluoStream, p_pos: Tuple[float, float, float], check_bbox: bool = False) \
7173
-> Optional[Tuple[float, float, float]]:
7274
"""
73-
Translate 3D physical coordinates into 3D pixel coordinates. The z coordinate is computed assuming iso-voxel
74-
between x, y and z.
75+
Translate 3D physical coordinates into 3D pixel coordinates.
7576
:param stream: Stream which is used as reference for coordinate conversion
7677
:param p_pos: the position in physical coordinates (m). x and y are the sample position, z is the focus position
7778
:param check_bbox: if True, the function will return None if the position is outside of the image
@@ -84,35 +85,34 @@ def getPixel3DCoordinates(stream: FluoStream, p_pos: Tuple[float, float, float],
8485

8586
raw = stream.raw[0]
8687
md = stream._find_metadata(raw.metadata)
87-
pxs = md.get(model.MD_PIXEL_SIZE, (1e-6, 1e-6))
88+
pxs = md.get(model.MD_PIXEL_SIZE, (1e-6, 1e-6, 1e-6))
8889
# For multipoint correlation, we assume that the pixel size in x is the same as in y
8990
if not util.almost_equal(pxs[0], pxs[1], atol=1e-9):
9091
logging.warning("Pixel size in x and y are not equal while computing pixel coordinates")
9192

92-
# Z position is found by taking into account MD_POS and subtracting it from the physical coordinates.
93-
# Pixel value used for Z enforces the iso-voxel condition between x, y and z. It is not the real pixel value in z.
9493
tpos = md.get(model.MD_POS, (0, 0, 0))
9594
tpos_z = tpos[2] if len(tpos) >= 3 else 0.0
96-
z = (p_pos[2] - tpos_z) / pxs[1]
95+
z = (p_pos[2] - tpos_z) / pxs[2]
9796
pixel_pos = (pixel_pos[0], pixel_pos[1], z)
97+
9898
return pixel_pos
9999

100-
def getPhysical3DCoordinates(stream: FluoStream, pixel_pos: Tuple[float, float, float])\
100+
def get_physical_3d_coordinates(stream: FluoStream, pixel_pos: Tuple[float, float, float])\
101101
-> Optional[Tuple[float, float, float]]:
102102
"""
103-
Translate 3D pixel coordinates into 3D physical coordinates. The z coordinate is computed assuming iso-voxel
104-
between x, y and z.
103+
Translate 3D pixel coordinates into 3D physical coordinates.
105104
:param stream: Stream which is used as reference for coordinate conversion
106105
:param pixel_pos: the position in pixel coordinates (x, y, z)
107106
:returns: the position in physical coordinates (x, y, z) in meters
108107
"""
109108
p_pos = stream.getPhysicalCoordinates(pixel_pos[:2])
110109
raw = stream.raw[0]
111110
md = stream._find_metadata(raw.metadata)
112-
pxs = md.get(model.MD_PIXEL_SIZE, (1e-6, 1e-6))[0:2]
111+
pxs = md.get(model.MD_PIXEL_SIZE, (1e-6, 1e-6, 1e-6))
113112
tpos = md.get(model.MD_POS, (0, 0, 0))
114113
tpos_z = tpos[2] if len(tpos) >= 3 else 0.0
115-
p_pos_z = pixel_pos[2] * pxs[1] + tpos_z
114+
# Account for slice thickness, aka, the z distance between slices
115+
p_pos_z = pixel_pos[2] * pxs[2] + tpos_z
116116
return (p_pos[0], p_pos[1], p_pos_z)
117117

118118
def update_feature_correlation_target(correlation_target: FIBFMCorrelationData,
@@ -146,8 +146,7 @@ def update_feature_correlation_target(correlation_target: FIBFMCorrelationData,
146146
fm_fiducials.sort(key=lambda x: x.index.value)
147147
correlation_target.fm_fiducials = fm_fiducials
148148

149-
acq_conf = conf.get_acqui_conf()
150-
save_project(acq_conf.pj_last_path, tab_data.main.features.value, tab_data.main.overviews.value)
149+
save_project(tab_data.main)
151150

152151
return correlation_target
153152

@@ -176,20 +175,20 @@ def __init__(self, frame):
176175
# Access the correlation points table (wxListCtrl)
177176
self.grid = self._panel.table_grid
178177

179-
# Access the Refine Z text (to check if refine_z is working or not)
180-
self.txt_refinez_active = self._panel.txt_refinez_active
181-
self.txt_refinez_active.Show(True)
178+
# Access the Refine XYZ status text (to check if XYZ targeting is working or not)
179+
self.txt_refine_xyz_active = self._panel.txt_refine_xyz_active
180+
self.txt_refine_xyz_active.Show(True)
182181

183-
# Access the Z-targeting button
184-
self.z_targeting_btn = self._panel.btn_z_targeting
185-
self.z_targeting_btn.Bind(wx.EVT_BUTTON, self._on_z_targeting)
186-
self.z_targeting_btn.Enable(False)
187-
# Disable Z-targeting button if super z stream is available as Z-targeting is not required in that case
188-
self.refinez_active = True
182+
# Access the XYZ-targeting button
183+
self.xyz_targeting_btn = self._panel.btn_xyz_targeting
184+
self.xyz_targeting_btn.Bind(wx.EVT_BUTTON, self._on_xyz_targeting)
185+
self.xyz_targeting_btn.Enable(False)
186+
# Disable XYZ-targeting button if super z stream is available as XYZ-targeting is not required in that case
187+
self.refine_xyz_active = True
189188
if self._tab_data_model.main.currentFeature.value.superz_stream_name:
190-
self.z_targeting_btn.SetToolTip("Super Z information available, Refine Z disabled")
191-
self.txt_refinez_active.SetLabel("Super Z information in use")
192-
self.refinez_active = False
189+
self.xyz_targeting_btn.SetToolTip("Super Z information available, Refine XYZ disabled")
190+
self.txt_refine_xyz_active.SetLabel("Super Z information in use")
191+
self.refine_xyz_active = False
193192

194193
self._panel.btn_delete_row.Bind(wx.EVT_BUTTON, self._on_delete_row)
195194

@@ -407,7 +406,7 @@ def check_correlation_conditions(self) -> bool:
407406
self._panel.Layout()
408407
# Update the FIB viewport because it shows the output overlays
409408
# It is the second viewport out of total two viewports
410-
self._viewports[1].canvas.update_drawing()
409+
self._viewports[1].canvas.request_drawing_update()
411410
return False
412411
else:
413412
return False
@@ -478,13 +477,15 @@ def _do_correlation(self):
478477
fib_coords.append(fib_coord)
479478
fib_coords = numpy.array(fib_coords, dtype=numpy.float32)
480479
for fm_coord in self.correlation_target.fm_fiducials:
481-
fm_coord_px = getPixel3DCoordinates(self.correlation_target.fm_streams[0], fm_coord.coordinates.value)
480+
fm_coord_px = get_pixel_3d_coordinates(self.correlation_target.fm_streams[0], fm_coord.coordinates.value)
482481
fm_coords.append(fm_coord_px)
483482
fm_coords = numpy.array(fm_coords, dtype=numpy.float32)
484483
poi_coord = self.correlation_target.fm_pois[0]
485-
poi_coord_px = getPixel3DCoordinates(self.correlation_target.fm_streams[0], poi_coord.coordinates.value)
484+
poi_coord_px = get_pixel_3d_coordinates(self.correlation_target.fm_streams[0], poi_coord.coordinates.value)
486485
poi_coords.append(poi_coord_px)
487486
poi_coords = numpy.array(poi_coords, dtype=numpy.float32)
487+
# Fixing seed, and thus basically resetting randomness, to get more consistent results
488+
numpy.random.seed(0)
488489
# Run the correlation
489490
self.correlation_target.correlation_result = run_tdct_correlation(fib_coords=fib_coords, fm_coords=fm_coords,
490491
poi_coords=poi_coords,
@@ -563,7 +564,7 @@ def _on_cell_selected(self, event) -> None:
563564
break
564565

565566
for vp in self._viewports:
566-
vp.canvas.update_drawing()
567+
vp.canvas.request_drawing_update()
567568

568569
# Highlight the selected row
569570
# Note: as of wxPython 4.1, when AppendRow() is called on an empty grid, this event is
@@ -663,7 +664,7 @@ def _on_cell_changing(self, event) -> None:
663664
elif col_name == GridColumns.Z.name and (
664665
self._tab_data_model.main.currentTarget.value.type.value != TargetType.FibFiducial):
665666
self._tab_data_model.main.currentTarget.value.coordinates.value[2] = \
666-
getPhysical3DCoordinates(self.correlation_target.fm_streams[0], (x, y, float(new_value)))[2]
667+
get_physical_3d_coordinates(self.correlation_target.fm_streams[0], (x, y, float(new_value)))[2]
667668
except ValueError:
668669
wx.MessageBox("X, Y, Z values must be a float!", "Invalid Input", wx.OK | wx.ICON_ERROR)
669670
event.Veto() # Prevent the change
@@ -675,7 +676,7 @@ def _on_cell_changed(self, event) -> None:
675676
"""Get the cell column of the modified cell and reorder the grid/table based on the index column."""
676677
# Refresh the canvas to update the target overlays in the viewports
677678
for vp in self._viewports:
678-
vp.canvas.update_drawing()
679+
vp.canvas.request_drawing_update()
679680
# If the index column is modified, reorder the grid based on the index column
680681
col = event.GetCol()
681682
col_name = self.grid.GetColLabelValue(col)
@@ -695,27 +696,33 @@ def _on_current_target_changes(self, target: Target) -> None:
695696
# For new targets, automatically perform Z targeting if MIP is checked for at least one FM stream
696697
if not target:
697698
self.grid.ClearSelection()
698-
self.z_targeting_btn.Enable(False)
699+
self.xyz_targeting_btn.Enable(False)
699700
return
700701

701-
mip_enabled = any([stream.max_projection.value for stream in self.correlation_target.fm_streams])
702+
# Only check for visible and non-reflective streams, since the refinement method is discarding other streams
703+
# anyway. Doesn't make sense to check for MIP value of a stream that is not used in refining a fiducial.
704+
visible_streams = [
705+
s.stream for s in self._tab_data_model.views.value[0].stream_tree.flat.value
706+
if s.stream.getRawMetadata()[0].get(model.MD_OUT_WL) != model.BAND_PASS_THROUGH
707+
]
708+
mip_enabled = any([stream.max_projection.value for stream in visible_streams])
702709

703-
# Refine z should be disabled if the the Z information was obtained using SuperZ
704-
if self.refinez_active and (target.type.value in self.grid_targets):
710+
# Refine xyz should be disabled if the Z information was obtained using SuperZ
711+
if self.refine_xyz_active and (target.type.value in self.grid_targets):
705712
if TargetType.FibFiducial == target.type.value:
706-
self.z_targeting_btn.Enable(False)
713+
self.xyz_targeting_btn.Enable(False)
707714
else:
708-
self.z_targeting_btn.Enable(True)
715+
self.xyz_targeting_btn.Enable(True)
709716
if mip_enabled:
710-
self._on_z_targeting(None)
717+
self._on_xyz_targeting(None)
711718

712719
for row in range(self.grid.GetNumberRows()):
713720
if self._selected_target_in_grid(target, row):
714721
self.grid.SelectRow(row)
715722
break
716723

717724
for vp in self._viewports:
718-
vp.canvas.update_drawing()
725+
vp.canvas.request_drawing_update()
719726

720727
if self._subscribed_target is not None:
721728
self._subscribed_target.coordinates.unsubscribe(self._on_current_coordinates_changes)
@@ -739,7 +746,7 @@ def _on_current_coordinates_changes(self, coordinates: ListVA) -> None:
739746
pixel_coords = self.correlation_target.fib_stream.getPixelCoordinates(
740747
(target.coordinates.value[0], target.coordinates.value[1]), check_bbox=False)
741748
else:
742-
pixel_coords = getPixel3DCoordinates(self.correlation_target.fm_streams[0], target.coordinates.value)
749+
pixel_coords = get_pixel_3d_coordinates(self.correlation_target.fm_streams[0], target.coordinates.value)
743750
if (self.grid.GetCellValue(row,
744751
GridColumns.Z.value)) != f"{pixel_coords[2]:.{GRID_PRECISION}f}":
745752
temp_check = True
@@ -757,6 +764,9 @@ def _on_current_coordinates_changes(self, coordinates: ListVA) -> None:
757764
if self.check_correlation_conditions() and (temp_check or target.type.value == TargetType.SurfaceFiducial):
758765
self._need_reprocessing()
759766

767+
for vp in self._viewports:
768+
vp.canvas.request_drawing_update()
769+
760770
@call_in_wx_main
761771
def _on_target_changes(self, targets: List[Target]) -> None:
762772
"""
@@ -783,7 +793,7 @@ def _on_target_changes(self, targets: List[Target]) -> None:
783793
(target.coordinates.value[0], target.coordinates.value[1]), check_bbox=False)
784794
self.grid.SetCellValue(current_row_count, GridColumns.Z.value, "")
785795
else:
786-
pixel_coords = getPixel3DCoordinates(self.correlation_target.fm_streams[0], target.coordinates.value)
796+
pixel_coords = get_pixel_3d_coordinates(self.correlation_target.fm_streams[0], target.coordinates.value)
787797
self.grid.SetCellValue(current_row_count, GridColumns.Z.value,
788798
f"{pixel_coords[2]:.{GRID_PRECISION}f}")
789799
# Set x and y position in the grid
@@ -801,40 +811,58 @@ def _on_target_changes(self, targets: List[Target]) -> None:
801811
self._panel.Layout()
802812
self.correlation_target = update_feature_correlation_target(self.correlation_target, self._tab_data_model)
803813

804-
for vp in self._viewports:
805-
vp.canvas.update_drawing()
806-
807814
if self.check_correlation_conditions():
808815
self._need_reprocessing()
809816

810-
def _on_z_targeting(self, evt) -> None:
817+
for vp in self._viewports:
818+
vp.canvas.request_drawing_update()
819+
820+
def _on_xyz_targeting(self, evt) -> None:
811821
"""
812-
Handle Z-targeting when the Z-targeting button is clicked.
822+
Handle targeting when the targeting button is clicked, or automatically triggered for MIP streams.
823+
Performs 3D Center of Mass targeting (X, Y, Z).
813824
"""
814825
if self._tab_data_model.main.currentTarget.value:
815-
816-
# Select the streams which are visible in the view for Z-targeting
817-
streams_projections = self._tab_data_model.views.value[0].stream_tree.flat.value
818-
if not streams_projections:
819-
wx.MessageBox("FM streams are not available for refining Z", "Error", wx.OK | wx.ICON_ERROR)
826+
# Select the non-reflective streams visible in the view for targeting
827+
streams = [
828+
s.stream for s in self._tab_data_model.views.value[0].stream_tree.flat.value
829+
if s.stream.getRawMetadata()[0].get(model.MD_OUT_WL) != model.BAND_PASS_THROUGH
830+
]
831+
if not streams:
832+
wx.MessageBox("FM streams are not available for refining targets", "Error", wx.OK | wx.ICON_ERROR)
820833
return
821834

822-
self.txt_refinez_active.SetLabel("active ...")
823-
wx.CallLater(1000, self.txt_refinez_active.SetLabel, "")
835+
self.txt_refine_xyz_active.SetLabel("active ...")
836+
wx.CallLater(1000, self.txt_refine_xyz_active.SetLabel, "")
824837

825838
coords = self._tab_data_model.main.currentTarget.value.coordinates.value
826-
pixel_coords = getPixel3DCoordinates(self.correlation_target.fm_streams[0], coords)
827-
das = [interpolate_z_stack(da=stream_projection.stream.raw[0]
828-
[:,
829-
int(pixel_coords[1]):int(pixel_coords[1])+1,
830-
int(pixel_coords[0]):int(pixel_coords[0])+1],
831-
method="linear")
832-
for stream_projection in streams_projections]
833-
834-
z = float(get_optimized_z_gauss(das, int(0), int(0), int(pixel_coords[2])))
835-
z_p = getPhysical3DCoordinates(self.correlation_target.fm_streams[0],
836-
(pixel_coords[0],pixel_coords[1], z))[2]
837-
self._tab_data_model.main.currentTarget.value.coordinates.value[2] = z_p
839+
pixel_coords = get_pixel_3d_coordinates(streams[0], coords)
840+
841+
# We are going to refine around the clicked position
842+
target_x, target_y = int(pixel_coords[0]), int(pixel_coords[1])
843+
# Ensure multi-channel compatibility
844+
raw_multi = numpy.asarray([s.raw[0] for s in streams])
845+
shape_y, shape_x = raw_multi.shape[-2], raw_multi.shape[-1]
846+
# Get boundary-safe slice & crop
847+
y_start = max(0, target_y - COM_ROI_PADDING)
848+
y_end = min(shape_y, target_y + COM_ROI_PADDING + 1)
849+
x_start = max(0, target_x - COM_ROI_PADDING)
850+
x_end = min(shape_x, target_x + COM_ROI_PADDING + 1)
851+
roi = numpy.s_[:, y_start:y_end, x_start:x_end]
852+
multi_crop = raw_multi[(slice(None),) + roi] # We search along all stack slices (first axis)
853+
# Find best channel and compute COM
854+
best_c = get_brightest_channel(multi_crop)
855+
com = compute_center_of_mass(multi_crop[best_c], baseline_percentile=95.0)
856+
com_z = com[0]
857+
com_y_crop = com[1] + roi[1].start
858+
com_x_crop = com[2] + roi[2].start
859+
# Map back to physical coordinates using optimized X, Y, and Z
860+
physical_coords = get_physical_3d_coordinates(streams[0],(com_x_crop, com_y_crop, com_z))
861+
# Update the model with the refined 3D coordinates
862+
target_coords = self._tab_data_model.main.currentTarget.value.coordinates.value
863+
target_coords[0] = physical_coords[0]
864+
target_coords[1] = physical_coords[1]
865+
target_coords[2] = physical_coords[2]
838866

839867
def _reorder_grid(self) -> None:
840868
"""

0 commit comments

Comments
 (0)