Skip to content

Commit 82c11b2

Browse files
committed
Fix SurrealDB adapter for surrealdb SDK 1.0
The adapter was written against the 0.3 async SDK. Its blocking methods are synchronous in 1.0+, so drop the old `db.connect()` call, rename signin keys from `user`/`pass` to `username`/`password`, and stop unwrapping an outer list from query results (scalars for RETURN, dict for INFO, list[dict] for SELECT are already returned directly). Convert RecordID values to strings so CLI output stays readable. Pin surrealdb>=1.0.0 in the extras.
1 parent f8737bf commit 82c11b2

3 files changed

Lines changed: 1263 additions & 219 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ all = [
6262
"adbc-driver-flightsql>=1.0.0",
6363
"impyla>=0.18.0",
6464
"osquery>=3.0.0",
65-
"surrealdb>=0.3.0",
65+
"surrealdb>=1.0.0",
6666
]
6767
postgres = ["psycopg2-binary>=2.9.0"]
6868
cockroachdb = ["psycopg2-binary>=2.9.0"]
@@ -88,7 +88,7 @@ athena = ["pyathena>=3.22.0"]
8888
flight = ["adbc-driver-flightsql>=1.0.0"]
8989
impala = ["impyla>=0.18.0"]
9090
osquery = ["osquery>=3.0.0"]
91-
surrealdb = ["surrealdb>=0.3.0"]
91+
surrealdb = ["surrealdb>=1.0.0"]
9292
ssh = [
9393
"sshtunnel>=0.4.0",
9494
"paramiko>=2.0.0,<4.0.0",

sqlit/domains/connections/providers/surrealdb/adapter.py

Lines changed: 84 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,16 @@
1818
from sqlit.domains.connections.domain.config import ConnectionConfig
1919

2020

21+
def _to_plain(value: Any) -> Any:
22+
"""Convert SurrealDB SDK types (RecordID, etc.) into something printable."""
23+
if value is None or isinstance(value, (bool, int, float, str)):
24+
return value
25+
cls_name = type(value).__name__
26+
if cls_name == "RecordID":
27+
return str(value)
28+
return value
29+
30+
2131
class SurrealDBAdapter(DatabaseAdapter):
2232
"""Adapter for SurrealDB using the official Python SDK.
2333
@@ -96,15 +106,12 @@ def connect(self, config: ConnectionConfig) -> Any:
96106
scheme = "wss" if use_ssl else "ws"
97107
url = f"{scheme}://{endpoint.host}:{port}/rpc"
98108

99-
# Create sync connection
109+
# Surreal() is blocking in surrealdb>=1.0; opens the socket in __init__
100110
db = surrealdb_module.Surreal(url)
101-
db.connect()
102111

103-
# Sign in if credentials provided
104112
if endpoint.username and endpoint.password:
105-
db.signin({"user": endpoint.username, "pass": endpoint.password})
113+
db.signin({"username": endpoint.username, "password": endpoint.password})
106114

107-
# Select namespace and database
108115
namespace = config.get_option("namespace", "test")
109116
database = endpoint.database or config.get_option("database", "test")
110117
db.use(namespace, database)
@@ -117,31 +124,25 @@ def disconnect(self, conn: Any) -> None:
117124

118125
def execute_test_query(self, conn: Any) -> None:
119126
"""Execute a simple query to verify the connection works."""
120-
result = conn.query("RETURN 1")
121-
if not result:
122-
raise Exception("SurrealDB test query returned no result")
127+
# query() raises on error; a successful RETURN 1 returns the int 1.
128+
conn.query("RETURN 1")
123129

124130
def get_databases(self, conn: Any) -> list[str]:
125131
"""Get list of databases in the current namespace."""
126132
try:
127-
result = conn.query("INFO FOR NS")
128-
if result and isinstance(result, list) and result[0]:
129-
info = result[0]
130-
if isinstance(info, dict) and "databases" in info:
131-
return list(info["databases"].keys())
133+
info = conn.query("INFO FOR NS")
134+
if isinstance(info, dict) and "databases" in info:
135+
return list(info["databases"].keys())
132136
except Exception:
133137
pass
134138
return []
135139

136140
def get_tables(self, conn: Any, database: str | None = None) -> list[TableInfo]:
137141
"""Get list of tables in the current database."""
138142
try:
139-
result = conn.query("INFO FOR DB")
140-
if result and isinstance(result, list) and result[0]:
141-
info = result[0]
142-
if isinstance(info, dict) and "tables" in info:
143-
tables = list(info["tables"].keys())
144-
return [("", t) for t in sorted(tables)]
143+
info = conn.query("INFO FOR DB")
144+
if isinstance(info, dict) and "tables" in info:
145+
return [("", t) for t in sorted(info["tables"].keys())]
145146
except Exception:
146147
pass
147148
return []
@@ -161,33 +162,40 @@ def get_columns(
161162
columns: list[ColumnInfo] = []
162163

163164
try:
164-
# First try to get schema info
165-
result = conn.query(f"INFO FOR TABLE {table}")
166-
if result and isinstance(result, list) and result[0]:
167-
info = result[0]
168-
if isinstance(info, dict) and "fields" in info:
169-
for field_name, field_def in info["fields"].items():
170-
# field_def might contain type info
171-
data_type = "any"
172-
if isinstance(field_def, dict) and "type" in field_def:
173-
data_type = str(field_def["type"])
174-
elif isinstance(field_def, str):
175-
# Try to extract type from definition string
176-
data_type = field_def.split()[0] if field_def else "any"
177-
columns.append(ColumnInfo(name=field_name, data_type=data_type))
178-
179-
# If no schema fields, sample data to infer columns
165+
info = conn.query(f"INFO FOR TABLE {table}")
166+
if isinstance(info, dict) and "fields" in info and info["fields"]:
167+
# SurrealDB's INFO FOR TABLE only lists explicitly-defined
168+
# fields, but every record implicitly has an `id`. Surface it
169+
# as the primary key column so consumers can detect it.
170+
columns.append(
171+
ColumnInfo(name="id", data_type="record", is_primary_key=True)
172+
)
173+
for field_name, field_def in info["fields"].items():
174+
data_type = "any"
175+
if isinstance(field_def, dict) and "type" in field_def:
176+
data_type = str(field_def["type"])
177+
elif isinstance(field_def, str):
178+
# e.g. "DEFINE FIELD name ON t1 TYPE string PERMISSIONS FULL"
179+
parts = field_def.split()
180+
if "TYPE" in parts:
181+
idx = parts.index("TYPE")
182+
if idx + 1 < len(parts):
183+
data_type = parts[idx + 1]
184+
elif parts:
185+
data_type = parts[0]
186+
columns.append(ColumnInfo(name=field_name, data_type=data_type))
187+
180188
if not columns:
181189
sample = conn.query(f"SELECT * FROM {table} LIMIT 1")
182-
if sample and isinstance(sample, list) and sample[0]:
190+
if isinstance(sample, list) and sample:
183191
first_row = sample[0]
184192
if isinstance(first_row, dict):
185193
for key in first_row.keys():
186-
if key != "id": # id is always present
187-
value = first_row[key]
188-
data_type = type(value).__name__ if value is not None else "any"
189-
columns.append(ColumnInfo(name=key, data_type=data_type))
190-
# Add id column first
194+
if key == "id":
195+
continue
196+
value = first_row[key]
197+
data_type = type(value).__name__ if value is not None else "any"
198+
columns.append(ColumnInfo(name=key, data_type=data_type))
191199
columns.insert(0, ColumnInfo(name="id", data_type="record"))
192200
except Exception:
193201
pass
@@ -201,22 +209,20 @@ def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo]
201209
"""Get list of indexes across all tables."""
202210
indexes: list[IndexInfo] = []
203211
try:
204-
result = conn.query("INFO FOR DB")
205-
if result and isinstance(result, list) and result[0]:
206-
info = result[0]
207-
if isinstance(info, dict) and "tables" in info:
208-
for table_name in info["tables"].keys():
209-
table_info = conn.query(f"INFO FOR TABLE {table_name}")
210-
if table_info and isinstance(table_info, list) and table_info[0]:
211-
t_info = table_info[0]
212-
if isinstance(t_info, dict) and "indexes" in t_info:
213-
for idx_name, idx_def in t_info["indexes"].items():
214-
is_unique = "UNIQUE" in str(idx_def).upper() if idx_def else False
215-
indexes.append(IndexInfo(
216-
name=idx_name,
217-
table_name=table_name,
218-
is_unique=is_unique
219-
))
212+
info = conn.query("INFO FOR DB")
213+
if not (isinstance(info, dict) and "tables" in info):
214+
return []
215+
for table_name in info["tables"].keys():
216+
t_info = conn.query(f"INFO FOR TABLE {table_name}")
217+
if not (isinstance(t_info, dict) and "indexes" in t_info):
218+
continue
219+
for idx_name, idx_def in t_info["indexes"].items():
220+
is_unique = "UNIQUE" in str(idx_def).upper() if idx_def else False
221+
indexes.append(IndexInfo(
222+
name=idx_name,
223+
table_name=table_name,
224+
is_unique=is_unique,
225+
))
220226
except Exception:
221227
pass
222228
return indexes
@@ -242,59 +248,47 @@ def build_select_query(
242248
def execute_query(
243249
self, conn: Any, query: str, max_rows: int | None = None
244250
) -> tuple[list[str], list[tuple], bool]:
245-
"""Execute a query and return (columns, rows, truncated)."""
246-
result = conn.query(query)
251+
"""Execute a query and return (columns, rows, truncated).
247252
248-
if not result:
249-
return [], [], False
250-
251-
# SurrealDB returns a list of results (one per statement)
252-
# For a single query, we take the first result
253-
data = result[0] if isinstance(result, list) else result
254-
255-
# Handle single value returns (like RETURN 1)
256-
if not isinstance(data, (list, dict)):
257-
return ["result"], [(data,)], False
253+
surrealdb>=1.0 returns the query result already unwrapped: scalars for
254+
RETURN, a dict for INFO, and a list[dict] for SELECT.
255+
"""
256+
data = conn.query(query)
258257

259-
# Handle empty results
260-
if isinstance(data, list) and not data:
258+
if data is None:
261259
return [], [], False
262260

263-
# Handle list of records
264261
if isinstance(data, list):
265262
if not data:
266263
return [], [], False
267264
first = data[0]
268265
if isinstance(first, dict):
269266
columns = list(first.keys())
270-
all_rows = [tuple(row.get(col) for col in columns) for row in data]
271-
if max_rows is not None and len(all_rows) > max_rows:
272-
return columns, all_rows[:max_rows], True
273-
return columns, all_rows, False
274-
# List of non-dict values
275-
rows = [(v,) for v in (data[:max_rows] if max_rows else data)]
267+
rows = [
268+
tuple(_to_plain(row.get(col)) for col in columns) for row in data
269+
]
270+
if max_rows is not None and len(rows) > max_rows:
271+
return columns, rows[:max_rows], True
272+
return columns, rows, False
273+
rows = [(_to_plain(v),) for v in (data[:max_rows] if max_rows else data)]
276274
truncated = max_rows is not None and len(data) > max_rows
277275
return ["value"], rows, truncated
278276

279-
# Handle single dict
280277
if isinstance(data, dict):
281278
columns = list(data.keys())
282-
return columns, [tuple(data.values())], False
279+
return columns, [tuple(_to_plain(v) for v in data.values())], False
283280

284-
return [], [], False
281+
# Scalar returns (int, str, bool, etc.)
282+
return ["result"], [(_to_plain(data),)], False
285283

286284
def execute_non_query(self, conn: Any, query: str) -> int:
287285
"""Execute a non-query statement."""
288286
result = conn.query(query)
289-
# SurrealDB doesn't return row counts in the traditional sense
290-
# Return 1 if operation succeeded
291-
if result is not None:
292-
if isinstance(result, list) and result:
293-
data = result[0]
294-
if isinstance(data, list):
295-
return len(data)
296-
return 1
297-
return 0
287+
if result is None:
288+
return 0
289+
if isinstance(result, list):
290+
return len(result)
291+
return 1
298292

299293
def classify_query(self, query: str) -> bool:
300294
"""Return True if the query is expected to return rows."""

0 commit comments

Comments
 (0)