Skip to content

Commit 900d9c7

Browse files
authored
PYTHON-5436 - Always include session on getMores if the initial curso… (#2794)
1 parent 575d75f commit 900d9c7

4 files changed

Lines changed: 78 additions & 10 deletions

File tree

pymongo/asynchronous/client_session.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1101,7 +1101,11 @@ def _apply_to(
11011101
read_preference: _ServerMode,
11021102
conn: AsyncConnection,
11031103
) -> None:
1104-
if not conn.supports_sessions:
1104+
# getMores must be sent with a session if the cursor was opened with one
1105+
operation = next(iter(command))
1106+
if not conn.supports_sessions and (
1107+
isinstance(self._server_session, _EmptyServerSession) or operation != "getMore"
1108+
):
11051109
if not self._implicit:
11061110
raise ConfigurationError("Sessions are not supported by this MongoDB deployment")
11071111
return

pymongo/synchronous/client_session.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1097,7 +1097,11 @@ def _apply_to(
10971097
read_preference: _ServerMode,
10981098
conn: Connection,
10991099
) -> None:
1100-
if not conn.supports_sessions:
1100+
# getMores must be sent with a session if the cursor was opened with one
1101+
operation = next(iter(command))
1102+
if not conn.supports_sessions and (
1103+
isinstance(self._server_session, _EmptyServerSession) or operation != "getMore"
1104+
):
11011105
if not self._implicit:
11021106
raise ConfigurationError("Sessions are not supported by this MongoDB deployment")
11031107
return

test/asynchronous/test_session.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
"""Test the client_session module."""
1616
from __future__ import annotations
1717

18-
import asyncio
1918
import copy
2019
import sys
2120
import time
@@ -24,8 +23,6 @@
2423
from test.asynchronous.helpers import ExceptionCatchingTask
2524
from typing import Any, Callable, List, Set, Tuple
2625

27-
from pymongo.synchronous.mongo_client import MongoClient
28-
2926
sys.path[0:0] = [""]
3027

3128
from test.asynchronous import (
@@ -45,7 +42,7 @@
4542

4643
from bson import DBRef
4744
from gridfs.asynchronous.grid_file import AsyncGridFS, AsyncGridFSBucket
48-
from pymongo import ASCENDING, AsyncMongoClient, _csot, monitoring
45+
from pymongo import ASCENDING, AsyncMongoClient, monitoring
4946
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
5047
from pymongo.asynchronous.cursor import AsyncCursor
5148
from pymongo.asynchronous.helpers import anext
@@ -938,6 +935,39 @@ async def test_session_binding_end_session(self):
938935

939936
await s2.end_session()
940937

938+
async def test_getmore_preserves_lsid_after_session_support_lost(self):
939+
listener = OvertCommandListener()
940+
client = await self.async_rs_or_single_client(event_listeners=[listener], maxPoolSize=1)
941+
coll = client.pymongo_test.test
942+
await coll.drop()
943+
await coll.insert_many([{"x": i} for i in range(10)])
944+
self.addAsyncCleanup(coll.drop)
945+
946+
async with client.start_session() as s:
947+
cursor = coll.find({}, batch_size=2, session=s)
948+
await anext(cursor)
949+
950+
find_event = next(e for e in listener.started_events if e.command_name == "find")
951+
lsid = find_event.command["lsid"]
952+
953+
# Simulate a node stepping down: mark idle connections as not supporting sessions.
954+
for server in client._topology._servers.values():
955+
for conn in server.pool.conns:
956+
conn.supports_sessions = False
957+
958+
listener.reset()
959+
await cursor.to_list()
960+
961+
getmore_events = [e for e in listener.started_events if e.command_name == "getMore"]
962+
self.assertGreater(len(getmore_events), 0, "expected at least one getMore command")
963+
for event in getmore_events:
964+
self.assertIn(
965+
"lsid", event.command, "getMore must include lsid when session is materialized"
966+
)
967+
self.assertEqual(
968+
lsid, event.command["lsid"], "getMore lsid must match the session lsid from find"
969+
)
970+
941971

942972
class TestCausalConsistency(AsyncUnitTest):
943973
listener: SessionTestListener

test/test_session.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
"""Test the client_session module."""
1616
from __future__ import annotations
1717

18-
import asyncio
1918
import copy
2019
import sys
2120
import time
@@ -24,8 +23,6 @@
2423
from test.helpers import ExceptionCatchingTask
2524
from typing import Any, Callable, List, Set, Tuple
2625

27-
from pymongo.synchronous.mongo_client import MongoClient
28-
2926
sys.path[0:0] = [""]
3027

3128
from test import (
@@ -45,7 +42,7 @@
4542

4643
from bson import DBRef
4744
from gridfs.synchronous.grid_file import GridFS, GridFSBucket
48-
from pymongo import ASCENDING, MongoClient, _csot, monitoring
45+
from pymongo import ASCENDING, MongoClient, monitoring
4946
from pymongo.common import _MAX_END_SESSIONS
5047
from pymongo.errors import ConfigurationError, InvalidOperation, OperationFailure
5148
from pymongo.operations import IndexModel, InsertOne, UpdateOne
@@ -938,6 +935,39 @@ def test_session_binding_end_session(self):
938935

939936
s2.end_session()
940937

938+
def test_getmore_preserves_lsid_after_session_support_lost(self):
939+
listener = OvertCommandListener()
940+
client = self.rs_or_single_client(event_listeners=[listener], maxPoolSize=1)
941+
coll = client.pymongo_test.test
942+
coll.drop()
943+
coll.insert_many([{"x": i} for i in range(10)])
944+
self.addCleanup(coll.drop)
945+
946+
with client.start_session() as s:
947+
cursor = coll.find({}, batch_size=2, session=s)
948+
next(cursor)
949+
950+
find_event = next(e for e in listener.started_events if e.command_name == "find")
951+
lsid = find_event.command["lsid"]
952+
953+
# Simulate a node stepping down: mark idle connections as not supporting sessions.
954+
for server in client._topology._servers.values():
955+
for conn in server.pool.conns:
956+
conn.supports_sessions = False
957+
958+
listener.reset()
959+
cursor.to_list()
960+
961+
getmore_events = [e for e in listener.started_events if e.command_name == "getMore"]
962+
self.assertGreater(len(getmore_events), 0, "expected at least one getMore command")
963+
for event in getmore_events:
964+
self.assertIn(
965+
"lsid", event.command, "getMore must include lsid when session is materialized"
966+
)
967+
self.assertEqual(
968+
lsid, event.command["lsid"], "getMore lsid must match the session lsid from find"
969+
)
970+
941971

942972
class TestCausalConsistency(UnitTest):
943973
listener: SessionTestListener

0 commit comments

Comments
 (0)