Skip to content

Commit eb7c03e

Browse files
committed
Fix osquery is_file_based and add SurrealDB SSL support
Set osquery is_file_based=False since it has no file_path field and the True value breaks connection validation. Add use_ssl option to SurrealDB to support wss:// connections.
1 parent 9f9b362 commit eb7c03e

16 files changed

Lines changed: 973 additions & 1 deletion

File tree

pyproject.toml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ all = [
5959
"redshift-connector",
6060
"pyathena>=3.22.0",
6161
"adbc-driver-flightsql>=1.0.0",
62+
"impyla>=0.18.0",
63+
"osquery>=3.0.0",
64+
"surrealdb>=0.3.0",
6265
]
6366
postgres = ["psycopg2-binary>=2.9.0"]
6467
cockroachdb = ["psycopg2-binary>=2.9.0"]
@@ -81,6 +84,9 @@ firebird = ["firebirdsql>=1.3.5"]
8184
snowflake = ["snowflake-connector-python>=3.7.0"]
8285
athena = ["pyathena>=3.22.0"]
8386
flight = ["adbc-driver-flightsql>=1.0.0"]
87+
impala = ["impyla>=0.18.0"]
88+
osquery = ["osquery>=3.0.0"]
89+
surrealdb = ["surrealdb>=0.3.0"]
8490
ssh = [
8591
"sshtunnel>=0.4.0",
8692
"paramiko>=2.0.0,<4.0.0",
@@ -241,7 +247,11 @@ module = [
241247
"adbc_driver_flightsql",
242248
"adbc_driver_flightsql.dbapi",
243249
"adbc_driver_manager",
244-
"textual_fastdatatable"
250+
"textual_fastdatatable",
251+
"impala",
252+
"impala.dbapi",
253+
"osquery",
254+
"surrealdb"
245255
]
246256
ignore_missing_imports = true
247257

sqlit/domains/connections/domain/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,21 @@ class DatabaseType(str, Enum):
1919
FIREBIRD = "firebird"
2020
FLIGHT = "flight"
2121
HANA = "hana"
22+
IMPALA = "impala"
2223
MARIADB = "mariadb"
2324
MOTHERDUCK = "motherduck"
2425
MSSQL = "mssql"
2526
MYSQL = "mysql"
2627
ORACLE = "oracle"
2728
ORACLE_LEGACY = "oracle_legacy"
29+
OSQUERY = "osquery"
2830
POSTGRESQL = "postgresql"
2931
PRESTO = "presto"
3032
REDSHIFT = "redshift"
3133
SNOWFLAKE = "snowflake"
3234
SQLITE = "sqlite"
3335
SUPABASE = "supabase"
36+
SURREALDB = "surrealdb"
3437
TERADATA = "teradata"
3538
TRINO = "trino"
3639
TURSO = "turso"
@@ -52,6 +55,7 @@ class DatabaseType(str, Enum):
5255
DatabaseType.BIGQUERY,
5356
DatabaseType.TRINO,
5457
DatabaseType.PRESTO,
58+
DatabaseType.IMPALA,
5559
DatabaseType.DUCKDB,
5660
DatabaseType.MOTHERDUCK,
5761
DatabaseType.REDSHIFT,
@@ -63,6 +67,8 @@ class DatabaseType(str, Enum):
6367
DatabaseType.ATHENA,
6468
DatabaseType.FIREBIRD,
6569
DatabaseType.FLIGHT,
70+
DatabaseType.OSQUERY,
71+
DatabaseType.SURREALDB,
6672
]
6773

