|
6 | 6 | from typing import Any |
7 | 7 |
|
8 | 8 | from altimate_engine.connectors.base import Connector |
| 9 | +from altimate_engine.credential_store import resolve_config |
| 10 | +from altimate_engine.ssh_tunnel import start, stop |
| 11 | + |
| 12 | +SSH_FIELDS = { |
| 13 | + "ssh_host", |
| 14 | + "ssh_port", |
| 15 | + "ssh_user", |
| 16 | + "ssh_auth_type", |
| 17 | + "ssh_key_path", |
| 18 | + "ssh_password", |
| 19 | +} |
9 | 20 |
|
10 | 21 |
|
11 | 22 | class ConnectionRegistry: |
@@ -44,7 +55,33 @@ def get(cls, name: str) -> Connector: |
44 | 55 | if name not in cls._connections: |
45 | 56 | raise ValueError(f"Connection '{name}' not found in registry") |
46 | 57 |
|
47 | | - config = cls._connections[name] |
| 58 | + config = dict(cls._connections[name]) |
| 59 | + config = resolve_config(name, config) |
| 60 | + |
| 61 | + ssh_host = config.get("ssh_host") |
| 62 | + if ssh_host: |
| 63 | + if config.get("connection_string"): |
| 64 | + raise ValueError( |
| 65 | + "SSH tunneling requires explicit host/port — " |
| 66 | + "cannot be used with connection_string" |
| 67 | + ) |
| 68 | + ssh_config = { |
| 69 | + k: config.pop(k) for k in list(config.keys()) if k in SSH_FIELDS |
| 70 | + } |
| 71 | + local_port = start( |
| 72 | + name=name, |
| 73 | + ssh_host=ssh_config.get("ssh_host", ""), |
| 74 | + remote_host=config.get("host", "localhost"), |
| 75 | + remote_port=config.get("port", 5432), |
| 76 | + ssh_port=ssh_config.get("ssh_port", 22), |
| 77 | + ssh_user=ssh_config.get("ssh_user"), |
| 78 | + ssh_auth_type=ssh_config.get("ssh_auth_type", "key"), |
| 79 | + ssh_key_path=ssh_config.get("ssh_key_path"), |
| 80 | + ssh_password=ssh_config.get("ssh_password"), |
| 81 | + ) |
| 82 | + config["host"] = "127.0.0.1" |
| 83 | + config["port"] = local_port |
| 84 | + |
48 | 85 | dialect = config.get("type", "duckdb") |
49 | 86 |
|
50 | 87 | if dialect == "duckdb": |
@@ -219,3 +256,26 @@ def test(cls, name: str) -> dict[str, Any]: |
219 | 256 | return {"connected": True, "error": None} |
220 | 257 | except Exception as e: |
221 | 258 | return {"connected": False, "error": str(e)} |
| 259 | + finally: |
| 260 | + stop(name) |
| 261 | + |
| 262 | + @classmethod |
| 263 | + def add(cls, name: str, config: dict[str, Any]) -> dict[str, Any]: |
| 264 | + from altimate_engine.credential_store import save_connection |
| 265 | + |
| 266 | + result = save_connection(name, config) |
| 267 | + cls._loaded = False |
| 268 | + return result |
| 269 | + |
| 270 | + @classmethod |
| 271 | + def remove(cls, name: str) -> bool: |
| 272 | + from altimate_engine.credential_store import remove_connection |
| 273 | + |
| 274 | + result = remove_connection(name) |
| 275 | + cls._loaded = False |
| 276 | + return result |
| 277 | + |
| 278 | + @classmethod |
| 279 | + def reload(cls) -> None: |
| 280 | + cls._loaded = False |
| 281 | + cls._connections.clear() |
0 commit comments