Skip to content

Commit 86fe403

Browse files
test(aggregation): Improve interactive plotter (#641)
* Add factories for aggregators * Improve display of length and angles of gradients * Add different versions of AlignedMTL to be selected
1 parent 012b1ba commit 86fe403

File tree

2 files changed

+116
-35
lines changed

2 files changed

+116
-35
lines changed

tests/plots/_utils.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from collections.abc import Callable
2+
13
import numpy as np
24
import torch
35
from plotly import graph_objects as go
@@ -7,14 +9,22 @@
79

810

911
class Plotter:
10-
def __init__(self, aggregators: list[Aggregator], matrix: torch.Tensor, seed: int = 0) -> None:
11-
self.aggregators = aggregators
12+
def __init__(
13+
self,
14+
aggregator_factories: dict[str, Callable[[], Aggregator]],
15+
selected_keys: list[str],
16+
matrix: torch.Tensor,
17+
seed: int = 0,
18+
) -> None:
19+
self._aggregator_factories = aggregator_factories
20+
self.selected_keys = selected_keys
1221
self.matrix = matrix
1322
self.seed = seed
1423

1524
def make_fig(self) -> Figure:
1625
torch.random.manual_seed(self.seed)
17-
results = [agg(self.matrix) for agg in self.aggregators]
26+
aggregators = [self._aggregator_factories[key]() for key in self.selected_keys]
27+
results = [agg(self.matrix) for agg in aggregators]
1828

1929
fig = go.Figure()
2030

@@ -23,14 +33,19 @@ def make_fig(self) -> Figure:
2333
fig.add_trace(cone)
2434

2535
for i in range(len(self.matrix)):
26-
scatter = make_vector_scatter(self.matrix[i], "blue", f"g{i + 1}")
36+
scatter = make_vector_scatter(
37+
self.matrix[i],
38+
"blue",
39+
f"g{i + 1}",
40+
textposition="top right",
41+
)
2742
fig.add_trace(scatter)
2843

2944
for i in range(len(results)):
3045
scatter = make_vector_scatter(
3146
results[i],
3247
"black",
33-
str(self.aggregators[i]),
48+
self.selected_keys[i],
3449
showlegend=True,
3550
dash=True,
3651
)

tests/plots/interactive_plotter.py

Lines changed: 96 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
import logging
22
import os
33
import webbrowser
4+
from collections.abc import Callable
45
from threading import Timer
56

67
import numpy as np
78
import torch
89
from dash import Dash, Input, Output, callback, dcc, html
910
from plotly.graph_objs import Figure
11+
from typing_extensions import Unpack
1012

1113
from plots._utils import Plotter, angle_to_coord, coord_to_angle
1214
from torchjd.aggregation import (
1315
IMTLG,
1416
MGDA,
17+
Aggregator,
1518
AlignedMTL,
1619
CAGrad,
1720
ConFIG,
@@ -31,6 +34,14 @@
3134
MAX_LENGTH = 25.0
3235

3336

37+
def _format_angle_display(angle: float) -> str:
38+
return f"{np.degrees(angle):.1f}°"
39+
40+
41+
def _format_length_display(r: float) -> str:
42+
return f"{r:.2f}"
43+
44+
3445
def main() -> None:
3546
log = logging.getLogger("werkzeug")
3647
log.setLevel(logging.CRITICAL)
@@ -43,27 +54,30 @@ def main() -> None:
4354
],
4455
)
4556

46-
aggregators = [
47-
AlignedMTL(),
48-
CAGrad(c=0.5),
49-
ConFIG(),
50-
DualProj(),
51-
GradDrop(),
52-
GradVac(),
53-
IMTLG(),
54-
Mean(),
55-
MGDA(),
56-
NashMTL(n_tasks=matrix.shape[0]),
57-
PCGrad(),
58-
Random(),
59-
Sum(),
60-
TrimmedMean(trim_number=1),
61-
UPGrad(),
62-
]
63-
64-
aggregators_dict = {str(aggregator): aggregator for aggregator in aggregators}
65-
66-
plotter = Plotter([], matrix)
57+
n_tasks = matrix.shape[0]
58+
aggregator_factories: dict[str, Callable[[], Aggregator]] = {
59+
"AlignedMTL-min": lambda: AlignedMTL(scale_mode="min"),
60+
"AlignedMTL-median": lambda: AlignedMTL(scale_mode="median"),
61+
"AlignedMTL-RMSE": lambda: AlignedMTL(scale_mode="rmse"),
62+
str(CAGrad(c=0.5)): lambda: CAGrad(c=0.5),
63+
str(ConFIG()): lambda: ConFIG(),
64+
str(DualProj()): lambda: DualProj(),
65+
str(GradDrop()): lambda: GradDrop(),
66+
str(GradVac()): lambda: GradVac(),
67+
str(IMTLG()): lambda: IMTLG(),
68+
str(Mean()): lambda: Mean(),
69+
str(MGDA()): lambda: MGDA(),
70+
str(NashMTL(n_tasks=n_tasks)): lambda: NashMTL(n_tasks=n_tasks),
71+
str(PCGrad()): lambda: PCGrad(),
72+
str(Random()): lambda: Random(),
73+
str(Sum()): lambda: Sum(),
74+
str(TrimmedMean(trim_number=1)): lambda: TrimmedMean(trim_number=1),
75+
str(UPGrad()): lambda: UPGrad(),
76+
}
77+
78+
aggregator_strings = list(aggregator_factories.keys())
79+
80+
plotter = Plotter(aggregator_factories, [], matrix)
6781