6874

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Impala provider package."""
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
"""Impala adapter using impyla."""
2+
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING, Any
6+
7+
from sqlit.domains.connections.providers.adapters.base import (
8+
ColumnInfo,
9+
CursorBasedAdapter,
10+
IndexInfo,
11+
SequenceInfo,
12+
TableInfo,
13+
TriggerInfo,
14+
)
15+
from sqlit.domains.connections.providers.registry import get_default_port
16+
17+
if TYPE_CHECKING:
18+
from sqlit.domains.connections.domain.config import ConnectionConfig
19+
20+
21+
class ImpalaAdapter(CursorBasedAdapter):
22+
"""Adapter for Apache Impala using impyla."""
23+
24+
@property
25+
def name(self) -> str:
26+
return "Impala"
27+
28+
@property
29+
def install_extra(self) -> str:
30+
return "impala"
31+
32+
@property
33+
def install_package(self) -> str:
34+
return "impyla"
35+
36+
@property
37+
def driver_import_names(self) -> tuple[str, ...]:
38+
return ("impala.dbapi",)
39+
40+
@property
41+
def supports_multiple_databases(self) -> bool:
42+
return True
43+
44+
@property
45+
def supports_cross_database_queries(self) -> bool:
46+
return True
47+
48+
@property
49+
def supports_stored_procedures(self) -> bool:
50+
return False
51+
52+
@property
53+
def supports_indexes(self) -> bool:
54+
return False # Impala uses partitions, not traditional indexes
55+
56+
@property
57+
def supports_triggers(self) -> bool:
58+
return False
59+
60+
@property
61+
def supports_sequences(self) -> bool:
62+
return False
63+
64+
@property
65+
def system_databases(self) -> frozenset[str]:
66+
return frozenset({"_impala_builtins"})
67+
68+
@property
69+
def default_schema(self) -> str:
70+
return ""
71+
72+
def apply_database_override(self, config: ConnectionConfig, database: str) -> ConnectionConfig:
73+
"""Apply a default database for unqualified queries."""
74+
if not database:
75+
return config
76+
return config.with_endpoint(database=database)
77+
78+
def connect(self, config: ConnectionConfig) -> Any:
79+
impala_dbapi = self._import_driver_module(
80+
"impala.dbapi",
81+
driver_name=self.name,
82+
extra_name=self.install_extra,
83+
package_name=self.install_package,
84+
)
85+
86+
endpoint = config.tcp_endpoint
87+
if endpoint is None:
88+
raise ValueError("Impala connections require a TCP-style endpoint.")
89+
port = int(endpoint.port or get_default_port("impala"))
90+
91+
auth_mechanism = str(config.get_option("auth_mechanism", "NOSASL"))
92+
use_ssl = str(config.get_option("use_ssl", "false")).lower() == "true"
93+
94+
connect_args: dict[str, Any] = {
95+
"host": endpoint.host,
96+
"port": port,
97+
"auth_mechanism": auth_mechanism,
98+
"use_ssl": use_ssl,
99+
}
100+
101+
if endpoint.database:
102+
connect_args["database"] = endpoint.database
103+
104+
if endpoint.username:
105+
connect_args["user"] = endpoint.username
106+
if endpoint.password:
107+
connect_args["password"] = endpoint.password
108+
109+
connect_args.update(config.extra_options)
110+
return impala_dbapi.connect(**connect_args)
111+
112+
def get_databases(self, conn: Any) -> list[str]:
113+
cursor = conn.cursor()
114+
cursor.execute("SHOW DATABASES")
115+
return [row[0] for row in cursor.fetchall()]
116+
117+
def get_tables(self, conn: Any, database: str | None = None) -> list[TableInfo]:
118+
cursor = conn.cursor()
119+
if database:
120+
cursor.execute(f"SHOW TABLES IN {self.quote_identifier(database)}")
121+
else:
122+
cursor.execute("SHOW TABLES")
123+
return [("", row[0]) for row in cursor.fetchall()]
124+
125+
def get_views(self, conn: Any, database: str | None = None) -> list[TableInfo]:
126+
# Impala doesn't distinguish views in SHOW TABLES by default
127+
# We can query from information_schema if available
128+
cursor = conn.cursor()
129+
try:
130+
if database:
131+
cursor.execute(
132+
f"SELECT table_name FROM {self.quote_identifier(database)}.information_schema.tables "
133+
"WHERE table_type = 'VIEW' ORDER BY table_name"
134+
)
135+
else:
136+
cursor.execute(
137+
"SELECT table_name FROM information_schema.tables "
138+
"WHERE table_type = 'VIEW' ORDER BY table_name"
139+
)
140+
return [("", row[0]) for row in cursor.fetchall()]
141+
except Exception:
142+
# information_schema might not be available
143+
return []
144+
145+
def get_columns(
146+
self, conn: Any, table: str, database: str | None = None, schema: str | None = None
147+
) -> list[ColumnInfo]:
148+
cursor = conn.cursor()
149+
if database:
150+
cursor.execute(f"DESCRIBE {self.quote_identifier(database)}.{self.quote_identifier(table)}")
151+
else:
152+
cursor.execute(f"DESCRIBE {self.quote_identifier(table)}")
153+
return [ColumnInfo(name=row[0], data_type=row[1]) for row in cursor.fetchall()]
154+
155+
def get_procedures(self, conn: Any, database: str | None = None) -> list[str]:
156+
return []
157+
158+
def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo]:
159+
return []
160+
161+
def get_triggers(self, conn: Any, database: str | None = None) -> list[TriggerInfo]:
162+
return []
163+
164+
def get_sequences(self, conn: Any, database: str | None = None) -> list[SequenceInfo]:
165+
return []
166+
167+
def quote_identifier(self, name: str) -> str:
168+
escaped = name.replace("`", "``")
169+
return f"`{escaped}`"
170+
171+
def build_select_query(
172+
self, table: str, limit: int, database: str | None = None, schema: str | None = None
173+
) -> str:
174+
if database:
175+
return f"SELECT * FROM `{database}`.`{table}` LIMIT {limit}"
176+
return f"SELECT * FROM `{table}` LIMIT {limit}"
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""Provider registration for Impala."""
2+
3+
from sqlit.domains.connections.providers.adapter_provider import build_adapter_provider
4+
from sqlit.domains.connections.providers.catalog import register_provider
5+
from sqlit.domains.connections.providers.impala.schema import SCHEMA
6+
from sqlit.domains.connections.providers.model import DatabaseProvider, ProviderSpec
7+
8+
9+
def _provider_factory(spec: ProviderSpec) -> DatabaseProvider:
10+
from sqlit.domains.connections.providers.impala.adapter import ImpalaAdapter
11+
12+
return build_adapter_provider(spec, SCHEMA, ImpalaAdapter())
13+
14+
15+
SPEC = ProviderSpec(
16+
db_type="impala",
17+
display_name="Impala",
18+
schema_path=("sqlit.domains.connections.providers.impala.schema", "SCHEMA"),
19+
supports_ssh=True,
20+
is_file_based=False,
21+
has_advanced_auth=True, # Kerberos support
22+
default_port="21050",
23+
requires_auth=False,
24+
badge_label="Impala",
25+
url_schemes=("impala",),
26+
provider_factory=_provider_factory,
27+
)
28+
29+
register_provider(SPEC)
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""Connection schema for Impala."""
2+
3+
from sqlit.domains.connections.providers.schema_helpers import (
4+
SSH_FIELDS,
5+
ConnectionSchema,
6+
FieldType,
7+
SchemaField,
8+
SelectOption,
9+
_password_field,
10+
_port_field,
11+
_server_field,
12+
_username_field,
13+
)
14+
15+
16+
def _get_auth_mechanism_options() -> tuple[SelectOption, ...]:
17+
return (
18+
SelectOption("NOSASL", "No Auth"),
19+
SelectOption("PLAIN", "PLAIN (LDAP)"),
20+
SelectOption("GSSAPI", "Kerberos (GSSAPI)"),
21+
)
22+
23+
24+
SCHEMA = ConnectionSchema(
25+
db_type="impala",
26+
display_name="Impala",
27+
fields=(
28+
_server_field(),
29+
_port_field("21050"),
30+
SchemaField(
31+
name="database",
32+
label="Database",
33+
placeholder="default",
34+
required=False,
35+
),
36+
_username_field(required=False),
37+
_password_field(),
38+
SchemaField(
39+
name="auth_mechanism",
40+
label="Auth Mechanism",
41+
field_type=FieldType.SELECT,
42+
options=_get_auth_mechanism_options(),
43+
default="NOSASL",
44+
advanced=True,
45+
),
46+
SchemaField(
47+
name="use_ssl",
48+
label="Use SSL",
49+
field_type=FieldType.SELECT,
50+
options=(
51+
SelectOption("false", "No"),
52+
SelectOption("true", "Yes"),
53+
),
54+
default="false",
55+
advanced=True,
56+
),
57+
)
58+
+ SSH_FIELDS,
59+
default_port="21050",
60+
requires_auth=False,
61+
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""osquery provider package."""

0 commit comments

Comments
 (0)