11"""Svn repository."""
22
33import contextlib
4+ import functools
45import os
56import pathlib
67import re
7- from collections .abc import Callable , Generator , Sequence
8+ from collections .abc import Callable , Generator , Mapping , Sequence
89from pathlib import Path
10+ from types import MappingProxyType
911from typing import NamedTuple
12+ from urllib .parse import urlparse
1013
1114from dfetch .log import get_logger
1215from dfetch .util .cmdline import SubprocessCommandError , run_on_cmdline
1518
1619logger = get_logger (__name__ )
1720
21+ _SSH_HOST_KEY_MSGS = ("host key verification failed" , "authenticity of host" )
22+
23+
24+ class SshHostKeyError (RuntimeError ):
25+ """Raised when SVN cannot connect due to an untrusted SSH host key."""
26+
27+
28+ # As a cli tool, we can safely assume this remains stable during the runtime, caching for speed is better
29+ @functools .lru_cache
30+ def _extend_env_for_non_interactive_mode () -> Mapping [str , str ]:
31+ """Extend the environment vars for svn running in non-interactive mode."""
32+ env = os .environ .copy ()
33+ ssh_cmd = env .get ("SVN_SSH" , "ssh" )
34+ if "BatchMode=" not in ssh_cmd :
35+ ssh_cmd += " -o BatchMode=yes"
36+ else :
37+ logger .debug ('BatchMode already configured in SVN_SSH: "%s"' , ssh_cmd )
38+ env ["SVN_SSH" ] = ssh_cmd
39+ return MappingProxyType (env )
40+
41+
42+ def _raise_if_ssh_host_key_error (url : str , exc : SubprocessCommandError ) -> None :
43+ """Raise a helpful SshHostKeyError if *exc* looks like an SSH host-key failure."""
44+ stderr_lower = exc .stderr .lower ()
45+ if any (msg in stderr_lower for msg in _SSH_HOST_KEY_MSGS ):
46+ parsed = urlparse (url )
47+ host = parsed .hostname or url
48+ target = f"{ parsed .username } @{ host } " if parsed .username else host
49+ raise SshHostKeyError (
50+ f"SSH host key verification failed while connecting to '{ url } '.\n "
51+ "Add the host to your known hosts file, for example by running:\n "
52+ f" ssh-keyscan { host } >> ~/.ssh/known_hosts\n "
53+ "Or test the SSH connection manually:\n "
54+ f" ssh -T { target } "
55+ ) from exc
56+
57+
58+ def _run_svn (args : list [str ], * , url : str = "" ) -> str :
59+ """Run an svn subcommand and return decoded stdout.
60+
61+ Uses --non-interactive and the non-interactive SSH env on every call.
62+ SSH host-key failures are converted to SshHostKeyError so callers don't
63+ need to handle that case individually.
64+ """
65+ try :
66+ result = run_on_cmdline (
67+ logger ,
68+ ["svn" , "--non-interactive" ] + args ,
69+ env = _extend_env_for_non_interactive_mode (),
70+ )
71+ return str (result .stdout .decode ())
72+ except SubprocessCommandError as exc :
73+ _raise_if_ssh_host_key_error (url , exc )
74+ raise
75+
1876
1977def get_svn_version () -> tuple [str , str ]:
2078 """Get the name and version of svn."""
21- result = run_on_cmdline (logger , ["svn" , "--version" , "--non-interactive" ])
22- first_line = result .stdout .decode ().split ("\n " )[0 ]
79+ first_line = _run_svn (["--version" ]).split ("\n " , maxsplit = 1 )[0 ]
2380 if "version" not in first_line .lower ():
2481 raise RuntimeError (f"Unexpected svn --version output format: { first_line } " )
2582 tool , version = first_line .replace ("," , "" ).split ("version" , maxsplit = 1 )
@@ -49,8 +106,10 @@ def __init__(self, remote: str) -> None:
49106 def is_svn (self ) -> bool :
50107 """Check if is SVN."""
51108 try :
52- run_on_cmdline ( logger , [ "svn" , " info" , self ._remote , "--non-interactive" ] )
109+ _run_svn ([ " info" , self ._remote ], url = self . _remote )
53110 return True
111+ except SshHostKeyError :
112+ raise
54113 except SubprocessCommandError as exc :
55114 if exc .stderr .startswith ("svn: E170013" ):
56115 raise RuntimeError (
@@ -64,26 +123,19 @@ def is_svn(self) -> bool:
64123 def list_of_branches (self ) -> list [str ]:
65124 """List branch names from the ``branches/`` directory."""
66125 try :
67- result = run_on_cmdline (
68- logger ,
69- ["svn" , "ls" , "--non-interactive" , f"{ self ._remote } /branches" ],
70- )
126+ output = _run_svn (["ls" , f"{ self ._remote } /branches" ], url = self ._remote )
71127 return [
72- line .strip ("/\r " )
73- for line in result .stdout .decode ().splitlines ()
74- if line .strip ("/\r " )
128+ line .strip ("/\r " ) for line in output .splitlines () if line .strip ("/\r " )
75129 ]
130+ except SshHostKeyError :
131+ raise
76132 except (SubprocessCommandError , RuntimeError ):
77133 return []
78134
79135 def list_of_tags (self ) -> list [str ]:
80136 """Get list of all available tags."""
81- result = run_on_cmdline (
82- logger , ["svn" , "ls" , "--non-interactive" , f"{ self ._remote } /tags" ]
83- )
84- return [
85- str (tag ).strip ("/\r " ) for tag in result .stdout .decode ().split ("\n " ) if tag
86- ]
137+ output = _run_svn (["ls" , f"{ self ._remote } /tags" ], url = self ._remote )
138+ return [str (tag ).strip ("/\r " ) for tag in output .split ("\n " ) if tag ]
87139
88140 @contextlib .contextmanager
89141 def browse_tree (
@@ -115,17 +167,17 @@ def ls(path: str = "") -> list[tuple[str, bool]]:
115167 def ls_tree (self , url_path : str ) -> list [tuple [str , bool ]]:
116168 """List immediate children of *url_path* as ``(name, is_dir)`` pairs."""
117169 try :
118- result = run_on_cmdline (
119- logger , ["svn" , "ls" , "--non-interactive" , url_path ]
120- )
170+ output = _run_svn (["ls" , url_path ], url = url_path )
121171 entries : list [tuple [str , bool ]] = []
122- for line in result . stdout . decode () .splitlines ():
172+ for line in output .splitlines ():
123173 line = line .strip ("\r " )
124174 if not line :
125175 continue
126176 is_dir = line .endswith ("/" )
127177 entries .append ((line .rstrip ("/" ), is_dir ))
128178 return entries
179+ except SshHostKeyError :
180+ raise
129181 except (SubprocessCommandError , RuntimeError ):
130182 return []
131183
@@ -146,39 +198,25 @@ def is_svn(self) -> bool:
146198 """Check if is SVN."""
147199 try :
148200 with in_directory (self ._path ):
149- run_on_cmdline ( logger , [ "svn" , " info" , "--non-interactive " ])
201+ _run_svn ([ " info" ])
150202 return True
151203 except (SubprocessCommandError , RuntimeError ):
152204 return False
153205
154206 def externals (self ) -> list [External ]:
155207 """Get list of externals."""
156208 with in_directory (self ._path ):
157- result = run_on_cmdline (
158- logger ,
159- [
160- "svn" ,
161- "--non-interactive" ,
162- "propget" ,
163- "svn:externals" ,
164- "-R" ,
165- ],
166- )
209+ output = _run_svn (["propget" , "svn:externals" , "-R" ])
167210 repo_root = SvnRepo .get_info_from_target ()["Repository Root" ]
168- return SvnRepo ._parse_externals (
169- result .stdout .decode (), repo_root , toplevel = self ._path
170- )
211+ return SvnRepo ._parse_externals (output , repo_root , toplevel = self ._path )
171212
172213 @staticmethod
173214 def externals_from_url (url : str , revision : str = "" ) -> list [External ]:
174215 """Get list of externals from a remote SVN URL."""
175- cmd = ["svn" , "--non-interactive" , "propget" , "svn:externals" , "-R" ]
176- if revision :
177- cmd += ["--revision" , revision ]
178- cmd += [url ]
179- result = run_on_cmdline (logger , cmd )
216+ extra = ["--revision" , revision ] if revision else []
217+ output = _run_svn (["propget" , "svn:externals" , "-R" ] + extra + [url ], url = url )
180218 repo_root = SvnRepo .get_info_from_target (url )["Repository Root" ]
181- normalized = SvnRepo ._normalize_url_prefix (result . stdout . decode () , url )
219+ normalized = SvnRepo ._normalize_url_prefix (output , url )
182220 return SvnRepo ._parse_externals (normalized , repo_root )
183221
184222 @staticmethod
@@ -291,9 +329,7 @@ def _split_url(url: str, repo_root: str) -> tuple[str, str, str, str]:
291329 def get_info_from_target (target : str = "" ) -> dict [str , str ]:
292330 """Get the info of the given target."""
293331 try :
294- result = run_on_cmdline (
295- logger , ["svn" , "info" , "--non-interactive" , target .strip ()]
296- ).stdout .decode ()
332+ output = _run_svn (["info" , target .strip ()], url = target )
297333 except SubprocessCommandError as exc :
298334 if exc .stderr .startswith ("svn: E170013" ):
299335 raise RuntimeError (
@@ -306,7 +342,7 @@ def get_info_from_target(target: str = "") -> dict[str, str]:
306342 key .strip (): value .strip ()
307343 for key , value in (
308344 line .split (":" , maxsplit = 1 )
309- for line in result .split (os .linesep )
345+ for line in output .split (os .linesep )
310346 if line and ":" in line
311347 )
312348 }
@@ -324,36 +360,16 @@ def get_last_changed_revision(target: str | Path) -> str:
324360 return parsed_version .group ("digits" )
325361 raise RuntimeError (f"svnversion output was unexpected: { version } " )
326362
327- return str (
328- run_on_cmdline (
329- logger ,
330- [
331- "svn" ,
332- "info" ,
333- "--non-interactive" ,
334- "--show-item" ,
335- "last-changed-revision" ,
336- target_str ,
337- ],
338- )
339- .stdout .decode ()
340- .strip ()
341- )
363+ return _run_svn (
364+ ["info" , "--show-item" , "last-changed-revision" , target_str ],
365+ url = target_str ,
366+ ).strip ()
342367
343368 @staticmethod
344369 def untracked_files (path : str , ignore : Sequence [str ]) -> list [str ]:
345370 """Get list of untracked files in the working copy."""
346- result = (
347- run_on_cmdline (
348- logger ,
349- ["svn" , "status" , "--non-interactive" , path ],
350- )
351- .stdout .decode ()
352- .splitlines ()
353- )
354-
355371 files = []
356- for line in result :
372+ for line in _run_svn ([ "status" , path ]). splitlines () :
357373 if line .startswith ("?" ):
358374 file_path = line [1 :].strip ()
359375 if not any (
@@ -377,24 +393,15 @@ def export(url: str, rev: str = "", dst: str = ".") -> None:
377393 """
378394 if rev and not rev .isdigit ():
379395 raise ValueError (f"SVN revision must be digits only, got: { rev !r} " )
380- run_on_cmdline (
381- logger ,
382- ["svn" , "export" , "--non-interactive" , "--force" ]
383- + (["--revision" , rev ] if rev else [])
384- + [url , dst ],
396+ _run_svn (
397+ ["export" , "--force" ] + (["--revision" , rev ] if rev else []) + [url , dst ],
398+ url = url ,
385399 )
386400
387401 @staticmethod
388402 def files_in_path (url_path : str ) -> list [str ]:
389403 """List all files in path at the given url."""
390- return [
391- str (line )
392- for line in run_on_cmdline (
393- logger , ["svn" , "list" , "--non-interactive" , url_path ]
394- )
395- .stdout .decode ()
396- .splitlines ()
397- ]
404+ return _run_svn (["list" , url_path ], url = url_path ).splitlines ()
398405
399406 @staticmethod
400407 def ignored_files (path : str ) -> Sequence [str ]:
@@ -403,16 +410,9 @@ def ignored_files(path: str) -> Sequence[str]:
403410 return []
404411
405412 with in_directory (path ):
406- result = (
407- run_on_cmdline (
408- logger ,
409- ["svn" , "status" , "--non-interactive" , "--no-ignore" , "." ],
410- )
411- .stdout .decode ()
412- .splitlines ()
413- )
413+ lines = _run_svn (["status" , "--no-ignore" , "." ]).splitlines ()
414414
415- return [line [1 :].strip () for line in result if line .startswith ("I" )]
415+ return [line [1 :].strip () for line in lines if line .startswith ("I" )]
416416
417417 @staticmethod
418418 def any_changes_or_untracked (path : str ) -> bool :
@@ -421,18 +421,7 @@ def any_changes_or_untracked(path: str) -> bool:
421421 raise RuntimeError ("Path does not exist." )
422422
423423 with in_directory (path ):
424- return bool (
425- run_on_cmdline (
426- logger ,
427- [
428- "svn" ,
429- "status" ,
430- "." ,
431- ],
432- )
433- .stdout .decode ()
434- .splitlines ()
435- )
424+ return bool (_run_svn (["status" , "." ]).splitlines ())
436425
437426 def create_diff (
438427 self ,
@@ -441,7 +430,7 @@ def create_diff(
441430 ignore : Sequence [str ],
442431 ) -> Patch :
443432 """Generate a relative diff patch."""
444- cmd = ["svn" , " diff" , "--non-interactive " , "--ignore-properties" , "." ]
433+ cmd = ["diff" , "--ignore-properties" , "." ]
445434
446435 if old_revision :
447436 cmd .extend (
@@ -452,26 +441,15 @@ def create_diff(
452441 )
453442
454443 with in_directory (self ._path ):
455- patch_text = run_on_cmdline ( logger , cmd ). stdout
444+ patch_text = _run_svn ( cmd )
456445
457446 if not patch_text .strip ():
458447 return Patch .empty ().convert_type (PatchType .SVN )
459- return Patch .from_bytes (patch_text ).filter (ignore )
448+ return Patch .from_string (patch_text ).filter (ignore )
460449
461450 def get_username (self ) -> str :
462451 """Get the username of the local svn repo."""
463452 try :
464- result = run_on_cmdline (
465- logger ,
466- [
467- "svn" ,
468- "info" ,
469- "--non-interactive" ,
470- "--show-item" ,
471- "author" ,
472- self ._path ,
473- ],
474- )
475- return str (result .stdout .decode ().strip ())
453+ return _run_svn (["info" , "--show-item" , "author" , self ._path ]).strip ()
476454 except SubprocessCommandError :
477455 return ""
0 commit comments