Skip to content

Commit 496ba58

Browse files
committed
Enable .pgpass support for SSH tunnel connections
Preserve original hostname for .pgpass lookup using PostgreSQL's host/hostaddr parameters: host keeps the original DB hostname (for .pgpass and SSL), hostaddr gets 127.0.0.1 (the tunnel endpoint). Changes: - main.py: Use hostaddr instead of replacing host with 127.0.0.1 - pgexecute.py: Simplify DSN filtering to keep dsn, password, hostaddr - tests: Add 3 new tests, update existing to verify host preservation Made with ❤️ and 🤖 Claude
1 parent b5bc102 commit 496ba58

4 files changed

Lines changed: 65 additions & 27 deletions

File tree

changelog.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ Features:
88
reflects the current editing mode: beam in INSERT, block in NORMAL, underline in REPLACE.
99
Uses prompt_toolkit's ``ModalCursorShapeConfig``.
1010
* Add support of Python 3.14.
11+
* Enable ``.pgpass`` support for SSH tunnel connections.
12+
* Preserve original hostname for ``.pgpass`` lookup using PostgreSQL's ``hostaddr`` parameter
13+
* SSH tunnel endpoint (``127.0.0.1``) is passed via ``hostaddr``, keeping ``host`` for ``.pgpass``
14+
* Works with both DSN and host/port connection styles
1115

1216
Bug fixes:
1317
----------

pgcli/main.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -718,11 +718,15 @@ def should_ask_for_password(exc):
718718
self.logger.handlers = logger_handlers
719719

720720
atexit.register(self.ssh_tunnel.stop)
721-
host = "127.0.0.1"
721+
# Preserve original host for .pgpass lookup and SSL certificate verification.
722+
# Use hostaddr to specify the actual connection endpoint (SSH tunnel).
723+
hostaddr = "127.0.0.1"
722724
port = self.ssh_tunnel.local_bind_ports[0]
723725

724726
if dsn:
725-
dsn = make_conninfo(dsn, host=host, port=port)
727+
dsn = make_conninfo(dsn, host=host, hostaddr=hostaddr, port=port)
728+
else:
729+
kwargs["hostaddr"] = hostaddr
726730

727731
# Attempt to connect to the database.
728732
# Note that passwd may be empty on the first attempt. If connection

