@@ -251,58 +251,235 @@ def load_ablation_results():
251251
252252 fig_arch = go .Figure ()
253253
254- # Bright colors that work on dark backgrounds
255- ARCH_BLUE = "#5BA3E6"
256- ARCH_TEAL = "#3DBFA0"
257- ARCH_RED = "#E07050"
258- ARCH_LIGHT = "#B0BEC5"
259- ARCH_WHITE = "#E0E4E8"
260-
261- boxes = [
262- ( 0.5 , 3 , "Forecast input \n (per-node values)" , ARCH_LIGHT ),
263- ( 0.5 , 2 , "Physics constraint layer \n (voltage, capacity, ramp rate)" , ARCH_BLUE ),
264- ( 2.5 , 2 , "Early exit? \n physics > 0.9" , ARCH_RED ),
265- ( 0.5 , 1 , "GNN verifier \n (GATv2Conv, 3 layers, 4 heads)" , ARCH_TEAL ),
266- ( 0.5 , 0 , "Cascade logic layer \n (2-hop neighbor propagation)" , ARCH_WHITE ),
267- ( 3.5 , 0.5 , "Ensemble score \n w_p=0.4, w_g=0.4, w_c=0.2" , ARCH_WHITE ),
268- ]
269-
270- for x , y , text , color in boxes :
254+ # ── Palette ──
255+ A_BLUE = "#5BA3E6" # Physics layer
256+ A_TEAL = "#3DBFA0" # GNN layer
257+ A_AMBER = "#D4A843" # Cascade layer
258+ A_RED = "#E07050" # Early-exit / decision
259+ A_GREY = "#8899AA" # Arrows, secondary text
260+ A_LIGHT = "#C8D0D8" # Input/output text
261+ A_WHITE = "#E8ECF0" # Primary text
262+ A_DIM = "#5A6674" # Faint guides
263+ A_ENSEMBLE = "#9B7FD4" # Ensemble
264+
265+ # ── Coordinate system ──
266+ # X: 0..16 Y: 0..22 (top=22)
267+ XR = [ - 0.5 , 16.5 ]
268+ YR = [ - 0.8 , 22.5 ]
269+
270+ def _rect ( x0 , y0 , x1 , y1 , color , opacity = 0.10 , width = 1.5 , dash = None ) :
271271 fig_arch .add_shape (
272- type = "rect" ,
273- x0 = x - 0.9 , y0 = y - 0.35 , x1 = x + 0.9 , y1 = y + 0.35 ,
274- fillcolor = color , opacity = 0.15 ,
275- line = dict (color = color , width = 1.5 ),
272+ type = "rect" , x0 = x0 , y0 = y0 , x1 = x1 , y1 = y1 ,
273+ fillcolor = color , opacity = opacity ,
274+ line = dict (color = color , width = width , dash = dash ),
276275 )
276+
277+ def _label (x , y , text , color = A_WHITE , size = 11 , bold = False , anchor = "middle" ):
278+ prefix = "<b>" if bold else ""
279+ suffix = "</b>" if bold else ""
277280 fig_arch .add_annotation (
278- x = x , y = y , text = text ,
279- showarrow = False , font = dict (size = 10 , color = color ),
281+ x = x , y = y , text = f"{ prefix } { text } { suffix } " ,
282+ showarrow = False , font = dict (size = size , color = color ),
283+ xanchor = anchor if anchor != "middle" else "center" ,
280284 )
281285
282- # Arrows
283- arrows = [
284- (0.5 , 2.65 , 0.5 , 2.35 ), # input -> physics
285- (1.4 , 2 , 1.6 , 2 ), # physics -> early exit
286- (0.5 , 1.65 , 0.5 , 1.35 ), # physics -> GNN
287- (0.5 , 0.65 , 0.5 , 0.35 ), # GNN -> cascade
288- (1.4 , 0 , 2.6 , 0.35 ), # cascade -> ensemble
289- (1.4 , 1 , 2.6 , 0.65 ), # GNN -> ensemble
290- (3.4 , 2 , 3.5 , 0.85 ), # early exit -> ensemble
291- ]
292- for x0 , y0 , x1 , y1 in arrows :
286+ def _arrow (x0 , y0 , x1 , y1 , color = A_GREY , width = 1.5 , dash = None ):
293287 fig_arch .add_annotation (
294288 x = x1 , y = y1 , ax = x0 , ay = y0 ,
295289 xref = "x" , yref = "y" , axref = "x" , ayref = "y" ,
296- showarrow = True ,
297- arrowhead = 2 , arrowsize = 1 , arrowwidth = 1.5 , arrowcolor = ARCH_LIGHT ,
290+ showarrow = True , arrowhead = 2 , arrowsize = 1.2 ,
291+ arrowwidth = width , arrowcolor = color ,
298292 )
299293
294+ # ================================================================
295+ # INPUT ROW (y ~ 21)
296+ # ================================================================
297+ _rect (1.5 , 20.4 , 5.5 , 21.6 , A_LIGHT , opacity = 0.08 )
298+ _label (3.5 , 21.2 , "Forecast input" , A_LIGHT , size = 12 , bold = True )
299+ _label (3.5 , 20.7 , "f(n): per-node values, n = 44 nodes" , A_GREY , size = 9 )
300+
301+ _rect (10.5 , 20.4 , 14.5 , 21.6 , A_LIGHT , opacity = 0.08 )
302+ _label (12.5 , 21.2 , "SSEN graph topology" , A_LIGHT , size = 12 , bold = True )
303+ _label (12.5 , 20.7 , "G(V, E): 44 nodes, 60 edges, 3 types" , A_GREY , size = 9 )
304+
305+ # Arrows from inputs down
306+ _arrow (3.5 , 20.4 , 3.5 , 19.6 ) # forecast -> physics
307+ _arrow (12.5 , 20.4 , 12.5 , 19.0 ) # graph -> right side (long, to GNN)
308+
309+ # ================================================================
310+ # LAYER 1: PHYSICS CONSTRAINTS (y ~ 16-19.5)
311+ # ================================================================
312+ _rect (0.3 , 15.6 , 10.0 , 19.5 , A_BLUE , opacity = 0.06 , width = 2 )
313+ _label (0.8 , 19.1 , "Layer 1: Physics constraints" , A_BLUE , size = 13 , bold = True , anchor = "left" )
314+ _label (6.5 , 19.1 , "Tolerance band scoring" , A_GREY , size = 9 )
315+
316+ # Sub-components
317+ _rect (0.8 , 16.8 , 3.5 , 18.5 , A_BLUE , opacity = 0.12 )
318+ _label (2.15 , 18.1 , "Voltage" , A_BLUE , size = 11 , bold = True )
319+ _label (2.15 , 17.65 , "BS EN 50160" , A_GREY , size = 8 )
320+ _label (2.15 , 17.25 , "230V nominal" , A_WHITE , size = 9 )
321+ _label (2.15 , 16.9 , "Safe: -6% / +8%" , A_GREY , size = 8 )
322+
323+ _rect (3.9 , 16.8 , 6.6 , 18.5 , A_BLUE , opacity = 0.12 )
324+ _label (5.25 , 18.1 , "Capacity" , A_BLUE , size = 11 , bold = True )
325+ _label (5.25 , 17.65 , "BS 7671:2018" , A_GREY , size = 8 )
326+ _label (5.25 , 17.25 , "15 kW typical" , A_WHITE , size = 9 )
327+ _label (5.25 , 16.9 , "100 kW absolute max" , A_GREY , size = 8 )
328+
329+ _rect (7.0 , 16.8 , 9.6 , 18.5 , A_BLUE , opacity = 0.12 )
330+ _label (8.3 , 18.1 , "Ramp rate" , A_BLUE , size = 11 , bold = True )
331+ _label (8.3 , 17.65 , "Rate of change" , A_GREY , size = 8 )
332+ _label (8.3 , 17.25 , "3.5 kW/interval warn" , A_WHITE , size = 9 )
333+ _label (8.3 , 16.9 , "5.0 kW/interval max" , A_GREY , size = 8 )
334+
335+ # Output annotation
336+ _label (5.15 , 16.15 , "Output: severity scores per node [0, 1]" , A_BLUE , size = 9 )
337+ _label (5.15 , 15.75 , "Combined = max(voltage, capacity, ramp) per node" , A_DIM , size = 8 )
338+
339+ # ================================================================
340+ # EARLY-EXIT DECISION (y ~ 13.5-15.5)
341+ # ================================================================
342+ # Diamond-style decision box
343+ _rect (2.8 , 13.5 , 7.2 , 15.2 , A_RED , opacity = 0.10 , width = 2 , dash = "dot" )
344+ _label (5.0 , 14.7 , "Early-exit decision" , A_RED , size = 12 , bold = True )
345+ _label (5.0 , 14.2 , "severity > 0.9 ?" , A_WHITE , size = 11 )
346+ _label (5.0 , 13.7 , "Auto-detect: voltage scoring skipped if values < 103V" , A_DIM , size = 8 )
347+
348+ # Arrow from physics down to decision
349+ _arrow (5.0 , 15.6 , 5.0 , 15.2 )
350+
351+ # YES path -- skip GNN, go right to ensemble
352+ _label (8.4 , 14.7 , "YES" , A_RED , size = 10 , bold = True )
353+ _label (8.4 , 14.3 , "Skip GNN" , A_RED , size = 9 )
354+ _arrow (7.2 , 14.5 , 8.0 , 14.5 , color = A_RED , width = 2 )
355+
356+ # Arrow from YES to ensemble (right side, curves down)
357+ _rect (10.5 , 13.8 , 14.5 , 15.2 , A_RED , opacity = 0.06 , dash = "dot" )
358+ _label (12.5 , 14.7 , "Early-exit path" , A_RED , size = 10 , bold = True )
359+ _label (12.5 , 14.2 , "Use physics score only" , A_GREY , size = 9 )
360+ _label (12.5 , 13.9 , "Weights become (1.0, 0.0, 0.0)" , A_DIM , size = 8 )
361+ _arrow (8.8 , 14.5 , 10.5 , 14.5 , color = A_RED , width = 1.5 , dash = "dot" )
362+
363+ # NO path -- continue to GNN
364+ _label (5.0 , 13.1 , "NO: continue" , A_TEAL , size = 9 )
365+ _arrow (5.0 , 13.5 , 5.0 , 12.6 , color = A_TEAL , width = 2 )
366+
367+ # ================================================================
368+ # LAYER 2: GNN VERIFIER (y ~ 8.5-12.5)
369+ # ================================================================
370+ _rect (0.3 , 8.2 , 10.0 , 12.5 , A_TEAL , opacity = 0.06 , width = 2 )
371+ _label (0.8 , 12.1 , "Layer 2: GNN verifier" , A_TEAL , size = 13 , bold = True , anchor = "left" )
372+ _label (6.5 , 12.1 , "GATVerifier" , A_GREY , size = 9 )
373+
374+ # Graph topology input arrow from right
375+ _arrow (12.5 , 13.8 , 10.0 , 11.0 , color = A_GREY , width = 1 )
376+ _label (12.0 , 12.3 , "edge_index, node_type" , A_GREY , size = 8 )
377+
378+ # Sub-components (2x2 grid)
379+ _rect (0.8 , 10.0 , 4.8 , 11.7 , A_TEAL , opacity = 0.12 )
380+ _label (2.8 , 11.3 , "GATv2Conv attention" , A_TEAL , size = 11 , bold = True )
381+ _label (2.8 , 10.85 , "3 layers, 4 heads per layer" , A_WHITE , size = 9 )
382+ _label (2.8 , 10.5 , "Dynamic attention (not static GAT)" , A_GREY , size = 8 )
383+ _label (2.8 , 10.15 , "concat=True, 64 hidden channels" , A_GREY , size = 8 )
384+
385+ _rect (5.2 , 10.0 , 9.6 , 11.7 , A_TEAL , opacity = 0.12 )
386+ _label (7.4 , 11.3 , "Oversmoothing prevention" , A_TEAL , size = 11 , bold = True )
387+ _label (7.4 , 10.85 , "GCNII-style initial residual" , A_WHITE , size = 9 )
388+ _label (7.4 , 10.5 , "Learnable alpha per layer" , A_GREY , size = 8 )
389+ _label (7.4 , 10.15 , "Preserves node distinguishability" , A_GREY , size = 8 )
390+
391+ _rect (0.8 , 8.5 , 4.8 , 9.7 , A_TEAL , opacity = 0.12 )
392+ _label (2.8 , 9.3 , "Temporal encoder" , A_TEAL , size = 11 , bold = True )
393+ _label (2.8 , 8.9 , "1D-Conv, captures local patterns" , A_WHITE , size = 9 )
394+ _label (2.8 , 8.6 , "5 temporal features per node" , A_GREY , size = 8 )
395+
396+ _rect (5.2 , 8.5 , 9.6 , 9.7 , A_TEAL , opacity = 0.12 )
397+ _label (7.4 , 9.3 , "Output head" , A_TEAL , size = 11 , bold = True )
398+ _label (7.4 , 8.9 , "Sigmoid activation -> [0, 1]" , A_WHITE , size = 9 )
399+ _label (7.4 , 8.6 , "Per-node anomaly probability" , A_GREY , size = 8 )
400+
401+ # ================================================================
402+ # LAYER 3: CASCADE LOGIC (y ~ 4.5-7.8)
403+ # ================================================================
404+ _arrow (5.0 , 8.2 , 5.0 , 7.8 ) # GNN -> cascade
405+ _arrow (12.5 , 13.8 , 12.5 , 7.8 , color = A_GREY , width = 1 ) # graph -> cascade
406+ _label (13.0 , 10.5 , "Graph" , A_GREY , size = 8 )
407+ _label (13.0 , 10.1 , "topology" , A_GREY , size = 8 )
408+
409+ _rect (0.3 , 4.5 , 14.5 , 7.8 , A_AMBER , opacity = 0.06 , width = 2 )
410+ _label (0.8 , 7.4 , "Layer 3: Cascade logic" , A_AMBER , size = 13 , bold = True , anchor = "left" )
411+ _label (6.5 , 7.4 , "Neighbor propagation scoring" , A_GREY , size = 9 )
412+
413+ _rect (0.8 , 5.0 , 4.8 , 7.0 , A_AMBER , opacity = 0.12 )
414+ _label (2.8 , 6.6 , "BFS propagation" , A_AMBER , size = 11 , bold = True )
415+ _label (2.8 , 6.2 , "2-hop neighborhood traversal" , A_WHITE , size = 9 )
416+ _label (2.8 , 5.85 , "Decay: 0.7 per hop" , A_WHITE , size = 9 )
417+ _label (2.8 , 5.45 , "Hop 0: 1.0 Hop 1: 0.70 Hop 2: 0.49" , A_GREY , size = 8 )
418+ _label (2.8 , 5.1 , "Max 30% of nodes affected" , A_GREY , size = 8 )
419+
420+ _rect (5.2 , 5.0 , 9.6 , 7.0 , A_AMBER , opacity = 0.12 )
421+ _label (7.4 , 6.6 , "Anomaly aggregation" , A_AMBER , size = 11 , bold = True )
422+ _label (7.4 , 6.2 , "Score = f(neighbor anomalies)" , A_WHITE , size = 9 )
423+ _label (7.4 , 5.85 , "High score = neighbors also anomalous" , A_GREY , size = 8 )
424+ _label (7.4 , 5.45 , "Distinguishes isolated spikes" , A_GREY , size = 8 )
425+ _label (7.4 , 5.1 , "from cascading failures" , A_GREY , size = 8 )
426+
427+ _rect (10.0 , 5.0 , 14.2 , 7.0 , A_AMBER , opacity = 0.12 )
428+ _label (12.1 , 6.6 , "Adjacency construction" , A_AMBER , size = 11 , bold = True )
429+ _label (12.1 , 6.2 , "From edge_index (COO format)" , A_WHITE , size = 9 )
430+ _label (12.1 , 5.85 , "Primary -> Secondary -> LV" , A_GREY , size = 8 )
431+ _label (12.1 , 5.45 , "Bidirectional edges" , A_GREY , size = 8 )
432+
433+ # ================================================================
434+ # ENSEMBLE (y ~ 1.5-4)
435+ # ================================================================
436+ _arrow (5.0 , 4.5 , 5.0 , 4.0 ) # cascade -> ensemble
437+ _arrow (12.5 , 13.8 , 14.0 , 4.0 , color = A_RED , width = 1 , dash = "dot" ) # early-exit to ensemble
438+
439+ _rect (0.3 , 1.2 , 14.5 , 4.0 , A_ENSEMBLE , opacity = 0.06 , width = 2 )
440+ _label (0.8 , 3.6 , "Ensemble combination" , A_ENSEMBLE , size = 13 , bold = True , anchor = "left" )
441+
442+ # Weight boxes
443+ _rect (0.8 , 1.6 , 4.2 , 3.2 , A_BLUE , opacity = 0.10 )
444+ _label (2.5 , 2.85 , "Physics score" , A_BLUE , size = 10 , bold = True )
445+ _label (2.5 , 2.4 , "w_p = 0.4" , A_WHITE , size = 12 , bold = True )
446+ _label (2.5 , 1.95 , "Constraint violations" , A_GREY , size = 8 )
447+
448+ _rect (4.6 , 1.6 , 8.0 , 3.2 , A_TEAL , opacity = 0.10 )
449+ _label (6.3 , 2.85 , "GNN score" , A_TEAL , size = 10 , bold = True )
450+ _label (6.3 , 2.4 , "w_g = 0.4" , A_WHITE , size = 12 , bold = True )
451+ _label (6.3 , 1.95 , "Learned patterns" , A_GREY , size = 8 )
452+
453+ _rect (8.4 , 1.6 , 11.6 , 3.2 , A_AMBER , opacity = 0.10 )
454+ _label (10.0 , 2.85 , "Cascade score" , A_AMBER , size = 10 , bold = True )
455+ _label (10.0 , 2.4 , "w_c = 0.2" , A_WHITE , size = 12 , bold = True )
456+ _label (10.0 , 1.95 , "Topology propagation" , A_GREY , size = 8 )
457+
458+ _rect (12.0 , 1.6 , 14.2 , 3.2 , A_RED , opacity = 0.08 , dash = "dot" )
459+ _label (13.1 , 2.85 , "Early-exit" , A_RED , size = 10 , bold = True )
460+ _label (13.1 , 2.4 , "(1, 0, 0)" , A_WHITE , size = 11 , bold = True )
461+ _label (13.1 , 1.95 , "Physics only" , A_GREY , size = 8 )
462+
463+ _label (7.5 , 1.35 , "combined = w_p * physics + w_g * gnn + w_c * cascade (per node)" , A_DIM , size = 9 )
464+
465+ # ================================================================
466+ # OUTPUT (y ~ 0)
467+ # ================================================================
468+ _arrow (7.5 , 1.2 , 7.5 , 0.6 )
469+ _rect (4.5 , - 0.2 , 10.5 , 0.6 , A_LIGHT , opacity = 0.08 )
470+ _label (7.5 , 0.4 , "Verification reward" , A_LIGHT , size = 12 , bold = True )
471+ _label (7.5 , 0.0 , "r in [-1, +1] | FN penalty ratio: 2.0" , A_GREY , size = 9 )
472+
473+ # ================================================================
474+ # Layout
475+ # ================================================================
300476 fig_arch .update_layout (
301477 template = "plotly_dark" ,
302- height = 350 ,
303- margin = dict (l = 10 , r = 10 , t = 10 , b = 10 ),
304- xaxis = dict (showgrid = False , showticklabels = False , zeroline = False , range = [- 0.8 , 5 ]),
305- yaxis = dict (showgrid = False , showticklabels = False , zeroline = False , range = [- 0.6 , 3.6 ]),
478+ height = 900 ,
479+ margin = dict (l = 5 , r = 5 , t = 5 , b = 5 ),
480+ xaxis = dict (showgrid = False , showticklabels = False , zeroline = False , range = XR ),
481+ yaxis = dict (showgrid = False , showticklabels = False , zeroline = False , range = YR ,
482+ scaleanchor = "x" , scaleratio = 1 ),
306483 paper_bgcolor = "rgba(0,0,0,0)" ,
307484 plot_bgcolor = "rgba(0,0,0,0)" ,
308485 )
0 commit comments