Skip to content

Commit 6c76ad2

Browse files
author
Dylan Huang
committed
update
1 parent 3ffc83a commit 6c76ad2

4 files changed

Lines changed: 56 additions & 27 deletions

File tree

eval_protocol/dataset_logger/tinydb_evaluation_row_store.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import List, Optional
33

44
from tinydb import Query, TinyDB
5+
from tinyrecord.transaction import transaction
56

67
from eval_protocol.dataset_logger.evaluation_row_store import EvaluationRowStore
78

@@ -12,6 +13,9 @@ class TinyDBEvaluationRowStore(EvaluationRowStore):
1213
1314
Stores data as plain JSON files, which are human-readable and
1415
don't suffer from SQLite's binary format corruption issues.
16+
17+
Uses tinyrecord for atomic transactions to handle concurrent access
18+
from multiple processes safely.
1519
"""
1620

1721
def __init__(self, db_path: str):
@@ -33,18 +37,34 @@ def upsert_row(self, data: dict) -> None:
3337
raise ValueError("execution_metadata.rollout_id is required to upsert a row")
3438

3539
Row = Query()
36-
self._table.upsert(data, Row.execution_metadata.rollout_id == rollout_id)
40+
condition = Row.execution_metadata.rollout_id == rollout_id
41+
42+
# tinyrecord doesn't support upsert directly, so we implement it manually
43+
# within a transaction for atomicity
44+
with transaction(self._table) as tr:
45+
# Check if document exists
46+
existing = self._table.search(condition)
47+
if existing:
48+
# Update existing document
49+
tr.update(data, condition)
50+
else:
51+
# Insert new document
52+
tr.insert(data)
3753

3854
def read_rows(self, rollout_id: Optional[str] = None) -> List[dict]:
55+
# Clear cache to ensure fresh read in multi-process scenarios
56+
self._table.clear_cache()
3957
if rollout_id is not None:
4058
Row = Query()
4159
return list(self._table.search(Row.execution_metadata.rollout_id == rollout_id))
4260
return list(self._table.all())
4361

4462
def delete_row(self, rollout_id: str) -> int:
4563
Row = Query()
46-
removed = self._table.remove(Row.execution_metadata.rollout_id == rollout_id)
47-
return len(removed)
64+
with transaction(self._table) as tr:
65+
tr.remove(Row.execution_metadata.rollout_id == rollout_id)
66+
# Return count after removal (we don't have access to removed count in transaction)
67+
return 1
4868

4969
def delete_all_rows(self) -> int:
5070
count = len(self._table)

eval_protocol/event_bus/tinydb_event_bus_database.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from uuid import uuid4
55

66
from tinydb import Query, TinyDB
7+
from tinyrecord.transaction import transaction
78

89
from eval_protocol.event_bus.event_bus_database import EventBusDatabase
910
from eval_protocol.event_bus.logger import logger
@@ -15,6 +16,9 @@ class TinyDBEventBusDatabase(EventBusDatabase):
1516
1617
Stores data as plain JSON files, which are human-readable and
1718
don't suffer from SQLite's binary format corruption issues.
19+
20+
Uses tinyrecord for atomic transactions to handle concurrent access
21+
from multiple processes safely.
1822
"""
1923

2024
def __init__(self, db_path: str):
@@ -27,24 +31,26 @@ def __init__(self, db_path: str):
2731
self._table = self._db.table("events")
2832

2933
def publish_event(self, event_type: str, data: Any, process_id: str) -> None:
30-
"""Publish an event to the database."""
34+
"""Publish an event to the database using atomic transaction."""
3135
try:
3236
# Serialize data, handling pydantic models
3337
if hasattr(data, "model_dump"):
3438
serialized_data = data.model_dump(mode="json", exclude_none=True)
3539
else:
3640
serialized_data = data
3741

38-
self._table.insert(
39-
{
40-
"event_id": str(uuid4()),
41-
"event_type": event_type,
42-
"data": serialized_data,
43-
"timestamp": time.time(),
44-
"process_id": process_id,
45-
"processed": False,
46-
}
47-
)
42+
document = {
43+
"event_id": str(uuid4()),
44+
"event_type": event_type,
45+
"data": serialized_data,
46+
"timestamp": time.time(),
47+
"process_id": process_id,
48+
"processed": False,
49+
}
50+
51+
# Use tinyrecord transaction for atomic, concurrent-safe insert
52+
with transaction(self._table) as tr:
53+
tr.insert(document)
4854
except Exception as e:
4955
logger.warning(f"Failed to publish event to database: {e}")
5056

