-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Expand file tree
/
Copy path_context_streams.py
More file actions
122 lines (88 loc) · 3.8 KB
/
_context_streams.py
File metadata and controls
122 lines (88 loc) · 3.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""Context-aware memory stream wrappers.
anyio memory streams do not propagate ``contextvars.Context`` across task
boundaries. These thin wrappers capture the sender's context at ``send()``
time and expose it on the receive side via ``last_context``, so consumers
can restore it with ``ctx.run(handler, item)``.
The iteration interface is unchanged (yields ``T``, not tuples), keeping
these wrappers duck-type compatible with plain ``MemoryObjectSendStream``
and ``MemoryObjectReceiveStream``.
"""
from __future__ import annotations
import contextvars
from types import TracebackType
from typing import Any, Generic, TypeVar
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
T = TypeVar("T")
# Internal payload carried through the underlying raw stream.
_Envelope = tuple[contextvars.Context, T]
class ContextSendStream(Generic[T]):
"""Send-side wrapper that snapshots ``contextvars.copy_context()`` on every ``send()``."""
__slots__ = ("_inner",)
def __init__(self, inner: MemoryObjectSendStream[_Envelope[T]]) -> None:
self._inner = inner
async def send(self, item: T) -> None:
await self._inner.send((contextvars.copy_context(), item))
async def send_with_context(self, context: contextvars.Context, item: T) -> None:
await self._inner.send((context, item))
def close(self) -> None:
self._inner.close()
async def aclose(self) -> None:
await self._inner.aclose()
def clone(self) -> ContextSendStream[T]: # pragma: no cover
return ContextSendStream(self._inner.clone())
async def __aenter__(self) -> ContextSendStream[T]:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool | None:
await self.aclose()
return None
class ContextReceiveStream(Generic[T]):
"""Receive-side wrapper that yields ``T`` and stores the sender's context in ``last_context``."""
__slots__ = ("_inner", "last_context")
def __init__(self, inner: MemoryObjectReceiveStream[_Envelope[T]]) -> None:
self._inner = inner
self.last_context: contextvars.Context | None = None
async def receive(self) -> T:
ctx, item = await self._inner.receive()
self.last_context = ctx
return item
def close(self) -> None:
self._inner.close()
async def aclose(self) -> None:
await self._inner.aclose()
def clone(self) -> ContextReceiveStream[T]: # pragma: no cover
return ContextReceiveStream(self._inner.clone())
def __aiter__(self) -> ContextReceiveStream[T]:
return self
async def __anext__(self) -> T:
try:
return await self.receive()
except anyio.EndOfStream:
raise StopAsyncIteration
async def __aenter__(self) -> ContextReceiveStream[T]:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool | None:
await self.aclose()
return None
class create_context_streams(
tuple[ContextSendStream[T], ContextReceiveStream[T]],
):
"""Create context-aware memory object streams.
Supports ``create_context_streams[T](n)`` bracket syntax,
matching anyio's ``create_memory_object_stream`` API style.
"""
def __new__(cls, max_buffer_size: float = 0) -> tuple[ContextSendStream[T], ContextReceiveStream[T]]: # type: ignore[type-var]
raw_send: MemoryObjectSendStream[Any]
raw_receive: MemoryObjectReceiveStream[Any]
raw_send, raw_receive = anyio.create_memory_object_stream(max_buffer_size)
return (ContextSendStream(raw_send), ContextReceiveStream(raw_receive))