Skip to content

Commit 2697d89

Browse files
kesmit13claude
andcommitted
Add msgpack timestamp support and fix C extension error handling
- Add timestamp=3 to msgpack.unpackb() to convert msgpack Timestamps to Python datetime objects - Add datetime=True to msgpack.packb() to properly serialize datetime objects back to msgpack Timestamps - Fix segfault in accel.c dump_rowdat_1 by adding NULL check after apply_transformer call - Add unit tests for msgpack timestamp functionality - Add UDF test functions for timestamp handling across numpy, pandas, polars, list, and non-vector types - Fix .gitignore to only ignore test*.py at root level Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 3396416 commit 2697d89

File tree

11 files changed

+639
-5
lines changed

11 files changed

+639
-5
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ dev-docs
8383
**/.ipynb_checkpoints
8484
**/.benchmarks
8585
*.ipynb
86-
test*.py
86+
/test*.py
8787
certs
8888
**/*.prof
8989
**/*.pprof

accel.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4639,6 +4639,7 @@ static PyObject *dump_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs)
46394639
while ((py_item = PyIter_Next(py_row_iter))) {
46404640

46414641
py_item = apply_transformer(py_transformers[i], py_item);
4642+
if (!py_item) goto error;
46424643
is_null = (uint8_t)(py_item == Py_None);
46434644

46444645
CHECKMEM(1);

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ test = [
5151
"fastapi",
5252
"ipython",
5353
"jupysql",
54+
"msgpack>=1.0.0",
5455
"pandas",
5556
"parameterized",
5657
"polars",

singlestoredb/functions/typing/__init__.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@
1212
from typing import TypeVar
1313
from typing import Union
1414

15+
try:
16+
import msgpack
17+
_has_msgpack = True
18+
except ImportError:
19+
msgpack = None # type: ignore[assignment]
20+
_has_msgpack = False
21+
1522
try:
1623
from typing import TypeVarTuple # type: ignore
1724
from typing import Unpack # type: ignore
@@ -153,10 +160,86 @@ def json_or_null_loads(v: Optional[str], **kwargs: Any) -> Optional[Any]:
153160
]
154161

155162

163+
def msgpack_or_null_dumps(v: Optional[Any], **kwargs: Any) -> Optional[bytes]:
164+
"""
165+
Serialize a Python object to MessagePack bytes or None.
166+
167+
Parameters
168+
----------
169+
v : Optional[Any]
170+
The Python object to serialize. If None or empty, the function returns None.
171+
**kwargs : Any
172+
Additional keyword arguments to pass to `msgpack.packb`.
173+
174+
Returns
175+
-------
176+
Optional[bytes]
177+
The MessagePack bytes representation of the input object,
178+
or None if the input is None or empty.
179+
180+
Raises
181+
------
182+
ImportError
183+
If msgpack is not installed.
184+
185+
"""
186+
if not _has_msgpack:
187+
raise ImportError('msgpack is required for MessagePack serialization')
188+
if not v:
189+
return None
190+
return msgpack.packb(v, datetime=True, **kwargs)
191+
192+
193+
# Force numpy dtype to 'object' to avoid issues with
194+
# numpy trying to infer the dtype and creating multidimensional arrays
195+
# instead of an array of Python objects.
196+
@output_type('object')
197+
def msgpack_or_null_loads(v: Optional[bytes], **kwargs: Any) -> Optional[Any]:
198+
"""
199+
Deserialize MessagePack bytes to a Python object or None.
200+
201+
Parameters
202+
----------
203+
v : Optional[bytes]
204+
The MessagePack bytes to deserialize. If None or empty,
205+
the function returns None.
206+
**kwargs : Any
207+
Additional keyword arguments to pass to `msgpack.unpackb`.
208+
209+
Returns
210+
-------
211+
Optional[Any]
212+
The Python object represented by the MessagePack bytes,
213+
or None if the input is None or empty.
214+
215+
Raises
216+
------
217+
ImportError
218+
If msgpack is not installed.
219+
220+
"""
221+
if not _has_msgpack:
222+
raise ImportError('msgpack is required for MessagePack deserialization')
223+
if not v:
224+
return None
225+
return msgpack.unpackb(v, raw=False, strict_map_key=False, timestamp=3, **kwargs)
226+
227+
228+
MessagePack: TypeAlias = Annotated[
229+
Union[Dict[str, Any], List[Any], int, float, str, bool, bytes, None],
230+
UDFAttrs(
231+
sql_type=sql_types.BLOB(nullable=False),
232+
args_transformer=msgpack_or_null_loads,
233+
returns_transformer=msgpack_or_null_dumps,
234+
),
235+
]
236+
237+
156238
__all__ = [
157239
'Table',
158240
'Masked',
159241
'JSON',
242+
'MessagePack',
160243
'UDFAttrs',
161244
'Transformer',
162245
]

singlestoredb/functions/typing/numpy.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
except ImportError:
1212
from typing_extensions import TypeAlias # type: ignore
1313

14+
from . import msgpack_or_null_dumps
15+
from . import msgpack_or_null_loads
1416
from . import UDFAttrs
1517
from . import json_or_null_dumps
1618
from . import json_or_null_loads
@@ -107,4 +109,30 @@ def default(self, obj: Any) -> Any:
107109
]
108110

