33import subprocess
44import time
55from dataclasses import dataclass
6- from typing import Any , BinaryIO , Callable , Dict , Optional
6+ from typing import Annotated , Any , BinaryIO , Callable , Dict , Optional , Union , cast
77
88import git
99import pydantic
1010from pydantic import Field
1111from typing_extensions import Literal
1212
13- from dstack ._internal .core .errors import DstackError
13+ from dstack ._internal .core .deprecated import Deprecated
14+ from dstack ._internal .core .errors import (
15+ RepoDetachedHeadError ,
16+ RepoError ,
17+ RepoGitError ,
18+ RepoInvalidGitRepositoryError ,
19+ )
1420from dstack ._internal .core .models .common import CoreConfig , generate_dual_core_model
1521from dstack ._internal .core .models .repos .base import BaseRepoInfo , Repo
1622from dstack ._internal .utils .hash import get_sha256 , slugify
23+ from dstack ._internal .utils .logging import get_logger
1724from dstack ._internal .utils .path import PathLike
1825from dstack ._internal .utils .ssh import get_host_config
1926
20- SCP_LOCATION_REGEX = re .compile (r"(?P<user>[^/]+)@(?P<host>[^/]+?):(?P<path>.+)" , re .IGNORECASE )
21-
27+ logger = get_logger (__name__ )
2228
23- class RepoError (DstackError ):
24- pass
29+ SCP_LOCATION_REGEX = re .compile (r"(?P<user>[^/]+)@(?P<host>[^/]+?):(?P<path>.+)" , re .IGNORECASE )
2530
2631
2732class RemoteRepoCredsConfig (CoreConfig ):
@@ -53,7 +58,7 @@ class RemoteRepoInfo(
5358class RemoteRunRepoData (RemoteRepoInfo ):
5459 repo_branch : Optional [str ] = None
5560 repo_hash : Optional [str ] = None
56- repo_diff : Optional [str ] = Field (None , exclude = True )
61+ repo_diff : Annotated [ Optional [str ], Field (exclude = True )] = None
5762 repo_config_name : Optional [str ] = None
5863 repo_config_email : Optional [str ] = None
5964
@@ -102,6 +107,7 @@ class RemoteRepo(Repo):
102107 """
103108
104109 run_repo_data : RemoteRunRepoData
110+ repo_url : str
105111
106112 @staticmethod
107113 def from_dir (repo_dir : PathLike ) -> "RemoteRepo" :
@@ -143,45 +149,36 @@ def __init__(
143149 repo_id : Optional [str ] = None ,
144150 local_repo_dir : Optional [PathLike ] = None ,
145151 repo_url : Optional [str ] = None ,
146- repo_data : Optional [RemoteRunRepoData ] = None ,
147152 repo_branch : Optional [str ] = None ,
148153 repo_hash : Optional [str ] = None ,
154+ repo_data : Union [Deprecated , RemoteRunRepoData , None ] = Deprecated .PLACEHOLDER ,
149155 ):
150- self .repo_dir = local_repo_dir
151- self .repo_url = repo_url
152-
153- if self .repo_dir is not None :
154- repo = git .Repo (self .repo_dir )
155- tracking_branch = repo .active_branch .tracking_branch ()
156- if tracking_branch is None :
157- raise RepoError ("No remote branch is configured" )
158- self .repo_url = repo .remote (tracking_branch .remote_name ).url
159- repo_data = RemoteRunRepoData .from_url (self .repo_url )
160- repo_data .repo_branch = tracking_branch .remote_head
161- repo_data .repo_hash = tracking_branch .commit .hexsha
162- repo_data .repo_config_name = repo .config_reader ().get_value ("user" , "name" , "" ) or None
163- repo_data .repo_config_email = (
164- repo .config_reader ().get_value ("user" , "email" , "" ) or None
156+ if repo_data is not Deprecated .PLACEHOLDER :
157+ logger .warning (
158+ "The repo_data argument is deprecated, ignored, and will be removed soon."
159+ " As it was always ignored, it's safe to remove it."
165160 )
166- repo_data .repo_diff = _repo_diff_verbose (repo , repo_data .repo_hash )
167- elif self .repo_url is not None :
168- repo_data = RemoteRunRepoData .from_url (self .repo_url )
169- if repo_branch is not None :
170- repo_data .repo_branch = repo_branch
171- if repo_hash is not None :
172- repo_data .repo_hash = repo_hash
173- elif repo_data is None :
174- raise RepoError ("No remote repo data provided" )
161+ # _init_from_* methods must set repo_dir, repo_url, and run_repo_data
162+ if local_repo_dir is not None :
163+ try :
164+ self ._init_from_repo_dir (local_repo_dir )
165+ except git .InvalidGitRepositoryError as e :
166+ raise RepoInvalidGitRepositoryError () from e
167+ except git .GitError as e :
168+ raise RepoGitError () from e
169+ elif repo_url is not None :
170+ self ._init_from_repo_url (repo_url , repo_branch , repo_hash )
171+ else :
172+ raise RepoError ("Neither local repo dir nor repo URL provided" )
175173
176174 if repo_id is None :
177175 repo_id = slugify (
178- repo_data .repo_name ,
176+ self . run_repo_data .repo_name ,
179177 GitRepoURL .parse (
180178 self .repo_url , get_ssh_config = get_host_config
181179 ).get_unique_location (),
182180 )
183181 self .repo_id = repo_id
184- self .run_repo_data = repo_data
185182
186183 def write_code_file (self , fp : BinaryIO ) -> str :
187184 if self .run_repo_data .repo_diff is not None :
@@ -191,6 +188,42 @@ def write_code_file(self, fp: BinaryIO) -> str:
191188 def get_repo_info (self ) -> RemoteRepoInfo :
192189 return RemoteRepoInfo (repo_name = self .run_repo_data .repo_name )
193190
191+ def _init_from_repo_dir (self , repo_dir : PathLike ):
192+ git_repo = git .Repo (repo_dir )
193+ if git_repo .head .is_detached :
194+ raise RepoDetachedHeadError ()
195+ tracking_branch = git_repo .active_branch .tracking_branch ()
196+ if tracking_branch is None :
197+ raise RepoError ("No remote branch is configured" )
198+
199+ repo_url = git_repo .remote (tracking_branch .remote_name ).url
200+ repo_data = RemoteRunRepoData .from_url (repo_url )
201+ repo_data .repo_branch = tracking_branch .remote_head
202+ repo_data .repo_hash = tracking_branch .commit .hexsha
203+ git_config = git_repo .config_reader ()
204+ if user_name := cast (str , git_config .get_value ("user" , "name" , "" )):
205+ repo_data .repo_config_name = user_name
206+ if user_email := cast (str , git_config .get_value ("user" , "email" , "" )):
207+ repo_data .repo_config_email = user_email
208+ repo_data .repo_diff = _repo_diff_verbose (git_repo , repo_data .repo_hash )
209+
210+ self .repo_dir = str (repo_dir )
211+ self .repo_url = repo_url
212+ self .run_repo_data = repo_data
213+
214+ def _init_from_repo_url (
215+ self , repo_url : str , repo_branch : Optional [str ], repo_hash : Optional [str ]
216+ ):
217+ repo_data = RemoteRunRepoData .from_url (repo_url )
218+ if repo_branch is not None :
219+ repo_data .repo_branch = repo_branch
220+ if repo_hash is not None :
221+ repo_data .repo_hash = repo_hash
222+
223+ self .repo_dir = None
224+ self .repo_url = repo_url
225+ self .run_repo_data = repo_data
226+
194227
195228class _DiffCollector :
196229 def __init__ (self , warning_time : float , delay : float = 5 ):
0 commit comments