Skip to content

Commit 3a5c409

Browse files
kesmit13claude
andcommitted
Add vector/columnar UDF dispatch to plugin call_function()
Detect vectorized functions (numpy, pandas, polars, arrow) via args_data_format in the function signature and route through the existing C-accelerated load_rowdat_1_numpy/dump_rowdat_1_numpy infrastructure instead of the per-row scalar path. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 45d79c8 commit 3a5c409

1 file changed

Lines changed: 109 additions & 0 deletions

File tree

singlestoredb/functions/ext/plugin/registry.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import sys
1414
import traceback
1515
import types
16+
import typing
1617
from datetime import datetime
1718
from datetime import timezone
1819
from typing import Any
@@ -22,9 +23,11 @@
2223
from typing import Optional
2324
from typing import Tuple
2425

26+
from singlestoredb.functions.ext import rowdat_1 as _rowdat_1
2527
from singlestoredb.functions.ext.rowdat_1 import dump as _dump_rowdat_1
2628
from singlestoredb.functions.ext.rowdat_1 import load as _load_rowdat_1
2729
from singlestoredb.functions.signature import get_signature
30+
from singlestoredb.functions.typing import Masked
2831
from singlestoredb.mysql.constants import FIELD_TYPE as ft
2932

3033
_accel_error: Optional[str] = None
@@ -453,6 +456,96 @@ def _register_function(
453456
}
454457

455458

459+
def _get_masked_params(func: Callable[..., Any]) -> List[bool]:
460+
"""Determine which parameters expect (data, mask) tuples vs just data."""
461+
params = inspect.signature(func).parameters
462+
return [typing.get_origin(x.annotation) is Masked for x in params.values()]
463+
464+
465+
def _get_vector_loader(fmt: str) -> Callable[..., Any]:
466+
"""Return the appropriate rowdat_1 loader for the given data format."""
467+
loaders: Dict[str, str] = {
468+
'numpy': 'load_numpy',
469+
'pandas': 'load_pandas',
470+
'polars': 'load_polars',
471+
'arrow': 'load_arrow',
472+
'list': 'load_list',
473+
}
474+
attr = loaders.get(fmt)
475+
if attr is None:
476+
raise ValueError(f'unsupported vector data format: {fmt!r}')
477+
return getattr(_rowdat_1, attr)
478+
479+
480+
def _get_vector_dumper(fmt: str) -> Callable[..., Any]:
481+
"""Return the appropriate rowdat_1 dumper for the given data format."""
482+
dumpers: Dict[str, str] = {
483+
'numpy': 'dump_numpy',
484+
'pandas': 'dump_pandas',
485+
'polars': 'dump_polars',
486+
'arrow': 'dump_arrow',
487+
'list': 'dump_list',
488+
}
489+
attr = dumpers.get(fmt)
490+
if attr is None:
491+
raise ValueError(f'unsupported vector data format: {fmt!r}')
492+
return getattr(_rowdat_1, attr)
493+
494+
495+
def _normalize_vector_output(
496+
out: Any,
497+
num_returns: int,
498+
) -> List[Tuple[Any, Any]]:
499+
"""Normalize vectorized UDF output to List[(data, mask_or_None)]."""
500+
if num_returns == 1:
501+
if isinstance(out, tuple) and len(out) == 2:
502+
# Could be a Masked (data, mask) or a 2-element tuple of columns
503+
# Check if it looks like Masked: second element is a boolean mask
504+
import numpy as np
505+
if hasattr(out[1], 'dtype') and out[1].dtype == np.bool_:
506+
return [out]
507+
return [(out, None)]
508+
509+
# Multiple return columns
510+
if not isinstance(out, (tuple, list)):
511+
raise TypeError(
512+
f'vectorized UDF with {num_returns} return columns must '
513+
f'return a tuple or list, got {type(out).__name__}',
514+
)
515+
result_cols = []
516+
for x in out:
517+
if isinstance(x, tuple) and len(x) == 2:
518+
result_cols.append(x)
519+
else:
520+
result_cols.append((x, None))
521+
return result_cols
522+
523+
524+
def _call_function_vector(
525+
func: Callable[..., Any],
526+
arg_types: List[Tuple[str, int]],
527+
return_types: List[int],
528+
input_data: bytes,
529+
args_data_format: str,
530+
returns_data_format: str,
531+
masks: List[bool],
532+
) -> bytes:
533+
"""Call a vectorized UDF with columnar data."""
534+
loader = _get_vector_loader(args_data_format)
535+
dumper = _get_vector_dumper(returns_data_format)
536+
537+
row_ids, cols = loader(arg_types, input_data)
538+
539+
# Masked params get the full (data, mask) tuple, others get just data
540+
func_args = [col if m else col[0] for col, m in zip(cols, masks)]
541+
542+
out = func(*func_args)
543+
544+
result_cols = _normalize_vector_output(out, len(return_types))
545+
546+
return bytes(dumper(return_types, row_ids, result_cols))
547+
548+
456549
def call_function(
457550
registry: FunctionRegistry,
458551
name: str,
@@ -470,8 +563,24 @@ def call_function(
470563
func = func_info['func']
471564
arg_types = func_info['arg_types']
472565
return_types = func_info['return_types']
566+
sig = func_info['signature']
567+
568+
args_data_format = sig.get('args_data_format') or 'scalar'
569+
returns_data_format = sig.get('returns_data_format') or 'scalar'
473570

474571
try:
572+
# Vector path: columnar processing
573+
if args_data_format not in ('scalar',):
574+
masks = func_info.get('_masks')
575+
if masks is None:
576+
masks = _get_masked_params(func)
577+
func_info['_masks'] = masks
578+
return _call_function_vector(
579+
func, arg_types, return_types, input_data,
580+
args_data_format, returns_data_format, masks,
581+
)
582+
583+
# Scalar path: row-by-row processing
475584
if _has_accel:
476585
return _call_function_accel(
477586
colspec=arg_types,

0 commit comments

Comments
 (0)