Skip to content

Commit 65b180b

Browse files
Merge branch 'udf-apis' into users/snarayanan/udfs
2 parents 7ccbd14 + 52dd5cd commit 65b180b

File tree

4 files changed

+137
-8
lines changed

4 files changed

+137
-8
lines changed

singlestoredb/config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,18 @@
407407
environ=['SINGLESTOREDB_EXT_FUNC_LOG_LEVEL'],
408408
)
409409

410+
register_option(
411+
'external_function.name_prefix', 'string', check_str, '',
412+
'Prefix to add to external function names.',
413+
environ=['SINGLESTOREDB_EXT_FUNC_NAME_PREFIX'],
414+
)
415+
416+
register_option(
417+
'external_function.name_suffix', 'string', check_str, '',
418+
'Suffix to add to external function names.',
419+
environ=['SINGLESTOREDB_EXT_FUNC_NAME_SUFFIX'],
420+
)
421+
410422
register_option(
411423
'external_function.connection', 'string', check_str,
412424
os.environ.get('SINGLESTOREDB_URL') or None,

singlestoredb/functions/ext/asgi.py

Lines changed: 98 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,7 @@ class Application(object):
468468
# Valid URL paths
469469
invoke_path = ('invoke',)
470470
show_create_function_path = ('show', 'create_function')
471+
show_function_info_path = ('show', 'function_info')
471472

472473
def __init__(
473474
self,
@@ -488,6 +489,8 @@ def __init__(
488489
link_name: Optional[str] = get_option('external_function.link_name'),
489490
link_config: Optional[Dict[str, Any]] = None,
490491
link_credentials: Optional[Dict[str, Any]] = None,
492+
name_prefix: str = get_option('external_function.name_prefix'),
493+
name_suffix: str = get_option('external_function.name_suffix'),
491494
) -> None:
492495
if link_name and (link_config or link_credentials):
493496
raise ValueError(
@@ -544,6 +547,7 @@ def __init__(
544547
if not hasattr(x, '_singlestoredb_attrs'):
545548
continue
546549
name = x._singlestoredb_attrs.get('name', x.__name__)
550+
name = f'{name_prefix}{name}{name_suffix}'
547551
external_functions[x.__name__] = x
548552
func, info = make_func(name, x)
549553
endpoints[name.encode('utf-8')] = func, info
@@ -559,6 +563,7 @@ def __init__(
559563
# Add endpoint for each exported function
560564
for name, alias in get_func_names(func_names):
561565
item = getattr(pkg, name)
566+
alias = f'{name_prefix}{name}{name_suffix}'
562567
external_functions[name] = item
563568
func, info = make_func(alias, item)
564569
endpoints[alias.encode('utf-8')] = func, info
@@ -571,12 +576,14 @@ def __init__(
571576
if not hasattr(x, '_singlestoredb_attrs'):
572577
continue
573578
name = x._singlestoredb_attrs.get('name', x.__name__)
579+
name = f'{name_prefix}{name}{name_suffix}'
574580
external_functions[x.__name__] = x
575581
func, info = make_func(name, x)
576582
endpoints[name.encode('utf-8')] = func, info
577583

578584
else:
579585
alias = funcs.__name__
586+
alias = f'{name_prefix}{alias}{name_suffix}'
580587
external_functions[funcs.__name__] = funcs
581588
func, info = make_func(alias, funcs)
582589
endpoints[alias.encode('utf-8')] = func, info
@@ -671,6 +678,12 @@ async def __call__(
671678

672679
await send(self.text_response_dict)
673680

681+
# Return function info
682+
elif method == 'GET' and path == self.show_function_info_path:
683+
functions = self.get_function_info()
684+
body = json.dumps(dict(functions=functions)).encode('utf-8')
685+
await send(self.text_response_dict)
686+
674687
# Path not found
675688
else:
676689
body = b''
@@ -715,14 +728,77 @@ def _locate_app_functions(self, cur: Any) -> Tuple[Set[str], Set[str]]:
715728
# See if function URL matches url
716729
cur.execute(f'SHOW CREATE FUNCTION `{name}`')
717730
for fname, _, code, *_ in list(cur):
718-
m = re.search(r" (?:\w+) SERVICE '([^']+)'", code)
731+
m = re.search(r" (?:\w+) (?:SERVICE|MANAGED) '([^']+)'", code)
719732
if m and m.group(1) == self.url:
720733
funcs.add(fname)
721734
if link and re.match(r'^py_ext_func_link_\S{14}$', link):
722735
links.add(link)
723736
return funcs, links
724737

725-
def show_create_functions(
738+
def get_function_info(
739+
self,
740+
func_name: Optional[str] = None,
741+
) -> Dict[str, Any]:
742+
"""
743+
Return the functions and function signature information.
744+
745+
Returns
746+
-------
747+
Dict[str, Any]
748+
749+
"""
750+
returns: Dict[str, Any] = {}
751+
functions = {}
752+
753+
for key, (_, info) in self.endpoints.items():
754+
if not func_name or key == func_name:
755+
sig = info['signature']
756+
args = []
757+
758+
# Function arguments
759+
for a in sig.get('args', []):
760+
dtype = a['dtype']
761+
nullable = '?' in dtype
762+
args.append(
763+
dict(
764+
name=a['name'],
765+
dtype=dtype.replace('?', ''),
766+
nullable=nullable,
767+
),
768+
)
769+
770+
# Record / table return types
771+
if sig['returns']['dtype'].startswith('tuple['):
772+
fields = []
773+
dtypes = sig['returns']['dtype'][6:-1].split(',')
774+
field_names = sig['returns']['field_names']
775+
for i, dtype in enumerate(dtypes):
776+
nullable = '?' in dtype
777+
dtype = dtype.replace('?', '')
778+
fields.append(
779+
dict(
780+
name=field_names[i],
781+
dtype=dtype,
782+
nullable=nullable,
783+
),
784+
)
785+
returns = dict(
786+
dtype='table' if info['function_type'] == 'tvf' else 'struct',
787+
fields=fields,
788+
)
789+
790+
# Atomic return types
791+
else:
792+
returns = dict(
793+
dtype=sig['returns'].get('dtype').replace('?', ''),
794+
nullable='?' in sig['returns'].get('dtype', ''),
795+
)
796+
797+
functions[sig['name']] = dict(args=args, returns=returns)
798+
799+
return functions
800+
801+
def get_create_functions(
726802
self,
727803
replace: bool = False,
728804
) -> List[str]:
@@ -790,7 +866,7 @@ def register_functions(
790866
cur.execute(f'DROP FUNCTION IF EXISTS `{fname}`')
791867
for link in links:
792868
cur.execute(f'DROP LINK {link}')
793-
for func in self.show_create_functions(replace=replace):
869+
for func in self.get_create_functions(replace=replace):
794870
cur.execute(func)
795871

796872
def drop_functions(
@@ -1118,6 +1194,22 @@ def main(argv: Optional[List[str]] = None) -> None:
11181194
),
11191195
help='logging level',
11201196
)
1197+
parser.add_argument(
1198+
'--name-prefix', metavar='name_prefix',
1199+
default=defaults.get(
1200+
'name_prefix',
1201+
get_option('external_function.name_prefix'),
1202+
),
1203+
help='Prefix to add to function names',
1204+
)
1205+
parser.add_argument(
1206+
'--name-suffix', metavar='name_suffix',
1207+
default=defaults.get(
1208+
'name_suffix',
1209+
get_option('external_function.name_suffix'),
1210+
),
1211+
help='Suffix to add to function names',
1212+
)
11211213
parser.add_argument(
11221214
'functions', metavar='module.or.func.path', nargs='*',
11231215
help='functions or modules to export in UDF server',
@@ -1210,9 +1302,11 @@ def main(argv: Optional[List[str]] = None) -> None:
12101302
link_config=json.loads(args.link_config) or None,
12111303
link_credentials=json.loads(args.link_credentials) or None,
12121304
app_mode='remote',
1305+
name_prefix=args.name_prefix,
1306+
name_suffix=args.name_suffix,
12131307
)
12141308

1215-
funcs = app.show_create_functions(replace=args.replace_existing)
1309+
funcs = app.get_create_functions(replace=args.replace_existing)
12161310
if not funcs:
12171311
raise RuntimeError('no functions specified')
12181312

singlestoredb/functions/ext/mmap.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def main(argv: Optional[List[str]] = None) -> None:
338338
app_mode='collocated',
339339
)
340340

341-
funcs = app.show_create_functions(replace=args.replace_existing)
341+
funcs = app.get_create_functions(replace=args.replace_existing)
342342
if not funcs:
343343
raise RuntimeError('no functions specified')
344344

singlestoredb/functions/signature.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -487,13 +487,15 @@ def get_signature(func: Callable[..., Any], name: Optional[str] = None) -> Dict[
487487
'missing annotations for {} in {}'
488488
.format(', '.join(spec_diff), name),
489489
)
490+
490491
elif isinstance(args_overrides, dict):
491492
for s in spec_diff:
492493
if s not in args_overrides:
493494
raise TypeError(
494495
'missing annotations for {} in {}'
495496
.format(', '.join(spec_diff), name),
496497
)
498+
497499
elif isinstance(args_overrides, list):
498500
if len(arg_names) != len(args_overrides):
499501
raise TypeError(
@@ -502,6 +504,7 @@ def get_signature(func: Callable[..., Any], name: Optional[str] = None) -> Dict[
502504
)
503505

504506
for i, arg in enumerate(arg_names):
507+
505508
if isinstance(args_overrides, list):
506509
sql = args_overrides[i]
507510
arg_type = sql_to_dtype(sql)
@@ -528,6 +531,7 @@ def get_signature(func: Callable[..., Any], name: Optional[str] = None) -> Dict[
528531
if isinstance(returns_overrides, str):
529532
sql = returns_overrides
530533
out_type = sql_to_dtype(sql)
534+
531535
elif isinstance(returns_overrides, list):
532536
if not output_fields:
533537
output_fields = [
@@ -540,29 +544,35 @@ def get_signature(func: Callable[..., Any], name: Optional[str] = None) -> Dict[
540544
sql = dtype_to_sql(
541545
out_type, function_type=function_type, field_names=output_fields,
542546
)
547+
543548
elif dataclasses.is_dataclass(returns_overrides):
544549
out_type = collapse_dtypes([
545550
classify_dtype(x)
546551
for x in simplify_dtype([x.type for x in returns_overrides.fields])
547552
])
553+
output_fields = [x.name for x in returns_overrides.fields]
548554
sql = dtype_to_sql(
549555
out_type,
550556
function_type=function_type,
551-
field_names=[x.name for x in returns_overrides.fields],
557+
field_names=output_fields,
552558
)
559+
553560
elif has_pydantic and inspect.isclass(returns_overrides) \
554561
and issubclass(returns_overrides, pydantic.BaseModel):
555562
out_type = collapse_dtypes([
556563
classify_dtype(x)
557564
for x in simplify_dtype([x for x in returns_overrides.model_fields.values()])
558565
])
566+
output_fields = [x for x in returns_overrides.model_fields.keys()]
559567
sql = dtype_to_sql(
560568
out_type,
561569
function_type=function_type,
562-
field_names=[x for x in returns_overrides.model_fields.keys()],
570+
field_names=output_fields,
563571
)
572+
564573
elif returns_overrides is not None and not isinstance(returns_overrides, str):
565574
raise TypeError(f'unrecognized type for return value: {returns_overrides}')
575+
566576
else:
567577
if not output_fields:
568578
if dataclasses.is_dataclass(signature.return_annotation):
@@ -572,13 +582,26 @@ def get_signature(func: Callable[..., Any], name: Optional[str] = None) -> Dict[
572582
elif has_pydantic and inspect.isclass(signature.return_annotation) \
573583
and issubclass(signature.return_annotation, pydantic.BaseModel):
574584
output_fields = list(signature.return_annotation.model_fields.keys())
585+
575586
out_type = collapse_dtypes([
576587
classify_dtype(x) for x in simplify_dtype(signature.return_annotation)
577588
])
589+
590+
if not output_fields:
591+
output_fields = [
592+
string.ascii_letters[i] for i in range(out_type.count(',')+1)
593+
]
594+
578595
sql = dtype_to_sql(
579596
out_type, function_type=function_type, field_names=output_fields,
580597
)
581-
out['returns'] = dict(dtype=out_type, sql=sql, default=None)
598+
599+
out['returns'] = dict(
600+
dtype=out_type,
601+
sql=sql,
602+
default=None,
603+
field_names=output_fields,
604+
)
582605

583606
copied_keys = ['database', 'environment', 'packages', 'resources', 'replace']
584607
for key in copied_keys:

0 commit comments

Comments
 (0)