109111

110-
__all__ = ['array'] + [x for x in globals().keys() if x.endswith('Array')]
112+
def msgpack_numpy_default(obj: Any) -> Any:
113+
"""Default function for msgpack that handles numpy types."""
114+
if isinstance(obj, np.integer):
115+
return int(obj)
116+
elif isinstance(obj, np.floating):
117+
return float(obj)
118+
elif isinstance(obj, np.ndarray):
119+
return obj.tolist()
120+
raise TypeError(f'Object of type {type(obj)} is not msgpack serializable')
121+
122+
123+
MessagePackArray: TypeAlias = Annotated[
124+
npt.NDArray[np.object_],
125+
UDFAttrs(
126+
sql_type=sql_types.BLOB(nullable=False),
127+
args_transformer=msgpack_or_null_loads,
128+
returns_transformer=lambda x: msgpack_or_null_dumps(
129+
x, default=msgpack_numpy_default,
130+
),
131+
),
132+
]
133+
134+
135+
__all__ = ['array'] + [
136+
x for x in globals().keys()
137+
if x.endswith('Array')
138+
]

singlestoredb/functions/typing/pandas.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
except ImportError:
1313
from typing_extensions import TypeAlias # type: ignore
1414

15+
from . import msgpack_or_null_dumps
16+
from . import msgpack_or_null_loads
1517
from . import UDFAttrs
1618
from . import json_or_null_dumps
1719
from . import json_or_null_loads
@@ -109,4 +111,34 @@ def default(self, obj: Any) -> Any:
109111
]
110112

111113

112-
__all__ = ['DataFrame'] + [x for x in globals().keys() if x.endswith('Series')]
114+
def msgpack_pandas_default(obj: Any) -> Any:
115+
"""Default function for msgpack that handles pandas types."""
116+
if hasattr(obj, 'dtype') and hasattr(obj, 'tolist'):
117+
# Handle pandas Series and numpy arrays
118+
return obj.tolist()
119+
elif isinstance(obj, np.integer):
120+
return int(obj)
121+
elif isinstance(obj, np.floating):
122+
return float(obj)
123+
elif hasattr(obj, 'item'):
124+
# Handle pandas scalar types
125+
return obj.item()
126+
raise TypeError(f'Object of type {type(obj)} is not msgpack serializable')
127+
128+
129+
MessagePackSeries: TypeAlias = Annotated[
130+
pd.Series,
131+
UDFAttrs(
132+
sql_type=sql_types.BLOB(nullable=False),
133+
args_transformer=msgpack_or_null_loads,
134+
returns_transformer=lambda x: msgpack_or_null_dumps(
135+
x, default=msgpack_pandas_default,
136+
),
137+
),
138+
]
139+
140+
141+
__all__ = ['DataFrame'] + [
142+
x for x in globals().keys()
143+
if x.endswith('Series')
144+
]

singlestoredb/functions/typing/polars.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
except ImportError:
1212
from typing_extensions import TypeAlias # type: ignore
1313

