diff --git a/zulip/tests/test_do_api_query.py b/zulip/tests/test_do_api_query.py new file mode 100644 index 000000000..2b4130047 --- /dev/null +++ b/zulip/tests/test_do_api_query.py @@ -0,0 +1,71 @@ +import unittest +from unittest.mock import MagicMock, patch + +import requests + +import zulip + + +def make_client() -> zulip.Client: + with patch.object(zulip.Client, "call_endpoint") as mock_call: + mock_call.return_value = { + "result": "success", + "zulip_version": "9.0.0", + "zulip_feature_level": 300, + "msg": "", + } + client = zulip.Client( + email="test@example.com", + api_key="deadbeef", + site="https://testserver", + ) + return client + + +class TestStaleConnectionRetry(unittest.TestCase): + def test_stale_connection_resets_session(self) -> None: + client = make_client() + + # Simulate a session that has already been used (has_connected = True) + stale_session = MagicMock() + client.session = stale_session + client.has_connected = True + + # First request raises ConnectionError (stale socket), + # second request succeeds + success_response = MagicMock() + success_response.status_code = 200 + success_response.json.return_value = {"result": "success", "msg": ""} + + stale_session.request.side_effect = requests.exceptions.ConnectionError("stale") + + fresh_session = MagicMock() + fresh_session.request.return_value = success_response + + call_count = 0 + + def mock_ensure_session(self: zulip.Client) -> None: + nonlocal call_count + call_count += 1 + if call_count == 1: + # First call: leave the stale session in place + self.session = stale_session + else: + # Subsequent calls: provide a fresh session + self.session = fresh_session + + with patch.object(zulip.Client, "ensure_session", mock_ensure_session): + result = client.do_api_query({}, "/api/v1/messages", method="POST") + + # The stale session should have been closed + stale_session.close.assert_called_once() + + # The result should come from the fresh session + self.assertEqual(result, {"result": "success", "msg": ""}) + + # ensure_session should have been called twice (once per loop iteration) + self.assertEqual(call_count, 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/zulip/zulip/__init__.py b/zulip/zulip/__init__.py index 82190a51d..03a5fcb2a 100644 --- a/zulip/zulip/__init__.py +++ b/zulip/zulip/__init__.py @@ -600,9 +600,6 @@ def do_api_query( req_files = [(f.name, f) for f in files] - self.ensure_session() - assert self.session is not None - query_state: Dict[str, Any] = { "had_error_retry": False, "request": request, @@ -637,6 +634,9 @@ def end_error_retry(succeeded: bool) -> None: print("Failed!") while True: + self.ensure_session() + assert self.session is not None + try: kwarg = "params" if method == "GET" else "data" @@ -686,6 +686,9 @@ def end_error_retry(succeeded: bool) -> None: raise UnrecoverableNetworkError( "cannot connect to server " + self.base_url ) from e + if self.session is not None: + self.session.close() + self.session = None if error_retry(""): continue