44
55from typing import TYPE_CHECKING , Any
66
7+ from rich .markup import escape as escape_markup
78from textual .widgets import Tree
89
910if TYPE_CHECKING :
@@ -44,10 +45,11 @@ def refresh_tree(self) -> None:
4445 self .object_tree .root .expand ()
4546
4647 for conn in self .connections :
47- display_info = conn .get_display_info ()
48+ display_info = escape_markup ( conn .get_display_info () )
4849 db_type_label = self ._db_type_badge (conn .db_type )
50+ escaped_name = escape_markup (conn .name )
4951 node = self .object_tree .root .add (
50- f"[dim]{ conn . name } [/dim] [{ db_type_label } ] ({ display_info } )"
52+ f"[dim]{ escaped_name } [/dim] [{ db_type_label } ] ({ display_info } )"
5153 )
5254 node .data = ("connection" , conn )
5355 node .allow_expand = True
@@ -63,12 +65,13 @@ def populate_connected_tree(self) -> None:
6365 adapter = self .current_adapter
6466
6567 def get_conn_label (config , connected = False ):
66- display_info = config .get_display_info ()
68+ display_info = escape_markup ( config .get_display_info () )
6769 db_type_label = self ._db_type_badge (config .db_type )
70+ escaped_name = escape_markup (config .name )
6871 if connected :
69- name = f"[green]{ config . name } [/green]"
72+ name = f"[green]{ escaped_name } [/green]"
7073 else :
71- name = config . name
74+ name = escaped_name
7275 return f"{ name } [{ db_type_label } ] ({ display_info } )"
7376
7477 active_node = None
@@ -99,7 +102,7 @@ def get_conn_label(config, connected=False):
99102
100103 databases = adapter .get_databases (self .current_connection )
101104 for db_name in databases :
102- db_node = dbs_node .add (db_name )
105+ db_node = dbs_node .add (escape_markup ( db_name ) )
103106 db_node .data = ("database" , db_name )
104107 db_node .allow_expand = True
105108 self ._add_database_object_nodes (db_node , db_name )
@@ -143,8 +146,9 @@ def _get_node_path(self, node) -> str:
143146 parts .append (f"db:{ data [1 ]} " )
144147 elif data [0 ] == "folder" :
145148 parts .append (f"folder:{ data [1 ]} " )
149+ elif data [0 ] == "schema" :
150+ parts .append (f"schema:{ data [2 ]} " )
146151 elif data [0 ] in ("table" , "view" ) and len (data ) >= 4 :
147- # Include schema in path for uniqueness
148152 schema_name = data [2 ]
149153 obj_name = data [3 ]
150154 parts .append (f"{ data [0 ]} :{ schema_name } .{ obj_name } " )
@@ -258,14 +262,14 @@ def _on_columns_loaded(self, node, db_name: str, schema_name: str, obj_name: str
258262 node_path = self ._get_node_path (node )
259263 self ._loading_nodes .discard (node_path )
260264
261- # Remove loading placeholder
262265 for child in list (node .children ):
263266 if child .data == ("loading" ,):
264267 child .remove ()
265268
266- # Add column nodes
267269 for col in columns :
268- child = node .add_leaf (f"[dim]{ col .name } [/] [italic dim]{ col .data_type } [/]" )
270+ col_name = escape_markup (col .name )
271+ col_type = escape_markup (col .data_type )
272+ child = node .add_leaf (f"[dim]{ col_name } [/] [italic dim]{ col_type } [/]" )
269273 child .data = ("column" , db_name , schema_name , obj_name , col .name )
270274
271275 def _load_folder_async (self , node , data : tuple ) -> None :
@@ -306,7 +310,6 @@ def _on_folder_loaded(self, node, db_name: str | None, folder_type: str, items:
306310 node_path = self ._get_node_path (node )
307311 self ._loading_nodes .discard (node_path )
308312
309- # Remove loading placeholder
310313 for child in list (node .children ):
311314 if child .data == ("loading" ,):
312315 child .remove ()
@@ -316,36 +319,70 @@ def _on_folder_loaded(self, node, db_name: str | None, folder_type: str, items:
316319
317320 adapter = self ._session .adapter
318321
319- # Add nodes based on type
322+ if folder_type in ("tables" , "views" ):
323+ self ._add_schema_grouped_items (node , db_name , folder_type , items , adapter .default_schema )
324+ else :
325+ for item in items :
326+ if item [0 ] == "procedure" :
327+ child = node .add (escape_markup (item [1 ]))
328+ child .data = ("procedure" , db_name , item [1 ])
329+
330+ def _add_schema_grouped_items (
331+ self ,
332+ node ,
333+ db_name : str | None ,
334+ folder_type : str ,
335+ items : list ,
336+ default_schema : str ,
337+ ) -> None :
338+ """Add tables/views grouped by schema."""
339+ from collections import defaultdict
340+
341+ by_schema : dict [str , list ] = defaultdict (list )
320342 for item in items :
321- if item [0 ] == "table" :
322- schema_name , table_name = item [1 ], item [2 ]
323- display_name = adapter .format_table_name (schema_name , table_name )
324- child = node .add (display_name )
325- child .data = ("table" , db_name , schema_name , table_name )
326- child .allow_expand = True
327- elif item [0 ] == "view" :
328- schema_name , view_name = item [1 ], item [2 ]
329- display_name = adapter .format_table_name (schema_name , view_name )
330- child = node .add (display_name )
331- child .data = ("view" , db_name , schema_name , view_name )
343+ by_schema [item [1 ]].append (item )
344+
345+ def schema_sort_key (schema : str ) -> tuple [int , str ]:
346+ if not schema or schema == default_schema :
347+ return (0 , schema )
348+ return (1 , schema .lower ())
349+
350+ sorted_schemas = sorted (by_schema .keys (), key = schema_sort_key )
351+ has_multiple_schemas = len (sorted_schemas ) > 1
352+ schema_nodes : dict [str , any ] = {}
353+
354+ for schema in sorted_schemas :
355+ schema_items = by_schema [schema ]
356+ is_default = not schema or schema == default_schema
357+
358+ if is_default and not has_multiple_schemas :
359+ parent = node
360+ else :
361+ if schema not in schema_nodes :
362+ display_name = schema if schema else default_schema
363+ escaped_name = escape_markup (display_name )
364+ schema_node = node .add (f"[dim]\\ [{ escaped_name } ][/]" )
365+ schema_node .data = ("schema" , db_name , schema or default_schema , folder_type )
366+ schema_node .allow_expand = True
367+ schema_nodes [schema ] = schema_node
368+ parent = schema_nodes [schema ]
369+
370+ for item in schema_items :
371+ item_type , schema_name , obj_name = item [0 ], item [1 ], item [2 ]
372+ child = parent .add (escape_markup (obj_name ))
373+ child .data = (item_type , db_name , schema_name , obj_name )
332374 child .allow_expand = True
333- elif item [0 ] == "procedure" :
334- proc_name = item [1 ]
335- child = node .add (proc_name )
336- child .data = ("procedure" , db_name , proc_name )
337375
338376 def _on_tree_load_error (self , node , error_message : str ) -> None :
339377 """Handle tree load error on main thread."""
340378 node_path = self ._get_node_path (node )
341379 self ._loading_nodes .discard (node_path )
342380
343- # Remove loading placeholder
344381 for child in list (node .children ):
345382 if child .data == ("loading" ,):
346383 child .remove ()
347384
348- self .notify (error_message , severity = "error" )
385+ self .notify (escape_markup ( error_message ) , severity = "error" )
349386
350387 def on_tree_node_selected (self , event : Tree .NodeSelected ) -> None :
351388 """Handle tree node selection (double-click/enter)."""
0 commit comments