Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/pyinfra/connectors/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
CommandOutput,
execute_command_with_sudo_retry,
make_unix_command_for_host,
output_indicates_sudo_password_failure,
run_local_process,
)

Expand Down Expand Up @@ -95,6 +96,7 @@ def execute_command() -> Tuple[int, CommandOutput]:
arguments,
execute_command,
)
self._log_sudo_auth_failure(arguments, combined_output)

if _success_exit_codes:
status = return_code in _success_exit_codes
Expand All @@ -103,6 +105,21 @@ def execute_command() -> Tuple[int, CommandOutput]:

return status, combined_output

def _log_sudo_auth_failure(
self,
arguments: "ConnectorArguments",
output: CommandOutput,
) -> None:
"""
Log sudo auth failures to aid debugging when cached credentials are cleared.
"""
if not arguments.get("_sudo"):
return
if output_indicates_sudo_password_failure(output):
logger.debug(
"Sudo authentication failed on localhost; cached password cleared",
)

@override
def put_file(
self,
Expand Down
18 changes: 18 additions & 0 deletions src/pyinfra/connectors/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
CommandOutput,
execute_command_with_sudo_retry,
make_unix_command_for_host,
output_indicates_sudo_password_failure,
read_output_buffers,
run_local_process,
write_stdin,
Expand Down Expand Up @@ -421,6 +422,7 @@ def execute_command() -> Tuple[int, CommandOutput]:
arguments,
execute_command,
)
self._log_sudo_auth_failure(arguments, combined_output)

if _success_exit_codes:
status = return_code in _success_exit_codes
Expand All @@ -429,6 +431,22 @@ def execute_command() -> Tuple[int, CommandOutput]:

return status, combined_output

def _log_sudo_auth_failure(
self,
arguments: "ConnectorArguments",
output: CommandOutput,
) -> None:
"""
Log sudo auth failures to aid debugging when cached credentials are cleared.
"""
if not arguments.get("_sudo"):
return
if output_indicates_sudo_password_failure(output):
logger.debug(
"Sudo authentication failed for %s; cached password cleared",
self.host.name,
)

@memoize
def get_file_transfer_connection(self) -> FileTransferClient | None:
assert self.client is not None
Expand Down
64 changes: 62 additions & 2 deletions src/pyinfra/connectors/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from socket import timeout as timeout_error
from subprocess import PIPE, Popen
from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union
from uuid import uuid4

import click
import gevent
Expand All @@ -22,13 +23,20 @@


SUDO_ASKPASS_ENV_VAR = "PYINFRA_SUDO_PASSWORD"
SUDO_ASKPASS_ONCE_ENV_VAR = "PYINFRA_SUDO_ASKPASS_ONCE_PATH"


