Skip to content

Commit 13f1a15

Browse files
committed
CP review
1 parent d6b883b commit 13f1a15

3 files changed

Lines changed: 53 additions & 36 deletions

File tree

pymongo/asynchronous/client_session.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@
139139
import time
140140
import uuid
141141
from collections.abc import Mapping as _Mapping
142-
from contextvars import ContextVar
142+
from contextvars import ContextVar, Token
143143
from typing import (
144144
TYPE_CHECKING,
145145
Any,
@@ -154,8 +154,6 @@
154154
TypeVar,
155155
)
156156

157-
from _contextvars import Token
158-
159157
from bson.binary import Binary
160158
from bson.int64 import Int64
161159
from bson.timestamp import Timestamp
@@ -193,6 +191,24 @@ def __init__(self, session: AsyncClientSession, client_id: int):
193191
self.client_id = client_id
194192

195193

194+
class AsyncBoundSessionContext:
195+
"""Context manager returned by AsyncClientSession.bind() that manages bound state."""
196+
197+
def __init__(self, session: AsyncClientSession) -> None:
198+
self._session = session
199+
self._session_token: Optional[Token[_AsyncBoundClientSession]] = None
200+
201+
async def __aenter__(self) -> AsyncClientSession:
202+
bound_session = _AsyncBoundClientSession(self._session, id(self._session._client))
203+
self._session_token = _SESSION.set(bound_session) # type: ignore[assignment]
204+
return self._session
205+
206+
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
207+
if self._session_token:
208+
_SESSION.reset(self._session_token) # type: ignore[arg-type]
209+
self._session_token = None
210+
211+
196212
class SessionOptions:
197213
"""Options for a new :class:`AsyncClientSession`.
198214
@@ -528,9 +544,6 @@ def __init__(
528544
self._attached_to_cursor = False
529545
# Should we leave the session alive when the cursor is closed?
530546
self._leave_alive = False
531-
# Is this session bound to a context manager scope?
532-
self._bound = False
533-
self._session_token: Optional[Token[_AsyncBoundClientSession]] = None
534547

535548
async def end_session(self) -> None:
536549
"""Finish this session. If a transaction has started, abort it.
@@ -561,23 +574,18 @@ def _check_ended(self) -> None:
561574
if self._server_session is None:
562575
raise InvalidOperation("Cannot use ended session")
563576

564-
def bind(self) -> AsyncClientSession:
565-
self._bound = True
566-
return self
577+
def bind(self) -> AsyncBoundSessionContext:
578+
"""Bind this session so it is implicitly passed to all database operations within the returned context.
579+
580+
.. versionadded:: 4.17
581+
"""
582+
return AsyncBoundSessionContext(self)
567583

568584
async def __aenter__(self) -> AsyncClientSession:
569-
if self._bound:
570-
bound_session = _AsyncBoundClientSession(self, id(self._client))
571-
self._session_token = _SESSION.set(bound_session) # type: ignore[assignment]
572585
return self
573586

574587
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
575-
if self._session_token:
576-
_SESSION.reset(self._session_token) # type: ignore[arg-type]
577-
self._session_token = None
578-
self._bound = False
579-
else:
580-
await self._end_session(lock=True)
588+
await self._end_session(lock=True)
581589

582590
@property
583591
def client(self) -> AsyncMongoClient[Any]:

pymongo/synchronous/client_session.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@
139139
import time
140140
import uuid
141141
from collections.abc import Mapping as _Mapping
142-
from contextvars import ContextVar
142+
from contextvars import ContextVar, Token
143143
from typing import (
144144
TYPE_CHECKING,
145145
Any,
@@ -153,8 +153,6 @@
153153
TypeVar,
154154
)
155155

156-
from _contextvars import Token
157-
158156
from bson.binary import Binary
159157
from bson.int64 import Int64
160158
from bson.timestamp import Timestamp
@@ -192,6 +190,24 @@ def __init__(self, session: ClientSession, client_id: int):
192190
self.client_id = client_id
193191

194192

193+
class BoundSessionContext:
194+
"""Context manager returned by ClientSession.bind() that manages bound state."""
195+
196+
def __init__(self, session: ClientSession) -> None:
197+
self._session = session
198+
self._session_token: Optional[Token[_BoundClientSession]] = None
199+
200+
def __enter__(self) -> ClientSession:
201+
bound_session = _BoundClientSession(self._session, id(self._session._client))
202+
self._session_token = _SESSION.set(bound_session) # type: ignore[assignment]
203+
return self._session
204+
205+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
206+
if self._session_token:
207+
_SESSION.reset(self._session_token) # type: ignore[arg-type]
208+
self._session_token = None
209+
210+
195211
class SessionOptions:
196212
"""Options for a new :class:`ClientSession`.
197213
@@ -527,9 +543,6 @@ def __init__(
527543
self._attached_to_cursor = False
528544
# Should we leave the session alive when the cursor is closed?
529545
self._leave_alive = False
530-
# Is this session bound to a context manager scope?
531-
self._bound = False
532-
self._session_token: Optional[Token[_BoundClientSession]] = None
533546

534547
def end_session(self) -> None:
535548
"""Finish this session. If a transaction has started, abort it.
@@ -560,23 +573,18 @@ def _check_ended(self) -> None:
560573
if self._server_session is None:
561574
raise InvalidOperation("Cannot use ended session")
562575

563-
def bind(self) -> ClientSession:
564-
self._bound = True
565-
return self
576+
def bind(self) -> BoundSessionContext:
577+
"""Bind this session so it is implicitly passed to all database operations within the returned context.
578+
579+
.. versionadded:: 4.17
580+
"""
581+
return BoundSessionContext(self)
566582

567583
def __enter__(self) -> ClientSession:
568-
if self._bound:
569-
bound_session = _BoundClientSession(self, id(self._client))
570-
self._session_token = _SESSION.set(bound_session) # type: ignore[assignment]
571584
return self
572585

573586
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
574-
if self._session_token:
575-
_SESSION.reset(self._session_token) # type: ignore[arg-type]
576-
self._session_token = None
577-
self._bound = False
578-
else:
579-
self._end_session(lock=True)
587+
self._end_session(lock=True)
580588

581589
@property
582590
def client(self) -> MongoClient[Any]:

tools/synchro.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
"AsyncRawBatchCommandCursor": "RawBatchCommandCursor",
3939
"AsyncClientSession": "ClientSession",
4040
"_AsyncBoundClientSession": "_BoundClientSession",
41+
"AsyncBoundSessionContext": "BoundSessionContext",
4142
"AsyncChangeStream": "ChangeStream",
4243
"AsyncCollectionChangeStream": "CollectionChangeStream",
4344
"AsyncDatabaseChangeStream": "DatabaseChangeStream",

0 commit comments

Comments
 (0)