-
Notifications
You must be signed in to change notification settings - Fork 17
Expand file tree
/
Copy pathsqlite_evaluation_row_store.py
More file actions
71 lines (55 loc) · 2.59 KB
/
Copy pathsqlite_evaluation_row_store.py
File metadata and controls
71 lines (55 loc) · 2.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
from typing import List, Optional
try:
from peewee import CharField, Model, SqliteDatabase
from playhouse.sqlite_ext import JSONField
except ImportError:
raise ImportError(
"SQLite storage backend requires 'peewee' package. Install it with: pip install eval-protocol[sqlite_storage]"
)
from eval_protocol.dataset_logger.evaluation_row_store import EvaluationRowStore
class SqliteEvaluationRowStore(EvaluationRowStore):
"""
Lightweight reusable SQLite store for evaluation rows.
Stores arbitrary row data as JSON keyed by a unique string `rollout_id`.
"""
def __init__(self, db_path: str):
# Handle case where db_path might be in the root directory
db_dir = os.path.dirname(db_path)
if db_dir:
os.makedirs(db_dir, exist_ok=True)
self._db_path = db_path
self._db = SqliteDatabase(self._db_path, pragmas={"journal_mode": "wal"})
class BaseModel(Model):
class Meta:
database = self._db
class EvaluationRow(BaseModel): # type: ignore
rollout_id = CharField(unique=True)
data = JSONField()
self._EvaluationRow = EvaluationRow
self._db.connect()
# Use safe=True to avoid errors when tables/indexes already exist
self._db.create_tables([EvaluationRow], safe=True)
@property
def db_path(self) -> str:
return self._db_path
def upsert_row(self, data: dict) -> None:
rollout_id = data["execution_metadata"]["rollout_id"]
if rollout_id is None:
raise ValueError("execution_metadata.rollout_id is required to upsert a row")
with self._db.atomic("EXCLUSIVE"):
if self._EvaluationRow.select().where(self._EvaluationRow.rollout_id == rollout_id).exists():
self._EvaluationRow.update(data=data).where(self._EvaluationRow.rollout_id == rollout_id).execute()
else:
self._EvaluationRow.create(rollout_id=rollout_id, data=data)
def read_rows(self, rollout_id: Optional[str] = None) -> List[dict]:
if rollout_id is None:
query = self._EvaluationRow.select().dicts()
else:
query = self._EvaluationRow.select().dicts().where(self._EvaluationRow.rollout_id == rollout_id)
results = list(query)
return [result["data"] for result in results]
def delete_row(self, rollout_id: str) -> int:
return self._EvaluationRow.delete().where(self._EvaluationRow.rollout_id == rollout_id).execute()
def delete_all_rows(self) -> int:
return self._EvaluationRow.delete().execute()