3737GREY_TEXT = "#5A6270"
3838WHITE = "#FFFFFF"
3939
40- PLOTLY_TEMPLATE = "plotly_white "
41- CHART_COLORS = [STEEL , TEAL , RED , NAVY , "#7B8794" , "#C4A35A" , "#6C4F9C" , "#2CA58D" ]
40+ PLOTLY_TEMPLATE = "plotly_dark "
41+ CHART_COLORS = [STEEL , TEAL , RED , "#B0BEC5" , "#7B8794" , "#C4A35A" , "#6C4F9C" , "#2CA58D" ]
4242
43- # Minimal CSS -- no gradients, no glassmorphism
43+ # Common layout kwargs for transparent plotly charts on dark Streamlit
44+ DARK_LAYOUT = dict (
45+ paper_bgcolor = "rgba(0,0,0,0)" ,
46+ plot_bgcolor = "rgba(0,0,0,0)" ,
47+ )
48+
49+ # Minimal CSS -- let Streamlit dark theme handle colors
4450st .markdown (
45- f """
51+ """
4652 <style>
47- .stApp {{ background-color: { WHITE } ; }}
48- section[data-testid="stSidebar"] {{ background-color: { GREY_BG } ; }}
49- h1, h2, h3 {{ color: { NAVY } ; font-weight: 500; }}
50- .stMetric label {{ color: { GREY_TEXT } ; }}
51- div[data-testid="stMetricValue"] {{ color: { NAVY } ; }}
52- .stTabs [data-baseweb="tab"] {{ color: { NAVY } ; }}
53+ h1, h2, h3 { font-weight: 500; }
5354 </style>
5455 """ ,
5556 unsafe_allow_html = True ,
@@ -186,11 +187,11 @@ def load_ablation_results():
186187 fig_tl .add_trace (go .Scatter (
187188 x = xs , y = [0 ] * len (xs ),
188189 mode = "lines+markers+text" ,
189- marker = dict (size = 14 , color = NAVY , symbol = "circle" ),
190- line = dict (color = NAVY , width = 2 ),
190+ marker = dict (size = 14 , color = STEEL , symbol = "circle" ),
191+ line = dict (color = STEEL , width = 2 ),
191192 text = labels ,
192193 textposition = "top center" ,
193- textfont = dict (size = 11 , color = NAVY ),
194+ textfont = dict (size = 11 , color = "#E0E4E8" ),
194195 hovertext = descriptions ,
195196 hoverinfo = "text" ,
196197 showlegend = False ,
@@ -199,7 +200,7 @@ def load_ablation_results():
199200 for i , desc in enumerate (descriptions ):
200201 fig_tl .add_annotation (
201202 x = i , y = - 0.15 , text = desc ,
202- showarrow = False , font = dict (size = 10 , color = GREY_TEXT ),
203+ showarrow = False , font = dict (size = 10 , color = "#A0AAB4" ),
203204 xanchor = "center" ,
204205 )
205206
@@ -209,6 +210,7 @@ def load_ablation_results():
209210 margin = dict (l = 20 , r = 20 , t = 10 , b = 60 ),
210211 xaxis = dict (showgrid = False , showticklabels = False , zeroline = False ),
211212 yaxis = dict (showgrid = False , showticklabels = False , zeroline = False , range = [- 0.4 , 0.3 ]),
213+ ** DARK_LAYOUT ,
212214 )
213215 st .plotly_chart (fig_tl , use_container_width = True )
214216
@@ -219,20 +221,27 @@ def load_ablation_results():
219221
220222 fig_arch = go .Figure ()
221223
224+ # Bright colors that work on dark backgrounds
225+ ARCH_BLUE = "#5BA3E6"
226+ ARCH_TEAL = "#3DBFA0"
227+ ARCH_RED = "#E07050"
228+ ARCH_LIGHT = "#B0BEC5"
229+ ARCH_WHITE = "#E0E4E8"
230+
222231 boxes = [
223- (0.5 , 3 , "Forecast input\n (per-node values)" , GREY_TEXT ),
224- (0.5 , 2 , "Physics constraint layer\n (voltage, capacity, ramp rate)" , STEEL ),
225- (2.5 , 2 , "Early exit?\n physics > 0.9" , RED ),
226- (0.5 , 1 , "GNN verifier\n (GATv2Conv, 3 layers, 4 heads)" , TEAL ),
227- (0.5 , 0 , "Cascade logic layer\n (2-hop neighbor propagation)" , NAVY ),
228- (3.5 , 0.5 , "Ensemble score\n w_p=0.4, w_g=0.4, w_c=0.2" , NAVY ),
232+ (0.5 , 3 , "Forecast input\n (per-node values)" , ARCH_LIGHT ),
233+ (0.5 , 2 , "Physics constraint layer\n (voltage, capacity, ramp rate)" , ARCH_BLUE ),
234+ (2.5 , 2 , "Early exit?\n physics > 0.9" , ARCH_RED ),
235+ (0.5 , 1 , "GNN verifier\n (GATv2Conv, 3 layers, 4 heads)" , ARCH_TEAL ),
236+ (0.5 , 0 , "Cascade logic layer\n (2-hop neighbor propagation)" , ARCH_WHITE ),
237+ (3.5 , 0.5 , "Ensemble score\n w_p=0.4, w_g=0.4, w_c=0.2" , ARCH_WHITE ),
229238 ]
230239
231240 for x , y , text , color in boxes :
232241 fig_arch .add_shape (
233242 type = "rect" ,
234243 x0 = x - 0.9 , y0 = y - 0.35 , x1 = x + 0.9 , y1 = y + 0.35 ,
235- fillcolor = color , opacity = 0.12 ,
244+ fillcolor = color , opacity = 0.15 ,
236245 line = dict (color = color , width = 1.5 ),
237246 )
238247 fig_arch .add_annotation (
@@ -255,15 +264,17 @@ def load_ablation_results():
255264 x = x1 , y = y1 , ax = x0 , ay = y0 ,
256265 xref = "x" , yref = "y" , axref = "x" , ayref = "y" ,
257266 showarrow = True ,
258- arrowhead = 2 , arrowsize = 1 , arrowwidth = 1.5 , arrowcolor = GREY_TEXT ,
267+ arrowhead = 2 , arrowsize = 1 , arrowwidth = 1.5 , arrowcolor = ARCH_LIGHT ,
259268 )
260269
261270 fig_arch .update_layout (
262- template = PLOTLY_TEMPLATE ,
271+ template = "plotly_dark" ,
263272 height = 350 ,
264273 margin = dict (l = 10 , r = 10 , t = 10 , b = 10 ),
265274 xaxis = dict (showgrid = False , showticklabels = False , zeroline = False , range = [- 0.8 , 5 ]),
266275 yaxis = dict (showgrid = False , showticklabels = False , zeroline = False , range = [- 0.6 , 3.6 ]),
276+ paper_bgcolor = "rgba(0,0,0,0)" ,
277+ plot_bgcolor = "rgba(0,0,0,0)" ,
267278 )
268279 st .plotly_chart (fig_arch , use_container_width = True )
269280
@@ -351,6 +362,7 @@ def load_ablation_results():
351362 xaxis_title = "Interval (30 min)" ,
352363 yaxis_title = "Value (kW)" ,
353364 legend = dict (orientation = "h" , yanchor = "bottom" , y = 1.02 ),
365+ ** DARK_LAYOUT ,
354366 )
355367 st .plotly_chart (fig_fc , use_container_width = True )
356368
@@ -381,6 +393,7 @@ def load_ablation_results():
381393 margin = dict (l = 80 , r = 40 , t = 10 , b = 30 ),
382394 xaxis = dict (range = [0 , 1.15 ], title = "Score" ),
383395 yaxis = dict (autorange = "reversed" ),
396+ ** DARK_LAYOUT ,
384397 )
385398 st .plotly_chart (fig_bars , use_container_width = True )
386399
@@ -434,9 +447,10 @@ def load_ablation_results():
434447 st .error (f"Failed to load graph: { e } " )
435448 st .stop ()
436449
437- num_nodes = graph_data .num_nodes
438- num_edges = graph_data .edge_index .shape [1 ]
439450 node_types = graph_data .node_type .numpy ()
451+ # Use node_type length as authoritative count (may differ from num_nodes)
452+ num_nodes = len (node_types )
453+ num_edges = graph_data .edge_index .shape [1 ]
440454
441455 # Stats
442456 s1 , s2 , s3 , s4 = st .columns (4 )
@@ -448,7 +462,6 @@ def load_ablation_results():
448462 st .divider ()
449463
450464 # Build layout with hierarchy
451- # Use a simple spring layout based on adjacency
452465 edge_index = graph_data .edge_index .numpy ()
453466
454467 # Assign positions by type for clear hierarchy
@@ -470,6 +483,10 @@ def load_ablation_results():
470483 st .session_state .anomaly_scores = np .zeros (num_nodes )
471484
472485 anomaly_scores = st .session_state .anomaly_scores
486+ # Ensure anomaly_scores matches current node count
487+ if len (anomaly_scores ) != num_nodes :
488+ anomaly_scores = np .zeros (num_nodes )
489+ st .session_state .anomaly_scores = anomaly_scores
473490
474491 # Cascade injection button
475492 col_btn , col_info = st .columns ([1 , 3 ])
@@ -507,17 +524,18 @@ def load_ablation_results():
507524 # Build network figure
508525 fig_net = go .Figure ()
509526
510- # Draw edges
527+ # Draw edges (skip any that reference out-of-bounds nodes)
511528 edge_x , edge_y = [], []
512529 for i in range (edge_index .shape [1 ]):
513530 src , dst = edge_index [0 , i ], edge_index [1 , i ]
514- edge_x .extend ([pos_x [src ], pos_x [dst ], None ])
515- edge_y .extend ([pos_y [src ], pos_y [dst ], None ])
531+ if src < num_nodes and dst < num_nodes :
532+ edge_x .extend ([pos_x [src ], pos_x [dst ], None ])
533+ edge_y .extend ([pos_y [src ], pos_y [dst ], None ])
516534
517535 fig_net .add_trace (go .Scatter (
518536 x = edge_x , y = edge_y ,
519537 mode = "lines" ,
520- line = dict (width = 0.8 , color = "#C0C8D0 " ),
538+ line = dict (width = 0.8 , color = "#4A5568 " ),
521539 hoverinfo = "none" ,
522540 showlegend = False ,
523541 ))
@@ -538,7 +556,7 @@ def load_ablation_results():
538556 b = int (29 * (1 - s ))
539557 node_colors .append (f"rgb({ r } ,{ g } ,{ b } )" )
540558 else :
541- node_colors .append ({0 : NAVY , 1 : STEEL , 2 : TEAL }[ntype ])
559+ node_colors .append ({0 : "#B0BEC5" , 1 : STEEL , 2 : TEAL }[ntype ])
542560
543561 hover_texts = [
544562 f"Node { idx } ({ type_names [ntype ]} )\n Anomaly: { anomaly_scores [idx ]:.2f} "
@@ -566,6 +584,7 @@ def load_ablation_results():
566584 xaxis = dict (showgrid = False , showticklabels = False , zeroline = False ),
567585 yaxis = dict (showgrid = False , showticklabels = False , zeroline = False ),
568586 legend = dict (orientation = "h" , yanchor = "bottom" , y = 1.02 , xanchor = "center" , x = 0.5 ),
587+ ** DARK_LAYOUT ,
569588 )
570589 st .plotly_chart (fig_net , use_container_width = True )
571590
@@ -611,6 +630,7 @@ def load_ablation_results():
611630 margin = dict (l = 140 , r = 60 , t = 10 , b = 40 ),
612631 xaxis = dict (range = [0 , 1.12 ], title = "ROC-AUC" ),
613632 yaxis = dict (autorange = "reversed" ),
633+ ** DARK_LAYOUT ,
614634 )
615635 st .plotly_chart (fig_roc , use_container_width = True )
616636
@@ -699,6 +719,7 @@ def load_ablation_results():
699719 overlaying = "y" , range = [0 , max (sweep_latency ) * 1.3 ],
700720 ),
701721 legend = dict (orientation = "h" , yanchor = "bottom" , y = 1.02 ),
722+ ** DARK_LAYOUT ,
702723 )
703724 st .plotly_chart (fig_sweep , use_container_width = True )
704725
@@ -854,6 +875,7 @@ def load_ablation_results():
854875 margin = dict (l = 50 , r = 20 , t = 40 , b = 40 ),
855876 xaxis_title = "Episode" ,
856877 yaxis_title = "Reward" ,
878+ ** DARK_LAYOUT ,
857879 )
858880 st .plotly_chart (fig_rew , use_container_width = True )
859881
@@ -874,6 +896,7 @@ def load_ablation_results():
874896 margin = dict (l = 50 , r = 20 , t = 40 , b = 40 ),
875897 xaxis_title = "Episode" ,
876898 yaxis_title = "Loss" ,
899+ ** DARK_LAYOUT ,
877900 )
878901 st .plotly_chart (fig_loss , use_container_width = True )
879902
@@ -895,6 +918,7 @@ def load_ablation_results():
895918 margin = dict (l = 50 , r = 20 , t = 40 , b = 40 ),
896919 xaxis_title = "Scenario type" ,
897920 yaxis_title = "Count" ,
921+ ** DARK_LAYOUT ,
898922 )
899923 st .plotly_chart (fig_dist , use_container_width = True )
900924
0 commit comments