diff --git a/backend/python-eval-function/src/accuracy_evaluator.py b/backend/python-eval-function/src/accuracy_evaluator.py index a2785a3..a73538b 100644 --- a/backend/python-eval-function/src/accuracy_evaluator.py +++ b/backend/python-eval-function/src/accuracy_evaluator.py @@ -1,7 +1,9 @@ import logging -from typing import Optional, List, Dict +from typing import Optional, List, Mapping from dataclasses import dataclass +from metrics import is_metric_enabled + logger = logging.getLogger(__name__) @@ -29,7 +31,8 @@ def __init__(self): def calculate_accuracy_metrics( self, predictions: List[str], - references: Optional[List[str]] + references: Optional[List[str]], + selected: Optional[Mapping[str, bool]] = None, ) -> Optional[AccuracyMetrics]: if references is None or len(references) == 0: logger.info("No reference outputs available, skipping accuracy metrics") @@ -38,18 +41,26 @@ def calculate_accuracy_metrics( if all(ref is None or ref == "" for ref in references): logger.info("All reference outputs are empty, skipping accuracy metrics") return None - - logger.info(f"Calculating accuracy metrics for {len(predictions)} predictions") - + + logger.info( + f"Calculating accuracy metrics for {len(predictions)} predictions " + f"(selected={dict(selected) if selected is not None else 'all'})" + ) + try: metrics = AccuracyMetrics() - - metrics.bleu = self._calculate_bleu(predictions, references) - metrics.rouge = self._calculate_rouge(predictions, references) - metrics.meteor = self._calculate_meteor(predictions, references) - metrics.levenshtein = self._calculate_levenshtein(predictions, references) - metrics.bertscore = self._calculate_bertscore(predictions, references) - + + if is_metric_enabled(selected, 'bleu'): + metrics.bleu = self._calculate_bleu(predictions, references) + if is_metric_enabled(selected, 'rouge'): + metrics.rouge = self._calculate_rouge(predictions, references) + if is_metric_enabled(selected, 'meteor'): + metrics.meteor = self._calculate_meteor(predictions, references) + if is_metric_enabled(selected, 'levenshtein'): + metrics.levenshtein = self._calculate_levenshtein(predictions, references) + if is_metric_enabled(selected, 'bertscore'): + metrics.bertscore = self._calculate_bertscore(predictions, references) + def fmt(v): return f"{v:.4f}" if v is not None else "N/A" logger.info( f"Accuracy metrics calculated - " @@ -57,9 +68,9 @@ def fmt(v): return f"{v:.4f}" if v is not None else "N/A" f"METEOR={fmt(metrics.meteor)}, Levenshtein={fmt(metrics.levenshtein)}, " f"BERTScore={fmt(metrics.bertscore)}" ) - + return metrics - + except Exception as e: logger.error(f"Error calculating accuracy metrics: {e}", exc_info=True) return None diff --git a/backend/python-eval-function/src/classification_evaluator.py b/backend/python-eval-function/src/classification_evaluator.py index 44a64d7..69ea449 100644 --- a/backend/python-eval-function/src/classification_evaluator.py +++ b/backend/python-eval-function/src/classification_evaluator.py @@ -1,7 +1,9 @@ import logging -from typing import Optional, List +from typing import Optional, List, Mapping from dataclasses import dataclass +from metrics import is_metric_enabled + logger = logging.getLogger(__name__) @@ -13,6 +15,14 @@ class ClassificationMetrics: f1_macro: Optional[float] = None f1_weighted: Optional[float] = None +_CLASSIFICATION_KEYS: tuple[str, ...] = ( + 'classification_accuracy', + 'precision_macro', + 'recall_macro', + 'f1_macro', + 'f1_weighted', +) + def normalize_prediction(prediction: str, valid_classes: List[str]) -> str: cleaned = prediction.strip() @@ -38,6 +48,7 @@ def calculate_classification_metrics( predictions: List[str], references: List[str], valid_classes: Optional[List[str]] = None, + selected: Optional[Mapping[str, bool]] = None, ) -> Optional[ClassificationMetrics]: if not references or not predictions: logger.info("No references or predictions, skipping classification metrics") @@ -47,6 +58,12 @@ def calculate_classification_metrics( logger.info("All references are empty, skipping classification metrics") return None + if selected is not None and not any( + is_metric_enabled(selected, k) for k in _CLASSIFICATION_KEYS + ): + logger.info("All classification metrics disabled, skipping computation") + return ClassificationMetrics() + if valid_classes is None: valid_classes = list(set(references)) @@ -56,7 +73,7 @@ def calculate_classification_metrics( logger.info( f"Calculating classification metrics for {len(normalized_preds)} predictions " - f"across {len(valid_classes)} classes" + f"across {len(valid_classes)} classes (selected={dict(selected) if selected is not None else 'all'})" ) try: @@ -86,11 +103,21 @@ def calculate_classification_metrics( ) metrics = ClassificationMetrics( - accuracy=round(acc, 4), - precision_macro=round(precision, 4), - recall_macro=round(recall, 4), - f1_macro=round(f1, 4), - f1_weighted=round(f1_w, 4), + accuracy=round(acc, 4) + if is_metric_enabled(selected, 'classification_accuracy') + else None, + precision_macro=round(precision, 4) + if is_metric_enabled(selected, 'precision_macro') + else None, + recall_macro=round(recall, 4) + if is_metric_enabled(selected, 'recall_macro') + else None, + f1_macro=round(f1, 4) + if is_metric_enabled(selected, 'f1_macro') + else None, + f1_weighted=round(f1_w, 4) + if is_metric_enabled(selected, 'f1_weighted') + else None, ) logger.info( diff --git a/backend/python-eval-function/src/dynamodb_service.py b/backend/python-eval-function/src/dynamodb_service.py index 4841e83..63504fa 100644 --- a/backend/python-eval-function/src/dynamodb_service.py +++ b/backend/python-eval-function/src/dynamodb_service.py @@ -7,6 +7,8 @@ import boto3 from botocore.exceptions import ClientError +from metrics import normalize_metrics_config + logger = logging.getLogger(__name__) @@ -28,18 +30,32 @@ def load_job(self, evaluation_id: str) -> Dict[str, Any]: item = response['Item'] models = json.loads(item.get('models', '[]')) weights = json.loads(item.get('weights', '{}')) - + raw_metrics = item.get('metrics') + stored_metrics: Optional[Dict[str, Any]] = None + if raw_metrics: + try: + stored_metrics = json.loads(raw_metrics) + except (TypeError, ValueError) as parse_err: + logger.warning( + f"Failed to parse stored metrics config " + f"({parse_err}); defaulting to all metrics enabled" + ) + metrics = normalize_metrics_config(stored_metrics) + job_config = { 'evaluation_id': item['evaluation_id'], 'dataset_id': item['dataset_id'], 'models': models, 'weights': weights, + 'metrics': metrics, 'status': item.get('status', 'pending'), 'created_at': item.get('created_at', ''), 'total_samples': item.get('total_samples') } - - logger.info(f"Loaded job config: {len(models)} models, weights: {weights}") + + logger.info( + f"Loaded job config: {len(models)} models, weights: {weights}, metrics: {metrics}" + ) return job_config except ClientError as e: diff --git a/backend/python-eval-function/src/geval_evaluator.py b/backend/python-eval-function/src/geval_evaluator.py index 0665200..8a6d581 100644 --- a/backend/python-eval-function/src/geval_evaluator.py +++ b/backend/python-eval-function/src/geval_evaluator.py @@ -99,6 +99,8 @@ def evaluate( predictions: List[str], references: Optional[List[str]] = None, task_type: str = "summarization", + compute_reasoning: bool = True, + compute_faithfulness: bool = True, ) -> GEvalMetrics: if not inputs or not predictions: logger.warning("Empty inputs or predictions, skipping G-Eval") @@ -108,10 +110,22 @@ def evaluate( logger.error("inputs and predictions length mismatch, skipping G-Eval") return GEvalMetrics() + if not compute_reasoning and not compute_faithfulness: + logger.info("Both G-Eval metrics disabled, skipping") + return GEvalMetrics() + try: model = self._build_judge_model() - reasoning_metric = self._build_reasoning_metric(model, task_type) - faithfulness_metric = self._build_faithfulness_metric(model, task_type) + reasoning_metric = ( + self._build_reasoning_metric(model, task_type) + if compute_reasoning + else None + ) + faithfulness_metric = ( + self._build_faithfulness_metric(model, task_type) + if compute_faithfulness + else None + ) except Exception as e: logger.error(f"Failed to initialize G-Eval components: {e}", exc_info=True) return GEvalMetrics() @@ -131,23 +145,25 @@ def evaluate( actual_output=pred, ) - try: - reasoning_metric.measure(test_case) - reasoning_scores.append(reasoning_metric.score) - logger.debug( - f"[{idx}] Reasoning score={reasoning_metric.score:.4f} reason={reasoning_metric.reason}" - ) - except Exception as e: - logger.warning(f"[{idx}] Reasoning metric failed: {e}") - - try: - faithfulness_metric.measure(test_case) - faithfulness_scores.append(faithfulness_metric.score) - logger.debug( - f"[{idx}] Faithfulness score={faithfulness_metric.score:.4f} reason={faithfulness_metric.reason}" - ) - except Exception as e: - logger.warning(f"[{idx}] Faithfulness metric failed: {e}") + if reasoning_metric is not None: + try: + reasoning_metric.measure(test_case) + reasoning_scores.append(reasoning_metric.score) + logger.debug( + f"[{idx}] Reasoning score={reasoning_metric.score:.4f} reason={reasoning_metric.reason}" + ) + except Exception as e: + logger.warning(f"[{idx}] Reasoning metric failed: {e}") + + if faithfulness_metric is not None: + try: + faithfulness_metric.measure(test_case) + faithfulness_scores.append(faithfulness_metric.score) + logger.debug( + f"[{idx}] Faithfulness score={faithfulness_metric.score:.4f} reason={faithfulness_metric.reason}" + ) + except Exception as e: + logger.warning(f"[{idx}] Faithfulness metric failed: {e}") result = GEvalMetrics() diff --git a/backend/python-eval-function/src/main.py b/backend/python-eval-function/src/main.py index 662f58a..ebfdd3f 100644 --- a/backend/python-eval-function/src/main.py +++ b/backend/python-eval-function/src/main.py @@ -74,8 +74,18 @@ def main(): dataset_id = job_config['dataset_id'] models = job_config['models'] weights = job_config['weights'] - - logger.info(f"Job config loaded - dataset: {dataset_id}, models: {len(models)}") + # `metrics_config` is already normalized by DynamoDBService.load_job — + # every key present, all bool. + metrics_config = job_config.get('metrics') or {} + + from metrics import is_metric_enabled + compute_geval_reasoning = is_metric_enabled(metrics_config, 'geval_reasoning') + compute_geval_faithfulness = is_metric_enabled(metrics_config, 'geval_faithfulness') + + logger.info( + f"Job config loaded - dataset: {dataset_id}, models: {len(models)}, " + f"metrics={metrics_config or 'all'}" + ) from dataset_loader import DatasetLoader dataset_loader = DatasetLoader() @@ -134,7 +144,9 @@ def main(): predictions = [invocation_results[i].response_text for i in successful_indices] references = [all_references[i] for i in successful_indices] if all_references else None - accuracy_metrics = accuracy_evaluator.calculate_accuracy_metrics(predictions, references) + accuracy_metrics = accuracy_evaluator.calculate_accuracy_metrics( + predictions, references, selected=metrics_config + ) accuracy_results[model_id] = accuracy_metrics or AccuracyMetrics() logger.info(f"Summarization accuracy complete for {len(accuracy_results)} models") @@ -149,7 +161,8 @@ def main(): references = [all_references[i] for i in successful_indices] if all_references else None cls_metrics = classification_evaluator.calculate_classification_metrics( - predictions, references, valid_classes=unique_classes + predictions, references, valid_classes=unique_classes, + selected=metrics_config, ) acc = AccuracyMetrics() @@ -163,28 +176,40 @@ def main(): logger.info(f"Classification accuracy complete for {len(accuracy_results)} models") - from geval_evaluator import GEvalEvaluator - geval_evaluator = GEvalEvaluator() - - for model_id, invocation_results in results_by_model.items(): - successful_indices = [i for i, r in enumerate(invocation_results) if r.error is None] - predictions = [invocation_results[i].response_text for i in successful_indices] - inputs = [dataset.documents[i] for i in successful_indices] - - logger.info(f"Running G-Eval for model {model_id} on {len(predictions)} samples") - geval_metrics = geval_evaluator.evaluate(inputs, predictions, task_type=task_type) - - acc = accuracy_results.get(model_id) - if acc is None: - acc = AccuracyMetrics() - accuracy_results[model_id] = acc - - acc.geval_reasoning = geval_metrics.reasoning - acc.geval_faithfulness = geval_metrics.faithfulness - - logger.info(f"G-Eval complete for {model_id}") - - logger.info(f"G-Eval evaluation complete for all models") + if compute_geval_reasoning or compute_geval_faithfulness: + from geval_evaluator import GEvalEvaluator + geval_evaluator = GEvalEvaluator() + + for model_id, invocation_results in results_by_model.items(): + successful_indices = [i for i, r in enumerate(invocation_results) if r.error is None] + predictions = [invocation_results[i].response_text for i in successful_indices] + inputs = [dataset.documents[i] for i in successful_indices] + + logger.info( + f"Running G-Eval for model {model_id} on {len(predictions)} samples " + f"(reasoning={compute_geval_reasoning}, faithfulness={compute_geval_faithfulness})" + ) + geval_metrics = geval_evaluator.evaluate( + inputs, predictions, task_type=task_type, + compute_reasoning=compute_geval_reasoning, + compute_faithfulness=compute_geval_faithfulness, + ) + + acc = accuracy_results.get(model_id) + if acc is None: + acc = AccuracyMetrics() + accuracy_results[model_id] = acc + + if compute_geval_reasoning: + acc.geval_reasoning = geval_metrics.reasoning + if compute_geval_faithfulness: + acc.geval_faithfulness = geval_metrics.faithfulness + + logger.info(f"G-Eval complete for {model_id}") + + logger.info("G-Eval evaluation complete for all models") + else: + logger.info("G-Eval disabled via metrics config, skipping") from cost_calculator import CostCalculator cost_calculator = CostCalculator() diff --git a/backend/python-eval-function/src/metrics.py b/backend/python-eval-function/src/metrics.py new file mode 100644 index 0000000..fad56c2 --- /dev/null +++ b/backend/python-eval-function/src/metrics.py @@ -0,0 +1,35 @@ +from typing import Dict, Mapping, Optional + +METRIC_KEYS: tuple[str, ...] = ( + # Algorithmic — summarization + 'bleu', + 'rouge', + 'meteor', + 'levenshtein', + 'bertscore', + # Algorithmic — classification + 'classification_accuracy', + 'precision_macro', + 'recall_macro', + 'f1_macro', + 'f1_weighted', + # LLM-as-judge + 'geval_reasoning', + 'geval_faithfulness', +) + + +def is_metric_enabled( + selected: Optional[Mapping[str, bool]], + key: str, +) -> bool: + if selected is None: + return True + return bool(selected.get(key, True)) + + +def normalize_metrics_config( + stored: Optional[Mapping[str, object]], +) -> Dict[str, bool]: + source: Mapping[str, object] = stored or {} + return {key: bool(source.get(key, True)) for key in METRIC_KEYS} diff --git a/backend/src/handlers/EvaluationLaunch/EvaluationLaunchAdapter.ts b/backend/src/handlers/EvaluationLaunch/EvaluationLaunchAdapter.ts index 283a89e..c130c7f 100644 --- a/backend/src/handlers/EvaluationLaunch/EvaluationLaunchAdapter.ts +++ b/backend/src/handlers/EvaluationLaunch/EvaluationLaunchAdapter.ts @@ -3,12 +3,21 @@ import type { APIGatewayProxyEventV2, APIGatewayProxyResultV2, } from 'aws-lambda'; -import { z } from 'zod'; +import { z, ZodRawShape } from 'zod'; +import { METRIC_KEYS } from '../../models/Evaluation'; import { tokenEvaluationLaunchUseCase } from '../../useCases/EvaluationLaunch/EvaluationLaunchUseCase'; import { handleHttpRequest } from '../api/handleHttpRequest'; import { parseApiEvent } from '../api/parseApiEvent'; +const MetricsConfigSchema = z + .object( + Object.fromEntries( + METRIC_KEYS.map((key) => [key, z.boolean().optional()]), + ) as ZodRawShape, + ) + .optional(); + const EvaluationRequestSchema = z.object({ dataset_id: z.string().min(1), models: z @@ -32,6 +41,7 @@ const EvaluationRequestSchema = z.object({ cost: z.number().optional(), }) .optional(), + metrics: MetricsConfigSchema, }); export class EvaluationLaunchAdapter { diff --git a/backend/src/models/Evaluation.ts b/backend/src/models/Evaluation.ts index 60e7d2e..92b4ab0 100644 --- a/backend/src/models/Evaluation.ts +++ b/backend/src/models/Evaluation.ts @@ -9,10 +9,49 @@ export interface WeightConfig { cost: number; } +export const METRIC_KEYS = [ + // Algorithmic — summarization + 'bleu', + 'rouge', + 'meteor', + 'levenshtein', + 'bertscore', + // Algorithmic — classification + 'classification_accuracy', + 'precision_macro', + 'recall_macro', + 'f1_macro', + 'f1_weighted', + // LLM-as-judge + 'geval_reasoning', + 'geval_faithfulness', +] as const; + +export type MetricKey = (typeof METRIC_KEYS)[number]; + +export type MetricsConfig = Record; + +export const DEFAULT_METRICS_CONFIG: MetricsConfig = Object.fromEntries( + METRIC_KEYS.map((k) => [k, true]), +) as MetricsConfig; + +export function resolveMetricsConfig( + metrics?: Partial, +): MetricsConfig { + const result = { ...DEFAULT_METRICS_CONFIG }; + if (!metrics) return result; + for (const key of METRIC_KEYS) { + const provided = metrics[key]; + if (provided !== undefined) result[key] = provided; + } + return result; +} + export interface EvaluationRequest { dataset_id: string; models: ModelConfig[]; weights?: Partial; + metrics?: Partial; } export interface EvaluationJob { @@ -20,6 +59,7 @@ export interface EvaluationJob { dataset_id: string; models: ModelConfig[]; weights: WeightConfig; + metrics: MetricsConfig; status: JobStatus; progress: number; current_model?: string; diff --git a/backend/src/services/EvaluationJobsRepository/EvaluationJobsRepository.ts b/backend/src/services/EvaluationJobsRepository/EvaluationJobsRepository.ts index 446cc54..a911806 100644 --- a/backend/src/services/EvaluationJobsRepository/EvaluationJobsRepository.ts +++ b/backend/src/services/EvaluationJobsRepository/EvaluationJobsRepository.ts @@ -8,8 +8,10 @@ import { import { createInjectionToken, inject } from '@trackit.io/di-container'; import { randomUUID } from 'crypto'; import { + DEFAULT_METRICS_CONFIG, EvaluationJob, JobStatus, + MetricsConfig, ModelConfig, ModelResult, Recommendation, @@ -21,6 +23,7 @@ export type EvaluationJobsRepository = { datasetId: string, models: ModelConfig[], weights: WeightConfig, + metrics: MetricsConfig, ): Promise; updateEvaluation( @@ -51,6 +54,7 @@ class EvaluationJobsRepositoryImpl implements EvaluationJobsRepository { datasetId: string, models: ModelConfig[], weights: WeightConfig, + metrics: MetricsConfig, ): Promise { const evaluationId = randomUUID(); const now = new Date().toISOString(); @@ -60,6 +64,7 @@ class EvaluationJobsRepositoryImpl implements EvaluationJobsRepository { dataset_id: datasetId, models, weights, + metrics, status: 'pending', progress: 0, created_at: now, @@ -74,6 +79,7 @@ class EvaluationJobsRepositoryImpl implements EvaluationJobsRepository { dataset_id: { S: job.dataset_id }, models: { S: JSON.stringify(job.models) }, weights: { S: JSON.stringify(job.weights) }, + metrics: { S: JSON.stringify(job.metrics) }, status: { S: job.status }, progress: { N: job.progress.toString() }, created_at: { S: job.created_at }, @@ -176,6 +182,9 @@ class EvaluationJobsRepositoryImpl implements EvaluationJobsRepository { dataset_id: item['dataset_id'].S!, models: JSON.parse(item['models'].S!) as ModelConfig[], weights: JSON.parse(item['weights'].S!) as WeightConfig, + metrics: item['metrics']?.S + ? (JSON.parse(item['metrics'].S) as MetricsConfig) + : { ...DEFAULT_METRICS_CONFIG }, status: item['status'].S! as JobStatus, progress: Number(item['progress'].N ?? '0'), current_model: item['current_model']?.S, diff --git a/backend/src/useCases/EvaluationLaunch/EvaluationLaunchUseCase.test.ts b/backend/src/useCases/EvaluationLaunch/EvaluationLaunchUseCase.test.ts index 045f104..056afe0 100644 --- a/backend/src/useCases/EvaluationLaunch/EvaluationLaunchUseCase.test.ts +++ b/backend/src/useCases/EvaluationLaunch/EvaluationLaunchUseCase.test.ts @@ -1,8 +1,13 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'; -import type { EvaluationRequest } from '../../models/Evaluation.js'; +import { + DEFAULT_METRICS_CONFIG, + type EvaluationRequest, +} from '../../models/Evaluation.js'; import { FakeBedrockModelValidationService } from '../../services/BedrockModelValidationService/FakeBedrockModelValidationService.js'; import { FakeEvaluationLaunchUseCase } from './FakeEvaluationLaunchUseCase'; +const ALL_METRICS_ENABLED = { ...DEFAULT_METRICS_CONFIG }; + // Mock dependencies const mockEvaluationJobsRepository = { createEvaluation: vi.fn(), @@ -64,6 +69,7 @@ describe('EvaluationLaunchUseCase - Weight Configuration', () => { latency: 0.33, cost: 0.34, }, + ALL_METRICS_ENABLED, ); }); @@ -91,10 +97,81 @@ describe('EvaluationLaunchUseCase - Weight Configuration', () => { latency: 0.33, cost: 0.34, }, + ALL_METRICS_ENABLED, ); }); }); + describe('Metrics toggles', () => { + it('should default to all metrics enabled when metrics not provided', async () => { + const request: EvaluationRequest = { + dataset_id: 'test-dataset-id', + models: [{ type: 'default', identifier: 'claude-sonnet' }], + }; + + await useCase.launchEvaluation(request); + + const metricsArg = + mockEvaluationJobsRepository.createEvaluation.mock.calls[0][3]; + expect(metricsArg).toEqual(ALL_METRICS_ENABLED); + }); + + it('should respect explicitly disabled metrics', async () => { + const request: EvaluationRequest = { + dataset_id: 'test-dataset-id', + models: [{ type: 'default', identifier: 'claude-sonnet' }], + metrics: { + bertscore: false, + geval_reasoning: false, + geval_faithfulness: false, + }, + }; + + await useCase.launchEvaluation(request); + + const metricsArg = + mockEvaluationJobsRepository.createEvaluation.mock.calls[0][3]; + expect(metricsArg).toEqual({ + ...ALL_METRICS_ENABLED, + bertscore: false, + geval_reasoning: false, + geval_faithfulness: false, + }); + }); + + it('should fill in defaults for partially provided metrics', async () => { + const request: EvaluationRequest = { + dataset_id: 'test-dataset-id', + models: [{ type: 'default', identifier: 'claude-sonnet' }], + metrics: { geval_reasoning: false, geval_faithfulness: false }, + }; + + await useCase.launchEvaluation(request); + + const metricsArg = + mockEvaluationJobsRepository.createEvaluation.mock.calls[0][3]; + expect(metricsArg).toEqual({ + ...ALL_METRICS_ENABLED, + geval_reasoning: false, + geval_faithfulness: false, + }); + }); + + it('should support disabling a single algorithmic metric', async () => { + const request: EvaluationRequest = { + dataset_id: 'test-dataset-id', + models: [{ type: 'default', identifier: 'claude-sonnet' }], + metrics: { bleu: false }, + }; + + await useCase.launchEvaluation(request); + + const metricsArg = + mockEvaluationJobsRepository.createEvaluation.mock.calls[0][3]; + expect(metricsArg).toEqual({ ...ALL_METRICS_ENABLED, bleu: false }); + }); + }); + describe('Negative weight rejection', () => { it('should reject negative accuracy weight', async () => { const request: EvaluationRequest = { @@ -216,6 +293,7 @@ describe('EvaluationLaunchUseCase - Weight Configuration', () => { latency: 0.33, cost: 0.34, }, + ALL_METRICS_ENABLED, ); }); @@ -349,11 +427,16 @@ describe('EvaluationLaunchUseCase - Weight Configuration', () => { expect( mockEvaluationJobsRepository.createEvaluation, - ).toHaveBeenCalledWith('test-dataset-id', request.models, { - accuracy: 0.33, - latency: 0.33, - cost: 0.34, - }); + ).toHaveBeenCalledWith( + 'test-dataset-id', + request.models, + { + accuracy: 0.33, + latency: 0.33, + cost: 0.34, + }, + ALL_METRICS_ENABLED, + ); }); it('should accept a mix of default and custom models', async () => { @@ -386,6 +469,7 @@ describe('EvaluationLaunchUseCase - Weight Configuration', () => { latency: 0.33, cost: 0.34, }, + ALL_METRICS_ENABLED, ); }); diff --git a/backend/src/useCases/EvaluationLaunch/EvaluationLaunchUseCase.ts b/backend/src/useCases/EvaluationLaunch/EvaluationLaunchUseCase.ts index b715277..79ea6b6 100644 --- a/backend/src/useCases/EvaluationLaunch/EvaluationLaunchUseCase.ts +++ b/backend/src/useCases/EvaluationLaunch/EvaluationLaunchUseCase.ts @@ -5,6 +5,7 @@ import { EvaluationJob, EvaluationRequest, ModelConfig, + resolveMetricsConfig, WeightConfig, } from '../../models/Evaluation'; import { tokenBedrockModelValidationService } from '../../services/BedrockModelValidationService/BedrockModelValidationService'; @@ -33,11 +34,13 @@ class EvaluationLaunchUseCaseImpl implements EvaluationLaunchUseCase { ); const normalizedWeights = this.normalizeWeights(request.weights); + const metrics = resolveMetricsConfig(request.metrics); const job = await this.evaluationJobsRepository.createEvaluation( request.dataset_id, modelsToPersist, normalizedWeights, + metrics, ); await this.fargateService.launchTask(job.evaluation_id); diff --git a/backend/src/useCases/EvaluationLaunch/FakeEvaluationLaunchUseCase.ts b/backend/src/useCases/EvaluationLaunch/FakeEvaluationLaunchUseCase.ts index 60865a0..2b57f1a 100644 --- a/backend/src/useCases/EvaluationLaunch/FakeEvaluationLaunchUseCase.ts +++ b/backend/src/useCases/EvaluationLaunch/FakeEvaluationLaunchUseCase.ts @@ -1,4 +1,5 @@ -import type { +import { + resolveMetricsConfig, EvaluationJob, EvaluationRequest, ModelConfig, @@ -40,11 +41,13 @@ export class FakeEvaluationLaunchUseCase implements EvaluationLaunchUseCase { ); const normalizedWeights = this.normalizeWeights(request.weights); + const metrics = resolveMetricsConfig(request.metrics); const job = await this.evaluationJobsRepository.createEvaluation( request.dataset_id, modelsToPersist, normalizedWeights, + metrics, ); await this.fargateService.launchTask(job.evaluation_id); diff --git a/frontend/README.md b/frontend/README.md index bfb7321..172796e 100644 --- a/frontend/README.md +++ b/frontend/README.md @@ -26,7 +26,8 @@ frontend/ │ │ ├── evaluator/ # Step components for the evaluation workflow │ │ │ ├── MetricsWeights.tsx # Step 1 — tune accuracy / latency / cost weights │ │ │ ├── ModelSelection.tsx # Step 2 — pick Bedrock models to evaluate -│ │ │ ├── DatasetUpload.tsx # Step 3 — upload CSV or JSONL dataset +│ │ │ ├── DatasetUpload.tsx # Step 3 — upload dataset and pick metrics to compute +│ │ │ ├── MetricsPicker.tsx # Task-aware picker for which metrics to compute │ │ │ ├── ProgressView.tsx # Step 4a — real-time job progress │ │ │ ├── ResultsView.tsx # Step 4b — ranked results and recommendation │ │ │ └── StepIndicator.tsx # Navigation breadcrumb for the stepper diff --git a/frontend/src/components/evaluator/DatasetUpload.tsx b/frontend/src/components/evaluator/DatasetUpload.tsx index 72cb44f..c676b63 100644 --- a/frontend/src/components/evaluator/DatasetUpload.tsx +++ b/frontend/src/components/evaluator/DatasetUpload.tsx @@ -1,6 +1,11 @@ +import { + MetricsPicker, + type MetricsPickerTaskType, +} from '@/components/evaluator/MetricsPicker'; import { Button } from '@/components/ui/button'; import { useUploadDataset } from '@/hooks/useEvaluation'; import { cn } from '@/lib/utils'; +import type { MetricsToggles } from '@/types/evaluation'; import { motion } from 'framer-motion'; import { AlertCircle, @@ -19,6 +24,8 @@ interface DatasetUploadProps { onStartEvaluation: () => void; onUploadSuccess: (data: { dataset_id: string; sample_count: number }) => void; isStarting?: boolean; + metrics: MetricsToggles; + onMetricsChange: (metrics: MetricsToggles) => void; } type TaskType = 'summarization' | 'classification'; @@ -76,12 +83,23 @@ const TASK_TYPES: { type FormatTab = 'csv' | 'jsonl'; +function resolveDetectedTask(data: { + has_summary: boolean; + has_class: boolean; +}): MetricsPickerTaskType | undefined { + if (data.has_summary) return 'summarization'; + if (data.has_class) return 'classification'; + return undefined; +} + export function DatasetUpload({ file, onChange, onStartEvaluation, onUploadSuccess, isStarting = false, + metrics, + onMetricsChange, }: DatasetUploadProps) { const [dragOver, setDragOver] = useState(false); const [activeTask, setActiveTask] = useState('summarization'); @@ -344,6 +362,16 @@ export function DatasetUpload({ )} + {uploadMutation.isSuccess && ( +
+ +
+ )} + {uploadMutation.isSuccess && ( + + + + +
+ {group.metrics.map((metric) => { + const enabled = metrics[metric.key]; + return ( + + ); + })} +
+ + ); + })} + + + ); +} diff --git a/frontend/src/pages/Index.tsx b/frontend/src/pages/Index.tsx index bf70b23..e02af82 100644 --- a/frontend/src/pages/Index.tsx +++ b/frontend/src/pages/Index.tsx @@ -11,6 +11,7 @@ import { useEvaluationStatus, } from '@/hooks/useEvaluation'; import type { EvaluationConfig } from '@/types/evaluation'; +import { DEFAULT_METRICS_TOGGLES } from '@/types/evaluation'; import { AlertCircle, ArrowLeft, ArrowRight, RotateCcw } from 'lucide-react'; import { useCallback, useEffect, useState } from 'react'; @@ -21,6 +22,7 @@ export default function Index() { const [step, setStep] = useState(0); const [config, setConfig] = useState({ weights: { accuracy: 40, cost: 30, latency: 30 }, + metrics: { ...DEFAULT_METRICS_TOGGLES }, selectedModels: [], datasetFile: null, }); @@ -74,6 +76,7 @@ export default function Index() { latency: config.weights.latency / 100, cost: config.weights.cost / 100, }, + metrics: config.metrics, }, { onSuccess: (data) => { @@ -90,6 +93,7 @@ export default function Index() { datasetId, config.selectedModels, config.weights, + config.metrics, createEvaluationMutation, ]); @@ -105,6 +109,7 @@ export default function Index() { setStep(0); setConfig({ weights: { accuracy: 40, cost: 30, latency: 30 }, + metrics: { ...DEFAULT_METRICS_TOGGLES }, selectedModels: [], datasetFile: null, }); @@ -153,6 +158,10 @@ export default function Index() { onStartEvaluation={handleStartEvaluation} onUploadSuccess={handleUploadSuccess} isStarting={createEvaluationMutation.isPending} + metrics={config.metrics} + onMetricsChange={(m) => + setConfig({ ...config, metrics: m }) + } /> )} diff --git a/frontend/src/types/evaluation.ts b/frontend/src/types/evaluation.ts index 996d431..45c4833 100644 --- a/frontend/src/types/evaluation.ts +++ b/frontend/src/types/evaluation.ts @@ -4,6 +4,32 @@ export interface MetricsWeights { latency: number; } +export const METRIC_KEYS = [ + // Algorithmic — summarization + 'bleu', + 'rouge', + 'meteor', + 'levenshtein', + 'bertscore', + // Algorithmic — classification + 'classification_accuracy', + 'precision_macro', + 'recall_macro', + 'f1_macro', + 'f1_weighted', + // LLM-as-judge + 'geval_reasoning', + 'geval_faithfulness', +] as const; + +export type MetricKey = (typeof METRIC_KEYS)[number]; + +export type MetricsToggles = Record; + +export const DEFAULT_METRICS_TOGGLES: MetricsToggles = Object.fromEntries( + METRIC_KEYS.map((k) => [k, true]), +) as MetricsToggles; + export interface ModelOption { id: string; name: string; @@ -14,6 +40,7 @@ export interface ModelOption { export interface EvaluationConfig { weights: MetricsWeights; + metrics: MetricsToggles; selectedModels: string[]; datasetFile: File | null; } @@ -37,21 +64,21 @@ export interface EvaluationResult { export const AVAILABLE_MODELS: ModelOption[] = [ // Amazon Nova { - id: 'us.amazon.nova-pro-v1:0', + id: 'amazon.nova-pro-v1:0', name: 'Nova Pro', provider: 'Amazon', contextWindow: '300K', costPer1kTokens: 0.0008, }, { - id: 'us.amazon.nova-lite-v1:0', + id: 'amazon.nova-lite-v1:0', name: 'Nova Lite', provider: 'Amazon', contextWindow: '300K', costPer1kTokens: 0.00006, }, { - id: 'us.amazon.nova-micro-v1:0', + id: 'amazon.nova-micro-v1:0', name: 'Nova Micro', provider: 'Amazon', contextWindow: '128K', @@ -59,21 +86,21 @@ export const AVAILABLE_MODELS: ModelOption[] = [ }, // Anthropic { - id: 'us.anthropic.claude-opus-4-6-v1', + id: 'anthropic.claude-opus-4-6-v1', name: 'Claude Opus 4.6', provider: 'Anthropic', contextWindow: '200K', costPer1kTokens: 0.005, }, { - id: 'us.anthropic.claude-sonnet-4-5-20250929-v1:0', + id: 'anthropic.claude-sonnet-4-5-20250929-v1:0', name: 'Claude Sonnet 4.5', provider: 'Anthropic', contextWindow: '200K', costPer1kTokens: 0.003, }, { - id: 'us.anthropic.claude-haiku-4-5-20251001-v1:0', + id: 'anthropic.claude-haiku-4-5-20251001-v1:0', name: 'Claude Haiku 4.5', provider: 'Anthropic', contextWindow: '200K', @@ -85,6 +112,7 @@ export interface CreateEvaluationRequest { dataset_id: string; models: { type: 'default' | 'custom'; identifier: string }[]; weights: { accuracy: number; latency: number; cost: number }; + metrics?: MetricsToggles; } export interface DatasetUploadData {