Skip to content

Commit 6297c84

Browse files
committed
Add deployed analytics metadata test
1 parent 70ed1ea commit 6297c84

1 file changed

Lines changed: 310 additions & 0 deletions

File tree

Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
from datetime import datetime, timedelta, timezone
2+
from time import sleep
3+
from typing import Any
4+
from urllib.parse import urlencode
5+
from uuid import UUID
6+
7+
import pytest
8+
9+
10+
TOP_LEVEL_ANALYTICS_KEYS = {
11+
"status",
12+
"message",
13+
"start_time",
14+
"end_time",
15+
"requested_version",
16+
"resolved_channel",
17+
"unique",
18+
"requests",
19+
}
20+
REQUEST_ANALYTICS_KEYS = {
21+
"request_uuid",
22+
"created_at",
23+
"api_version",
24+
"country_id",
25+
"model_version",
26+
"requested_version",
27+
"resolved_channel",
28+
"endpoint",
29+
"method",
30+
"response_status_code",
31+
"distinct_variable_count",
32+
"unsupported_variable_count",
33+
"deprecated_allowlisted_variable_count",
34+
"variables",
35+
}
36+
VARIABLE_ANALYTICS_KEYS = {
37+
"variable_name",
38+
"entity_type",
39+
"source",
40+
"period_granularity",
41+
"entity_count",
42+
"period_count",
43+
"occurrence_count",
44+
"availability_status",
45+
"variable_name_truncated",
46+
}
47+
EXPECTED_VARIABLES = {
48+
"age": {
49+
"variable_name": "age",
50+
"entity_type": "person",
51+
"source": "household_input",
52+
"period_granularity": "year",
53+
"entity_count": 2,
54+
"period_count": 1,
55+
"occurrence_count": 2,
56+
"availability_status": "supported",
57+
"variable_name_truncated": False,
58+
},
59+
"employment_income": {
60+
"variable_name": "employment_income",
61+
"entity_type": "person",
62+
"source": "household_input",
63+
"period_granularity": "year",
64+
"entity_count": 1,
65+
"period_count": 1,
66+
"occurrence_count": 1,
67+
"availability_status": "supported",
68+
"variable_name_truncated": False,
69+
},
70+
"state_name": {
71+
"variable_name": "state_name",
72+
"entity_type": "household",
73+
"source": "household_input",
74+
"period_granularity": "year",
75+
"entity_count": 1,
76+
"period_count": 1,
77+
"occurrence_count": 1,
78+
"availability_status": "supported",
79+
"variable_name_truncated": False,
80+
},
81+
"ctc": {
82+
"variable_name": "ctc",
83+
"entity_type": "tax_unit",
84+
"source": "requested_output",
85+
"period_granularity": "year",
86+
"entity_count": 1,
87+
"period_count": 1,
88+
"occurrence_count": 1,
89+
"availability_status": "supported",
90+
"variable_name_truncated": False,
91+
},
92+
}
93+
94+
95+
def test_calculate_request_records_complete_analytics_metadata(
96+
deployed_api,
97+
auth_token,
98+
request_version,
99+
expected_channel,
100+
route_mode,
101+
):
102+
if not expected_channel or not route_mode:
103+
pytest.skip(
104+
"Modal route metadata is only asserted in Modal route tests"
105+
)
106+
107+
resolved_channel = _expected_resolved_channel(
108+
deployed_api,
109+
request_version,
110+
expected_channel,
111+
route_mode,
112+
)
113+
requested_version = request_version or "current"
114+
start_time = datetime.now(timezone.utc) - timedelta(seconds=5)
115+
116+
calculate_response = deployed_api.post(
117+
"/us/calculate",
118+
headers={"Authorization": f"Bearer {auth_token}"},
119+
json_body=_calculate_request_body(requested_version),
120+
)
121+
122+
assert calculate_response.status_code == 200
123+
124+
analytics_request = _wait_for_analytics_request(
125+
deployed_api,
126+
auth_token,
127+
start_time=start_time,
128+
requested_version=requested_version,
129+
resolved_channel=resolved_channel,
130+
)
131+
132+
_assert_request_metadata(
133+
analytics_request,
134+
requested_version=requested_version,
135+
resolved_channel=resolved_channel,
136+
)
137+
138+
139+
def _expected_resolved_channel(
140+
deployed_api,
141+
request_version: str | None,
142+
expected_channel: str,
143+
route_mode: str,
144+
) -> str:
145+
if route_mode == "channel":
146+
return expected_channel
147+
148+
if route_mode != "exact":
149+
raise AssertionError(f"Unexpected route mode: {route_mode}")
150+
151+
versions_response = deployed_api.get("/versions/us")
152+
assert versions_response.status_code == 200
153+
versions = versions_response.json()
154+
for channel in ("current", "frontier"):
155+
if versions.get(channel) == request_version:
156+
return channel
157+
158+
raise AssertionError(
159+
f"No active channel serves US package version {request_version}"
160+
)
161+
162+
163+
def _calculate_request_body(requested_version: str) -> dict[str, Any]:
164+
return {
165+
"version": requested_version,
166+
"household": {
167+
"people": {
168+
"parent": {
169+
"age": {"2026": 35},
170+
"employment_income": {"2026": 60_000},
171+
},
172+
"child": {
173+
"age": {"2026": 6},
174+
},
175+
},
176+
"tax_units": {
177+
"tax_unit": {
178+
"members": ["parent", "child"],
179+
"ctc": {"2026": None},
180+
},
181+
},
182+
"spm_units": {
183+
"spm_unit": {
184+
"members": ["parent", "child"],
185+
},
186+
},
187+
"households": {
188+
"household": {
189+
"members": ["parent", "child"],
190+
"state_name": {"2026": "AZ"},
191+
},
192+
},
193+
},
194+
}
195+
196+
197+
def _wait_for_analytics_request(
198+
deployed_api,
199+
auth_token: str,
200+
*,
201+
start_time: datetime,
202+
requested_version: str,
203+
resolved_channel: str,
204+
) -> dict[str, Any]:
205+
for _ in range(5):
206+
payload = _analytics_payload(
207+
deployed_api,
208+
auth_token,
209+
start_time=start_time,
210+
requested_version=requested_version,
211+
resolved_channel=resolved_channel,
212+
)
213+
request_record = _matching_request(payload)
214+
if request_record is not None:
215+
return request_record
216+
sleep(1)
217+
218+
raise AssertionError(
219+
"Calculate analytics request was not returned with the expected "
220+
f"metadata. Last payload: {payload}"
221+
)
222+
223+
224+
def _analytics_payload(
225+
deployed_api,
226+
auth_token: str,
227+
*,
228+
start_time: datetime,
229+
requested_version: str,
230+
resolved_channel: str,
231+
) -> dict[str, Any]:
232+
query = urlencode(
233+
{
234+
"start_time": _isoformat_utc(start_time),
235+
"requested_version": requested_version,
236+
"resolved_channel": resolved_channel,
237+
"limit": "20",
238+
}
239+
)
240+
response = deployed_api.get(
241+
f"/analytics/calculate/requests?{query}",
242+
headers={"Authorization": f"Bearer {auth_token}"},
243+
)
244+
assert response.status_code == 200
245+
246+
payload = response.json()
247+
assert set(payload) == TOP_LEVEL_ANALYTICS_KEYS
248+
assert payload["status"] == "ok"
249+
assert payload["message"] is None
250+
assert payload["start_time"] == _isoformat_utc(start_time)
251+
assert payload["end_time"] is None
252+
assert payload["requested_version"] == requested_version
253+
assert payload["resolved_channel"] == resolved_channel
254+
assert payload["unique"] is False
255+
assert isinstance(payload["requests"], list)
256+
return payload
257+
258+
259+
def _matching_request(payload: dict[str, Any]) -> dict[str, Any] | None:
260+
expected_variable_names = set(EXPECTED_VARIABLES)
261+
for request_record in payload["requests"]:
262+
variables = request_record.get("variables", [])
263+
variable_names = {
264+
variable.get("variable_name") for variable in variables
265+
}
266+
if variable_names == expected_variable_names:
267+
return request_record
268+
return None
269+
270+
271+
def _assert_request_metadata(
272+
request_record: dict[str, Any],
273+
*,
274+
requested_version: str,
275+
resolved_channel: str,
276+
) -> None:
277+
assert set(request_record) == REQUEST_ANALYTICS_KEYS
278+
UUID(request_record["request_uuid"])
279+
assert _parse_api_datetime(request_record["created_at"])
280+
assert isinstance(request_record["api_version"], str)
281+
assert request_record["api_version"]
282+
assert request_record["country_id"] == "us"
283+
assert isinstance(request_record["model_version"], str)
284+
assert request_record["model_version"]
285+
assert request_record["requested_version"] == requested_version
286+
assert request_record["resolved_channel"] == resolved_channel
287+
assert request_record["endpoint"] == "calculate"
288+
assert request_record["method"] == "POST"
289+
assert request_record["response_status_code"] == 200
290+
assert request_record["distinct_variable_count"] == len(EXPECTED_VARIABLES)
291+
assert request_record["unsupported_variable_count"] == 0
292+
assert request_record["deprecated_allowlisted_variable_count"] == 0
293+
294+
variables_by_name = {
295+
variable["variable_name"]: variable
296+
for variable in request_record["variables"]
297+
}
298+
assert variables_by_name == EXPECTED_VARIABLES
299+
for variable in variables_by_name.values():
300+
assert set(variable) == VARIABLE_ANALYTICS_KEYS
301+
302+
303+
def _isoformat_utc(value: datetime) -> str:
304+
return (
305+
value.astimezone(timezone.utc).replace(tzinfo=None).isoformat() + "Z"
306+
)
307+
308+
309+
def _parse_api_datetime(value: str) -> datetime:
310+
return datetime.fromisoformat(value.replace("Z", "+00:00"))

0 commit comments

Comments
 (0)