@@ -83,21 +89,23 @@ def get_unprocessed_events(self, process_id: str) -> List[dict]:
8389
return []
8490

8591
def mark_event_processed(self, event_id: str) -> None:
86-
"""Mark an event as processed."""
92+
"""Mark an event as processed using atomic transaction."""
8793
try:
8894
Event = Query()
89-
self._table.update({"processed": True}, Event.event_id == event_id)
95+
with transaction(self._table) as tr:
96+
tr.update({"processed": True}, Event.event_id == event_id)
9097
except Exception as e:
9198
logger.debug(f"Failed to mark event as processed: {e}")
9299

93100
def cleanup_old_events(self, max_age_hours: int = 24) -> None:
94-
"""Clean up old processed events."""
101+
"""Clean up old processed events using atomic transaction."""
95102
try:
96-
# Reload table from disk to see latest data before cleanup
97-
self._table._read_table()
103+
# Clear cache to see latest data before cleanup
104+
self._table.clear_cache()
98105

99106
cutoff_time = time.time() - (max_age_hours * 3600)
100107
Event = Query()
101-
self._table.remove((Event.processed == True) & (Event.timestamp < cutoff_time)) # noqa: E712
108+
with transaction(self._table) as tr:
109+
tr.remove((Event.processed == True) & (Event.timestamp < cutoff_time)) # noqa: E712
102110
except Exception as e:
103111
logger.debug(f"Failed to cleanup old events: {e}")

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ dependencies = [
3535
"pytest>=6.0.0",
3636
"pytest-asyncio>=0.21.0",
3737
"tinydb>=4.8.0",
38+
"tinyrecord>=0.2.0",
3839
"backoff>=2.2.0",
3940
"questionary>=2.0.0",
4041
# Dependencies for vendored tau2 package

tests/pytest/test_pytest_ensure_logging.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,17 @@ async def test_ensure_logging(monkeypatch):
1515
monkeypatch.setattr(_dl, "_logger", None, raising=False)
1616
except Exception:
1717
pass
18-
# Mock the SqliteEvaluationRowStore to track calls
18+
# Mock the EvaluationRowStore to track calls
1919
mock_store = Mock()
2020
mock_store.upsert_row = Mock()
2121
mock_store.read_rows = Mock(return_value=[])
2222
mock_store.db_path = "/tmp/test.db"
2323

24-
# Mock the SqliteEvaluationRowStore constructor so that when SqliteDatasetLoggerAdapter
25-
# creates its store, it gets our mock instead
26-
with patch(
27-
"eval_protocol.dataset_logger.sqlite_dataset_logger_adapter.SqliteEvaluationRowStore", return_value=mock_store
28-
):
24+
# Mock get_evaluation_row_store so that when DatasetLoggerAdapter
25+
# creates its store, it gets our mock instead.
26+
# We patch at the module level where it's defined, which is where
27+
# dataset_logger_adapter imports it from.
28+
with patch("eval_protocol.dataset_logger.get_evaluation_row_store", return_value=mock_store):
2929
from eval_protocol.models import EvaluationRow, EvaluateResult
3030
from eval_protocol.pytest.default_no_op_rollout_processor import NoOpRolloutProcessor
3131
from eval_protocol.pytest.evaluation_test import evaluation_test
@@ -55,7 +55,7 @@ def eval_fn(row: EvaluationRow) -> EvaluationRow:
5555
)
5656

5757
# Verify that the store's upsert_row method was called
58-
assert mock_store.upsert_row.called, "SqliteEvaluationRowStore.upsert_row should have been called"
58+
assert mock_store.upsert_row.called, "EvaluationRowStore.upsert_row should have been called"
5959

6060
# Check that it was called multiple times (once for each row)
6161
call_count = mock_store.upsert_row.call_count

0 commit comments

Comments
 (0)