Skip to content

Commit 971b7b2

Browse files
committed
Connection dialog improvments. Revert provider lazy loading
1 parent 85aa924 commit 971b7b2

8 files changed

Lines changed: 448 additions & 245 deletions

File tree

.github/workflows/release.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ jobs:
104104
- name: Setup SSH for AUR
105105
run: |
106106
mkdir -p ~/.ssh
107-
echo "${{ secrets.AUR_SSH_KEY }}" > ~/.ssh/aur
107+
# Decode base64-encoded SSH key (encode with: base64 -w 0 < ~/.ssh/aur_key)
108+
echo "${{ secrets.AUR_SSH_KEY }}" | base64 -d > ~/.ssh/aur
108109
chmod 600 ~/.ssh/aur
109110
echo "Host aur.archlinux.org" >> ~/.ssh/config
110111
echo " IdentityFile ~/.ssh/aur" >> ~/.ssh/config

sqlit/db/providers.py

Lines changed: 53 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Canonical provider registry (Plan B).
1+
"""Canonical provider registry.
22
33
This module is the single source of truth for:
44
- supported provider ids (db_type)
@@ -8,67 +8,61 @@
88

99
from __future__ import annotations
1010

11+
from collections.abc import Iterable
1112
from dataclasses import dataclass
12-
from importlib import import_module
1313
from typing import TYPE_CHECKING
1414

15+
# Pre-import all schemas (no external dependencies)
16+
from .schema import (
17+
COCKROACHDB_SCHEMA,
18+
D1_SCHEMA,
19+
DUCKDB_SCHEMA,
20+
MARIADB_SCHEMA,
21+
MSSQL_SCHEMA,
22+
MYSQL_SCHEMA,
23+
ORACLE_SCHEMA,
24+
POSTGRESQL_SCHEMA,
25+
SQLITE_SCHEMA,
26+
SUPABASE_SCHEMA,
27+
TURSO_SCHEMA,
28+
ConnectionSchema,
29+
)
30+
31+
# Pre-import all adapter classes (they lazy-load their dependencies internally)
32+
from .adapters.cockroachdb import CockroachDBAdapter
33+
from .adapters.d1 import D1Adapter
34+
from .adapters.duckdb import DuckDBAdapter
35+
from .adapters.mariadb import MariaDBAdapter
36+
from .adapters.mssql import SQLServerAdapter
37+
from .adapters.mysql import MySQLAdapter
38+
from .adapters.oracle import OracleAdapter
39+
from .adapters.postgresql import PostgreSQLAdapter
40+
from .adapters.sqlite import SQLiteAdapter
41+
from .adapters.supabase import SupabaseAdapter
42+
from .adapters.turso import TursoAdapter
43+
1544
if TYPE_CHECKING:
16-
from collections.abc import Iterable
1745
from .adapters.base import DatabaseAdapter
18-
from .schema import ConnectionSchema
1946

2047

2148
@dataclass(frozen=True)
2249
class ProviderSpec:
23-
schema_path: tuple[str, str]
24-
adapter_path: tuple[str, str]
50+
schema: ConnectionSchema
51+
adapter_class: type["DatabaseAdapter"]
2552

2653

2754
PROVIDERS: dict[str, ProviderSpec] = {
28-
"mssql": ProviderSpec(
29-
schema_path=("sqlit.db.schema", "MSSQL_SCHEMA"),
30-
adapter_path=("sqlit.db.adapters.mssql", "SQLServerAdapter"),
31-
),
32-
"sqlite": ProviderSpec(
33-
schema_path=("sqlit.db.schema", "SQLITE_SCHEMA"),
34-
adapter_path=("sqlit.db.adapters.sqlite", "SQLiteAdapter"),
35-
),
36-
"postgresql": ProviderSpec(
37-
schema_path=("sqlit.db.schema", "POSTGRESQL_SCHEMA"),
38-
adapter_path=("sqlit.db.adapters.postgresql", "PostgreSQLAdapter"),
39-
),
40-
"mysql": ProviderSpec(
41-
schema_path=("sqlit.db.schema", "MYSQL_SCHEMA"),
42-
adapter_path=("sqlit.db.adapters.mysql", "MySQLAdapter"),
43-
),
44-
"oracle": ProviderSpec(
45-
schema_path=("sqlit.db.schema", "ORACLE_SCHEMA"),
46-
adapter_path=("sqlit.db.adapters.oracle", "OracleAdapter"),
47-
),
48-
"mariadb": ProviderSpec(
49-
schema_path=("sqlit.db.schema", "MARIADB_SCHEMA"),
50-
adapter_path=("sqlit.db.adapters.mariadb", "MariaDBAdapter"),
51-
),
52-
"duckdb": ProviderSpec(
53-
schema_path=("sqlit.db.schema", "DUCKDB_SCHEMA"),
54-
adapter_path=("sqlit.db.adapters.duckdb", "DuckDBAdapter"),
55-
),
56-
"cockroachdb": ProviderSpec(
57-
schema_path=("sqlit.db.schema", "COCKROACHDB_SCHEMA"),
58-
adapter_path=("sqlit.db.adapters.cockroachdb", "CockroachDBAdapter"),
59-
),
60-
"turso": ProviderSpec(
61-
schema_path=("sqlit.db.schema", "TURSO_SCHEMA"),
62-
adapter_path=("sqlit.db.adapters.turso", "TursoAdapter"),
63-
),
64-
"supabase": ProviderSpec(
65-
schema_path=("sqlit.db.schema", "SUPABASE_SCHEMA"),
66-
adapter_path=("sqlit.db.adapters.supabase", "SupabaseAdapter"),
67-
),
68-
"d1": ProviderSpec(
69-
schema_path=("sqlit.db.schema", "D1_SCHEMA"),
70-
adapter_path=("sqlit.db.adapters.d1", "D1Adapter"),
71-
),
55+
"mssql": ProviderSpec(schema=MSSQL_SCHEMA, adapter_class=SQLServerAdapter),
56+
"sqlite": ProviderSpec(schema=SQLITE_SCHEMA, adapter_class=SQLiteAdapter),
57+
"postgresql": ProviderSpec(schema=POSTGRESQL_SCHEMA, adapter_class=PostgreSQLAdapter),
58+
"mysql": ProviderSpec(schema=MYSQL_SCHEMA, adapter_class=MySQLAdapter),
59+
"oracle": ProviderSpec(schema=ORACLE_SCHEMA, adapter_class=OracleAdapter),
60+
"mariadb": ProviderSpec(schema=MARIADB_SCHEMA, adapter_class=MariaDBAdapter),
61+
"duckdb": ProviderSpec(schema=DUCKDB_SCHEMA, adapter_class=DuckDBAdapter),
62+
"cockroachdb": ProviderSpec(schema=COCKROACHDB_SCHEMA, adapter_class=CockroachDBAdapter),
63+
"turso": ProviderSpec(schema=TURSO_SCHEMA, adapter_class=TursoAdapter),
64+
"supabase": ProviderSpec(schema=SUPABASE_SCHEMA, adapter_class=SupabaseAdapter),
65+
"d1": ProviderSpec(schema=D1_SCHEMA, adapter_class=D1Adapter),
7266
}
7367

7468

@@ -77,7 +71,7 @@ def get_supported_db_types() -> list[str]:
7771

7872

7973
def iter_provider_schemas() -> Iterable[ConnectionSchema]:
80-
return (_get_schema(spec) for spec in PROVIDERS.values())
74+
return (spec.schema for spec in PROVIDERS.values())
8175

8276

8377
def get_provider_spec(db_type: str) -> ProviderSpec:
@@ -87,18 +81,12 @@ def get_provider_spec(db_type: str) -> ProviderSpec:
8781
return spec
8882

8983

90-
def _get_schema(spec: ProviderSpec) -> ConnectionSchema:
91-
module_name, attr_name = spec.schema_path
92-
module = import_module(module_name)
93-
return getattr(module, attr_name)
94-
95-
9684
def get_connection_schema(db_type: str) -> ConnectionSchema:
97-
return _get_schema(get_provider_spec(db_type))
85+
return get_provider_spec(db_type).schema
9886

9987

10088
def get_all_schemas() -> dict[str, ConnectionSchema]:
101-
return {k: _get_schema(v) for k, v in PROVIDERS.items()}
89+
return {k: v.schema for k, v in PROVIDERS.items()}
10290

10391

10492
def get_adapter(db_type: str) -> "DatabaseAdapter":
@@ -109,35 +97,31 @@ def get_adapter(db_type: str) -> "DatabaseAdapter":
10997

11098

11199
def get_adapter_class(db_type: str) -> type["DatabaseAdapter"]:
112-
spec = get_provider_spec(db_type)
113-
module_name, class_name = spec.adapter_path
114-
module = import_module(module_name)
115-
adapter_cls = getattr(module, class_name)
116-
return adapter_cls
100+
return get_provider_spec(db_type).adapter_class
117101

118102

119103
def get_default_port(db_type: str) -> str:
120104
spec = PROVIDERS.get(db_type)
121105
if spec is None:
122106
return "1433"
123-
return _get_schema(spec).default_port
107+
return spec.schema.default_port
124108

125109

126110
def get_display_name(db_type: str) -> str:
127111
spec = PROVIDERS.get(db_type)
128-
return _get_schema(spec).display_name if spec else db_type
112+
return spec.schema.display_name if spec else db_type
129113

130114

131115
def supports_ssh(db_type: str) -> bool:
132116
spec = PROVIDERS.get(db_type)
133-
return _get_schema(spec).supports_ssh if spec else False
117+
return spec.schema.supports_ssh if spec else False
134118

135119

136120
def is_file_based(db_type: str) -> bool:
137121
spec = PROVIDERS.get(db_type)
138-
return _get_schema(spec).is_file_based if spec else False
122+
return spec.schema.is_file_based if spec else False
139123

140124

141125
def has_advanced_auth(db_type: str) -> bool:
142126
spec = PROVIDERS.get(db_type)
143-
return _get_schema(spec).has_advanced_auth if spec else False
127+
return spec.schema.has_advanced_auth if spec else False

sqlit/ui/mixins/connection.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -299,10 +299,15 @@ def handle_connection_result(self: AppProtocol, result: tuple | None) -> None:
299299
if not result:
300300
return
301301

302-
action, config = result
302+
action, config = result[0], result[1]
303+
original_name = result[2] if len(result) > 2 else None
303304

304305
if action == "save":
305-
def do_save(with_config) -> None: # noqa: ANN001
306+
def do_save(with_config, orig_name=None) -> None: # noqa: ANN001
307+
# When editing, remove by original name to properly update renamed connections
308+
if orig_name:
309+
self.connections = [c for c in self.connections if c.name != orig_name]
310+
# Also remove by new name to handle overwrites/duplicates
306311
self.connections = [c for c in self.connections if c.name != with_config.name]
307312
self.connections.append(with_config)
308313
if getattr(self, "_mock_profile", None):
@@ -319,13 +324,13 @@ def do_save(with_config) -> None: # noqa: ANN001
319324

320325
if allow_plaintext is True:
321326
reset_credentials_service()
322-
do_save(config)
327+
do_save(config, original_name)
323328
return
324329

325330
if allow_plaintext is False:
326331
config.password = ""
327332
config.ssh_password = ""
328-
do_save(config)
333+
do_save(config, original_name)
329334
self.notify("Keyring unavailable: passwords will be prompted when needed", severity="warning")
330335
return
331336

@@ -335,15 +340,15 @@ def on_confirm(confirmed: bool | None) -> None:
335340
settings2[ALLOW_PLAINTEXT_CREDENTIALS_SETTING] = True
336341
save_settings(settings2)
337342
reset_credentials_service()
338-
do_save(config)
343+
do_save(config, original_name)
339344
self.notify("Saved passwords as plaintext in ~/.sqlit/ (0600)", severity="warning")
340345
return
341346

342347
settings2[ALLOW_PLAINTEXT_CREDENTIALS_SETTING] = False
343348
save_settings(settings2)
344349
config.password = ""
345350
config.ssh_password = ""
346-
do_save(config)
351+
do_save(config, original_name)
347352
self.notify("Passwords were not saved (keyring unavailable)", severity="warning")
348353

349354
self.push_screen(
@@ -357,7 +362,7 @@ def on_confirm(confirmed: bool | None) -> None:
357362
)
358363
return
359364

360-
do_save(config)
365+
do_save(config, original_name)
361366

362367
def action_duplicate_connection(self: AppProtocol) -> None:
363368
from dataclasses import replace

sqlit/ui/screens/__init__.py

Lines changed: 13 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,18 @@
11
"""Modal screens for sqlit."""
22

3-
from importlib import import_module
4-
from typing import TYPE_CHECKING, Any
3+
from .confirm import ConfirmScreen
4+
from .connection import ConnectionScreen
5+
from .connection_picker import ConnectionPickerScreen
6+
from .driver_setup import DriverSetupScreen
7+
from .error import ErrorScreen
8+
from .help import HelpScreen
9+
from .leader_menu import LeaderMenuScreen
10+
from .message import MessageScreen
11+
from .package_setup import PackageSetupScreen
12+
from .password_input import PasswordInputScreen
13+
from .query_history import QueryHistoryScreen
14+
from .theme import ThemeScreen
15+
from .value_view import ValueViewScreen
516

617
__all__ = [
718
"ConfirmScreen",
@@ -18,43 +29,3 @@
1829
"ThemeScreen",
1930
"ValueViewScreen",
2031
]
21-
22-
_LAZY_ATTRS: dict[str, tuple[str, str]] = {
23-
"ConfirmScreen": ("sqlit.ui.screens.confirm", "ConfirmScreen"),
24-
"ConnectionScreen": ("sqlit.ui.screens.connection", "ConnectionScreen"),
25-
"ConnectionPickerScreen": ("sqlit.ui.screens.connection_picker", "ConnectionPickerScreen"),
26-
"DriverSetupScreen": ("sqlit.ui.screens.driver_setup", "DriverSetupScreen"),
27-
"ErrorScreen": ("sqlit.ui.screens.error", "ErrorScreen"),
28-
"HelpScreen": ("sqlit.ui.screens.help", "HelpScreen"),
29-
"LeaderMenuScreen": ("sqlit.ui.screens.leader_menu", "LeaderMenuScreen"),
30-
"MessageScreen": ("sqlit.ui.screens.message", "MessageScreen"),
31-
"PackageSetupScreen": ("sqlit.ui.screens.package_setup", "PackageSetupScreen"),
32-
"PasswordInputScreen": ("sqlit.ui.screens.password_input", "PasswordInputScreen"),
33-
"QueryHistoryScreen": ("sqlit.ui.screens.query_history", "QueryHistoryScreen"),
34-
"ThemeScreen": ("sqlit.ui.screens.theme", "ThemeScreen"),
35-
"ValueViewScreen": ("sqlit.ui.screens.value_view", "ValueViewScreen"),
36-
}
37-
38-
if TYPE_CHECKING:
39-
from .confirm import ConfirmScreen
40-
from .connection import ConnectionScreen
41-
from .connection_picker import ConnectionPickerScreen
42-
from .driver_setup import DriverSetupScreen
43-
from .error import ErrorScreen
44-
from .help import HelpScreen
45-
from .leader_menu import LeaderMenuScreen
46-
from .message import MessageScreen
47-
from .package_setup import PackageSetupScreen
48-
from .password_input import PasswordInputScreen
49-
from .query_history import QueryHistoryScreen
50-
from .theme import ThemeScreen
51-
from .value_view import ValueViewScreen
52-
53-
54-
def __getattr__(name: str) -> Any:
55-
target = _LAZY_ATTRS.get(name)
56-
if target is None:
57-
raise AttributeError(name)
58-
module_name, attr_name = target
59-
module = import_module(module_name)
60-
return getattr(module, attr_name)

0 commit comments

Comments
 (0)