1313import sys
1414import traceback
1515import types
16+ import typing
1617from datetime import datetime
1718from datetime import timezone
1819from typing import Any
2223from typing import Optional
2324from typing import Tuple
2425
26+ from singlestoredb .functions .ext import rowdat_1 as _rowdat_1
2527from singlestoredb .functions .ext .rowdat_1 import dump as _dump_rowdat_1
2628from singlestoredb .functions .ext .rowdat_1 import load as _load_rowdat_1
2729from singlestoredb .functions .signature import get_signature
30+ from singlestoredb .functions .typing import Masked
2831from singlestoredb .mysql .constants import FIELD_TYPE as ft
2932
3033_accel_error : Optional [str ] = None
@@ -453,6 +456,96 @@ def _register_function(
453456 }
454457
455458
459+ def _get_masked_params (func : Callable [..., Any ]) -> List [bool ]:
460+ """Determine which parameters expect (data, mask) tuples vs just data."""
461+ params = inspect .signature (func ).parameters
462+ return [typing .get_origin (x .annotation ) is Masked for x in params .values ()]
463+
464+
465+ def _get_vector_loader (fmt : str ) -> Callable [..., Any ]:
466+ """Return the appropriate rowdat_1 loader for the given data format."""
467+ loaders : Dict [str , str ] = {
468+ 'numpy' : 'load_numpy' ,
469+ 'pandas' : 'load_pandas' ,
470+ 'polars' : 'load_polars' ,
471+ 'arrow' : 'load_arrow' ,
472+ 'list' : 'load_list' ,
473+ }
474+ attr = loaders .get (fmt )
475+ if attr is None :
476+ raise ValueError (f'unsupported vector data format: { fmt !r} ' )
477+ return getattr (_rowdat_1 , attr )
478+
479+
480+ def _get_vector_dumper (fmt : str ) -> Callable [..., Any ]:
481+ """Return the appropriate rowdat_1 dumper for the given data format."""
482+ dumpers : Dict [str , str ] = {
483+ 'numpy' : 'dump_numpy' ,
484+ 'pandas' : 'dump_pandas' ,
485+ 'polars' : 'dump_polars' ,
486+ 'arrow' : 'dump_arrow' ,
487+ 'list' : 'dump_list' ,
488+ }
489+ attr = dumpers .get (fmt )
490+ if attr is None :
491+ raise ValueError (f'unsupported vector data format: { fmt !r} ' )
492+ return getattr (_rowdat_1 , attr )
493+
494+
495+ def _normalize_vector_output (
496+ out : Any ,
497+ num_returns : int ,
498+ ) -> List [Tuple [Any , Any ]]:
499+ """Normalize vectorized UDF output to List[(data, mask_or_None)]."""
500+ if num_returns == 1 :
501+ if isinstance (out , tuple ) and len (out ) == 2 :
502+ # Could be a Masked (data, mask) or a 2-element tuple of columns
503+ # Check if it looks like Masked: second element is a boolean mask
504+ import numpy as np
505+ if hasattr (out [1 ], 'dtype' ) and out [1 ].dtype == np .bool_ :
506+ return [out ]
507+ return [(out , None )]
508+
509+ # Multiple return columns
510+ if not isinstance (out , (tuple , list )):
511+ raise TypeError (
512+ f'vectorized UDF with { num_returns } return columns must '
513+ f'return a tuple or list, got { type (out ).__name__ } ' ,
514+ )
515+ result_cols = []
516+ for x in out :
517+ if isinstance (x , tuple ) and len (x ) == 2 :
518+ result_cols .append (x )
519+ else :
520+ result_cols .append ((x , None ))
521+ return result_cols
522+
523+
524+ def _call_function_vector (
525+ func : Callable [..., Any ],
526+ arg_types : List [Tuple [str , int ]],
527+ return_types : List [int ],
528+ input_data : bytes ,
529+ args_data_format : str ,
530+ returns_data_format : str ,
531+ masks : List [bool ],
532+ ) -> bytes :
533+ """Call a vectorized UDF with columnar data."""
534+ loader = _get_vector_loader (args_data_format )
535+ dumper = _get_vector_dumper (returns_data_format )
536+
537+ row_ids , cols = loader (arg_types , input_data )
538+
539+ # Masked params get the full (data, mask) tuple, others get just data
540+ func_args = [col if m else col [0 ] for col , m in zip (cols , masks )]
541+
542+ out = func (* func_args )
543+
544+ result_cols = _normalize_vector_output (out , len (return_types ))
545+
546+ return bytes (dumper (return_types , row_ids , result_cols ))
547+
548+
456549def call_function (
457550 registry : FunctionRegistry ,
458551 name : str ,
@@ -470,8 +563,24 @@ def call_function(
470563 func = func_info ['func' ]
471564 arg_types = func_info ['arg_types' ]
472565 return_types = func_info ['return_types' ]
566+ sig = func_info ['signature' ]
567+
568+ args_data_format = sig .get ('args_data_format' ) or 'scalar'
569+ returns_data_format = sig .get ('returns_data_format' ) or 'scalar'
473570
474571 try :
572+ # Vector path: columnar processing
573+ if args_data_format not in ('scalar' ,):
574+ masks = func_info .get ('_masks' )
575+ if masks is None :
576+ masks = _get_masked_params (func )
577+ func_info ['_masks' ] = masks
578+ return _call_function_vector (
579+ func , arg_types , return_types , input_data ,
580+ args_data_format , returns_data_format , masks ,
581+ )
582+
583+ # Scalar path: row-by-row processing
475584 if _has_accel :
476585 return _call_function_accel (
477586 colspec = arg_types ,
0 commit comments