1313from pathlib import Path
1414from tempfile import TemporaryDirectory
1515from typing import Annotated
16+ from urllib .parse import urlparse , urlunparse
1617
1718import sh
1819import yaml
@@ -136,13 +137,15 @@ def __init__(self, *, backend_url: ConfigSourceUrl) -> None:
136137 MAX_CS_CACHED_VERSIONS , DEFAULT_CS_CACHE_TTL
137138 )
138139 self ._read_raw_cache : Cache = LRUCache (MAX_CS_CACHED_VERSIONS )
140+ self .remote_url = self .exctract_remote_url (backend_url )
141+ self .git_branch = self .get_git_branch_from_url (backend_url )
139142
140143 @cachedmethod (lambda self : self ._latest_revision_cache )
141144 def latest_revision (self ) -> tuple [str , datetime ]:
142145 try :
143146 rev = sh .git (
144147 "rev-parse" ,
145- DEFAULT_GIT_BRANCH ,
148+ self . git_branch ,
146149 _cwd = self .repo_location ,
147150 _tty_out = False ,
148151 _async = is_running_in_async_context (),
@@ -192,6 +195,16 @@ def clear_caches(self):
192195 self ._latest_revision_cache .clear ()
193196 self ._read_raw_cache .clear ()
194197
198+ def exctract_remote_url (self , backend_url : ConfigSourceUrl ) -> str :
199+ """Extract the base URL without the 'git+' prefix and query parameters."""
200+ parsed_url = urlparse (str (backend_url ).replace ("git+" , "" ))
201+ remote_url = urlunparse (parsed_url ._replace (query = "" ))
202+ return remote_url
203+
204+ def get_git_branch_from_url (self , backend_url : ConfigSourceUrl ) -> str :
205+ """Extract the branch from the query parameters."""
206+ return dict (backend_url .query_params ()).get ("branch" , DEFAULT_GIT_BRANCH )
207+
195208
196209class LocalGitConfigSource (BaseGitConfigSource ):
197210 """The configuration is stored on a local git repository
@@ -219,6 +232,7 @@ def __init__(self, *, backend_url: ConfigSourceUrl) -> None:
219232 raise ValueError (
220233 f"{ self .repo_location } is not a valid git repository"
221234 ) from e
235+ sh .git .checkout (self .git_branch , _cwd = self .repo_location , _async = False )
222236
223237 def __hash__ (self ):
224238 return hash (self .repo_location )
@@ -234,14 +248,13 @@ def __init__(self, *, backend_url: ConfigSourceUrl) -> None:
234248 if not backend_url :
235249 raise ValueError ("No remote url for RemoteGitConfigSource" )
236250
237- # git does not understand `git+https`, so we remove the `git+` part
238- self .remote_url = str (backend_url ).replace ("git+" , "" )
239251 self ._temp_dir = TemporaryDirectory ()
240252 self .repo_location = Path (self ._temp_dir .name )
241253 sh .git .clone (self .remote_url , self .repo_location , _async = False )
242254 self ._pull_cache : Cache = TTLCache (
243255 MAX_PULL_CACHED_VERSIONS , DEFAULT_PULL_CACHE_TTL
244256 )
257+ sh .git .checkout (self .git_branch , _cwd = self .repo_location , _async = False )
245258
246259 def clear_caches (self ):
247260 super ().clear_caches ()
0 commit comments