-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathsql_metric.py
More file actions
47 lines (40 loc) · 1.82 KB
/
sql_metric.py
File metadata and controls
47 lines (40 loc) · 1.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
"""SQL evaluation metrics."""
import sqlite3
from db import create_db
import dspy
from query_timeout import execute_query_with_timeout
def sql_correctness_metric(example, prediction, trace=None, pred_name=None, pred_trace=None):
"""Evaluate predicted SQL correctness with 10-second query timeout.
Returns 1.0 if results match (ignoring row/column order), 0.0 otherwise.
"""
if not hasattr(prediction, 'sql_query') or prediction.sql_query is None:
return dspy.Prediction(score=0.0, feedback="No SQL query")
conn = create_db()
try:
pred_results, pred_error = execute_query_with_timeout(
conn, prediction.sql_query, timeout_seconds=10.0
)
if pred_results is None:
return dspy.Prediction(score=0.0, feedback=f"Predicted SQL failed: {pred_error}")
gold_results, gold_error = execute_query_with_timeout(
conn, example.sql_query, timeout_seconds=10.0
)
assert gold_results is not None, f"Gold SQL failed: {gold_error}"
# Compare results: frozenset of frozensets to ignore row and column order
pred_set = set(frozenset(row) for row in pred_results)
gold_set = set(frozenset(row) for row in gold_results)
if pred_set == gold_set:
return dspy.Prediction(score=1.0, feedback="Match")
else:
pred_sample = pred_results[:3]
gold_sample = gold_results[:3]
feedback = (
f"Different results | "
f"Pred SQL: {prediction.sql_query} | "
f"Gold SQL: {example.sql_query} | "
f"Pred rows ({len(pred_results)}): {pred_sample} | "
f"Gold rows ({len(gold_results)}): {gold_sample}"
)
return dspy.Prediction(score=0.0, feedback=feedback)
finally:
conn.close()