Skip to content

Commit 7252af4

Browse files
RedoubtsxShinnRyuu
andauthored
Add context manager support to clients (#5693)
* minimum context Signed-off-by: redoubts <Redoubts@users.noreply.github.com> * tests Signed-off-by: redoubts <Redoubts@users.noreply.github.com> * test_api_consistency Signed-off-by: redoubts <Redoubts@users.noreply.github.com> * xShinnRyuu rc Signed-off-by: redoubts <Redoubts@users.noreply.github.com> --------- Signed-off-by: redoubts <Redoubts@users.noreply.github.com> Co-authored-by: Thomas Zhou <54688146+xShinnRyuu@users.noreply.github.com>
1 parent b72ebd1 commit 7252af4

5 files changed

Lines changed: 52 additions & 1 deletion

File tree

python/glide-async/python/glide/glide_client.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import sys
44
import threading
5+
from types import TracebackType
56
from typing import (
67
TYPE_CHECKING,
78
Any,
@@ -288,6 +289,12 @@ async def _create_uds_connection(self) -> None:
288289
raise ClosingError("Failed to create UDS connection") from e
289290

290291
async def close(self, err_message: Optional[str] = None) -> None:
292+
"""
293+
Forwards to `aclose`, the more common method for async resources.
294+
"""
295+
await self.aclose(err_message)
296+
297+
async def aclose(self, err_message: Optional[str] = None) -> None:
291298
"""
292299
Terminate the client by closing all associated resources, including the socket and any active futures.
293300
All open futures will be closed with an exception.
@@ -313,6 +320,17 @@ async def close(self, err_message: Optional[str] = None) -> None:
313320

314321
await self._stream.aclose()
315322

323+
async def __aenter__(self) -> Self:
324+
return self
325+
326+
async def __aexit__(
327+
self,
328+
exc_type: Optional[type[BaseException]],
329+
exc: Optional[BaseException],
330+
tb: Optional[TracebackType],
331+
) -> None:
332+
await self.aclose()
333+
316334
def _get_future(self, callback_idx: int) -> "TFuture":
317335
response_future: "TFuture" = _get_new_future_instance()
318336
self._available_futures.update({callback_idx: response_future})

python/glide-sync/glide_sync/glide_client.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import sys
55
import threading
6+
from types import TracebackType
67
from typing import Any, List, Optional, Tuple, Union
78

89
from glide_shared._fast_response import parse_response as _fast_parse_response
@@ -911,7 +912,7 @@ def _parse_pubsub_state(self, result, is_cluster):
911912
actual_subscriptions=actual_subscriptions,
912913
)
913914

914-
def close(self):
915+
def close(self) -> None:
915916
if not self._is_closed:
916917
self._is_closed = True
917918
with self._pubsub_condition:
@@ -920,6 +921,17 @@ def close(self):
920921
self._core_client = self._ffi.NULL
921922
self._pubsub_callback_ref = None
922923

924+
def __enter__(self) -> Self:
925+
return self
926+
927+
def __exit__(
928+
self,
929+
exc_type: Optional[type[BaseException]],
930+
exc: Optional[BaseException],
931+
tb: Optional[TracebackType],
932+
) -> None:
933+
self.close()
934+
923935

924936
class GlideClusterClient(BaseClient, ClusterCommands):
925937
"""

python/tests/async_tests/test_async_client.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,16 @@ async def test_register_client_name_and_version(self, glide_client: TGlideClient
135135
assert "lib-name=GlidePy" in info_str
136136
assert "lib-ver=unknown" in info_str
137137

138+
@pytest.mark.parametrize("cluster_mode", [True, False])
139+
@pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3])
140+
async def test_context_manager(self, request, cluster_mode, protocol):
141+
async with await create_client(
142+
request, cluster_mode=cluster_mode, protocol=protocol, request_timeout=5000
143+
) as client:
144+
assert not client._is_closed
145+
146+
assert client._is_closed
147+
138148
@pytest.mark.parametrize("cluster_mode", [True, False])
139149
@pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3])
140150
async def test_send_and_receive_large_values(self, request, cluster_mode, protocol):

python/tests/sync_tests/test_sync_client.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,16 @@ def test_sync_register_client_name_and_version(
139139
assert "lib-name=GlidePySync" in info_str
140140
assert "lib-ver=unknown" in info_str
141141

142+
@pytest.mark.parametrize("cluster_mode", [True, False])
143+
@pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3])
144+
def test_context_manager(self, request, cluster_mode, protocol):
145+
with create_sync_client(
146+
request, cluster_mode=cluster_mode, protocol=protocol, request_timeout=5000
147+
) as client:
148+
assert not client._is_closed
149+
150+
assert client._is_closed
151+
142152
@pytest.mark.parametrize("cluster_mode", [True, False])
143153
@pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3])
144154
def test_sync_send_and_receive_large_values(self, request, cluster_mode, protocol):

python/tests/test_api_consistency.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
"create_leaked_value",
3838
"start_socket_listener_external",
3939
"value_from_pointer",
40+
"aclose",
4041
],
4142
"sync_only": [],
4243
}

0 commit comments

Comments
 (0)