|
| 1 | +from typing import Optional |
| 2 | + |
1 | 3 | from polyaxon import settings |
2 | 4 | from polyaxon._client.client import PolyaxonClient |
3 | 5 | from polyaxon._schemas.client import ClientConfig |
4 | 6 | from polyaxon._utils.fqn_utils import split_owner_team_space |
| 7 | +from polyaxon.exceptions import PolyaxonClientException |
5 | 8 |
|
6 | 9 |
|
7 | 10 | class ClientMixin: |
| 11 | + _IS_ASYNC: bool = False |
| 12 | + |
| 13 | + def _set_client(self, client: Optional[PolyaxonClient]): |
| 14 | + if client is not None and client.is_async != self._IS_ASYNC: |
| 15 | + raise PolyaxonClientException( |
| 16 | + "Injected PolyaxonClient transport mode does not match client class." |
| 17 | + ) |
| 18 | + self._client = client |
| 19 | + self._owns_client = client is None |
| 20 | + |
8 | 21 | @property |
9 | | - def client(self): |
10 | | - if self._client: |
| 22 | + def client(self) -> PolyaxonClient: |
| 23 | + if getattr(self, "_client", None) is not None: |
11 | 24 | return self._client |
12 | | - self._client = PolyaxonClient() |
| 25 | + self._client = PolyaxonClient(is_async=self._IS_ASYNC) |
| 26 | + self._owns_client = True |
13 | 27 | return self._client |
14 | 28 |
|
15 | 29 | def reset_client(self, **kwargs): |
| 30 | + if self._IS_ASYNC: |
| 31 | + raise PolyaxonClientException( |
| 32 | + "Use `await areset_client(...)` for async clients." |
| 33 | + ) |
16 | 34 | if not settings.CLIENT_CONFIG.in_cluster: |
| 35 | + previous = ( |
| 36 | + self._client |
| 37 | + if getattr(self, "_owns_client", False) |
| 38 | + and getattr(self, "_client", None) is not None |
| 39 | + else None |
| 40 | + ) |
17 | 41 | self._client = PolyaxonClient( |
18 | | - ClientConfig.patch_from(settings.CLIENT_CONFIG, **kwargs) |
| 42 | + ClientConfig.patch_from(settings.CLIENT_CONFIG, **kwargs), |
| 43 | + is_async=False, |
| 44 | + ) |
| 45 | + self._owns_client = True |
| 46 | + if previous is not None: |
| 47 | + previous.close() |
| 48 | + |
| 49 | + async def areset_client(self, **kwargs): |
| 50 | + if not self._IS_ASYNC: |
| 51 | + raise PolyaxonClientException( |
| 52 | + "Use `reset_client(...)` for sync clients." |
19 | 53 | ) |
| 54 | + if not settings.CLIENT_CONFIG.in_cluster: |
| 55 | + previous = ( |
| 56 | + self._client |
| 57 | + if getattr(self, "_owns_client", False) |
| 58 | + and getattr(self, "_client", None) is not None |
| 59 | + else None |
| 60 | + ) |
| 61 | + self._client = PolyaxonClient( |
| 62 | + ClientConfig.patch_from(settings.CLIENT_CONFIG, **kwargs), |
| 63 | + is_async=True, |
| 64 | + ) |
| 65 | + self._owns_client = True |
| 66 | + if previous is not None: |
| 67 | + await previous.aclose() |
| 68 | + |
| 69 | + async def _flush_on_exit(self) -> None: |
| 70 | + return None |
| 71 | + |
| 72 | + def __enter__(self): |
| 73 | + if self._IS_ASYNC: |
| 74 | + raise PolyaxonClientException("Use `async with` for async clients.") |
| 75 | + return self |
| 76 | + |
| 77 | + def __exit__(self, exc_type, exc_value, traceback): |
| 78 | + if getattr(self, "_owns_client", False) and getattr(self, "_client", None): |
| 79 | + self._client.close() |
| 80 | + |
| 81 | + async def __aenter__(self): |
| 82 | + if not self._IS_ASYNC: |
| 83 | + raise PolyaxonClientException("Use `with` for sync clients.") |
| 84 | + return self |
| 85 | + |
| 86 | + async def __aexit__(self, exc_type, exc_value, traceback): |
| 87 | + try: |
| 88 | + await self._flush_on_exit() |
| 89 | + finally: |
| 90 | + if getattr(self, "_owns_client", False) and getattr(self, "_client", None): |
| 91 | + await self._client.aclose() |
20 | 92 |
|
21 | 93 | @property |
22 | 94 | def owner(self) -> str: |
|
0 commit comments