Skip to content

Commit 114c73a

Browse files
committed
Big refactoring of the way parameters / return values work
1 parent edc6893 commit 114c73a

File tree

7 files changed

+914
-537
lines changed

7 files changed

+914
-537
lines changed

accel.c

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#define NUMPY_DATETIME 13
3737
#define NUMPY_OBJECT 14
3838
#define NUMPY_BYTES 15
39+
#define NUMPY_FIXED_STRING 16
3940

4041
#define MYSQL_FLAG_NOT_NULL 1
4142
#define MYSQL_FLAG_PRI_KEY 2
@@ -2745,6 +2746,13 @@ static NumpyColType get_numpy_col_type(PyObject *py_array) {
27452746
goto error;
27462747
}
27472748
break;
2749+
case 'U':
2750+
out.type = NUMPY_FIXED_STRING;
2751+
out.length = (Py_ssize_t)strtol(str + 2, NULL, 10);
2752+
if (out.length < 0) {
2753+
goto error;
2754+
}
2755+
break;
27482756
default:
27492757
goto error;
27502758
}
@@ -3845,7 +3853,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k
38453853
case MYSQL_TYPE_MEDIUM_BLOB:
38463854
case MYSQL_TYPE_LONG_BLOB:
38473855
case MYSQL_TYPE_BLOB:
3848-
if (col_types[i].type != NUMPY_OBJECT) {
3856+
if (col_types[i].type != NUMPY_OBJECT && col_types[i].type != NUMPY_FIXED_STRING) {
38493857
PyErr_SetString(PyExc_ValueError, "unsupported numpy data type for character output types");
38503858
goto error;
38513859
}
@@ -3856,6 +3864,24 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k
38563864
memcpy(out+out_idx, &i64, 8);
38573865
out_idx += 8;
38583866

3867+
} else if (col_types[i].type == NUMPY_FIXED_STRING) {
3868+
void *bytes = (void*)(cols[i] + j * 8);
3869+
3870+
if (bytes == NULL) {
3871+
CHECKMEM(8);
3872+
i64 = 0;
3873+
memcpy(out+out_idx, &i64, 8);
3874+
out_idx += 8;
3875+
} else {
3876+
Py_ssize_t str_l = strnlen(bytes, col_types[i].length);
3877+
CHECKMEM(8+str_l);
3878+
i64 = str_l;
3879+
memcpy(out+out_idx, &i64, 8);
3880+
out_idx += 8;
3881+
memcpy(out+out_idx, bytes, str_l);
3882+
out_idx += str_l;
3883+
}
3884+
38593885
} else {
38603886
u64 = *(uint64_t*)(cols[i] + j * 8);
38613887

Lines changed: 41 additions & 174 deletions
Original file line numberDiff line numberDiff line change
@@ -1,142 +1,56 @@
1-
import dataclasses
2-
import datetime
1+
from __future__ import annotations
2+
33
import functools
44
import inspect
55
from typing import Any
66
from typing import Callable
7-
from typing import Dict
87
from typing import List
98
from typing import Optional
10-
from typing import Tuple
9+
from typing import Type
1110
from typing import Union
1211

13-
from . import dtypes
14-
from .dtypes import DataType
15-
from .signature import simplify_dtype
16-
17-
try:
18-
import pydantic
19-
has_pydantic = True
20-
except ImportError:
21-
has_pydantic = False
22-
23-
python_type_map: Dict[Any, Callable[..., str]] = {
24-
str: dtypes.TEXT,
25-
int: dtypes.BIGINT,
26-
float: dtypes.DOUBLE,
27-
bool: dtypes.BOOL,
28-
bytes: dtypes.BINARY,
29-
bytearray: dtypes.BINARY,
30-
datetime.datetime: dtypes.DATETIME,
31-
datetime.date: dtypes.DATE,
32-
datetime.timedelta: dtypes.TIME,
33-
}
34-
35-
36-
def listify(x: Any) -> List[Any]:
37-
"""Make sure sure value is a list."""
38-
if x is None:
39-
return []
40-
if isinstance(x, (list, tuple, set)):
41-
return list(x)
42-
return [x]
43-
44-
45-
def process_annotation(annotation: Any) -> Tuple[Any, bool]:
46-
types = simplify_dtype(annotation)
47-
if isinstance(types, list):
48-
nullable = False
49-
if type(None) in types:
50-
nullable = True
51-
types = [x for x in types if x is not type(None)]
52-
if len(types) > 1:
53-
raise ValueError(f'multiple types not supported: {annotation}')
54-
return types[0], nullable
55-
return types, True
5612

13+
ParameterType = Union[
14+
str,
15+
Callable[..., str],
16+
List[Union[str, Callable[..., str]]],
17+
Type[Any],
18+
]
5719

58-
def process_types(params: Any) -> Any:
59-
if params is None:
60-
return params, []
61-
62-
elif isinstance(params, (list, tuple)):
63-
params = list(params)
64-
for i, item in enumerate(params):
65-
if params[i] in python_type_map:
66-
params[i] = python_type_map[params[i]]()
67-
elif callable(item):
68-
params[i] = item()
69-
for item in params:
70-
if not isinstance(item, str):
71-
raise TypeError(f'unrecognized type for parameter: {item}')
72-
return params, []
73-
74-
elif isinstance(params, dict):
75-
names = []
76-
params = dict(params)
77-
for k, v in list(params.items()):
78-
names.append(k)
79-
if params[k] in python_type_map:
80-
params[k] = python_type_map[params[k]]()
81-
elif callable(v):
82-
params[k] = v()
83-
for item in params.values():
84-
if not isinstance(item, str):
85-
raise TypeError(f'unrecognized type for parameter: {item}')
86-
return params, names
87-
88-
elif dataclasses.is_dataclass(params):
89-
names = []
90-
out = []
91-
for item in dataclasses.fields(params):
92-
typ, nullable = process_annotation(item.type)
93-
sql_type = process_types(typ)[0]
94-
if not nullable:
95-
sql_type = sql_type.replace('NULL', 'NOT NULL')
96-
out.append(sql_type)
97-
names.append(item.name)
98-
return out, names
99-
100-
elif has_pydantic and inspect.isclass(params) \
101-
and issubclass(params, pydantic.BaseModel):
102-
names = []
103-
out = []
104-
for name, item in params.model_fields.items():
105-
typ, nullable = process_annotation(item.annotation)
106-
sql_type = process_types(typ)[0]
107-
if not nullable:
108-
sql_type = sql_type.replace('NULL', 'NOT NULL')
109-
out.append(sql_type)
110-
names.append(name)
111-
return out, names
112-
113-
elif params in python_type_map:
114-
return python_type_map[params](), []
20+
ReturnType = ParameterType
11521

116-
elif callable(params):
117-
return params(), []
11822

119-
elif isinstance(params, str):
120-
return params, []
23+
def expand_types(args: Any) -> Optional[Union[List[str], Type[Any]]]:
24+
"""Expand the types for the function arguments / return values."""
25+
if args is None:
26+
return None
12127

122-
raise TypeError(f'unrecognized data type for args: {params}')
28+
# SQL string
29+
if isinstance(args, str):
30+
return [args]
12331

32+
# General way of accepting pydantic.BaseModel, NamedTuple, TypedDict
33+
elif inspect.isclass(args):
34+
return args
12435

125-
ParameterType = Union[
126-
str,
127-
List[str],
128-
Dict[str, str],
129-
'pydantic.BaseModel',
130-
type,
131-
]
36+
# Callable that returns a SQL string
37+
elif callable(args):
38+
out = args()
39+
if not isinstance(out, str):
40+
raise TypeError(f'unrecognized type for parameter: {args}')
41+
return [out]
13242

133-
ReturnType = Union[
134-
str,
135-
List[DataType],
136-
List[type],
137-
'pydantic.BaseModel',
138-
type,
139-
]
43+
# List of SQL strings or callables
44+
else:
45+
new_args = []
46+
for arg in args:
47+
if isinstance(arg, str):
48+
new_args.append(arg)
49+
elif callable(arg):
50+
new_args.append(arg())
51+
else:
52+
raise TypeError(f'unrecognized type for parameter: {arg}')
53+
return new_args
14054

14155

14256
def _func(
@@ -145,40 +59,18 @@ def _func(
14559
name: Optional[str] = None,
14660
args: Optional[ParameterType] = None,
14761
returns: Optional[ReturnType] = None,
148-
data_format: Optional[str] = None,
14962
include_masks: bool = False,
15063
function_type: str = 'udf',
151-
output_fields: Optional[List[str]] = None,
15264
) -> Callable[..., Any]:
15365
"""Generic wrapper for UDF and TVF decorators."""
154-
args, _ = process_types(args)
155-
returns, fields = process_types(returns)
156-
157-
if not output_fields and fields:
158-
output_fields = fields
159-
160-
if isinstance(returns, list) \
161-
and isinstance(output_fields, list) \
162-
and len(output_fields) != len(returns):
163-
raise ValueError(
164-
'The number of output fields must match the number of return types',
165-
)
166-
167-
if include_masks and data_format == 'python':
168-
raise RuntimeError(
169-
'include_masks is only valid when using '
170-
'vectors for input parameters',
171-
)
17266

17367
_singlestoredb_attrs = { # type: ignore
17468
k: v for k, v in dict(
17569
name=name,
176-
args=args,
177-
returns=returns,
178-
data_format=data_format,
70+
args=expand_types(args),
71+
returns=expand_types(returns),
17972
include_masks=include_masks,
18073
function_type=function_type,
181-
output_fields=output_fields or None,
18274
).items() if v is not None
18375
}
18476

@@ -207,7 +99,6 @@ def udf(
20799
name: Optional[str] = None,
208100
args: Optional[ParameterType] = None,
209101
returns: Optional[ReturnType] = None,
210-
data_format: Optional[str] = None,
211102
include_masks: bool = False,
212103
) -> Callable[..., Any]:
213104
"""
@@ -219,7 +110,7 @@ def udf(
219110
The UDF to apply parameters to
220111
name : str, optional
221112
The name to use for the UDF in the database
222-
args : str | Callable | List[str | Callable] | Dict[str, str | Callable], optional
113+
args : str | Callable | List[str | Callable], optional
223114
Specifies the data types of the function arguments. Typically,
224115
the function data types are derived from the function parameter
225116
annotations. These annotations can be overridden. If the function
@@ -235,8 +126,6 @@ def udf(
235126
returns : str, optional
236127
Specifies the return data type of the function. If not specified,
237128
the type annotation from the function is used.
238-
data_format : str, optional
239-
The data format of each parameter: python, pandas, arrow, polars
240129
include_masks : bool, optional
241130
Should boolean masks be included with each input parameter to indicate
242131
which elements are NULL? This is only used when a input parameters are
@@ -252,27 +141,18 @@ def udf(
252141
name=name,
253142
args=args,
254143
returns=returns,
255-
data_format=data_format,
256144
include_masks=include_masks,
257145
function_type='udf',
258146
)
259147

260148

261-
udf.pandas = functools.partial(udf, data_format='pandas') # type: ignore
262-
udf.polars = functools.partial(udf, data_format='polars') # type: ignore
263-
udf.arrow = functools.partial(udf, data_format='arrow') # type: ignore
264-
udf.numpy = functools.partial(udf, data_format='numpy') # type: ignore
265-
266-
267149
def tvf(
268150
func: Optional[Callable[..., Any]] = None,
269151
*,
270152
name: Optional[str] = None,
271153
args: Optional[ParameterType] = None,
272154
returns: Optional[ReturnType] = None,
273-
data_format: Optional[str] = None,
274155
include_masks: bool = False,
275-
output_fields: Optional[List[str]] = None,
276156
) -> Callable[..., Any]:
277157
"""
278158
Apply attributes to a TVF.
@@ -283,7 +163,7 @@ def tvf(
283163
The TVF to apply parameters to
284164
name : str, optional
285165
The name to use for the TVF in the database
286-
args : str | Callable | List[str | Callable] | Dict[str, str | Callable], optional
166+
args : str | Callable | List[str | Callable], optional
287167
Specifies the data types of the function arguments. Typically,
288168
the function data types are derived from the function parameter
289169
annotations. These annotations can be overridden. If the function
@@ -299,15 +179,10 @@ def tvf(
299179
returns : str, optional
300180
Specifies the return data type of the function. If not specified,
301181
the type annotation from the function is used.
302-
data_format : str, optional
303-
The data format of each parameter: python, pandas, arrow, polars
304182
include_masks : bool, optional
305183
Should boolean masks be included with each input parameter to indicate
306184
which elements are NULL? This is only used when a input parameters are
307185
configured to a vector type (numpy, pandas, polars, arrow).
308-
output_fields : List[str], optional
309-
The names of the output fields for the TVF. If not specified, the
310-
names are generated.
311186
312187
Returns
313188
-------
@@ -319,14 +194,6 @@ def tvf(
319194
name=name,
320195
args=args,
321196
returns=returns,
322-
data_format=data_format,
323197
include_masks=include_masks,
324198
function_type='tvf',
325-
output_fields=output_fields,
326199
)
327-
328-
329-
tvf.pandas = functools.partial(tvf, data_format='pandas') # type: ignore
330-
tvf.polars = functools.partial(tvf, data_format='polars') # type: ignore
331-
tvf.arrow = functools.partial(tvf, data_format='arrow') # type: ignore
332-
tvf.numpy = functools.partial(tvf, data_format='numpy') # type: ignore

0 commit comments

Comments
 (0)