Skip to content

Commit d0d301e

Browse files
feat: add train/val gap tracking and bootstrap insights (#185)
* feat: add train/val gap tracking and bootstrap insights * fix: clarify TODO wording and add return type annotation * chore: bump version to 1.4.1 * fix: add type annotations to _evaluate_predictor
1 parent d5032d0 commit d0d301e

11 files changed

Lines changed: 179 additions & 43 deletions

File tree

plexe/CODE_INDEX.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Code Index: plexe
22

3-
> Generated on 2026-03-02 22:03:39
3+
> Generated on 2026-03-03 00:06:47
44
55
Code structure and public interface documentation for the **plexe** package.
66

@@ -222,7 +222,7 @@ Helper functions for workflow.
222222

223223
**Functions:**
224224
- `select_viable_model_types(data_layout: DataLayout, selected_frameworks: list[str] | None) -> list[str]` - Select viable model types using three-tier filtering.
225-
- `evaluate_on_sample(spark: SparkSession, sample_uri: str, model_artifacts_path: Path, model_type: str, metric: str, target_columns: list[str], group_column: str | None) -> float` - Evaluate model on sample (fast).
225+
- `evaluate_on_sample(spark: SparkSession, sample_uri: str, model_artifacts_path: Path, model_type: str, metric: str, target_columns: list[str], group_column: str | None, train_sample_uri: str | None) -> tuple[float, float | None]` - Evaluate model on validation sample, optionally also on training sample.
226226
- `compute_metric_hardcoded(y_true, y_pred, metric_name: str) -> float` - Compute metric using hardcoded sklearn implementations.
227227
- `compute_metric(y_true, y_pred, metric_name: str, group_ids) -> float` - Compute metric value.
228228

plexe/agents/hypothesiser.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,13 @@ def _summarize_node(node) -> str:
189189
if not node:
190190
return "No node"
191191

192-
summary = f" Performance: {node.performance:.4f}\n" if node.performance else " Performance: N/A\n"
192+
if node.performance is not None:
193+
summary = f" Val Performance: {node.performance:.4f}\n"
194+
if node.train_performance is not None:
195+
gap = node.train_performance - node.performance
196+
summary += f" Train Performance: {node.train_performance:.4f} (train-val gap: {gap:+.4f})\n"
197+
else:
198+
summary = " Performance: N/A\n"
193199

194200
if node.plan:
195201
summary += f" Features: {node.plan.features.strategy}\n"
@@ -236,6 +242,11 @@ def _summarize_history(history: list[dict]) -> str:
236242
status = f"✓ {perf:.4f}" if success and perf else ("✗ FAILED" if not success else "pending")
237243
summary += f" Solution {solution_id} ({stage}): {status}\n"
238244

245+
train_perf = entry.get("train_performance")
246+
if success and perf and train_perf is not None:
247+
gap = train_perf - perf
248+
summary += f" Train/val gap: {gap:+.4f}\n"
249+
239250
if entry.get("error"):
240251
summary += f" Error: {entry['error'][:80]}\n"
241252

plexe/agents/insight_extractor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,10 @@ def _summarize_results(self, parent_perf: float | None) -> str:
209209
if perf_delta is not None:
210210
perf_pct = (perf_delta / parent_perf * 100) if parent_perf else 0
211211
perf_str += f" ({perf_delta:+.4f}, {perf_pct:+.1f}%)"
212-
summary += f" Performance: {perf_str}\n"
212+
summary += f" Val Performance: {perf_str}\n"
213+
if sol.train_performance is not None:
214+
gap = sol.train_performance - sol.performance
215+
summary += f" Train Performance: {sol.train_performance:.4f} (train-val gap: {gap:+.4f})\n"
213216
else:
214217
summary += f" Status: {status}\n"
215218

plexe/helpers.py

Lines changed: 47 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import logging
1010
from pathlib import Path
11-
from typing import TYPE_CHECKING
11+
from typing import Any, TYPE_CHECKING
1212

1313
import numpy as np
1414
import pandas as pd
@@ -93,68 +93,85 @@ def evaluate_on_sample(
9393
metric: str,
9494
target_columns: list[str],
9595
group_column: str | None = None,
96-
) -> float:
96+
train_sample_uri: str | None = None,
97+
) -> tuple[float, float | None]:
9798
"""
98-
Evaluate model on sample (fast).
99+
Evaluate model on validation sample, optionally also on training sample.
99100
100101
Args:
101102
spark: SparkSession
102-
sample_uri: Sample data URI
103+
sample_uri: Validation sample data URI
103104
model_artifacts_path: Path to model artifacts
104105
model_type: "xgboost", "catboost", "lightgbm", "keras", or "pytorch"
105106
metric: Metric name
106107
target_columns: Target column names
107108
group_column: Optional group column for ranking metrics (query_id, session_id)
109+
train_sample_uri: Optional training sample URI (for train/val gap computation)
108110
109111
Returns:
110-
Performance value
112+
Tuple of (val_performance, train_performance). train_performance is None
113+
when train_sample_uri is not provided.
111114
"""
115+
predictor = _load_predictor(model_artifacts_path, model_type)
116+
val_performance = _evaluate_predictor(spark, predictor, sample_uri, metric, target_columns, group_column)
117+
logger.info(f"Val sample performance ({metric}): {val_performance:.4f}")
118+
119+
# TODO: Computing secondary metrics (e.g. per-class breakdown, calibration) per solution during search.
120+
train_performance = None
121+
if train_sample_uri:
122+
train_performance = _evaluate_predictor(
123+
spark, predictor, train_sample_uri, metric, target_columns, group_column
124+
)
125+
gap = train_performance - val_performance
126+
logger.info(f"Train sample performance ({metric}): {train_performance:.4f} (train-val gap: {gap:+.4f})")
112127

