Skip to content

Commit 321aa5a

Browse files
JayNewstromclaude
andcommitted
allow @webhook_trigger functions to control the http response
Webhook handlers now wait for the decorated function and use its return value to build the response: int -> Response(status=...), aiohttp.web.Response -> as-is, None -> default 200. When multiple triggers share a webhook_id, the first non-None return wins. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
1 parent d5f4041 commit 321aa5a

7 files changed

Lines changed: 328 additions & 7 deletions

File tree

custom_components/pyscript/decorator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ async def _call(self, data: DispatchData) -> None:
264264
# notify handlers with "None"
265265
for result_handler_dec in result_handlers:
266266
await result_handler_dec.handle_call_result(data, None)
267+
data.set_result(None)
267268
return
268269
# Fire an event indicating that pyscript is running
269270
# Note: the event must have an entity_id for logbook to work correctly.
@@ -279,7 +280,9 @@ async def _call(self, data: DispatchData) -> None:
279280
result = await data.call_ast_ctx.call_func(self.eval_func, None, **data.func_args)
280281
for result_handler_dec in result_handlers:
281282
await result_handler_dec.handle_call_result(data, result)
283+
data.set_result(result)
282284
except Exception as e:
285+
data.set_result(None)
283286
await self.handle_exception(e)
284287

285288
async def dispatch(self, data: DispatchData) -> None:
@@ -290,6 +293,7 @@ async def dispatch(self, data: DispatchData) -> None:
290293
for dec in decorators:
291294
if await dec.handle_dispatch(data) is False:
292295
self.logger.debug("Trigger not active due to %s", dec)
296+
data.set_result(None)
293297
return
294298

