-
Notifications
You must be signed in to change notification settings - Fork 137
Expand file tree
/
Copy pathtrigger_factory.py
More file actions
185 lines (148 loc) · 7.31 KB
/
trigger_factory.py
File metadata and controls
185 lines (148 loc) · 7.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any
from dify_plugin.core.runtime import Session
from dify_plugin.entities.provider_config import CredentialType
from dify_plugin.entities.trigger import (
EventConfiguration,
TriggerProviderConfiguration,
TriggerSubscriptionConstructorRuntime,
)
from dify_plugin.interfaces.trigger import Event, EventRuntime, Trigger, TriggerRuntime, TriggerSubscriptionConstructor
@dataclass(slots=True)
class _TriggerProviderEntry:
"""Internal container storing metadata associated with a trigger provider."""
configuration: TriggerProviderConfiguration
provider_cls: type[Trigger]
subscription_constructor_cls: type[TriggerSubscriptionConstructor] | None
events: dict[str, tuple[EventConfiguration, type[Event]]]
class TriggerProviderRegistration:
"""Helper that allows incremental registration of provider triggers."""
def __init__(self, entry: _TriggerProviderEntry) -> None:
self._entry = entry
def register_trigger(
self,
*,
name: str,
configuration: EventConfiguration,
trigger_cls: type[Event],
) -> None:
"""Register an event implementation for the provider."""
if name in self._entry.events:
raise ValueError(
f"Event `{name}` is already registered for provider `{self._entry.configuration.identity.name}`"
)
self._entry.events[name] = (configuration, trigger_cls)
class TriggerFactory:
"""Registry that produces trigger related runtime instances on demand."""
def __init__(self) -> None:
# Provider name -> runtime metadata. Using a dict allows O(1) lookups when
# resolving provider classes during request handling.
self._providers: dict[str, _TriggerProviderEntry] = {}
def register_trigger_provider(
self,
*,
configuration: TriggerProviderConfiguration,
provider_cls: type[Trigger],
subscription_constructor_cls: type[TriggerSubscriptionConstructor] | None,
events: Mapping[str, tuple[EventConfiguration, type[Event]]],
) -> TriggerProviderRegistration:
"""Register a trigger provider and its runtime classes."""
# Each provider can only be registered once to avoid conflicting runtime
# definitions when multiple plugins try to use the same identifier.
provider_name = configuration.identity.name
if provider_name in self._providers:
raise ValueError(f"Trigger provider `{provider_name}` is already registered")
entry = _TriggerProviderEntry(
configuration=configuration,
provider_cls=provider_cls,
subscription_constructor_cls=subscription_constructor_cls,
events={},
)
self._providers[provider_name] = entry
registration = TriggerProviderRegistration(entry)
# Pre-populate the registry with events that were already discovered
# during plugin loading. Providers can keep adding more events by
# calling ``registration.register_trigger`` inside their module level
# registration hook.
for name, (event_config, event_cls) in events.items():
registration.register_trigger(
name=name,
configuration=event_config,
trigger_cls=event_cls,
)
return registration
# ------------------------------------------------------------------
# Provider factories
# ------------------------------------------------------------------
def get_trigger_provider(
self,
provider_name: str,
session: Session,
credentials: Mapping[str, Any] | None,
credential_type: CredentialType | None,
) -> Trigger:
"""Instantiate the trigger provider implementation for the given provider name."""
entry = self._get_entry(provider_name)
return entry.provider_cls(
runtime=TriggerRuntime(
session=session,
credential_type=credential_type or CredentialType.UNAUTHORIZED,
credentials=credentials,
)
)
def get_provider_cls(self, provider_name: str) -> type[Trigger]:
return self._get_entry(provider_name).provider_cls
def has_subscription_constructor(self, provider_name: str) -> bool:
return self._get_entry(provider_name).subscription_constructor_cls is not None
def get_subscription_constructor(
self,
provider_name: str,
runtime: TriggerSubscriptionConstructorRuntime,
) -> TriggerSubscriptionConstructor:
"""Instantiate the subscription constructor implementation."""
entry = self._get_entry(provider_name)
if not entry.subscription_constructor_cls:
raise ValueError(f"Trigger provider `{provider_name}` does not define a subscription constructor")
return entry.subscription_constructor_cls(runtime)
def get_subscription_constructor_cls(self, provider_name: str) -> type[TriggerSubscriptionConstructor] | None:
return self._get_entry(provider_name).subscription_constructor_cls
# ------------------------------------------------------------------
# Event factories
# ------------------------------------------------------------------
def get_trigger_event_handler_safely(self, provider_name: str, event: str, runtime: EventRuntime) -> Event | None:
try:
entry = self._get_entry(provider_name)
if event not in entry.events:
return None
_, event_cls = entry.events[event]
return event_cls(runtime)
except Exception as e:
print(f"Error getting trigger event handler: {e!s}")
return None
def get_trigger_event_handler(self, provider_name: str, event: str, runtime: EventRuntime) -> Event:
"""Instantiate an event for the given provider and event name."""
entry = self._get_entry(provider_name)
if event not in entry.events:
raise ValueError(f"Event `{event}` not found in provider `{provider_name}`")
_, event_cls = entry.events[event]
return event_cls(runtime)
def get_trigger_configuration(self, provider_name: str, event: str) -> EventConfiguration | None:
entry = self._get_entry(provider_name)
event_entry = entry.events.get(event)
if event_entry is None:
return None
return event_entry[0]
def iter_events(self, provider_name: str) -> Mapping[str, tuple[EventConfiguration, type[Event]]]:
"""Return a shallow copy of the registered events for inspection."""
# Returning a copy ensures callers cannot mutate the internal registry
# inadvertently, while still providing a dictionary-like interface for
# tooling and API handlers that need to enumerate events.
return dict(self._get_entry(provider_name).events)
def get_configuration(self, provider_name: str) -> TriggerProviderConfiguration:
return self._get_entry(provider_name).configuration
def _get_entry(self, provider_name: str) -> _TriggerProviderEntry:
try:
return self._providers[provider_name]
except KeyError as exc: # pragma: no cover - defensive branch
raise ValueError(f"Trigger provider `{provider_name}` not found") from exc