113-
logger.info(f"Evaluating on sample with metric: {metric}")
114-
115-
# Load Sample
116-
sample_df = spark.read.parquet(sample_uri).toPandas()
117-
118-
# Extract group IDs if ranking task
119-
group_ids = sample_df[group_column].values if group_column and group_column in sample_df.columns else None
120-
121-
# Use column names instead of positional indexing to handle target columns in any position
122-
columns_to_drop = list(target_columns)
123-
if group_column and group_column in sample_df.columns:
124-
columns_to_drop.append(group_column)
128+
return val_performance, train_performance
125129

126-
X_sample = sample_df.drop(columns=columns_to_drop)
127-
y_sample = sample_df[target_columns[0]]
128130

129-
# Load Predictor
131+
def _load_predictor(model_artifacts_path: Path, model_type: str) -> Any:
132+
"""Load the appropriate predictor for a model type."""
130133
if model_type == ModelType.XGBOOST:
131134
from plexe.templates.inference.xgboost_predictor import XGBoostPredictor
132135

133-
predictor = XGBoostPredictor(str(model_artifacts_path))
136+
return XGBoostPredictor(str(model_artifacts_path))
134137
elif model_type == ModelType.CATBOOST:
135138
from plexe.templates.inference.catboost_predictor import CatBoostPredictor
136139

137-
predictor = CatBoostPredictor(str(model_artifacts_path))
140+
return CatBoostPredictor(str(model_artifacts_path))
138141
elif model_type == ModelType.LIGHTGBM:
139142
from plexe.templates.inference.lightgbm_predictor import LightGBMPredictor
140143

141-
predictor = LightGBMPredictor(str(model_artifacts_path))
144+
return LightGBMPredictor(str(model_artifacts_path))
142145
elif model_type == ModelType.KERAS:
143146
from plexe.templates.inference.keras_predictor import KerasPredictor
144147

145-
predictor = KerasPredictor(str(model_artifacts_path))
148+
return KerasPredictor(str(model_artifacts_path))
146149
else:
147150
from plexe.templates.inference.pytorch_predictor import PyTorchPredictor
148151

149-
predictor = PyTorchPredictor(str(model_artifacts_path))
152+
return PyTorchPredictor(str(model_artifacts_path))
150153

151-
# Predict and compute metric on predictions
152-
predictions = predictor.predict(X_sample)["prediction"].values
153-
performance = compute_metric(y_sample, predictions, metric, group_ids=group_ids)
154154

155-
logger.info(f"Sample performance ({metric}): {performance:.4f}")
155+
def _evaluate_predictor(
156+
spark: "SparkSession",
157+
predictor: Any,
158+
data_uri: str,
159+
metric: str,
160+
target_columns: list[str],
161+
group_column: str | None,
162+
) -> float:
163+
"""Run predictor on a dataset and compute metric."""
164+
df = spark.read.parquet(data_uri).toPandas()
165+
group_ids = df[group_column].values if group_column and group_column in df.columns else None
166+
167+
columns_to_drop = list(target_columns)
168+
if group_column and group_column in df.columns:
169+
columns_to_drop.append(group_column)
156170

