11"""Svn repository."""
22
33import contextlib
4+ import functools
45import os
56import pathlib
67import re
78from collections .abc import Callable , Generator , Sequence
89from pathlib import Path
910from typing import NamedTuple
11+ from urllib .parse import urlparse
1012
1113from dfetch .log import get_logger
1214from dfetch .util .cmdline import SubprocessCommandError , run_on_cmdline
1517
1618logger = get_logger (__name__ )
1719
20+ _SSH_HOST_KEY_MSGS = ("host key verification failed" , "authenticity of host" )
21+
22+
23+ # As a cli tool, we can safely assume this remains stable during the runtime, caching for speed is better
24+ @functools .lru_cache
25+ def _extend_env_for_non_interactive_mode () -> dict [str , str ]:
26+ """Extend the environment vars for svn running in non-interactive mode."""
27+ env = os .environ .copy ()
28+ ssh_cmd = env .get ("SVN_SSH" , "ssh" )
29+ if "BatchMode=" not in ssh_cmd :
30+ ssh_cmd += " -o BatchMode=yes"
31+ else :
32+ logger .debug ('BatchMode already configured in SVN_SSH: "%s"' , ssh_cmd )
33+ env ["SVN_SSH" ] = ssh_cmd
34+ return env
35+
36+
37+ def _raise_if_ssh_host_key_error (url : str , exc : SubprocessCommandError ) -> None :
38+ """Raise a helpful RuntimeError if *exc* looks like an SSH host-key failure."""
39+ stderr_lower = exc .stderr .lower ()
40+ if any (msg in stderr_lower for msg in _SSH_HOST_KEY_MSGS ):
41+ parsed = urlparse (url )
42+ host = parsed .hostname or url
43+ target = f"{ parsed .username } @{ host } " if parsed .username else host
44+ raise RuntimeError (
45+ f"SSH host key verification failed while connecting to '{ url } '.\n "
46+ "Add the host to your known hosts file, for example by running:\n "
47+ f" ssh-keyscan { host } >> ~/.ssh/known_hosts\n "
48+ "Or test the SSH connection manually:\n "
49+ f" ssh -T { target } "
50+ ) from exc
51+
1852
1953def get_svn_version () -> tuple [str , str ]:
2054 """Get the name and version of svn."""
@@ -49,9 +83,14 @@ def __init__(self, remote: str) -> None:
4983 def is_svn (self ) -> bool :
5084 """Check if is SVN."""
5185 try :
52- run_on_cmdline (logger , ["svn" , "info" , self ._remote , "--non-interactive" ])
86+ run_on_cmdline (
87+ logger ,
88+ ["svn" , "info" , self ._remote , "--non-interactive" ],
89+ env = _extend_env_for_non_interactive_mode (),
90+ )
5391 return True
5492 except SubprocessCommandError as exc :
93+ _raise_if_ssh_host_key_error (self ._remote , exc )
5594 if exc .stderr .startswith ("svn: E170013" ):
5695 raise RuntimeError (
5796 f">>>{ exc .cmd } <<< failed!\n "
@@ -67,20 +106,30 @@ def list_of_branches(self) -> list[str]:
67106 result = run_on_cmdline (
68107 logger ,
69108 ["svn" , "ls" , "--non-interactive" , f"{ self ._remote } /branches" ],
109+ env = _extend_env_for_non_interactive_mode (),
70110 )
71111 return [
72112 line .strip ("/\r " )
73113 for line in result .stdout .decode ().splitlines ()
74114 if line .strip ("/\r " )
75115 ]
76- except (SubprocessCommandError , RuntimeError ):
116+ except SubprocessCommandError as exc :
117+ _raise_if_ssh_host_key_error (self ._remote , exc )
118+ return []
119+ except RuntimeError :
77120 return []
78121
79122 def list_of_tags (self ) -> list [str ]:
80123 """Get list of all available tags."""
81- result = run_on_cmdline (
82- logger , ["svn" , "ls" , "--non-interactive" , f"{ self ._remote } /tags" ]
83- )
124+ try :
125+ result = run_on_cmdline (
126+ logger ,
127+ ["svn" , "ls" , "--non-interactive" , f"{ self ._remote } /tags" ],
128+ env = _extend_env_for_non_interactive_mode (),
129+ )
130+ except SubprocessCommandError as exc :
131+ _raise_if_ssh_host_key_error (self ._remote , exc )
132+ raise
84133 return [
85134 str (tag ).strip ("/\r " ) for tag in result .stdout .decode ().split ("\n " ) if tag
86135 ]
@@ -116,7 +165,9 @@ def ls_tree(self, url_path: str) -> list[tuple[str, bool]]:
116165 """List immediate children of *url_path* as ``(name, is_dir)`` pairs."""
117166 try :
118167 result = run_on_cmdline (
119- logger , ["svn" , "ls" , "--non-interactive" , url_path ]
168+ logger ,
169+ ["svn" , "ls" , "--non-interactive" , url_path ],
170+ env = _extend_env_for_non_interactive_mode (),
120171 )
121172 entries : list [tuple [str , bool ]] = []
122173 for line in result .stdout .decode ().splitlines ():
@@ -126,7 +177,10 @@ def ls_tree(self, url_path: str) -> list[tuple[str, bool]]:
126177 is_dir = line .endswith ("/" )
127178 entries .append ((line .rstrip ("/" ), is_dir ))
128179 return entries
129- except (SubprocessCommandError , RuntimeError ):
180+ except SubprocessCommandError as exc :
181+ _raise_if_ssh_host_key_error (url_path , exc )
182+ return []
183+ except RuntimeError :
130184 return []
131185
132186
@@ -176,7 +230,13 @@ def externals_from_url(url: str, revision: str = "") -> list[External]:
176230 if revision :
177231 cmd += ["--revision" , revision ]
178232 cmd += [url ]
179- result = run_on_cmdline (logger , cmd )
233+ try :
234+ result = run_on_cmdline (
235+ logger , cmd , env = _extend_env_for_non_interactive_mode ()
236+ )
237+ except SubprocessCommandError as exc :
238+ _raise_if_ssh_host_key_error (url , exc )
239+ raise
180240 repo_root = SvnRepo .get_info_from_target (url )["Repository Root" ]
181241 normalized = SvnRepo ._normalize_url_prefix (result .stdout .decode (), url )
182242 return SvnRepo ._parse_externals (normalized , repo_root )
@@ -292,9 +352,12 @@ def get_info_from_target(target: str = "") -> dict[str, str]:
292352 """Get the info of the given target."""
293353 try :
294354 result = run_on_cmdline (
295- logger , ["svn" , "info" , "--non-interactive" , target .strip ()]
355+ logger ,
356+ ["svn" , "info" , "--non-interactive" , target .strip ()],
357+ env = _extend_env_for_non_interactive_mode (),
296358 ).stdout .decode ()
297359 except SubprocessCommandError as exc :
360+ _raise_if_ssh_host_key_error (target , exc )
298361 if exc .stderr .startswith ("svn: E170013" ):
299362 raise RuntimeError (
300363 f">>>{ exc .cmd } <<< failed!\n "
@@ -324,8 +387,8 @@ def get_last_changed_revision(target: str | Path) -> str:
324387 return parsed_version .group ("digits" )
325388 raise RuntimeError (f"svnversion output was unexpected: { version } " )
326389
327- return str (
328- run_on_cmdline (
390+ try :
391+ result = run_on_cmdline (
329392 logger ,
330393 [
331394 "svn" ,
@@ -335,10 +398,12 @@ def get_last_changed_revision(target: str | Path) -> str:
335398 "last-changed-revision" ,
336399 target_str ,
337400 ],
401+ env = _extend_env_for_non_interactive_mode (),
338402 )
339- .stdout .decode ()
340- .strip ()
341- )
403+ except SubprocessCommandError as exc :
404+ _raise_if_ssh_host_key_error (target_str , exc )
405+ raise
406+ return str (result .stdout .decode ().strip ())
342407
343408 @staticmethod
344409 def untracked_files (path : str , ignore : Sequence [str ]) -> list [str ]:
@@ -377,24 +442,31 @@ def export(url: str, rev: str = "", dst: str = ".") -> None:
377442 """
378443 if rev and not rev .isdigit ():
379444 raise ValueError (f"SVN revision must be digits only, got: { rev !r} " )
380- run_on_cmdline (
381- logger ,
382- ["svn" , "export" , "--non-interactive" , "--force" ]
383- + (["--revision" , rev ] if rev else [])
384- + [url , dst ],
385- )
445+ try :
446+ run_on_cmdline (
447+ logger ,
448+ ["svn" , "export" , "--non-interactive" , "--force" ]
449+ + (["--revision" , rev ] if rev else [])
450+ + [url , dst ],
451+ env = _extend_env_for_non_interactive_mode (),
452+ )
453+ except SubprocessCommandError as exc :
454+ _raise_if_ssh_host_key_error (url , exc )
455+ raise
386456
387457 @staticmethod
388458 def files_in_path (url_path : str ) -> list [str ]:
389459 """List all files in path at the given url."""
390- return [
391- str (line )
392- for line in run_on_cmdline (
393- logger , ["svn" , "list" , "--non-interactive" , url_path ]
460+ try :
461+ result = run_on_cmdline (
462+ logger ,
463+ ["svn" , "list" , "--non-interactive" , url_path ],
464+ env = _extend_env_for_non_interactive_mode (),
394465 )
395- .stdout .decode ()
396- .splitlines ()
397- ]
466+ except SubprocessCommandError as exc :
467+ _raise_if_ssh_host_key_error (url_path , exc )
468+ raise
469+ return [str (line ) for line in result .stdout .decode ().splitlines ()]
398470
399471 @staticmethod
400472 def ignored_files (path : str ) -> Sequence [str ]:
0 commit comments