Skip to content

Commit 70e4cd0

Browse files
committed
feat: pass extra_options to database drivers
- Add extra_options pass-through to all adapters, allowing users to pass custom driver parameters via connections.json or CLI URLs - Add Snowflake authentication dropdown with support for: - Username & Password (default) - SSO (Browser) - Key Pair (JWT) - OAuth Token - Add conditional fields for private key file and password when JWT is selected
1 parent dd3d4ca commit 70e4cd0

21 files changed

Lines changed: 290 additions & 13 deletions

File tree

sqlit/domains/connections/providers/clickhouse/adapter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def connect(self, config: ConnectionConfig) -> Any:
141141
if tls_mode != TLS_MODE_DEFAULT:
142142
connect_args["verify"] = tls_mode != TLS_MODE_REQUIRE
143143

144+
connect_args.update(config.extra_options)
144145
client = clickhouse_connect.get_client(**connect_args)
145146
return client
146147

sqlit/domains/connections/providers/cockroachdb/adapter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def connect(self, config: ConnectionConfig) -> Any:
8787
if tls_key_password:
8888
connect_args["sslpassword"] = tls_key_password
8989

90+
connect_args.update(config.extra_options)
9091
conn = psycopg2.connect(**connect_args)
9192
# Enable autocommit to avoid transaction issues
9293
conn.autocommit = True

sqlit/domains/connections/providers/db2/adapter.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,9 @@ def connect(self, config: ConnectionConfig) -> Any:
8181
f"UID={endpoint.username};"
8282
f"PWD={endpoint.password};"
8383
)
84-
return ibm_db_dbi.connect(conn_str, "", "")
84+
connect_args: dict[str, Any] = {}
85+
connect_args.update(config.extra_options)
86+
return ibm_db_dbi.connect(conn_str, "", "", **connect_args)
8587

8688
def get_databases(self, conn: Any) -> list[str]:
8789
return []

sqlit/domains/connections/providers/duckdb/adapter.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,9 @@ def connect(self, config: ConnectionConfig) -> Any:
8383
raise ValueError("DuckDB connections require a file endpoint.")
8484
file_path = resolve_file_path(str(file_endpoint.path))
8585
duckdb_any: Any = duckdb
86-
return duckdb_any.connect(str(file_path))
86+
connect_args: dict[str, Any] = {}
87+
connect_args.update(config.extra_options)
88+
return duckdb_any.connect(str(file_path), **connect_args)
8789

8890
def get_databases(self, conn: Any) -> list[str]:
8991
"""DuckDB doesn't support multiple databases - return empty list."""

sqlit/domains/connections/providers/firebird/adapter.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,15 @@ def connect(self, config: "ConnectionConfig") -> Any:
7272
endpoint = config.tcp_endpoint
7373
if endpoint is None:
7474
raise ValueError("Firebird connections require a TCP-style endpoint.")
75-
conn = firebirdsql.connect(
76-
host=endpoint.host or "localhost",
77-
port=int(endpoint.port) if endpoint.port else 3050,
78-
database=endpoint.database or "security.db",
79-
user=endpoint.username,
80-
password=endpoint.password,
81-
)
75+
connect_args: dict[str, Any] = {
76+
"host": endpoint.host or "localhost",
77+
"port": int(endpoint.port) if endpoint.port else 3050,
78+
"database": endpoint.database or "security.db",
79+
"user": endpoint.username,
80+
"password": endpoint.password,
81+
}
82+
connect_args.update(config.extra_options)
83+
conn = firebirdsql.connect(**connect_args)
8284
return conn
8385

8486
def get_databases(self, conn: Any) -> list[str]:

sqlit/domains/connections/providers/flight/adapter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def connect(self, config: ConnectionConfig) -> Any:
130130
if use_tls and config.get_option("flight_skip_verify", "false") == "true":
131131
db_kwargs["adbc.flight.sql.client_option.tls_skip_verify"] = "true"
132132

133+
db_kwargs.update(config.extra_options)
133134
conn = flight_sql.connect(uri, db_kwargs=db_kwargs)
134135

135136
# Store the catalog/database for later use

sqlit/domains/connections/providers/hana/adapter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def connect(self, config: ConnectionConfig) -> Any:
8282
if schema:
8383
connect_args["currentSchema"] = schema
8484

85+
connect_args.update(config.extra_options)
8586
return hdbcli.connect(**connect_args)
8687

8788
def get_databases(self, conn: Any) -> list[str]:

sqlit/domains/connections/providers/mariadb/adapter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def connect(self, config: ConnectionConfig) -> Any:
100100
connect_args["ssl_verify_cert"] = tls_mode_verifies_cert(tls_mode)
101101
connect_args["ssl_verify_identity"] = tls_mode_verifies_hostname(tls_mode)
102102

103+
connect_args.update(config.extra_options)
103104
conn = mariadb_any.connect(**connect_args)
104105
self._supports_sequences = self._detect_sequences_support(conn)
105106
return conn

sqlit/domains/connections/providers/mssql/adapter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,9 @@ def connect(self, config: ConnectionConfig) -> Any:
180180
)
181181

182182
conn_str = self._build_connection_string(config)
183+
# Append extra_options to connection string
184+
for key, value in config.extra_options.items():
185+
conn_str += f"{key}={value};"
183186
return mssql_python.connect(conn_str)
184187

185188
def get_databases(self, conn: Any) -> list[str]:

sqlit/domains/connections/providers/mysql/adapter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,4 +110,5 @@ def connect(self, config: ConnectionConfig) -> Any:
110110
ssl_params["check_hostname"] = tls_mode_verifies_hostname(tls_mode)
111111
connect_args["ssl"] = ssl_params
112112

113+
connect_args.update(config.extra_options)
113114
return pymysql.connect(**connect_args)

0 commit comments

Comments
 (0)