Skip to content

Commit cb3a3b9

Browse files
committed
feat: add custom JSON encoder for NumPy and pandas types and update serialization methods
1 parent 9f09554 commit cb3a3b9

2 files changed

Lines changed: 59 additions & 9 deletions

File tree

src/vfbquery/term_info_queries.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,21 @@
11
import re
22
import json
3+
import numpy as np
4+
5+
# Custom JSON encoder to handle NumPy and pandas types
6+
class NumpyEncoder(json.JSONEncoder):
7+
def default(self, obj):
8+
if isinstance(obj, np.integer):
9+
return int(obj)
10+
elif isinstance(obj, np.floating):
11+
return float(obj)
12+
elif isinstance(obj, np.ndarray):
13+
return obj.tolist()
14+
elif isinstance(obj, np.bool_):
15+
return bool(obj)
16+
elif hasattr(obj, 'item'): # Handle pandas scalar types
17+
return obj.item()
18+
return super(NumpyEncoder, self).default(obj)
319
import requests
420
from dataclasses import dataclass
521
from dataclasses_json import dataclass_json
@@ -15,7 +31,7 @@ class Coordinates:
1531
Z: float
1632

1733
def __str__(self):
18-
return json.dumps([str(self.X), str(self.Y), str(self.Z)])
34+
return json.dumps([str(self.X), str(self.Y), str(self.Z)], cls=NumpyEncoder)
1935

2036

2137
class CoordinatesFactory:
@@ -1062,7 +1078,7 @@ def serialize_term_info_to_json(vfb_term: VfbTerminfo, show_types=False) -> str:
10621078
:return: json string representation of the term info object
10631079
"""
10641080
term_info_dict = serialize_term_info_to_dict(vfb_term, show_types)
1065-
return json.dumps(term_info_dict, indent=4)
1081+
return json.dumps(term_info_dict, indent=4, cls=NumpyEncoder)
10661082

10671083

10681084
def process(term_info_response: dict, variable, loaded_template: Optional[str] = None, show_types=False) -> dict:

src/vfbquery/vfb_queries.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,35 @@
88
import pandas as pd
99
from marshmallow import ValidationError
1010
import json
11+
import numpy as np
12+
13+
# Custom JSON encoder to handle NumPy and pandas types
14+
class NumpyEncoder(json.JSONEncoder):
15+
def default(self, obj):
16+
if isinstance(obj, np.integer):
17+
return int(obj)
18+
elif isinstance(obj, np.floating):
19+
return float(obj)
20+
elif isinstance(obj, np.ndarray):
21+
return obj.tolist()
22+
elif isinstance(obj, np.bool_):
23+
return bool(obj)
24+
elif hasattr(obj, 'item'): # Handle pandas scalar types
25+
return obj.item()
26+
return super(NumpyEncoder, self).default(obj)
27+
28+
def safe_to_dict(df):
29+
"""Convert DataFrame to dict with numpy types converted to native Python types"""
30+
if isinstance(df, pd.DataFrame):
31+
# Convert numpy dtypes to native Python types
32+
df_copy = df.copy()
33+
for col in df_copy.columns:
34+
if df_copy[col].dtype.name.startswith('int'):
35+
df_copy[col] = df_copy[col].astype('object')
36+
elif df_copy[col].dtype.name.startswith('float'):
37+
df_copy[col] = df_copy[col].astype('object')
38+
return df_copy.to_dict("records")
39+
return df
1140

1241
# Lazy import for dict_cursor to avoid GUI library issues
1342
def get_dict_cursor():
@@ -780,8 +809,13 @@ def ListAllAvailableImages_to_schema(name, take_default):
780809
return Query(query=query, label=label, function=function, takes=takes, preview=preview, preview_columns=preview_columns)
781810

782811
def serialize_solr_output(results):
783-
# Serialize the sanitized dictionary to JSON
784-
json_string = json.dumps(results.docs[0], ensure_ascii=False)
812+
# Create a copy of the document and remove Solr-specific fields
813+
doc = dict(results.docs[0])
814+
# Remove the _version_ field which can cause serialization issues with large integers
815+
doc.pop('_version_', None)
816+
817+
# Serialize the sanitized dictionary to JSON using NumpyEncoder
818+
json_string = json.dumps(doc, ensure_ascii=False, cls=NumpyEncoder)
785819
json_string = json_string.replace('\\', '')
786820
json_string = json_string.replace('"{', '{')
787821
json_string = json_string.replace('}"', '}')
@@ -914,7 +948,7 @@ def get_instances(short_form: str, return_dataframe=True, limit: int = -1):
914948
"thumbnail"
915949
]
916950
}
917-
for row in df.to_dict("records")
951+
for row in safe_to_dict(df)
918952
],
919953
"count": total_count
920954
}
@@ -1002,7 +1036,7 @@ def get_templates(limit: int = -1, return_dataframe: bool = False):
10021036
"license"
10031037
]
10041038
}
1005-
for row in df.to_dict("records")
1039+
for row in safe_to_dict(df)
10061040
],
10071041
"count": total_count
10081042
}
@@ -1118,7 +1152,7 @@ def get_similar_neurons(neuron, similarity_score='NBLAST_score', return_datafram
11181152
"thumbnail"
11191153
]
11201154
}
1121-
for row in df.to_dict("records")
1155+
for row in safe_to_dict(df)
11221156
],
11231157
"count": total_count
11241158
}
@@ -1228,7 +1262,7 @@ def get_individual_neuron_inputs(neuron_short_form: str, return_dataframe=True,
12281262
"Images"
12291263
]
12301264
}
1231-
for row in df.to_dict("records")
1265+
for row in safe_to_dict(df)
12321266
],
12331267
"count": total_count
12341268
}
@@ -1248,7 +1282,7 @@ def get_individual_neuron_inputs(neuron_short_form: str, return_dataframe=True,
12481282
"Weight",
12491283
]
12501284
}
1251-
for row in df.to_dict("records")
1285+
for row in safe_to_dict(df)
12521286
],
12531287
"count": total_count
12541288
}

0 commit comments

Comments
 (0)