Skip to content

Commit 6eaf094

Browse files
committed
Add end_session parameter to ClientSession.bind()
1 parent 5edd9a0 commit 6eaf094

4 files changed

Lines changed: 48 additions & 6 deletions

File tree

pymongo/asynchronous/client_session.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,10 @@
188188
class _AsyncBoundSessionContext:
189189
"""Context manager returned by AsyncClientSession.bind() that manages bound state."""
190190

191-
def __init__(self, session: AsyncClientSession) -> None:
191+
def __init__(self, session: AsyncClientSession, end_session: bool) -> None:
192192
self._session = session
193193
self._session_token: Optional[Token[AsyncClientSession]] = None
194+
self._end_session = end_session
194195

195196
async def __aenter__(self) -> AsyncClientSession:
196197
self._session_token = _SESSION.set(self._session) # type: ignore[assignment]
@@ -200,6 +201,8 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
200201
if self._session_token:
201202
_SESSION.reset(self._session_token) # type: ignore[arg-type]
202203
self._session_token = None
204+
if self._end_session:
205+
await self._session.end_session()
203206

204207

205208
class SessionOptions:
@@ -567,12 +570,14 @@ def _check_ended(self) -> None:
567570
if self._server_session is None:
568571
raise InvalidOperation("Cannot use ended session")
569572

570-
def bind(self) -> _AsyncBoundSessionContext:
573+
def bind(self, end_session: bool = False) -> _AsyncBoundSessionContext:
571574
"""Bind this session so it is implicitly passed to all database operations within the returned context.
572575
576+
:param end_session: Whether to end the session on exiting the returned context. Defaults to False.
577+
573578
.. versionadded:: 4.17
574579
"""
575-
return _AsyncBoundSessionContext(self)
580+
return _AsyncBoundSessionContext(self, end_session)
576581

577582
async def __aenter__(self) -> AsyncClientSession:
578583
return self

pymongo/synchronous/client_session.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,10 @@
187187
class BoundSessionContext:
188188
"""Context manager returned by ClientSession.bind() that manages bound state."""
189189

190-
def __init__(self, session: ClientSession) -> None:
190+
def __init__(self, session: ClientSession, end_session: bool) -> None:
191191
self._session = session
192192
self._session_token: Optional[Token[ClientSession]] = None
193+
self._end_session = end_session
193194

194195
def __enter__(self) -> ClientSession:
195196
self._session_token = _SESSION.set(self._session) # type: ignore[assignment]
@@ -199,6 +200,8 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
199200
if self._session_token:
200201
_SESSION.reset(self._session_token) # type: ignore[arg-type]
201202
self._session_token = None
203+
if self._end_session:
204+
self._session.end_session()
202205

203206

204207
class SessionOptions:
@@ -566,12 +569,14 @@ def _check_ended(self) -> None:
566569
if self._server_session is None:
567570
raise InvalidOperation("Cannot use ended session")
568571

569-
def bind(self) -> BoundSessionContext:
572+
def bind(self, end_session: bool = False) -> BoundSessionContext:
570573
"""Bind this session so it is implicitly passed to all database operations within the returned context.
571574
575+
:param end_session: Whether to end the session on exiting the returned context. Defaults to False.
576+
572577
.. versionadded:: 4.17
573578
"""
574-
return BoundSessionContext(self)
579+
return BoundSessionContext(self, end_session)
575580

576581
def __enter__(self) -> ClientSession:
577582
return self

test/asynchronous/test_session.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -922,6 +922,22 @@ async def test_nested_session_binding(self):
922922
await session1.end_session()
923923
await session2.end_session()
924924

925+
async def test_session_binding_end_session(self):
926+
coll = self.client.pymongo_test.test
927+
await coll.insert_one({"x": 1})
928+
929+
async with self.client.start_session().bind(end_session=True) as s1:
930+
await coll.find_one()
931+
932+
self.assertTrue(s1.has_ended)
933+
934+
async with self.client.start_session().bind() as s2:
935+
await coll.find_one()
936+
937+
self.assertFalse(s2.has_ended)
938+
939+
await s2.end_session()
940+
925941

926942
class TestCausalConsistency(AsyncUnitTest):
927943
listener: SessionTestListener

test/test_session.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -922,6 +922,22 @@ def test_nested_session_binding(self):
922922
session1.end_session()
923923
session2.end_session()
924924

925+
def test_session_binding_end_session(self):
926+
coll = self.client.pymongo_test.test
927+
coll.insert_one({"x": 1})
928+
929+
with self.client.start_session().bind(end_session=True) as s1:
930+
coll.find_one()
931+
932+
self.assertTrue(s1.has_ended)
933+
934+
with self.client.start_session().bind() as s2:
935+
coll.find_one()
936+
937+
self.assertFalse(s2.has_ended)
938+
939+
s2.end_session()
940+
925941

926942
class TestCausalConsistency(UnitTest):
927943
listener: SessionTestListener

0 commit comments

Comments
 (0)