Skip to content

Commit eb5b3f4

Browse files
committed
Add polyaxon client lifecycle methods
1 parent 037d6a9 commit eb5b3f4

2 files changed

Lines changed: 220 additions & 19 deletions

File tree

cli/polyaxon/_client/client.py

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
)
2424
from polyaxon._sdk.async_client.api_client import AsyncApiClient
2525
from polyaxon._sdk.sync_client.api_client import ApiClient
26+
from polyaxon.exceptions import PolyaxonClientException
2627

2728
if TYPE_CHECKING:
2829
from polyaxon._schemas.client import ClientConfig
@@ -86,6 +87,9 @@ def __init__(
8687
self.is_async = is_async
8788
self.is_internal = is_internal
8889
self.api_client = self._get_client()
90+
self._reset_api_wrappers()
91+
92+
def _reset_api_wrappers(self):
8993
self._projects_v1 = None
9094
self._runs_v1 = None
9195
self._project_dashboards_v1 = None
@@ -118,24 +122,46 @@ def _get_client(self):
118122
return ApiClient(self.config.sdk_config, **self.config.client_header)
119123

120124
def reset(self):
121-
self._projects_v1 = None
122-
self._runs_v1 = None
123-
self._project_dashboards_v1 = None
124-
self._project_searches_v1 = None
125-
self._auth_v1 = None
126-
self._users_v1 = None
127-
self._versions_v1 = None
128-
self._agents_v1 = None
129-
self._queues_v1 = None
130-
self._service_accounts_v1 = None
131-
self._presets_v1 = None
132-
self._tags_v1 = None
133-
self._teams_v1 = None
134-
self._connections_v1 = None
135-
self._dashboards_v1 = None
136-
self._searches_v1 = None
137-
self._organizations_v1 = None
125+
if self.is_async:
126+
raise PolyaxonClientException("Use `await areset()` for async clients.")
127+
previous_client = self.api_client
138128
self.api_client = self._get_client()
129+
self._reset_api_wrappers()
130+
previous_client.close()
131+
132+
async def areset(self):
133+
if not self.is_async:
134+
raise PolyaxonClientException("Use `reset()` for sync clients.")
135+
previous_client = self.api_client
136+
self.api_client = self._get_client()
137+
self._reset_api_wrappers()
138+
await previous_client.close()
139+
140+
def close(self):
141+
if self.is_async:
142+
raise PolyaxonClientException("Use `await aclose()` for async clients.")
143+
self.api_client.close()
144+
145+
async def aclose(self):
146+
if not self.is_async:
147+
raise PolyaxonClientException("Use `close()` for sync clients.")
148+
await self.api_client.close()
149+
150+
def __enter__(self):
151+
if self.is_async:
152+
raise PolyaxonClientException("Use `async with` for async clients.")
153+
return self
154+
155+
def __exit__(self, exc_type, exc_value, traceback):
156+
self.close()
157+
158+
async def __aenter__(self):
159+
if not self.is_async:
160+
raise PolyaxonClientException("Use `with` for sync clients.")
161+
return self
162+
163+
async def __aexit__(self, exc_type, exc_value, traceback):
164+
await self.aclose()
139165

140166
@property
141167
def config(self):

cli/tests/test_client/test_polyaxon_client.py

Lines changed: 177 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,47 @@
33

44
from mock import patch
55

6+
from clipped.utils.paths import delete_path
7+
68
from polyaxon import settings
79
from polyaxon._client.client import PolyaxonClient
10+
from polyaxon._contexts import paths as ctx_paths
811
from polyaxon._constants.globals import NO_AUTH
912
from polyaxon._schemas.client import ClientConfig
1013
from polyaxon._sdk.api import (
14+
AgentsV1Api,
1115
AuthV1Api,
1216
ProjectsV1Api,
1317
RunsV1Api,
1418
UsersV1Api,
15-
VersionsV1Api, AgentsV1Api,
19+
VersionsV1Api,
1620
)
17-
from polyaxon._utils.test_utils import BaseTestCase
21+
from polyaxon._utils.test_utils import BaseTestCase, patch_settings
22+
from polyaxon.exceptions import PolyaxonClientException
23+
24+
25+
class SDKClientMock:
26+
def __init__(self, name=None, events=None):
27+
self.name = name
28+
self.events = events
29+
self.close_calls = 0
30+
31+
def close(self):
32+
self.close_calls += 1
33+
if self.events is not None:
34+
self.events.append("{}:close".format(self.name))
35+
36+
37+
class AsyncSDKClientMock(SDKClientMock):
38+
async def close(self):
39+
self.close_calls += 1
40+
if self.events is not None:
41+
self.events.append("{}:close".format(self.name))
42+
43+
44+
def setup_async_test_settings():
45+
delete_path(ctx_paths.CONTEXT_USER_POLYAXON_PATH)
46+
patch_settings()
1847

1948

2049
@pytest.mark.client_mark
@@ -104,3 +133,149 @@ def test_load_token(self):
104133
assert client.config.is_managed is False
105134
assert client.config.host == "http://localhost:8000"
106135
assert client.config.token == "test2"
136+
137+
def test_close_raises_on_async_instance(self):
138+
sdk_client = AsyncSDKClientMock()
139+
with patch.object(PolyaxonClient, "_get_client", return_value=sdk_client):
140+
client = PolyaxonClient(is_async=True)
141+
142+
with pytest.raises(PolyaxonClientException):
143+
client.close()
144+
145+
def test_reset_raises_on_async_instance(self):
146+
sdk_client = AsyncSDKClientMock()
147+
with patch.object(PolyaxonClient, "_get_client", return_value=sdk_client):
148+
client = PolyaxonClient(is_async=True)
149+
150+
with pytest.raises(PolyaxonClientException):
151+
client.reset()
152+
153+
def test_close_calls_api_client_close(self):
154+
sdk_client = SDKClientMock()
155+
with patch.object(PolyaxonClient, "_get_client", return_value=sdk_client):
156+
client = PolyaxonClient()
157+
client.close()
158+
159+
assert sdk_client.close_calls == 1
160+
161+
def test_reset_closes_previous_after_replacement(self):
162+
events = []
163+
previous = SDKClientMock(name="previous", events=events)
164+
replacement = SDKClientMock(name="replacement", events=events)
165+
166+
def get_client(client):
167+
if not events:
168+
events.append("previous:make")
169+
return previous
170+
events.append("replacement:make")
171+
return replacement
172+
173+
with patch.object(PolyaxonClient, "_get_client", autospec=True) as mock_get:
174+
mock_get.side_effect = get_client
175+
client = PolyaxonClient()
176+
client.reset()
177+
178+
assert client.api_client is replacement
179+
assert events == ["previous:make", "replacement:make", "previous:close"]
180+
181+
def test_supports_sync_context_manager(self):
182+
sdk_client = SDKClientMock()
183+
with patch.object(PolyaxonClient, "_get_client", return_value=sdk_client):
184+
with PolyaxonClient() as client:
185+
assert client.api_client is sdk_client
186+
187+
assert sdk_client.close_calls == 1
188+
189+
def test_sync_context_manager_raises_on_async_instance(self):
190+
sdk_client = AsyncSDKClientMock()
191+
with patch.object(PolyaxonClient, "_get_client", return_value=sdk_client):
192+
client = PolyaxonClient(is_async=True)
193+
194+
with pytest.raises(PolyaxonClientException):
195+
with client:
196+
pass
197+
198+
199+
@pytest.mark.client_mark
200+
@pytest.mark.asyncio
201+
async def test_aclose_raises_on_sync_instance():
202+
setup_async_test_settings()
203+
sdk_client = SDKClientMock()
204+
with patch.object(PolyaxonClient, "_get_client", return_value=sdk_client):
205+
client = PolyaxonClient()
206+
207+
with pytest.raises(PolyaxonClientException):
208+
await client.aclose()
209+
210+
211+
@pytest.mark.client_mark
212+
@pytest.mark.asyncio
213+
async def test_areset_raises_on_sync_instance():
214+
setup_async_test_settings()
215+
sdk_client = SDKClientMock()
216+
with patch.object(PolyaxonClient, "_get_client", return_value=sdk_client):
217+
client = PolyaxonClient()
218+
219+
with pytest.raises(PolyaxonClientException):
220+
await client.areset()
221+
222+
223+
@pytest.mark.client_mark
224+
@pytest.mark.asyncio
225+
async def test_aclose_awaits_api_client_close():
226+
setup_async_test_settings()
227+
sdk_client = AsyncSDKClientMock()
228+
with patch.object(PolyaxonClient, "_get_client", return_value=sdk_client):
229+
client = PolyaxonClient(is_async=True)
230+
await client.aclose()
231+
232+
assert sdk_client.close_calls == 1
233+
234+
235+
@pytest.mark.client_mark
236+
@pytest.mark.asyncio
237+
async def test_areset_closes_previous_after_replacement():
238+
setup_async_test_settings()
239+
events = []
240+
previous = AsyncSDKClientMock(name="previous", events=events)
241+
replacement = AsyncSDKClientMock(name="replacement", events=events)
242+
243+
def get_client(client):
244+
if not events:
245+
events.append("previous:make")
246+
return previous
247+
events.append("replacement:make")
248+
return replacement
249+
250+
with patch.object(PolyaxonClient, "_get_client", autospec=True) as mock_get:
251+
mock_get.side_effect = get_client
252+
client = PolyaxonClient(is_async=True)
253+
await client.areset()
254+
255+
assert client.api_client is replacement
256+
assert events == ["previous:make", "replacement:make", "previous:close"]
257+
258+
259+
@pytest.mark.client_mark
260+
@pytest.mark.asyncio
261+
async def test_supports_async_context_manager():
262+
setup_async_test_settings()
263+
sdk_client = AsyncSDKClientMock()
264+
with patch.object(PolyaxonClient, "_get_client", return_value=sdk_client):
265+
async with PolyaxonClient(is_async=True) as client:
266+
assert client.api_client is sdk_client
267+
268+
assert sdk_client.close_calls == 1
269+
270+
271+
@pytest.mark.client_mark
272+
@pytest.mark.asyncio
273+
async def test_async_context_manager_raises_on_sync_instance():
274+
setup_async_test_settings()
275+
sdk_client = SDKClientMock()
276+
with patch.object(PolyaxonClient, "_get_client", return_value=sdk_client):
277+
client = PolyaxonClient()
278+
279+
with pytest.raises(PolyaxonClientException):
280+
async with client:
281+
pass

0 commit comments

Comments
 (0)