|
8 | 8 |
|
9 | 9 | import logging |
10 | 10 | from pathlib import Path |
11 | | -from typing import TYPE_CHECKING |
| 11 | +from typing import Any, TYPE_CHECKING |
12 | 12 |
|
13 | 13 | import numpy as np |
14 | 14 | import pandas as pd |
@@ -93,68 +93,85 @@ def evaluate_on_sample( |
93 | 93 | metric: str, |
94 | 94 | target_columns: list[str], |
95 | 95 | group_column: str | None = None, |
96 | | -) -> float: |
| 96 | + train_sample_uri: str | None = None, |
| 97 | +) -> tuple[float, float | None]: |
97 | 98 | """ |
98 | | - Evaluate model on sample (fast). |
| 99 | + Evaluate model on validation sample, optionally also on training sample. |
99 | 100 |
|
100 | 101 | Args: |
101 | 102 | spark: SparkSession |
102 | | - sample_uri: Sample data URI |
| 103 | + sample_uri: Validation sample data URI |
103 | 104 | model_artifacts_path: Path to model artifacts |
104 | 105 | model_type: "xgboost", "catboost", "lightgbm", "keras", or "pytorch" |
105 | 106 | metric: Metric name |
106 | 107 | target_columns: Target column names |
107 | 108 | group_column: Optional group column for ranking metrics (query_id, session_id) |
| 109 | + train_sample_uri: Optional training sample URI (for train/val gap computation) |
108 | 110 |
|
109 | 111 | Returns: |
110 | | - Performance value |
| 112 | + Tuple of (val_performance, train_performance). train_performance is None |
| 113 | + when train_sample_uri is not provided. |
111 | 114 | """ |
| 115 | + predictor = _load_predictor(model_artifacts_path, model_type) |
| 116 | + val_performance = _evaluate_predictor(spark, predictor, sample_uri, metric, target_columns, group_column) |
| 117 | + logger.info(f"Val sample performance ({metric}): {val_performance:.4f}") |
| 118 | + |
| 119 | + # TODO: Computing secondary metrics (e.g. per-class breakdown, calibration) per solution during search. |
| 120 | + train_performance = None |
| 121 | + if train_sample_uri: |
| 122 | + train_performance = _evaluate_predictor( |
| 123 | + spark, predictor, train_sample_uri, metric, target_columns, group_column |
| 124 | + ) |
| 125 | + gap = train_performance - val_performance |
| 126 | + logger.info(f"Train sample performance ({metric}): {train_performance:.4f} (train-val gap: {gap:+.4f})") |
112 | 127 |
|
113 | | - logger.info(f"Evaluating on sample with metric: {metric}") |
114 | | - |
115 | | - # Load Sample |
116 | | - sample_df = spark.read.parquet(sample_uri).toPandas() |
117 | | - |
118 | | - # Extract group IDs if ranking task |
119 | | - group_ids = sample_df[group_column].values if group_column and group_column in sample_df.columns else None |
120 | | - |
121 | | - # Use column names instead of positional indexing to handle target columns in any position |
122 | | - columns_to_drop = list(target_columns) |
123 | | - if group_column and group_column in sample_df.columns: |
124 | | - columns_to_drop.append(group_column) |
| 128 | + return val_performance, train_performance |
125 | 129 |
|
126 | | - X_sample = sample_df.drop(columns=columns_to_drop) |
127 | | - y_sample = sample_df[target_columns[0]] |
128 | 130 |
|
129 | | - # Load Predictor |
| 131 | +def _load_predictor(model_artifacts_path: Path, model_type: str) -> Any: |
| 132 | + """Load the appropriate predictor for a model type.""" |
130 | 133 | if model_type == ModelType.XGBOOST: |
131 | 134 | from plexe.templates.inference.xgboost_predictor import XGBoostPredictor |
132 | 135 |
|
133 | | - predictor = XGBoostPredictor(str(model_artifacts_path)) |
| 136 | + return XGBoostPredictor(str(model_artifacts_path)) |
134 | 137 | elif model_type == ModelType.CATBOOST: |
135 | 138 | from plexe.templates.inference.catboost_predictor import CatBoostPredictor |
136 | 139 |
|
137 | | - predictor = CatBoostPredictor(str(model_artifacts_path)) |
| 140 | + return CatBoostPredictor(str(model_artifacts_path)) |
138 | 141 | elif model_type == ModelType.LIGHTGBM: |
139 | 142 | from plexe.templates.inference.lightgbm_predictor import LightGBMPredictor |
140 | 143 |
|
141 | | - predictor = LightGBMPredictor(str(model_artifacts_path)) |
| 144 | + return LightGBMPredictor(str(model_artifacts_path)) |
142 | 145 | elif model_type == ModelType.KERAS: |
143 | 146 | from plexe.templates.inference.keras_predictor import KerasPredictor |
144 | 147 |
|
145 | | - predictor = KerasPredictor(str(model_artifacts_path)) |
| 148 | + return KerasPredictor(str(model_artifacts_path)) |
146 | 149 | else: |
147 | 150 | from plexe.templates.inference.pytorch_predictor import PyTorchPredictor |
148 | 151 |
|
149 | | - predictor = PyTorchPredictor(str(model_artifacts_path)) |
| 152 | + return PyTorchPredictor(str(model_artifacts_path)) |
150 | 153 |
|
151 | | - # Predict and compute metric on predictions |
152 | | - predictions = predictor.predict(X_sample)["prediction"].values |
153 | | - performance = compute_metric(y_sample, predictions, metric, group_ids=group_ids) |
154 | 154 |
|
155 | | - logger.info(f"Sample performance ({metric}): {performance:.4f}") |
| 155 | +def _evaluate_predictor( |
| 156 | + spark: "SparkSession", |
| 157 | + predictor: Any, |
| 158 | + data_uri: str, |
| 159 | + metric: str, |
| 160 | + target_columns: list[str], |
| 161 | + group_column: str | None, |
| 162 | +) -> float: |
| 163 | + """Run predictor on a dataset and compute metric.""" |
| 164 | + df = spark.read.parquet(data_uri).toPandas() |
| 165 | + group_ids = df[group_column].values if group_column and group_column in df.columns else None |
| 166 | + |
| 167 | + columns_to_drop = list(target_columns) |
| 168 | + if group_column and group_column in df.columns: |
| 169 | + columns_to_drop.append(group_column) |
156 | 170 |
|
157 | | - return performance |
| 171 | + X = df.drop(columns=columns_to_drop) |
| 172 | + y = df[target_columns[0]] |
| 173 | + predictions = predictor.predict(X)["prediction"].values |
| 174 | + return compute_metric(y, predictions, metric, group_ids=group_ids) |
158 | 175 |
|
159 | 176 |
|
160 | 177 | def compute_metric_hardcoded(y_true, y_pred, metric_name: str) -> float: |
|
0 commit comments