-
Notifications
You must be signed in to change notification settings - Fork 17
Expand file tree
/
Copy pathsql_resource.py
More file actions
326 lines (276 loc) · 13.8 KB
/
Copy pathsql_resource.py
File metadata and controls
326 lines (276 loc) · 13.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
"""
SQLResource: A ForkableResource for managing SQL database states, initially focusing on SQLite.
"""
import os
import shutil
import sqlite3
import uuid
from pathlib import Path
from typing import Any, Dict, List, Optional
from ..resource_abc import ForkableResource
# SQLite connection settings for hardened concurrency safety
SQLITE_CONNECTION_TIMEOUT = 30 # 30 seconds
def _apply_hardened_pragmas(conn: sqlite3.Connection) -> None:
"""Apply hardened SQLite pragmas for concurrency safety."""
conn.execute("PRAGMA journal_mode=WAL") # Write-Ahead Logging
conn.execute("PRAGMA synchronous=NORMAL") # Balance safety and performance
conn.execute("PRAGMA busy_timeout=30000") # 30 second timeout
conn.execute("PRAGMA wal_autocheckpoint=1000") # Checkpoint every 1000 pages
conn.execute("PRAGMA cache_size=-64000") # 64MB cache
conn.execute("PRAGMA foreign_keys=ON") # Enable foreign key constraints
conn.execute("PRAGMA temp_store=MEMORY") # Store temp tables in memory
def _checkpoint_and_copy_database(
source_path: Path, dest_path: Path, timeout: int = SQLITE_CONNECTION_TIMEOUT
) -> None:
"""
Safely copy a SQLite database by checkpointing WAL first.
In WAL mode, data may exist in the -wal file that hasn't been written
to the main database file. This function performs a TRUNCATE checkpoint
to flush all WAL data to the main file before copying, ensuring a
complete and consistent copy.
Args:
source_path: Path to the source database file.
dest_path: Path where the copy should be created.
timeout: Connection timeout in seconds.
"""
# First, checkpoint the WAL to ensure all data is in the main file
conn = sqlite3.connect(str(source_path), timeout=timeout)
try:
# TRUNCATE mode: checkpoint and truncate the WAL file to zero bytes
# This ensures all data is flushed to the main database file
conn.execute("PRAGMA wal_checkpoint(TRUNCATE)")
finally:
conn.close()
# Now safely copy just the main database file
shutil.copyfile(str(source_path), str(dest_path))
class SQLResource(ForkableResource):
"""
A ForkableResource for managing SQL database states, primarily SQLite.
Manages a SQLite database file, allowing it to be initialized with a schema
and seed data, forked (by copying the DB file), checkpointed (by copying),
and restored.
Uses hardened SQLite settings for concurrency safety.
Attributes:
_config (Dict[str, Any]): Configuration for the resource.
_db_path (Optional[Path]): Path to the current SQLite database file.
_base_db_path (Optional[Path]): Path to the initially set up database, used for forking.
_temp_dir (Path): Directory to store database files.
"""
def __init__(self) -> None:
self._config: Dict[str, Any] = {}
self._db_path: Optional[Path] = None
self._base_db_path: Optional[Path] = None
# Consider making temp_dir configurable or using a more robust temp solution
self._temp_dir = Path("./.rk_temp_dbs").resolve() # Ensure absolute path
self._temp_dir.mkdir(parents=True, exist_ok=True)
def _get_db_connection(self) -> sqlite3.Connection:
if not self._db_path:
raise ConnectionError("Database path not set. Call setup() or fork() first.")
# Set timeout to prevent indefinite hangs with hardened settings
conn = sqlite3.connect(
str(self._db_path),
timeout=SQLITE_CONNECTION_TIMEOUT,
isolation_level="DEFERRED", # Better for concurrent access
)
_apply_hardened_pragmas(conn)
return conn
async def setup(self, config: Dict[str, Any]) -> None:
"""
Initializes the SQLite database.
Args:
config: Configuration dictionary. Expected keys:
- 'db_type' (str): Must be 'sqlite'.
- 'db_name' (Optional[str]): Name for the database file. Defaults to a UUID.
- 'schema_file' (Optional[str]): Path to an SQL file to execute for schema setup.
- 'seed_data_file' (Optional[str]): Path to an SQL file for initial data seeding.
- 'schema_sql' (Optional[str]): SQL string for schema setup.
- 'seed_sql' (Optional[str]): SQL string for initial data seeding.
"""
self._config = config.copy()
db_type = self._config.get("db_type", "sqlite")
if db_type != "sqlite":
raise ValueError("SQLResource currently only supports 'sqlite'.")
db_name = self._config.get("db_name", f"db_{uuid.uuid4().hex}.sqlite")
self._base_db_path = self._temp_dir / db_name
self._db_path = self._base_db_path # Initially, the current DB is the base DB
# Ensure a fresh start if the base DB file already exists from a previous run
if self._base_db_path is not None and self._base_db_path.exists():
self._base_db_path.unlink()
conn = self._get_db_connection()
try:
with conn:
# Apply schema
schema_file = self._config.get("schema_file")
if schema_file and Path(schema_file).exists():
with open(schema_file, "r") as f:
conn.executescript(f.read())
schema_sql = self._config.get("schema_sql")
if schema_sql:
conn.executescript(schema_sql)
# Apply seed data
seed_data_file = self._config.get("seed_data_file")
if seed_data_file and Path(seed_data_file).exists():
with open(seed_data_file, "r") as f:
conn.executescript(f.read())
seed_sql = self._config.get("seed_sql")
if seed_sql:
conn.executescript(seed_sql)
finally:
conn.close()
async def fork(self) -> "SQLResource":
"""
Creates a new SQLResource instance with a copy of the base database state.
If called on an already forked resource, it forks from its current state.
"""
if not self._db_path or not self._db_path.exists():
raise RuntimeError("Cannot fork: original database does not exist or setup was not called.")
forked_resource = SQLResource()
forked_resource._config = self._config.copy()
forked_resource._temp_dir = self._temp_dir # Share the same temp dir base
# The new fork's base is the current state of this resource
forked_resource._base_db_path = self._db_path
# Create a new unique DB file for this fork
forked_db_name = f"fork_{uuid.uuid4().hex}.sqlite"
forked_resource._db_path = self._temp_dir / forked_db_name
# Use checkpoint-and-copy to ensure WAL data is flushed before copying
_checkpoint_and_copy_database(self._db_path, forked_resource._db_path)
return forked_resource
async def checkpoint(self) -> Dict[str, Any]:
"""
Returns a serializable representation of the resource's current state.
For SQLite, this involves copying the database file to a checkpoint location
and returning the path.
"""
if not self._db_path or not self._db_path.exists():
raise RuntimeError("Cannot checkpoint: database does not exist.")
checkpoint_name = f"checkpoint_{self._db_path.stem}_{uuid.uuid4().hex}.sqlite"
checkpoint_path = self._temp_dir / checkpoint_name
# Use checkpoint-and-copy to ensure WAL data is flushed before copying
_checkpoint_and_copy_database(self._db_path, checkpoint_path)
return {"db_type": "sqlite", "checkpoint_path": str(checkpoint_path)}
async def restore(self, state_data: Dict[str, Any]) -> None:
"""
Restores the resource's state from a previously checkpointed state.
For SQLite, this means copying the checkpointed DB file to become the current DB.
"""
db_type = state_data.get("db_type")
checkpoint_path_str = state_data.get("checkpoint_path")
if db_type != "sqlite" or not checkpoint_path_str:
raise ValueError("Invalid state_data for SQLite restore.")
checkpoint_path = Path(checkpoint_path_str)
if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
# If current db_path is not set (e.g. fresh resource), assign one
if not self._db_path:
self._db_path = self._temp_dir / f"restored_{uuid.uuid4().hex}.sqlite"
# Use checkpoint-and-copy to ensure WAL data is flushed before copying
_checkpoint_and_copy_database(checkpoint_path, self._db_path)
self._base_db_path = self._db_path # The restored state becomes the new base for future forks
async def step(self, action_name: str, action_params: Dict[str, Any]) -> Any:
"""
Executes a SQL query on the database.
Args:
action_name: Should be 'execute_sql'.
action_params: Dictionary containing:
- 'query' (str): The SQL query to execute.
- 'parameters' (Optional[Dict | List]): Parameters for the query.
- 'fetch_mode' (Optional[str]): 'one', 'all', or 'val'. If None, no fetch.
Returns:
Query result based on fetch_mode, or rowcount for DML.
"""
if action_name != "execute_sql":
raise NotImplementedError(f"Action '{action_name}' not supported by SQLResource.")
query = action_params.get("query")
if not query:
raise ValueError("Missing 'query' in action_params for 'execute_sql'.")
params = action_params.get("parameters", [])
fetch_mode = action_params.get("fetch_mode") # 'one', 'all', 'val'
conn = self._get_db_connection()
try:
with conn:
cursor = conn.cursor()
cursor.execute(query, params)
if fetch_mode == "one":
columns = [desc[0] for desc in cursor.description]
row = cursor.fetchone()
return dict(zip(columns, row)) if row else None
elif fetch_mode == "all":
columns = [desc[0] for desc in cursor.description]
rows = cursor.fetchall()
return [dict(zip(columns, row)) for row in rows]
elif fetch_mode == "val":
row = cursor.fetchone()
return row[0] if row else None
else: # DML or no fetch needed
return {"rowcount": cursor.rowcount}
finally:
conn.close()
async def get_observation(self) -> Dict[str, Any]:
"""
Returns the current observable state of the resource.
For SQLResource, this could be the path to the DB or a status message.
"""
return {
"db_type": "sqlite",
"db_path": str(self._db_path) if self._db_path else None,
"status": ("ready" if self._db_path and self._db_path.exists() else "uninitialized"),
}
async def get_tools_spec(self) -> List[Dict[str, Any]]:
"""
Returns tool specifications for interacting with the SQL database.
"""
return [
{
"type": "function",
"function": {
"name": "execute_sql",
"description": "Executes a SQL query against the database. "
"Use 'fetch_mode' to control return value: "
"'one' for a single row, "
"'all' for all rows, "
"'val' for a single value from the first row. "
"If 'fetch_mode' is not provided, returns rowcount for DML statements.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The SQL query to execute.",
},
"parameters": {
"type": "array", # Or object for named parameters, sqlite3 supports both
"description": "Parameters for the SQL query (optional).",
"items": {"type": "any"},
},
"fetch_mode": {
"type": "string",
"enum": ["one", "all", "val"],
"description": "Specifies how to fetch results (optional).",
},
},
"required": ["query"],
},
},
}
]
async def close(self) -> None:
"""
Cleans up by deleting the created database file(s).
More robust cleanup of the _temp_dir might be needed if it's shared or persistent.
"""
if self._db_path and self._db_path.exists():
try:
self._db_path.unlink()
except OSError as e:
print(f"Error deleting database file {self._db_path}: {e}")
# Potentially clean up base_db_path if it's different and also temporary
# if self._base_db_path and self._base_db_path.exists() and self._base_db_path != self._db_path:
# try:
# self._base_db_path.unlink()
# except OSError:
# pass # ignore if it was already deleted or moved
# For now, we don't delete the _temp_dir itself, as it might contain checkpoints
# or other DBs from concurrent operations. A more sophisticated cleanup strategy
# for _temp_dir might be needed for long-running processes.
self._db_path = None
self._base_db_path = None