Skip to content

Commit 1f97577

Browse files
authored
Merge pull request #250 from eibx/feature/update-delete-query-qualified-name
fix: Add database/schema for update and delete queries
2 parents 0f46bee + 57dffa1 commit 1f97577

5 files changed

Lines changed: 252 additions & 11 deletions

File tree

sqlit/domains/connections/providers/adapters/base.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -422,13 +422,15 @@ def qualified_name(self, database: str | None, schema: str | None, name: str) ->
422422
423423
Default handles SQL Server-style `[db].[schema].[name]`, PostgreSQL-
424424
style `"schema"."name"`, and single-part `"name"` by omitting any
425-
empty/None component. Dialects that want different composition
426-
(e.g. MySQL, which has no schemas within databases) can override.
425+
empty/None component, removing database component if provider
426+
doesn't support cross-database queries and remove default schema name.
427+
Dialects that want different composition (e.g. MySQL, which has
428+
no schemas within databases) can override.
427429
"""
428430
parts: list[str] = []
429-
if database:
431+
if database and self.supports_cross_database_queries:
430432
parts.append(self.quote_identifier(database))
431-
if schema:
433+
if schema and (schema != self.default_schema or parts):
432434
parts.append(self.quote_identifier(schema))
433435
parts.append(self.quote_identifier(name))
434436
return ".".join(parts)

sqlit/domains/results/ui/mixins/results.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,11 +1050,14 @@ def sql_value(v: object) -> str:
10501050
return "'" + str(v).replace("'", "''") + "'"
10511051

10521052
# Get table name and primary key columns
1053-
table_name = "<table>"
1053+
qualified_name = "<table>"
10541054
pk_column_names: set[str] = set()
10551055
table_info = self._get_active_results_table_info(table, _stacked)
10561056
if table_info:
1057-
table_name = table_info.get("name", table_name)
1057+
database_name = table_info.get("database")
1058+
schema_name = table_info.get("schema")
1059+
table_name = table_info.get("name")
1060+
qualified_name = self.current_provider.dialect.qualified_name(database_name, schema_name, table_name)
10581061
# Get PK columns from column info
10591062
for col in table_info.get("columns", []):
10601063
if col.is_primary_key:
@@ -1090,7 +1093,7 @@ def sql_value(v: object) -> str:
10901093
where_clause = " AND ".join(where_parts)
10911094

10921095
# Generate DELETE query for the row
1093-
query = f"DELETE FROM {table_name} WHERE {where_clause};"
1096+
query = f"DELETE FROM {qualified_name} WHERE {where_clause};"
10941097

10951098
# Set query and switch to insert mode
10961099
self._suppress_autocomplete_once = True
@@ -1147,10 +1150,14 @@ def sql_value(v: object) -> str:
11471150
return "'" + str(v).replace("'", "''") + "'"
11481151

11491152
# Get table name and primary key columns
1150-
table_name = "<table>"
1153+
qualified_name = "<table>"
11511154
pk_column_names: set[str] = set()
11521155
if table_info:
1153-
table_name = table_info.get("name", table_name)
1156+
database_name = table_info.get("database")
1157+
schema_name = table_info.get("schema")
1158+
table_name = table_info.get("name")
1159+
qualified_name = self.current_provider.dialect.qualified_name(database_name, schema_name, table_name)
1160+
11541161
# Get PK columns from column info
11551162
for col in table_info.get("columns", []):
11561163
if col.is_primary_key:
@@ -1182,7 +1189,7 @@ def sql_value(v: object) -> str:
11821189
where_clause = " AND ".join(where_parts)
11831190

11841191
# Generate UPDATE query with empty placeholder for the new value
1185-
query = f"UPDATE {table_name} SET {column_name} = '' WHERE {where_clause};"
1192+
query = f"UPDATE {qualified_name} SET {column_name} = '' WHERE {where_clause};"
11861193

11871194
# Find position inside the empty quotes (after "SET column = '")
11881195
set_prefix = f"SET {column_name} = '"
Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
"""Integration test for PR #250 against a real multi-database SQL Server.
2+
3+
The bug: ``action_delete_row`` / ``action_edit_cell`` in the results panel built
4+
the DELETE/UPDATE from the *bare* table name. When you preview a table that
5+
lives in a different database than the one the connection is currently on, the
6+
unqualified statement targets the *wrong* database (the connection's current
7+
one), silently mutating the wrong table.
8+
9+
This test reproduces it end to end:
10+
11+
* a real SQL Server with two databases, each holding ``dbo.widgets`` with one
12+
row carrying a database-specific marker;
13+
* a connection whose current database is A;
14+
* the real ``ResultsMixin`` actions driven with ``table_info`` pointing at the
15+
table in database B (exactly what the explorer stashes when you open a table
16+
from another database);
17+
* the generated SQL executed against the live connection.
18+
19+
With the bug the statement hits database A. With the fix it hits B. We assert
20+
that B is mutated and A is left untouched, so the test is RED on the old code
21+
and GREEN on PR #250.
22+
"""
23+
24+
from __future__ import annotations
25+
26+
from types import SimpleNamespace
27+
from typing import Any
28+
29+
import pytest
30+
31+
from sqlit.domains.connections.providers.adapters.base import ColumnInfo
32+
from sqlit.domains.connections.providers.mssql.adapter import SQLServerAdapter
33+
from sqlit.domains.results.ui.mixins.results import ResultsMixin
34+
from tests.conftest import MSSQL_HOST, MSSQL_PASSWORD, MSSQL_PORT, MSSQL_USER
35+
from tests.fixtures.mssql import mssql_available
36+
37+
DB_A = "sqlit_qual_a"
38+
DB_B = "sqlit_qual_b"
39+
40+
41+
def _master_config() -> Any:
42+
from tests.helpers import ConnectionConfig
43+
44+
return ConnectionConfig(
45+
name="test-qual-master",
46+
db_type="mssql",
47+
server=MSSQL_HOST,
48+
port=str(MSSQL_PORT),
49+
database="master",
50+
username=MSSQL_USER,
51+
password=MSSQL_PASSWORD,
52+
options={"auth_type": "sql"},
53+
)
54+
55+
56+
def _db_config(database: str) -> Any:
57+
from tests.helpers import ConnectionConfig
58+
59+
return ConnectionConfig(
60+
name=f"test-qual-{database}",
61+
db_type="mssql",
62+
server=MSSQL_HOST,
63+
port=str(MSSQL_PORT),
64+
database=database,
65+
username=MSSQL_USER,
66+
password=MSSQL_PASSWORD,
67+
options={"auth_type": "sql"},
68+
)
69+
70+
71+
@pytest.fixture
72+
def two_databases():
73+
"""Create two databases each with dbo.widgets(id PK, label) and one row."""
74+
if not mssql_available():
75+
pytest.skip("SQL Server is not available")
76+
77+
adapter = SQLServerAdapter()
78+
master = adapter.connect(_master_config())
79+
master.autocommit = True
80+
cur = master.cursor()
81+
for db, marker in ((DB_A, "A_original"), (DB_B, "B_original")):
82+
cur.execute(f"IF DB_ID('{db}') IS NOT NULL BEGIN ALTER DATABASE [{db}] SET SINGLE_USER WITH ROLLBACK IMMEDIATE; DROP DATABASE [{db}]; END")
83+
cur.execute(f"CREATE DATABASE [{db}]")
84+
cur.execute(f"CREATE TABLE [{db}].[dbo].[widgets] (id INT PRIMARY KEY, label NVARCHAR(50))")
85+
cur.execute(f"INSERT INTO [{db}].[dbo].[widgets] (id, label) VALUES (1, '{marker}')")
86+
cur.close()
87+
master.close()
88+
89+
yield adapter
90+
91+
master = adapter.connect(_master_config())
92+
master.autocommit = True
93+
cur = master.cursor()
94+
for db in (DB_A, DB_B):
95+
cur.execute(f"IF DB_ID('{db}') IS NOT NULL BEGIN ALTER DATABASE [{db}] SET SINGLE_USER WITH ROLLBACK IMMEDIATE; DROP DATABASE [{db}]; END")
96+
cur.close()
97+
master.close()
98+
99+
100+
class _FakeInput:
101+
def __init__(self) -> None:
102+
self.text = ""
103+
self.cursor_location = (0, 0)
104+
self.read_only = False
105+
106+
def focus(self) -> None:
107+
pass
108+
109+
110+
class _FakeTable:
111+
"""Mimics the focused DataTable holding the single previewed row."""
112+
113+
def __init__(self, row: tuple[Any, ...]) -> None:
114+
self._row = row
115+
self.row_count = 1
116+
self.cursor_coordinate: tuple[int, int] = (0, 0)
117+
118+
def get_row_at(self, _row: int) -> list[Any]:
119+
return list(self._row)
120+
121+
122+
class _ResultsHost(ResultsMixin):
123+
"""Minimal host so the *real* ResultsMixin actions run without Textual.
124+
125+
Everything that matters for the bug — qualified_name composition, the
126+
table_info lookup, WHERE/PK handling — is the production mixin + adapter.
127+
"""
128+
129+
def __init__(self, adapter: SQLServerAdapter, table_info: dict[str, Any], row: tuple[Any, ...], columns: list[str]) -> None:
130+
self._table = _FakeTable(row)
131+
self._columns = columns
132+
# The explorer stashes the previewed table's identity here; the real
133+
# _get_active_results_table_info falls back to it.
134+
self._last_query_table = table_info
135+
self.query_input = _FakeInput()
136+
self._suppress_autocomplete_once = False
137+
self.current_provider = SimpleNamespace(dialect=adapter)
138+
self.vim_mode = None
139+
140+
def _get_active_results_context(self) -> tuple[Any, list, list, bool]:
141+
return self._table, self._columns, [tuple(self._table._row)], False
142+
143+
def notify(self, *_a: Any, **_k: Any) -> None:
144+
pass
145+
146+
def action_focus_query(self) -> None:
147+
pass
148+
149+
def _update_footer_bindings(self) -> None:
150+
pass
151+
152+
def _update_vim_mode_visuals(self) -> None:
153+
pass
154+
155+
156+
def _columns_meta() -> list[ColumnInfo]:
157+
return [
158+
ColumnInfo(name="id", data_type="int", is_primary_key=True),
159+
ColumnInfo(name="label", data_type="nvarchar", is_primary_key=False),
160+
]
161+
162+
163+
def _count(adapter: SQLServerAdapter, conn: Any, database: str) -> int:
164+
_cols, rows, _ = adapter.execute_query(conn, f"SELECT COUNT(*) FROM [{database}].[dbo].[widgets]")
165+
return rows[0][0]
166+
167+
168+
def _label(adapter: SQLServerAdapter, conn: Any, database: str) -> str | None:
169+
_cols, rows, _ = adapter.execute_query(conn, f"SELECT label FROM [{database}].[dbo].[widgets] WHERE id = 1")
170+
return rows[0][0] if rows else None
171+
172+
173+
@pytest.mark.integration
174+
@pytest.mark.mssql
175+
class TestResultsQualifiedMutation:
176+
def test_delete_targets_table_own_database(self, two_databases: SQLServerAdapter) -> None:
177+
adapter = two_databases
178+
# Connection's *current* database is A; we operate on a row from B.
179+
conn = adapter.connect(_db_config(DB_A))
180+
try:
181+
table_info = {"database": DB_B, "schema": "dbo", "name": "widgets", "columns": _columns_meta()}
182+
host = _ResultsHost(adapter, table_info, row=(1, "B_original"), columns=["id", "label"])
183+
184+
host.action_delete_row()
185+
query = host.query_input.text
186+
assert query, "no DELETE query generated"
187+
188+
# Execute exactly what the panel produced, against the A-connection.
189+
cur = conn.cursor()
190+
cur.execute(query)
191+
conn.commit()
192+
cur.close()
193+
194+
# The fix must delete from B (the table we were viewing) and leave A.
195+
assert _count(adapter, conn, DB_B) == 0, f"row in {DB_B} should be deleted; query was: {query}"
196+
assert _count(adapter, conn, DB_A) == 1, f"row in {DB_A} must be untouched; query was: {query}"
197+
finally:
198+
conn.close()
199+
200+
def test_update_targets_table_own_database(self, two_databases: SQLServerAdapter) -> None:
201+
adapter = two_databases
202+
conn = adapter.connect(_db_config(DB_A))
203+
try:
204+
table_info = {"database": DB_B, "schema": "dbo", "name": "widgets", "columns": _columns_meta()}
205+
host = _ResultsHost(adapter, table_info, row=(1, "B_original"), columns=["id", "label"])
206+
# Put the cursor on the editable (non-PK) `label` column.
207+
host._table.cursor_coordinate = (0, 1)
208+
209+
host.action_edit_cell()
210+
query = host.query_input.text
211+
assert query and query.startswith("UPDATE"), f"no UPDATE query generated: {query!r}"
212+
213+
cur = conn.cursor()
214+
cur.execute(query)
215+
conn.commit()
216+
cur.close()
217+
218+
# B's label was set to '' (the panel's placeholder); A stays original.
219+
assert _label(adapter, conn, DB_B) == "", f"row in {DB_B} should be updated; query was: {query}"
220+
assert _label(adapter, conn, DB_A) == "A_original", f"row in {DB_A} must be untouched; query was: {query}"
221+
finally:
222+
conn.close()

tests/unit/test_autocomplete_multidb.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,13 @@ def test_postgresql_qualified_name_uses_schema_only() -> None:
4646
for cross-reference within the connected database."""
4747
dialect = _get_dialect("postgresql")
4848
# No db segment expected when schema is present.
49-
assert dialect.qualified_name(None, "public", "users") == '"public"."users"'
49+
assert dialect.qualified_name(None, "test", "users") == '"test"."users"'
5050