157-
return performance
171+
X = df.drop(columns=columns_to_drop)
172+
y = df[target_columns[0]]
173+
predictions = predictor.predict(X)["prediction"].values
174+
return compute_metric(y, predictions, metric, group_ids=group_ids)
158175

159176

160177
def compute_metric_hardcoded(y_true, y_pred, metric_name: str) -> float:

plexe/models.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,8 @@ class Solution:
294294

295295
# Execution Results
296296
model_artifacts_path: Path | None = None
297-
performance: float | None = None
297+
performance: float | None = None # Validation set metric
298+
train_performance: float | None = None # Training set metric (for overfitting/underfitting detection)
298299
training_time: float | None = None
299300

300301
# Tree Structure
@@ -359,6 +360,7 @@ def to_dict(self) -> dict:
359360
"parent_solution_id": self.parent.solution_id if self.parent else None,
360361
"child_solution_ids": [c.solution_id for c in self.children],
361362
"performance": self.performance,
363+
"train_performance": self.train_performance,
362364
"training_time": self.training_time,
363365
"is_buggy": self.is_buggy,
364366
"error": self.error,
@@ -398,6 +400,7 @@ def from_dict(d: dict, all_solutions: dict[int, "Solution"]) -> "Solution":
398400
model_type=d["model_type"],
399401
model_artifacts_path=Path(d["model_artifacts_path"]) if d.get("model_artifacts_path") else None,
400402
performance=d.get("performance"),
403+
train_performance=d.get("train_performance"),
401404
training_time=d.get("training_time"),
402405
parent=None, # Will be linked in second pass
403406
children=[], # Will be linked in second pass

plexe/search/journal.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def get_history(self, limit: int = 10) -> list[dict]:
130130
"stage": _compute_stage(node),
131131
"success": not node.is_buggy,
132132
"performance": node.performance,
133+
"train_performance": node.train_performance,
133134
"error": node.error,
134135
}
135136

