Skip to content

Commit 15a46e9

Browse files
committed
Add async Polyaxon store
* Add support for downloads * Keep artifact uploads explicitly unsupported on async clients to avoid hiding sync file IO behind async APIs.
1 parent 959c53d commit 15a46e9

5 files changed

Lines changed: 316 additions & 14 deletions

File tree

cli/polyaxon/_client/run.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
get_global_or_inline_config,
4242
)
4343
from polyaxon._client.mixin import ClientMixin
44-
from polyaxon._client.store import PolyaxonStore
44+
from polyaxon._client.store import AsyncPolyaxonStore, PolyaxonStore
4545
from polyaxon._constants.metadata import META_COPY_ARTIFACTS, META_RECOMPILE, META_TMUX
4646
from polyaxon._containers.names import MAIN_CONTAINER_NAMES
4747
from polyaxon._contexts import paths as ctx_paths
@@ -3253,6 +3253,13 @@ async def _use_agent_host(self):
32533253
POLYAXON_HOST=self.settings.agent.url,
32543254
)
32553255

3256+
@property
3257+
def store(self):
3258+
if self._store:
3259+
return self._store
3260+
self._store = AsyncPolyaxonStore(client=self)
3261+
return self._store
3262+
32563263
@async_client_handler(check_no_op=True)
32573264
async def get_inputs(self) -> Dict[str, Any]:
32583265
if not self._run_data.inputs:
@@ -3734,9 +3741,7 @@ async def download_artifact_for_lineage(
37343741
)
37353742
params.update({"names": lineage.name, "pkg_assets": True})
37363743

