Skip to content

Commit 5039ec2

Browse files
committed
Fix masks in table results
1 parent be0e64a commit 5039ec2

File tree

5 files changed

+175
-14
lines changed

5 files changed

+175
-14
lines changed

singlestoredb/functions/ext/asgi.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,13 @@ def get_dataframe_columns(df: Any) -> List[Any]:
182182
df = df[0]
183183
else:
184184
return list(df)
185+
186+
if isinstance(df, Masked):
187+
return [df]
188+
185189
if isinstance(df, tuple):
186190
return list(df)
191+
187192
rtype = str(type(df)).lower()
188193
if 'dataframe' in rtype:
189194
return [df[x] for x in df.columns]
@@ -195,6 +200,7 @@ def get_dataframe_columns(df: Any) -> List[Any]:
195200
return [df]
196201
elif 'tuple' in rtype:
197202
return list(df)
203+
198204
raise TypeError(
199205
'Unsupported data type for dataframe columns: '
200206
f'{rtype}',
@@ -292,7 +298,10 @@ async def do_func(
292298
async def do_func( # type: ignore
293299
row_ids: Sequence[int],
294300
cols: Sequence[Tuple[Sequence[Any], Optional[Sequence[bool]]]],
295-
) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]:
301+
) -> Tuple[
302+
Sequence[int],
303+
List[Tuple[Sequence[Any], Optional[Sequence[bool]]]],
304+
]:
296305
'''Call function on given cols of data.'''
297306
# NOTE: There is no way to determine which row ID belongs to
298307
# each result row, so we just have to use the same
@@ -310,7 +319,10 @@ def build_tuple(x: Any) -> Any:
310319
res = get_dataframe_columns(func())
311320

312321
# Generate row IDs
313-
row_ids = array_cls([row_ids[0]] * len(res[0]))
322+
if isinstance(res[0], Masked):
323+
row_ids = array_cls([row_ids[0]] * len(res[0][0]))
324+
else:
325+
row_ids = array_cls([row_ids[0]] * len(res[0]))
314326

315327
return row_ids, [build_tuple(x) for x in res]
316328

@@ -331,7 +343,10 @@ async def do_func(
331343
async def do_func( # type: ignore
332344
row_ids: Sequence[int],
333345
cols: Sequence[Tuple[Sequence[Any], Optional[Sequence[bool]]]],
334-
) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]:
346+
) -> Tuple[
347+
Sequence[int],
348+
List[Tuple[Sequence[Any], Optional[Sequence[bool]]]],
349+
]:
335350
'''Call function on given cols of data.'''
336351
row_ids = array_cls(row_ids)
337352

