44from matplotlib import pyplot as plt
55from matplotlib .backends .backend_qtagg import FigureCanvasQTAgg as FigureCanvas
66from matplotlib .backends .backend_qtagg import NavigationToolbar2QT
7+ from matplotlib .container import ErrorbarContainer
78from PySide6 .QtCore import QObject , QRunnable , Qt , QThreadPool , QTimer , Signal
9+ from PySide6 .QtGui import QAction
810from PySide6 .QtWidgets import (
911 QDockWidget ,
1012 QMenu ,
@@ -22,12 +24,13 @@ class PlotWorkerSignals(QObject):
2224
2325
2426class PlotWorker (QRunnable ):
25- def __init__ (self , vis_df , cond_df , meas_df , sim_df ):
27+ def __init__ (self , vis_df , cond_df , meas_df , sim_df , group_by ):
2628 super ().__init__ ()
2729 self .vis_df = vis_df
2830 self .cond_df = cond_df
2931 self .meas_df = meas_df
3032 self .sim_df = sim_df
33+ self .group_by = group_by
3134 self .signals = PlotWorkerSignals ()
3235
3336 def run (self ):
@@ -61,6 +64,7 @@ def run(self):
6164 self .cond_df ,
6265 measurements_df = self .meas_df ,
6366 simulations_df = sim_df ,
67+ group_by = self .group_by ,
6468 )
6569 fig = plt .gcf ()
6670 fig .subplots_adjust (left = 0.12 , bottom = 0.15 , right = 0.95 , top = 0.9 , wspace = 0.3 , hspace = 0.4 )
@@ -77,6 +81,7 @@ class MeasurementPlotter(QDockWidget):
7781 def __init__ (self , parent = None ):
7882 super ().__init__ ("Measurement Plot" , parent )
7983 self .setObjectName ("plot_dock" )
84+ self .options_manager = ToolbarOptionManager ()
8085
8186 self .meas_proxy = None
8287 self .sim_proxy = None
@@ -102,6 +107,7 @@ def initialize(self, meas_proxy, sim_proxy, cond_proxy):
102107 self .vis_df = None
103108
104109 # Connect data changes
110+ self .options_manager .option_changed .connect (self ._debounced_plot )
105111 self .meas_proxy .dataChanged .connect (self ._debounced_plot )
106112 self .meas_proxy .rowsInserted .connect (self ._debounced_plot )
107113 self .meas_proxy .rowsRemoved .connect (self ._debounced_plot )
@@ -121,16 +127,26 @@ def plot_it(self):
121127 measurements_df = proxy_to_dataframe (self .meas_proxy )
122128 simulations_df = proxy_to_dataframe (self .sim_proxy )
123129 conditions_df = proxy_to_dataframe (self .cond_proxy )
130+ group_by = self .options_manager .get_option ()
131+ # group_by different value in petab.visualize
132+ if group_by == "condition" :
133+ group_by = "simulation"
124134
125135 worker = PlotWorker (
126- self .vis_df , conditions_df , measurements_df , simulations_df
136+ self .vis_df ,
137+ conditions_df ,
138+ measurements_df ,
139+ simulations_df ,
140+ group_by
127141 )
128142 worker .signals .finished .connect (self ._update_tabs )
129143 QThreadPool .globalInstance ().start (worker )
130144
131145 def _update_tabs (self , fig : plt .Figure ):
132146 # Clean previous tabs
133147 self .tab_widget .clear ()
148+ # Clear Highlighter
149+ self .highlighter .clear_highlight ()
134150 if fig is None :
135151 # Fallback: show one empty plot tab
136152 empty_fig , _ = plt .subplots ()
@@ -164,11 +180,18 @@ def _update_tabs(self, fig: plt.Figure):
164180 for idx , ax in enumerate (fig .axes ):
165181 # Create a new figure and copy Axes content
166182 sub_fig , sub_ax = plt .subplots (constrained_layout = True )
167- for line in ax .get_lines ():
183+ handles , labels = ax .get_legend_handles_labels ()
184+ for handle , label in zip (handles , labels , strict = False ):
185+ if isinstance (handle , ErrorbarContainer ):
186+ line = handle .lines [0 ]
187+ elif isinstance (handle , plt .Line2D ):
188+ line = handle
189+ else :
190+ continue
168191 sub_ax .plot (
169192 line .get_xdata (),
170193 line .get_ydata (),
171- label = line . get_label () ,
194+ label = label ,
172195 linestyle = line .get_linestyle (),
173196 marker = line .get_marker (),
174197 color = line .get_color (),
@@ -178,9 +201,7 @@ def _update_tabs(self, fig: plt.Figure):
178201 sub_ax .set_title (ax .get_title ())
179202 sub_ax .set_xlabel (ax .get_xlabel ())
180203 sub_ax .set_ylabel (ax .get_ylabel ())
181- handles , labels = ax .get_legend_handles_labels ()
182- if handles :
183- sub_ax .legend (handles = handles , labels = labels , loc = "best" )
204+ sub_ax .legend ()
184205
185206 sub_canvas = FigureCanvas (sub_fig )
186207 sub_toolbar = CustomNavigationToolbar (sub_canvas , self )
@@ -202,20 +223,18 @@ def _update_tabs(self, fig: plt.Figure):
202223 obs_id = f"subplot_{ idx } "
203224
204225 self .observable_to_subplot [obs_id ] = idx
205- # Also register the original ax from the full figure (main tab)
206226 self .highlighter .register_subplot (ax , idx )
207227 # Register subplot canvas
208228 self .highlighter .register_subplot (sub_ax , idx )
229+ # Also register the original ax from the full figure (main tab)
209230 self .highlighter .connect_picking (sub_canvas )
210231
211232 def highlight_from_selection (self , selected_rows : list [int ], proxy = None , y_axis_col = "measurement" ):
212233 proxy = proxy or self .meas_proxy
213234 if not proxy :
214235 return
215236
216- # x_axis_col = self.x_axis_selector.currentText()
217237 x_axis_col = "time"
218- y_axis_col = "measurement" if proxy == self .meas_proxy else "simulation"
219238 observable_col = "observableId"
220239
221240 def column_index (name ):
@@ -260,6 +279,9 @@ def __init__(self):
260279 self .point_index_map = {} # (subplot index, observableId, x, y) → row index
261280 self .click_callback = None
262281
282+ def clear_highlight (self ):
283+ self .highlight_scatters = defaultdict (list )
284+
263285 def register_subplot (self , ax , subplot_idx ):
264286 scatter = ax .scatter (
265287 [], [], s = 80 , edgecolors = 'black' , facecolors = 'none' , zorder = 5
@@ -293,24 +315,76 @@ def _on_pick(self, event):
293315 ax = artist .axes
294316
295317 # Try to recover the label from the legend (handle → label mapping)
296- label = ax .get_legend ().texts [1 ].get_text ().split ()[- 1 ]
318+ handles , labels = ax .get_legend_handles_labels ()
319+ label = None
320+ for h , l in zip (handles , labels , strict = False ):
321+ if h is artist :
322+ label_parts = l .split ()
323+ if label_parts [- 1 ] == "simulation" :
324+ data_type = "simulation"
325+ label = label_parts [- 2 ]
326+ else :
327+ data_type = "measurement"
328+ label = label_parts [- 1 ]
329+ break
297330
298331 for i in ind :
299332 x = xdata [i ]
300333 y = ydata [i ]
301- self .click_callback (x , y , label )
334+ self .click_callback (x , y , label , data_type )
335+
336+
337+ class ToolbarOptionManager (QObject ):
338+ """A Manager, synchronizing the selected option across all toolbars."""
339+
340+ option_changed = Signal (str )
341+ _instance = None
342+ _initialized = False
343+
344+ def __new__ (cls ):
345+ if cls ._instance is None :
346+ cls ._instance = super (ToolbarOptionManager , cls ).__new__ (cls )
347+ return cls ._instance
348+
349+ def __init__ (self ):
350+ # Ensure QObject.__init__ runs only once
351+ if not self ._initialized :
352+ super ().__init__ ()
353+ self ._selected_option = "observable"
354+ ToolbarOptionManager ._initialized = True
355+
356+ def set_option (self , option ):
357+ if option != self ._selected_option :
358+ self ._selected_option = option
359+ self .option_changed .emit (option )
360+
361+ def get_option (self ):
362+ return self ._selected_option
302363
303364
304365class CustomNavigationToolbar (NavigationToolbar2QT ):
305366 def __init__ (self , canvas , parent ):
306367 super ().__init__ (canvas , parent )
368+ self .manager = ToolbarOptionManager ()
307369
308370 self .settings_btn = QToolButton (self )
309371 self .settings_btn .setIcon (qta .icon ("mdi6.cog-outline" ))
310372 self .settings_btn .setPopupMode (QToolButton .InstantPopup )
311373 self .settings_menu = QMenu (self .settings_btn )
312- self .settings_menu .addAction ("Option 1" )
313- self .settings_menu .addAction ("Option 2" )
374+ self .groupy_by_options = {
375+ grp : QAction (f"Groupy by { grp } " , self )
376+ for grp in ["observable" , "dataset" , "condition" ]
377+ }
378+ for grp , action in self .groupy_by_options .items ():
379+ action .setCheckable (True )
380+ action .triggered .connect (lambda _ , grp = grp : self .manager .set_option (grp ))
381+ self .settings_menu .addAction (action )
382+ self .manager .option_changed .connect (self .update_checked_state )
383+ self .update_checked_state (self .manager .get_option ())
314384 self .settings_btn .setMenu (self .settings_menu )
315385
316386 self .addWidget (self .settings_btn )
387+
388+ def update_checked_state (self , selected_option ):
389+ for action in self .groupy_by_options .values ():
390+ action .setChecked (action .text () == f"Groupy by { selected_option } " )
0 commit comments