Skip to content

Commit 8313d65

Browse files
committed
Add per_session_context support for golden eval
Thread per_session_context through quality_report.py into classify_sessions_via_api() so golden eval expected answers can be injected into the judge prompt per session.
1 parent c823672 commit 8313d65

2 files changed

Lines changed: 15 additions & 1 deletion

File tree

scripts/quality_report.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1158,6 +1158,7 @@ def run_evaluation_from_conversations(
11581158
concurrency=10,
11591159
tag_turns=False,
11601160
eval_config=None,
1161+
per_session_context=None,
11611162
):
11621163
"""Evaluate local conversations without BigQuery.
11631164
@@ -1173,6 +1174,8 @@ def run_evaluation_from_conversations(
11731174
concurrency: Max parallel API calls (default 10).
11741175
tag_turns: When True, run the full turn tagger to classify each user
11751176
turn and identify correction boundaries / sub-trajectories.
1177+
per_session_context: Optional dict mapping session_id to additional
1178+
context string for the judge prompt (e.g. matched golden eval).
11761179
11771180
Returns:
11781181
Dict with ``report`` (CategoricalEvaluationReport) and
@@ -1213,6 +1216,7 @@ def run_evaluation_from_conversations(
12131216
async def _run_all():
12141217
classify_task = classify_sessions_via_api(
12151218
transcripts, cat_config, model,
1219+
per_session_context=per_session_context,
12161220
)
12171221
resolve_task = _build_resolved_map_from_conversations(
12181222
conversations, model, concurrency=concurrency,
@@ -1238,6 +1242,7 @@ def generate_quality_report_from_conversations(
12381242
concurrency=10,
12391243
tag_turns=False,
12401244
trajectory_samples=0,
1245+
per_session_context=None,
12411246
) -> dict:
12421247
"""Evaluate local conversations and return a structured quality report.
12431248
@@ -1253,6 +1258,8 @@ def generate_quality_report_from_conversations(
12531258
tag_turns: When True, run the full turn tagger to add per-turn tags,
12541259
correction boundaries, and sub-trajectories to the output.
12551260
trajectory_samples: Number of execution traces to fetch from BigQuery.
1261+
per_session_context: Optional dict mapping session_id to additional
1262+
context string for the judge prompt (e.g. matched golden eval).
12561263
12571264
Returns:
12581265
Dict with ``summary`` and ``sessions`` keys.
@@ -1263,6 +1270,7 @@ def generate_quality_report_from_conversations(
12631270
result = run_evaluation_from_conversations(
12641271
conversations, model=model, config_path=config_path,
12651272
concurrency=concurrency, tag_turns=tag_turns,
1273+
per_session_context=per_session_context,
12661274
)
12671275
elapsed = time.time() - t0
12681276

src/bigquery_agent_analytics/categorical_evaluator.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -832,6 +832,7 @@ async def classify_sessions_via_api(
832832
transcripts: dict[str, str],
833833
config: CategoricalEvaluationConfig,
834834
endpoint: str = DEFAULT_ENDPOINT,
835+
per_session_context: dict[str, str] | None = None,
835836
) -> list[CategoricalSessionResult]:
836837
"""Classifies sessions using the Gemini API (fallback).
837838
@@ -843,6 +844,8 @@ async def classify_sessions_via_api(
843844
transcripts: Maps ``session_id`` to transcript text.
844845
config: Categorical evaluation configuration.
845846
endpoint: Model endpoint name.
847+
per_session_context: Optional per-session context to inject into the
848+
judge prompt (e.g. matched golden eval expected answers).
846849
847850
Returns:
848851
One ``CategoricalSessionResult`` per session.
@@ -861,7 +864,10 @@ async def classify_sessions_via_api(
861864
if len(text) > 25000:
862865
text = text[:25000] + "\n... [truncated]"
863866

864-
full_prompt = prompt_prefix + "\n\nTranscript:\n" + text
867+
session_ctx = ""
868+
if per_session_context and sid in per_session_context:
869+
session_ctx = "\n\n" + per_session_context[sid]
870+
full_prompt = prompt_prefix + session_ctx + "\n\nTranscript:\n" + text
865871

866872
try:
867873
response = await client.aio.models.generate_content(

0 commit comments

Comments
 (0)