pgcli/pgexecute.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,8 @@ def connect(
212212
new_params.update(kwargs)
213213

214214
if new_params["dsn"]:
215-
new_params = {"dsn": new_params["dsn"], "password": new_params["password"]}
215+
# When using DSN, only keep dsn, password, and hostaddr (for SSH tunnels)
216+
new_params = {k: v for k, v in new_params.items() if k in ("dsn", "password", "hostaddr")}
216217

217218
if new_params["password"]:
218219
new_params["dsn"] = make_conninfo(new_params["dsn"], password=new_params.pop("password"))
@@ -505,8 +506,7 @@ def view_definition(self, spec):
505506
else:
506507
template = "CREATE OR REPLACE VIEW {name} AS \n{stmt}"
507508
return (
508-
psycopg.sql
509-
.SQL(template)
509+
psycopg.sql.SQL(template)
510510
.format(
511511
name=psycopg.sql.Identifier(result.nspname, result.relname),
512512
stmt=psycopg.sql.SQL(result.viewdef),

tests/test_ssh_tunnel.py

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from click.testing import CliRunner
77
from sshtunnel import SSHTunnelForwarder
88

9-
from pgcli.main import cli, notify_callback, PGCli
9+
from pgcli.main import cli, PGCli
1010
from pgcli.pgexecute import PGExecute
1111

1212

@@ -50,15 +50,13 @@ def test_ssh_tunnel(mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicM
5050
mock_pgexecute.assert_called_once()
5151

5252
call_args, call_kwargs = mock_pgexecute.call_args
53-
assert call_args == (
54-
db_params["database"],
55-
db_params["user"],
56-
db_params["passwd"],
57-
"127.0.0.1",
58-
pgcli.ssh_tunnel.local_bind_ports[0],
59-
"",
60-
notify_callback,
61-
)
53+
# Original host is preserved for .pgpass lookup, hostaddr has tunnel endpoint
54+
assert call_args[0] == db_params["database"]
55+
assert call_args[1] == db_params["user"]
56+
assert call_args[2] == db_params["passwd"]
57+
assert call_args[3] == db_params["host"] # original host preserved
58+
assert call_args[4] == pgcli.ssh_tunnel.local_bind_ports[0]
59+
assert call_kwargs.get("hostaddr") == "127.0.0.1"
6260
mock_ssh_tunnel_forwarder.reset_mock()
6361
mock_pgexecute.reset_mock()
6462

@@ -86,15 +84,9 @@ def test_ssh_tunnel(mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicM
8684
mock_pgexecute.assert_called_once()
8785

8886
call_args, call_kwargs = mock_pgexecute.call_args
89-
assert call_args == (
90-
db_params["database"],
91-
db_params["user"],
92-
db_params["passwd"],
93-
"127.0.0.1",
94-
pgcli.ssh_tunnel.local_bind_ports[0],
95-
"",
96-
notify_callback,
97-
)
87+
assert call_args[3] == db_params["host"] # original host preserved
88+
assert call_args[4] == pgcli.ssh_tunnel.local_bind_ports[0]
89+
assert call_kwargs.get("hostaddr") == "127.0.0.1"
9890
mock_ssh_tunnel_forwarder.reset_mock()
9991
mock_pgexecute.reset_mock()
10092

@@ -104,13 +96,51 @@ def test_ssh_tunnel(mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicM
10496
pgcli = PGCli(ssh_tunnel_url=tunnel_url)
10597
pgcli.connect(dsn=dsn)
10698

107-
expected_dsn = f"user={db_params['user']} password={db_params['passwd']} host=127.0.0.1 port={pgcli.ssh_tunnel.local_bind_ports[0]}"
108-
10999
mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params)
110100
mock_pgexecute.assert_called_once()
111101

112102
call_args, call_kwargs = mock_pgexecute.call_args
113-
assert expected_dsn in call_args
103+
# DSN should contain original host AND hostaddr for tunnel
104+
dsn_arg = call_args[5]
105+
assert f"host={db_params['host']}" in dsn_arg
106+
assert "hostaddr=127.0.0.1" in dsn_arg
107+
assert f"port={pgcli.ssh_tunnel.local_bind_ports[0]}" in dsn_arg
108+
109+
110+
def test_ssh_tunnel_preserves_original_host_for_pgpass(mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicMock) -> None:
111+
"""Verify that the original hostname is preserved for .pgpass lookup."""
112+
tunnel_url = "bastion.example.com"
113+
original_host = "production.db.example.com"
114+
115+
pgcli = PGCli(ssh_tunnel_url=tunnel_url)
116+
pgcli.connect(database="mydb", host=original_host, user="dbuser", passwd="dbpass")
117+
118+
call_args, call_kwargs = mock_pgexecute.call_args
119+
assert call_args[3] == original_host # host preserved
120+
assert call_kwargs.get("hostaddr") == "127.0.0.1" # tunnel endpoint
121+
122+
123+
def test_ssh_tunnel_with_dsn_preserves_host(mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicMock) -> None:
124+
"""DSN connections should include hostaddr for tunnel while preserving host."""
125+
tunnel_url = "bastion.example.com"
126+
dsn = "host=production.db.example.com port=5432 dbname=mydb user=dbuser"
127+
128+
pgcli = PGCli(ssh_tunnel_url=tunnel_url)
129+
pgcli.connect(dsn=dsn)
130+
131+
call_args, call_kwargs = mock_pgexecute.call_args
132+
dsn_arg = call_args[5]
133+
assert "host=production.db.example.com" in dsn_arg
134+
assert "hostaddr=127.0.0.1" in dsn_arg
135+
136+
137+
def test_no_ssh_tunnel_does_not_set_hostaddr(mock_pgexecute: MagicMock) -> None:
138+
"""Without SSH tunnel, hostaddr should not be set."""
139+
pgcli = PGCli()
140+
pgcli.connect(database="mydb", host="localhost", user="user", passwd="pass")
141+
142+
call_args, call_kwargs = mock_pgexecute.call_args
143+
assert "hostaddr" not in call_kwargs
114144

115145

116146
def test_cli_with_tunnel() -> None:

0 commit comments

Comments
 (0)