Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 19 additions & 18 deletions singlestoredb/functions/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,23 +45,20 @@ def is_valid_type(obj: Any) -> bool:
return False


def is_valid_callable(obj: Any) -> bool:
def is_sqlstr_callable(obj: Any) -> bool:
"""Check if the object is a valid callable for a parameter type."""
if not callable(obj):
return False

returns = utils.get_annotations(obj).get('return', None)

if inspect.isclass(returns) and issubclass(returns, str):
if inspect.isclass(returns) and issubclass(returns, SQLString):
return True

raise TypeError(
f'callable {obj} must return a str, '
f'but got {returns}',
)
return False


def expand_types(args: Any) -> Optional[Union[List[str], Type[Any]]]:
def expand_types(args: Any) -> Optional[List[Any]]:
"""Expand the types for the function arguments / return values."""
if args is None:
return None
Expand All @@ -70,28 +67,32 @@ def expand_types(args: Any) -> Optional[Union[List[str], Type[Any]]]:
if isinstance(args, str):
return [args]

# General way of accepting pydantic.BaseModel, NamedTuple, TypedDict
elif is_valid_type(args):
return args

# List of SQL strings or callables
elif isinstance(args, list):
new_args = []
new_args: List[Any] = []
for arg in args:
if isinstance(arg, str):
new_args.append(arg)
elif callable(arg):
elif is_sqlstr_callable(arg):
new_args.append(arg())
elif type(arg) is type:
new_args.append(arg)
elif is_valid_type(arg):
new_args.append(arg)
else:
raise TypeError(f'unrecognized type for parameter: {arg}')
return new_args

# Callable that returns a SQL string
elif is_valid_callable(args):
out = args()
if not isinstance(out, str):
raise TypeError(f'unrecognized type for parameter: {args}')
return [out]
elif is_sqlstr_callable(args):
return [args()]

# General way of accepting pydantic.BaseModel, NamedTuple, TypedDict
elif is_valid_type(args):
return [args]

elif type(args) is type:
return [args]

raise TypeError(f'unrecognized type for parameter: {args}')

Expand Down
Loading