SUDO_ASKPASS_COMMAND = r"""
temp=$(mktemp "${{TMPDIR:={0}}}/pyinfra-sudo-askpass-XXXXXXXXXXXX")
cat >"$temp"<<'__EOF__'
#!/bin/sh
printf '%s\n' "${1}"
if [ -n "${{{2}}}" ]; then
if [ -e "${{{2}}}" ]; then
exit 1
fi
: > "${{{2}}}"
fi
printf '%s\n' "${{{1}}}"
__EOF__
chmod 755 "$temp"
echo "$temp"
Expand Down Expand Up @@ -114,6 +122,21 @@ def stderr(self) -> str:
return "\n".join(self.stderr_lines)


def output_indicates_sudo_password_failure(output: CommandOutput) -> bool:
if not output or not output.combined_lines:
return False
for line in output.combined_lines:
message = line.line.strip()
if not message:
continue
normalized = message.lower()
if normalized == "sorry, try again.":
return True
if normalized.startswith("sudo:") and "incorrect password attempt" in normalized:
return True
return False


def read_buffer(
name: str,
io: Iterable,
Expand Down Expand Up @@ -214,6 +237,11 @@ def execute_command_with_sudo_retry(
return_code, output = execute_command()
break

if return_code != 0 and command_arguments.get("_sudo"):
if output_indicates_sudo_password_failure(output):
if host.connector_data.get("prompted_sudo_password"):
host.connector_data["prompted_sudo_password"] = None

return return_code, output


Expand All @@ -237,6 +265,12 @@ def remove_any_sudo_askpass_file(host) -> None:
host.run_shell_command("rm -f {0}".format(sudo_askpass_path))
host.connector_data["sudo_askpass_path"] = None

sudo_askpass_once_paths = host.connector_data.get("sudo_askpass_once_paths")
if sudo_askpass_once_paths:
for path in sudo_askpass_once_paths:
host.run_shell_command("rm -f {0}".format(shlex.quote(path)))
host.connector_data["sudo_askpass_once_paths"] = set()


@memoize
def _show_use_su_login_warning() -> None:
Expand Down Expand Up @@ -268,11 +302,25 @@ def _ensure_sudo_askpass_set_for_host(host: "Host"):
if host.connector_data.get("sudo_askpass_path"):
return
_, output = host.run_shell_command(
SUDO_ASKPASS_COMMAND.format(host.get_temp_dir_config(), SUDO_ASKPASS_ENV_VAR)
SUDO_ASKPASS_COMMAND.format(
host.get_temp_dir_config(),
SUDO_ASKPASS_ENV_VAR,
SUDO_ASKPASS_ONCE_ENV_VAR,
)
)
host.connector_data["sudo_askpass_path"] = shlex.quote(output.stdout_lines[0])


def _build_sudo_askpass_once_path(host: "Host") -> str:
temp_dir = host.get_temp_dir_config()
return "{0}/pyinfra-sudo-askpass-once-{1}".format(temp_dir, uuid4().hex)


def _track_sudo_askpass_once_path(host: "Host", path: str) -> None:
sudo_askpass_once_paths = host.connector_data.setdefault("sudo_askpass_once_paths", set())
sudo_askpass_once_paths.add(path)


def make_unix_command_for_host(
state: "State",
host: "Host",
Expand All @@ -292,6 +340,10 @@ def make_unix_command_for_host(
# Ensure the askpass path is correctly set and passed through
_ensure_sudo_askpass_set_for_host(host)
command_arguments["_sudo_askpass_path"] = host.connector_data["sudo_askpass_path"]
if not command_arguments.get("_sudo_askpass_attempt_path"):
attempt_path = _build_sudo_askpass_once_path(host)
command_arguments["_sudo_askpass_attempt_path"] = attempt_path
_track_sudo_askpass_once_path(host, attempt_path)
return make_unix_command(command, **command_arguments)


Expand All @@ -315,6 +367,7 @@ def make_unix_command(
_use_sudo_login=False,
_sudo_password="",
_sudo_askpass_path=None,
_sudo_askpass_attempt_path=None,
_preserve_sudo_env=False,
# Doas config
_doas=False,
Expand Down Expand Up @@ -356,6 +409,13 @@ def make_unix_command(
MaskString("{0}={1}".format(SUDO_ASKPASS_ENV_VAR, shlex.quote(_sudo_password))),
],
)
if _sudo_askpass_attempt_path:
command_bits.append(
"{0}={1}".format(
SUDO_ASKPASS_ONCE_ENV_VAR,
shlex.quote(_sudo_askpass_attempt_path),
)
)

if _sudo:
command_bits.extend(["sudo", "-H"])
Expand Down
15 changes: 14 additions & 1 deletion tests/test_connectors/test_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,12 +590,14 @@ def test_run_shell_command_error(self, fake_ssh_client):
assert len(out) == 2
assert out[0] is False

@mock.patch("pyinfra.connectors.util.uuid4")
@mock.patch("pyinfra.connectors.util.getpass")
@mock.patch("pyinfra.connectors.ssh.SSHClient")
def test_run_shell_command_sudo_password_automatic_prompt(
self,
fake_ssh_client,
fake_getpass,
fake_uuid4,
):
fake_ssh = mock.MagicMock()
first_fake_stdout = mock.MagicMock()
Expand Down Expand Up @@ -625,6 +627,7 @@ def test_run_shell_command_sudo_password_automatic_prompt(

fake_ssh_client.return_value = fake_ssh
fake_getpass.return_value = "password"
fake_uuid4.return_value.hex = "deadbeef"

inventory = make_inventory(hosts=("somehost",))
State(inventory, Config())
Expand All @@ -648,17 +651,20 @@ def test_run_shell_command_sudo_password_automatic_prompt(
(
"env SUDO_ASKPASS=/tmp/pyinfra-sudo-askpass-XXXXXXXXXXXX "
"PYINFRA_SUDO_PASSWORD=password "
"PYINFRA_SUDO_ASKPASS_ONCE_PATH=/tmp/pyinfra-sudo-askpass-once-deadbeef "
"sudo -H -A -k sh -c 'echo Šablony'"
),
get_pty=False,
)

@mock.patch("pyinfra.connectors.util.uuid4")
@mock.patch("pyinfra.connectors.util.getpass")
@mock.patch("pyinfra.connectors.ssh.SSHClient")
def test_run_shell_command_sudo_password_automatic_prompt_with_special_chars_in_password(
self,
fake_ssh_client,
fake_getpass,
fake_uuid4,
):
fake_ssh = mock.MagicMock()
first_fake_stdout = mock.MagicMock()
Expand Down Expand Up @@ -688,6 +694,7 @@ def test_run_shell_command_sudo_password_automatic_prompt_with_special_chars_in_

fake_ssh_client.return_value = fake_ssh
fake_getpass.return_value = "p@ss'word';"
fake_uuid4.return_value.hex = "deadbeef"

inventory = make_inventory(hosts=("somehost",))
State(inventory, Config())
Expand All @@ -711,6 +718,7 @@ def test_run_shell_command_sudo_password_automatic_prompt_with_special_chars_in_
(
"env SUDO_ASKPASS=/tmp/pyinfra-sudo-askpass-XXXXXXXXXXXX "
"""PYINFRA_SUDO_PASSWORD='p@ss'"'"'word'"'"';' """
"PYINFRA_SUDO_ASKPASS_ONCE_PATH=/tmp/pyinfra-sudo-askpass-once-deadbeef "
"sudo -H -A -k sh -c 'echo Šablony'"
),
get_pty=False,
Expand All @@ -719,14 +727,17 @@ def test_run_shell_command_sudo_password_automatic_prompt_with_special_chars_in_
# SSH file put/get tests
#

@mock.patch("pyinfra.connectors.util.uuid4")
@mock.patch("pyinfra.connectors.ssh.SSHClient")
@mock.patch("pyinfra.connectors.util.getpass")
def test_run_shell_command_retry_for_sudo_password(
self,
fake_getpass,
fake_ssh_client,
fake_uuid4,
):
fake_getpass.return_value = "PASSWORD"
fake_uuid4.return_value.hex = "deadbeef"

fake_ssh = mock.MagicMock()
fake_stdin = mock.MagicMock()
Expand All @@ -752,7 +763,9 @@ def test_run_shell_command_retry_for_sudo_password(
assert fake_getpass.called
fake_ssh.exec_command.assert_called_with(
"env SUDO_ASKPASS=/tmp/pyinfra-sudo-askpass-XXXXXXXXXXXX "
"PYINFRA_SUDO_PASSWORD=PASSWORD sudo -H -A -k sh -c 'echo hi'",
"PYINFRA_SUDO_PASSWORD=PASSWORD "
"PYINFRA_SUDO_ASKPASS_ONCE_PATH=/tmp/pyinfra-sudo-askpass-once-deadbeef "
"sudo -H -A -k sh -c 'echo hi'",
get_pty=False,
)

Expand Down
Loading