Skip to content
5 changes: 4 additions & 1 deletion src/functions_framework/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import click

from functions_framework import create_app
from functions_framework import _function_registry, create_app
from functions_framework._http import create_server


Expand All @@ -39,6 +39,9 @@
help="Use ASGI server for function execution",
)
def _cli(target, source, signature_type, host, port, debug, asgi):
if not asgi and target in _function_registry.ASGI_FUNCTIONS:
asgi = True

if asgi: # pragma: no cover
from functions_framework.aio import create_asgi_app

Expand Down
4 changes: 4 additions & 0 deletions src/functions_framework/_function_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@
# Keys are the user function name, values are the type of the function input
INPUT_TYPE_MAP = {}

# ASGI_FUNCTIONS stores function names that require ASGI mode.
# Functions decorated with @aio.http or @aio.cloud_event are added here.
ASGI_FUNCTIONS = set()


def get_user_function(source, source_module, target):
"""Returns user function, raises exception for invalid function."""
Expand Down
2 changes: 2 additions & 0 deletions src/functions_framework/aio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def cloud_event(func: CloudEventFunction) -> CloudEventFunction:
_function_registry.REGISTRY_MAP[func.__name__] = (
_function_registry.CLOUDEVENT_SIGNATURE_TYPE
)
_function_registry.ASGI_FUNCTIONS.add(func.__name__)
if inspect.iscoroutinefunction(func):

@functools.wraps(func)
Expand All @@ -82,6 +83,7 @@ def http(func: HTTPFunction) -> HTTPFunction:
_function_registry.REGISTRY_MAP[func.__name__] = (
_function_registry.HTTP_SIGNATURE_TYPE
)
_function_registry.ASGI_FUNCTIONS.add(func.__name__)

if inspect.iscoroutinefunction(func):

Expand Down
48 changes: 48 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from click.testing import CliRunner

import functions_framework
import functions_framework._function_registry as _function_registry
import functions_framework.aio

from functions_framework._cli import _cli

Expand Down Expand Up @@ -124,3 +126,49 @@ def test_asgi_cli(monkeypatch):
assert result.exit_code == 0
assert create_asgi_app.calls == [pretend.call("foo", None, "http")]
assert asgi_server.run.calls == [pretend.call("0.0.0.0", 8080)]


@pytest.fixture
def clean_registry():
"""Save and restore function registry state."""
original_asgi = _function_registry.ASGI_FUNCTIONS.copy()
_function_registry.ASGI_FUNCTIONS.clear()
yield
_function_registry.ASGI_FUNCTIONS.clear()
_function_registry.ASGI_FUNCTIONS.update(original_asgi)


def test_auto_asgi_for_aio_decorated_functions(monkeypatch, clean_registry):
_function_registry.ASGI_FUNCTIONS.add("my_aio_func")

asgi_app = pretend.stub()
create_asgi_app = pretend.call_recorder(lambda *a, **k: asgi_app)
aio_module = pretend.stub(create_asgi_app=create_asgi_app)
monkeypatch.setitem(sys.modules, "functions_framework.aio", aio_module)

asgi_server = pretend.stub(run=pretend.call_recorder(lambda host, port: None))
create_server = pretend.call_recorder(lambda app, debug: asgi_server)
monkeypatch.setattr(functions_framework._cli, "create_server", create_server)

runner = CliRunner()
result = runner.invoke(_cli, ["--target", "my_aio_func"])

assert create_asgi_app.calls == [pretend.call("my_aio_func", None, "http")]
assert asgi_server.run.calls == [pretend.call("0.0.0.0", 8080)]


def test_no_auto_asgi_for_regular_functions(monkeypatch, clean_registry):

app = pretend.stub()
create_app = pretend.call_recorder(lambda *a, **k: app)
monkeypatch.setattr(functions_framework._cli, "create_app", create_app)

flask_server = pretend.stub(run=pretend.call_recorder(lambda host, port: None))
create_server = pretend.call_recorder(lambda app, debug: flask_server)
monkeypatch.setattr(functions_framework._cli, "create_server", create_server)

runner = CliRunner()
result = runner.invoke(_cli, ["--target", "regular_func"])

assert create_app.calls == [pretend.call("regular_func", None, "http")]
assert flask_server.run.calls == [pretend.call("0.0.0.0", 8080)]
48 changes: 48 additions & 0 deletions tests/test_decorator_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from cloudevents import conversion as ce_conversion
from cloudevents.http import CloudEvent

import functions_framework._function_registry as registry

# Conditional import for Starlette
if sys.version_info >= (3, 8):
from starlette.testclient import TestClient as StarletteTestClient
Expand Down Expand Up @@ -128,3 +130,49 @@ def test_aio_http_dict_response():
resp = client.post("/")
assert resp.status_code == 200
assert resp.json() == {"message": "hello", "count": 42, "success": True}


@pytest.fixture
def clean_registry():
"""Save and restore registry state."""
original_registry_map = registry.REGISTRY_MAP.copy()
original_asgi_functions = registry.ASGI_FUNCTIONS.copy()
registry.REGISTRY_MAP.clear()
registry.ASGI_FUNCTIONS.clear()

yield

registry.REGISTRY_MAP.clear()
registry.REGISTRY_MAP.update(original_registry_map)
registry.ASGI_FUNCTIONS.clear()
registry.ASGI_FUNCTIONS.update(original_asgi_functions)


def test_aio_decorators_register_asgi_functions(clean_registry):
"""Test that @aio decorators add function names to ASGI_FUNCTIONS registry."""
from functions_framework.aio import cloud_event, http

@http
async def test_http_func(request):
return "test"

@cloud_event
async def test_cloud_event_func(event):
pass

assert "test_http_func" in registry.ASGI_FUNCTIONS
assert "test_cloud_event_func" in registry.ASGI_FUNCTIONS

assert registry.REGISTRY_MAP["test_http_func"] == "http"
assert registry.REGISTRY_MAP["test_cloud_event_func"] == "cloudevent"

@http
def test_http_sync(request):
return "sync"

@cloud_event
def test_cloud_event_sync(event):
pass

assert "test_http_sync" in registry.ASGI_FUNCTIONS
assert "test_cloud_event_sync" in registry.ASGI_FUNCTIONS
Loading