6882
app = Dash(__name__)
6983

@@ -98,7 +112,6 @@ def main() -> None:
98112
gradient_slider_inputs.append(Input(angle_input, "value"))
99113
gradient_slider_inputs.append(Input(r_input, "value"))
100114

101-
aggregator_strings = [str(aggregator) for aggregator in aggregators]
102115
checklist = dcc.Checklist(aggregator_strings, [], id="aggregator-checklist")
103116

104117
control_div = html.Div(
@@ -117,32 +130,40 @@ def update_seed(value: int) -> Figure:
117130
plotter.seed = value
118131
return plotter.make_fig()
119132

133+
n_gradients = len(matrix)
134+
gradient_value_outputs: list[Output] = []
135+
for i in range(n_gradients):
136+
gradient_value_outputs.append(Output(f"g{i + 1}-angle-value", "children"))
137+
gradient_value_outputs.append(Output(f"g{i + 1}-length-value", "children"))
138+
120139
@callback(
121140
Output("aggregations-fig", "figure", allow_duplicate=True),
141+
*gradient_value_outputs,
122142
*gradient_slider_inputs,
123143
prevent_initial_call=True,
124144
)
125-
def update_gradient_coordinate(*values: str) -> Figure:
145+
def update_gradient_coordinate(*values: str) -> tuple[Figure, Unpack[tuple[str, ...]]]:
126146
values_ = [float(value) for value in values]
127147

148+
display_parts: list[str] = []
128149
for j in range(len(values_) // 2):
129150
angle = values_[2 * j]
130151
r = values_[2 * j + 1]
131152
x, y = angle_to_coord(angle, r)
132153
plotter.matrix[j, 0] = x
133154
plotter.matrix[j, 1] = y
155+
display_parts.append(_format_angle_display(angle))
156+
display_parts.append(_format_length_display(r))
134157

135-
return plotter.make_fig()
158+
return (plotter.make_fig(), *display_parts)
136159

137160
@callback(
138161
Output("aggregations-fig", "figure", allow_duplicate=True),
139162
Input("aggregator-checklist", "value"),
140163
prevent_initial_call=True,
141164
)
142165
def update_aggregators(value: list[str]) -> Figure:
143-
aggregator_keys = value
144-
new_aggregators = [aggregators_dict[key] for key in aggregator_keys]
145-
plotter.aggregators = new_aggregators
166+
plotter.selected_keys = list(value)
146167
return plotter.make_fig()
147168

148169
Timer(1, open_browser).start()
@@ -175,11 +196,56 @@ def make_gradient_div(
175196
style={"width": "250px"},
176197
)
177198

199+
label_style: dict[str, str | int] = {
200+
"display": "inline-block",
201+
"width": "52px",
202+
"margin-right": "8px",
203+
"vertical-align": "middle",
204+
}
205+
value_style: dict[str, str] = {
206+
"display": "inline-block",
207+
"margin-left": "10px",
208+
"min-width": "140px",
209+
"font-family": "monospace",
210+
"font-size": "13px",
211+
"vertical-align": "middle",
212+
}
213+
row_style: dict[str, str] = {"display": "block", "margin-bottom": "6px"}
178214
div = html.Div(
179215
[
180-
html.P(f"g{i + 1}", style={"display": "inline-block", "margin-right": 20}),
181-
angle_input,
182-
r_input,
216+
dcc.Markdown(
217+
f"$g_{{{i + 1}}}$",
218+
mathjax=True,
219+
style={
220+
"margin": "0 0 6px 0",
221+
"font-weight": "bold",
222+
"display": "block",
223+
},
224+
),
225+
html.Div(
226+
[
227+
html.Span("Angle", style=label_style),
228+
angle_input,
229+
html.Span(
230+
id=f"g{i + 1}-angle-value",
231+
children=_format_angle_display(angle),
232+
style=value_style,
233+
),
234+
],
235+
style=row_style,
236+
),
237+
html.Div(
238+
[
239+
html.Span("Length", style=label_style),
240+
r_input,
241+
html.Span(
242+
id=f"g{i + 1}-length-value",
243+
children=_format_length_display(r),
244+
style=value_style,
245+
),
246+
],
247+
style={**row_style, "margin-bottom": "12px"},
248+
),
183249
],
184250
)
185251
return div, angle_input, r_input

0 commit comments

Comments
 (0)