51+
def test_postgresql_qualified_name_uses_table_only() -> None:
52+
"""PostgreSQL uses public as default schema. Only table makes sense."""
53+
dialect = _get_dialect("postgresql")
54+
# No db segment expected when schema is present.
55+
assert dialect.qualified_name(None, "public", "users") == '"users"'
5156

5257
def test_sqlserver_qualified_name_is_three_part() -> None:
5358
"""SQL Server explicitly uses [db].[schema].[table]."""

tests/unit/test_results_copy_markup.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from __future__ import annotations
1111

12+
from types import SimpleNamespace
1213
from typing import Any
1314

1415
import pytest
@@ -86,6 +87,7 @@ def test_copy_cell_preserves_literal_brackets_when_not_rendering_markup() -> Non
8687
app.action_copy_cell()
8788
assert app.clipboard_text == "[bold]hello"
8889

90+
8991
class _FakeQueryInput:
9092
def __init__(self) -> None:
9193
self.text = ""
@@ -102,6 +104,9 @@ def __init__(self, cells: list[tuple[str, ...]], columns: list[str]) -> None:
102104
self._columns = columns
103105
self.query_input = _FakeQueryInput()
104106
self._suppress_autocomplete_once = False
107+
self.current_provider = SimpleNamespace(
108+
dialect=SimpleNamespace(qualified_name=lambda database, schema, name: name),
109+
)
105110

106111
def _get_active_results_context(self) -> tuple[Any, list, list, bool]:
107112
return self._table, self._columns, [], False

0 commit comments

Comments
 (0)