14+
from . import msgpack_or_null_dumps
15+
from . import msgpack_or_null_loads
1416
from . import UDFAttrs
1517
from . import json_or_null_dumps
1618
from . import json_or_null_loads
@@ -115,4 +117,41 @@ def default(self, obj: Any) -> Any:
115117
]
116118

117119

118-
__all__ = ['DataFrame'] + [x for x in globals().keys() if x.endswith('Series')]
120+
def msgpack_polars_default(obj: Any) -> Any:
121+
"""Default function for msgpack that handles polars types."""
122+
if isinstance(obj, pl.Series):
123+
# Convert Polars Series to Python list
124+
return obj.to_list()
125+
elif hasattr(obj, 'dtype') and \
126+
str(obj.dtype).startswith(('Int', 'UInt', 'Float')):
127+
# Handle Polars scalar integer and float types
128+
return obj.item() if hasattr(obj, 'item') else obj
129+
elif isinstance(
130+
obj, (
131+
pl.datatypes.Int8, pl.datatypes.Int16, pl.datatypes.Int32,
132+
pl.datatypes.Int64, pl.datatypes.UInt8, pl.datatypes.UInt16,
133+
pl.datatypes.UInt32, pl.datatypes.UInt64,
134+
),
135+
):
136+
return int(obj)
137+
elif isinstance(obj, (pl.datatypes.Float32, pl.datatypes.Float64)):
138+
return float(obj)
139+
raise TypeError(f'Object of type {type(obj)} is not msgpack serializable')
140+
141+
142+
MessagePackSeries: TypeAlias = Annotated[
143+
pl.Series,
144+
UDFAttrs(
145+
sql_type=sql_types.BLOB(nullable=False),
146+
args_transformer=msgpack_or_null_loads,
147+
returns_transformer=lambda x: msgpack_or_null_dumps(
148+
x, default=msgpack_polars_default,
149+
),
150+
),
151+
]
152+
153+
154+
__all__ = ['DataFrame'] + [
155+
x for x in globals().keys()
156+
if x.endswith('Series')
157+
]

singlestoredb/functions/typing/pyarrow.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
except ImportError:
1313
from typing_extensions import TypeAlias # type: ignore
1414

15+
from . import msgpack_or_null_dumps
16+
from . import msgpack_or_null_loads # noqa: F401
1517
from . import UDFAttrs
1618
from . import json_or_null_dumps
1719
from . import json_or_null_loads # noqa: F401
@@ -113,4 +115,38 @@ def default(self, obj: Any) -> Any:
113115
]
114116

115117

116-
__all__ = ['Table', 'array'] + [x for x in globals().keys() if x.endswith('Array')]
118+
def msgpack_pyarrow_default(obj: Any) -> Any:
119+
"""Default function for msgpack that handles pyarrow types."""
120+
if hasattr(obj, 'as_py'):
121+
# Handle PyArrow scalar types (including individual ints and floats)
122+
return obj.as_py()
123+
elif isinstance(obj, pa.Array):
124+
# Convert PyArrow Array to Python list
125+
return obj.to_pylist()
126+
elif isinstance(obj, pa.Table):
127+
# Convert PyArrow Table to list of dictionaries
128+
return obj.to_pydict()
129+
raise TypeError(f'Object of type {type(obj)} is not msgpack serializable')
130+
131+
132+
#
133+
# NOTE: We don't use input_transformer=msgpack_or_null_loads because it doesn't handle
134+
# all cases (e.g., when the input is already a dict/list).
135+
#
136+
137+
MessagePackArray: TypeAlias = Annotated[
138+
pa.Array,
139+
UDFAttrs(
140+
sql_type=sql_types.BLOB(nullable=True),
141+
# input_transformer=msgpack_or_null_loads,
142+
returns_transformer=lambda x: msgpack_or_null_dumps(
143+
x, default=msgpack_pyarrow_default,
144+
),
145+
),
146+
]
147+
148+
149+
__all__ = ['Table', 'array'] + [
150+
x for x in globals().keys()
151+
if x.endswith('Array')
152+
]

0 commit comments

Comments
 (0)