singlestoredb/functions/signature.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -817,14 +817,16 @@ def get_schema(
817817
if len(args) != 1:
818818
raise TypeError(
819819
'only one list is supported within a table; to '
820-
'return multiple columns, use a NamedTuple, dataclass, '
821-
'TypedDict, or pydantic model',
820+
'return multiple columns, use a tuple, NamedTuple, '
821+
'dataclass, TypedDict, or pydantic model',
822822
)
823823
spec = typing.get_args(args[0])[0]
824824
data_format = 'list'
825825

826-
elif not all([utils.is_vector(x) for x in args]):
827-
# TODO: Don't fail if types are specified in np.ndarrays
826+
elif all([utils.is_vector(x, include_masks=True) for x in args]):
827+
pass
828+
829+
else:
828830
raise TypeError(
829831
'return type for TVF must be a list, DataFrame / Table, '
830832
'or tuple of vectors',
@@ -970,7 +972,7 @@ def get_schema(
970972
# return types or parameter types
971973
if out_overrides and len(typing.get_args(spec)) != len(out_overrides):
972974
raise ValueError(
973-
'number of {mode} types does not match the number of '
975+
f'number of {mode} types does not match the number of '
974976
'overrides specified',
975977
)
976978

@@ -1312,14 +1314,14 @@ def dtype_to_sql(
13121314
13131315
"""
13141316
nullable = ' NOT NULL'
1315-
if force_nullable:
1316-
nullable = ' NULL'
1317-
elif dtype.endswith('?'):
1317+
if dtype.endswith('?'):
13181318
nullable = ' NULL'
13191319
dtype = dtype[:-1]
13201320
elif '|null' in dtype:
13211321
nullable = ' NULL'
13221322
dtype = dtype.replace('|null', '')
1323+
elif force_nullable:
1324+
nullable = ' NULL'
13231325

13241326
if dtype == 'null':
13251327
nullable = ''

singlestoredb/functions/utils.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Any
77
from typing import Dict
88

9+
from .typing import Masked
910

1011
if sys.version_info >= (3, 10):
1112
_UNION_TYPES = {typing.Union, types.UnionType}
@@ -16,6 +17,15 @@
1617
is_dataclass = dataclasses.is_dataclass
1718

1819

20+
def is_masked(obj: Any) -> bool:
21+
"""Check if an object is a Masked type."""
22+
origin = typing.get_origin(obj)
23+
if origin is not None:
24+
return origin is Masked or \
25+
(inspect.isclass(origin) and issubclass(origin, Masked))
26+
return False
27+
28+
1929
def is_union(x: Any) -> bool:
2030
"""Check if the object is a Union."""
2131
return typing.get_origin(x) in _UNION_TYPES
@@ -77,12 +87,13 @@ def is_dataframe(obj: Any) -> bool:
7787
return False
7888

7989

80-
def is_vector(obj: Any) -> bool:
90+
def is_vector(obj: Any, include_masks: bool = False) -> bool:
8191
"""Check if an object is a vector type."""
8292
return is_pandas_series(obj) \
8393
or is_polars_series(obj) \
8494
or is_pyarrow_array(obj) \
85-
or is_numpy(obj)
95+
or is_numpy(obj) \
96+
or is_masked(obj)
8697

8798

8899
def get_data_format(obj: Any) -> str:

singlestoredb/tests/ext_funcs/__init__.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,3 +580,45 @@ def vec_function_ints(
580580
x: npt.NDArray[np.int_], y: npt.NDArray[np.int_],
581581
) -> npt.NDArray[np.int_]:
582582
return x * y
583+
584+
585+
class DFOutputs(typing.NamedTuple):
586+
res: np.int16
587+
res2: np.float64
588+
589+
590+
@udf(args=VecInputs, returns=DFOutputs)
591+
def vec_function_df(
592+
x: npt.NDArray[np.int_], y: npt.NDArray[np.int_],
593+
) -> Table[pd.DataFrame]:
594+
return pd.DataFrame(dict(res=[1, 2, 3], res2=[1.1, 2.2, 3.3]))
595+
596+
597+
class MaskOutputs(typing.NamedTuple):
598+
res: Optional[np.int16]
599+
600+
601+
@udf(args=VecInputs, returns=MaskOutputs)
602+
def vec_function_ints_masked(
603+
x: Masked[npt.NDArray[np.int_]], y: Masked[npt.NDArray[np.int_]],
604+
) -> Table[Masked[npt.NDArray[np.int_]]]:
605+
x_data, x_nulls = x
606+
y_data, y_nulls = y
607+
return Table(Masked(x_data * y_data, x_nulls | y_nulls))
608+
609+
610+
class MaskOutputs2(typing.NamedTuple):
611+
res: Optional[np.int16]
612+
res2: Optional[np.int16]
613+
614+
615+
@udf(args=VecInputs, returns=MaskOutputs2)
616+
def vec_function_ints_masked2(
617+
x: Masked[npt.NDArray[np.int_]], y: Masked[npt.NDArray[np.int_]],
618+
) -> Table[Masked[npt.NDArray[np.int_]], Masked[npt.NDArray[np.int_]]]:
619+
x_data, x_nulls = x
620+
y_data, y_nulls = y
621+
return Table(
622+
Masked(x_data * y_data, x_nulls | y_nulls),
623+
Masked(x_data * y_data, x_nulls | y_nulls),
624+
)

singlestoredb/tests/test_ext_func.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1286,6 +1286,97 @@ def test_vec_function(self):
12861286
assert [tuple(x) for x in self.cur] == [(50.0,)]
12871287

12881288
def test_vec_function_ints(self):
1289-
self.cur.execute('select vec_function(5, 10) as res')
1289+
self.cur.execute('select vec_function_ints(5, 10) as res')
1290+
1291+
assert [tuple(x) for x in self.cur] == [(50,)]
1292+
1293+
def test_vec_function_df(self):
1294+
self.cur.execute('select * from vec_function_df(5, 10)')
1295+
1296+
out = list(self.cur)
1297+
1298+
assert out == [
1299+
(1, 1.1),
1300+
(2, 2.2),
1301+
(3, 3.3),
1302+
]
1303+
1304+
desc = self.cur.description
1305+
assert len(desc) == 2
1306+
assert desc[0].name == 'res'
1307+
assert desc[0].type_code == ft.SHORT
1308+
assert desc[0].null_ok is False
1309+
assert desc[1].name == 'res2'
1310+
assert desc[1].type_code == ft.DOUBLE
1311+
assert desc[1].null_ok is False
1312+
1313+
def test_vec_function_ints_masked(self):
1314+
self.cur.execute('select * from vec_function_ints_masked(5, 10)')
12901315

12911316
assert [tuple(x) for x in self.cur] == [(50,)]
1317+
1318+
desc = self.cur.description
1319+
assert len(desc) == 1
1320+
assert desc[0].name == 'res'
1321+
assert desc[0].type_code == ft.SHORT
1322+
assert desc[0].null_ok is True
1323+
1324+
self.cur.execute('select * from vec_function_ints_masked(NULL, 10)')
1325+
1326+
assert [tuple(x) for x in self.cur] == [(None,)]
1327+
1328+
desc = self.cur.description
1329+
assert len(desc) == 1
1330+
assert desc[0].name == 'res'
1331+
assert desc[0].type_code == ft.SHORT
1332+
assert desc[0].null_ok is True
1333+
1334+
self.cur.execute('select * from vec_function_ints_masked(5, NULL)')
1335+
1336+
assert [tuple(x) for x in self.cur] == [(None,)]
1337+
1338+
desc = self.cur.description
1339+
assert len(desc) == 1
1340+
assert desc[0].name == 'res'
1341+
assert desc[0].type_code == ft.SHORT
1342+
assert desc[0].null_ok is True
1343+
1344+
def test_vec_function_ints_masked2(self):
1345+
self.cur.execute('select * from vec_function_ints_masked2(5, 10)')
1346+
1347+
assert [tuple(x) for x in self.cur] == [(50, 50)]
1348+
1349+
desc = self.cur.description
1350+
assert len(desc) == 2
1351+
assert desc[0].name == 'res'
1352+
assert desc[0].type_code == ft.SHORT
1353+
assert desc[0].null_ok is True
1354+
assert desc[1].name == 'res2'
1355+
assert desc[1].type_code == ft.SHORT
1356+
assert desc[1].null_ok is True
1357+
1358+
self.cur.execute('select * from vec_function_ints_masked2(NULL, 10)')
1359+
1360+
assert [tuple(x) for x in self.cur] == [(None, None)]
1361+
1362+
desc = self.cur.description
1363+
assert len(desc) == 2
1364+
assert desc[0].name == 'res'
1365+
assert desc[0].type_code == ft.SHORT
1366+
assert desc[0].null_ok is True
1367+
assert desc[1].name == 'res2'
1368+
assert desc[1].type_code == ft.SHORT
1369+
assert desc[1].null_ok is True
1370+
1371+
self.cur.execute('select * from vec_function_ints_masked2(5, NULL)')
1372+
1373+
assert [tuple(x) for x in self.cur] == [(None, None)]
1374+
1375+
desc = self.cur.description
1376+
assert len(desc) == 2
1377+
assert desc[0].name == 'res'
1378+
assert desc[0].type_code == ft.SHORT
1379+
assert desc[0].null_ok is True
1380+
assert desc[1].name == 'res2'
1381+
assert desc[1].type_code == ft.SHORT
1382+
assert desc[1].null_ok is True

0 commit comments

Comments
 (0)