Skip to content

Commit 9c745ba

Browse files
authored
Merge pull request #2310 from PolicyEngine/HerveRI/tests_tracer_analysis_service
Added tests for execute_analysis function
2 parents 4d07b2a + dc583bc commit 9c745ba

3 files changed

Lines changed: 133 additions & 0 deletions

File tree

changelog_entry.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- bump: minor
2+
changes:
3+
added:
4+
- Tests for execute_analysis function

tests/fixtures/services/tracer_analysis_service.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
import pytest
2+
import json
3+
from policyengine_api.services.tracer_analysis_service import (
4+
TracerAnalysisService,
5+
)
6+
from unittest.mock import patch
7+
8+
19
valid_tracer_output = [
210
" snap<2027, (default)> = [6769.799]",
311
" snap<2027-01, (default)> = [561.117]",
@@ -22,3 +30,55 @@
2230
spliced_valid_tracer_output_leaf_variable = valid_tracer_output[8:]
2331

2432
empty_tracer = []
33+
34+
35+
@pytest.fixture
36+
def sample_tracer_data():
37+
return valid_tracer_output
38+
39+
40+
@pytest.fixture
41+
def sample_expected_segment():
42+
return spliced_valid_tracer_output_nested_variable
43+
44+
45+
@pytest.fixture
46+
def mock_get_tracer(sample_tracer_data):
47+
with patch.object(
48+
TracerAnalysisService, "get_tracer", return_value=sample_tracer_data
49+
) as mock:
50+
yield mock
51+
52+
53+
@pytest.fixture
54+
def mock_parse_tracer_output(sample_expected_segment):
55+
with patch.object(
56+
TracerAnalysisService,
57+
"_parse_tracer_output",
58+
return_value=sample_expected_segment,
59+
) as mock:
60+
yield mock
61+
62+
63+
@pytest.fixture
64+
def mock_get_existing_analysis():
65+
with patch.object(
66+
TracerAnalysisService,
67+
"get_existing_analysis",
68+
return_value="Existing static analysis",
69+
) as mock:
70+
yield mock
71+
72+
73+
@pytest.fixture
74+
def mock_trigger_ai_analysis():
75+
def dummy_generator():
76+
yield "stream chunk 1"
77+
yield "stream chunk 2"
78+
79+
with patch.object(
80+
TracerAnalysisService,
81+
"trigger_ai_analysis",
82+
return_value=dummy_generator(),
83+
) as mock:
84+
yield mock
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import pytest
2+
import json
3+
from policyengine_api.services.tracer_analysis_service import (
4+
TracerAnalysisService,
5+
)
6+
from werkzeug.exceptions import NotFound
7+
8+
from tests.fixtures.services.tracer_analysis_service import (
9+
sample_tracer_data,
10+
sample_expected_segment,
11+
mock_get_tracer,
12+
mock_get_existing_analysis,
13+
mock_parse_tracer_output,
14+
mock_trigger_ai_analysis,
15+
)
16+
17+
service = TracerAnalysisService()
18+
country_id = "us"
19+
household_id = "71424"
20+
policy_id = "2"
21+
target_variable = "takes_up_snap_if_eligible"
22+
23+
24+
class TestExecuteAnalysis:
25+
def test_execute_analysis_static(
26+
self,
27+
mock_get_tracer,
28+
mock_parse_tracer_output,
29+
mock_get_existing_analysis,
30+
):
31+
"""
32+
GIVEN a valid tracer data and an expected parsed segment (included as fixture),
33+
AND get_existing_analysis returns a static analysis (included as fixture),
34+
WHEN execute_analysis is called,
35+
THEN then a static analysis with the "static" flag should be returned.
36+
"""
37+
38+
analysis, analysis_type = service.execute_analysis(
39+
country_id, household_id, policy_id, target_variable
40+
)
41+
42+
assert analysis == "Existing static analysis"
43+
assert analysis_type == "static"
44+
45+
def test_execute_analysis_streaming(
46+
self,
47+
mock_get_tracer,
48+
mock_parse_tracer_output,
49+
mock_get_existing_analysis,
50+
mock_trigger_ai_analysis,
51+
):
52+
"""
53+
GIVEN a valid tracer data and an expected parsed segment,
54+
AND get_existing_analysis returns None,
55+
WHEN execute_analysis is called,
56+
THEN trigger_ai_analysis is called and returns a generator with the "streaming" flag.
57+
"""
58+
59+
# When existing analysis value is None
60+
mock_get_existing_analysis.return_value = None
61+
62+
analysis, analysis_type = service.execute_analysis(
63+
country_id, household_id, policy_id, target_variable
64+
)
65+
66+
expected_streaming_output = ["stream chunk 1", "stream chunk 2"]
67+
streaming_output = list(analysis)
68+
assert streaming_output == expected_streaming_output
69+
assert analysis_type == "streaming"

0 commit comments

Comments
 (0)