|
| 1 | +import asyncio |
1 | 2 | import os |
2 | 3 | import requests |
3 | 4 | from typing import Dict, List |
4 | 5 |
|
| 6 | +import aiofiles |
| 7 | +import aiohttp |
5 | 8 | from requests_toolbelt import MultipartEncoder, MultipartEncoderMonitor |
6 | 9 |
|
7 | 10 | from clipped.formatting import Printer |
@@ -326,3 +329,155 @@ def download_dir( |
326 | 329 |
|
327 | 330 | def delete(self, path, **kwargs): |
328 | 331 | pass |
| 332 | + |
| 333 | + |
| 334 | +class AsyncPolyaxonStore(PolyaxonStore): |
| 335 | + async def ls(self, path): |
| 336 | + return await self.list(path=path) |
| 337 | + |
| 338 | + async def list(self, path): |
| 339 | + return await self._client.get_artifacts_tree(path=path) |
| 340 | + |
| 341 | + @staticmethod |
| 342 | + async def check_response_status(response, endpoint): |
| 343 | + if 200 <= response.status < 300: |
| 344 | + return response |
| 345 | + |
| 346 | + try: |
| 347 | + reason = await response.text() |
| 348 | + logger.error( |
| 349 | + "Request to %s failed with status code %s. \nReason: %s", |
| 350 | + endpoint, |
| 351 | + response.status, |
| 352 | + reason, |
| 353 | + ) |
| 354 | + except TypeError: |
| 355 | + logger.error("Request to %s failed with status code", endpoint) |
| 356 | + |
| 357 | + raise PolyaxonClientException(HTTP_ERROR_MESSAGES_MAPPING.get(response.status)) |
| 358 | + |
| 359 | + @staticmethod |
| 360 | + def _get_request_timeout(timeout): |
| 361 | + if isinstance(timeout, aiohttp.ClientTimeout): |
| 362 | + return timeout |
| 363 | + return aiohttp.ClientTimeout(total=timeout) |
| 364 | + |
| 365 | + async def download( |
| 366 | + self, |
| 367 | + url, |
| 368 | + filename, |
| 369 | + params=None, |
| 370 | + headers=None, |
| 371 | + timeout=None, |
| 372 | + session=None, |
| 373 | + untar=False, |
| 374 | + delete_tar=True, |
| 375 | + extract_path=None, |
| 376 | + use_filepath=True, |
| 377 | + show_progress=True, |
| 378 | + ): |
| 379 | + logger.debug("Downloading files from url: %s", url) |
| 380 | + |
| 381 | + request_headers = self._client.client.config.get_full_headers(headers=headers) |
| 382 | + timeout = timeout if timeout is not None else settings.LONG_REQUEST_TIMEOUT |
| 383 | + request_timeout = self._get_request_timeout(timeout) |
| 384 | + close_session = session is None |
| 385 | + session = session or aiohttp.ClientSession(timeout=request_timeout) |
| 386 | + |
| 387 | + try: |
| 388 | + with Printer.console.status("Loading content ..."): |
| 389 | + async with session.get( |
| 390 | + url=url, |
| 391 | + params=params, |
| 392 | + headers=request_headers, |
| 393 | + timeout=request_timeout, |
| 394 | + ) as response: |
| 395 | + content_disposition = self._get_header_value( |
| 396 | + headers=response.headers, |
| 397 | + key="content-disposition", |
| 398 | + ) |
| 399 | + has_tar = ( |
| 400 | + '.tar"' in content_disposition |
| 401 | + or '.tar.gz"' in content_disposition |
| 402 | + ) |
| 403 | + if has_tar: |
| 404 | + filename = filename + ".tar.gz" |
| 405 | + if untar: |
| 406 | + untar = has_tar |
| 407 | + |
| 408 | + await self.check_response_status(response, url) |
| 409 | + |
| 410 | + content_length = self._get_header_value( |
| 411 | + headers=response.headers, |
| 412 | + key="content-length", |
| 413 | + ) |
| 414 | + content_length = float(content_length) if content_length else None |
| 415 | + chunk_size = 1024 * 10 |
| 416 | + |
| 417 | + async def _download_impl(progress=None, task=None): |
| 418 | + async with aiofiles.open(filename, "wb") as f: |
| 419 | + async for chunk in response.content.iter_chunked( |
| 420 | + chunk_size |
| 421 | + ): |
| 422 | + if progress: |
| 423 | + progress.update(task, advance=len(chunk)) |
| 424 | + if chunk: |
| 425 | + await f.write(chunk) |
| 426 | + |
| 427 | + if show_progress: |
| 428 | + with Printer.get_progress() as progress: |
| 429 | + task = progress.add_task( |
| 430 | + "Writing content:", total=content_length |
| 431 | + ) |
| 432 | + await _download_impl(progress, task) |
| 433 | + else: |
| 434 | + await _download_impl() |
| 435 | + |
| 436 | + if untar: |
| 437 | + filename = await asyncio.to_thread( |
| 438 | + untar_file, |
| 439 | + filename=filename, |
| 440 | + delete_tar=delete_tar, |
| 441 | + extract_path=extract_path, |
| 442 | + use_filepath=use_filepath, |
| 443 | + ) |
| 444 | + return filename |
| 445 | + except (aiohttp.ClientError, asyncio.TimeoutError) as exception: |
| 446 | + try: |
| 447 | + logger.debug("Exception: %s", exception) |
| 448 | + except TypeError: |
| 449 | + pass |
| 450 | + |
| 451 | + raise PolyaxonShouldExitError( |
| 452 | + "Error connecting to Polyaxon server on `{}`.\n" |
| 453 | + "An Error `{}` occurred.\n" |
| 454 | + "Check your host and ports configuration " |
| 455 | + "and your internet connection.".format(url, exception) |
| 456 | + ) |
| 457 | + finally: |
| 458 | + if close_session: |
| 459 | + await session.close() |
| 460 | + |
| 461 | + async def download_file(self, url, path, **kwargs): |
| 462 | + local_path = kwargs.pop("path_to", None) |
| 463 | + local_path = local_path or os.path.join( |
| 464 | + settings.CLIENT_CONFIG.archives_root, self._client.run_uuid |
| 465 | + ) |
| 466 | + if path: |
| 467 | + local_path = os.path.join(local_path, path) |
| 468 | + |
| 469 | + await asyncio.to_thread(check_or_create_path, local_path, is_dir=False) |
| 470 | + if not await asyncio.to_thread(os.path.exists, local_path): |
| 471 | + params = kwargs.pop("params", {}) |
| 472 | + params["path"] = path |
| 473 | + await self.download(filename=local_path, params=params, url=url, **kwargs) |
| 474 | + return local_path |
| 475 | + |
| 476 | + async def upload_file(self, *args, **kwargs): |
| 477 | + raise PolyaxonClientException("Async artifact upload is not supported yet.") |
| 478 | + |
| 479 | + async def upload_dir(self, *args, **kwargs): |
| 480 | + raise PolyaxonClientException("Async artifact upload is not supported yet.") |
| 481 | + |
| 482 | + async def upload(self, *args, **kwargs): |
| 483 | + raise PolyaxonClientException("Async artifact upload is not supported yet.") |
0 commit comments