-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathapplication.py
More file actions
110 lines (88 loc) · 4.26 KB
/
application.py
File metadata and controls
110 lines (88 loc) · 4.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import asyncio
from typing import Any, TypedDict
from typing_extensions import NotRequired, Required, Unpack
from asyncapi_python.kernel.document.operation import Operation
from asyncapi_python.kernel.wire import AbstractWireFactory
from .codec import CodecFactory
from .endpoint import AbstractEndpoint, EndpointFactory
from .endpoint.abc import EndpointParams
class BaseApplication:
class Inputs(TypedDict):
wire_factory: Required[AbstractWireFactory[Any, Any]]
codec_factory: Required[CodecFactory[Any, Any]]
endpoint_params: NotRequired[EndpointParams]
def __init__(self, **kwargs: Unpack[Inputs]) -> None:
self.__endpoints: set[AbstractEndpoint] = set()
self.__wire_factory: AbstractWireFactory[Any, Any] = kwargs["wire_factory"]
self.__codec_factory: CodecFactory[Any, Any] = kwargs["codec_factory"]
self.__endpoint_params: EndpointParams = kwargs.get("endpoint_params", {})
self._stop_event: asyncio.Event | None = None
self._monitor_task: asyncio.Task[None] | None = None
self._exception_future: asyncio.Future[Exception] | None = None
def _register_endpoint(self, op: Operation) -> AbstractEndpoint:
endpoint = EndpointFactory.create(
operation=op,
wire_factory=self.__wire_factory,
codec_factory=self.__codec_factory,
endpoint_params=self.__endpoint_params,
)
self.__endpoints.add(endpoint)
return endpoint
async def start(self, *, blocking: bool = False) -> None:
"""Start all endpoints in the application.
Args:
blocking: If True, block until stop() is called or process is interrupted.
If False (default), return immediately after starting endpoints.
"""
await asyncio.gather(
*(
e.start(exception_callback=self._propagate_exception)
for e in self.__endpoints
)
)
if blocking:
# Block until stop() is called or process is interrupted
self._stop_event = asyncio.Event()
self._exception_future = asyncio.Future()
try:
# Create tasks for both conditions
stop_task = asyncio.create_task(self._stop_event.wait())
# Convert Future to awaitable
async def _wait_for_exception():
if self._exception_future is None:
# Create a never-completing future if no exception future exists
await asyncio.Event().wait()
return # This line will never be reached
return await asyncio.wrap_future(self._exception_future)
exception_task = asyncio.create_task(_wait_for_exception())
# Wait for either stop event or exception
_, pending = await asyncio.wait(
[stop_task, exception_task], return_when=asyncio.FIRST_COMPLETED
)
# Cancel remaining tasks
for task in pending:
task.cancel()
# Check if an exception was raised
if exception_task.done() and not exception_task.cancelled():
exc = exception_task.result()
if exc is not None:
await self.stop()
raise exc
except asyncio.CancelledError:
# Handle graceful shutdown on cancellation
await self.stop()
raise
async def stop(self) -> None:
"""Stop all endpoints in the application."""
await asyncio.gather(*(e.stop() for e in self.__endpoints))
# Signal the blocking start() method to exit if it's waiting
if self._stop_event:
self._stop_event.set()
def _add_endpoint(self, endpoint: AbstractEndpoint) -> None:
"""Add an endpoint to this application."""
self.__endpoints.add(endpoint)
def _propagate_exception(self, exception: Exception) -> None:
"""Propagate exception from endpoint to application level."""
if self._exception_future and not self._exception_future.done():
self._exception_future.set_result(exception)
__all__ = ["BaseApplication"]