Skip to content

Commit c4c865f

Browse files
committed
Code cleanup
1 parent dd3f041 commit c4c865f

File tree

1 file changed

+188
-94
lines changed

1 file changed

+188
-94
lines changed

singlestoredb/functions/ext/asgi.py

Lines changed: 188 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,192 @@ def get_masked_params(func: Callable[..., Any]) -> List[bool]:
246246
return [typing.get_origin(x.annotation) is Masked for x in params.values()]
247247

248248

249+
def build_tuple(x: Any) -> Any:
250+
"""Convert object to tuple."""
251+
return tuple(x) if isinstance(x, Masked) else (x, None)
252+
253+
254+
def build_udf_endpoint(
255+
func: Callable[..., Any],
256+
returns_data_format: str,
257+
) -> Callable[..., Any]:
258+
"""
259+
Build a UDF endpoint for scalar / list types (row-based).
260+
261+
Parameters
262+
----------
263+
func : Callable
264+
The function to call as the endpoint
265+
returns_data_format : str
266+
The format of the return values
267+
268+
Returns
269+
-------
270+
Callable
271+
The function endpoint
272+
273+
"""
274+
if returns_data_format in ['scalar', 'list']:
275+
276+
async def do_func(
277+
row_ids: Sequence[int],
278+
rows: Sequence[Sequence[Any]],
279+
) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]:
280+
'''Call function on given rows of data.'''
281+
return row_ids, [as_tuple(x) for x in zip(func_map(func, rows))]
282+
283+
return do_func
284+
285+
return build_vector_udf_endpoint(func, returns_data_format)
286+
287+
288+
def build_vector_udf_endpoint(
289+
func: Callable[..., Any],
290+
returns_data_format: str,
291+
) -> Callable[..., Any]:
292+
"""
293+
Build a UDF endpoint for vector formats (column-based).
294+
295+
Parameters
296+
----------
297+
func : Callable
298+
The function to call as the endpoint
299+
returns_data_format : str
300+
The format of the return values
301+
302+
Returns
303+
-------
304+
Callable
305+
The function endpoint
306+
307+
"""
308+
masks = get_masked_params(func)
309+
array_cls = get_array_class(returns_data_format)
310+
311+
async def do_func(
312+
row_ids: Sequence[int],
313+
cols: Sequence[Tuple[Sequence[Any], Optional[Sequence[bool]]]],
314+
) -> Tuple[
315+
Sequence[int],
316+
List[Tuple[Sequence[Any], Optional[Sequence[bool]]]],
317+
]:
318+
'''Call function on given columns of data.'''
319+
row_ids = array_cls(row_ids)
320+
321+
# Call the function with `cols` as the function parameters
322+
if cols and cols[0]:
323+
out = func(*[x if m else x[0] for x, m in zip(cols, masks)])
324+
else:
325+
out = func()
326+
327+
# Single masked value
328+
if isinstance(out, Masked):
329+
return row_ids, [tuple(out)]
330+
331+
# Multiple return values
332+
if isinstance(out, tuple):
333+
return row_ids, [build_tuple(x) for x in out]
334+
335+
# Single return value
336+
return row_ids, [(out, None)]
337+
338+
return do_func
339+
340+
341+
def build_tvf_endpoint(
342+
func: Callable[..., Any],
343+
returns_data_format: str,
344+
) -> Callable[..., Any]:
345+
"""
346+
Build a TVF endpoint for scalar / list types (row-based).
347+
348+
Parameters
349+
----------
350+
func : Callable
351+
The function to call as the endpoint
352+
returns_data_format : str
353+
The format of the return values
354+
355+
Returns
356+
-------
357+
Callable
358+
The function endpoint
359+
360+
"""
361+
if returns_data_format in ['scalar', 'list']:
362+
363+
async def do_func(
364+
row_ids: Sequence[int],
365+
rows: Sequence[Sequence[Any]],
366+
) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]:
367+
'''Call function on given rows of data.'''
368+
out_ids: List[int] = []
369+
out = []
370+
# Call function on each row of data
371+
for i, res in zip(row_ids, func_map(func, rows)):
372+
out.extend(as_list_of_tuples(res))
373+
out_ids.extend([row_ids[i]] * (len(out)-len(out_ids)))
374+
return out_ids, out
375+
376+
return do_func
377+
378+
return build_vector_tvf_endpoint(func, returns_data_format)
379+
380+
381+
def build_vector_tvf_endpoint(
382+
func: Callable[..., Any],
383+
returns_data_format: str,
384+
) -> Callable[..., Any]:
385+
"""
386+
Build a TVF endpoint for vector formats (column-based).
387+
388+
Parameters
389+
----------
390+
func : Callable
391+
The function to call as the endpoint
392+
returns_data_format : str
393+
The format of the return values
394+
395+
Returns
396+
-------
397+
Callable
398+
The function endpoint
399+
400+
"""
401+
masks = get_masked_params(func)
402+
array_cls = get_array_class(returns_data_format)
403+
404+
async def do_func(
405+
row_ids: Sequence[int],
406+
cols: Sequence[Tuple[Sequence[Any], Optional[Sequence[bool]]]],
407+
) -> Tuple[
408+
Sequence[int],
409+
List[Tuple[Sequence[Any], Optional[Sequence[bool]]]],
410+
]:
411+
'''Call function on given columns of data.'''
412+
# NOTE: There is no way to determine which row ID belongs to
413+
# each result row, so we just have to use the same
414+
# row ID for all rows in the result.
415+
416+
# Call function on each column of data
417+
if cols and cols[0]:
418+
res = get_dataframe_columns(
419+
func(*[x if m else x[0] for x, m in zip(cols, masks)]),
420+
)
421+
else:
422+
res = get_dataframe_columns(func())
423+
424+
# Generate row IDs
425+
if isinstance(res[0], Masked):
426+
row_ids = array_cls([row_ids[0]] * len(res[0][0]))
427+
else:
428+
row_ids = array_cls([row_ids[0]] * len(res[0]))
429+
430+
return row_ids, [build_tuple(x) for x in res]
431+
432+
return do_func
433+
434+
249435
def make_func(
250436
name: str,
251437
func: Callable[..., Any],
@@ -273,102 +459,10 @@ def make_func(
273459
args_data_format = sig.get('args_data_format', 'scalar')
274460
returns_data_format = sig.get('returns_data_format', 'scalar')
275461

276-
masks = get_masked_params(func)
277-
278462
if function_type == 'tvf':
279-
# Scalar / list types (row-based)
280-
if returns_data_format in ['scalar', 'list']:
281-
async def do_func(
282-
row_ids: Sequence[int],
283-
rows: Sequence[Sequence[Any]],
284-
) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]:
285-
'''Call function on given rows of data.'''
286-
out_ids: List[int] = []
287-
out = []
288-
# Call function on each row of data
289-
for i, res in zip(row_ids, func_map(func, rows)):
290-
out.extend(as_list_of_tuples(res))
291-
out_ids.extend([row_ids[i]] * (len(out)-len(out_ids)))
292-
return out_ids, out
293-
294-
# Vector formats (column-based)
295-
else:
296-
array_cls = get_array_class(returns_data_format)
297-
298-
async def do_func( # type: ignore
299-
row_ids: Sequence[int],
300-
cols: Sequence[Tuple[Sequence[Any], Optional[Sequence[bool]]]],
301-
) -> Tuple[
302-
Sequence[int],
303-
List[Tuple[Sequence[Any], Optional[Sequence[bool]]]],
304-
]:
305-
'''Call function on given cols of data.'''
306-
# NOTE: There is no way to determine which row ID belongs to
307-
# each result row, so we just have to use the same
308-
# row ID for all rows in the result.
309-
310-
def build_tuple(x: Any) -> Any:
311-
return tuple(x) if isinstance(x, Masked) else (x, None)
312-
313-
# Call function on each column of data
314-
if cols and cols[0]:
315-
res = get_dataframe_columns(
316-
func(*[x if m else x[0] for x, m in zip(cols, masks)]),
317-
)
318-
else:
319-
res = get_dataframe_columns(func())
320-
321-
# Generate row IDs
322-
if isinstance(res[0], Masked):
323-
row_ids = array_cls([row_ids[0]] * len(res[0][0]))
324-
else:
325-
row_ids = array_cls([row_ids[0]] * len(res[0]))
326-
327-
return row_ids, [build_tuple(x) for x in res]
328-
463+
do_func = build_tvf_endpoint(func, returns_data_format)
329464
else:
330-
# Scalar / list types (row-based)
331-
if returns_data_format in ['scalar', 'list']:
332-
async def do_func(
333-
row_ids: Sequence[int],
334-
rows: Sequence[Sequence[Any]],
335-
) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]:
336-
'''Call function on given rows of data.'''
337-
return row_ids, [as_tuple(x) for x in zip(func_map(func, rows))]
338-
339-
# Vector formats (column-based)
340-
else:
341-
array_cls = get_array_class(returns_data_format)
342-
343-
async def do_func( # type: ignore
344-
row_ids: Sequence[int],
345-
cols: Sequence[Tuple[Sequence[Any], Optional[Sequence[bool]]]],
346-
) -> Tuple[
347-
Sequence[int],
348-
List[Tuple[Sequence[Any], Optional[Sequence[bool]]]],
349-
]:
350-
'''Call function on given cols of data.'''
351-
row_ids = array_cls(row_ids)
352-
353-
def build_tuple(x: Any) -> Any:
354-
return tuple(x) if isinstance(x, Masked) else (x, None)
355-
356-
# Call the function with `cols` as the function parameters
357-
if cols and cols[0]:
358-
out = func(*[x if m else x[0] for x, m in zip(cols, masks)])
359-
else:
360-
out = func()
361-
362-
# Single masked value
363-
if isinstance(out, Masked):
364-
return row_ids, [tuple(out)]
365-
366-
# Multiple return values
367-
if isinstance(out, tuple):
368-
return row_ids, [build_tuple(x) for x in out]
369-
370-
# Single return value
371-
return row_ids, [(out, None)]
465+
do_func = build_udf_endpoint(func, returns_data_format)
372466

373467
do_func.__name__ = name
374468
do_func.__doc__ = func.__doc__

0 commit comments

Comments
 (0)