Skip to content

Commit f95c85c

Browse files
authored
Merge pull request #1528 from PolicyEngine/enable-modal-memory-snapshot
Enable Modal memory snapshots on household API worker
2 parents b56dd03 + f95b0a8 commit f95c85c

6 files changed

Lines changed: 211 additions & 11 deletions

File tree

changelog.d/1527.changed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Enable Modal memory snapshots on the household API worker so cold starts restore the pre-loaded Flask app and policyengine country systems instead of re-running the ~45s import chain on every fresh container.

docs/engineering/skills/modal-release-prs.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,8 @@ Modal deployment history for deploy provenance.
121121

122122
For `/calculate` and `/calculate_demo`, the Modal gateway reads the top-level
123123
request field `version` and removes it before dispatching to the worker's
124-
`handle_household_request` Modal function. Accepted values are:
124+
`HouseholdWorker.handle_household_request` Modal class method. Accepted
125+
values are:
125126

126127
- omitted or `current`: route to the current worker
127128
- `frontier`: route to the frontier worker

policyengine_household_api/modal_release/gateway.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,22 @@ def resolve_app_for_request(
168168
def call_worker_function(app_name: str, payload: dict[str, Any]) -> Response:
169169
import modal
170170

171-
worker_function = modal.Function.from_name(
172-
app_name,
173-
"handle_household_request",
174-
)
175-
return _response_from_dispatch_result(worker_function.remote(payload))
171+
# Prefer the class-based worker (post #1528). During the release
172+
# transition the existing frontier is promoted to current without a
173+
# redeploy, so for one release cycle the current worker may still expose
174+
# the pre-#1528 top-level `handle_household_request` function. Fall back
175+
# to that shape if the class is not present.
176+
try:
177+
worker_cls = modal.Cls.from_name(app_name, "HouseholdWorker")
178+
return _response_from_dispatch_result(
179+
worker_cls().handle_household_request.remote(payload)
180+
)
181+
except modal.exception.NotFoundError:
182+
worker_function = modal.Function.from_name(
183+
app_name,
184+
"handle_household_request",
185+
)
186+
return _response_from_dispatch_result(worker_function.remote(payload))
176187

177188

178189
def _extract_requested_version(body: bytes) -> tuple[bytes, str]:

policyengine_household_api/modal_release/worker_app.py

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def worker_function_options(
4545
"secrets": [household_api_secret()],
4646
"timeout": 180,
4747
"scaledown_window": 300,
48+
"enable_memory_snapshot": True,
4849
}
4950
if environment == "main":
5051
options["min_containers"] = 3
@@ -53,9 +54,69 @@ def worker_function_options(
5354
return options
5455

5556

56-
@app.function(**worker_function_options())
57-
def handle_household_request(payload: dict[str, Any]) -> dict[str, Any]:
58-
configure_google_credentials()
59-
from policyengine_household_api.api import app as flask_app
57+
@app.cls(**worker_function_options())
58+
class HouseholdWorker:
59+
"""Worker class for handling household API requests.
6060
61-
return dispatch_to_flask_app(flask_app, payload)
61+
Uses a Modal class with ``@modal.enter(snap=True)`` so the heavy Flask
62+
app import runs at memory-snapshot creation time. Subsequent container
63+
starts restore from the snapshot in seconds rather than re-running the
64+
full policyengine country-package import chain on every cold start.
65+
"""
66+
67+
@modal.enter(snap=True)
68+
def load_flask_app(self) -> None:
69+
# Importing `policyengine_household_api.api` runs
70+
# `initialize_analytics_db_if_enabled` at module level, which opens a
71+
# Cloud SQL connection in environments where analytics is enabled.
72+
# That connection needs GOOGLE_APPLICATION_CREDENTIALS, set by
73+
# `configure_google_credentials()`. Configure credentials first so the
74+
# snapshot-time import can succeed even before any request method runs.
75+
configure_google_credentials()
76+
77+
from policyengine_household_api.api import app as flask_app
78+
79+
self.flask_app = flask_app
80+
81+
@modal.enter(snap=False)
82+
def reset_post_snapshot_state(self) -> None:
83+
# Runs on every container start AFTER snapshot restore. Memory
84+
# snapshots preserve Python object state but not live network
85+
# connections; the SQLAlchemy pool and the Cloud SQL Connector
86+
# captured in the snapshot hold sockets that closed at snapshot
87+
# time. Reset them so the first request opens fresh connections.
88+
#
89+
# Also force-recreate the Google credentials file: Modal preserves
90+
# env vars across snapshot restore, but /tmp is not guaranteed to
91+
# be preserved. Without popping the env var first,
92+
# configure_google_credentials() would short-circuit on the
93+
# surviving GOOGLE_APPLICATION_CREDENTIALS and leave it pointing
94+
# at a missing file, breaking analytics DB reconnects.
95+
# See: https://modal.com/docs/guide/memory-snapshot
96+
os.environ.pop("GOOGLE_APPLICATION_CREDENTIALS", None)
97+
configure_google_credentials()
98+
99+
from policyengine_household_api.data import analytics_setup
100+
101+
if not analytics_setup.is_analytics_enabled():
102+
return
103+
104+
analytics_setup.cleanup()
105+
106+
try:
107+
with self.flask_app.app_context():
108+
analytics_setup.db.engine.dispose()
109+
except Exception as exc:
110+
import logging
111+
112+
logging.getLogger(__name__).warning(
113+
"Failed to dispose analytics DB engine after snapshot "
114+
"restore; subsequent queries may reconnect lazily: %s",
115+
exc,
116+
)
117+
118+
@modal.method()
119+
def handle_household_request(
120+
self, payload: dict[str, Any]
121+
) -> dict[str, Any]:
122+
return dispatch_to_flask_app(self.flask_app, payload)

tests/unit/modal_release/test_gateway.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,3 +267,97 @@ def from_name(name, *, create_if_missing):
267267

268268
with pytest.raises(modal.exception.AuthError):
269269
load_modal_manifest()
270+
271+
272+
def test_call_worker_function_uses_class_when_available(monkeypatch):
273+
"""When the worker exposes the new ``HouseholdWorker`` class, the
274+
gateway should dispatch through ``modal.Cls.from_name``."""
275+
from policyengine_household_api.modal_release import gateway
276+
277+
captured = {}
278+
279+
class _StubMethod:
280+
def remote(self, payload):
281+
captured["dispatched_via"] = "class"
282+
captured["payload"] = payload
283+
return {"status_code": 200, "body": b'{"status":"ok"}'}
284+
285+
class _StubInstance:
286+
handle_household_request = _StubMethod()
287+
288+
class _StubCls:
289+
def __call__(self):
290+
return _StubInstance()
291+
292+
def fake_cls_from_name(app_name, class_name):
293+
captured["cls_app_name"] = app_name
294+
captured["class_name"] = class_name
295+
return _StubCls()
296+
297+
def fake_function_from_name(app_name, function_name):
298+
captured["fallback_invoked"] = True
299+
raise AssertionError("Function fallback must not be invoked")
300+
301+
monkeypatch.setattr(
302+
modal.Cls, "from_name", staticmethod(fake_cls_from_name)
303+
)
304+
monkeypatch.setattr(
305+
modal.Function, "from_name", staticmethod(fake_function_from_name)
306+
)
307+
308+
response = gateway.call_worker_function(
309+
"frontier-app", {"household": {"foo": "bar"}}
310+
)
311+
312+
assert response.status_code == 200
313+
assert captured["dispatched_via"] == "class"
314+
assert captured["cls_app_name"] == "frontier-app"
315+
assert captured["class_name"] == "HouseholdWorker"
316+
assert captured["payload"] == {"household": {"foo": "bar"}}
317+
assert "fallback_invoked" not in captured
318+
319+
320+
def test_call_worker_function_falls_back_to_function_for_legacy_workers(
321+
monkeypatch,
322+
):
323+
"""During a release transition, the existing frontier worker gets
324+
promoted to current without a redeploy, so the current worker may
325+
still expose the pre-class ``handle_household_request`` function.
326+
The gateway must fall back to ``modal.Function.from_name`` when the
327+
class cannot be found."""
328+
from policyengine_household_api.modal_release import gateway
329+
330+
captured = {}
331+
332+
def fake_cls_from_name(app_name, class_name):
333+
raise modal.exception.NotFoundError(
334+
f"No class named `{class_name}` in app `{app_name}`"
335+
)
336+
337+
class _StubFunction:
338+
def remote(self, payload):
339+
captured["dispatched_via"] = "function"
340+
captured["payload"] = payload
341+
return {"status_code": 200, "body": b'{"status":"ok"}'}
342+
343+
def fake_function_from_name(app_name, function_name):
344+
captured["fn_app_name"] = app_name
345+
captured["function_name"] = function_name
346+
return _StubFunction()
347+
348+
monkeypatch.setattr(
349+
modal.Cls, "from_name", staticmethod(fake_cls_from_name)
350+
)
351+
monkeypatch.setattr(
352+
modal.Function, "from_name", staticmethod(fake_function_from_name)
353+
)
354+
355+
response = gateway.call_worker_function(
356+
"current-app", {"household": {"foo": "bar"}}
357+
)
358+
359+
assert response.status_code == 200
360+
assert captured["dispatched_via"] == "function"
361+
assert captured["fn_app_name"] == "current-app"
362+
assert captured["function_name"] == "handle_household_request"
363+
assert captured["payload"] == {"household": {"foo": "bar"}}

tests/unit/modal_release/test_worker_app.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,38 @@ def test_worker_function_options_do_not_keep_workers_warm_without_env(
5252
worker_function_options(modal_environment=None)
5353

5454

55+
def test_worker_function_options_enable_memory_snapshot_in_all_envs(
56+
worker_app,
57+
):
58+
for environment in ("main", "staging", "testing"):
59+
options = worker_app.worker_function_options(
60+
modal_environment=environment
61+
)
62+
assert options["enable_memory_snapshot"] is True, (
63+
f"enable_memory_snapshot must be True in `{environment}` "
64+
"so cold starts restore from a memory snapshot instead of "
65+
"re-running the ~45s policyengine import chain"
66+
)
67+
68+
69+
def test_household_worker_exposes_snapshot_entrypoint(worker_app):
70+
"""The class must declare its snapshot-time hook so heavy imports
71+
are captured in the memory snapshot rather than running per cold
72+
start."""
73+
worker_cls = worker_app.HouseholdWorker
74+
assert hasattr(worker_cls, "load_flask_app")
75+
assert hasattr(worker_cls, "handle_household_request")
76+
77+
78+
def test_household_worker_exposes_post_snapshot_reset_hook(worker_app):
79+
"""The class must declare a post-restore hook so network state
80+
captured in the memory snapshot (SQLAlchemy pool, Cloud SQL
81+
Connector) gets reset on every container start. Modal preserves
82+
Python object state but not live TCP sockets across snapshots."""
83+
worker_cls = worker_app.HouseholdWorker
84+
assert hasattr(worker_cls, "reset_post_snapshot_state")
85+
86+
5587
def test_country_package_install_specs_use_release_package_versions_only():
5688
assert country_package_install_specs(
5789
{

0 commit comments

Comments
 (0)