Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions eval_protocol/agent/resources/sql_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,21 @@
from ..resource_abc import ForkableResource


# SQLite connection settings for hardened concurrency safety
SQLITE_CONNECTION_TIMEOUT = 30 # 30 seconds


def _apply_hardened_pragmas(conn: sqlite3.Connection) -> None:
"""Apply hardened SQLite pragmas for concurrency safety."""
conn.execute("PRAGMA journal_mode=WAL") # Write-Ahead Logging
conn.execute("PRAGMA synchronous=NORMAL") # Balance safety and performance
conn.execute("PRAGMA busy_timeout=30000") # 30 second timeout
conn.execute("PRAGMA wal_autocheckpoint=1000") # Checkpoint every 1000 pages
conn.execute("PRAGMA cache_size=-64000") # 64MB cache
conn.execute("PRAGMA foreign_keys=ON") # Enable foreign key constraints
conn.execute("PRAGMA temp_store=MEMORY") # Store temp tables in memory


class SQLResource(ForkableResource):
"""
A ForkableResource for managing SQL database states, primarily SQLite.
Expand All @@ -20,6 +35,8 @@ class SQLResource(ForkableResource):
and seed data, forked (by copying the DB file), checkpointed (by copying),
and restored.

Uses hardened SQLite settings for concurrency safety.

Attributes:
_config (Dict[str, Any]): Configuration for the resource.
_db_path (Optional[Path]): Path to the current SQLite database file.
Expand All @@ -38,8 +55,14 @@ def __init__(self) -> None:
def _get_db_connection(self) -> sqlite3.Connection:
if not self._db_path:
raise ConnectionError("Database path not set. Call setup() or fork() first.")
# Set timeout to prevent indefinite hangs
return sqlite3.connect(str(self._db_path), timeout=10)
# Set timeout to prevent indefinite hangs with hardened settings
conn = sqlite3.connect(
str(self._db_path),
timeout=SQLITE_CONNECTION_TIMEOUT,
isolation_level="DEFERRED", # Better for concurrent access
)
_apply_hardened_pragmas(conn)
return conn

async def setup(self, config: Dict[str, Any]) -> None:
"""
Expand Down
120 changes: 105 additions & 15 deletions eval_protocol/cli_commands/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,82 @@

import os
from ..utils.logs_server import serve_logs
from ..event_bus.sqlite_event_bus_database import DatabaseCorruptedError, _backup_and_remove_database


def _handle_database_corruption(db_path: str) -> bool:
"""
Handle database corruption by prompting user to fix it.

Args:
db_path: Path to the corrupted database

Returns:
True if user chose to fix and database was reset, False otherwise
"""
print("\n" + "=" * 60)
print("⚠️ DATABASE CORRUPTION DETECTED")
print("=" * 60)
print(f"\nThe database file at:\n {db_path}\n")
print("appears to be corrupted or is not a valid SQLite database.")
print("\nThis can happen due to:")
print(" • Incomplete writes during a crash")
print(" • Concurrent access issues")
print(" • File system errors")
print("\n" + "-" * 60)
print("Would you like to automatically fix this?")
print(" • The corrupted file will be backed up")
print(" • A fresh database will be created")
print(" • You will lose existing log data, but can continue using the tool")
print("-" * 60)

try:
response = input("\nFix database automatically? [Y/n]: ").strip().lower()
if response in ("", "y", "yes"):
_backup_and_remove_database(db_path)
print("\n✅ Database has been reset. Restarting server...")
return True
else:
print("\n❌ Database repair cancelled.")
print(f" You can manually delete the corrupted file: {db_path}")
return False
except (EOFError, KeyboardInterrupt):
print("\n❌ Database repair cancelled.")
return False


def _is_database_corruption_error(error: Exception) -> tuple[bool, str]:
"""
Check if an exception is related to database corruption.

Returns:
Tuple of (is_corruption_error, db_path)
"""
error_str = str(error).lower()
corruption_indicators = [
"file is not a database",
"database disk image is malformed",
"database is locked",
"unable to open database file",
]
Comment thread
dphuang2 marked this conversation as resolved.

for indicator in corruption_indicators:
if indicator in error_str:
# Try to find the database path
from ..directory_utils import find_eval_protocol_dir

try:
eval_protocol_dir = find_eval_protocol_dir()
db_path = os.path.join(eval_protocol_dir, "logs.db")
return True, db_path
except Exception:
return True, ""

# Check if it's a DatabaseCorruptedError
if isinstance(error, DatabaseCorruptedError):
return True, error.db_path

