66from click .testing import CliRunner
77from sshtunnel import SSHTunnelForwarder
88
9- from pgcli .main import cli , notify_callback , PGCli
9+ from pgcli .main import cli , PGCli
1010from pgcli .pgexecute import PGExecute
1111
1212
@@ -50,15 +50,13 @@ def test_ssh_tunnel(mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicM
5050 mock_pgexecute .assert_called_once ()
5151
5252 call_args , call_kwargs = mock_pgexecute .call_args
53- assert call_args == (
54- db_params ["database" ],
55- db_params ["user" ],
56- db_params ["passwd" ],
57- "127.0.0.1" ,
58- pgcli .ssh_tunnel .local_bind_ports [0 ],
59- "" ,
60- notify_callback ,
61- )
53+ # Original host is preserved for .pgpass lookup, hostaddr has tunnel endpoint
54+ assert call_args [0 ] == db_params ["database" ]
55+ assert call_args [1 ] == db_params ["user" ]
56+ assert call_args [2 ] == db_params ["passwd" ]
57+ assert call_args [3 ] == db_params ["host" ] # original host preserved
58+ assert call_args [4 ] == pgcli .ssh_tunnel .local_bind_ports [0 ]
59+ assert call_kwargs .get ("hostaddr" ) == "127.0.0.1"
6260 mock_ssh_tunnel_forwarder .reset_mock ()
6361 mock_pgexecute .reset_mock ()
6462
@@ -86,15 +84,9 @@ def test_ssh_tunnel(mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicM
8684 mock_pgexecute .assert_called_once ()
8785
8886 call_args , call_kwargs = mock_pgexecute .call_args
89- assert call_args == (
90- db_params ["database" ],
91- db_params ["user" ],
92- db_params ["passwd" ],
93- "127.0.0.1" ,
94- pgcli .ssh_tunnel .local_bind_ports [0 ],
95- "" ,
96- notify_callback ,
97- )
87+ assert call_args [3 ] == db_params ["host" ] # original host preserved
88+ assert call_args [4 ] == pgcli .ssh_tunnel .local_bind_ports [0 ]
89+ assert call_kwargs .get ("hostaddr" ) == "127.0.0.1"
9890 mock_ssh_tunnel_forwarder .reset_mock ()
9991 mock_pgexecute .reset_mock ()
10092
@@ -104,13 +96,51 @@ def test_ssh_tunnel(mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicM
10496 pgcli = PGCli (ssh_tunnel_url = tunnel_url )
10597 pgcli .connect (dsn = dsn )
10698
107- expected_dsn = f"user={ db_params ['user' ]} password={ db_params ['passwd' ]} host=127.0.0.1 port={ pgcli .ssh_tunnel .local_bind_ports [0 ]} "
108-
10999 mock_ssh_tunnel_forwarder .assert_called_once_with (** expected_tunnel_params )
110100 mock_pgexecute .assert_called_once ()
111101
112102 call_args , call_kwargs = mock_pgexecute .call_args
113- assert expected_dsn in call_args
103+ # DSN should contain original host AND hostaddr for tunnel
104+ dsn_arg = call_args [5 ]
105+ assert f"host={ db_params ['host' ]} " in dsn_arg
106+ assert "hostaddr=127.0.0.1" in dsn_arg
107+ assert f"port={ pgcli .ssh_tunnel .local_bind_ports [0 ]} " in dsn_arg
108+
109+
110+ def test_ssh_tunnel_preserves_original_host_for_pgpass (mock_ssh_tunnel_forwarder : MagicMock , mock_pgexecute : MagicMock ) -> None :
111+ """Verify that the original hostname is preserved for .pgpass lookup."""
112+ tunnel_url = "bastion.example.com"
113+ original_host = "production.db.example.com"
114+
115+ pgcli = PGCli (ssh_tunnel_url = tunnel_url )
116+ pgcli .connect (database = "mydb" , host = original_host , user = "dbuser" , passwd = "dbpass" )
117+
118+ call_args , call_kwargs = mock_pgexecute .call_args
119+ assert call_args [3 ] == original_host # host preserved
120+ assert call_kwargs .get ("hostaddr" ) == "127.0.0.1" # tunnel endpoint
121+
122+
123+ def test_ssh_tunnel_with_dsn_preserves_host (mock_ssh_tunnel_forwarder : MagicMock , mock_pgexecute : MagicMock ) -> None :
124+ """DSN connections should include hostaddr for tunnel while preserving host."""
125+ tunnel_url = "bastion.example.com"
126+ dsn = "host=production.db.example.com port=5432 dbname=mydb user=dbuser"
127+
128+ pgcli = PGCli (ssh_tunnel_url = tunnel_url )
129+ pgcli .connect (dsn = dsn )
130+
131+ call_args , call_kwargs = mock_pgexecute .call_args
132+ dsn_arg = call_args [5 ]
133+ assert "host=production.db.example.com" in dsn_arg
134+ assert "hostaddr=127.0.0.1" in dsn_arg
135+
136+
137+ def test_no_ssh_tunnel_does_not_set_hostaddr (mock_pgexecute : MagicMock ) -> None :
138+ """Without SSH tunnel, hostaddr should not be set."""
139+ pgcli = PGCli ()
140+ pgcli .connect (database = "mydb" , host = "localhost" , user = "user" , passwd = "pass" )
141+
142+ call_args , call_kwargs = mock_pgexecute .call_args
143+ assert "hostaddr" not in call_kwargs
114144
115145
116146def test_cli_with_tunnel () -> None :
0 commit comments