Skip to content

Commit be0e64a

Browse files
committed
Fix Table wrappers
1 parent 67b4641 commit be0e64a

File tree

4 files changed

+133
-12
lines changed

4 files changed

+133
-12
lines changed

singlestoredb/functions/ext/asgi.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
from ..signature import get_signature
6666
from ..signature import signature_to_sql
6767
from ..typing import Masked
68+
from ..typing import Table
6869

6970
try:
7071
import cloudpickle
@@ -159,18 +160,28 @@ def as_tuple(x: Any) -> Any:
159160

160161
def as_list_of_tuples(x: Any) -> Any:
161162
"""Convert object to a list of tuples."""
163+
if isinstance(x, Table):
164+
x = x[0]
162165
if isinstance(x, (list, tuple)) and len(x) > 0:
166+
if isinstance(x[0], (list, tuple)):
167+
return x
163168
if has_pydantic and isinstance(x[0], BaseModel):
164169
return [tuple(y.model_dump().values()) for y in x]
165170
if dataclasses.is_dataclass(x[0]):
166171
return [dataclasses.astuple(y) for y in x]
167172
if isinstance(x[0], dict):
168173
return [tuple(y.values()) for y in x]
174+
return [(y,) for y in x]
169175
return x
170176

171177

172178
def get_dataframe_columns(df: Any) -> List[Any]:
173179
"""Return columns of data from a dataframe/table."""
180+
if isinstance(df, Table):
181+
if len(df) == 1:
182+
df = df[0]
183+
else:
184+
return list(df)
174185
if isinstance(df, tuple):
175186
return list(df)
176187
rtype = str(type(df)).lower()
@@ -259,8 +270,8 @@ def make_func(
259270
masks = get_masked_params(func)
260271

261272
if function_type == 'tvf':
262-
# Scalar (Python) types
263-
if returns_data_format == 'scalar':
273+
# Scalar / list types (row-based)
274+
if returns_data_format in ['scalar', 'list']:
264275
async def do_func(
265276
row_ids: Sequence[int],
266277
rows: Sequence[Sequence[Any]],
@@ -274,7 +285,7 @@ async def do_func(
274285
out_ids.extend([row_ids[i]] * (len(out)-len(out_ids)))
275286
return out_ids, out
276287

277-
# Vector formats
288+
# Vector formats (column-based)
278289
else:
279290
array_cls = get_array_class(returns_data_format)
280291

@@ -304,16 +315,16 @@ def build_tuple(x: Any) -> Any:
304315
return row_ids, [build_tuple(x) for x in res]
305316

306317
else:
307-
# Scalar (Python) types
308-
if returns_data_format == 'scalar':
318+
# Scalar / list types (row-based)
319+
if returns_data_format in ['scalar', 'list']:
309320
async def do_func(
310321
row_ids: Sequence[int],
311322
rows: Sequence[Sequence[Any]],
312323
) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]:
313324
'''Call function on given rows of data.'''
314325
return row_ids, [as_tuple(x) for x in zip(func_map(func, rows))]
315326

316-
# Vector formats
327+
# Vector formats (column-based)
317328
else:
318329
array_cls = get_array_class(returns_data_format)
319330

@@ -471,8 +482,8 @@ class Application(object):
471482
response=rowdat_1_response_dict,
472483
),
473484
(b'application/octet-stream', b'1.0', 'list'): dict(
474-
load=rowdat_1.load_list,
475-
dump=rowdat_1.dump_list,
485+
load=rowdat_1.load,
486+
dump=rowdat_1.dump,
476487
response=rowdat_1_response_dict,
477488
),
478489
(b'application/octet-stream', b'1.0', 'pandas'): dict(
@@ -501,8 +512,8 @@ class Application(object):
501512
response=json_response_dict,
502513
),
503514
(b'application/json', b'1.0', 'list'): dict(
504-
load=jdata.load_list,
505-
dump=jdata.dump_list,
515+
load=jdata.load,
516+
dump=jdata.dump,
506517
response=json_response_dict,
507518
),
508519
(b'application/json', b'1.0', 'pandas'): dict(

singlestoredb/functions/signature.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1002,7 +1002,7 @@ def get_schema(
10021002
f'{", ".join(out_data_formats)}',
10031003
)
10041004

1005-
if out_data_formats:
1005+
if data_format != 'list' and out_data_formats:
10061006
data_format = out_data_formats[0]
10071007

10081008
# Since the colspec was computed by get_schema already, don't go

singlestoredb/tests/ext_funcs/__init__.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
11
#!/usr/bin/env python3
22
# mypy: disable-error-code="type-arg"
3+
import typing
4+
from typing import List
5+
from typing import NamedTuple
36
from typing import Optional
7+
from typing import Tuple
48

59
import numpy as np
610
import numpy.typing as npt
711
import pandas as pd
812
import polars as pl
913
import pyarrow as pa
1014

15+
import singlestoredb.functions.dtypes as dt
1116
from singlestoredb.functions import Masked
17+
from singlestoredb.functions import Table
1218
from singlestoredb.functions import udf
1319
from singlestoredb.functions.dtypes import BIGINT
1420
from singlestoredb.functions.dtypes import BLOB
@@ -18,7 +24,6 @@
1824
from singlestoredb.functions.dtypes import SMALLINT
1925
from singlestoredb.functions.dtypes import TEXT
2026
from singlestoredb.functions.dtypes import TINYINT
21-
from singlestoredb.functions.typing import Table
2227

2328

2429
@udf
@@ -525,3 +530,53 @@ def numpy_fixed_binary() -> Table[npt.NDArray[np.bytes_]]:
525530
@udf
526531
def no_args_no_return_value() -> None:
527532
pass
533+
534+
535+
@udf
536+
def table_function(n: int) -> Table[List[int]]:
537+
return Table([10] * n)
538+
539+
540+
@udf(
541+
returns=[
542+
dt.INT(name='c_int', nullable=False),
543+
dt.DOUBLE(name='c_float', nullable=False),
544+
dt.TEXT(name='c_str', nullable=False),
545+
],
546+
)
547+
def table_function_tuple(n: int) -> Table[List[Tuple[int, float, str]]]:
548+
return Table([(10, 10.0, 'ten')] * n)
549+
550+
551+
class MyTable(NamedTuple):
552+
c_int: int
553+
c_float: float
554+
c_str: str
555+
556+
557+
@udf
558+
def table_function_struct(n: int) -> Table[List[MyTable]]:
559+
return Table([MyTable(10, 10.0, 'ten')] * n)
560+
561+
562+
@udf
563+
def vec_function(
564+
x: npt.NDArray[np.float64], y: npt.NDArray[np.float64],
565+
) -> npt.NDArray[np.float64]:
566+
return x * y
567+
568+
569+
class VecInputs(typing.NamedTuple):
570+
x: np.int8
571+
y: np.int8
572+
573+
574+
class VecOutputs(typing.NamedTuple):
575+
res: np.int16
576+
577+
578+
@udf(args=VecInputs, returns=VecOutputs)
579+
def vec_function_ints(
580+
x: npt.NDArray[np.int_], y: npt.NDArray[np.int_],
581+
) -> npt.NDArray[np.int_]:
582+
return x * y

singlestoredb/tests/test_ext_func.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1234,3 +1234,58 @@ def test_no_args_no_return_value(self):
12341234
assert desc[0].name == 'res'
12351235
assert desc[0].type_code == ft.TINY
12361236
assert desc[0].null_ok is True
1237+
1238+
def test_table_function(self):
1239+
self.cur.execute('select * from table_function(5)')
1240+
1241+
assert [x[0] for x in self.cur] == [10, 10, 10, 10, 10]
1242+
1243+
desc = self.cur.description
1244+
assert len(desc) == 1
1245+
assert desc[0].name == 'a'
1246+
assert desc[0].type_code == ft.LONGLONG
1247+
assert desc[0].null_ok is False
1248+
1249+
def test_table_function_tuple(self):
1250+
self.cur.execute('select * from table_function_tuple(3)')
1251+
1252+
out = list(self.cur)
1253+
1254+
assert out == [
1255+
(10, 10.0, 'ten'),
1256+
(10, 10.0, 'ten'),
1257+
(10, 10.0, 'ten'),
1258+
]
1259+
1260+
desc = self.cur.description
1261+
assert len(desc) == 3
1262+
assert desc[0].name == 'c_int'
1263+
assert desc[1].name == 'c_float'
1264+
assert desc[2].name == 'c_str'
1265+
1266+
def test_table_function_struct(self):
1267+
self.cur.execute('select * from table_function_struct(3)')
1268+
1269+
out = list(self.cur)
1270+
1271+
assert out == [
1272+
(10, 10.0, 'ten'),
1273+
(10, 10.0, 'ten'),
1274+
(10, 10.0, 'ten'),
1275+
]
1276+
1277+
desc = self.cur.description
1278+
assert len(desc) == 3
1279+
assert desc[0].name == 'c_int'
1280+
assert desc[1].name == 'c_float'
1281+
assert desc[2].name == 'c_str'
1282+
1283+
def test_vec_function(self):
1284+
self.cur.execute('select vec_function(5, 10) as res')
1285+
1286+
assert [tuple(x) for x in self.cur] == [(50.0,)]
1287+
1288+
def test_vec_function_ints(self):
1289+
self.cur.execute('select vec_function(5, 10) as res')
1290+
1291+
assert [tuple(x) for x in self.cur] == [(50,)]

0 commit comments

Comments
 (0)