diff --git a/src/dstack/_internal/cli/commands/init.py b/src/dstack/_internal/cli/commands/init.py index 2a5487a47a..64c5a240ae 100644 --- a/src/dstack/_internal/cli/commands/init.py +++ b/src/dstack/_internal/cli/commands/init.py @@ -11,6 +11,8 @@ register_init_repo_args, ) from dstack._internal.cli.utils.common import console +from dstack._internal.core.errors import CLIError, RepoInvalidCredentialsError +from dstack._internal.core.services.repos import get_repo_creds_and_default_branch from dstack.api import Client @@ -55,10 +57,19 @@ def _command(self, args: argparse.Namespace): repo = get_repo_from_dir(repo_path) else: assert False, "should not reach here" + + try: + repo_creds, _ = get_repo_creds_and_default_branch( + repo_url=repo.repo_url, + identity_file=args.git_identity_file, + oauth_token=args.gh_token, + ) + except RepoInvalidCredentialsError: + raise CLIError( + "No valid default Git credentials found. Pass valid `--token` or `--git-identity`." + ) + api = Client.from_config(project_name=args.project) - api.repos.init( - repo=repo, - git_identity_file=args.git_identity_file, - oauth_token=args.gh_token, - ) + api.repos.init(repo=repo, creds=repo_creds) + console.print("OK") diff --git a/src/dstack/_internal/cli/services/configurators/run.py b/src/dstack/_internal/cli/services/configurators/run.py index 8b7c125ea6..63738c968f 100644 --- a/src/dstack/_internal/cli/services/configurators/run.py +++ b/src/dstack/_internal/cli/services/configurators/run.py @@ -32,6 +32,7 @@ from dstack._internal.core.errors import ( CLIError, ConfigurationError, + RepoInvalidCredentialsError, ResourceNotExistsError, ServerClientError, ) @@ -52,10 +53,7 @@ from dstack._internal.core.models.resources import CPUSpec from dstack._internal.core.models.runs import JobStatus, JobSubmission, RunSpec, RunStatus from dstack._internal.core.services.diff import diff_models -from dstack._internal.core.services.repos import ( - InvalidRepoCredentialsError, - get_repo_creds_and_default_branch, -) +from dstack._internal.core.services.repos import get_repo_creds_and_default_branch from dstack._internal.core.services.ssh.ports import PortUsedError from dstack._internal.settings import FeatureFlags from dstack._internal.utils.common import local_time @@ -554,8 +552,6 @@ def get_repo( else: assert False, "should not reach here" - assert repo.repo_url is not None - if repo_head is not None and repo_head.repo_creds is not None: if git_identity_file is None and oauth_token is None: git_private_key = repo_head.repo_creds.private_key @@ -570,20 +566,17 @@ def get_repo( private_key=git_private_key, oauth_token=oauth_token, ) - except InvalidRepoCredentialsError as e: - raise CLIError(*e.args) from e + except RepoInvalidCredentialsError: + raise CLIError( + "No valid default Git credentials found. Pass valid `--token` or `--git-identity`." + ) repo.run_repo_data.repo_branch = repo_branch if repo_hash is not None: repo.run_repo_data.repo_hash = repo_hash if init: - self.api.repos.init( - repo=repo, - git_identity_file=git_identity_file, - oauth_token=oauth_token, - creds=repo_creds, - ) + self.api.repos.init(repo=repo, creds=repo_creds) return repo diff --git a/src/dstack/_internal/cli/services/repos.py b/src/dstack/_internal/cli/services/repos.py index 9805c717fb..b9a48e295a 100644 --- a/src/dstack/_internal/cli/services/repos.py +++ b/src/dstack/_internal/cli/services/repos.py @@ -1,10 +1,13 @@ from pathlib import Path -import git - from dstack._internal.cli.services.configurators.base import ArgsParser -from dstack._internal.core.errors import CLIError -from dstack._internal.core.models.repos.remote import GitRepoURL, RemoteRepo, RepoError +from dstack._internal.core.errors import ( + CLIError, + RepoDetachedHeadError, + RepoError, + RepoInvalidGitRepositoryError, +) +from dstack._internal.core.models.repos.remote import GitRepoURL, RemoteRepo from dstack._internal.core.models.repos.virtual import VirtualRepo from dstack._internal.utils.path import PathLike from dstack.api._public import Client @@ -42,14 +45,14 @@ def get_repo_from_dir(repo_dir: PathLike) -> RemoteRepo: raise CLIError(f"Path is not a directory: {repo_dir}") try: return RemoteRepo.from_dir(repo_dir) - except git.InvalidGitRepositoryError: + except RepoInvalidGitRepositoryError: raise CLIError( f"Git repo not found: {repo_dir}\n" "Use `files` to mount an arbitrary directory:" " https://dstack.ai/docs/concepts/tasks/#files" ) - except git.GitError as e: - raise CLIError(f"{e.__class__.__name__}: {e}") from e + except RepoDetachedHeadError: + raise CLIError(f"Git repo in 'detached HEAD' state: {repo_dir}\nCheck out to a branch") except RepoError as e: raise CLIError(str(e)) from e @@ -57,8 +60,6 @@ def get_repo_from_dir(repo_dir: PathLike) -> RemoteRepo: def get_repo_from_url(repo_url: str) -> RemoteRepo: try: return RemoteRepo.from_url(repo_url) - except git.GitError as e: - raise CLIError(f"{e.__class__.__name__}: {e}") from e except RepoError as e: raise CLIError(str(e)) from e diff --git a/src/dstack/api/_public/common.py b/src/dstack/_internal/core/deprecated.py similarity index 100% rename from src/dstack/api/_public/common.py rename to src/dstack/_internal/core/deprecated.py diff --git a/src/dstack/_internal/core/errors.py b/src/dstack/_internal/core/errors.py index 0d4262fe9b..5182d063b2 100644 --- a/src/dstack/_internal/core/errors.py +++ b/src/dstack/_internal/core/errors.py @@ -162,3 +162,40 @@ class SSHPortInUseError(SSHError): class DockerRegistryError(DstackError): pass + + +class RepoError(DstackError): + pass + + +class RepoDetachedHeadError(RepoError): + pass + + +class RepoInvalidCredentialsError(RepoError): + pass + + +class RepoGitError(RepoError): + """ + A wrapper for `git.exc.GitError` and its subclasses. + + Should be raised with `from e` clause to indicate the underlying exception. + To build a message from the underlying exception, raise this exception without arguments. + + try: + ... + except git.GitError as e: + raise RepoGitError() from e + """ + + def __str__(self) -> str: + if self.args or self.__cause__ is None: + return super().__str__() + return f"{self.__cause__.__class__.__name__}: {self.__cause__}" + + +class RepoInvalidGitRepositoryError(RepoGitError): + """ + `DstackError` counterpart for `git.exc.InvalidGitRepositoryError`. + """ diff --git a/src/dstack/_internal/core/models/repos/remote.py b/src/dstack/_internal/core/models/repos/remote.py index 3bfd34024d..20a6678626 100644 --- a/src/dstack/_internal/core/models/repos/remote.py +++ b/src/dstack/_internal/core/models/repos/remote.py @@ -3,25 +3,30 @@ import subprocess import time from dataclasses import dataclass -from typing import Any, BinaryIO, Callable, Dict, Optional +from typing import Annotated, Any, BinaryIO, Callable, Dict, Optional, Union, cast import git import pydantic from pydantic import Field from typing_extensions import Literal -from dstack._internal.core.errors import DstackError +from dstack._internal.core.deprecated import Deprecated +from dstack._internal.core.errors import ( + RepoDetachedHeadError, + RepoError, + RepoGitError, + RepoInvalidGitRepositoryError, +) from dstack._internal.core.models.common import CoreConfig, generate_dual_core_model from dstack._internal.core.models.repos.base import BaseRepoInfo, Repo from dstack._internal.utils.hash import get_sha256, slugify +from dstack._internal.utils.logging import get_logger from dstack._internal.utils.path import PathLike from dstack._internal.utils.ssh import get_host_config -SCP_LOCATION_REGEX = re.compile(r"(?P[^/]+)@(?P[^/]+?):(?P.+)", re.IGNORECASE) - +logger = get_logger(__name__) -class RepoError(DstackError): - pass +SCP_LOCATION_REGEX = re.compile(r"(?P[^/]+)@(?P[^/]+?):(?P.+)", re.IGNORECASE) class RemoteRepoCredsConfig(CoreConfig): @@ -53,7 +58,7 @@ class RemoteRepoInfo( class RemoteRunRepoData(RemoteRepoInfo): repo_branch: Optional[str] = None repo_hash: Optional[str] = None - repo_diff: Optional[str] = Field(None, exclude=True) + repo_diff: Annotated[Optional[str], Field(exclude=True)] = None repo_config_name: Optional[str] = None repo_config_email: Optional[str] = None @@ -102,6 +107,7 @@ class RemoteRepo(Repo): """ run_repo_data: RemoteRunRepoData + repo_url: str @staticmethod def from_dir(repo_dir: PathLike) -> "RemoteRepo": @@ -143,45 +149,36 @@ def __init__( repo_id: Optional[str] = None, local_repo_dir: Optional[PathLike] = None, repo_url: Optional[str] = None, - repo_data: Optional[RemoteRunRepoData] = None, repo_branch: Optional[str] = None, repo_hash: Optional[str] = None, + repo_data: Union[Deprecated, RemoteRunRepoData, None] = Deprecated.PLACEHOLDER, ): - self.repo_dir = local_repo_dir - self.repo_url = repo_url - - if self.repo_dir is not None: - repo = git.Repo(self.repo_dir) - tracking_branch = repo.active_branch.tracking_branch() - if tracking_branch is None: - raise RepoError("No remote branch is configured") - self.repo_url = repo.remote(tracking_branch.remote_name).url - repo_data = RemoteRunRepoData.from_url(self.repo_url) - repo_data.repo_branch = tracking_branch.remote_head - repo_data.repo_hash = tracking_branch.commit.hexsha - repo_data.repo_config_name = repo.config_reader().get_value("user", "name", "") or None - repo_data.repo_config_email = ( - repo.config_reader().get_value("user", "email", "") or None + if repo_data is not Deprecated.PLACEHOLDER: + logger.warning( + "The repo_data argument is deprecated, ignored, and will be removed soon." + " As it was always ignored, it's safe to remove it." ) - repo_data.repo_diff = _repo_diff_verbose(repo, repo_data.repo_hash) - elif self.repo_url is not None: - repo_data = RemoteRunRepoData.from_url(self.repo_url) - if repo_branch is not None: - repo_data.repo_branch = repo_branch - if repo_hash is not None: - repo_data.repo_hash = repo_hash - elif repo_data is None: - raise RepoError("No remote repo data provided") + # _init_from_* methods must set repo_dir, repo_url, and run_repo_data + if local_repo_dir is not None: + try: + self._init_from_repo_dir(local_repo_dir) + except git.InvalidGitRepositoryError as e: + raise RepoInvalidGitRepositoryError() from e + except git.GitError as e: + raise RepoGitError() from e + elif repo_url is not None: + self._init_from_repo_url(repo_url, repo_branch, repo_hash) + else: + raise RepoError("Neither local repo dir nor repo URL provided") if repo_id is None: repo_id = slugify( - repo_data.repo_name, + self.run_repo_data.repo_name, GitRepoURL.parse( self.repo_url, get_ssh_config=get_host_config ).get_unique_location(), ) self.repo_id = repo_id - self.run_repo_data = repo_data def write_code_file(self, fp: BinaryIO) -> str: if self.run_repo_data.repo_diff is not None: @@ -191,6 +188,42 @@ def write_code_file(self, fp: BinaryIO) -> str: def get_repo_info(self) -> RemoteRepoInfo: return RemoteRepoInfo(repo_name=self.run_repo_data.repo_name) + def _init_from_repo_dir(self, repo_dir: PathLike): + git_repo = git.Repo(repo_dir) + if git_repo.head.is_detached: + raise RepoDetachedHeadError() + tracking_branch = git_repo.active_branch.tracking_branch() + if tracking_branch is None: + raise RepoError("No remote branch is configured") + + repo_url = git_repo.remote(tracking_branch.remote_name).url + repo_data = RemoteRunRepoData.from_url(repo_url) + repo_data.repo_branch = tracking_branch.remote_head + repo_data.repo_hash = tracking_branch.commit.hexsha + git_config = git_repo.config_reader() + if user_name := cast(str, git_config.get_value("user", "name", "")): + repo_data.repo_config_name = user_name + if user_email := cast(str, git_config.get_value("user", "email", "")): + repo_data.repo_config_email = user_email + repo_data.repo_diff = _repo_diff_verbose(git_repo, repo_data.repo_hash) + + self.repo_dir = str(repo_dir) + self.repo_url = repo_url + self.run_repo_data = repo_data + + def _init_from_repo_url( + self, repo_url: str, repo_branch: Optional[str], repo_hash: Optional[str] + ): + repo_data = RemoteRunRepoData.from_url(repo_url) + if repo_branch is not None: + repo_data.repo_branch = repo_branch + if repo_hash is not None: + repo_data.repo_hash = repo_hash + + self.repo_dir = None + self.repo_url = repo_url + self.run_repo_data = repo_data + class _DiffCollector: def __init__(self, warning_time: float, delay: float = 5): diff --git a/src/dstack/_internal/core/services/repos.py b/src/dstack/_internal/core/services/repos.py index f3b37443e4..dd02d93de9 100644 --- a/src/dstack/_internal/core/services/repos.py +++ b/src/dstack/_internal/core/services/repos.py @@ -4,11 +4,11 @@ from tempfile import NamedTemporaryFile from typing import Optional +import git import git.cmd import yaml -from git.exc import GitCommandError -from dstack._internal.core.errors import DstackError +from dstack._internal.core.errors import RepoInvalidCredentialsError from dstack._internal.core.models.repos import RemoteRepoCreds from dstack._internal.core.models.repos.remote import GitRepoURL from dstack._internal.utils.logging import get_logger @@ -21,10 +21,6 @@ default_ssh_key = os.path.expanduser("~/.ssh/id_rsa") -class InvalidRepoCredentialsError(DstackError): - pass - - def get_repo_creds_and_default_branch( repo_url: str, identity_file: Optional[PathLike] = None, @@ -34,7 +30,7 @@ def get_repo_creds_and_default_branch( url = GitRepoURL.parse(repo_url, get_ssh_config=get_host_config) # no auth - with suppress(InvalidRepoCredentialsError): + with suppress(RepoInvalidCredentialsError): creds, default_branch = _get_repo_creds_and_default_branch_https(url) logger.debug( "Git repo %s is public. Using no auth. Default branch: %s", repo_url, default_branch @@ -93,7 +89,7 @@ def get_repo_creds_and_default_branch( identities = get_host_config(url.original_host).get("identityfile") if identities: _identity_file = identities[0] - with suppress(InvalidRepoCredentialsError): + with suppress(RepoInvalidCredentialsError): _private_key = _read_private_key(_identity_file) creds, default_branch = _get_repo_creds_and_default_branch_ssh( url, _identity_file, _private_key @@ -112,7 +108,7 @@ def get_repo_creds_and_default_branch( gh_hosts = yaml.load(f, Loader=yaml.FullLoader) _oauth_token = gh_hosts.get(url.host, {}).get("oauth_token") if _oauth_token is not None: - with suppress(InvalidRepoCredentialsError): + with suppress(RepoInvalidCredentialsError): creds, default_branch = _get_repo_creds_and_default_branch_https(url, _oauth_token) masked_token = ( len(_oauth_token[:-4]) * "*" + _oauth_token[-4:] @@ -130,7 +126,7 @@ def get_repo_creds_and_default_branch( # default user key if os.path.exists(default_ssh_key): - with suppress(InvalidRepoCredentialsError): + with suppress(RepoInvalidCredentialsError): _private_key = _read_private_key(default_ssh_key) creds, default_branch = _get_repo_creds_and_default_branch_ssh( url, default_ssh_key, _private_key @@ -143,9 +139,7 @@ def get_repo_creds_and_default_branch( ) return creds, default_branch - raise InvalidRepoCredentialsError( - "No valid default Git credentials found. Pass valid `--token` or `--git-identity`." - ) + raise RepoInvalidCredentialsError() def _get_repo_creds_and_default_branch_ssh( @@ -155,9 +149,9 @@ def _get_repo_creds_and_default_branch_ssh( env = _make_git_env_for_creds_check(identity_file=identity_file) try: default_branch = _get_repo_default_branch(_url, env) - except GitCommandError as e: + except git.GitCommandError as e: message = f"Cannot access `{_url}` using the `{identity_file}` private SSH key" - raise InvalidRepoCredentialsError(message) from e + raise RepoInvalidCredentialsError(message) from e creds = RemoteRepoCreds( clone_url=_url, private_key=private_key, @@ -173,12 +167,12 @@ def _get_repo_creds_and_default_branch_https( env = _make_git_env_for_creds_check() try: default_branch = _get_repo_default_branch(url.as_https(oauth_token), env) - except GitCommandError as e: + except git.GitCommandError as e: message = f"Cannot access `{_url}`" if oauth_token is not None: masked_token = len(oauth_token[:-4]) * "*" + oauth_token[-4:] message = f"{message} using the `{masked_token}` token" - raise InvalidRepoCredentialsError(message) from e + raise RepoInvalidCredentialsError(message) from e creds = RemoteRepoCreds( clone_url=_url, private_key=None, @@ -224,11 +218,11 @@ def _get_repo_default_branch(url: str, env: dict[str, str]) -> Optional[str]: def _read_private_key(identity_file: PathLike) -> str: identity_file = Path(identity_file).expanduser().resolve() if not Path(identity_file).exists(): - raise InvalidRepoCredentialsError(f"The `{identity_file}` private SSH key doesn't exist") + raise RepoInvalidCredentialsError(f"The `{identity_file}` private SSH key doesn't exist") if not os.access(identity_file, os.R_OK): - raise InvalidRepoCredentialsError(f"Cannot access the `{identity_file}` private SSH key") + raise RepoInvalidCredentialsError(f"Cannot access the `{identity_file}` private SSH key") if not try_ssh_key_passphrase(identity_file): - raise InvalidRepoCredentialsError( + raise RepoInvalidCredentialsError( f"Cannot use the `{identity_file}` private SSH key. " "Ensure that it is valid and passphrase-free" ) diff --git a/src/dstack/_internal/server/services/placement.py b/src/dstack/_internal/server/services/placement.py index f99f1beadf..8a2e27b160 100644 --- a/src/dstack/_internal/server/services/placement.py +++ b/src/dstack/_internal/server/services/placement.py @@ -3,7 +3,6 @@ from typing import Optional from uuid import UUID -from git import List from sqlalchemy import and_, select, update from sqlalchemy.ext.asyncio import AsyncSession @@ -66,7 +65,7 @@ def get_placement_group_provisioning_data( async def get_fleet_placement_group_models( session: AsyncSession, fleet_id: Optional[UUID], -) -> List[PlacementGroupModel]: +) -> list[PlacementGroupModel]: if fleet_id is None: return [] res = await session.execute( @@ -138,7 +137,7 @@ def get_placement_group_model_for_job( async def find_or_create_suitable_placement_group( fleet_model: FleetModel, - placement_groups: List[PlacementGroupModel], + placement_groups: list[PlacementGroupModel], instance_offer: InstanceOffer, compute: ComputeWithPlacementGroupSupport, ) -> Optional[PlacementGroupModel]: @@ -157,7 +156,7 @@ async def find_or_create_suitable_placement_group( def find_suitable_placement_group( - placement_groups: List[PlacementGroupModel], + placement_groups: list[PlacementGroupModel], instance_offer: InstanceOffer, compute: ComputeWithPlacementGroupSupport, ) -> Optional[PlacementGroupModel]: diff --git a/src/dstack/api/_public/repos.py b/src/dstack/api/_public/repos.py index d212201439..07dfa1e46d 100644 --- a/src/dstack/api/_public/repos.py +++ b/src/dstack/api/_public/repos.py @@ -1,8 +1,11 @@ from typing import Literal, Optional, Union, overload -from git import InvalidGitRepositoryError - -from dstack._internal.core.errors import ConfigurationError, ResourceNotExistsError +from dstack._internal.core.errors import ( + ConfigurationError, + RepoInvalidCredentialsError, + RepoInvalidGitRepositoryError, + ResourceNotExistsError, +) from dstack._internal.core.models.repos import ( LocalRepo, RemoteRepo, @@ -11,10 +14,7 @@ RepoHead, RepoHeadWithCreds, ) -from dstack._internal.core.services.repos import ( - InvalidRepoCredentialsError, - get_repo_creds_and_default_branch, -) +from dstack._internal.core.services.repos import get_repo_creds_and_default_branch from dstack._internal.utils.logging import get_logger from dstack._internal.utils.path import PathLike from dstack.api.server import APIClient @@ -77,15 +77,14 @@ def init( " an arbitrary directory: https://dstack.ai/docs/concepts/tasks/#files" ) if creds is None and isinstance(repo, RemoteRepo): - assert repo.repo_url is not None try: creds, _ = get_repo_creds_and_default_branch( repo_url=repo.repo_url, identity_file=git_identity_file, oauth_token=oauth_token, ) - except InvalidRepoCredentialsError as e: - raise ConfigurationError(*e.args) + except RepoInvalidCredentialsError: + raise ConfigurationError("No valid default Git credentials found") self._api_client.repos.init(self._project, repo.repo_id, repo.get_repo_info(), creds) def load( @@ -129,7 +128,7 @@ def load( logger.debug("Initializing repo") try: repo = RemoteRepo.from_dir(repo_dir) - except InvalidGitRepositoryError: + except RepoInvalidGitRepositoryError: raise ConfigurationError( f"Git repo not found: {repo_dir}. Use `files` to mount an arbitrary" " directory: https://dstack.ai/docs/concepts/tasks/#files" diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index a8afac24ce..963978ee03 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -16,6 +16,7 @@ import dstack.api as api from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT, DSTACK_RUNNER_SSH_PORT +from dstack._internal.core.deprecated import Deprecated from dstack._internal.core.errors import ClientError, ConfigurationError, ResourceNotExistsError from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.configurations import ( @@ -49,7 +50,6 @@ from dstack._internal.utils.files import create_file_archive from dstack._internal.utils.logging import get_logger from dstack._internal.utils.path import PathLike -from dstack.api._public.common import Deprecated from dstack.api.server import APIClient logger = get_logger(__name__) diff --git a/src/tests/_internal/core/models/repos/test_remote.py b/src/tests/_internal/core/models/repos/test_remote.py index 04c257f146..acf80f2596 100644 --- a/src/tests/_internal/core/models/repos/test_remote.py +++ b/src/tests/_internal/core/models/repos/test_remote.py @@ -1,6 +1,7 @@ import pytest -from dstack._internal.core.models.repos.remote import GitRepoURL, RepoError +from dstack._internal.core.errors import RepoError +from dstack._internal.core.models.repos.remote import GitRepoURL class TestGitRepoURL: