Skip to content

Commit 1addf32

Browse files
committed
test(session): add restorer unit tests
1 parent 25c6b1e commit 1addf32

1 file changed

Lines changed: 211 additions & 0 deletions

File tree

tests/test_session_restorer.py

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
"""Tests for SessionRestorer.
2+
3+
Includes property-based tests for correctness properties from the design doc.
4+
"""
5+
6+
from datetime import datetime
7+
from typing import Any, Dict, List
8+
from unittest.mock import MagicMock
9+
10+
import pytest
11+
from hypothesis import given, settings
12+
from hypothesis import strategies as st
13+
14+
from shello_cli.session.models import SessionMetadata
15+
from shello_cli.session.restorer import SessionRestorer
16+
17+
settings.load_profile("default")
18+
19+
# ---------------------------------------------------------------------------
20+
# Helpers
21+
# ---------------------------------------------------------------------------
22+
23+
NON_SYSTEM_ROLES = ["user", "assistant", "tool"]
24+
25+
26+
def _make_agent(system_prompt: str = "system prompt", model: str = "unknown-model") -> MagicMock:
27+
"""Return a minimal mock ShelloAgent."""
28+
agent = MagicMock()
29+
agent._messages = []
30+
agent._chat_history = []
31+
agent._build_system_prompt.return_value = system_prompt
32+
agent.get_current_model.return_value = model
33+
return agent
34+
35+
36+
def _make_meta(provider: str = "openai", model: str = "gpt-4o") -> SessionMetadata:
37+
return SessionMetadata(
38+
session_id="20250101_120000_abcd",
39+
start_time=datetime(2025, 1, 1, 12, 0, 0),
40+
working_directory="/tmp",
41+
provider=provider,
42+
model=model,
43+
)
44+
45+
46+
# ---------------------------------------------------------------------------
47+
# Strategies
48+
# ---------------------------------------------------------------------------
49+
50+
# Simple text content for messages
51+
_text = st.text(min_size=0, max_size=200)
52+
53+
# A single non-system API message dict
54+
_non_system_message = st.fixed_dictionaries(
55+
{
56+
"role": st.sampled_from(NON_SYSTEM_ROLES),
57+
"content": _text,
58+
}
59+
)
60+
61+
# A system message (should be filtered out by the restorer)
62+
_system_message = st.fixed_dictionaries(
63+
{
64+
"role": st.just("system"),
65+
"content": _text,
66+
}
67+
)
68+
69+
# A conversation state: mix of system and non-system messages
70+
_conversation_state = st.lists(
71+
st.one_of(_non_system_message, _system_message),
72+
min_size=0,
73+
max_size=20,
74+
)
75+
76+
77+
# ---------------------------------------------------------------------------
78+
# Property 8: Non-system messages preserved on restore
79+
# Feature: session-history, Property 8: Non-system messages preserved on restore
80+
# ---------------------------------------------------------------------------
81+
82+
@given(conversation_state=_conversation_state)
83+
def test_property8_non_system_messages_preserved(conversation_state: List[Dict[str, Any]]):
84+
"""**Validates: Requirements 4.8**
85+
86+
For any conversation state, after SessionRestorer.restore() rebuilds the
87+
agent state, all messages with role != "system" from the original state
88+
should appear in agent._messages in the same order with identical content.
89+
"""
90+
# Feature: session-history, Property 8: Non-system messages preserved on restore
91+
agent = _make_agent()
92+
meta = _make_meta()
93+
restorer = SessionRestorer()
94+
95+
restorer.restore(agent, conversation_state, meta)
96+
97+
# Collect expected non-system messages (original order)
98+
expected = [m for m in conversation_state if m.get("role") != "system"]
99+
100+
# agent._messages[0] is always the (new) system prompt; the rest are non-system
101+
actual_non_system = [m for m in agent._messages if m.get("role") != "system"]
102+
103+
assert len(actual_non_system) == len(expected), (
104+
f"Expected {len(expected)} non-system messages, got {len(actual_non_system)}"
105+
)
106+
107+
for i, (actual_msg, expected_msg) in enumerate(zip(actual_non_system, expected)):
108+
assert actual_msg["role"] == expected_msg["role"], (
109+
f"Message {i}: role mismatch — expected {expected_msg['role']!r}, "
110+
f"got {actual_msg['role']!r}"
111+
)
112+
assert actual_msg["content"] == expected_msg["content"], (
113+
f"Message {i}: content mismatch — expected {expected_msg['content']!r}, "
114+
f"got {actual_msg['content']!r}"
115+
)
116+
117+
118+
@given(conversation_state=_conversation_state)
119+
def test_property8_system_message_is_rebuilt(conversation_state: List[Dict[str, Any]]):
120+
"""**Validates: Requirements 4.5, 4.8**
121+
122+
The first message in agent._messages after restore must be a system message
123+
built from the current system info (not from the original conversation state).
124+
"""
125+
# Feature: session-history, Property 8: Non-system messages preserved on restore
126+
fresh_prompt = "fresh system prompt for current env"
127+
agent = _make_agent(system_prompt=fresh_prompt)
128+
meta = _make_meta()
129+
restorer = SessionRestorer()
130+
131+
restorer.restore(agent, conversation_state, meta)
132+
133+
assert len(agent._messages) >= 1
134+
first = agent._messages[0]
135+
assert first["role"] == "system"
136+
assert first["content"] == fresh_prompt
137+
138+
139+
# ---------------------------------------------------------------------------
140+
# Unit tests
141+
# ---------------------------------------------------------------------------
142+
143+
def test_restore_empty_conversation_state():
144+
"""Restoring an empty conversation state leaves only the system message."""
145+
agent = _make_agent()
146+
meta = _make_meta()
147+
restorer = SessionRestorer()
148+
149+
restorer.restore(agent, [], meta)
150+
151+
assert len(agent._messages) == 1
152+
assert agent._messages[0]["role"] == "system"
153+
assert agent._chat_history == []
154+
155+
156+
def test_restore_filters_system_messages():
157+
"""System messages in the original state are not carried over."""
158+
conversation_state = [
159+
{"role": "system", "content": "old system prompt"},
160+
{"role": "user", "content": "hello"},
161+
{"role": "assistant", "content": "hi there"},
162+
]
163+
agent = _make_agent()
164+
meta = _make_meta()
165+
restorer = SessionRestorer()
166+
167+
restorer.restore(agent, conversation_state, meta)
168+
169+
roles = [m["role"] for m in agent._messages]
170+
# Only one system message (the rebuilt one), then user + assistant
171+
assert roles.count("system") == 1
172+
assert roles[0] == "system"
173+
assert roles[1:] == ["user", "assistant"]
174+
175+
176+
def test_restore_preserves_order():
177+
"""Non-system messages appear in original order after restore."""
178+
conversation_state = [
179+
{"role": "user", "content": "first"},
180+
{"role": "assistant", "content": "second"},
181+
{"role": "user", "content": "third"},
182+
{"role": "tool", "content": "fourth"},
183+
]
184+
agent = _make_agent()
185+
meta = _make_meta()
186+
restorer = SessionRestorer()
187+
188+
restorer.restore(agent, conversation_state, meta)
189+
190+
non_system = [m for m in agent._messages if m["role"] != "system"]
191+
contents = [m["content"] for m in non_system]
192+
assert contents == ["first", "second", "third", "fourth"]
193+
194+
195+
def test_restore_rebuilds_chat_history():
196+
"""agent._chat_history is rebuilt with correct types and content."""
197+
conversation_state = [
198+
{"role": "user", "content": "hello"},
199+
{"role": "assistant", "content": "world"},
200+
]
201+
agent = _make_agent()
202+
meta = _make_meta()
203+
restorer = SessionRestorer()
204+
205+
restorer.restore(agent, conversation_state, meta)
206+
207+
assert len(agent._chat_history) == 2
208+
assert agent._chat_history[0].type == "user"
209+
assert agent._chat_history[0].content == "hello"
210+
assert agent._chat_history[1].type == "assistant"
211+
assert agent._chat_history[1].content == "world"

0 commit comments

Comments
 (0)