3737-
# TODO: Update with AsyncPolyaxonStore is done
3738-
return await asyncio.to_thread(
3739-
self.store.download_file,
3744+
return await self.store.download_file(
37403745
url=url,
37413746
path=self.run_uuid,
37423747
use_filepath=False,
@@ -3795,8 +3800,7 @@ async def download_artifact(
37953800
)
37963801
url = absolute_uri(url=url, host=self.client.config.host)
37973802
params = get_streams_params(connection=self.artifacts_store, force=force)
3798-
return await asyncio.to_thread(
3799-
self.store.download_file,
3803+
return await self.store.download_file(
38003804
url=url,
38013805
path=path,
38023806
path_to=path_to,
@@ -3830,8 +3834,7 @@ async def download_artifacts(
38303834
if check_path:
38313835
params["check_path"] = True
38323836

3833-
return await asyncio.to_thread(
3834-
self.store.download_file,
3837+
return await self.store.download_file(
38353838
url=url,
38363839
path=path,
38373840
untar=untar,

cli/polyaxon/_client/store.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
import asyncio
12
import os
23
import requests
34
from typing import Dict, List
45

6+
import aiofiles
7+
import aiohttp
58
from requests_toolbelt import MultipartEncoder, MultipartEncoderMonitor
69

710
from clipped.formatting import Printer
@@ -326,3 +329,155 @@ def download_dir(
326329

327330
def delete(self, path, **kwargs):
328331
pass
332+
333+
334+
class AsyncPolyaxonStore(PolyaxonStore):
335+
async def ls(self, path):
336+
return await self.list(path=path)
337+
338+
async def list(self, path):
339+
return await self._client.get_artifacts_tree(path=path)
340+
341+
@staticmethod
342+
async def check_response_status(response, endpoint):
343+
if 200 <= response.status < 300:
344+
return response
345+
346+
try:
347+
reason = await response.text()
348+
logger.error(
349+
"Request to %s failed with status code %s. \nReason: %s",
350+
endpoint,
351+
response.status,
352+
reason,
353+
)
354+
except TypeError:
355+
logger.error("Request to %s failed with status code", endpoint)
356+
357+
raise PolyaxonClientException(HTTP_ERROR_MESSAGES_MAPPING.get(response.status))
358+
359+
@staticmethod
360+
def _get_request_timeout(timeout):
361+
if isinstance(timeout, aiohttp.ClientTimeout):
362+
return timeout
363+
return aiohttp.ClientTimeout(total=timeout)
364+
365+
async def download(
366+
self,
367+
url,
368+
filename,
369+
params=None,
370+
headers=None,
371+
timeout=None,
372+
session=None,
373+
untar=False,
374+
delete_tar=True,
375+
extract_path=None,
376+
use_filepath=True,
377+
show_progress=True,
378+
):
379+
logger.debug("Downloading files from url: %s", url)
380+
381+
request_headers = self._client.client.config.get_full_headers(headers=headers)
382+
timeout = timeout if timeout is not None else settings.LONG_REQUEST_TIMEOUT
383+
request_timeout = self._get_request_timeout(timeout)
384+
close_session = session is None
385+
session = session or aiohttp.ClientSession(timeout=request_timeout)
386+
387+
try:
388+
with Printer.console.status("Loading content ..."):
389+
async with session.get(
390+
url=url,
391+
params=params,
392+
headers=request_headers,
393+
timeout=request_timeout,
394+
) as response:
395+
content_disposition = self._get_header_value(
396+
headers=response.headers,
397+
key="content-disposition",
398+
)
399+
has_tar = (
400+
'.tar"' in content_disposition
401+
or '.tar.gz"' in content_disposition
402+
)
403+
if has_tar:
404+
filename = filename + ".tar.gz"
405+
if untar:
406+
untar = has_tar
407+
408+
await self.check_response_status(response, url)
409+
410+
content_length = self._get_header_value(
411+
headers=response.headers,
412+
key="content-length",
413+
)
414+
content_length = float(content_length) if content_length else None
415+
chunk_size = 1024 * 10
416+
417+
async def _download_impl(progress=None, task=None):
418+
async with aiofiles.open(filename, "wb") as f:
419+
async for chunk in response.content.iter_chunked(
420+
chunk_size
421+
):
422+
if progress:
423+
progress.update(task, advance=len(chunk))
424+
if chunk:
425+
await f.write(chunk)
426+
427+
if show_progress:
428+
with Printer.get_progress() as progress:
429+
task = progress.add_task(
430+
"Writing content:", total=content_length
431+
)
432+
await _download_impl(progress, task)
433+
else:
434+
await _download_impl()
435+
436+
if untar:
437+
filename = await asyncio.to_thread(
438+
untar_file,
439+
filename=filename,
440+
delete_tar=delete_tar,
441+
extract_path=extract_path,
442+
use_filepath=use_filepath,
443+
)
444+
return filename
445+
except (aiohttp.ClientError, asyncio.TimeoutError) as exception:
446+
try:
447+
logger.debug("Exception: %s", exception)
448+
except TypeError:
449+
pass
450+
451+
raise PolyaxonShouldExitError(
452+
"Error connecting to Polyaxon server on `{}`.\n"
453+
"An Error `{}` occurred.\n"
454+
"Check your host and ports configuration "
455+
"and your internet connection.".format(url, exception)
456+
)
457+
finally:
458+
if close_session:
459+
await session.close()
460+
461+
async def download_file(self, url, path, **kwargs):
462+
local_path = kwargs.pop("path_to", None)
463+
local_path = local_path or os.path.join(
464+
settings.CLIENT_CONFIG.archives_root, self._client.run_uuid
465+
)
466+
if path:
467+
local_path = os.path.join(local_path, path)
468+
469+
await asyncio.to_thread(check_or_create_path, local_path, is_dir=False)
470+
if not await asyncio.to_thread(os.path.exists, local_path):
471+
params = kwargs.pop("params", {})
472+
params["path"] = path
473+
await self.download(filename=local_path, params=params, url=url, **kwargs)
474+
return local_path
475+
476+
async def upload_file(self, *args, **kwargs):
477+
raise PolyaxonClientException("Async artifact upload is not supported yet.")
478+
479+
async def upload_dir(self, *args, **kwargs):
480+
raise PolyaxonClientException("Async artifact upload is not supported yet.")
481+
482+
async def upload(self, *args, **kwargs):
483+
raise PolyaxonClientException("Async artifact upload is not supported yet.")

cli/polyaxon/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from polyaxon._client.organization import AsyncOrganizationClient, OrganizationClient
44
from polyaxon._client.project import AsyncProjectClient, ProjectClient
55
from polyaxon._client.run import AsyncRunClient, RunClient, get_run_logs
6-
from polyaxon._client.store import PolyaxonStore
6+
from polyaxon._client.store import AsyncPolyaxonStore, PolyaxonStore
77
from polyaxon._schemas.agent import AgentConfig
88
from polyaxon._schemas.authentication import AccessTokenConfig
99
from polyaxon._schemas.cli import CliConfig

cli/tests/test_client/test_async_run_client.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from clipped.utils.hashing import hash_file, hash_value
66
from polyaxon import settings
77
from polyaxon._client.run import AsyncRunClient, RunClient
8+
from polyaxon._client.store import AsyncPolyaxonStore
89
from polyaxon._schemas.lifecycle import (
910
V1ProjectVersionKind,
1011
V1StatusCondition,
@@ -148,6 +149,13 @@ def make_run(**kwargs):
148149
return V1Run.model_construct(**data)
149150

150151

152+
def test_async_run_client_uses_async_store():
153+
patch_settings()
154+
client = make_client(AsyncPolyaxonClientMock())
155+
156+
assert isinstance(client.store, AsyncPolyaxonStore)
157+
158+
151159
def get_logged_lineage_artifact(sdk_client, index=0):
152160
body = sdk_client.runs_v1.create_run_artifacts_lineage.call_args_list[index][1][
153161
"body"
@@ -583,7 +591,7 @@ async def test_download_artifact_methods_call_store_download_file():
583591
client = make_client(sdk_client)
584592
client._run_data = make_run()
585593
client._store = mock.Mock()
586-
client._store.download_file = mock.Mock(side_effect=["/tmp/file", "/tmp/archive"])
594+
client._store.download_file = AsyncMock(side_effect=["/tmp/file", "/tmp/archive"])
587595

588596
file_path = await client.download_artifact(
589597
"outputs/model.pkl",
@@ -700,7 +708,7 @@ async def test_download_artifact_for_lineage_downloads_event_package():
700708
client = make_client(sdk_client)
701709
client._run_data = make_run()
702710
client._store = mock.Mock()
703-
client._store.download_file = mock.Mock(return_value="/tmp/events")
711+
client._store.download_file = AsyncMock(return_value="/tmp/events")
704712
lineage = V1RunArtifact.model_construct(
705713
kind=V1ArtifactKind.MODEL,
706714
name="model",

0 commit comments

Comments
 (0)