Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions src/dstack/_internal/cli/commands/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
register_init_repo_args,
)
from dstack._internal.cli.utils.common import console
from dstack._internal.core.errors import CLIError, RepoInvalidCredentialsError
from dstack._internal.core.services.repos import get_repo_creds_and_default_branch
from dstack.api import Client


Expand Down Expand Up @@ -55,10 +57,19 @@ def _command(self, args: argparse.Namespace):
repo = get_repo_from_dir(repo_path)
else:
assert False, "should not reach here"

try:
repo_creds, _ = get_repo_creds_and_default_branch(
repo_url=repo.repo_url,
identity_file=args.git_identity_file,
oauth_token=args.gh_token,
)
except RepoInvalidCredentialsError:
raise CLIError(
"No valid default Git credentials found. Pass valid `--token` or `--git-identity`."
)

api = Client.from_config(project_name=args.project)
api.repos.init(
repo=repo,
git_identity_file=args.git_identity_file,
oauth_token=args.gh_token,
)
api.repos.init(repo=repo, creds=repo_creds)

console.print("OK")
21 changes: 7 additions & 14 deletions src/dstack/_internal/cli/services/configurators/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from dstack._internal.core.errors import (
CLIError,
ConfigurationError,
RepoInvalidCredentialsError,
ResourceNotExistsError,
ServerClientError,
)
Expand All @@ -52,10 +53,7 @@
from dstack._internal.core.models.resources import CPUSpec
from dstack._internal.core.models.runs import JobStatus, JobSubmission, RunSpec, RunStatus
from dstack._internal.core.services.diff import diff_models
from dstack._internal.core.services.repos import (
InvalidRepoCredentialsError,
get_repo_creds_and_default_branch,
)
from dstack._internal.core.services.repos import get_repo_creds_and_default_branch
from dstack._internal.core.services.ssh.ports import PortUsedError
from dstack._internal.settings import FeatureFlags
from dstack._internal.utils.common import local_time
Expand Down Expand Up @@ -554,8 +552,6 @@ def get_repo(
else:
assert False, "should not reach here"

assert repo.repo_url is not None

if repo_head is not None and repo_head.repo_creds is not None:
if git_identity_file is None and oauth_token is None:
git_private_key = repo_head.repo_creds.private_key
Expand All @@ -570,20 +566,17 @@ def get_repo(
private_key=git_private_key,
oauth_token=oauth_token,
)
except InvalidRepoCredentialsError as e:
raise CLIError(*e.args) from e
except RepoInvalidCredentialsError:
raise CLIError(
"No valid default Git credentials found. Pass valid `--token` or `--git-identity`."
)

repo.run_repo_data.repo_branch = repo_branch
if repo_hash is not None:
repo.run_repo_data.repo_hash = repo_hash

if init:
self.api.repos.init(
repo=repo,
git_identity_file=git_identity_file,
oauth_token=oauth_token,
creds=repo_creds,
)
self.api.repos.init(repo=repo, creds=repo_creds)

return repo

Expand Down
19 changes: 10 additions & 9 deletions src/dstack/_internal/cli/services/repos.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from pathlib import Path

import git

from dstack._internal.cli.services.configurators.base import ArgsParser
from dstack._internal.core.errors import CLIError
from dstack._internal.core.models.repos.remote import GitRepoURL, RemoteRepo, RepoError
from dstack._internal.core.errors import (
CLIError,
RepoDetachedHeadError,
RepoError,
RepoInvalidGitRepositoryError,
)
from dstack._internal.core.models.repos.remote import GitRepoURL, RemoteRepo
from dstack._internal.core.models.repos.virtual import VirtualRepo
from dstack._internal.utils.path import PathLike
from dstack.api._public import Client
Expand Down Expand Up @@ -42,23 +45,21 @@ def get_repo_from_dir(repo_dir: PathLike) -> RemoteRepo:
raise CLIError(f"Path is not a directory: {repo_dir}")
try:
return RemoteRepo.from_dir(repo_dir)
except git.InvalidGitRepositoryError:
except RepoInvalidGitRepositoryError:
raise CLIError(
f"Git repo not found: {repo_dir}\n"
"Use `files` to mount an arbitrary directory:"
" https://dstack.ai/docs/concepts/tasks/#files"
)
except git.GitError as e:
raise CLIError(f"{e.__class__.__name__}: {e}") from e
except RepoDetachedHeadError:
raise CLIError(f"Git repo in 'detached HEAD' state: {repo_dir}\nCheck out to a branch")
except RepoError as e:
raise CLIError(str(e)) from e


def get_repo_from_url(repo_url: str) -> RemoteRepo:
try:
return RemoteRepo.from_url(repo_url)
except git.GitError as e:
raise CLIError(f"{e.__class__.__name__}: {e}") from e
except RepoError as e:
raise CLIError(str(e)) from e

Expand Down
37 changes: 37 additions & 0 deletions src/dstack/_internal/core/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,40 @@ class SSHPortInUseError(SSHError):

class DockerRegistryError(DstackError):
pass


class RepoError(DstackError):
pass


class RepoDetachedHeadError(RepoError):
pass


class RepoInvalidCredentialsError(RepoError):
pass


class RepoGitError(RepoError):
"""
A wrapper for `git.exc.GitError` and its subclasses.

Should be raised with `from e` clause to indicate the underlying exception.
To build a message from the underlying exception, raise this exception without arguments.

try:
...
except git.GitError as e:
raise RepoGitError() from e
"""

def __str__(self) -> str:
if self.args or self.__cause__ is None:
return super().__str__()
return f"{self.__cause__.__class__.__name__}: {self.__cause__}"


class RepoInvalidGitRepositoryError(RepoGitError):
"""
`DstackError` counterpart for `git.exc.InvalidGitRepositoryError`.
"""
101 changes: 67 additions & 34 deletions src/dstack/_internal/core/models/repos/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,30 @@
import subprocess
import time
from dataclasses import dataclass
from typing import Any, BinaryIO, Callable, Dict, Optional
from typing import Annotated, Any, BinaryIO, Callable, Dict, Optional, Union, cast

import git
import pydantic
from pydantic import Field
from typing_extensions import Literal

from dstack._internal.core.errors import DstackError
from dstack._internal.core.deprecated import Deprecated
from dstack._internal.core.errors import (
RepoDetachedHeadError,
RepoError,
RepoGitError,
RepoInvalidGitRepositoryError,
)
from dstack._internal.core.models.common import CoreConfig, generate_dual_core_model
from dstack._internal.core.models.repos.base import BaseRepoInfo, Repo
from dstack._internal.utils.hash import get_sha256, slugify
from dstack._internal.utils.logging import get_logger
from dstack._internal.utils.path import PathLike
from dstack._internal.utils.ssh import get_host_config

SCP_LOCATION_REGEX = re.compile(r"(?P<user>[^/]+)@(?P<host>[^/]+?):(?P<path>.+)", re.IGNORECASE)

logger = get_logger(__name__)

class RepoError(DstackError):
pass
SCP_LOCATION_REGEX = re.compile(r"(?P<user>[^/]+)@(?P<host>[^/]+?):(?P<path>.+)", re.IGNORECASE)


class RemoteRepoCredsConfig(CoreConfig):
Expand Down Expand Up @@ -53,7 +58,7 @@ class RemoteRepoInfo(
class RemoteRunRepoData(RemoteRepoInfo):
repo_branch: Optional[str] = None
repo_hash: Optional[str] = None
repo_diff: Optional[str] = Field(None, exclude=True)
repo_diff: Annotated[Optional[str], Field(exclude=True)] = None
repo_config_name: Optional[str] = None
repo_config_email: Optional[str] = None

Expand Down Expand Up @@ -102,6 +107,7 @@ class RemoteRepo(Repo):
"""

run_repo_data: RemoteRunRepoData
repo_url: str

@staticmethod
def from_dir(repo_dir: PathLike) -> "RemoteRepo":
Expand Down Expand Up @@ -143,45 +149,36 @@ def __init__(
repo_id: Optional[str] = None,
local_repo_dir: Optional[PathLike] = None,
repo_url: Optional[str] = None,
repo_data: Optional[RemoteRunRepoData] = None,
repo_branch: Optional[str] = None,
repo_hash: Optional[str] = None,
repo_data: Union[Deprecated, RemoteRunRepoData, None] = Deprecated.PLACEHOLDER,
):
self.repo_dir = local_repo_dir
self.repo_url = repo_url

if self.repo_dir is not None:
repo = git.Repo(self.repo_dir)
tracking_branch = repo.active_branch.tracking_branch()
if tracking_branch is None:
raise RepoError("No remote branch is configured")
self.repo_url = repo.remote(tracking_branch.remote_name).url
repo_data = RemoteRunRepoData.from_url(self.repo_url)
repo_data.repo_branch = tracking_branch.remote_head
repo_data.repo_hash = tracking_branch.commit.hexsha
repo_data.repo_config_name = repo.config_reader().get_value("user", "name", "") or None
repo_data.repo_config_email = (
repo.config_reader().get_value("user", "email", "") or None
if repo_data is not Deprecated.PLACEHOLDER:
logger.warning(
"The repo_data argument is deprecated, ignored, and will be removed soon."
" As it was always ignored, it's safe to remove it."
)
repo_data.repo_diff = _repo_diff_verbose(repo, repo_data.repo_hash)
elif self.repo_url is not None:
repo_data = RemoteRunRepoData.from_url(self.repo_url)
if repo_branch is not None:
repo_data.repo_branch = repo_branch
if repo_hash is not None:
repo_data.repo_hash = repo_hash
elif repo_data is None:
raise RepoError("No remote repo data provided")
# _init_from_* methods must set repo_dir, repo_url, and run_repo_data
if local_repo_dir is not None:
try:
self._init_from_repo_dir(local_repo_dir)
except git.InvalidGitRepositoryError as e:
raise RepoInvalidGitRepositoryError() from e
except git.GitError as e:
raise RepoGitError() from e
elif repo_url is not None:
self._init_from_repo_url(repo_url, repo_branch, repo_hash)
else:
raise RepoError("Neither local repo dir nor repo URL provided")

if repo_id is None:
repo_id = slugify(
repo_data.repo_name,
self.run_repo_data.repo_name,
GitRepoURL.parse(
self.repo_url, get_ssh_config=get_host_config
).get_unique_location(),
)
self.repo_id = repo_id
self.run_repo_data = repo_data

def write_code_file(self, fp: BinaryIO) -> str:
if self.run_repo_data.repo_diff is not None:
Expand All @@ -191,6 +188,42 @@ def write_code_file(self, fp: BinaryIO) -> str:
def get_repo_info(self) -> RemoteRepoInfo:
return RemoteRepoInfo(repo_name=self.run_repo_data.repo_name)

def _init_from_repo_dir(self, repo_dir: PathLike):
git_repo = git.Repo(repo_dir)
if git_repo.head.is_detached:
raise RepoDetachedHeadError()
tracking_branch = git_repo.active_branch.tracking_branch()
if tracking_branch is None:
raise RepoError("No remote branch is configured")

repo_url = git_repo.remote(tracking_branch.remote_name).url
repo_data = RemoteRunRepoData.from_url(repo_url)
repo_data.repo_branch = tracking_branch.remote_head
repo_data.repo_hash = tracking_branch.commit.hexsha
git_config = git_repo.config_reader()
if user_name := cast(str, git_config.get_value("user", "name", "")):
repo_data.repo_config_name = user_name
if user_email := cast(str, git_config.get_value("user", "email", "")):
repo_data.repo_config_email = user_email
repo_data.repo_diff = _repo_diff_verbose(git_repo, repo_data.repo_hash)

self.repo_dir = str(repo_dir)
self.repo_url = repo_url
self.run_repo_data = repo_data

def _init_from_repo_url(
self, repo_url: str, repo_branch: Optional[str], repo_hash: Optional[str]
):
repo_data = RemoteRunRepoData.from_url(repo_url)
if repo_branch is not None:
repo_data.repo_branch = repo_branch
if repo_hash is not None:
repo_data.repo_hash = repo_hash

self.repo_dir = None
self.repo_url = repo_url
self.run_repo_data = repo_data


class _DiffCollector:
def __init__(self, warning_time: float, delay: float = 5):
Expand Down
Loading
Loading