Skip to content

Commit 10b076e

Browse files
authored
Revamp repo errors handling (#3730)
* Revamp repo errors handling * Move all repo-related exception classes to `core.errors` * Wrap exceptions from `git` library * Handle 'detached HEAD` error * Fix _some_ typing issues Fixes: #3722 * Restore RemoteRepo repo_data argument
1 parent 3a891ef commit 10b076e

File tree

11 files changed

+167
-99
lines changed

11 files changed

+167
-99
lines changed

src/dstack/_internal/cli/commands/init.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
register_init_repo_args,
1212
)
1313
from dstack._internal.cli.utils.common import console
14+
from dstack._internal.core.errors import CLIError, RepoInvalidCredentialsError
15+
from dstack._internal.core.services.repos import get_repo_creds_and_default_branch
1416
from dstack.api import Client
1517

1618

@@ -55,10 +57,19 @@ def _command(self, args: argparse.Namespace):
5557
repo = get_repo_from_dir(repo_path)
5658
else:
5759
assert False, "should not reach here"
60+
61+
try:
62+
repo_creds, _ = get_repo_creds_and_default_branch(
63+
repo_url=repo.repo_url,
64+
identity_file=args.git_identity_file,
65+
oauth_token=args.gh_token,
66+
)
67+
except RepoInvalidCredentialsError:
68+
raise CLIError(
69+
"No valid default Git credentials found. Pass valid `--token` or `--git-identity`."
70+
)
71+
5872
api = Client.from_config(project_name=args.project)
59-
api.repos.init(
60-
repo=repo,
61-
git_identity_file=args.git_identity_file,
62-
oauth_token=args.gh_token,
63-
)
73+
api.repos.init(repo=repo, creds=repo_creds)
74+
6475
console.print("OK")

