11"""Svn repository."""
22
33import contextlib
4+ import functools
45import os
56import pathlib
67import re
7- from collections .abc import Callable , Generator , Sequence
8+ from collections .abc import Callable , Generator , Mapping , Sequence
89from pathlib import Path
10+ from types import MappingProxyType
911from typing import NamedTuple
12+ from urllib .parse import urlparse
1013
1114from dfetch .log import get_logger
1215from dfetch .util .cmdline import SubprocessCommandError , run_on_cmdline
1518
1619logger = get_logger (__name__ )
1720
21+ _SSH_HOST_KEY_MSGS = ("host key verification failed" , "authenticity of host" )
22+
23+
24+ class SshHostKeyError (RuntimeError ):
25+ """Raised when SVN cannot connect due to an untrusted SSH host key."""
26+
27+
28+ # As a cli tool, we can safely assume this remains stable during the runtime, caching for speed is better
29+ @functools .lru_cache
30+ def _extend_env_for_non_interactive_mode () -> Mapping [str , str ]:
31+ """Extend the environment vars for svn running in non-interactive mode."""
32+ env = os .environ .copy ()
33+ ssh_cmd = env .get ("SVN_SSH" , "ssh" )
34+ if "BatchMode=" not in ssh_cmd :
35+ ssh_cmd += " -o BatchMode=yes"
36+ else :
37+ logger .debug ('BatchMode already configured in SVN_SSH: "%s"' , ssh_cmd )
38+ env ["SVN_SSH" ] = ssh_cmd
39+ return MappingProxyType (env )
40+
41+
42+ def _raise_if_ssh_host_key_error (url : str , exc : SubprocessCommandError ) -> None :
43+ """Raise a helpful SshHostKeyError if *exc* looks like an SSH host-key failure."""
44+ stderr_lower = exc .stderr .lower ()
45+ if not any (msg in stderr_lower for msg in _SSH_HOST_KEY_MSGS ):
46+ return
47+ parsed = urlparse (url )
48+ if parsed .hostname :
49+ host = parsed .hostname
50+ target = f"{ parsed .username } @{ host } " if parsed .username else host
51+ raise SshHostKeyError (
52+ f"SSH host key verification failed while connecting to '{ url } '.\n "
53+ "Add the host to your known hosts file, for example by running:\n "
54+ f" ssh-keyscan { host } >> ~/.ssh/known_hosts\n "
55+ "Or test the SSH connection manually:\n "
56+ f" ssh -T { target } "
57+ ) from exc
58+ raise SshHostKeyError (
59+ "SSH host key verification failed while connecting to the repository.\n "
60+ "Add the repository's host to your known hosts file, for example by running:\n "
61+ " ssh-keyscan <host> >> ~/.ssh/known_hosts"
62+ ) from exc
63+
64+
65+ def _run_svn_raw (args : list [str ], * , url : str = "" ) -> bytes :
66+ """Run an svn subcommand and return raw stdout bytes.
67+
68+ Uses --non-interactive and the non-interactive SSH env on every call.
69+ SSH host-key failures are converted to SshHostKeyError so callers don't
70+ need to handle that case individually.
71+ """
72+ try :
73+ result = run_on_cmdline (
74+ logger ,
75+ ["svn" , "--non-interactive" ] + args ,
76+ env = _extend_env_for_non_interactive_mode (),
77+ )
78+ return bytes (result .stdout )
79+ except SubprocessCommandError as exc :
80+ _raise_if_ssh_host_key_error (url , exc )
81+ raise
82+
83+
84+ def _run_svn (args : list [str ], * , url : str = "" ) -> str :
85+ """Run an svn subcommand and return decoded stdout (see _run_svn_raw)."""
86+ return _run_svn_raw (args , url = url ).decode ()
87+
1888
1989def get_svn_version () -> tuple [str , str ]:
2090 """Get the name and version of svn."""
21- result = run_on_cmdline (logger , ["svn" , "--version" , "--non-interactive" ])
22- first_line = result .stdout .decode ().split ("\n " )[0 ]
91+ first_line = _run_svn (["--version" ]).split ("\n " , maxsplit = 1 )[0 ]
2392 if "version" not in first_line .lower ():
2493 raise RuntimeError (f"Unexpected svn --version output format: { first_line } " )
2594 tool , version = first_line .replace ("," , "" ).split ("version" , maxsplit = 1 )
@@ -49,8 +118,10 @@ def __init__(self, remote: str) -> None:
49118 def is_svn (self ) -> bool :
50119 """Check if is SVN."""
51120 try :
52- run_on_cmdline ( logger , [ "svn" , " info" , self ._remote , "--non-interactive" ] )
121+ _run_svn ([ " info" , self ._remote ], url = self . _remote )
53122 return True
123+ except SshHostKeyError :
124+ raise
54125 except SubprocessCommandError as exc :
55126 if exc .stderr .startswith ("svn: E170013" ):
56127 raise RuntimeError (
@@ -64,26 +135,19 @@ def is_svn(self) -> bool:
64135 def list_of_branches (self ) -> list [str ]:
65136 """List branch names from the ``branches/`` directory."""
66137 try :
67- result = run_on_cmdline (
68- logger ,
69- ["svn" , "ls" , "--non-interactive" , f"{ self ._remote } /branches" ],
70- )
138+ output = _run_svn (["ls" , f"{ self ._remote } /branches" ], url = self ._remote )
71139 return [
72- line .strip ("/\r " )
73- for line in result .stdout .decode ().splitlines ()
74- if line .strip ("/\r " )
140+ line .strip ("/\r " ) for line in output .splitlines () if line .strip ("/\r " )
75141 ]
142+ except SshHostKeyError :
143+ raise
76144 except (SubprocessCommandError , RuntimeError ):
77145 return []
78146
79147 def list_of_tags (self ) -> list [str ]:
80148 """Get list of all available tags."""
81- result = run_on_cmdline (
82- logger , ["svn" , "ls" , "--non-interactive" , f"{ self ._remote } /tags" ]
83- )
84- return [
85- str (tag ).strip ("/\r " ) for tag in result .stdout .decode ().split ("\n " ) if tag
86- ]
149+ output = _run_svn (["ls" , f"{ self ._remote } /tags" ], url = self ._remote )
150+ return [str (tag ).strip ("/\r " ) for tag in output .split ("\n " ) if tag ]
87151
88152 @contextlib .contextmanager
89153 def browse_tree (
@@ -103,6 +167,8 @@ def browse_tree(
103167 try :
104168 SvnRepo .get_info_from_target (branches_url )
105169 base_url = branches_url
170+ except SshHostKeyError :
171+ raise
106172 except RuntimeError :
107173 base_url = f"{ self ._remote } /tags/{ version } "
108174
@@ -115,17 +181,17 @@ def ls(path: str = "") -> list[tuple[str, bool]]:
115181 def ls_tree (self , url_path : str ) -> list [tuple [str , bool ]]:
116182 """List immediate children of *url_path* as ``(name, is_dir)`` pairs."""
117183 try :
118- result = run_on_cmdline (
119- logger , ["svn" , "ls" , "--non-interactive" , url_path ]
120- )
184+ output = _run_svn (["ls" , url_path ], url = url_path )
121185 entries : list [tuple [str , bool ]] = []
122- for line in result . stdout . decode () .splitlines ():
186+ for line in output .splitlines ():
123187 line = line .strip ("\r " )
124188 if not line :
125189 continue
126190 is_dir = line .endswith ("/" )
127191 entries .append ((line .rstrip ("/" ), is_dir ))
128192 return entries
193+ except SshHostKeyError :
194+ raise
129195 except (SubprocessCommandError , RuntimeError ):
130196 return []
131197
@@ -146,39 +212,25 @@ def is_svn(self) -> bool:
146212 """Check if is SVN."""
147213 try :
148214 with in_directory (self ._path ):
149- run_on_cmdline ( logger , [ "svn" , " info" , "--non-interactive " ])
215+ _run_svn ([ " info" ])
150216 return True
151217 except (SubprocessCommandError , RuntimeError ):
152218 return False
153219
154220 def externals (self ) -> list [External ]:
155221 """Get list of externals."""
156222 with in_directory (self ._path ):
157- result = run_on_cmdline (
158- logger ,
159- [
160- "svn" ,
161- "--non-interactive" ,
162- "propget" ,
163- "svn:externals" ,
164- "-R" ,
165- ],
166- )
223+ output = _run_svn (["propget" , "svn:externals" , "-R" ])
167224 repo_root = SvnRepo .get_info_from_target ()["Repository Root" ]
168- return SvnRepo ._parse_externals (
169- result .stdout .decode (), repo_root , toplevel = self ._path
170- )
225+ return SvnRepo ._parse_externals (output , repo_root , toplevel = self ._path )
171226
172227 @staticmethod
173228 def externals_from_url (url : str , revision : str = "" ) -> list [External ]:
174229 """Get list of externals from a remote SVN URL."""
175- cmd = ["svn" , "--non-interactive" , "propget" , "svn:externals" , "-R" ]
176- if revision :
177- cmd += ["--revision" , revision ]
178- cmd += [url ]
179- result = run_on_cmdline (logger , cmd )
230+ extra = ["--revision" , revision ] if revision else []
231+ output = _run_svn (["propget" , "svn:externals" , "-R" ] + extra + [url ], url = url )
180232 repo_root = SvnRepo .get_info_from_target (url )["Repository Root" ]
181- normalized = SvnRepo ._normalize_url_prefix (result . stdout . decode () , url )
233+ normalized = SvnRepo ._normalize_url_prefix (output , url )
182234 return SvnRepo ._parse_externals (normalized , repo_root )
183235
184236 @staticmethod
@@ -291,9 +343,7 @@ def _split_url(url: str, repo_root: str) -> tuple[str, str, str, str]:
291343 def get_info_from_target (target : str = "" ) -> dict [str , str ]:
292344 """Get the info of the given target."""
293345 try :
294- result = run_on_cmdline (
295- logger , ["svn" , "info" , "--non-interactive" , target .strip ()]
296- ).stdout .decode ()
346+ output = _run_svn (["info" , target .strip ()], url = target )
297347 except SubprocessCommandError as exc :
298348 if exc .stderr .startswith ("svn: E170013" ):
299349 raise RuntimeError (
@@ -306,7 +356,7 @@ def get_info_from_target(target: str = "") -> dict[str, str]:
306356 key .strip (): value .strip ()
307357 for key , value in (
308358 line .split (":" , maxsplit = 1 )
309- for line in result .split (os .linesep )
359+ for line in output .split (os .linesep )
310360 if line and ":" in line
311361 )
312362 }
@@ -324,36 +374,16 @@ def get_last_changed_revision(target: str | Path) -> str:
324374 return parsed_version .group ("digits" )
325375 raise RuntimeError (f"svnversion output was unexpected: { version } " )
326376
327- return str (
328- run_on_cmdline (
329- logger ,
330- [
331- "svn" ,
332- "info" ,
333- "--non-interactive" ,
334- "--show-item" ,
335- "last-changed-revision" ,
336- target_str ,
337- ],
338- )
339- .stdout .decode ()
340- .strip ()
341- )
377+ return _run_svn (
378+ ["info" , "--show-item" , "last-changed-revision" , target_str ],
379+ url = target_str ,
380+ ).strip ()
342381
343382 @staticmethod
344383 def untracked_files (path : str , ignore : Sequence [str ]) -> list [str ]:
345384 """Get list of untracked files in the working copy."""
346- result = (
347- run_on_cmdline (
348- logger ,
349- ["svn" , "status" , "--non-interactive" , path ],
350- )
351- .stdout .decode ()
352- .splitlines ()
353- )
354-
355385 files = []
356- for line in result :
386+ for line in _run_svn ([ "status" , path ]). splitlines () :
357387 if line .startswith ("?" ):
358388 file_path = line [1 :].strip ()
359389 if not any (
@@ -377,24 +407,15 @@ def export(url: str, rev: str = "", dst: str = ".") -> None:
377407 """
378408 if rev and not rev .isdigit ():
379409 raise ValueError (f"SVN revision must be digits only, got: { rev !r} " )
380- run_on_cmdline (
381- logger ,
382- ["svn" , "export" , "--non-interactive" , "--force" ]
383- + (["--revision" , rev ] if rev else [])
384- + [url , dst ],
410+ _run_svn (
411+ ["export" , "--force" ] + (["--revision" , rev ] if rev else []) + [url , dst ],
412+ url = url ,
385413 )
386414
387415 @staticmethod
388416 def files_in_path (url_path : str ) -> list [str ]:
389417 """List all files in path at the given url."""
390- return [
391- str (line )
392- for line in run_on_cmdline (
393- logger , ["svn" , "list" , "--non-interactive" , url_path ]
394- )
395- .stdout .decode ()
396- .splitlines ()
397- ]
418+ return _run_svn (["list" , url_path ], url = url_path ).splitlines ()
398419
399420 @staticmethod
400421 def ignored_files (path : str ) -> Sequence [str ]:
@@ -403,16 +424,9 @@ def ignored_files(path: str) -> Sequence[str]:
403424 return []
404425
405426 with in_directory (path ):
406- result = (
407- run_on_cmdline (
408- logger ,
409- ["svn" , "status" , "--non-interactive" , "--no-ignore" , "." ],
410- )
411- .stdout .decode ()
412- .splitlines ()
413- )
427+ lines = _run_svn (["status" , "--no-ignore" , "." ]).splitlines ()
414428
415- return [line [1 :].strip () for line in result if line .startswith ("I" )]
429+ return [line [1 :].strip () for line in lines if line .startswith ("I" )]
416430
417431 @staticmethod
418432 def any_changes_or_untracked (path : str ) -> bool :
@@ -421,18 +435,7 @@ def any_changes_or_untracked(path: str) -> bool:
421435 raise RuntimeError ("Path does not exist." )
422436
423437 with in_directory (path ):
424- return bool (
425- run_on_cmdline (
426- logger ,
427- [
428- "svn" ,
429- "status" ,
430- "." ,
431- ],
432- )
433- .stdout .decode ()
434- .splitlines ()
435- )
438+ return bool (_run_svn (["status" , "." ]).splitlines ())
436439
437440 def create_diff (
438441 self ,
@@ -441,7 +444,7 @@ def create_diff(
441444 ignore : Sequence [str ],
442445 ) -> Patch :
443446 """Generate a relative diff patch."""
444- cmd = ["svn" , " diff" , "--non-interactive " , "--ignore-properties" , "." ]
447+ cmd = ["diff" , "--ignore-properties" , "." ]
445448
446449 if old_revision :
447450 cmd .extend (
@@ -452,7 +455,7 @@ def create_diff(
452455 )
453456
454457 with in_directory (self ._path ):
455- patch_text = run_on_cmdline ( logger , cmd ). stdout
458+ patch_text = _run_svn_raw ( cmd )
456459
457460 if not patch_text .strip ():
458461 return Patch .empty ().convert_type (PatchType .SVN )
@@ -461,17 +464,6 @@ def create_diff(
461464 def get_username (self ) -> str :
462465 """Get the username of the local svn repo."""
463466 try :
464- result = run_on_cmdline (
465- logger ,
466- [
467- "svn" ,
468- "info" ,
469- "--non-interactive" ,
470- "--show-item" ,
471- "author" ,
472- self ._path ,
473- ],
474- )
475- return str (result .stdout .decode ().strip ())
467+ return _run_svn (["info" , "--show-item" , "author" , self ._path ]).strip ()
476468 except SubprocessCommandError :
477469 return ""
0 commit comments