return False, ""


def logs_command(args):
Expand Down Expand Up @@ -40,18 +116,32 @@ def logs_command(args):
or "https://tracing.fireworks.ai"
)

try:
serve_logs(
port=args.port,
elasticsearch_config=elasticsearch_config,
debug=args.debug,
backend="fireworks" if use_fireworks else "elasticsearch",
fireworks_base_url=fireworks_base_url if use_fireworks else None,
)
return 0
except KeyboardInterrupt:
print("\n🛑 Server stopped by user")
return 0
except Exception as e:
print(f"❌ Error starting server: {e}")
return 1
max_retries = 2
for attempt in range(max_retries):
try:
serve_logs(
port=args.port,
elasticsearch_config=elasticsearch_config,
debug=args.debug,
backend="fireworks" if use_fireworks else "elasticsearch",
fireworks_base_url=fireworks_base_url if use_fireworks else None,
)
return 0
except KeyboardInterrupt:
print("\n🛑 Server stopped by user")
return 0
except Exception as e:
is_corruption, db_path = _is_database_corruption_error(e)

if is_corruption and db_path and attempt < max_retries - 1:
if _handle_database_corruption(db_path):
# User chose to fix, retry
continue
else:
# User declined fix
return 1

print(f"❌ Error starting server: {e}")
return 1

return 1
21 changes: 17 additions & 4 deletions eval_protocol/dataset_logger/sqlite_evaluation_row_store.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import os
from typing import List, Optional

from peewee import CharField, Model, SqliteDatabase
from peewee import CharField, DatabaseError, Model, SqliteDatabase
from playhouse.sqlite_ext import JSONField

from eval_protocol.event_bus.sqlite_event_bus_database import (
SQLITE_HARDENED_PRAGMAS,
DatabaseCorruptedError,
check_and_repair_database,
)
from eval_protocol.models import EvaluationRow


Expand All @@ -12,12 +17,20 @@ class SqliteEvaluationRowStore:
Lightweight reusable SQLite store for evaluation rows.

Stores arbitrary row data as JSON keyed by a unique string `rollout_id`.
Uses hardened SQLite settings for concurrency safety.
"""

def __init__(self, db_path: str):
os.makedirs(os.path.dirname(db_path), exist_ok=True)
def __init__(self, db_path: str, auto_repair: bool = True):
db_dir = os.path.dirname(db_path)
if db_dir:
os.makedirs(db_dir, exist_ok=True)
self._db_path = db_path
self._db = SqliteDatabase(self._db_path, pragmas={"journal_mode": "wal"})

# Check and optionally repair corrupted database
check_and_repair_database(db_path, auto_repair=auto_repair)

# Use hardened pragmas for concurrency safety
self._db = SqliteDatabase(self._db_path, pragmas=SQLITE_HARDENED_PRAGMAS)

class BaseModel(Model):
class Meta:
Expand Down
5 changes: 5 additions & 0 deletions eval_protocol/event_bus/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# Global event bus instance - uses SqliteEventBus for cross-process functionality
from typing import Any, Callable
from eval_protocol.event_bus.event_bus import EventBus
from eval_protocol.event_bus.sqlite_event_bus_database import (
DatabaseCorruptedError,
check_and_repair_database,
SQLITE_HARDENED_PRAGMAS,
)


def _get_default_event_bus():
Expand Down
115 changes: 111 additions & 4 deletions eval_protocol/event_bus/sqlite_event_bus_database.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,125 @@
import os
import time
from typing import Any, List
from uuid import uuid4

from peewee import BooleanField, CharField, DateTimeField, Model, SqliteDatabase
from peewee import BooleanField, CharField, DatabaseError, DateTimeField, Model, SqliteDatabase
from playhouse.sqlite_ext import JSONField

from eval_protocol.event_bus.logger import logger


# SQLite pragmas for hardened concurrency safety
SQLITE_HARDENED_PRAGMAS = {
"journal_mode": "wal", # Write-Ahead Logging for concurrent reads/writes
"synchronous": "normal", # Balance between safety and performance
"busy_timeout": 30000, # 30 second timeout for locked database
"wal_autocheckpoint": 1000, # Checkpoint every 1000 pages
"cache_size": -64000, # 64MB cache (negative = KB)
"foreign_keys": 1, # Enable foreign key constraints
"temp_store": "memory", # Store temp tables in memory
}


class DatabaseCorruptedError(Exception):
"""Raised when the database file is corrupted or not a valid SQLite database."""

def __init__(self, db_path: str, original_error: Exception):
self.db_path = db_path
self.original_error = original_error
super().__init__(f"Database file is corrupted: {db_path}. Original error: {original_error}")


def check_and_repair_database(db_path: str, auto_repair: bool = False) -> bool:
"""
Check if a database file is valid and optionally repair it.

