|
139 | 139 | import time |
140 | 140 | import uuid |
141 | 141 | from collections.abc import Mapping as _Mapping |
142 | | -from contextvars import ContextVar |
| 142 | +from contextvars import ContextVar, Token |
143 | 143 | from typing import ( |
144 | 144 | TYPE_CHECKING, |
145 | 145 | Any, |
|
154 | 154 | TypeVar, |
155 | 155 | ) |
156 | 156 |
|
157 | | -from _contextvars import Token |
158 | | - |
159 | 157 | from bson.binary import Binary |
160 | 158 | from bson.int64 import Int64 |
161 | 159 | from bson.timestamp import Timestamp |
@@ -193,6 +191,24 @@ def __init__(self, session: AsyncClientSession, client_id: int): |
193 | 191 | self.client_id = client_id |
194 | 192 |
|
195 | 193 |
|
| 194 | +class AsyncBoundSessionContext: |
| 195 | + """Context manager returned by AsyncClientSession.bind() that manages bound state.""" |
| 196 | + |
| 197 | + def __init__(self, session: AsyncClientSession) -> None: |
| 198 | + self._session = session |
| 199 | + self._session_token: Optional[Token[_AsyncBoundClientSession]] = None |
| 200 | + |
| 201 | + async def __aenter__(self) -> AsyncClientSession: |
| 202 | + bound_session = _AsyncBoundClientSession(self._session, id(self._session._client)) |
| 203 | + self._session_token = _SESSION.set(bound_session) # type: ignore[assignment] |
| 204 | + return self._session |
| 205 | + |
| 206 | + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: |
| 207 | + if self._session_token: |
| 208 | + _SESSION.reset(self._session_token) # type: ignore[arg-type] |
| 209 | + self._session_token = None |
| 210 | + |
| 211 | + |
196 | 212 | class SessionOptions: |
197 | 213 | """Options for a new :class:`AsyncClientSession`. |
198 | 214 |
|
@@ -528,9 +544,6 @@ def __init__( |
528 | 544 | self._attached_to_cursor = False |
529 | 545 | # Should we leave the session alive when the cursor is closed? |
530 | 546 | self._leave_alive = False |
531 | | - # Is this session bound to a context manager scope? |
532 | | - self._bound = False |
533 | | - self._session_token: Optional[Token[_AsyncBoundClientSession]] = None |
534 | 547 |
|
535 | 548 | async def end_session(self) -> None: |
536 | 549 | """Finish this session. If a transaction has started, abort it. |
@@ -561,23 +574,18 @@ def _check_ended(self) -> None: |
561 | 574 | if self._server_session is None: |
562 | 575 | raise InvalidOperation("Cannot use ended session") |
563 | 576 |
|
564 | | - def bind(self) -> AsyncClientSession: |
565 | | - self._bound = True |
566 | | - return self |
| 577 | + def bind(self) -> AsyncBoundSessionContext: |
| 578 | + """Bind this session so it is implicitly passed to all database operations within the returned context. |
| 579 | +
|
| 580 | + .. versionadded:: 4.17 |
| 581 | + """ |
| 582 | + return AsyncBoundSessionContext(self) |
567 | 583 |
|
568 | 584 | async def __aenter__(self) -> AsyncClientSession: |
569 | | - if self._bound: |
570 | | - bound_session = _AsyncBoundClientSession(self, id(self._client)) |
571 | | - self._session_token = _SESSION.set(bound_session) # type: ignore[assignment] |
572 | 585 | return self |
573 | 586 |
|
574 | 587 | async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: |
575 | | - if self._session_token: |
576 | | - _SESSION.reset(self._session_token) # type: ignore[arg-type] |
577 | | - self._session_token = None |
578 | | - self._bound = False |
579 | | - else: |
580 | | - await self._end_session(lock=True) |
| 588 | + await self._end_session(lock=True) |
581 | 589 |
|
582 | 590 | @property |
583 | 591 | def client(self) -> AsyncMongoClient[Any]: |
|
0 commit comments