@@ -91,6 +91,36 @@ def load_graph_data():
9191 return builder .build_from_metadata (df ), df
9292
9393
94+ GNN_NUM_NODES = 44 # GNN checkpoint was trained on 44-node graphs
95+
96+
97+ @st .cache_resource
98+ def load_gnn_graph ():
99+ """Build a 44-node graph compatible with the trained GNN checkpoint.
100+
101+ GridGraphBuilder produces 49 node_type/feature entries but the GNN
102+ was trained on 44 nodes. We slice tensors and filter edges to
103+ produce a Data object the GNN can consume without dimension errors.
104+ """
105+ from torch_geometric .data import Data
106+
107+ full_graph , meta_df = load_graph_data ()
108+ n = GNN_NUM_NODES
109+
110+ # Slice node attributes to first n nodes
111+ node_type = full_graph .node_type [:n ]
112+ x = full_graph .x [:n ] if full_graph .x is not None else None
113+
114+ # Filter edges: keep only those where both src and dst < n
115+ ei = full_graph .edge_index
116+ mask = (ei [0 ] < n ) & (ei [1 ] < n )
117+ edge_index = ei [:, mask ]
118+
119+ data = Data (x = x , edge_index = edge_index , node_type = node_type )
120+ data .num_nodes = n
121+ return data
122+
123+
94124@st .cache_resource
95125def load_hybrid_verifier (_graph_data ):
96126 """Load HybridVerifierAgent with trained GNN checkpoint."""
@@ -286,8 +316,8 @@ def load_ablation_results():
286316 st .header ("Live anomaly detection" )
287317
288318 try :
289- graph_data , _meta_df = load_graph_data ()
290- verifier = load_hybrid_verifier (graph_data )
319+ gnn_graph = load_gnn_graph ()
320+ verifier = load_hybrid_verifier (gnn_graph )
291321 proposer = load_proposer ()
292322 except Exception as e :
293323 st .error (f"Failed to load models: { e } " )
@@ -311,7 +341,7 @@ def load_ablation_results():
311341 scenario = proposer .propose_scenario (
312342 context ,
313343 forecast_horizon = horizon ,
314- graph_data = graph_data ,
344+ graph_data = gnn_graph ,
315345 )
316346 # Override type and magnitude
317347 scenario .scenario_type = scenario_type
@@ -320,10 +350,9 @@ def load_ablation_results():
320350 # Apply scenario to create forecast
321351 forecast_1d = scenario .apply_to_timeseries (context [:horizon ])
322352
323- # Ensure forecast covers enough nodes for the verifier
324- n_eval = len (graph_data .node_type )
325- eval_input = np .zeros (n_eval )
326- n_copy = min (len (forecast_1d ), n_eval )
353+ # Pad or trim to GNN node count (44)
354+ eval_input = np .zeros (GNN_NUM_NODES )
355+ n_copy = min (len (forecast_1d ), GNN_NUM_NODES )
327356 eval_input [:n_copy ] = forecast_1d [:n_copy ]
328357
329358 # Run through hybrid verifier
@@ -412,8 +441,7 @@ def load_ablation_results():
412441 )
413442
414443 early_exits = breakdown .get ("early_exit_count" , 0 )
415- total_nodes = len (graph_data .node_type )
416- st .markdown (f"**Early exits:** { early_exits } /{ total_nodes } nodes" )
444+ st .markdown (f"**Early exits:** { early_exits } /{ GNN_NUM_NODES } nodes" )
417445
418446 # -- Expandable details --
419447 with st .expander ("Raw verification details" ):
@@ -820,7 +848,7 @@ def load_ablation_results():
820848 from fyp .selfplay .verifier import VerifierAgent
821849 from fyp .selfplay .trainer import SelfPlayTrainer
822850
823- graph_data_train , _ = load_graph_data ()
851+ gnn_graph_train = load_gnn_graph ()
824852
825853 with st .spinner ("Initializing agents..." ):
826854 proposer = ProposerAgent (
@@ -835,7 +863,7 @@ def load_ablation_results():
835863 proposer = proposer ,
836864 solver = solver ,
837865 verifier = base_verifier ,
838- graph_data = graph_data_train ,
866+ graph_data = gnn_graph_train ,
839867 )
840868
841869 progress = st .progress (0 )
@@ -858,13 +886,19 @@ def load_ablation_results():
858886
859887 try :
860888 metrics = trainer .train_episode (batch )
861- rewards_history .append (metrics .get ("mean_reward" , 0 ))
862- solver_losses .append (metrics .get ("solver_loss" , 0 ))
889+ rewards_history .append (
890+ metrics .get ("avg_verification_reward" , 0 )
891+ )
892+ solver_losses .append (
893+ metrics .get ("avg_solver_loss" , 0 )
894+ )
863895 curriculum_levels .append (
864- metrics .get ("curriculum_level " , metrics . get ( "difficulty" , 0 ) )
896+ metrics .get ("scenario_diversity " , 0 )
865897 )
866- scenario_types .append (
867- metrics .get ("scenario_type" , "UNKNOWN" )
898+ # scenarios is a list of types per batch item
899+ ep_scenarios = metrics .get ("scenarios" , [])
900+ scenario_types .extend (
901+ [str (s ) for s in ep_scenarios ] if ep_scenarios else ["UNKNOWN" ]
868902 )
869903 except Exception as ep_err :
870904 rewards_history .append (0 )
@@ -929,17 +963,17 @@ def load_ablation_results():
929963 fig_cur .add_trace (go .Scatter (
930964 x = eps_x , y = curriculum_levels ,
931965 mode = "lines+markers" ,
932- name = "Curriculum level " ,
966+ name = "Scenario diversity " ,
933967 line = dict (color = STEEL , width = 2 ),
934968 marker = dict (size = 6 ),
935969 ))
936970 fig_cur .update_layout (
937971 template = PLOTLY_TEMPLATE ,
938- title = "Proposer difficulty per episode" ,
972+ title = "Scenario diversity per episode" ,
939973 height = 300 ,
940974 margin = dict (l = 50 , r = 20 , t = 40 , b = 40 ),
941975 xaxis_title = "Episode" ,
942- yaxis_title = "Difficulty " ,
976+ yaxis_title = "Diversity (unique types / total) " ,
943977 ** DARK_LAYOUT ,
944978 )
945979 st .plotly_chart (fig_cur , use_container_width = True )
0 commit comments