Skip to content

Commit 9fe3746

Browse files
committed
Merge main into feature/extra-options-passthrough
Resolve conflicts in mssql and mysql adapters by combining: - extra_options passthrough from this branch - autocommit enablement (mssql) and charset auto-sync (mysql) from main
2 parents 70e4cd0 + 171c402 commit 9fe3746

18 files changed

Lines changed: 1376 additions & 8 deletions

File tree

infra/docker/docker-compose.test.yml

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,48 @@ services:
5151
tmpfs:
5252
- /var/lib/mysql
5353

54+
# MySQL with TIS-620 charset (Thai) - for charset testing
55+
mysql-tis620:
56+
image: mysql:8.0
57+
container_name: sqlit-test-mysql-tis620
58+
command: --character-set-server=tis620 --collation-server=tis620_thai_ci
59+
environment:
60+
MYSQL_ROOT_PASSWORD: "TestPassword123!"
61+
MYSQL_USER: "testuser"
62+
MYSQL_PASSWORD: "TestPassword123!"
63+
MYSQL_DATABASE: "test_sqlit"
64+
ports:
65+
- "${MYSQL_TIS620_PORT:-3308}:3306"
66+
healthcheck:
67+
test: ["CMD", "mysqladmin", "ping", "-h", "localhost", "-u", "testuser", "-pTestPassword123!"]
68+
interval: 5s
69+
timeout: 5s
70+
retries: 10
71+
start_period: 30s
72+
tmpfs:
73+
- /var/lib/mysql
74+
75+
# MySQL with Latin1 charset - for charset testing
76+
mysql-latin1:
77+
image: mysql:8.0
78+
container_name: sqlit-test-mysql-latin1
79+
command: --character-set-server=latin1 --collation-server=latin1_swedish_ci
80+
environment:
81+
MYSQL_ROOT_PASSWORD: "TestPassword123!"
82+
MYSQL_USER: "testuser"
83+
MYSQL_PASSWORD: "TestPassword123!"
84+
MYSQL_DATABASE: "test_sqlit"
85+
ports:
86+
- "${MYSQL_LATIN1_PORT:-3309}:3306"
87+
healthcheck:
88+
test: ["CMD", "mysqladmin", "ping", "-h", "localhost", "-u", "testuser", "-pTestPassword123!"]
89+
interval: 5s
90+
timeout: 5s
91+
retries: 10
92+
start_period: 30s
93+
tmpfs:
94+
- /var/lib/mysql
95+
5496
clickhouse:
5597
image: clickhouse/clickhouse-server:latest
5698
container_name: sqlit-test-clickhouse

sqlit/domains/connections/providers/mariadb/adapter.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,11 @@ def connect(self, config: ConnectionConfig) -> Any:
102102

103103
connect_args.update(config.extra_options)
104104
conn = mariadb_any.connect(**connect_args)
105+
106+
# Note: The MariaDB Python connector only supports UTF-8 family charsets.
107+
# Legacy charsets like TIS-620 or Latin1 are not supported. For databases
108+
# using legacy charsets, use the MySQL provider with PyMySQL instead.
109+
105110
self._supports_sequences = self._detect_sequences_support(conn)
106111
return conn
107112

sqlit/domains/connections/providers/mssql/adapter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,10 @@ def connect(self, config: ConnectionConfig) -> Any:
183183
# Append extra_options to connection string
184184
for key, value in config.extra_options.items():
185185
conn_str += f"{key}={value};"
186-
return mssql_python.connect(conn_str)
186+
conn = mssql_python.connect(conn_str)
187+
# Enable autocommit to allow DDL statements like CREATE DATABASE
188+
conn.autocommit = True
189+
return conn
187190

188191
def get_databases(self, conn: Any) -> list[str]:
189192
"""Get list of databases from SQL Server."""

sqlit/domains/connections/providers/mysql/adapter.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,4 +111,24 @@ def connect(self, config: ConnectionConfig) -> Any:
111111
connect_args["ssl"] = ssl_params
112112

113113
connect_args.update(config.extra_options)
114-
return pymysql.connect(**connect_args)
114+
conn = pymysql.connect(**connect_args)
115+
116+
# Auto-sync charset with server to handle legacy encodings (e.g., TIS-620, Latin1).
117+
# This ensures data is read correctly when the database uses a non-UTF-8 charset.
118+
try:
119+
cursor = conn.cursor()
120+
cursor.execute("SELECT @@character_set_database")
121+
row = cursor.fetchone()
122+
if row:
123+
server_charset = row[0]
124+
# Only switch if server uses a different charset than our default (utf8mb4)
125+
if server_charset and server_charset.lower() != "utf8mb4":
126+
# Use set_charset() which both sends SET NAMES AND updates
127+
# PyMySQL's internal encoding for proper byte decoding
128+
conn.set_charset(server_charset)
129+
cursor.close()
130+
except Exception:
131+
# If charset sync fails, continue with default - better than failing completely
132+
pass
133+
134+
return conn

sqlit/domains/connections/providers/oracle/adapter.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,18 @@ def connect(self, config: ConnectionConfig) -> Any:
7272
if endpoint is None:
7373
raise ValueError("Oracle connections require a TCP-style endpoint.")
7474
port = int(endpoint.port or get_default_port("oracle"))
75-
# Use Easy Connect string format: host:port/service_name
76-
dsn = f"{endpoint.host}:{port}/{endpoint.database}"
75+
76+
# Determine connection type: service_name (default) or sid
77+
connection_type = config.get_option("oracle_connection_type", "service_name")
78+
79+
if connection_type == "sid":
80+
# SID format: host:port:sid (uses colon separator)
81+
# SID is stored in oracle_sid field, fall back to database for backward compat
82+
sid = config.get_option("oracle_sid") or endpoint.database
83+
dsn = f"{endpoint.host}:{port}:{sid}"
84+
else:
85+
# Service Name format: host:port/service_name (uses slash separator)
86+
dsn = f"{endpoint.host}:{port}/{endpoint.database}"
7787

7888
# Determine connection mode based on oracle_role
7989
oracle_role = config.get_option("oracle_role", "normal")

sqlit/domains/connections/providers/oracle/schema.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,21 @@ def _get_oracle_role_options() -> tuple[SelectOption, ...]:
2020
)
2121

2222

23+
def _get_oracle_connection_type_options() -> tuple[SelectOption, ...]:
24+
return (
25+
SelectOption("service_name", "Service Name"),
26+
SelectOption("sid", "SID"),
27+
)
28+
29+
30+
def _oracle_connection_type_is_service_name(values: dict) -> bool:
31+
return values.get("oracle_connection_type", "service_name") != "sid"
32+
33+
34+
def _oracle_connection_type_is_sid(values: dict) -> bool:
35+
return values.get("oracle_connection_type") == "sid"
36+
37+
2338
SCHEMA = ConnectionSchema(
2439
db_type="oracle",
2540
display_name="Oracle",
@@ -32,11 +47,26 @@ def _get_oracle_role_options() -> tuple[SelectOption, ...]:
3247
group="server_port",
3348
),
3449
_port_field("1521"),
50+
SchemaField(
51+
name="oracle_connection_type",
52+
label="Connection Type",
53+
field_type=FieldType.DROPDOWN,
54+
options=_get_oracle_connection_type_options(),
55+
default="service_name",
56+
),
3557
SchemaField(
3658
name="database",
3759
label="Service Name",
3860
placeholder="ORCL or XEPDB1",
3961
required=True,
62+
visible_when=_oracle_connection_type_is_service_name,
63+
),
64+
SchemaField(
65+
name="oracle_sid",
66+
label="SID",
67+
placeholder="ORCL",
68+
required=True,
69+
visible_when=_oracle_connection_type_is_sid,
4070
),
4171
_username_field(),
4272
_password_field(),

sqlit/domains/connections/providers/oracle_legacy/schema.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,21 @@ def _get_oracle_role_options() -> tuple[SelectOption, ...]:
2020
)
2121

2222

