Skip to content

Commit f003e25

Browse files
committed
feat: enhance JSON handling for NumPy and pandas types in test utilities and query results
1 parent cb3a3b9 commit f003e25

2 files changed

Lines changed: 36 additions & 4 deletions

File tree

src/vfbquery/test_utils.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,36 @@
11
import pandas as pd
2+
import json
3+
import numpy as np
24
from typing import Any, Dict, Union
35

6+
# Custom JSON encoder to handle NumPy and pandas types
7+
class NumpyEncoder(json.JSONEncoder):
8+
def default(self, obj):
9+
if isinstance(obj, np.integer):
10+
return int(obj)
11+
elif isinstance(obj, np.floating):
12+
return float(obj)
13+
elif isinstance(obj, np.ndarray):
14+
return obj.tolist()
15+
elif isinstance(obj, np.bool_):
16+
return bool(obj)
17+
elif hasattr(obj, 'item'): # Handle pandas scalar types
18+
return obj.item()
19+
return super(NumpyEncoder, self).default(obj)
20+
21+
def safe_to_dict(df):
22+
"""Convert DataFrame to dict with numpy types converted to native Python types"""
23+
if isinstance(df, pd.DataFrame):
24+
# Convert numpy dtypes to native Python types
25+
df_copy = df.copy()
26+
for col in df_copy.columns:
27+
if df_copy[col].dtype.name.startswith('int'):
28+
df_copy[col] = df_copy[col].astype('object')
29+
elif df_copy[col].dtype.name.startswith('float'):
30+
df_copy[col] = df_copy[col].astype('object')
31+
return df_copy.to_dict("records")
32+
return df
33+
434
def safe_extract_row(result: Any, index: int = 0) -> Dict:
535
"""
636
Safely extract a row from a pandas DataFrame or return the object itself if not a DataFrame.
@@ -11,7 +41,9 @@ def safe_extract_row(result: Any, index: int = 0) -> Dict:
1141
"""
1242
if isinstance(result, pd.DataFrame):
1343
if not result.empty and len(result.index) > index:
14-
return result.iloc[index].to_dict()
44+
# Convert to dict using safe method to handle numpy types
45+
row_series = result.iloc[index]
46+
return {col: (val.item() if hasattr(val, 'item') else val) for col, val in row_series.items()}
1547
else:
1648
return {}
1749
return result
@@ -28,8 +60,8 @@ def patch_vfb_connect_query_wrapper():
2860
def patched_get_term_info(self, terms, *args, **kwargs):
2961
result = original_get_term_info(self, terms, *args, **kwargs)
3062
if isinstance(result, pd.DataFrame):
31-
# Return list of row dictionaries instead of DataFrame
32-
return [row.to_dict() for i, row in result.iterrows()]
63+
# Return list of row dictionaries instead of DataFrame using safe conversion
64+
return safe_to_dict(result)
3365
return result
3466

3567
NeoQueryWrapper._get_TermInfo = patched_get_term_info

src/vfbquery/vfb_queries.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1347,7 +1347,7 @@ def fill_query_results(term_info):
13471347
filtered_item = item
13481348
filtered_result.append(filtered_item)
13491349
elif isinstance(result, pd.DataFrame):
1350-
filtered_result = result[query['preview_columns']].to_dict('records')
1350+
filtered_result = safe_to_dict(result[query['preview_columns']])
13511351
else:
13521352
print(f"Unsupported result format for filtering columns in {query['function']}")
13531353

0 commit comments

Comments
 (0)