Skip to content

Commit f586d20

Browse files
committed
Fix possible race condition
1 parent 734d392 commit f586d20

2 files changed

Lines changed: 45 additions & 13 deletions

File tree

sqlit/ui/mixins/autocomplete.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@ class AutocompleteMixin:
2121
_schema_cache: dict[str, Any] = {}
2222
_table_metadata: dict[str, tuple[str, str, str | None]] = {}
2323

24+
def _run_db_call(self: AppProtocol, fn: Any, *args: Any, **kwargs: Any) -> Any:
25+
session = getattr(self, "_session", None)
26+
if session is not None:
27+
return session.executor.submit(fn, *args, **kwargs).result()
28+
return fn(*args, **kwargs)
29+
2430
def _get_word_before_cursor(self, text: str, cursor_pos: int) -> tuple[str, str]:
2531
"""Get the current word being typed and the context keyword before it."""
2632
if cursor_pos <= 0 or cursor_pos > len(text):
@@ -103,7 +109,9 @@ def work() -> None:
103109
column_names = []
104110
else:
105111
try:
106-
columns = adapter.get_columns(connection, actual_table_name, database, schema_name)
112+
columns = self._run_db_call(
113+
adapter.get_columns, connection, actual_table_name, database, schema_name
114+
)
107115
column_names = [c.name for c in columns]
108116
except Exception:
109117
column_names = []
@@ -362,6 +370,12 @@ async def _load_schema_cache_async(self: AppProtocol) -> None:
362370
}
363371
table_metadata: dict[str, tuple[str, str, str | None]] = {}
364372

373+
async def run_db_call(fn: Any, *args: Any, **kwargs: Any) -> Any:
374+
session = getattr(self, "_session", None)
375+
if session is not None:
376+
return await session.executor.run_async(fn, *args, **kwargs)
377+
return await asyncio.to_thread(fn, *args, **kwargs)
378+
365379
try:
366380
# Get database list in thread
367381
databases: list[str | None]
@@ -370,7 +384,7 @@ async def _load_schema_cache_async(self: AppProtocol) -> None:
370384
if db and db.lower() not in ("", "master"):
371385
databases = [db]
372386
else:
373-
all_dbs = await asyncio.to_thread(adapter.get_databases, connection)
387+
all_dbs = await run_db_call(adapter.get_databases, connection)
374388
system_dbs = {"master", "tempdb", "model", "msdb"}
375389
databases = [d for d in all_dbs if d.lower() not in system_dbs]
376390
else:
@@ -379,7 +393,7 @@ async def _load_schema_cache_async(self: AppProtocol) -> None:
379393
for database in databases:
380394
try:
381395
# Get tables in thread (NO columns - lazy loaded)
382-
tables = await asyncio.to_thread(adapter.get_tables, connection, database)
396+
tables = await run_db_call(adapter.get_tables, connection, database)
383397
for schema_name, table_name in tables:
384398
display_name = adapter.format_table_name(schema_name, table_name)
385399
schema_cache["tables"].append(display_name)
@@ -392,7 +406,7 @@ async def _load_schema_cache_async(self: AppProtocol) -> None:
392406
table_metadata[full_name.lower()] = (schema_name, table_name, database)
393407

394408
# Get views in thread (NO columns - lazy loaded)
395-
views = await asyncio.to_thread(adapter.get_views, connection, database)
409+
views = await run_db_call(adapter.get_views, connection, database)
396410
for schema_name, view_name in views:
397411
display_name = adapter.format_table_name(schema_name, view_name)
398412
schema_cache["views"].append(display_name)
@@ -405,7 +419,7 @@ async def _load_schema_cache_async(self: AppProtocol) -> None:
405419
table_metadata[full_name.lower()] = (schema_name, view_name, database)
406420

