Skip to content

Commit dcfbdd1

Browse files
fix: update stale tests (provider key + streaming mock) and ruff format
1 parent 271a226 commit dcfbdd1

4 files changed

Lines changed: 28 additions & 12 deletions

File tree

extropy/population/sampler/core.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,6 @@ class SamplingError(Exception):
171171
pass
172172

173173

174-
175174
def _has_household_attributes(spec: PopulationSpec) -> bool:
176175
"""Check if the spec has household-scoped attributes, indicating household mode."""
177176
return any(attr.scope == "household" for attr in spec.attributes)
@@ -452,7 +451,11 @@ def _sample_population_households(
452451
"""
453452
if config is None:
454453
config = HouseholdConfig()
455-
focus_mode = agent_focus_mode if agent_focus_mode in ("all", "couples", "primary_only") else "primary_only"
454+
focus_mode = (
455+
agent_focus_mode
456+
if agent_focus_mode in ("all", "couples", "primary_only")
457+
else "primary_only"
458+
)
456459

457460
hh_id_width = len(str(target_n - 1)) # safe upper bound for household IDs
458461

tests/test_agent_focus.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,9 @@ class TestAgentFocusPrimaryOnly:
140140

141141
def test_partners_are_npc(self):
142142
spec = _make_household_spec()
143-
result = sample_population(spec, count=200, seed=42, agent_focus_mode="primary_only")
143+
result = sample_population(
144+
spec, count=200, seed=42, agent_focus_mode="primary_only"
145+
)
144146
agents = result.agents
145147

146148
# Find primary adults with partners
@@ -168,7 +170,9 @@ def test_partners_are_npc(self):
168170

169171
def test_no_partner_agents_in_result(self):
170172
spec = _make_household_spec()
171-
result = sample_population(spec, count=200, seed=42, agent_focus_mode="primary_only")
173+
result = sample_population(
174+
spec, count=200, seed=42, agent_focus_mode="primary_only"
175+
)
172176
agents = result.agents
173177

174178
# Count agents with household_role = adult_secondary
@@ -183,12 +187,16 @@ def test_no_partner_agents_in_result(self):
183187
def test_exact_agent_count(self):
184188
"""Regression test: -n must produce exactly N agents."""
185189
spec = _make_household_spec()
186-
result = sample_population(spec, count=500, seed=42, agent_focus_mode="primary_only")
190+
result = sample_population(
191+
spec, count=500, seed=42, agent_focus_mode="primary_only"
192+
)
187193
assert len(result.agents) == 500
188194

189195
def test_npc_partner_has_correlated_demographics(self):
190196
spec = _make_household_spec()
191-
result = sample_population(spec, count=300, seed=42, agent_focus_mode="primary_only")
197+
result = sample_population(
198+
spec, count=300, seed=42, agent_focus_mode="primary_only"
199+
)
192200
agents = result.agents
193201

194202
age_diffs = []
@@ -211,7 +219,9 @@ def test_npc_partner_has_correlated_demographics(self):
211219

212220
def test_npc_partner_shares_last_name(self):
213221
spec = _make_household_spec()
214-
result = sample_population(spec, count=200, seed=42, agent_focus_mode="primary_only")
222+
result = sample_population(
223+
spec, count=200, seed=42, agent_focus_mode="primary_only"
224+
)
215225

216226
for agent in result.agents:
217227
npc = agent.get("partner_npc")

tests/test_estimator.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -217,10 +217,8 @@ def test_resolve_openai_default(self):
217217
assert resolve_default_model("openai", "simple") == "gpt-5-mini"
218218

219219
def test_resolve_claude_default(self):
220-
assert (
221-
resolve_default_model("claude", "reasoning") == "claude-sonnet-4-5-20250929"
222-
)
223-
assert resolve_default_model("claude", "simple") == "claude-haiku-4-5-20251001"
220+
assert resolve_default_model("anthropic", "reasoning") == "claude-sonnet-4-6"
221+
assert resolve_default_model("anthropic", "simple") == "claude-haiku-4-5-20251001"
224222

225223
def test_resolve_unknown_provider_falls_back(self):
226224
model = resolve_default_model("unknown_provider", "reasoning")

tests/test_providers.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,12 @@ def test_extracts_sources(self, mock_get_client):
404404
response.content = [search_block, tool_block, text_block]
405405

406406
mock_client = MagicMock()
407-
mock_client.messages.create.return_value = response
407+
# Code uses client.messages.stream() context manager → .get_final_message()
408+
mock_stream = MagicMock()
409+
mock_stream.__enter__ = MagicMock(return_value=mock_stream)
410+
mock_stream.__exit__ = MagicMock(return_value=False)
411+
mock_stream.get_final_message.return_value = response
412+
mock_client.messages.stream.return_value = mock_stream
408413
mock_get_client.return_value = mock_client
409414

410415
result, sources = provider.agentic_research(

0 commit comments

Comments
 (0)