Args:
db_path: Path to the database file
auto_repair: If True, automatically delete and recreate corrupted database

Returns:
True if database is valid or was repaired, False otherwise

Raises:
DatabaseCorruptedError: If database is corrupted and auto_repair is False
"""
if not os.path.exists(db_path):
return True # New database, nothing to check

try:
# Try to open the database and run an integrity check
test_db = SqliteDatabase(db_path, pragmas={"busy_timeout": 5000})
test_db.connect()
cursor = test_db.execute_sql("PRAGMA integrity_check")
result = cursor.fetchone()
test_db.close()

if result and result[0] == "ok":
return True
else:
logger.warning(f"Database integrity check failed for {db_path}: {result}")
if auto_repair:
_backup_and_remove_database(db_path)
return True
raise DatabaseCorruptedError(db_path, Exception(f"Integrity check failed: {result}"))

except DatabaseError as e:
error_str = str(e).lower()
if "file is not a database" in error_str or "database disk image is malformed" in error_str:
logger.warning(f"Database file is corrupted: {db_path}")
if auto_repair:
_backup_and_remove_database(db_path)
return True
raise DatabaseCorruptedError(db_path, e)
raise

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Database connection not closed on exception

In check_and_repair_database, when a DatabaseError is raised after test_db.connect() succeeds (during execute_sql or fetchone), the test_db connection is never closed. The test_db.close() call at line 56 only executes in the happy path. If an exception occurs, the exception handler at lines 67-82 catches DatabaseError but doesn't close the connection, causing a resource leak. The try block lacks a finally clause to ensure cleanup.

Fix in Cursor Fix in Web

except Exception as e:
logger.warning(f"Error checking database {db_path}: {e}")
if auto_repair:
_backup_and_remove_database(db_path)
return True
raise DatabaseCorruptedError(db_path, e)
Comment thread
dphuang2 marked this conversation as resolved.
Outdated


def _backup_and_remove_database(db_path: str) -> None:
"""Backup a corrupted database file and remove it."""
backup_path = f"{db_path}.corrupted.{int(time.time())}"
try:
os.rename(db_path, backup_path)
logger.info(f"Backed up corrupted database to: {backup_path}")
except OSError as e:
logger.warning(f"Failed to backup corrupted database, removing: {e}")
try:
os.remove(db_path)
except OSError:
pass

# Also try to remove WAL and SHM files if they exist
for suffix in ["-wal", "-shm"]:
wal_file = f"{db_path}{suffix}"
if os.path.exists(wal_file):
try:
os.remove(wal_file)
except OSError:
pass


class SqliteEventBusDatabase:
"""SQLite database for cross-process event communication."""

def __init__(self, db_path: str):
def __init__(self, db_path: str, auto_repair: bool = True):
self._db_path = db_path
self._db = SqliteDatabase(db_path)

# Ensure directory exists
db_dir = os.path.dirname(db_path)
if db_dir:
os.makedirs(db_dir, exist_ok=True)

# Check and optionally repair corrupted database
check_and_repair_database(db_path, auto_repair=auto_repair)

# Initialize database with hardened concurrency settings
self._db = SqliteDatabase(db_path, pragmas=SQLITE_HARDENED_PRAGMAS)

class BaseModel(Model):
class Meta:
Expand All @@ -29,7 +135,8 @@ class Event(BaseModel): # type: ignore

self._Event = Event
self._db.connect()
self._db.create_tables([Event])
# Use safe=True to avoid errors when tables already exist
self._db.create_tables([Event], safe=True)

def publish_event(self, event_type: str, data: Any, process_id: str) -> None:
"""Publish an event to the database."""
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ combine-as-imports = true
[tool.pyright]
typeCheckingMode = "basic" # Changed from "standard" to reduce memory usage
pythonVersion = "3.10"
include = ["eval_protocol"] # Reduced scope to just the main package
exclude = ["vite-app", "vendor", "examples", "tests", "development", "local_evals"]
include = ["eval_protocol", "tests"] # Reduced scope to just the main package
exclude = ["vite-app", "vendor", "examples", "development", "local_evals"]
# Ignore diagnostics for vendored generator code
ignore = ["versioneer.py"]
reportUnusedCallResult = "none"
Expand Down
Loading
Loading