407421
if adapter.supports_stored_procedures:
408-
procedures = await asyncio.to_thread(adapter.get_procedures, connection, database)
422+
procedures = await run_db_call(adapter.get_procedures, connection, database)
409423
schema_cache["procedures"].extend(procedures)
410424

411425
except Exception:

sqlit/ui/mixins/tree.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@
3232
class TreeMixin:
3333
"""Mixin providing tree/explorer functionality."""
3434

35+
def _run_db_call(self: AppProtocol, fn: Any, *args: Any, **kwargs: Any) -> Any:
36+
session = getattr(self, "_session", None)
37+
if session is not None:
38+
return session.executor.submit(fn, *args, **kwargs).result()
39+
return fn(*args, **kwargs)
40+
3541
def _db_type_badge(self, db_type: str) -> str:
3642
"""Get short badge for database type."""
3743
return get_badge_label(db_type)
@@ -166,7 +172,7 @@ def get_conn_label(config: Any, connected: Any = False) -> str:
166172
dbs_node = active_node.add("Databases")
167173
dbs_node.data = FolderNode(folder_type="databases")
168174

169-
databases = adapter.get_databases(self.current_connection)
175+
databases = self._run_db_call(adapter.get_databases, self.current_connection)
170176
for db_name in databases:
171177
db_node = dbs_node.add(escape_markup(db_name))
172178
db_node.data = DatabaseNode(name=db_name)
@@ -331,7 +337,7 @@ def work() -> None:
331337
else:
332338
adapter = self._session.adapter
333339
conn = self._session.connection
334-
columns = adapter.get_columns(conn, obj_name, db_name, schema_name)
340+
columns = self._run_db_call(adapter.get_columns, conn, obj_name, db_name, schema_name)
335341

336342
# Update UI from worker thread
337343
self.call_from_thread(self._on_columns_loaded, node, db_name, schema_name, obj_name, columns)
@@ -377,27 +383,39 @@ def work() -> None:
377383
conn = self._session.connection
378384

379385
if folder_type == "tables":
380-
items = [("table", s, t) for s, t in adapter.get_tables(conn, db_name)]
386+
items = [("table", s, t) for s, t in self._run_db_call(adapter.get_tables, conn, db_name)]
381387
elif folder_type == "views":
382-
items = [("view", s, v) for s, v in adapter.get_views(conn, db_name)]
388+
items = [("view", s, v) for s, v in self._run_db_call(adapter.get_views, conn, db_name)]
383389
elif folder_type == "indexes":
384390
if adapter.supports_indexes:
385-
items = [("index", i.name, i.table_name) for i in adapter.get_indexes(conn, db_name)]
391+
items = [
392+
("index", i.name, i.table_name)
393+
for i in self._run_db_call(adapter.get_indexes, conn, db_name)
394+
]
386395
else:
387396
items = []
388397
elif folder_type == "triggers":
389398
if adapter.supports_triggers:
390-
items = [("trigger", t.name, t.table_name) for t in adapter.get_triggers(conn, db_name)]
399+
items = [
400+
("trigger", t.name, t.table_name)
401+
for t in self._run_db_call(adapter.get_triggers, conn, db_name)
402+
]
391403
else:
392404
items = []
393405
elif folder_type == "sequences":
394406
if adapter.supports_sequences:
395-
items = [("sequence", s.name, "") for s in adapter.get_sequences(conn, db_name)]
407+
items = [
408+
("sequence", s.name, "")
409+
for s in self._run_db_call(adapter.get_sequences, conn, db_name)
410+
]
396411
else:
397412
items = []
398413
elif folder_type == "procedures":
399414
if adapter.supports_stored_procedures:
400-
items = [("procedure", "", p) for p in adapter.get_procedures(conn, db_name)]
415+
items = [
416+
("procedure", "", p)
417+
for p in self._run_db_call(adapter.get_procedures, conn, db_name)
418+
]
401419
else:
402420
items = []
403421
else:

0 commit comments

Comments
 (0)