|
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 |
|
| 16 | +from __future__ import annotations |
| 17 | + |
16 | 18 | import base64 |
17 | 19 | import logging |
18 | 20 | import os |
|
22 | 24 | import time |
23 | 25 | from dataclasses import dataclass, field |
24 | 26 | from datetime import datetime |
| 27 | +from enum import Enum |
25 | 28 | from pathlib import Path |
26 | 29 | from typing import Any, List, Optional, Set, Type |
27 | 30 |
|
28 | 31 | from invoke.context import Context |
29 | | -from leptonai.api.v2.client import APIClient |
30 | | -from leptonai.api.v1.types.affinity import LeptonResourceAffinity |
31 | | -from leptonai.api.v1.types.common import Metadata, LeptonVisibility |
32 | | -from leptonai.api.v1.types.dedicated_node_group import DedicatedNodeGroup |
33 | | -from leptonai.api.v1.types.deployment import ( |
34 | | - EnvVar, |
35 | | - EnvValue, |
36 | | - LeptonContainer, |
37 | | - Mount, |
38 | | -) |
39 | | -from leptonai.api.v1.types.job import ( |
40 | | - LeptonJob, |
41 | | - LeptonJobState, |
42 | | - LeptonJobUserSpec, |
43 | | - ReservationConfig, |
44 | | -) |
45 | | -from leptonai.api.v1.types.replica import Replica |
46 | 32 |
|
47 | 33 | from nemo_run.config import get_nemorun_home |
48 | 34 | from nemo_run.core.execution.base import Executor, ExecutorMacros |
49 | 35 | from nemo_run.core.packaging.base import Packager |
50 | 36 | from nemo_run.core.packaging.git import GitArchivePackager |
51 | 37 |
|
| 38 | +_LEPTON_IMPORT_ERROR: ImportError | None = None |
| 39 | +_LEPTON_AVAILABLE = False |
| 40 | + |
| 41 | +try: |
| 42 | + from leptonai.api.v1.types.affinity import LeptonResourceAffinity |
| 43 | + from leptonai.api.v1.types.common import LeptonVisibility, Metadata |
| 44 | + from leptonai.api.v1.types.dedicated_node_group import DedicatedNodeGroup |
| 45 | + from leptonai.api.v1.types.deployment import ( |
| 46 | + EnvVar, |
| 47 | + EnvValue, |
| 48 | + LeptonContainer, |
| 49 | + Mount, |
| 50 | + ) |
| 51 | + from leptonai.api.v1.types.job import ( |
| 52 | + LeptonJob, |
| 53 | + LeptonJobState, |
| 54 | + LeptonJobUserSpec, |
| 55 | + ReservationConfig, |
| 56 | + ) |
| 57 | + from leptonai.api.v1.types.replica import Replica |
| 58 | + from leptonai.api.v2.client import APIClient |
| 59 | + |
| 60 | + _LEPTON_AVAILABLE = True |
| 61 | +except ImportError as e: |
| 62 | + _LEPTON_IMPORT_ERROR = e |
| 63 | + |
| 64 | + class LeptonJobState(Enum): |
| 65 | + Starting = "Starting" |
| 66 | + Running = "Running" |
| 67 | + Failed = "Failed" |
| 68 | + Completed = "Completed" |
| 69 | + Deleting = "Deleting" |
| 70 | + Restarting = "Restarting" |
| 71 | + Archived = "Archived" |
| 72 | + Stopped = "Stopped" |
| 73 | + Stopping = "Stopping" |
| 74 | + Unknown = "Unknown" |
| 75 | + |
| 76 | + APIClient = None |
| 77 | + DedicatedNodeGroup = None |
| 78 | + EnvValue = None |
| 79 | + EnvVar = None |
| 80 | + LeptonContainer = None |
| 81 | + LeptonJob = None |
| 82 | + LeptonJobUserSpec = None |
| 83 | + LeptonResourceAffinity = None |
| 84 | + LeptonVisibility = None |
| 85 | + Metadata = None |
| 86 | + Mount = None |
| 87 | + Replica = None |
| 88 | + ReservationConfig = None |
| 89 | + |
52 | 90 | logger = logging.getLogger(__name__) |
53 | 91 |
|
54 | 92 |
|
| 93 | +def _require_leptonai() -> None: |
| 94 | + if not _LEPTON_AVAILABLE: |
| 95 | + raise ImportError( |
| 96 | + "leptonai package is required for LeptonExecutor. " |
| 97 | + 'Install it with: pip install "nemo_run[lepton]"' |
| 98 | + ) from _LEPTON_IMPORT_ERROR |
| 99 | + |
| 100 | + |
55 | 101 | @dataclass(kw_only=True) |
56 | 102 | class LeptonExecutor(Executor): |
57 | 103 | """ |
@@ -84,6 +130,9 @@ class LeptonExecutor(Executor): |
84 | 130 | head_resource_shape: Optional[str] = "" # Only used for LeptonRayCluster |
85 | 131 | ray_version: Optional[str] = None # Only used for LeptonRayCluster |
86 | 132 |
|
| 133 | + def __post_init__(self) -> None: |
| 134 | + _require_leptonai() |
| 135 | + |
87 | 136 | def stop_job(self, job_id: str): |
88 | 137 | """ |
89 | 138 | Send a stop signal to the requested job |
@@ -376,6 +425,7 @@ def cancel(self, job_id: str): |
376 | 425 |
|
377 | 426 | @classmethod |
378 | 427 | def logs(cls: Type["LeptonExecutor"], app_id: str, fallback_path: Optional[str]): |
| 428 | + _require_leptonai() |
379 | 429 | client = APIClient() |
380 | 430 |
|
381 | 431 | # Get the first replica from the job which contains the job logs |
|
0 commit comments