Skip to content

Commit 87566dd

Browse files
committed
Fix osquery is_file_based and add SurrealDB SSL support
Set osquery is_file_based=False since it has no file_path field and the True value breaks connection validation. Add use_ssl option to SurrealDB to support wss:// connections.
1 parent 9f9b362 commit 87566dd

File tree

4 files changed

+427
-0
lines changed

4 files changed

+427
-0
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""Provider registration for osquery."""
2+
3+
from sqlit.domains.connections.providers.adapter_provider import build_adapter_provider
4+
from sqlit.domains.connections.providers.catalog import register_provider
5+
from sqlit.domains.connections.providers.model import DatabaseProvider, ProviderSpec
6+
from sqlit.domains.connections.providers.osquery.schema import SCHEMA
7+
8+
9+
def _provider_factory(spec: ProviderSpec) -> DatabaseProvider:
10+
from sqlit.domains.connections.providers.osquery.adapter import OsqueryAdapter
11+
12+
return build_adapter_provider(spec, SCHEMA, OsqueryAdapter())
13+
14+
15+
SPEC = ProviderSpec(
16+
db_type="osquery",
17+
display_name="osquery",
18+
schema_path=("sqlit.domains.connections.providers.osquery.schema", "SCHEMA"),
19+
supports_ssh=False,
20+
is_file_based=False,
21+
has_advanced_auth=False,
22+
default_port="",
23+
requires_auth=False,
24+
badge_label="osq",
25+
url_schemes=("osquery",),
26+
provider_factory=_provider_factory,
27+
)
28+
29+
register_provider(SPEC)
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
"""Connection schema for osquery."""
2+
3+
from sqlit.domains.connections.providers.schema_helpers import (
4+
ConnectionSchema,
5+
FieldType,
6+
SchemaField,
7+
SelectOption,
8+
)
9+
10+
11+
def _connection_mode_is_socket(v: dict) -> bool:
12+
return v.get("connection_mode") == "socket"
13+
14+
15+
SCHEMA = ConnectionSchema(
16+
db_type="osquery",
17+
display_name="osquery",
18+
fields=(
19+
SchemaField(
20+
name="connection_mode",
21+
label="Connection Mode",
22+
field_type=FieldType.SELECT,
23+
options=(
24+
SelectOption("spawn", "Spawn Instance (embedded)"),
25+
SelectOption("socket", "Connect to Socket"),
26+
),
27+
default="spawn",
28+
),
29+
SchemaField(
30+
name="socket_path",
31+
label="Socket Path",
32+
placeholder="/var/osquery/osquery.em",
33+
required=False,
34+
visible_when=_connection_mode_is_socket,
35+
description="Path to osqueryd extension socket",
36+
),
37+
),
38+
supports_ssh=False,
39+
is_file_based=False,
40+
default_port="",
41+
requires_auth=False,
42+
)
Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
"""SurrealDB adapter using surrealdb.py SDK."""
2+
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING, Any
6+
7+
from sqlit.domains.connections.providers.adapters.base import (
8+
ColumnInfo,
9+
DatabaseAdapter,
10+
IndexInfo,
11+
SequenceInfo,
12+
TableInfo,
13+
TriggerInfo,
14+
)
15+
from sqlit.domains.connections.providers.registry import get_default_port
16+
17+
if TYPE_CHECKING:
18+
from sqlit.domains.connections.domain.config import ConnectionConfig
19+
20+
21+
class SurrealDBAdapter(DatabaseAdapter):
22+
"""Adapter for SurrealDB using the official Python SDK.
23+
24+
SurrealDB is a multi-model database that uses SurrealQL,
25+
a query language similar to SQL but with some differences.
26+
"""
27+
28+
@property
29+
def name(self) -> str:
30+
return "SurrealDB"
31+
32+
@property
33+
def install_extra(self) -> str:
34+
return "surrealdb"
35+
36+
@property
37+
def install_package(self) -> str:
38+
return "surrealdb"
39+
40+
@property
41+
def driver_import_names(self) -> tuple[str, ...]:
42+
return ("surrealdb",)
43+
44+
@property
45+
def supports_multiple_databases(self) -> bool:
46+
return True # Namespace/database hierarchy
47+
48+
@property
49+
def supports_cross_database_queries(self) -> bool:
50+
return False # Must use() a specific database
51+
52+
@property
53+
def supports_stored_procedures(self) -> bool:
54+
return False
55+
56+
@property
57+
def supports_indexes(self) -> bool:
58+
return True
59+
60+
@property
61+
def supports_triggers(self) -> bool:
62+
return False
63+
64+
@property
65+
def supports_sequences(self) -> bool:
66+
return False
67+
68+
@property
69+
def supports_process_worker(self) -> bool:
70+
# WebSocket connections may not work well across process boundaries
71+
return False
72+
73+
@property
74+
def default_schema(self) -> str:
75+
return ""
76+
77+
@property
78+
def test_query(self) -> str:
79+
return "RETURN 1"
80+
81+
def connect(self, config: ConnectionConfig) -> Any:
82+
surrealdb_module = self._import_driver_module(
83+
"surrealdb",
84+
driver_name=self.name,
85+
extra_name=self.install_extra,
86+
package_name=self.install_package,
87+
)
88+
89+
endpoint = config.tcp_endpoint
90+
if endpoint is None:
91+
raise ValueError("SurrealDB connections require a TCP-style endpoint.")
92+
port = int(endpoint.port or get_default_port("surrealdb"))
93+
94+
# Build WebSocket URL
95+
use_ssl = str(config.get_option("use_ssl", "false")).lower() == "true"
96+
scheme = "wss" if use_ssl else "ws"
97+
url = f"{scheme}://{endpoint.host}:{port}/rpc"
98+
99+
# Create sync connection
100+
db = surrealdb_module.Surreal(url)
101+
db.connect()
102+
103+
# Sign in if credentials provided
104+
if endpoint.username and endpoint.password:
105+
db.signin({"user": endpoint.username, "pass": endpoint.password})
106+
107+
# Select namespace and database
108+
namespace = config.get_option("namespace", "test")
109+
database = endpoint.database or config.get_option("database", "test")
110+
db.use(namespace, database)
111+
112+
return db
113+
114+
def disconnect(self, conn: Any) -> None:
115+
if hasattr(conn, "close"):
116+
conn.close()
117+
118+
def execute_test_query(self, conn: Any) -> None:
119+
"""Execute a simple query to verify the connection works."""
120+
result = conn.query("RETURN 1")
121+
if not result:
122+
raise Exception("SurrealDB test query returned no result")
123+
124+
def get_databases(self, conn: Any) -> list[str]:
125+
"""Get list of databases in the current namespace."""
126+
try:
127+
result = conn.query("INFO FOR NS")
128+
if result and isinstance(result, list) and result[0]:
129+
info = result[0]
130+
if isinstance(info, dict) and "databases" in info:
131+
return list(info["databases"].keys())
132+
except Exception:
133+
pass
134+
return []
135+
136+
def get_tables(self, conn: Any, database: str | None = None) -> list[TableInfo]:
137+
"""Get list of tables in the current database."""
138+
try:
139+
result = conn.query("INFO FOR DB")
140+
if result and isinstance(result, list) and result[0]:
141+
info = result[0]
142+
if isinstance(info, dict) and "tables" in info:
143+
tables = list(info["tables"].keys())
144+
return [("", t) for t in sorted(tables)]
145+
except Exception:
146+
pass
147+
return []
148+
149+
def get_views(self, conn: Any, database: str | None = None) -> list[TableInfo]:
150+
# SurrealDB doesn't have traditional views
151+
return []
152+
153+
def get_columns(
154+
self, conn: Any, table: str, database: str | None = None, schema: str | None = None
155+
) -> list[ColumnInfo]:
156+
"""Get column information for a table.
157+
158+
SurrealDB is schemaless by default, so we sample records to infer columns.
159+
If a schema is defined, we use INFO FOR TABLE.
160+
"""
161+
columns: list[ColumnInfo] = []
162+
163+
try:
164+
# First try to get schema info
165+
result = conn.query(f"INFO FOR TABLE {table}")
166+
if result and isinstance(result, list) and result[0]:
167+
info = result[0]
168+
if isinstance(info, dict) and "fields" in info:
169+
for field_name, field_def in info["fields"].items():
170+
# field_def might contain type info
171+
data_type = "any"
172+
if isinstance(field_def, dict) and "type" in field_def:
173+
data_type = str(field_def["type"])
174+
elif isinstance(field_def, str):
175+
# Try to extract type from definition string
176+
data_type = field_def.split()[0] if field_def else "any"
177+
columns.append(ColumnInfo(name=field_name, data_type=data_type))
178+
179+
# If no schema fields, sample data to infer columns
180+
if not columns:
181+
sample = conn.query(f"SELECT * FROM {table} LIMIT 1")
182+
if sample and isinstance(sample, list) and sample[0]:
183+
first_row = sample[0]
184+
if isinstance(first_row, dict):
185+
for key in first_row.keys():
186+
if key != "id": # id is always present
187+
value = first_row[key]
188+
data_type = type(value).__name__ if value is not None else "any"
189+
columns.append(ColumnInfo(name=key, data_type=data_type))
190+
# Add id column first
191+
columns.insert(0, ColumnInfo(name="id", data_type="record"))
192+
except Exception:
193+
pass
194+
195+
return columns
196+
197+
def get_procedures(self, conn: Any, database: str | None = None) -> list[str]:
198+
return []
199+
200+
def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo]:
201+
"""Get list of indexes across all tables."""
202+
indexes: list[IndexInfo] = []
203+
try:
204+
result = conn.query("INFO FOR DB")
205+
if result and isinstance(result, list) and result[0]:
206+
info = result[0]
207+
if isinstance(info, dict) and "tables" in info:
208+
for table_name in info["tables"].keys():
209+
table_info = conn.query(f"INFO FOR TABLE {table_name}")
210+
if table_info and isinstance(table_info, list) and table_info[0]:
211+
t_info = table_info[0]
212+
if isinstance(t_info, dict) and "indexes" in t_info:
213+
for idx_name, idx_def in t_info["indexes"].items():
214+
is_unique = "UNIQUE" in str(idx_def).upper() if idx_def else False
215+
indexes.append(IndexInfo(
216+
name=idx_name,
217+
table_name=table_name,
218+
is_unique=is_unique
219+
))
220+
except Exception:
221+
pass
222+
return indexes
223+
224+
def get_triggers(self, conn: Any, database: str | None = None) -> list[TriggerInfo]:
225+
return []
226+
227+
def get_sequences(self, conn: Any, database: str | None = None) -> list[SequenceInfo]:
228+
return []
229+
230+
def quote_identifier(self, name: str) -> str:
231+
# SurrealDB uses backticks for identifiers with special characters
232+
if any(c in name for c in " -./"):
233+
escaped = name.replace("`", "``")
234+
return f"`{escaped}`"
235+
return name
236+
237+
def build_select_query(
238+
self, table: str, limit: int, database: str | None = None, schema: str | None = None
239+
) -> str:
240+
return f"SELECT * FROM {self.quote_identifier(table)} LIMIT {limit}"
241+
242+
def execute_query(
243+
self, conn: Any, query: str, max_rows: int | None = None
244+
) -> tuple[list[str], list[tuple], bool]:
245+
"""Execute a query and return (columns, rows, truncated)."""
246+
result = conn.query(query)
247+
248+
if not result:
249+
return [], [], False
250+
251+
# SurrealDB returns a list of results (one per statement)
252+
# For a single query, we take the first result
253+
data = result[0] if isinstance(result, list) else result
254+
255+
# Handle single value returns (like RETURN 1)
256+
if not isinstance(data, (list, dict)):
257+
return ["result"], [(data,)], False
258+
259+
# Handle empty results
260+
if isinstance(data, list) and not data:
261+
return [], [], False
262+
263+
# Handle list of records
264+
if isinstance(data, list):
265+
if not data:
266+
return [], [], False
267+
first = data[0]
268+
if isinstance(first, dict):
269+
columns = list(first.keys())
270+
all_rows = [tuple(row.get(col) for col in columns) for row in data]
271+
if max_rows is not None and len(all_rows) > max_rows:
272+
return columns, all_rows[:max_rows], True
273+
return columns, all_rows, False
274+
# List of non-dict values
275+
rows = [(v,) for v in (data[:max_rows] if max_rows else data)]
276+
truncated = max_rows is not None and len(data) > max_rows
277+
return ["value"], rows, truncated
278+
279+
# Handle single dict
280+
if isinstance(data, dict):
281+
columns = list(data.keys())
282+
return columns, [tuple(data.values())], False
283+
284+
return [], [], False
285+
286+
def execute_non_query(self, conn: Any, query: str) -> int:
287+
"""Execute a non-query statement."""
288+
result = conn.query(query)
289+
# SurrealDB doesn't return row counts in the traditional sense
290+
# Return 1 if operation succeeded
291+
if result is not None:
292+
if isinstance(result, list) and result:
293+
data = result[0]
294+
if isinstance(data, list):
295+
return len(data)
296+
return 1
297+
return 0
298+
299+
def classify_query(self, query: str) -> bool:
300+
"""Return True if the query is expected to return rows."""
301+
query_type = query.strip().upper().split()[0] if query.strip() else ""
302+
# SurrealQL query types that return data
303+
return query_type in {
304+
"SELECT", "RETURN", "INFO", "SHOW", "LIVE",
305+
"CREATE", "INSERT", "UPDATE", "UPSERT", "DELETE" # These also return the affected records
306+
}

0 commit comments

Comments
 (0)