plexe/workflow.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,11 @@ def build_model(
499499
logger.error("No search journal or good nodes available for fallback")
500500
logger.error("Proceeding with failed model (will be marked as FAILED)")
501501

502+
# TODO(EVAL_REFINEMENT_LOOP): When verdict is CONDITIONAL_PASS or FAIL with actionable
503+
# HIGH-priority recommendations, summarize evaluation findings into structured feedback,
504+
# loop back to search_models() for 1-3 targeted iterations, then re-evaluate.
505+
# This creates a closed loop: search -> evaluate -> refine -> re-evaluate.
506+
502507
# Phase 6: Package Final Deliverables
503508
if start_phase <= 6:
504509
with tracer.start_as_current_span("Phase 6: Package Final Model"):
@@ -1260,19 +1265,21 @@ def _execute_variant(
12601265
(src_dir / "__init__.py").write_text("# Auto-generated\n")
12611266

12621267
# Evaluate
1263-
performance = evaluate_on_sample(
1268+
performance, train_performance = evaluate_on_sample(
12641269
spark=spark,
12651270
sample_uri=variant_context.val_sample_uri,
12661271
model_artifacts_path=model_artifacts_path,
12671272
model_type=plan.model.model_type,
12681273
metric=variant_context.metric.name,
12691274
target_columns=variant_context.output_targets,
12701275
group_column=variant_context.group_column,
1276+
train_sample_uri=variant_context.train_sample_uri,
12711277
)
12721278

12731279
# Update solution
12741280
new_solution.model_artifacts_path = model_artifacts_path
12751281
new_solution.performance = float(performance)
1282+
new_solution.train_performance = float(train_performance) if train_performance is not None else None
12761283
new_solution.training_time = time.time() - start_time
12771284
new_solution.is_buggy = False
12781285

@@ -1405,6 +1412,17 @@ def search_models(
14051412

14061413
logger.info(f"Generated {len(plans)} bootstrap plan(s)")
14071414

1415+
# Synthetic hypothesis for insight extraction on bootstrap results
1416+
hypothesis = Hypothesis(
1417+
expand_solution_id=-1,
1418+
focus="both",
1419+
vary="bootstrap_initial_solutions",
1420+
num_variants=len(plans),
1421+
rationale="Initial diverse solutions to seed the search tree",
1422+
keep_from_parent=[],
1423+
expected_impact="Establish baseline performance range across strategies",
1424+
)
1425+
14081426
except Exception as e:
14091427
logger.error(f"Bootstrap planning failed: {e}")
14101428
continue # Skip this iteration
@@ -1506,10 +1524,9 @@ def search_models(
15061524
solution_id_counter += len(variant_ids)
15071525

15081526
# ============================================
1509-
# Step 2e: Extract Insights from Variants (skip in bootstrap mode)
1527+
# Step 2e: Extract Insights from Variants
15101528
# ============================================
1511-
if variant_solutions and expand_solution_id is not None:
1512-
# Only extract insights when we have a hypothesis to learn from
1529+
if variant_solutions:
15131530
try:
15141531
InsightExtractorAgent(
15151532
hypothesis=hypothesis,
@@ -1699,7 +1716,7 @@ def retrain_on_full_dataset(
16991716
# ============================================
17001717
logger.info("Evaluating final model on full validation set...")
17011718

1702-
final_val_performance = evaluate_on_sample(
1719+
final_val_performance, _ = evaluate_on_sample(
17031720
spark=spark,
17041721
sample_uri=context.val_uri, # ← FULL validation set
17051722
model_artifacts_path=final_artifacts_path,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "plexe"
3-
version = "1.4.0"
3+
version = "1.4.1"
44
description = "An agentic framework for building ML models from natural language"
55
authors = [
66
"Marcello De Bernardi <mdebernardi@plexe.ai>",

tests/CODE_INDEX.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Code Index: tests
22

3-
> Generated on 2026-03-02 22:03:39
3+
> Generated on 2026-03-03 00:06:47
44
55
Test suite structure and test case documentation.
66

@@ -122,6 +122,8 @@ Unit tests for SearchJournal.
122122
- `test_journal_get_history()` - Test history returns recent entries.
123123
- `test_journal_improvement_trend_improving()` - Test improvement trend with steadily improving solutions.
124124
- `test_journal_improvement_trend_insufficient_data()` - Test improvement trend with fewer than 2 successful solutions.
125+
- `test_journal_get_history_includes_train_performance()` - get_history should include train_performance when set on a solution.
126+
- `test_journal_get_history_train_performance_none()` - get_history should include train_performance=None when not set.
125127

126128
---
127129
## `unit/search/test_tree_policy_determinism.py`
@@ -208,6 +210,10 @@ Unit tests for core model dataclasses.
208210

209211
**Functions:**
210212
- `test_build_context_update_and_unknown_key()` - Update should set known fields and reject unknown keys.
213+
- `test_solution_train_performance_defaults_to_none()` - New field should default to None for backward compatibility.
214+
- `test_solution_to_dict_includes_train_performance()` - to_dict should serialize train_performance.
215+
- `test_solution_from_dict_backward_compatible()` - Old checkpoints missing train_performance should deserialize cleanly.
216+
- `test_solution_from_dict_with_train_performance()` - Checkpoints with train_performance should round-trip correctly.
211217

212218
---
213219
## `unit/test_submission_pytorch.py`

tests/unit/search/test_journal.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,3 +170,28 @@ def test_journal_improvement_trend_insufficient_data():
170170
trend = journal.get_improvement_trend()
171171

172172
assert trend == 0.0
173+
174+
175+
# ============================================
176+
# Train Performance in History Tests
177+
# ============================================
178+
179+
180+
def test_journal_get_history_includes_train_performance():
181+
"""get_history should include train_performance when set on a solution."""
182+
journal = SearchJournal()
183+
sol = _make_solution(0, performance=0.85)
184+
sol.train_performance = 0.92
185+
journal.add_node(sol)
186+
187+
history = journal.get_history()
188+
assert history[0]["train_performance"] == 0.92
189+
190+
191+
def test_journal_get_history_train_performance_none():
192+
"""get_history should include train_performance=None when not set."""
193+
journal = SearchJournal()
194+
journal.add_node(_make_solution(0, performance=0.85))
195+
196+
history = journal.get_history()
197+
assert history[0]["train_performance"] is None

0 commit comments

Comments
 (0)