Skip to content

Commit 59e6541

Browse files
committed
support multiple triggers per webhook_id
1 parent 7c2de98 commit 59e6541

1 file changed

Lines changed: 44 additions & 16 deletions

File tree

custom_components/pyscript/decorators/webhook.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
"""Webhook decorator."""
22

3+
from __future__ import annotations
4+
35
import logging
6+
from typing import ClassVar
47

58
from aiohttp import hdrs
69
import voluptuous as vol
@@ -36,6 +39,8 @@ class WebhookTriggerDecorator(TriggerDecorator, ExpressionDecorator, AutoKwargsD
3639
local_only: bool
3740
methods: set[str]
3841

42+
webhook_id2triggers: ClassVar[dict[str, set[WebhookTriggerDecorator]]] = {}
43+
3944
async def validate(self):
4045
"""Validate the webhook trigger configuration."""
4146
await super().validate()
@@ -44,7 +49,8 @@ async def validate(self):
4449
if len(self.args) == 2:
4550
self.create_expression(self.args[1])
4651

47-
async def _handler(self, hass, webhook_id, request):
52+
@staticmethod
53+
async def _handler(_hass, webhook_id, request):
4854
func_args = {
4955
"trigger_type": "webhook",
5056
"webhook_id": webhook_id,
@@ -57,28 +63,50 @@ async def _handler(self, hass, webhook_id, request):
5763
payload_multidict = await request.post()
5864
func_args["payload"] = {k: payload_multidict.getone(k) for k in payload_multidict.keys()}
5965

60-
if self.has_expression():
61-
if not await self.check_expression_vars(func_args):
62-
return
63-
64-
await self.dispatch(DispatchData(func_args))
66+
for trigger in WebhookTriggerDecorator.webhook_id2triggers.get(webhook_id, set()).copy():
67+
trigger_args = func_args.copy()
68+
if trigger.has_expression():
69+
if not await trigger.check_expression_vars(trigger_args):
70+
continue
71+
await trigger.dispatch(DispatchData(trigger_args))
72+
73+
@staticmethod
74+
def _add_trigger(trigger: WebhookTriggerDecorator) -> None:
75+
webhook_id = trigger.webhook_id
76+
if webhook_id not in WebhookTriggerDecorator.webhook_id2triggers:
77+
webhook.async_register(
78+
trigger.dm.hass,
79+
"pyscript", # DOMAIN
80+
"pyscript", # NAME
81+
webhook_id,
82+
WebhookTriggerDecorator._handler,
83+
local_only=trigger.local_only,
84+
allowed_methods=trigger.methods,
85+
)
86+
WebhookTriggerDecorator.webhook_id2triggers[webhook_id] = set()
87+
88+
WebhookTriggerDecorator.webhook_id2triggers[webhook_id].add(trigger)
89+
90+
@staticmethod
91+
def _remove_trigger(trigger: WebhookTriggerDecorator) -> None:
92+
webhook_id = trigger.webhook_id
93+
triggers = WebhookTriggerDecorator.webhook_id2triggers.get(webhook_id)
94+
if not triggers:
95+
return
96+
97+
triggers.discard(trigger)
98+
if len(triggers) == 0:
99+
webhook.async_unregister(trigger.dm.hass, webhook_id)
100+
del WebhookTriggerDecorator.webhook_id2triggers[webhook_id]
65101

66102
async def start(self):
67103
"""Start the webhook trigger."""
68104
await super().start()
69-
webhook.async_register(
70-
self.dm.hass,
71-
"pyscript", # DOMAIN
72-
"pyscript", # NAME
73-
self.webhook_id,
74-
self._handler,
75-
local_only=self.local_only,
76-
allowed_methods=self.methods,
77-
)
105+
self._add_trigger(self)
78106

79107
_LOGGER.debug("webhook trigger %s listening on id %s", self.dm.name, self.webhook_id)
80108

81109
async def stop(self):
82110
"""Stop the webhook trigger."""
83111
await super().stop()
84-
webhook.async_unregister(self.dm.hass, self.webhook_id)
112+
self._remove_trigger(self)

0 commit comments

Comments
 (0)