23+
def _get_oracle_connection_type_options() -> tuple[SelectOption, ...]:
24+
return (
25+
SelectOption("service_name", "Service Name"),
26+
SelectOption("sid", "SID"),
27+
)
28+
29+
30+
def _oracle_connection_type_is_service_name(values: dict) -> bool:
31+
return values.get("oracle_connection_type", "service_name") != "sid"
32+
33+
34+
def _oracle_connection_type_is_sid(values: dict) -> bool:
35+
return values.get("oracle_connection_type") == "sid"
36+
37+
2338
def _get_oracle_client_mode_options() -> tuple[SelectOption, ...]:
2439
return (
2540
SelectOption("thick", "Thick (Instant Client)"),
@@ -43,11 +58,26 @@ def _oracle_thick_mode_enabled(values: dict) -> bool:
4358
group="server_port",
4459
),
4560
_port_field("1521"),
61+
SchemaField(
62+
name="oracle_connection_type",
63+
label="Connection Type",
64+
field_type=FieldType.DROPDOWN,
65+
options=_get_oracle_connection_type_options(),
66+
default="service_name",
67+
),
4668
SchemaField(
4769
name="database",
4870
label="Service Name",
71+
placeholder="ORCL or XEPDB1",
72+
required=True,
73+
visible_when=_oracle_connection_type_is_service_name,
74+
),
75+
SchemaField(
76+
name="oracle_sid",
77+
label="SID",
4978
placeholder="ORCL",
5079
required=True,
80+
visible_when=_oracle_connection_type_is_sid,
5181
),
5282
_username_field(),
5383
_password_field(),

sqlit/domains/explorer/app/schema_service.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,11 @@ def list_folder_items(self, folder_type: str, database: str | None) -> list[Any]
9494
cache_key = database or "__default__"
9595
obj_cache = self.object_cache
9696

97-
def cached(key: str, loader: Callable[[], Any]) -> Any:
97+
def cached(key: str, loader: Callable[[], Any], *, allow_empty: bool = True) -> Any:
9898
if cache_key in obj_cache and key in obj_cache[cache_key]:
99-
return obj_cache[cache_key][key]
99+
data = obj_cache[cache_key][key]
100+
if allow_empty or data:
101+
return data
100102
data = loader()
101103
if cache_key not in obj_cache:
102104
obj_cache[cache_key] = {}
@@ -110,6 +112,7 @@ def cached(key: str, loader: Callable[[], Any]) -> Any:
110112
lambda: inspector.get_tables(self.session.connection, db_arg),
111113
database,
112114
),
115+
allow_empty=self.session.provider.metadata.db_type != "duckdb",
113116
)
114117
return [("table", schema, name) for schema, name in raw_data]
115118
if folder_type == "views":
@@ -119,6 +122,7 @@ def cached(key: str, loader: Callable[[], Any]) -> Any:
119122
lambda: inspector.get_views(self.session.connection, db_arg),
120123
database,
121124
),
125+
allow_empty=self.session.provider.metadata.db_type != "duckdb",
122126
)
123127
return [("view", schema, name) for schema, name in raw_data]
124128
if folder_type == "databases":

sqlit/domains/explorer/ui/mixins/tree.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,9 +210,21 @@ def on_tree_node_highlighted(self: TreeMixinHost, event: Tree.NodeHighlighted) -
210210

211211
def action_refresh_tree(self: TreeMixinHost) -> None:
212212
"""Refresh the explorer."""
213+
self._refresh_tree_common(notify=True)
214+
215+
def _refresh_tree_after_schema_change(self: TreeMixinHost) -> None:
216+
"""Refresh tree after DDL without showing a notification."""
217+
self._refresh_tree_common(notify=False)
218+
219+
def _refresh_tree_common(self: TreeMixinHost, *, notify: bool) -> None:
213220
self._get_object_cache().clear()
214-
if hasattr(self, "_schema_cache") and "columns" in self._schema_cache:
221+
if hasattr(self, "_schema_cache") and isinstance(self._schema_cache, dict):
215222
self._schema_cache["columns"] = {}
223+
self._schema_cache["tables"] = []
224+
self._schema_cache["views"] = []
225+
self._schema_cache["procedures"] = []
226+
if hasattr(self, "_db_object_cache"):
227+
self._db_object_cache = {}
216228
if hasattr(self, "_loading_nodes"):
217229
self._loading_nodes.clear()
218230
self._schema_service = None
@@ -253,7 +265,8 @@ def run_loader() -> None:
253265
)
254266
else:
255267
self._schedule_timer(MIN_TIMER_DELAY_S, run_loader)
256-
self.notify("Refreshed")
268+
if notify:
269+
self.notify("Refreshed")
257270

258271
def refresh_tree(self: TreeMixinHost) -> None:
259272
tree_builder.refresh_tree_chunked(self)

sqlit/domains/query/ui/mixins/query_execution.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from __future__ import annotations
44

