Skip to content

Commit 55f9083

Browse files
committed
fix(dashboard): resolve GNN dimension mismatch and training metrics
1. Tab 2: GNN trained on 44 nodes but GridGraphBuilder produces 49. Add load_gnn_graph() that slices node_type, features, and edges to 44 nodes. HybridVerifierAgent now receives compatible graph data. Verified: reward=0.2008, all three layer scores non-zero. 2. Tab 5: Use same 44-node graph for SelfPlayTrainer. Fix metric keys: avg_verification_reward (not mean_reward), avg_solver_loss (not solver_loss), scenarios list (not scenario_type string). Training now produces real non-zero values and correct scenario type names.
1 parent 37b4bae commit 55f9083

1 file changed

Lines changed: 53 additions & 19 deletions

File tree

app.py

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,36 @@ def load_graph_data():
9191
return builder.build_from_metadata(df), df
9292

9393

94+
GNN_NUM_NODES = 44 # GNN checkpoint was trained on 44-node graphs
95+
96+
97+
@st.cache_resource
98+
def load_gnn_graph():
99+
"""Build a 44-node graph compatible with the trained GNN checkpoint.
100+
101+
GridGraphBuilder produces 49 node_type/feature entries but the GNN
102+
was trained on 44 nodes. We slice tensors and filter edges to
103+
produce a Data object the GNN can consume without dimension errors.
104+
"""
105+
from torch_geometric.data import Data
106+
107+
full_graph, meta_df = load_graph_data()
108+
n = GNN_NUM_NODES
109+
110+
# Slice node attributes to first n nodes
111+
node_type = full_graph.node_type[:n]
112+
x = full_graph.x[:n] if full_graph.x is not None else None
113+
114+
# Filter edges: keep only those where both src and dst < n
115+
ei = full_graph.edge_index
116+
mask = (ei[0] < n) & (ei[1] < n)
117+
edge_index = ei[:, mask]
118+
119+
data = Data(x=x, edge_index=edge_index, node_type=node_type)
120+
data.num_nodes = n
121+
return data
122+
123+
94124
@st.cache_resource
95125
def load_hybrid_verifier(_graph_data):
96126
"""Load HybridVerifierAgent with trained GNN checkpoint."""
@@ -286,8 +316,8 @@ def load_ablation_results():
286316
st.header("Live anomaly detection")
287317

288318
try:
289-
graph_data, _meta_df = load_graph_data()
290-
verifier = load_hybrid_verifier(graph_data)
319+
gnn_graph = load_gnn_graph()
320+
verifier = load_hybrid_verifier(gnn_graph)
291321
proposer = load_proposer()
292322
except Exception as e:
293323
st.error(f"Failed to load models: {e}")
@@ -311,7 +341,7 @@ def load_ablation_results():
311341
scenario = proposer.propose_scenario(
312342
context,
313343
forecast_horizon=horizon,
314-
graph_data=graph_data,
344+
graph_data=gnn_graph,
315345
)
316346
# Override type and magnitude
317347
scenario.scenario_type = scenario_type
@@ -320,10 +350,9 @@ def load_ablation_results():
320350
# Apply scenario to create forecast
321351
forecast_1d = scenario.apply_to_timeseries(context[:horizon])
322352

323-
# Ensure forecast covers enough nodes for the verifier
324-
n_eval = len(graph_data.node_type)
325-
eval_input = np.zeros(n_eval)
326-
n_copy = min(len(forecast_1d), n_eval)
353+
# Pad or trim to GNN node count (44)
354+
eval_input = np.zeros(GNN_NUM_NODES)
355+
n_copy = min(len(forecast_1d), GNN_NUM_NODES)
327356
eval_input[:n_copy] = forecast_1d[:n_copy]
328357

329358
# Run through hybrid verifier
@@ -412,8 +441,7 @@ def load_ablation_results():
412441
)
413442

414443
early_exits = breakdown.get("early_exit_count", 0)
415-
total_nodes = len(graph_data.node_type)
416-
st.markdown(f"**Early exits:** {early_exits}/{total_nodes} nodes")
444+
st.markdown(f"**Early exits:** {early_exits}/{GNN_NUM_NODES} nodes")
417445

418446
# -- Expandable details --
419447
with st.expander("Raw verification details"):
@@ -820,7 +848,7 @@ def load_ablation_results():
820848
from fyp.selfplay.verifier import VerifierAgent
821849
from fyp.selfplay.trainer import SelfPlayTrainer
822850

823-
graph_data_train, _ = load_graph_data()
851+
gnn_graph_train = load_gnn_graph()
824852

825853
with st.spinner("Initializing agents..."):
826854
proposer = ProposerAgent(
@@ -835,7 +863,7 @@ def load_ablation_results():
835863
proposer=proposer,
836864
solver=solver,
837865
verifier=base_verifier,
838-
graph_data=graph_data_train,
866+
graph_data=gnn_graph_train,
839867
)
840868

841869
progress = st.progress(0)
@@ -858,13 +886,19 @@ def load_ablation_results():
858886

859887
try:
860888
metrics = trainer.train_episode(batch)
861-
rewards_history.append(metrics.get("mean_reward", 0))
862-
solver_losses.append(metrics.get("solver_loss", 0))
889+
rewards_history.append(
890+
metrics.get("avg_verification_reward", 0)
891+
)
892+
solver_losses.append(
893+
metrics.get("avg_solver_loss", 0)
894+
)
863895
curriculum_levels.append(
864-
metrics.get("curriculum_level", metrics.get("difficulty", 0))
896+
metrics.get("scenario_diversity", 0)
865897
)
866-
scenario_types.append(
867-
metrics.get("scenario_type", "UNKNOWN")
898+
# scenarios is a list of types per batch item
899+
ep_scenarios = metrics.get("scenarios", [])
900+
scenario_types.extend(
901+
[str(s) for s in ep_scenarios] if ep_scenarios else ["UNKNOWN"]
868902
)
869903
except Exception as ep_err:
870904
rewards_history.append(0)
@@ -929,17 +963,17 @@ def load_ablation_results():
929963
fig_cur.add_trace(go.Scatter(
930964
x=eps_x, y=curriculum_levels,
931965
mode="lines+markers",
932-
name="Curriculum level",
966+
name="Scenario diversity",
933967
line=dict(color=STEEL, width=2),
934968
marker=dict(size=6),
935969
))
936970
fig_cur.update_layout(
937971
template=PLOTLY_TEMPLATE,
938-
title="Proposer difficulty per episode",
972+
title="Scenario diversity per episode",
939973
height=300,
940974
margin=dict(l=50, r=20, t=40, b=40),
941975
xaxis_title="Episode",
942-
yaxis_title="Difficulty",
976+
yaxis_title="Diversity (unique types / total)",
943977
**DARK_LAYOUT,
944978
)
945979
st.plotly_chart(fig_cur, use_container_width=True)

0 commit comments

Comments
 (0)