Skip to content

Commit 2e3adfe

Browse files
authored
Merge pull request #2352 from PolicyEngine/SonaliBedge/2237-tracer-get-method-handling-invalid-input-incorrectly
Added validation in get_tracer method to handle inputs
2 parents c39a641 + 6214420 commit 2e3adfe

5 files changed

Lines changed: 156 additions & 3 deletions

File tree

changelog_entry.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- bump: patch
2+
changes:
3+
fixed:
4+
- Added functions and tests to handle invalid or incorrect input parameters.

policyengine_api/routes/tracer_analysis_routes.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
TracerAnalysisService,
99
)
1010
import json
11+
from policyengine_api.country import COUNTRY_PACKAGE_VERSIONS
12+
import re
1113

1214
tracer_analysis_bp = Blueprint("tracer_analysis", __name__)
1315
tracer_analysis_service = TracerAnalysisService()
@@ -26,6 +28,10 @@ def execute_tracer_analysis(country_id):
2628
household_id = payload.get("household_id")
2729
policy_id = payload.get("policy_id")
2830
variable = payload.get("variable")
31+
api_version = COUNTRY_PACKAGE_VERSIONS[country_id]
32+
33+
if not isinstance(variable, str):
34+
raise BadRequest("variable must be a string")
2935

3036
analysis, analysis_type = tracer_analysis_service.execute_analysis(
3137
country_id,

policyengine_api/services/tracer_analysis_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,10 @@ def _parse_tracer_output(self, tracer_output, target_variable):
112112
# Create a regex pattern to match the exact variable name
113113
# This will match the variable name followed by optional whitespace,
114114
# then optional angle brackets with any content, then optional whitespace
115-
116115
pattern = (
117116
rf"^(\s*)({re.escape(target_variable)})(?!\w)\s*(?:<[^>]*>)?\s*"
118117
)
118+
119119
for line in tracer_output:
120120
# Count leading spaces to determine indentation level
121121
indent = len(line) - len(line.strip())

policyengine_api/utils/payload_validators/validate_tracer_analysis_payload.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,22 @@ def validate_tracer_analysis_payload(payload: dict):
88
if key not in payload:
99
return False, f"Missing required key: {key}"
1010

11+
# Validate types and formats
12+
household_id = payload["household_id"]
13+
policy_id = payload["policy_id"]
14+
variable = payload["variable"]
15+
16+
if not isinstance(household_id, (str, int)) or (
17+
isinstance(household_id, str) and not household_id.isdigit()
18+
):
19+
return False, "household_id must be a numeric integer or string"
20+
21+
if not isinstance(policy_id, (str, int)) or (
22+
isinstance(policy_id, str) and not policy_id.isdigit()
23+
):
24+
return False, "policy_id must be a numeric integer or string"
25+
26+
if not isinstance(variable, str):
27+
return False, "variable must be a string"
28+
1129
return True, None

tests/to_refactor/python/test_tracer_analysis_routes.py

Lines changed: 127 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
import pytest
22
from flask import json
33
from unittest.mock import patch
4+
from werkzeug.exceptions import BadRequest
5+
6+
# constants
7+
VALID_HOUSEHOLD_ID = 123
8+
VALID_POLICY_ID = 456
9+
INVALID_HOUSEHOLD_ID = "abc123"
10+
INVALID_POLICY_ID = "invalid-id"
11+
TEST_VARIABLE = "disposable_income"
12+
INVALID_VARIABLE = 123
413

514

615
@patch("policyengine_api.services.tracer_analysis_service.local_database")
@@ -41,8 +50,8 @@ def test_execute_tracer_analysis_no_tracer(mock_db, rest_client):
4150
response = rest_client.post(
4251
"/us/tracer-analysis",
4352
json={
44-
"household_id": "test_household",
45-
"policy_id": "test_policy",
53+
"household_id": VALID_HOUSEHOLD_ID,
54+
"policy_id": VALID_POLICY_ID,
4655
"variable": "disposable_income",
4756
},
4857
)
@@ -85,6 +94,32 @@ def test_execute_tracer_analysis_ai_error(
8594
assert json.loads(response.data)["status"] == "error"
8695

8796

97+
@patch("policyengine_api.services.tracer_analysis_service.local_database")
98+
def test_invalid_variable_types(mock_db, rest_client):
99+
"""Test that different non-string variable types are rejected"""
100+
invalid_variables = [
101+
123,
102+
None,
103+
{"key": "value"},
104+
["list"],
105+
True,
106+
]
107+
108+
for invalid_var in invalid_variables:
109+
response = rest_client.post(
110+
"/us/tracer-analysis",
111+
json={
112+
"household_id": VALID_HOUSEHOLD_ID,
113+
"policy_id": VALID_POLICY_ID,
114+
"variable": invalid_var,
115+
},
116+
)
117+
assert response.status_code == 400
118+
assert (
119+
"variable must be a string" in json.loads(response.data)["message"]
120+
)
121+
122+
88123
# Test invalid country
89124
def test_invalid_country(rest_client):
90125
response = rest_client.post(
@@ -97,3 +132,93 @@ def test_invalid_country(rest_client):
97132
)
98133
assert response.status_code == 400
99134
assert b"Country invalid_country not found" in response.data
135+
136+
137+
def test_invalid_household_id_format(rest_client):
138+
"""Test that non-numeric household_id is rejected"""
139+
response = rest_client.post(
140+
"/us/tracer-analysis",
141+
json={
142+
"household_id": INVALID_HOUSEHOLD_ID,
143+
"policy_id": VALID_POLICY_ID,
144+
"variable": "disposable_income",
145+
},
146+
)
147+
assert response.status_code == 400
148+
assert (
149+
"household_id must be a numeric integer or string"
150+
in json.loads(response.data)["message"]
151+
)
152+
153+
154+
def test_invalid_policy_id_format(rest_client):
155+
"""Test that non-numeric policy_id is rejected"""
156+
response = rest_client.post(
157+
"/us/tracer-analysis",
158+
json={
159+
"household_id": VALID_HOUSEHOLD_ID,
160+
"policy_id": INVALID_POLICY_ID,
161+
"variable": "disposable_income",
162+
},
163+
)
164+
assert response.status_code == 400
165+
assert (
166+
"policy_id must be a numeric integer or string"
167+
in json.loads(response.data)["message"]
168+
)
169+
170+
171+
def test_empty_household_id(rest_client):
172+
"""Test that empty household_id is rejected"""
173+
response = rest_client.post(
174+
"/us/tracer-analysis",
175+
json={
176+
"household_id": "",
177+
"policy_id": VALID_POLICY_ID,
178+
"variable": "disposable_income",
179+
},
180+
)
181+
assert response.status_code == 400
182+
183+
184+
def test_missing_required_fields(rest_client):
185+
"""Test that missing required fields are rejected"""
186+
response = rest_client.post(
187+
"/us/tracer-analysis",
188+
json={
189+
# household_id missing
190+
"policy_id": VALID_POLICY_ID,
191+
"variable": "disposable_income",
192+
},
193+
)
194+
assert response.status_code == 400
195+
196+
197+
def test_invalid_types(rest_client):
198+
"""Test that invalid types are rejected"""
199+
response = rest_client.post(
200+
"/us/tracer-analysis",
201+
json={
202+
"household_id": None, # Invalid type
203+
"policy_id": INVALID_POLICY_ID,
204+
"variable": "disposable_income",
205+
},
206+
)
207+
assert response.status_code == 400
208+
209+
210+
def test_validate_tracer_analysis_payload_failure(rest_client):
211+
"""Test handling of invalid payload from validate_tracer_analysis_payload"""
212+
response = rest_client.post(
213+
"/us/tracer-analysis",
214+
json={
215+
# Missing required field 'variable'
216+
"household_id": VALID_HOUSEHOLD_ID,
217+
"policy_id": VALID_POLICY_ID,
218+
},
219+
)
220+
assert response.status_code == 400
221+
assert (
222+
"Missing required key: variable"
223+
in json.loads(response.data)["message"]
224+
)

0 commit comments

Comments
 (0)