@@ -246,6 +246,192 @@ def get_masked_params(func: Callable[..., Any]) -> List[bool]:
246246 return [typing .get_origin (x .annotation ) is Masked for x in params .values ()]
247247
248248
249+ def build_tuple (x : Any ) -> Any :
250+ """Convert object to tuple."""
251+ return tuple (x ) if isinstance (x , Masked ) else (x , None )
252+
253+
254+ def build_udf_endpoint (
255+ func : Callable [..., Any ],
256+ returns_data_format : str ,
257+ ) -> Callable [..., Any ]:
258+ """
259+ Build a UDF endpoint for scalar / list types (row-based).
260+
261+ Parameters
262+ ----------
263+ func : Callable
264+ The function to call as the endpoint
265+ returns_data_format : str
266+ The format of the return values
267+
268+ Returns
269+ -------
270+ Callable
271+ The function endpoint
272+
273+ """
274+ if returns_data_format in ['scalar' , 'list' ]:
275+
276+ async def do_func (
277+ row_ids : Sequence [int ],
278+ rows : Sequence [Sequence [Any ]],
279+ ) -> Tuple [Sequence [int ], List [Tuple [Any , ...]]]:
280+ '''Call function on given rows of data.'''
281+ return row_ids , [as_tuple (x ) for x in zip (func_map (func , rows ))]
282+
283+ return do_func
284+
285+ return build_vector_udf_endpoint (func , returns_data_format )
286+
287+
288+ def build_vector_udf_endpoint (
289+ func : Callable [..., Any ],
290+ returns_data_format : str ,
291+ ) -> Callable [..., Any ]:
292+ """
293+ Build a UDF endpoint for vector formats (column-based).
294+
295+ Parameters
296+ ----------
297+ func : Callable
298+ The function to call as the endpoint
299+ returns_data_format : str
300+ The format of the return values
301+
302+ Returns
303+ -------
304+ Callable
305+ The function endpoint
306+
307+ """
308+ masks = get_masked_params (func )
309+ array_cls = get_array_class (returns_data_format )
310+
311+ async def do_func (
312+ row_ids : Sequence [int ],
313+ cols : Sequence [Tuple [Sequence [Any ], Optional [Sequence [bool ]]]],
314+ ) -> Tuple [
315+ Sequence [int ],
316+ List [Tuple [Sequence [Any ], Optional [Sequence [bool ]]]],
317+ ]:
318+ '''Call function on given columns of data.'''
319+ row_ids = array_cls (row_ids )
320+
321+ # Call the function with `cols` as the function parameters
322+ if cols and cols [0 ]:
323+ out = func (* [x if m else x [0 ] for x , m in zip (cols , masks )])
324+ else :
325+ out = func ()
326+
327+ # Single masked value
328+ if isinstance (out , Masked ):
329+ return row_ids , [tuple (out )]
330+
331+ # Multiple return values
332+ if isinstance (out , tuple ):
333+ return row_ids , [build_tuple (x ) for x in out ]
334+
335+ # Single return value
336+ return row_ids , [(out , None )]
337+
338+ return do_func
339+
340+
341+ def build_tvf_endpoint (
342+ func : Callable [..., Any ],
343+ returns_data_format : str ,
344+ ) -> Callable [..., Any ]:
345+ """
346+ Build a TVF endpoint for scalar / list types (row-based).
347+
348+ Parameters
349+ ----------
350+ func : Callable
351+ The function to call as the endpoint
352+ returns_data_format : str
353+ The format of the return values
354+
355+ Returns
356+ -------
357+ Callable
358+ The function endpoint
359+
360+ """
361+ if returns_data_format in ['scalar' , 'list' ]:
362+
363+ async def do_func (
364+ row_ids : Sequence [int ],
365+ rows : Sequence [Sequence [Any ]],
366+ ) -> Tuple [Sequence [int ], List [Tuple [Any , ...]]]:
367+ '''Call function on given rows of data.'''
368+ out_ids : List [int ] = []
369+ out = []
370+ # Call function on each row of data
371+ for i , res in zip (row_ids , func_map (func , rows )):
372+ out .extend (as_list_of_tuples (res ))
373+ out_ids .extend ([row_ids [i ]] * (len (out )- len (out_ids )))
374+ return out_ids , out
375+
376+ return do_func
377+
378+ return build_vector_tvf_endpoint (func , returns_data_format )
379+
380+
381+ def build_vector_tvf_endpoint (
382+ func : Callable [..., Any ],
383+ returns_data_format : str ,
384+ ) -> Callable [..., Any ]:
385+ """
386+ Build a TVF endpoint for vector formats (column-based).
387+
388+ Parameters
389+ ----------
390+ func : Callable
391+ The function to call as the endpoint
392+ returns_data_format : str
393+ The format of the return values
394+
395+ Returns
396+ -------
397+ Callable
398+ The function endpoint
399+
400+ """
401+ masks = get_masked_params (func )
402+ array_cls = get_array_class (returns_data_format )
403+
404+ async def do_func (
405+ row_ids : Sequence [int ],
406+ cols : Sequence [Tuple [Sequence [Any ], Optional [Sequence [bool ]]]],
407+ ) -> Tuple [
408+ Sequence [int ],
409+ List [Tuple [Sequence [Any ], Optional [Sequence [bool ]]]],
410+ ]:
411+ '''Call function on given columns of data.'''
412+ # NOTE: There is no way to determine which row ID belongs to
413+ # each result row, so we just have to use the same
414+ # row ID for all rows in the result.
415+
416+ # Call function on each column of data
417+ if cols and cols [0 ]:
418+ res = get_dataframe_columns (
419+ func (* [x if m else x [0 ] for x , m in zip (cols , masks )]),
420+ )
421+ else :
422+ res = get_dataframe_columns (func ())
423+
424+ # Generate row IDs
425+ if isinstance (res [0 ], Masked ):
426+ row_ids = array_cls ([row_ids [0 ]] * len (res [0 ][0 ]))
427+ else :
428+ row_ids = array_cls ([row_ids [0 ]] * len (res [0 ]))
429+
430+ return row_ids , [build_tuple (x ) for x in res ]
431+
432+ return do_func
433+
434+
249435def make_func (
250436 name : str ,
251437 func : Callable [..., Any ],
@@ -273,102 +459,10 @@ def make_func(
273459 args_data_format = sig .get ('args_data_format' , 'scalar' )
274460 returns_data_format = sig .get ('returns_data_format' , 'scalar' )
275461
276- masks = get_masked_params (func )
277-
278462 if function_type == 'tvf' :
279- # Scalar / list types (row-based)
280- if returns_data_format in ['scalar' , 'list' ]:
281- async def do_func (
282- row_ids : Sequence [int ],
283- rows : Sequence [Sequence [Any ]],
284- ) -> Tuple [Sequence [int ], List [Tuple [Any , ...]]]:
285- '''Call function on given rows of data.'''
286- out_ids : List [int ] = []
287- out = []
288- # Call function on each row of data
289- for i , res in zip (row_ids , func_map (func , rows )):
290- out .extend (as_list_of_tuples (res ))
291- out_ids .extend ([row_ids [i ]] * (len (out )- len (out_ids )))
292- return out_ids , out
293-
294- # Vector formats (column-based)
295- else :
296- array_cls = get_array_class (returns_data_format )
297-
298- async def do_func ( # type: ignore
299- row_ids : Sequence [int ],
300- cols : Sequence [Tuple [Sequence [Any ], Optional [Sequence [bool ]]]],
301- ) -> Tuple [
302- Sequence [int ],
303- List [Tuple [Sequence [Any ], Optional [Sequence [bool ]]]],
304- ]:
305- '''Call function on given cols of data.'''
306- # NOTE: There is no way to determine which row ID belongs to
307- # each result row, so we just have to use the same
308- # row ID for all rows in the result.
309-
310- def build_tuple (x : Any ) -> Any :
311- return tuple (x ) if isinstance (x , Masked ) else (x , None )
312-
313- # Call function on each column of data
314- if cols and cols [0 ]:
315- res = get_dataframe_columns (
316- func (* [x if m else x [0 ] for x , m in zip (cols , masks )]),
317- )
318- else :
319- res = get_dataframe_columns (func ())
320-
321- # Generate row IDs
322- if isinstance (res [0 ], Masked ):
323- row_ids = array_cls ([row_ids [0 ]] * len (res [0 ][0 ]))
324- else :
325- row_ids = array_cls ([row_ids [0 ]] * len (res [0 ]))
326-
327- return row_ids , [build_tuple (x ) for x in res ]
328-
463+ do_func = build_tvf_endpoint (func , returns_data_format )
329464 else :
330- # Scalar / list types (row-based)
331- if returns_data_format in ['scalar' , 'list' ]:
332- async def do_func (
333- row_ids : Sequence [int ],
334- rows : Sequence [Sequence [Any ]],
335- ) -> Tuple [Sequence [int ], List [Tuple [Any , ...]]]:
336- '''Call function on given rows of data.'''
337- return row_ids , [as_tuple (x ) for x in zip (func_map (func , rows ))]
338-
339- # Vector formats (column-based)
340- else :
341- array_cls = get_array_class (returns_data_format )
342-
343- async def do_func ( # type: ignore
344- row_ids : Sequence [int ],
345- cols : Sequence [Tuple [Sequence [Any ], Optional [Sequence [bool ]]]],
346- ) -> Tuple [
347- Sequence [int ],
348- List [Tuple [Sequence [Any ], Optional [Sequence [bool ]]]],
349- ]:
350- '''Call function on given cols of data.'''
351- row_ids = array_cls (row_ids )
352-
353- def build_tuple (x : Any ) -> Any :
354- return tuple (x ) if isinstance (x , Masked ) else (x , None )
355-
356- # Call the function with `cols` as the function parameters
357- if cols and cols [0 ]:
358- out = func (* [x if m else x [0 ] for x , m in zip (cols , masks )])
359- else :
360- out = func ()
361-
362- # Single masked value
363- if isinstance (out , Masked ):
364- return row_ids , [tuple (out )]
365-
366- # Multiple return values
367- if isinstance (out , tuple ):
368- return row_ids , [build_tuple (x ) for x in out ]
369-
370- # Single return value
371- return row_ids , [(out , None )]
465+ do_func = build_udf_endpoint (func , returns_data_format )
372466
373467 do_func .__name__ = name
374468 do_func .__doc__ = func .__doc__
0 commit comments