295299
action_ast_ctx = AstEval(

custom_components/pyscript/decorator_abc.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
from abc import ABC, abstractmethod
6+
import asyncio
67
from dataclasses import dataclass, field
78
from enum import StrEnum
89
import logging
@@ -46,6 +47,16 @@ class DispatchData:
4647
call_ast_ctx: AstEval | None = field(default=None, kw_only=True)
4748
hass_context: Context | None = field(default=None, kw_only=True)
4849

50+
# When set, the dispatch pipeline resolves this future with the
51+
# decorated function's return value. Resolved with None if the
52+
# function is skipped (guard rejection) or raises.
53+
result_future: asyncio.Future[Any] | None = field(default=None, kw_only=True)
54+
55+
def set_result(self, value: Any) -> None:
56+
"""Resolve result_future with value if it is still pending."""
57+
if self.result_future is not None and not self.result_future.done():
58+
self.result_future.set_result(value)
59+
4960

5061
class Decorator(ABC):
5162
"""Generic decorator abstraction."""

custom_components/pyscript/decorators/webhook.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22

33
from __future__ import annotations
44

5+
import asyncio
56
import logging
6-
from typing import ClassVar
7+
from typing import Any, ClassVar
78

8-
from aiohttp import hdrs
9+
from aiohttp import hdrs, web
910
import voluptuous as vol
1011

1112
from homeassistant.components import webhook
@@ -32,12 +33,14 @@ class WebhookTriggerDecorator(TriggerDecorator, ExpressionDecorator, AutoKwargsD
3233
{
3334
vol.Optional("local_only", default=True): cv.boolean,
3435
vol.Optional("methods"): vol.All(list[str], [vol.In(SUPPORTED_METHODS)]),
36+
vol.Optional("sets_http_response_code", default=False): cv.boolean,
3537
}
3638
)
3739

3840
webhook_id: str
3941
local_only: bool
4042
methods: set[str]
43+
sets_http_response_code: bool
4144

4245
webhook_id2triggers: ClassVar[dict[str, set[WebhookTriggerDecorator]]] = {}
4346

@@ -50,7 +53,7 @@ async def validate(self):
5053
self.create_expression(self.args[1])
5154

5255
@staticmethod
53-
async def _handler(_hass, webhook_id, request):
56+
async def _handler(hass, webhook_id, request):
5457
func_args = {
5558
"trigger_type": "webhook",
5659
"webhook_id": webhook_id,
@@ -64,17 +67,59 @@ async def _handler(_hass, webhook_id, request):
6467
payload_multidict = await request.post()
6568
func_args["payload"] = {k: payload_multidict.getone(k) for k in payload_multidict.keys()}
6669

70+
response_future: asyncio.Future[Any] | None = None
71+
futures: list[asyncio.Future[Any]] = []
6772
for trigger in WebhookTriggerDecorator.webhook_id2triggers.get(webhook_id, set()).copy():
6873
trigger_args = func_args.copy()
6974
if trigger.has_expression():
7075
if not await trigger.check_expression_vars(trigger_args):
7176
continue
72-
await trigger.dispatch(DispatchData(trigger_args))
77+
future: asyncio.Future[Any] = hass.loop.create_future()
78+
if trigger.sets_http_response_code:
79+
response_future = future
80+
futures.append(future)
81+
await trigger.dispatch(DispatchData(trigger_args, result_future=future))
82+
83+
if not futures:
84+
return None
85+
86+
await asyncio.gather(*futures, return_exceptions=True)
87+
88+
if response_future is None:
89+
return None
90+
return WebhookTriggerDecorator.coerce_response(response_future.result())
91+
92+
@staticmethod
93+
def coerce_response(value: Any) -> web.Response | None:
94+
"""Convert a webhook function return value to an aiohttp Response."""
95+
if value is None:
96+
return None
97+
if isinstance(value, web.Response):
98+
return value
99+
# bool is a subclass of int; reject it so True/False don't become 1/0 status codes.
100+
if isinstance(value, int) and not isinstance(value, bool):
101+
return web.Response(status=value)
102+
_LOGGER.warning(
103+
"webhook function returned unsupported type %s; expected int status code or aiohttp.web.Response",
104+
type(value).__name__,
105+
)
106+
return None
73107

74108
@staticmethod
75109
def _add_trigger(trigger: WebhookTriggerDecorator) -> None:
76110
webhook_id = trigger.webhook_id
77-
if webhook_id not in WebhookTriggerDecorator.webhook_id2triggers:
111+
existing = WebhookTriggerDecorator.webhook_id2triggers.get(webhook_id)
112+
if (
113+
trigger.sets_http_response_code
114+
and existing is not None
115+
and any(t.sets_http_response_code for t in existing)
116+
):
117+
raise ValueError(
118+
f"webhook_id '{webhook_id}' already has a @webhook_trigger with "
119+
f"sets_http_response_code=True; only one is allowed"
120+
)
121+
122+
if existing is None:
78123
webhook.async_register(
79124
trigger.dm.hass,
80125
"pyscript", # DOMAIN

custom_components/pyscript/stubs/pyscript_builtins.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def webhook_trigger(
127127
str_expr: str | None = None,
128128
local_only: bool = True,
129129
methods: set[SUPPORTED_METHODS] | list[SUPPORTED_METHODS] = {"POST", "PUT"},
130+
sets_http_response_code: bool = False,
130131
kwargs: dict | None = None,
131132
) -> Callable[..., Any]:
132133
"""Trigger when a request is made to a webhook endpoint.
@@ -136,6 +137,7 @@ def webhook_trigger(
136137
str_expr: Optional expression evaluated against ``trigger_type``, ``webhook_id``, ``request``, and ``payload``.
137138
local_only: If False, allow requests from anywhere on the internet.
138139
methods: HTTP methods to allow.
140+
sets_http_response_code: If True, the function's return value drives the HTTP response (``int`` status code or ``aiohttp.web.Response``); at most one trigger per ``webhook_id`` may set this.
139141
kwargs: Extra keyword arguments merged into each invocation.
140142
141143
Trigger kwargs include ``trigger_type="webhook"``, ``webhook_id``, the parsed payload fields, and ``request`` (the underlying ``aiohttp.web.Request``).

docs/reference.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -915,6 +915,18 @@ To validate an HMAC signature on incoming requests, declare ``request`` in the f
915915
return
916916
log.info(f"verified webhook: {payload}")
917917
918+
To control the HTTP response sent back to the webhook caller, opt in by passing ``sets_http_response_code=True``. The flagged function's return value then drives the response: ``None`` produces a ``200 OK``, an ``int`` sends back a response with that status code, and an ``aiohttp.web.Response`` allows full control over the body and headers. Return values from triggers without the flag are ignored. For example:
919+
920+
.. code:: python
921+
922+
@webhook_trigger("myid", sets_http_response_code=True)
923+
def webhook_check(payload):
924+
if "token" not in payload:
925+
return 401
926+
return 204
927+
928+
At most one ``@webhook_trigger`` per ``webhook_id`` may set ``sets_http_response_code=True``; declaring more than one is an error at setup time. The webhook handler waits for all decorated function(s) for the ``webhook_id`` to finish before responding, so use ``task.create()`` to fire-and-forget any long-running work.
929+
918930
NOTE: A webhook_id can only be used by either a built-in Home Assistant automation or pyscript, but not both. Trying to use the same webhook_id in both will result in an error.
919931

920932
@state_active

tests/test_decorator_manager.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import asyncio
56
import logging
67
from typing import ClassVar
78
from unittest.mock import patch
@@ -353,9 +354,15 @@ def make_dispatch_data(
353354
*,
354355
call_ast_ctx: DummyCallAstCtx | None = None,
355356
hass_context: Context | None = None,
357+
result_future: asyncio.Future | None = None,
356358
) -> DispatchData:
357359
"""Build DispatchData from test doubles."""
358-
return DispatchData(func_args, call_ast_ctx=call_ast_ctx, hass_context=hass_context)
360+
return DispatchData(
361+
func_args,
362+
call_ast_ctx=call_ast_ctx,
363+
hass_context=hass_context,
364+
result_future=result_future,
365+
)
359366

360367

361368
def setup_global_context_function_hass(hass: HomeAssistant, config_data: dict | None = None) -> None:
@@ -599,6 +606,77 @@ async def test_function_decorator_manager_logs_call_exception(hass):
599606
assert str(ast_ctx.logged_exceptions[0]) == "decorated call failed"
600607

601608

609+
@pytest.mark.asyncio
610+
async def test_function_decorator_manager_result_future_success(hass):
611+
"""Successful calls should resolve result_future with the function's return value."""
612+
DecoratorManager.hass = hass
613+
manager = FunctionDecoratorManager(DummyAstCtx(), DummyEvalFuncVar())
614+
call_ast_ctx = DummyCallAstCtx(result=201)
615+
future: asyncio.Future = hass.loop.create_future()
616+
617+
with patch.object(Function, "store_hass_context"):
618+
await call_function_manager(
619+
manager,
620+
make_dispatch_data(
621+
{"arg1": 1},
622+
call_ast_ctx=call_ast_ctx,
623+
hass_context=Context(id="call-parent"),
624+
result_future=future,
625+
),
626+
)
627+
await hass.async_block_till_done()
628+
629+
assert future.done()
630+
assert future.result() == 201
631+
632+
633+
@pytest.mark.asyncio
634+
async def test_function_decorator_manager_result_future_cancel(hass):
635+
"""When a call handler cancels, result_future should resolve to None."""
636+
DecoratorManager.hass = hass
637+
manager = FunctionDecoratorManager(DummyAstCtx(), DummyEvalFuncVar())
638+
manager.add(make_cancel_call_handler())
639+
future: asyncio.Future = hass.loop.create_future()
640+
641+
await call_function_manager(
642+
manager,
643+
make_dispatch_data(
644+
{"arg1": 1},
645+
call_ast_ctx=DummyCallAstCtx(result="unused"),
646+
hass_context=Context(id="call-parent"),
647+
result_future=future,
648+
),
649+
)
650+
651+
assert future.done()
652+
assert future.result() is None
653+
654+
655+
@pytest.mark.asyncio
656+
async def test_function_decorator_manager_result_future_exception(hass):
657+
"""When the decorated function raises, result_future should resolve to None."""
658+
DecoratorManager.hass = hass
659+
ast_ctx = DummyAstCtx()
660+
manager = FunctionDecoratorManager(ast_ctx, DummyEvalFuncVar())
661+
call_ast_ctx = DummyCallAstCtx(exc=RuntimeError("boom"))
662+
future: asyncio.Future = hass.loop.create_future()
663+
664+
with patch.object(Function, "store_hass_context"):
665+
await call_function_manager(
666+
manager,
667+
make_dispatch_data(
668+
{"arg1": 1},
669+
call_ast_ctx=call_ast_ctx,
670+
hass_context=Context(id="call-parent"),
671+
result_future=future,
672+
),
673+
)
674+
675+
assert future.done()
676+
assert future.result() is None
677+
assert len(ast_ctx.logged_exceptions) == 1
678+
679+
602680
def test_decorator_registry_register_requires_name():
603681
"""Registry should reject decorators without a declared name."""
604682

0 commit comments

Comments
 (0)