Skip to content

Commit 1f068d0

Browse files
committed
Prevent extra_checkpoint from overriding checkpoint_payload
1 parent 7343100 commit 1f068d0

2 files changed

Lines changed: 21 additions & 0 deletions

File tree

ddev/src/ddev/ai/phases/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,11 @@ async def process_message(self, message: PhaseTrigger) -> None:
149149
},
150150
"memory_path": str(self._checkpoint_manager.memory_path(self._phase_id)),
151151
}
152+
reserved = set(checkpoint_payload) & set(outcome.extra_checkpoint)
153+
if reserved:
154+
raise ValueError(
155+
f"Phase {self._phase_id!r}: extra_checkpoint cannot override reserved keys: {sorted(reserved)}"
156+
)
152157
checkpoint_payload.update(outcome.extra_checkpoint)
153158

154159
self._checkpoint_manager.write_phase_checkpoint(self._phase_id, checkpoint_payload)

ddev/tests/ai/phases/test_base.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from datetime import UTC, datetime
66
from unittest.mock import MagicMock
77

8+
import pytest
9+
810
from ddev.ai.phases.base import Phase, PhaseOutcome, _make_memory_resolver
911
from ddev.ai.phases.checkpoint import CheckpointManager
1012
from ddev.ai.phases.config import PhaseConfig
@@ -226,6 +228,20 @@ async def test_process_message_writes_memory_and_checkpoint(flow_dir, message_qu
226228
assert checkpoint["finished_at"]
227229

228230

231+
@pytest.mark.parametrize(
232+
"reserved_key",
233+
["status", "started_at", "finished_at", "tokens", "memory_path"],
234+
)
235+
async def test_extra_checkpoint_cannot_override_reserved_keys(flow_dir, message_queue, reserved_key):
236+
outcome = PhaseOutcome(memory_text="m", extra_checkpoint={reserved_key: "evil"})
237+
phase, mgr = _make_stub_phase(flow_dir, message_queue, outcome=outcome)
238+
239+
with pytest.raises(ValueError, match=f"reserved keys.*{reserved_key}"):
240+
await phase.process_message(PhaseTrigger(id="start", phase_id=None))
241+
242+
assert mgr.read() == {}
243+
244+
229245
async def test_failed_phase_omits_memory_path(flow_dir, message_queue):
230246
phase, mgr = _make_stub_phase(flow_dir, message_queue)
231247

0 commit comments

Comments
 (0)