Skip to content

Commit 2d6ff9d

Browse files
committed
Add null masks
1 parent f337bc6 commit 2d6ff9d

File tree

7 files changed

+114
-135
lines changed

7 files changed

+114
-135
lines changed

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ install_requires =
2525
requests
2626
setuptools
2727
sqlparams
28-
typing-extensions<=4.13.2
2928
wheel
3029
tomli>=1.1.0;python_version < '3.11'
30+
typing-extensions<=4.13.2;python_version < '3.11'
3131
python_requires = >=3.8
3232
include_package_data = True
3333
tests_require =
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
from .decorator import udf # noqa: F401
2-
from .decorator import udf_with_null_masks # noqa: F401
32
from .typing import Masked # noqa: F401
43
from .typing import Table # noqa: F401

singlestoredb/functions/decorator.py

Lines changed: 5 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import functools
22
import inspect
3-
import typing
43
from typing import Any
54
from typing import Callable
65
from typing import List
@@ -10,7 +9,6 @@
109

1110
from . import utils
1211
from .dtypes import SQLString
13-
from .typing import Masked
1412

1513

1614
ParameterType = Union[
@@ -61,27 +59,6 @@ def is_valid_callable(obj: Any) -> bool:
6159
)
6260

6361

64-
def verify_mask(obj: Any) -> bool:
65-
"""Verify that the object is a tuple of two vector types."""
66-
if not typing.get_origin(obj) is Masked:
67-
raise TypeError(
68-
f'expected a Masked type, but got {type(obj)}',
69-
)
70-
return True
71-
72-
73-
def verify_masks(obj: Callable[..., Any]) -> bool:
74-
"""Verify that the function parameters and return value are all masks."""
75-
ann = utils.get_annotations(obj)
76-
for name, value in ann.items():
77-
if not verify_mask(value):
78-
raise TypeError(
79-
f'Expected a vector type for the parameter {name} '
80-
f'in function {obj.__name__}, but got {value}',
81-
)
82-
return True
83-
84-
8562
def expand_types(args: Any) -> Optional[Union[List[str], Type[Any]]]:
8663
"""Expand the types for the function arguments / return values."""
8764
if args is None:
@@ -123,7 +100,6 @@ def _func(
123100
name: Optional[str] = None,
124101
args: Optional[ParameterType] = None,
125102
returns: Optional[ReturnType] = None,
126-
with_null_masks: bool = False,
127103
) -> Callable[..., Any]:
128104
"""Generic wrapper for UDF and TVF decorators."""
129105

@@ -132,7 +108,6 @@ def _func(
132108
name=name,
133109
args=expand_types(args),
134110
returns=expand_types(returns),
135-
with_null_masks=with_null_masks,
136111
).items() if v is not None
137112
}
138113

