Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion bird/llm/src/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,35 @@ def execute_sql(predicted_sql,ground_truth, db_path):
cursor.execute(ground_truth)
ground_truth_res = cursor.fetchall()
res = 0
if set(predicted_res) == set(ground_truth_res):


# if set(predicted_res) == set(ground_truth_res):
# res = 1

"""
Instead of using set(predicted_res) == set(ground_truth_res), we compare the normalized and sorted results.
This approach is better because:
- It preserves duplicate rows: set() removes duplicates, but SQL result sets may contain duplicate rows that are semantically important.
- It handles mixed data types: By converting all elements to strings, we avoid type comparison issues (e.g., comparing int and str).
- It is order-insensitive within tuples: Sorting each tuple's elements allows for comparison even if the order of columns differs, which can be useful for unordered result sets.
- It is order-insensitive for rows: Sorting the list of tuples means the order of rows does not affect the comparison.

Example:
Consider the query:
SELECT name, age FROM users;
and a prediction that returns:
SELECT age, name FROM users;
The column order is swapped, but the data is the same. Using set() would not distinguish column order within rows,
but normalizing as done here allows us to compare the content regardless of column order, which is useful for
queries where column order is not semantically important.
"""
if len(predicted_res) != len(ground_truth_res):
res = 0
# Normalize each tuple by sorting its elements, converting them to strings
normalized_a = sorted([tuple(sorted(map(str, t))) for t in predicted_res])
normalized_b = sorted([tuple(sorted(map(str, t))) for t in ground_truth_res])

if normalized_a == normalized_b:
res = 1
return res

Expand Down
27 changes: 27 additions & 0 deletions bird/llm/src/evaluation_ves.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,34 @@ def iterated_execute_sql(predicted_sql,ground_truth,db_path,iterate_num):
cursor.execute(ground_truth)
ground_truth_res = cursor.fetchall()
time_ratio = 0
"""
Instead of using set(predicted_res) == set(ground_truth_res), we compare the normalized and sorted results.
This approach is better because:
- It preserves duplicate rows: set() removes duplicates, but SQL result sets may contain duplicate rows that are semantically important.
- It handles mixed data types: By converting all elements to strings, we avoid type comparison issues (e.g., comparing int and str).
- It is order-insensitive within tuples: Sorting each tuple's elements allows for comparison even if the order of columns differs, which can be useful for unordered result sets.
- It is order-insensitive for rows: Sorting the list of tuples means the order of rows does not affect the comparison.

Example:
Consider the query:
SELECT name, age FROM users;
and a prediction that returns:
SELECT age, name FROM users;
The column order is swapped, but the data is the same. Using set() would not distinguish column order within rows,
but normalizing as done here allows us to compare the content regardless of column order, which is useful for
queries where column order is not semantically important.
"""
if set(predicted_res) == set(ground_truth_res):

# if len(predicted_res) != len(ground_truth_res):
# return 0
# # Normalize each tuple by sorting its elements, converting them to strings
# # to handle mixed data types (e.g., strings and integers).
# normalized_predicted = sorted([tuple(sorted(map(str, t))) for t in predicted_res])
# normalized_ground_truth = sorted([tuple(sorted(map(str, t))) for t in ground_truth_res])

# if normalized_predicted == normalized_ground_truth:

for i in range(iterate_num):
predicted_time = execute_sql(predicted_sql, db_path)
ground_truth_time = execute_sql(ground_truth, db_path)
Expand Down