Skip to content

Commit 103d8a6

Browse files
committed
Group tables by schema if multiple
1 parent bf771ca commit 103d8a6

8 files changed

Lines changed: 1112 additions & 41 deletions

File tree

sqlit/state_machine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,7 @@ def is_active(self, app: SSMSTUI) -> bool:
599599

600600

601601
class TreeOnFolderState(State):
602-
"""Tree focused on a folder or database node."""
602+
"""Tree focused on a folder, database, or schema node."""
603603

604604
def _setup_actions(self) -> None:
605605
pass # Just inherits from parent
@@ -632,7 +632,7 @@ def is_active(self, app: SSMSTUI) -> bool:
632632
return (
633633
node is not None
634634
and node.data is not None
635-
and node.data[0] in ("folder", "database")
635+
and node.data[0] in ("folder", "database", "schema")
636636
)
637637

638638

sqlit/ui/mixins/query.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from typing import TYPE_CHECKING, Any
66

7+
from rich.markup import escape as escape_markup
78
from textual.timer import Timer
89
from textual.widgets import DataTable, TextArea
910
from textual.worker import Worker
@@ -187,9 +188,8 @@ def _display_query_results(
187188
self.results_table.clear(columns=True)
188189
self.results_table.add_columns(*columns)
189190

190-
# Only display first 1000 rows in the table
191191
for row in rows[:1000]:
192-
str_row = tuple(str(v) if v is not None else "NULL" for v in row)
192+
str_row = tuple(escape_markup(str(v)) if v is not None else "NULL" for v in row)
193193
self.results_table.add_row(*str_row)
194194

195195
time_str = f"{elapsed_ms:.0f}ms" if elapsed_ms >= 1 else f"{elapsed_ms:.2f}ms"
@@ -218,7 +218,7 @@ def _display_query_error(self, error_message: str) -> None:
218218

219219
self.results_table.clear(columns=True)
220220
self.results_table.add_column("Error")
221-
self.results_table.add_row(error_message)
221+
self.results_table.add_row(escape_markup(error_message))
222222
self.notify(f"Query error: {error_message}", severity="error")
223223

224224
def _restore_insert_mode(self) -> None:

sqlit/ui/mixins/tree.py

Lines changed: 66 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from typing import TYPE_CHECKING, Any
66

7+
from rich.markup import escape as escape_markup
78
from textual.widgets import Tree
89

910
if TYPE_CHECKING:
@@ -44,10 +45,11 @@ def refresh_tree(self) -> None:
4445
self.object_tree.root.expand()
4546

4647
for conn in self.connections:
47-
display_info = conn.get_display_info()
48+
display_info = escape_markup(conn.get_display_info())
4849
db_type_label = self._db_type_badge(conn.db_type)
50+
escaped_name = escape_markup(conn.name)
4951
node = self.object_tree.root.add(
50-
f"[dim]{conn.name}[/dim] [{db_type_label}] ({display_info})"
52+
f"[dim]{escaped_name}[/dim] [{db_type_label}] ({display_info})"
5153
)
5254
node.data = ("connection", conn)
5355
node.allow_expand = True
@@ -63,12 +65,13 @@ def populate_connected_tree(self) -> None:
6365
adapter = self.current_adapter
6466

6567
def get_conn_label(config, connected=False):
66-
display_info = config.get_display_info()
68+
display_info = escape_markup(config.get_display_info())
6769
db_type_label = self._db_type_badge(config.db_type)
70+
escaped_name = escape_markup(config.name)
6871
if connected:
69-
name = f"[green]{config.name}[/green]"
72+
name = f"[green]{escaped_name}[/green]"
7073
else:
71-
name = config.name
74+
name = escaped_name
7275
return f"{name} [{db_type_label}] ({display_info})"
7376

7477
active_node = None
@@ -99,7 +102,7 @@ def get_conn_label(config, connected=False):
99102

100103
databases = adapter.get_databases(self.current_connection)
101104
for db_name in databases:
102-
db_node = dbs_node.add(db_name)
105+
db_node = dbs_node.add(escape_markup(db_name))
103106
db_node.data = ("database", db_name)
104107
db_node.allow_expand = True
105108
self._add_database_object_nodes(db_node, db_name)
@@ -143,8 +146,9 @@ def _get_node_path(self, node) -> str:
143146
parts.append(f"db:{data[1]}")
144147
elif data[0] == "folder":
145148
parts.append(f"folder:{data[1]}")
149+
elif data[0] == "schema":
150+
parts.append(f"schema:{data[2]}")
146151
elif data[0] in ("table", "view") and len(data) >= 4:
147-
# Include schema in path for uniqueness
148152
schema_name = data[2]
149153
obj_name = data[3]
150154
parts.append(f"{data[0]}:{schema_name}.{obj_name}")
@@ -258,14 +262,14 @@ def _on_columns_loaded(self, node, db_name: str, schema_name: str, obj_name: str
258262
node_path = self._get_node_path(node)
259263
self._loading_nodes.discard(node_path)
260264

261-
# Remove loading placeholder
262265
for child in list(node.children):
263266
if child.data == ("loading",):
264267
child.remove()
265268

266-
# Add column nodes
267269
for col in columns:
268-
child = node.add_leaf(f"[dim]{col.name}[/] [italic dim]{col.data_type}[/]")
270+
col_name = escape_markup(col.name)
271+
col_type = escape_markup(col.data_type)
272+
child = node.add_leaf(f"[dim]{col_name}[/] [italic dim]{col_type}[/]")
269273
child.data = ("column", db_name, schema_name, obj_name, col.name)
270274

271275
def _load_folder_async(self, node, data: tuple) -> None:
@@ -306,7 +310,6 @@ def _on_folder_loaded(self, node, db_name: str | None, folder_type: str, items:
306310
node_path = self._get_node_path(node)
307311
self._loading_nodes.discard(node_path)
308312

309-
# Remove loading placeholder
310313
for child in list(node.children):
311314
if child.data == ("loading",):
312315
child.remove()
@@ -316,36 +319,70 @@ def _on_folder_loaded(self, node, db_name: str | None, folder_type: str, items:
316319

317320
adapter = self._session.adapter
318321

319-
# Add nodes based on type
322+
if folder_type in ("tables", "views"):
323+
self._add_schema_grouped_items(node, db_name, folder_type, items, adapter.default_schema)
324+
else:
325+
for item in items:
326+
if item[0] == "procedure":
327+
child = node.add(escape_markup(item[1]))
328+
child.data = ("procedure", db_name, item[1])
329+
330+
def _add_schema_grouped_items(
331+
self,
332+
node,
333+
db_name: str | None,
334+
folder_type: str,
335+
items: list,
336+
default_schema: str,
337+
) -> None:
338+
"""Add tables/views grouped by schema."""
339+
from collections import defaultdict
340+
341+
by_schema: dict[str, list] = defaultdict(list)
320342
for item in items:
321-
if item[0] == "table":
322-
schema_name, table_name = item[1], item[2]
323-
display_name = adapter.format_table_name(schema_name, table_name)
324-
child = node.add(display_name)
325-
child.data = ("table", db_name, schema_name, table_name)
326-
child.allow_expand = True
327-
elif item[0] == "view":
328-
schema_name, view_name = item[1], item[2]
329-
display_name = adapter.format_table_name(schema_name, view_name)
330-
child = node.add(display_name)
331-
child.data = ("view", db_name, schema_name, view_name)
343+
by_schema[item[1]].append(item)
344+
345+
def schema_sort_key(schema: str) -> tuple[int, str]:
346+
if not schema or schema == default_schema:
347+
return (0, schema)
348+
return (1, schema.lower())
349+
350+
sorted_schemas = sorted(by_schema.keys(), key=schema_sort_key)
351+
has_multiple_schemas = len(sorted_schemas) > 1
352+
schema_nodes: dict[str, any] = {}
353+
354+
for schema in sorted_schemas:
355+
schema_items = by_schema[schema]
356+
is_default = not schema or schema == default_schema
357+
358+
if is_default and not has_multiple_schemas:
359+
parent = node
360+
else:
361+
if schema not in schema_nodes:
362+
display_name = schema if schema else default_schema
363+
escaped_name = escape_markup(display_name)
364+
schema_node = node.add(f"[dim]\\[{escaped_name}][/]")
365+
schema_node.data = ("schema", db_name, schema or default_schema, folder_type)
366+
schema_node.allow_expand = True
367+
schema_nodes[schema] = schema_node
368+
parent = schema_nodes[schema]
369+
370+
for item in schema_items:
371+
item_type, schema_name, obj_name = item[0], item[1], item[2]
372+
child = parent.add(escape_markup(obj_name))
373+
child.data = (item_type, db_name, schema_name, obj_name)
332374
child.allow_expand = True
333-
elif item[0] == "procedure":
334-
proc_name = item[1]
335-
child = node.add(proc_name)
336-
child.data = ("procedure", db_name, proc_name)
337375

338376
def _on_tree_load_error(self, node, error_message: str) -> None:
339377
"""Handle tree load error on main thread."""
340378
node_path = self._get_node_path(node)
341379
self._loading_nodes.discard(node_path)
342380

343-
# Remove loading placeholder
344381
for child in list(node.children):
345382
if child.data == ("loading",):
346383
child.remove()
347384

348-
self.notify(error_message, severity="error")
385+
self.notify(escape_markup(error_message), severity="error")
349386

350387
def on_tree_node_selected(self, event: Tree.NodeSelected) -> None:
351388
"""Handle tree node selection (double-click/enter)."""

tests/ui/explorer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Explorer tree tests."""

0 commit comments

Comments
 (0)