Skip to content

Commit 2f8df17

Browse files
authored
feat: support SSH IdentityAgent config directive (#1630)
1 parent e3f41dc commit 2f8df17

3 files changed

Lines changed: 121 additions & 14 deletions

File tree

src/pyinfra/connectors/ssh.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import os
34
from random import uniform
45
from shutil import which
56
from socket import error as socket_error, gaierror
@@ -308,10 +309,23 @@ def _retry_paramiko_agent_keys(
308309
if not kwargs.get("allow_agent"):
309310
return False
310311

312+
# Honor IdentityAgent from SSH config
313+
identity_agent = getattr(self.client, "identity_agent", None)
314+
old_auth_sock = os.environ.get("SSH_AUTH_SOCK")
315+
if isinstance(identity_agent, str):
316+
os.environ["SSH_AUTH_SOCK"] = identity_agent
317+
else:
318+
identity_agent = None
311319
try:
312320
agent_keys = list(Agent().get_keys())
313321
except Exception:
314322
return False
323+
finally:
324+
if identity_agent:
325+
if old_auth_sock is not None:
326+
os.environ["SSH_AUTH_SOCK"] = old_auth_sock
327+
else:
328+
os.environ.pop("SSH_AUTH_SOCK", None)
315329

316330
if not agent_keys:
317331
return False

src/pyinfra/connectors/sshuserclient/client.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
source has now vanished (https://github.com/tobald/sshuserclient).
44
"""
55

6+
import os
67
from os import path
78

89
from gevent.lock import BoundedSemaphore
@@ -166,6 +167,7 @@ def connect( # type: ignore[override]
166167
missing_host_key_policy,
167168
host_keys_files,
168169
keep_alive,
170+
identity_agent,
169171
) = self.parse_config(
170172
hostname,
171173
kwargs,
@@ -188,7 +190,21 @@ def connect( # type: ignore[override]
188190
config.update(_pyinfra_ssh_paramiko_connect_kwargs)
189191

190192
self._ssh_config = config
191-
super().connect(hostname, **config)
193+
self.identity_agent = identity_agent
194+
195+
# Honor IdentityAgent from SSH config by temporarily setting SSH_AUTH_SOCK
196+
# so Paramiko's Agent class connects to the correct socket.
197+
old_auth_sock = os.environ.get("SSH_AUTH_SOCK")
198+
if identity_agent:
199+
os.environ["SSH_AUTH_SOCK"] = identity_agent
200+
try:
201+
super().connect(hostname, **config)
202+
finally:
203+
if identity_agent:
204+
if old_auth_sock is not None:
205+
os.environ["SSH_AUTH_SOCK"] = old_auth_sock
206+
else:
207+
os.environ.pop("SSH_AUTH_SOCK", None)
192208

193209
if _pyinfra_ssh_forward_agent is not None:
194210
forward_agent = _pyinfra_ssh_forward_agent
@@ -225,6 +241,7 @@ def parse_config(
225241

226242
keep_alive = 0
227243
forward_agent = False
244+
identity_agent = None
228245
missing_host_key_policy = get_missing_host_key_policy(strict_host_key_checking)
229246
host_keys_files = (path.expanduser("~/.ssh/known_hosts"),)
230247

@@ -237,6 +254,7 @@ def parse_config(
237254
missing_host_key_policy,
238255
host_keys_files,
239256
keep_alive,
257+
identity_agent,
240258
)
241259

242260
host_config = ssh_config.lookup(hostname)
@@ -269,6 +287,11 @@ def parse_config(
269287
if "serveraliveinterval" in host_config:
270288
keep_alive = int(host_config["serveraliveinterval"])
271289

290+
if "identityagent" in host_config:
291+
agent_path = host_config["identityagent"]
292+
if agent_path.lower() != "none":
293+
identity_agent = path.expanduser(agent_path)
294+
272295
if "proxycommand" in host_config:
273296
cfg["sock"] = ProxyCommand(host_config["proxycommand"])
274297

@@ -294,7 +317,15 @@ def parse_config(
294317
sock = c.gateway(hostname, cfg["port"], target, target_config["port"])
295318
cfg["sock"] = sock
296319

297-
return hostname, cfg, forward_agent, missing_host_key_policy, host_keys_files, keep_alive
320+
return (
321+
hostname,
322+
cfg,
323+
forward_agent,
324+
missing_host_key_policy,
325+
host_keys_files,
326+
keep_alive,
327+
identity_agent,
328+
)
298329

299330
@staticmethod
300331
def derive_shorthand(ssh_config, host_string):

tests/test_connectors/test_sshuserclient.py

Lines changed: 74 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,18 @@
4646
UserKnownHostsFile ~/.ssh/known_hosts ~/.ssh/known_hosts.infra ~/.ssh/known_hosts.webservers
4747
"""
4848

49+
SSH_CONFIG_IDENTITY_AGENT = """
50+
Host 10.0.0.1
51+
User agentuser
52+
IdentityAgent ~/Library/Group Containers/2BUA8C4S2C.com.1password/t/agent.sock
53+
"""
54+
55+
SSH_CONFIG_IDENTITY_AGENT_NONE = """
56+
Host 10.0.0.2
57+
User agentuser
58+
IdentityAgent none
59+
"""
60+
4961
BAD_SSH_CONFIG_DATA = """
5062
&
5163
"""
@@ -90,13 +102,20 @@ def setUp(self):
90102
def test_load_ssh_config_no_exist(self):
91103
client = SSHClient()
92104

93-
_, config, forward_agent, missing_host_key_policy, host_keys_file, keep_alive = (
94-
client.parse_config(
95-
"127.0.0.1",
96-
)
105+
(
106+
_,
107+
config,
108+
forward_agent,
109+
missing_host_key_policy,
110+
host_keys_file,
111+
keep_alive,
112+
identity_agent,
113+
) = client.parse_config(
114+
"127.0.0.1",
97115
)
98116

99117
assert config.get("port") == 22
118+
assert identity_agent is None
100119

101120

102121
@patch(
@@ -140,10 +159,16 @@ def setUp(self):
140159
def test_load_ssh_config(self):
141160
client = SSHClient()
142161

143-
_, config, forward_agent, missing_host_key_policy, host_keys_file, keep_alive = (
144-
client.parse_config(
145-
"127.0.0.1",
146-
)
162+
(
163+
_,
164+
config,
165+
forward_agent,
166+
missing_host_key_policy,
167+
host_keys_file,
168+
keep_alive,
169+
identity_agent,
170+
) = client.parse_config(
171+
"127.0.0.1",
147172
)
148173

149174
assert config.get("key_filename") == ["/id_rsa", "/id_rsa2"]
@@ -153,6 +178,7 @@ def test_load_ssh_config(self):
153178
assert forward_agent is False
154179
assert isinstance(missing_host_key_policy, AskPolicy)
155180
assert host_keys_file == ("~/.ssh/known_hosts",) # OpenSSH default
181+
assert identity_agent is None
156182

157183
(
158184
_,
@@ -161,6 +187,7 @@ def test_load_ssh_config(self):
161187
missing_host_key_policy,
162188
host_keys_file,
163189
keep_alive,
190+
identity_agent,
164191
) = client.parse_config("192.168.1.1")
165192

166193
assert other_config.get("username") == "otheruser"
@@ -177,9 +204,15 @@ def test_load_ssh_config_inline_comments(self):
177204
"""Test that inline comments are stripped from SSH config values (issue #1568)."""
178205
client = SSHClient()
179206

180-
_, config, forward_agent, missing_host_key_policy, host_keys_file, keep_alive = (
181-
client.parse_config("127.0.0.1")
182-
)
207+
(
208+
_,
209+
config,
210+
forward_agent,
211+
missing_host_key_policy,
212+
host_keys_file,
213+
keep_alive,
214+
identity_agent,
215+
) = client.parse_config("127.0.0.1")
183216

184217
assert config.get("key_filename") == ["/id_rsa"]
185218
assert config.get("username") == "testuser"
@@ -206,6 +239,7 @@ def test_load_ssh_config_multiple_known_hosts(self):
206239
missing_host_key_policy,
207240
host_keys_files,
208241
keep_alive,
242+
identity_agent,
209243
) = client.parse_config("192.168.1.3")
210244

211245
# Verify multiple known hosts files are parsed as a tuple
@@ -215,6 +249,34 @@ def test_load_ssh_config_multiple_known_hosts(self):
215249
"~/.ssh/known_hosts.webservers",
216250
)
217251

252+
@patch(
253+
"pyinfra.connectors.sshuserclient.client.open",
254+
mock_open(read_data=SSH_CONFIG_IDENTITY_AGENT),
255+
create=True,
256+
)
257+
def test_load_ssh_config_identity_agent(self):
258+
"""Test that IdentityAgent is parsed from SSH config."""
259+
client = SSHClient()
260+
261+
_, config, _, _, _, _, identity_agent = client.parse_config("10.0.0.1")
262+
263+
assert config.get("username") == "agentuser"
264+
assert identity_agent == "~/Library/Group Containers/2BUA8C4S2C.com.1password/t/agent.sock"
265+
266+
@patch(
267+
"pyinfra.connectors.sshuserclient.client.open",
268+
mock_open(read_data=SSH_CONFIG_IDENTITY_AGENT_NONE),
269+
create=True,
270+
)
271+
def test_load_ssh_config_identity_agent_none(self):
272+
"""Test that IdentityAgent set to 'none' is ignored."""
273+
client = SSHClient()
274+
275+
_, config, _, _, _, _, identity_agent = client.parse_config("10.0.0.2")
276+
277+
assert config.get("username") == "agentuser"
278+
assert identity_agent is None
279+
218280
@patch(
219281
"pyinfra.connectors.sshuserclient.client.open",
220282
mock_open(read_data=BAD_SSH_CONFIG_DATA),
@@ -262,7 +324,7 @@ def test_load_ssh_config_proxyjump(self, fake_gateway, fake_ssh_connect):
262324
client = SSHClient()
263325

264326
# Load the SSH config with ProxyJump configured
265-
_, config, forward_agent, _, _, _ = client.parse_config(
327+
_, config, forward_agent, _, _, _, _ = client.parse_config(
266328
"192.168.1.2",
267329
{"port": 1022},
268330
ssh_config_file="other_file",

0 commit comments

Comments
 (0)