@@ -141,8 +116,6 @@ def _func(
141116
# in at that time.
142117
if func is None:
143118
def decorate(func: Callable[..., Any]) -> Callable[..., Any]:
144-
if with_null_masks:
145-
verify_masks(func)
146119

147120
def wrapper(*args: Any, **kwargs: Any) -> Callable[..., Any]:
148121
return func(*args, **kwargs) # type: ignore
@@ -153,9 +126,6 @@ def wrapper(*args: Any, **kwargs: Any) -> Callable[..., Any]:
153126

154127
return decorate
155128

156-
if with_null_masks:
157-
verify_masks(func)
158-
159129
def wrapper(*args: Any, **kwargs: Any) -> Callable[..., Any]:
160130
return func(*args, **kwargs) # type: ignore
161131

@@ -180,54 +150,7 @@ def udf(
180150
The UDF to apply parameters to
181151
name : str, optional
182152
The name to use for the UDF in the database
183-
args : str | Callable | List[str | Callable], optional
184-
Specifies the data types of the function arguments. Typically,
185-
the function data types are derived from the function parameter
186-
annotations. These annotations can be overridden. If the function
187-
takes a single type for all parameters, `args` can be set to a
188-
SQL string describing all parameters. If the function takes more
189-
than one parameter and all of the parameters are being manually
190-
defined, a list of SQL strings may be used (one for each parameter).
191-
A dictionary of SQL strings may be used to specify a parameter type
192-
for a subset of parameters; the keys are the names of the
193-
function parameters. Callables may also be used for datatypes. This
194-
is primarily for using the functions in the ``dtypes`` module that
195-
are associated with SQL types with all default options (e.g., ``dt.FLOAT``).
196-
returns : str, optional
197-
Specifies the return data type of the function. If not specified,
198-
the type annotation from the function is used.
199-
200-
Returns
201-
-------
202-
Callable
203-
204-
"""
205-
return _func(
206-
func=func,
207-
name=name,
208-
args=args,
209-
returns=returns,
210-
with_null_masks=False,
211-
)
212-
213-
214-
def udf_with_null_masks(
215-
func: Optional[Callable[..., Any]] = None,
216-
*,
217-
name: Optional[str] = None,
218-
args: Optional[ParameterType] = None,
219-
returns: Optional[ReturnType] = None,
220-
) -> Callable[..., Any]:
221-
"""
222-
Define a user-defined function (UDF) with null masks.
223-
224-
Parameters
225-
----------
226-
func : callable, optional
227-
The UDF to apply parameters to
228-
name : str, optional
229-
The name to use for the UDF in the database
230-
args : str | Callable | List[str | Callable], optional
153+
args : str | Type | Callable | List[str | Callable], optional
231154
Specifies the data types of the function arguments. Typically,
232155
the function data types are derived from the function parameter
233156
annotations. These annotations can be overridden. If the function
@@ -240,9 +163,10 @@ def udf_with_null_masks(
240163
function parameters. Callables may also be used for datatypes. This
241164
is primarily for using the functions in the ``dtypes`` module that
242165
are associated with SQL types with all default options (e.g., ``dt.FLOAT``).
243-
returns : str, optional
244-
Specifies the return data type of the function. If not specified,
245-
the type annotation from the function is used.
166+
returns : str | Type | Callable | List[str | Callable] | Table, optional
167+
Specifies the return data type of the function. This parameter
168+
works the same way as `args`. If the function is a table-valued
169+
function, the return type should be a `Table` object.
246170
247171
Returns
248172
-------
@@ -254,5 +178,4 @@ def udf_with_null_masks(
254178
name=name,
255179
args=args,
256180
returns=returns,
257-
with_null_masks=True,
258181
)

singlestoredb/functions/ext/asgi.py

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import asyncio
2727
import dataclasses
2828
import importlib.util
29+
import inspect
2930
import io
3031
import itertools
3132
import json
@@ -36,6 +37,7 @@
3637
import sys
3738
import tempfile
3839
import textwrap
40+
import typing
3941
import urllib
4042
import zipfile
4143
import zipimport
@@ -62,6 +64,7 @@
6264
from ...mysql.constants import FIELD_TYPE as ft
6365
from ..signature import get_signature
6466
from ..signature import signature_to_sql
67+
from ..typing import Masked
6568

6669
try:
6770
import cloudpickle
@@ -207,6 +210,25 @@ def get_array_class(data_format: str) -> Callable[..., Any]:
207210
return array_cls
208211

209212

213+
def get_masked_params(func: Callable[..., Any]) -> List[bool]:
214+
"""
215+
Get the list of masked parameters for the function.
216+
217+
Parameters
218+
----------
219+
func : Callable
220+
The function to call as the endpoint
221+
222+
Returns
223+
-------
224+
List[bool]
225+
Boolean list of masked parameters
226+
227+
"""
228+
params = inspect.signature(func).parameters
229+
return [typing.get_origin(x.annotation) is Masked for x in params.values()]
230+
231+
210232
def make_func(
211233
name: str,
212234
func: Callable[..., Any],
@@ -226,8 +248,6 @@ def make_func(
226248
(Callable, Dict[str, Any])
227249
228250
"""
229-
attrs = getattr(func, '_singlestoredb_attrs', {})
230-
with_null_masks = attrs.get('with_null_masks', False)
231251
info: Dict[str, Any] = {}
232252

233253
sig = get_signature(func, func_name=name)
@@ -236,6 +256,8 @@ def make_func(
236256
args_data_format = sig.get('args_data_format', 'scalar')
237257
returns_data_format = sig.get('returns_data_format', 'scalar')
238258

259+
masks = get_masked_params(func)
260+
239261
if function_type == 'tvf':
240262
# Scalar (Python) types
241263
if returns_data_format == 'scalar':
@@ -265,24 +287,21 @@ async def do_func( # type: ignore
265287
# each result row, so we just have to use the same
266288
# row ID for all rows in the result.
267289

268-
# If `with_null_masks` is set, the function is expected to return
269-
# a tuple of (data, mask) for each column.
270-
if with_null_masks:
271-
out = func(*cols)
272-
assert isinstance(out, tuple)
273-
row_ids = array_cls([row_ids[0]] * len(out[0][0]))
274-
return row_ids, [out]
290+
def build_tuple(x: Any) -> Any:
291+
return tuple(x) if isinstance(x, Masked) else (x, None)
275292

276293
# Call function on each column of data
277294
if cols and cols[0]:
278-
res = get_dataframe_columns(func(*[x[0] for x in cols]))
295+
res = get_dataframe_columns(
296+
func(*[x if m else x[0] for x, m in zip(cols, masks)]),
297+
)
279298
else:
280299
res = get_dataframe_columns(func())
281300

282301
# Generate row IDs
283302
row_ids = array_cls([row_ids[0]] * len(res[0]))
284303

285-
return row_ids, [(x, None) for x in res]
304+
return row_ids, [build_tuple(x) for x in res]
286305

287306
else:
288307
# Scalar (Python) types
@@ -305,22 +324,22 @@ async def do_func( # type: ignore
305324
'''Call function on given cols of data.'''
306325
row_ids = array_cls(row_ids)
307326

308-
# If `with_null_masks` is set, the function is expected to return
309-
# a tuple of (data, mask) for each column.`
310-
if with_null_masks:
311-
out = func(*cols)
312-
assert isinstance(out, tuple)
313-
return row_ids, [out]
327+
def build_tuple(x: Any) -> Any:
328+
return tuple(x) if isinstance(x, Masked) else (x, None)
314329

315330
# Call the function with `cols` as the function parameters
316331
if cols and cols[0]:
317-
out = func(*[x[0] for x in cols])
332+
out = func(*[x if m else x[0] for x, m in zip(cols, masks)])
318333
else:
319334
out = func()
320335

336+
# Single masked value
337+
if isinstance(out, Masked):
338+
return row_ids, [tuple(out)]
339+
321340
# Multiple return values
322341
if isinstance(out, tuple):
323-
return row_ids, [(x, None) for x in out]
342+
return row_ids, [build_tuple(x) for x in out]
324343

325344
# Single return value
326345
return row_ids, [(out, None)]

0 commit comments

Comments
 (0)