5+
import re
6+
57
from typing import TYPE_CHECKING, Any, Callable
68

79
from sqlit.domains.explorer.ui.tree import db_switching as tree_db_switching
@@ -21,6 +23,19 @@
2123
from sqlit.domains.query.app.transaction import TransactionExecutor
2224

2325

26+
_SCHEMA_CHANGE_RE = re.compile(
27+
r"\b(create|alter|drop|truncate|rename|comment|grant|revoke)\b",
28+
re.IGNORECASE,
29+
)
30+
_SQL_COMMENT_RE = re.compile(r"(--[^\n]*|/\*.*?\*/)", re.DOTALL)
31+
_SQL_LITERAL_RE = re.compile(r"('([^']|'')*'|\"([^\"]|\"\")*\"|`[^`]*`|\[[^\]]*\])", re.DOTALL)
32+
33+
34+
def _strip_sql_comments_and_literals(sql: str) -> str:
35+
sql = _SQL_COMMENT_RE.sub(" ", sql)
36+
return _SQL_LITERAL_RE.sub(" ", sql)
37+
38+
2439
class QueryExecutionMixin(ProcessWorkerLifecycleMixin):
2540
"""Mixin providing query execution actions."""
2641

@@ -216,6 +231,21 @@ def _on_result(confirmed: bool | None) -> None:
216231
_on_result,
217232
)
218233

234+
def _query_changes_schema(self: QueryMixinHost, query: str) -> bool:
235+
cleaned = _strip_sql_comments_and_literals(query)
236+
return bool(_SCHEMA_CHANGE_RE.search(cleaned))
237+
238+
def _maybe_refresh_explorer_after_query(self: QueryMixinHost, query: str) -> None:
239+
if not self._query_changes_schema(query):
240+
return
241+
refresh = getattr(self, "_refresh_tree_after_schema_change", None)
242+
if callable(refresh):
243+
refresh()
244+
return
245+
action = getattr(self, "action_refresh_tree", None)
246+
if callable(action):
247+
action()
248+
219249
def _start_query_spinner(self: QueryMixinHost) -> None:
220250
"""Start the query execution spinner animation."""
221251
import time
@@ -481,6 +511,7 @@ async def _run_query_async(self: QueryMixinHost, query: str, keep_insert_mode: b
481511
)
482512
else:
483513
self._display_non_query_result(result.rows_affected, elapsed_ms)
514+
self._maybe_refresh_explorer_after_query(query)
484515
if keep_insert_mode:
485516
self._restore_insert_mode()
486517
return
@@ -500,6 +531,7 @@ async def _run_query_async(self: QueryMixinHost, query: str, keep_insert_mode: b
500531
except Exception:
501532
pass
502533
self._display_multi_statement_results(multi_result, elapsed_ms)
534+
self._maybe_refresh_explorer_after_query(query)
503535
else:
504536
# Single statement - existing behavior
505537
result = await asyncio.to_thread(
@@ -520,6 +552,7 @@ async def _run_query_async(self: QueryMixinHost, query: str, keep_insert_mode: b
520552
)
521553
else:
522554
self._display_non_query_result(result.rows_affected, elapsed_ms)
555+
self._maybe_refresh_explorer_after_query(query)
523556

524557
if keep_insert_mode:
525558
self._restore_insert_mode()
@@ -584,14 +617,17 @@ async def _run_query_atomic_async(self: QueryMixinHost, query: str) -> None:
584617
self.notify("Transaction rolled back (error in statement)", severity="error")
585618
else:
586619
self.notify("Query executed atomically (committed)", severity="information")
620+
self._maybe_refresh_explorer_after_query(query)
587621
elif isinstance(result, QueryResult):
588622
await self._display_query_results(
589623
result.columns, result.rows, result.row_count, result.truncated, elapsed_ms
590624
)
591625
self.notify("Query executed atomically (committed)", severity="information")
626+
self._maybe_refresh_explorer_after_query(query)
592627
else:
593628
self._display_non_query_result(result.rows_affected, elapsed_ms)
594629
self.notify("Query executed atomically (committed)", severity="information")
630+
self._maybe_refresh_explorer_after_query(query)
595631

596632
except Exception as e:
597633
self._display_query_error(f"Transaction rolled back: {e}")

0 commit comments

Comments
 (0)