@@ -21,6 +21,12 @@ class AutocompleteMixin:
2121 _schema_cache : dict [str , Any ] = {}
2222 _table_metadata : dict [str , tuple [str , str , str | None ]] = {}
2323
24+ def _run_db_call (self : AppProtocol , fn : Any , * args : Any , ** kwargs : Any ) -> Any :
25+ session = getattr (self , "_session" , None )
26+ if session is not None :
27+ return session .executor .submit (fn , * args , ** kwargs ).result ()
28+ return fn (* args , ** kwargs )
29+
2430 def _get_word_before_cursor (self , text : str , cursor_pos : int ) -> tuple [str , str ]:
2531 """Get the current word being typed and the context keyword before it."""
2632 if cursor_pos <= 0 or cursor_pos > len (text ):
@@ -103,7 +109,9 @@ def work() -> None:
103109 column_names = []
104110 else :
105111 try :
106- columns = adapter .get_columns (connection , actual_table_name , database , schema_name )
112+ columns = self ._run_db_call (
113+ adapter .get_columns , connection , actual_table_name , database , schema_name
114+ )
107115 column_names = [c .name for c in columns ]
108116 except Exception :
109117 column_names = []
@@ -362,6 +370,12 @@ async def _load_schema_cache_async(self: AppProtocol) -> None:
362370 }
363371 table_metadata : dict [str , tuple [str , str , str | None ]] = {}
364372
373+ async def run_db_call (fn : Any , * args : Any , ** kwargs : Any ) -> Any :
374+ session = getattr (self , "_session" , None )
375+ if session is not None :
376+ return await session .executor .run_async (fn , * args , ** kwargs )
377+ return await asyncio .to_thread (fn , * args , ** kwargs )
378+
365379 try :
366380 # Get database list in thread
367381 databases : list [str | None ]
@@ -370,7 +384,7 @@ async def _load_schema_cache_async(self: AppProtocol) -> None:
370384 if db and db .lower () not in ("" , "master" ):
371385 databases = [db ]
372386 else :
373- all_dbs = await asyncio . to_thread (adapter .get_databases , connection )
387+ all_dbs = await run_db_call (adapter .get_databases , connection )
374388 system_dbs = {"master" , "tempdb" , "model" , "msdb" }
375389 databases = [d for d in all_dbs if d .lower () not in system_dbs ]
376390 else :
@@ -379,7 +393,7 @@ async def _load_schema_cache_async(self: AppProtocol) -> None:
379393 for database in databases :
380394 try :
381395 # Get tables in thread (NO columns - lazy loaded)
382- tables = await asyncio . to_thread (adapter .get_tables , connection , database )
396+ tables = await run_db_call (adapter .get_tables , connection , database )
383397 for schema_name , table_name in tables :
384398 display_name = adapter .format_table_name (schema_name , table_name )
385399 schema_cache ["tables" ].append (display_name )
@@ -392,7 +406,7 @@ async def _load_schema_cache_async(self: AppProtocol) -> None:
392406 table_metadata [full_name .lower ()] = (schema_name , table_name , database )
393407
394408 # Get views in thread (NO columns - lazy loaded)
395- views = await asyncio . to_thread (adapter .get_views , connection , database )
409+ views = await run_db_call (adapter .get_views , connection , database )
396410 for schema_name , view_name in views :
397411 display_name = adapter .format_table_name (schema_name , view_name )
398412 schema_cache ["views" ].append (display_name )
@@ -405,7 +419,7 @@ async def _load_schema_cache_async(self: AppProtocol) -> None:
405419 table_metadata [full_name .lower ()] = (schema_name , view_name , database )
406420
407421 if adapter .supports_stored_procedures :
408- procedures = await asyncio . to_thread (adapter .get_procedures , connection , database )
422+ procedures = await run_db_call (adapter .get_procedures , connection , database )
409423 schema_cache ["procedures" ].extend (procedures )
410424
411425 except Exception :
0 commit comments