11import logging
22import os
33import webbrowser
4+ from collections .abc import Callable
45from threading import Timer
56
67import numpy as np
78import torch
89from dash import Dash , Input , Output , callback , dcc , html
910from plotly .graph_objs import Figure
11+ from typing_extensions import Unpack
1012
1113from plots ._utils import Plotter , angle_to_coord , coord_to_angle
1214from torchjd .aggregation import (
1315 IMTLG ,
1416 MGDA ,
17+ Aggregator ,
1518 AlignedMTL ,
1619 CAGrad ,
1720 ConFIG ,
3134MAX_LENGTH = 25.0
3235
3336
37+ def _format_angle_display (angle : float ) -> str :
38+ return f"{ np .degrees (angle ):.1f} °"
39+
40+
41+ def _format_length_display (r : float ) -> str :
42+ return f"{ r :.2f} "
43+
44+
3445def main () -> None :
3546 log = logging .getLogger ("werkzeug" )
3647 log .setLevel (logging .CRITICAL )
@@ -43,27 +54,30 @@ def main() -> None:
4354 ],
4455 )
4556
46- aggregators = [
47- AlignedMTL (),
48- CAGrad (c = 0.5 ),
49- ConFIG (),
50- DualProj (),
51- GradDrop (),
52- GradVac (),
53- IMTLG (),
54- Mean (),
55- MGDA (),
56- NashMTL (n_tasks = matrix .shape [0 ]),
57- PCGrad (),
58- Random (),
59- Sum (),
60- TrimmedMean (trim_number = 1 ),
61- UPGrad (),
62- ]
63-
64- aggregators_dict = {str (aggregator ): aggregator for aggregator in aggregators }
65-
66- plotter = Plotter ([], matrix )
57+ n_tasks = matrix .shape [0 ]
58+ aggregator_factories : dict [str , Callable [[], Aggregator ]] = {
59+ "AlignedMTL-min" : lambda : AlignedMTL (scale_mode = "min" ),
60+ "AlignedMTL-median" : lambda : AlignedMTL (scale_mode = "median" ),
61+ "AlignedMTL-RMSE" : lambda : AlignedMTL (scale_mode = "rmse" ),
62+ str (CAGrad (c = 0.5 )): lambda : CAGrad (c = 0.5 ),
63+ str (ConFIG ()): lambda : ConFIG (),
64+ str (DualProj ()): lambda : DualProj (),
65+ str (GradDrop ()): lambda : GradDrop (),
66+ str (GradVac ()): lambda : GradVac (),
67+ str (IMTLG ()): lambda : IMTLG (),
68+ str (Mean ()): lambda : Mean (),
69+ str (MGDA ()): lambda : MGDA (),
70+ str (NashMTL (n_tasks = n_tasks )): lambda : NashMTL (n_tasks = n_tasks ),
71+ str (PCGrad ()): lambda : PCGrad (),
72+ str (Random ()): lambda : Random (),
73+ str (Sum ()): lambda : Sum (),
74+ str (TrimmedMean (trim_number = 1 )): lambda : TrimmedMean (trim_number = 1 ),
75+ str (UPGrad ()): lambda : UPGrad (),
76+ }
77+
78+ aggregator_strings = list (aggregator_factories .keys ())
79+
80+ plotter = Plotter (aggregator_factories , [], matrix )
6781
6882 app = Dash (__name__ )
6983
@@ -98,7 +112,6 @@ def main() -> None:
98112 gradient_slider_inputs .append (Input (angle_input , "value" ))
99113 gradient_slider_inputs .append (Input (r_input , "value" ))
100114
101- aggregator_strings = [str (aggregator ) for aggregator in aggregators ]
102115 checklist = dcc .Checklist (aggregator_strings , [], id = "aggregator-checklist" )
103116
104117 control_div = html .Div (
@@ -117,32 +130,40 @@ def update_seed(value: int) -> Figure:
117130 plotter .seed = value
118131 return plotter .make_fig ()
119132
133+ n_gradients = len (matrix )
134+ gradient_value_outputs : list [Output ] = []
135+ for i in range (n_gradients ):
136+ gradient_value_outputs .append (Output (f"g{ i + 1 } -angle-value" , "children" ))
137+ gradient_value_outputs .append (Output (f"g{ i + 1 } -length-value" , "children" ))
138+
120139 @callback (
121140 Output ("aggregations-fig" , "figure" , allow_duplicate = True ),
141+ * gradient_value_outputs ,
122142 * gradient_slider_inputs ,
123143 prevent_initial_call = True ,
124144 )
125- def update_gradient_coordinate (* values : str ) -> Figure :
145+ def update_gradient_coordinate (* values : str ) -> tuple [ Figure , Unpack [ tuple [ str , ...]]] :
126146 values_ = [float (value ) for value in values ]
127147
148+ display_parts : list [str ] = []
128149 for j in range (len (values_ ) // 2 ):
129150 angle = values_ [2 * j ]
130151 r = values_ [2 * j + 1 ]
131152 x , y = angle_to_coord (angle , r )
132153 plotter .matrix [j , 0 ] = x
133154 plotter .matrix [j , 1 ] = y
155+ display_parts .append (_format_angle_display (angle ))
156+ display_parts .append (_format_length_display (r ))
134157
135- return plotter .make_fig ()
158+ return ( plotter .make_fig (), * display_parts )
136159
137160 @callback (
138161 Output ("aggregations-fig" , "figure" , allow_duplicate = True ),
139162 Input ("aggregator-checklist" , "value" ),
140163 prevent_initial_call = True ,
141164 )
142165 def update_aggregators (value : list [str ]) -> Figure :
143- aggregator_keys = value
144- new_aggregators = [aggregators_dict [key ] for key in aggregator_keys ]
145- plotter .aggregators = new_aggregators
166+ plotter .selected_keys = list (value )
146167 return plotter .make_fig ()
147168
148169 Timer (1 , open_browser ).start ()
@@ -175,11 +196,56 @@ def make_gradient_div(
175196 style = {"width" : "250px" },
176197 )
177198
199+ label_style : dict [str , str | int ] = {
200+ "display" : "inline-block" ,
201+ "width" : "52px" ,
202+ "margin-right" : "8px" ,
203+ "vertical-align" : "middle" ,
204+ }
205+ value_style : dict [str , str ] = {
206+ "display" : "inline-block" ,
207+ "margin-left" : "10px" ,
208+ "min-width" : "140px" ,
209+ "font-family" : "monospace" ,
210+ "font-size" : "13px" ,
211+ "vertical-align" : "middle" ,
212+ }
213+ row_style : dict [str , str ] = {"display" : "block" , "margin-bottom" : "6px" }
178214 div = html .Div (
179215 [
180- html .P (f"g{ i + 1 } " , style = {"display" : "inline-block" , "margin-right" : 20 }),
181- angle_input ,
182- r_input ,
216+ dcc .Markdown (
217+ f"$g_{{{ i + 1 } }}$" ,
218+ mathjax = True ,
219+ style = {
220+ "margin" : "0 0 6px 0" ,
221+ "font-weight" : "bold" ,
222+ "display" : "block" ,
223+ },
224+ ),
225+ html .Div (
226+ [
227+ html .Span ("Angle" , style = label_style ),
228+ angle_input ,
229+ html .Span (
230+ id = f"g{ i + 1 } -angle-value" ,
231+ children = _format_angle_display (angle ),
232+ style = value_style ,
233+ ),
234+ ],
235+ style = row_style ,
236+ ),
237+ html .Div (
238+ [
239+ html .Span ("Length" , style = label_style ),
240+ r_input ,
241+ html .Span (
242+ id = f"g{ i + 1 } -length-value" ,
243+ children = _format_length_display (r ),
244+ style = value_style ,
245+ ),
246+ ],
247+ style = {** row_style , "margin-bottom" : "12px" },
248+ ),
183249 ],
184250 )
185251 return div , angle_input , r_input
0 commit comments