Skip to content

Commit 5602674

Browse files
group_by (#96)
* Shared conditions * Better highlighting, more robust. current issues: - not errorbar redoing - not clear how barplots would be plotted - highlighting with multiple observables in one plot seems difficult.
1 parent dffb2c8 commit 5602674

File tree

2 files changed

+117
-29
lines changed

2 files changed

+117
-29
lines changed

src/petab_gui/controllers/mother_controller.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import qtawesome as qta
1010
import yaml
11-
from PySide6.QtCore import Qt, QUrl
11+
from PySide6.QtCore import Qt, QTimer, QUrl
1212
from PySide6.QtGui import QAction, QDesktopServices, QKeySequence, QUndoStack
1313
from PySide6.QtWidgets import (
1414
QFileDialog,
@@ -231,13 +231,19 @@ def setup_connections(self):
231231
self.sbml_controller.overwritten_model.connect(
232232
self.parameter_controller.update_handler_sbml
233233
)
234-
# overwrite signals
234+
# Plotting update. Regulated through a Timer
235+
self._plot_update_timer = QTimer()
236+
self._plot_update_timer.setSingleShot(True)
237+
self._plot_update_timer.setInterval(0)
238+
self._plot_update_timer.timeout.connect(self.init_plotter)
235239
for controller in [
236-
# self.measurement_controller,
237-
self.condition_controller
240+
self.measurement_controller,
241+
self.condition_controller,
242+
self.visualization_controller,
243+
self.simulation_controller,
238244
]:
239245
controller.overwritten_df.connect(
240-
self.init_plotter
246+
self._schedule_plot_update
241247
)
242248

243249
def setup_actions(self):
@@ -871,19 +877,23 @@ def init_plotter(self):
871877
self.plotter = self.view.plot_dock
872878
self.plotter.highlighter.click_callback = self._on_plot_point_clicked
873879

874-
def _on_plot_point_clicked(self, x, y, label):
880+
def _on_plot_point_clicked(self, x, y, label, data_type):
875881
# Extract observable ID from label, if formatted like 'obsId (label)'
876-
meas_proxy = self.measurement_controller.proxy_model
882+
proxy = self.measurement_controller.proxy_model
883+
view = self.measurement_controller.view.table_view
884+
if data_type == "simulation":
885+
proxy = self.simulation_controller.proxy_model
886+
view = self.simulation_controller.view.table_view
877887
obs = label
878888

879889
x_axis_col = "time"
880-
y_axis_col = "measurement"
890+
y_axis_col = data_type
881891
observable_col = "observableId"
882892

883893
def column_index(name):
884-
for col in range(meas_proxy.columnCount()):
894+
for col in range(proxy.columnCount()):
885895
if (
886-
meas_proxy.headerData(col, Qt.Horizontal)
896+
proxy.headerData(col, Qt.Horizontal)
887897
== name
888898
):
889899
return col
@@ -893,16 +903,16 @@ def column_index(name):
893903
y_col = column_index(y_axis_col)
894904
obs_col = column_index(observable_col)
895905

896-
for row in range(meas_proxy.rowCount()):
897-
row_obs = meas_proxy.index(row, obs_col).data()
898-
row_x = meas_proxy.index(row, x_col).data()
899-
row_y = meas_proxy.index(row, y_col).data()
906+
for row in range(proxy.rowCount()):
907+
row_obs = proxy.index(row, obs_col).data()
908+
row_x = proxy.index(row, x_col).data()
909+
row_y = proxy.index(row, y_col).data()
900910
try:
901911
row_x, row_y = float(row_x), float(row_y)
902912
except ValueError:
903913
continue
904914
if row_obs == obs and row_x == x and row_y == y:
905-
self.measurement_controller.view.table_view.selectRow(row)
915+
view.selectRow(row)
906916
break
907917

908918
def _on_table_selection_changed(self, selected, deselected):
@@ -919,3 +929,7 @@ def _on_simulation_selection_changed(self, selected, deselected):
919929
proxy=self.simulation_controller.proxy_model,
920930
y_axis_col="simulation"
921931
)
932+
933+
def _schedule_plot_update(self):
934+
"""Start the plot schedule timer."""
935+
self._plot_update_timer.start()

src/petab_gui/views/simple_plot_view.py

Lines changed: 88 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
from matplotlib import pyplot as plt
55
from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg as FigureCanvas
66
from matplotlib.backends.backend_qtagg import NavigationToolbar2QT
7+
from matplotlib.container import ErrorbarContainer
78
from PySide6.QtCore import QObject, QRunnable, Qt, QThreadPool, QTimer, Signal
9+
from PySide6.QtGui import QAction
810
from PySide6.QtWidgets import (
911
QDockWidget,
1012
QMenu,
@@ -22,12 +24,13 @@ class PlotWorkerSignals(QObject):
2224

2325

2426
class PlotWorker(QRunnable):
25-
def __init__(self, vis_df, cond_df, meas_df, sim_df):
27+
def __init__(self, vis_df, cond_df, meas_df, sim_df, group_by):
2628
super().__init__()
2729
self.vis_df = vis_df
2830
self.cond_df = cond_df
2931
self.meas_df = meas_df
3032
self.sim_df = sim_df
33+
self.group_by = group_by
3134
self.signals = PlotWorkerSignals()
3235

3336
def run(self):
@@ -61,6 +64,7 @@ def run(self):
6164
self.cond_df,
6265
measurements_df=self.meas_df,
6366
simulations_df=sim_df,
67+
group_by=self.group_by,
6468
)
6569
fig = plt.gcf()
6670
fig.subplots_adjust(left=0.12, bottom=0.15, right=0.95, top=0.9, wspace=0.3, hspace=0.4)
@@ -77,6 +81,7 @@ class MeasurementPlotter(QDockWidget):
7781
def __init__(self, parent=None):
7882
super().__init__("Measurement Plot", parent)
7983
self.setObjectName("plot_dock")
84+
self.options_manager = ToolbarOptionManager()
8085

8186
self.meas_proxy = None
8287
self.sim_proxy = None
@@ -102,6 +107,7 @@ def initialize(self, meas_proxy, sim_proxy, cond_proxy):
102107
self.vis_df = None
103108

104109
# Connect data changes
110+
self.options_manager.option_changed.connect(self._debounced_plot)
105111
self.meas_proxy.dataChanged.connect(self._debounced_plot)
106112
self.meas_proxy.rowsInserted.connect(self._debounced_plot)
107113
self.meas_proxy.rowsRemoved.connect(self._debounced_plot)
@@ -121,16 +127,26 @@ def plot_it(self):
121127
measurements_df = proxy_to_dataframe(self.meas_proxy)
122128
simulations_df = proxy_to_dataframe(self.sim_proxy)
123129
conditions_df = proxy_to_dataframe(self.cond_proxy)
130+
group_by = self.options_manager.get_option()
131+
# group_by different value in petab.visualize
132+
if group_by == "condition":
133+
group_by = "simulation"
124134

125135
worker = PlotWorker(
126-
self.vis_df, conditions_df, measurements_df, simulations_df
136+
self.vis_df,
137+
conditions_df,
138+
measurements_df,
139+
simulations_df,
140+
group_by
127141
)
128142
worker.signals.finished.connect(self._update_tabs)
129143
QThreadPool.globalInstance().start(worker)
130144

131145
def _update_tabs(self, fig: plt.Figure):
132146
# Clean previous tabs
133147
self.tab_widget.clear()
148+
# Clear Highlighter
149+
self.highlighter.clear_highlight()
134150
if fig is None:
135151
# Fallback: show one empty plot tab
136152
empty_fig, _ = plt.subplots()
@@ -164,11 +180,18 @@ def _update_tabs(self, fig: plt.Figure):
164180
for idx, ax in enumerate(fig.axes):
165181
# Create a new figure and copy Axes content
166182
sub_fig, sub_ax = plt.subplots(constrained_layout=True)
167-
for line in ax.get_lines():
183+
handles, labels = ax.get_legend_handles_labels()
184+
for handle, label in zip(handles, labels, strict=False):
185+
if isinstance(handle, ErrorbarContainer):
186+
line = handle.lines[0]
187+
elif isinstance(handle, plt.Line2D):
188+
line = handle
189+
else:
190+
continue
168191
sub_ax.plot(
169192
line.get_xdata(),
170193
line.get_ydata(),
171-
label=line.get_label(),
194+
label=label,
172195
linestyle=line.get_linestyle(),
173196
marker=line.get_marker(),
174197
color=line.get_color(),
@@ -178,9 +201,7 @@ def _update_tabs(self, fig: plt.Figure):
178201
sub_ax.set_title(ax.get_title())
179202
sub_ax.set_xlabel(ax.get_xlabel())
180203
sub_ax.set_ylabel(ax.get_ylabel())
181-
handles, labels = ax.get_legend_handles_labels()
182-
if handles:
183-
sub_ax.legend(handles=handles, labels=labels, loc="best")
204+
sub_ax.legend()
184205

185206
sub_canvas = FigureCanvas(sub_fig)
186207
sub_toolbar = CustomNavigationToolbar(sub_canvas, self)
@@ -202,20 +223,18 @@ def _update_tabs(self, fig: plt.Figure):
202223
obs_id = f"subplot_{idx}"
203224

204225
self.observable_to_subplot[obs_id] = idx
205-
# Also register the original ax from the full figure (main tab)
206226
self.highlighter.register_subplot(ax, idx)
207227
# Register subplot canvas
208228
self.highlighter.register_subplot(sub_ax, idx)
229+
# Also register the original ax from the full figure (main tab)
209230
self.highlighter.connect_picking(sub_canvas)
210231

211232
def highlight_from_selection(self, selected_rows: list[int], proxy=None, y_axis_col="measurement"):
212233
proxy = proxy or self.meas_proxy
213234
if not proxy:
214235
return
215236

216-
# x_axis_col = self.x_axis_selector.currentText()
217237
x_axis_col = "time"
218-
y_axis_col = "measurement" if proxy == self.meas_proxy else "simulation"
219238
observable_col = "observableId"
220239

221240
def column_index(name):
@@ -260,6 +279,9 @@ def __init__(self):
260279
self.point_index_map = {} # (subplot index, observableId, x, y) → row index
261280
self.click_callback = None
262281

282+
def clear_highlight(self):
283+
self.highlight_scatters = defaultdict(list)
284+
263285
def register_subplot(self, ax, subplot_idx):
264286
scatter = ax.scatter(
265287
[], [], s=80, edgecolors='black', facecolors='none', zorder=5
@@ -293,24 +315,76 @@ def _on_pick(self, event):
293315
ax = artist.axes
294316

295317
# Try to recover the label from the legend (handle → label mapping)
296-
label = ax.get_legend().texts[1].get_text().split()[-1]
318+
handles, labels = ax.get_legend_handles_labels()
319+
label = None
320+
for h, l in zip(handles, labels, strict=False):
321+
if h is artist:
322+
label_parts = l.split()
323+
if label_parts[-1] == "simulation":
324+
data_type = "simulation"
325+
label = label_parts[-2]
326+
else:
327+
data_type = "measurement"
328+
label = label_parts[-1]
329+
break
297330

298331
for i in ind:
299332
x = xdata[i]
300333
y = ydata[i]
301-
self.click_callback(x, y, label)
334+
self.click_callback(x, y, label, data_type)
335+
336+
337+
class ToolbarOptionManager(QObject):
338+
"""A Manager, synchronizing the selected option across all toolbars."""
339+
340+
option_changed = Signal(str)
341+
_instance = None
342+
_initialized = False
343+
344+
def __new__(cls):
345+
if cls._instance is None:
346+
cls._instance = super(ToolbarOptionManager, cls).__new__(cls)
347+
return cls._instance
348+
349+
def __init__(self):
350+
# Ensure QObject.__init__ runs only once
351+
if not self._initialized:
352+
super().__init__()
353+
self._selected_option = "observable"
354+
ToolbarOptionManager._initialized = True
355+
356+
def set_option(self, option):
357+
if option != self._selected_option:
358+
self._selected_option = option
359+
self.option_changed.emit(option)
360+
361+
def get_option(self):
362+
return self._selected_option
302363

303364

304365
class CustomNavigationToolbar(NavigationToolbar2QT):
305366
def __init__(self, canvas, parent):
306367
super().__init__(canvas, parent)
368+
self.manager = ToolbarOptionManager()
307369

308370
self.settings_btn = QToolButton(self)
309371
self.settings_btn.setIcon(qta.icon("mdi6.cog-outline"))
310372
self.settings_btn.setPopupMode(QToolButton.InstantPopup)
311373
self.settings_menu = QMenu(self.settings_btn)
312-
self.settings_menu.addAction("Option 1")
313-
self.settings_menu.addAction("Option 2")
374+
self.groupy_by_options = {
375+
grp: QAction(f"Groupy by {grp}", self)
376+
for grp in ["observable", "dataset", "condition"]
377+
}
378+
for grp, action in self.groupy_by_options.items():
379+
action.setCheckable(True)
380+
action.triggered.connect(lambda _, grp=grp: self.manager.set_option(grp))
381+
self.settings_menu.addAction(action)
382+
self.manager.option_changed.connect(self.update_checked_state)
383+
self.update_checked_state(self.manager.get_option())
314384
self.settings_btn.setMenu(self.settings_menu)
315385

316386
self.addWidget(self.settings_btn)
387+
388+
def update_checked_state(self, selected_option):
389+
for action in self.groupy_by_options.values():
390+
action.setChecked(action.text() == f"Groupy by {selected_option}")

0 commit comments

Comments
 (0)