src/dstack/_internal/cli/services/configurators/run.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from dstack._internal.core.errors import (
3333
CLIError,
3434
ConfigurationError,
35+
RepoInvalidCredentialsError,
3536
ResourceNotExistsError,
3637
ServerClientError,
3738
)
@@ -52,10 +53,7 @@
5253
from dstack._internal.core.models.resources import CPUSpec
5354
from dstack._internal.core.models.runs import JobStatus, JobSubmission, RunSpec, RunStatus
5455
from dstack._internal.core.services.diff import diff_models
55-
from dstack._internal.core.services.repos import (
56-
InvalidRepoCredentialsError,
57-
get_repo_creds_and_default_branch,
58-
)
56+
from dstack._internal.core.services.repos import get_repo_creds_and_default_branch
5957
from dstack._internal.core.services.ssh.ports import PortUsedError
6058
from dstack._internal.settings import FeatureFlags
6159
from dstack._internal.utils.common import local_time
@@ -554,8 +552,6 @@ def get_repo(
554552
else:
555553
assert False, "should not reach here"
556554

557-
assert repo.repo_url is not None
558-
559555
if repo_head is not None and repo_head.repo_creds is not None:
560556
if git_identity_file is None and oauth_token is None:
561557
git_private_key = repo_head.repo_creds.private_key
@@ -570,20 +566,17 @@ def get_repo(
570566
private_key=git_private_key,
571567
oauth_token=oauth_token,
572568
)
573-
except InvalidRepoCredentialsError as e:
574-
raise CLIError(*e.args) from e
569+
except RepoInvalidCredentialsError:
570+
raise CLIError(
571+
"No valid default Git credentials found. Pass valid `--token` or `--git-identity`."
572+
)
575573

576574
repo.run_repo_data.repo_branch = repo_branch
577575
if repo_hash is not None:
578576
repo.run_repo_data.repo_hash = repo_hash
579577

580578
if init:
581-
self.api.repos.init(
582-
repo=repo,
583-
git_identity_file=git_identity_file,
584-
oauth_token=oauth_token,
585-
creds=repo_creds,
586-
)
579+
self.api.repos.init(repo=repo, creds=repo_creds)
587580

588581
return repo
589582

src/dstack/_internal/cli/services/repos.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from pathlib import Path
22

3-
import git
4-
53
from dstack._internal.cli.services.configurators.base import ArgsParser
6-
from dstack._internal.core.errors import CLIError
7-
from dstack._internal.core.models.repos.remote import GitRepoURL, RemoteRepo, RepoError
4+
from dstack._internal.core.errors import (
5+
CLIError,
6+
RepoDetachedHeadError,
7+
RepoError,
8+
RepoInvalidGitRepositoryError,
9+
)
10+
from dstack._internal.core.models.repos.remote import GitRepoURL, RemoteRepo
811
from dstack._internal.core.models.repos.virtual import VirtualRepo
912
from dstack._internal.utils.path import PathLike
1013
from dstack.api._public import Client
@@ -42,23 +45,21 @@ def get_repo_from_dir(repo_dir: PathLike) -> RemoteRepo:
4245
raise CLIError(f"Path is not a directory: {repo_dir}")
4346
try:
4447
return RemoteRepo.from_dir(repo_dir)
45-
except git.InvalidGitRepositoryError:
48+
except RepoInvalidGitRepositoryError:
4649
raise CLIError(
4750
f"Git repo not found: {repo_dir}\n"
4851
"Use `files` to mount an arbitrary directory:"
4952
" https://dstack.ai/docs/concepts/tasks/#files"
5053
)
51-
except git.GitError as e:
52-
raise CLIError(f"{e.__class__.__name__}: {e}") from e
54+
except RepoDetachedHeadError:
55+
raise CLIError(f"Git repo in 'detached HEAD' state: {repo_dir}\nCheck out to a branch")
5356
except RepoError as e:
5457
raise CLIError(str(e)) from e
5558

5659

5760
def get_repo_from_url(repo_url: str) -> RemoteRepo:
5861
try:
5962
return RemoteRepo.from_url(repo_url)
60-
except git.GitError as e:
61-
raise CLIError(f"{e.__class__.__name__}: {e}") from e
6263
except RepoError as e:
6364
raise CLIError(str(e)) from e
6465

src/dstack/_internal/core/errors.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,40 @@ class SSHPortInUseError(SSHError):
162162

163163
class DockerRegistryError(DstackError):
164164
pass
165+
166+
167+
class RepoError(DstackError):
168+
pass
169+
170+
171+
class RepoDetachedHeadError(RepoError):
172+
pass
173+
174+
175+
class RepoInvalidCredentialsError(RepoError):
176+
pass
177+
178+
179+
class RepoGitError(RepoError):
180+
"""
181+
A wrapper for `git.exc.GitError` and its subclasses.
182+
183+
Should be raised with `from e` clause to indicate the underlying exception.
184+
To build a message from the underlying exception, raise this exception without arguments.
185+
186+
try:
187+
...
188+
except git.GitError as e:
189+
raise RepoGitError() from e
190+
"""
191+
192+
def __str__(self) -> str:
193+
if self.args or self.__cause__ is None:
194+
return super().__str__()
195+
return f"{self.__cause__.__class__.__name__}: {self.__cause__}"
196+
197+
198+
class RepoInvalidGitRepositoryError(RepoGitError):
199+
"""
200+
`DstackError` counterpart for `git.exc.InvalidGitRepositoryError`.
201+
"""

src/dstack/_internal/core/models/repos/remote.py

Lines changed: 67 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,30 @@
33
import subprocess
44
import time
55
from dataclasses import dataclass
6-
from typing import Any, BinaryIO, Callable, Dict, Optional
6+
from typing import Annotated, Any, BinaryIO, Callable, Dict, Optional, Union, cast
77

88
import git
99
import pydantic
1010
from pydantic import Field
1111
from typing_extensions import Literal
1212

13-
from dstack._internal.core.errors import DstackError
13+
from dstack._internal.core.deprecated import Deprecated
14+
from dstack._internal.core.errors import (
15+
RepoDetachedHeadError,
16+
RepoError,
17+
RepoGitError,
18+
RepoInvalidGitRepositoryError,
19+
)
1420
from dstack._internal.core.models.common import CoreConfig, generate_dual_core_model
1521
from dstack._internal.core.models.repos.base import BaseRepoInfo, Repo
1622
from dstack._internal.utils.hash import get_sha256, slugify
23+
from dstack._internal.utils.logging import get_logger
1724
from dstack._internal.utils.path import PathLike
1825
from dstack._internal.utils.ssh import get_host_config
1926

20-
SCP_LOCATION_REGEX = re.compile(r"(?P<user>[^/]+)@(?P<host>[^/]+?):(?P<path>.+)", re.IGNORECASE)
21-
27+
logger = get_logger(__name__)
2228

23-
class RepoError(DstackError):
24-
pass
29+
SCP_LOCATION_REGEX = re.compile(r"(?P<user>[^/]+)@(?P<host>[^/]+?):(?P<path>.+)", re.IGNORECASE)
2530

2631

2732
class RemoteRepoCredsConfig(CoreConfig):
@@ -53,7 +58,7 @@ class RemoteRepoInfo(
5358
class RemoteRunRepoData(RemoteRepoInfo):
5459
repo_branch: Optional[str] = None
5560
repo_hash: Optional[str] = None
56-
repo_diff: Optional[str] = Field(None, exclude=True)
61+
repo_diff: Annotated[Optional[str], Field(exclude=True)] = None
5762
repo_config_name: Optional[str] = None
5863
repo_config_email: Optional[str] = None
5964

@@ -102,6 +107,7 @@ class RemoteRepo(Repo):
102107
"""
103108

104109
run_repo_data: RemoteRunRepoData
110+
repo_url: str
105111

106112
@staticmethod
107113
def from_dir(repo_dir: PathLike) -> "RemoteRepo":
@@ -143,45 +149,36 @@ def __init__(
143149
repo_id: Optional[str] = None,
144150
local_repo_dir: Optional[PathLike] = None,
145151
repo_url: Optional[str] = None,
146-
repo_data: Optional[RemoteRunRepoData] = None,
147152
repo_branch: Optional[str] = None,
148153
repo_hash: Optional[str] = None,
154+
repo_data: Union[Deprecated, RemoteRunRepoData, None] = Deprecated.PLACEHOLDER,
149155
):
150-
self.repo_dir = local_repo_dir
151-
self.repo_url = repo_url
152-
153-
if self.repo_dir is not None:
154-
repo = git.Repo(self.repo_dir)
155-
tracking_branch = repo.active_branch.tracking_branch()
156-
if tracking_branch is None:
157-
raise RepoError("No remote branch is configured")
158-
self.repo_url = repo.remote(tracking_branch.remote_name).url
159-
repo_data = RemoteRunRepoData.from_url(self.repo_url)
160-
repo_data.repo_branch = tracking_branch.remote_head
161-
repo_data.repo_hash = tracking_branch.commit.hexsha
162-
repo_data.repo_config_name = repo.config_reader().get_value("user", "name", "") or None
163-
repo_data.repo_config_email = (
164-
repo.config_reader().get_value("user", "email", "") or None
156+
if repo_data is not Deprecated.PLACEHOLDER:
157+
logger.warning(
158+
"The repo_data argument is deprecated, ignored, and will be removed soon."
159+
" As it was always ignored, it's safe to remove it."
165160
)
166-
repo_data.repo_diff = _repo_diff_verbose(repo, repo_data.repo_hash)
167-
elif self.repo_url is not None:
168-
repo_data = RemoteRunRepoData.from_url(self.repo_url)
169-
if repo_branch is not None:
170-
repo_data.repo_branch = repo_branch
171-
if repo_hash is not None:
172-
repo_data.repo_hash = repo_hash
173-
elif repo_data is None:
174-
raise RepoError("No remote repo data provided")
161+
# _init_from_* methods must set repo_dir, repo_url, and run_repo_data
162+
if local_repo_dir is not None:
163+
try:
164+
self._init_from_repo_dir(local_repo_dir)
165+
except git.InvalidGitRepositoryError as e:
166+
raise RepoInvalidGitRepositoryError() from e
167+
except git.GitError as e:
168+
raise RepoGitError() from e
169+
elif repo_url is not None:
170+
self._init_from_repo_url(repo_url, repo_branch, repo_hash)
171+
else:
172+
raise RepoError("Neither local repo dir nor repo URL provided")
175173

176174
if repo_id is None:
177175
repo_id = slugify(
178-
repo_data.repo_name,
176+
self.run_repo_data.repo_name,
179177
GitRepoURL.parse(
180178
self.repo_url, get_ssh_config=get_host_config
181179
).get_unique_location(),
182180
)
183181
self.repo_id = repo_id
184-
self.run_repo_data = repo_data
185182

186183
def write_code_file(self, fp: BinaryIO) -> str:
187184
if self.run_repo_data.repo_diff is not None:
@@ -191,6 +188,42 @@ def write_code_file(self, fp: BinaryIO) -> str:
191188
def get_repo_info(self) -> RemoteRepoInfo:
192189
return RemoteRepoInfo(repo_name=self.run_repo_data.repo_name)
193190

191+
def _init_from_repo_dir(self, repo_dir: PathLike):
192+
git_repo = git.Repo(repo_dir)
193+
if git_repo.head.is_detached:
194+
raise RepoDetachedHeadError()
195+
tracking_branch = git_repo.active_branch.tracking_branch()
196+
if tracking_branch is None:
197+
raise RepoError("No remote branch is configured")
198+
199+
repo_url = git_repo.remote(tracking_branch.remote_name).url
200+
repo_data = RemoteRunRepoData.from_url(repo_url)
201+
repo_data.repo_branch = tracking_branch.remote_head
202+
repo_data.repo_hash = tracking_branch.commit.hexsha
203+
git_config = git_repo.config_reader()
204+
if user_name := cast(str, git_config.get_value("user", "name", "")):
205+
repo_data.repo_config_name = user_name
206+
if user_email := cast(str, git_config.get_value("user", "email", "")):
207+
repo_data.repo_config_email = user_email
208+
repo_data.repo_diff = _repo_diff_verbose(git_repo, repo_data.repo_hash)
209+
210+
self.repo_dir = str(repo_dir)
211+
self.repo_url = repo_url
212+
self.run_repo_data = repo_data
213+
214+
def _init_from_repo_url(
215+
self, repo_url: str, repo_branch: Optional[str], repo_hash: Optional[str]
216+
):
217+
repo_data = RemoteRunRepoData.from_url(repo_url)
218+
if repo_branch is not None:
219+
repo_data.repo_branch = repo_branch
220+
if repo_hash is not None:
221+
repo_data.repo_hash = repo_hash
222+
223+
self.repo_dir = None
224+
self.repo_url = repo_url
225+
self.run_repo_data = repo_data
226+
194227

195228
class _DiffCollector:
196229
def __init__(self, warning_time: float, delay: float = 5):

0 commit comments

Comments
 (0)