Skip to content

Commit 7ec2c24

Browse files
committed
Fix imports; change binary from hex to base64
1 parent 858e652 commit 7ec2c24

File tree

4 files changed

+63
-35
lines changed

4 files changed

+63
-35
lines changed

singlestoredb/functions/dtypes.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#!/usr/bin/env python3
2+
import base64
23
import datetime
34
import decimal
45
import re
@@ -106,7 +107,7 @@ def bytestr(x: Any) -> Optional[bytes]:
106107
return x
107108
if isinstance(x, bytes):
108109
return x
109-
return bytes.fromhex(x)
110+
return base64.b64decode(x)
110111

111112

112113
PYTHON_CONVERTERS = {

singlestoredb/functions/ext/json.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#!/usr/bin/env python3
2+
import base64
23
import json
34
from typing import Any
45
from typing import List
@@ -35,7 +36,7 @@ class JSONEncoder(json.JSONEncoder):
3536

3637
def default(self, obj: Any) -> Any:
3738
if isinstance(obj, bytes):
38-
return obj.hex()
39+
return base64.b64encode(obj).decode('utf-8')
3940
return json.JSONEncoder.default(self, obj)
4041

4142

@@ -130,6 +131,8 @@ def load_pandas(
130131
Tuple[pd.Series[int], List[pd.Series[Any]]
131132
132133
'''
134+
import numpy as np
135+
import pandas as pd
133136
row_ids, cols = _load_vectors(colspec, data)
134137
index = pd.Series(row_ids, dtype=np.longlong)
135138
return index, \
@@ -164,6 +167,7 @@ def load_polars(
164167
Tuple[polars.Series[int], List[polars.Series[Any]]
165168
166169
'''
170+
import polars as pl
167171
row_ids, cols = _load_vectors(colspec, data)
168172
return pl.Series(None, row_ids, dtype=pl.Int64), \
169173
[
@@ -194,6 +198,7 @@ def load_numpy(
194198
Tuple[np.ndarray[int], List[np.ndarray[Any]]
195199
196200
'''
201+
import numpy as np
197202
row_ids, cols = _load_vectors(colspec, data)
198203
return np.asarray(row_ids, dtype=np.longlong), \
199204
[
@@ -224,6 +229,7 @@ def load_arrow(
224229
Tuple[pyarrow.Array[int], List[pyarrow.Array[Any]]
225230
226231
'''
232+
import pyarrow as pa
227233
row_ids, cols = _load_vectors(colspec, data)
228234
return pa.array(row_ids, type=pa.int64()), \
229235
[

singlestoredb/functions/ext/rowdat_1.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
except ImportError:
3636
pass
3737
try:
38-
import pyarrow.compute as pc
38+
import pyarrow.compute as pc # noqa: F401
3939
except ImportError:
4040
pass
4141

@@ -205,6 +205,7 @@ def _load_pandas(
205205
Tuple[pd.Series[int], List[Tuple[pd.Series[Any], pd.Series[bool]]]]
206206
207207
'''
208+
import numpy as np
208209
import pandas as pd
209210

210211
row_ids, cols = _load_vectors(colspec, data)
@@ -558,6 +559,7 @@ def _load_pandas_accel(
558559
if not has_accel:
559560
raise RuntimeError('could not load SingleStoreDB extension')
560561

562+
import numpy as np
561563
import pandas as pd
562564

563565
numpy_ids, numpy_cols = _singlestoredb_accel.load_rowdat_1_numpy(colspec, data)
@@ -663,8 +665,11 @@ def _create_arrow_mask(
663665
data: 'pa.Array[Any]',
664666
mask: 'pa.Array[pa.bool_]',
665667
) -> 'pa.Array[pa.bool_]':
668+
import pyarrow.compute as pc # noqa: F811
669+
666670
if mask is None:
667671
return data.is_null().to_numpy(zero_copy_only=False)
672+
668673
return pc.or_(data.is_null(), mask.is_null()).to_numpy(zero_copy_only=False)
669674

670675

singlestoredb/functions/utils.py

Lines changed: 48 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,6 @@
66
from typing import Any
77
from typing import Dict
88

9-
try:
10-
import numpy as np
11-
has_numpy = True
12-
except ImportError:
13-
has_numpy = False
14-
159

1610
if sys.version_info >= (3, 10):
1711
_UNION_TYPES = {typing.Union, types.UnionType}
@@ -36,29 +30,51 @@ def get_annotations(obj: Any) -> Dict[str, Any]:
3630
return getattr(obj, '__annotations__', {})
3731

3832

33+
def get_module(obj: Any) -> str:
34+
"""Get the module of an object."""
35+
module = getattr(obj, '__module__', '').split('.')
36+
if module:
37+
return module[0]
38+
return ''
39+
40+
41+
def get_type_name(obj: Any) -> str:
42+
"""Get the type name of an object."""
43+
if hasattr(obj, '__name__'):
44+
return obj.__name__
45+
if hasattr(obj, '__class__'):
46+
return obj.__class__.__name__
47+
return ''
48+
49+
3950
def is_numpy(obj: Any) -> bool:
4051
"""Check if an object is a numpy array."""
41-
if is_union(obj):
42-
obj = typing.get_args(obj)[0]
43-
if not has_numpy:
44-
return False
4552
if inspect.isclass(obj):
46-
return obj is np.ndarray
47-
if typing.get_origin(obj) is np.ndarray:
48-
return True
49-
return isinstance(obj, np.ndarray)
53+
if get_module(obj) == 'numpy':
54+
return get_type_name(obj) == 'ndarray'
55+
56+
origin = typing.get_origin(obj)
57+
if get_module(origin) == 'numpy':
58+
if get_type_name(origin) == 'ndarray':
59+
return True
60+
61+
dtype = type(obj)
62+
if get_module(dtype) == 'numpy':
63+
return get_type_name(dtype) == 'ndarray'
64+
65+
return False
5066

5167

5268
def is_dataframe(obj: Any) -> bool:
5369
"""Check if an object is a DataFrame."""
5470
# Cheating here a bit so we don't have to import pandas / polars / pyarrow:
5571
# unless we absolutely need to
56-
if getattr(obj, '__module__', '').startswith('pandas.'):
57-
return getattr(obj, '__name__', '') == 'DataFrame'
58-
if getattr(obj, '__module__', '').startswith('polars.'):
59-
return getattr(obj, '__name__', '') == 'DataFrame'
60-
if getattr(obj, '__module__', '').startswith('pyarrow.'):
61-
return getattr(obj, '__name__', '') == 'Table'
72+
if get_module(obj) == 'pandas':
73+
return get_type_name(obj) == 'DataFrame'
74+
if get_module(obj) == 'polars':
75+
return get_type_name(obj) == 'DataFrame'
76+
if get_module(obj) == 'pyarrow':
77+
return get_type_name(obj) == 'Table'
6278
return False
6379

6480

@@ -74,13 +90,13 @@ def get_data_format(obj: Any) -> str:
7490
"""Return the data format of the DataFrame / Table / vector."""
7591
# Cheating here a bit so we don't have to import pandas / polars / pyarrow
7692
# unless we absolutely need to
77-
if getattr(obj, '__module__', '').startswith('pandas.'):
93+
if get_module(obj) == 'pandas':
7894
return 'pandas'
79-
if getattr(obj, '__module__', '').startswith('polars.'):
95+
if get_module(obj) == 'polars':
8096
return 'polars'
81-
if getattr(obj, '__module__', '').startswith('pyarrow.'):
97+
if get_module(obj) == 'pyarrow':
8298
return 'arrow'
83-
if getattr(obj, '__module__', '').startswith('numpy.'):
99+
if get_module(obj) == 'numpy':
84100
return 'numpy'
85101
if isinstance(obj, list):
86102
return 'list'
@@ -92,8 +108,8 @@ def is_pandas_series(obj: Any) -> bool:
92108
if is_union(obj):
93109
obj = typing.get_args(obj)[0]
94110
return (
95-
getattr(obj, '__module__', '').startswith('pandas.') and
96-
getattr(obj, '__name__', '') == 'Series'
111+
get_module(obj) == 'pandas' and
112+
get_type_name(obj) == 'Series'
97113
)
98114

99115

@@ -102,8 +118,8 @@ def is_polars_series(obj: Any) -> bool:
102118
if is_union(obj):
103119
obj = typing.get_args(obj)[0]
104120
return (
105-
getattr(obj, '__module__', '').startswith('polars.') and
106-
getattr(obj, '__name__', '') == 'Series'
121+
get_module(obj) == 'polars' and
122+
get_type_name(obj) == 'Series'
107123
)
108124

109125

@@ -112,8 +128,8 @@ def is_pyarrow_array(obj: Any) -> bool:
112128
if is_union(obj):
113129
obj = typing.get_args(obj)[0]
114130
return (
115-
getattr(obj, '__module__', '').startswith('pyarrow.') and
116-
getattr(obj, '__name__', '') == 'Array'
131+
get_module(obj) == 'pyarrow' and
132+
get_type_name(obj) == 'Array'
117133
)
118134

119135

@@ -147,6 +163,6 @@ def is_pydantic(obj: Any) -> bool:
147163
# the class is a subclass
148164
return bool([
149165
x for x in inspect.getmro(obj)
150-
if x.__module__.startswith('pydantic.')
151-
and x.__name__ == 'BaseModel'
166+
if get_module(x) == 'pydantic'
167+
and get_type_name(x) == 'BaseModel'
152168
])

0 commit comments

Comments
 (0)