Skip to content

Commit a334cc2

Browse files
committed
refactor(worker): add pydantic payload validation to all jobs
replace unsafe data.get() with pydantic model_validate() for runtime type safety. routers now return 422 for invalid payloads instead of 500.
1 parent c6100b7 commit a334cc2

15 files changed

Lines changed: 185 additions & 58 deletions

apps/worker/src/jobs/data_cleanup.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import structlog
77

88
from src.jobs.base import BaseJob, register_job
9+
from src.jobs.schemas import DataCleanupPayload
910
from src.lib.config import settings
1011
from src.lib.retry import with_retry
1112

@@ -32,13 +33,12 @@ async def _call_api(self, headers: dict[str, str], payload: dict[str, Any]) -> i
3233
return response.status_code
3334

3435
async def execute(self, data: dict[str, Any]) -> dict[str, Any]:
35-
retention_days: int = data.get("retention_days", 90)
36-
resource_type: str = data.get("resource_type", "all")
36+
payload = DataCleanupPayload.model_validate(data)
3737

3838
logger.info(
3939
"data_cleanup_start",
40-
retention_days=retention_days,
41-
resource_type=resource_type,
40+
retention_days=payload.retention_days,
41+
resource_type=payload.resource_type,
4242
)
4343

4444
headers: dict[str, str] = {}
@@ -47,14 +47,17 @@ async def execute(self, data: dict[str, Any]) -> dict[str, Any]:
4747

4848
status_code = await self._call_api(
4949
headers,
50-
{"retention_days": retention_days, "resource_type": resource_type},
50+
{
51+
"retention_days": payload.retention_days,
52+
"resource_type": payload.resource_type,
53+
},
5154
)
5255

5356
logger.info(
5457
"data_cleanup_complete",
5558
api_status=status_code,
5659
)
57-
return {"cleaned": True, "retention_days": retention_days}
60+
return {"cleaned": True, "retention_days": payload.retention_days}
5861

5962

6063
register_job(DataCleanupJob())

apps/worker/src/jobs/medication_reminder.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import structlog
66

77
from src.jobs.base import BaseJob, register_job
8+
from src.jobs.schemas import MedicationReminderPayload
89
from src.lib.dispatch import dispatch_local
910

1011
logger = structlog.get_logger(__name__)
@@ -18,30 +19,28 @@ def job_type(self) -> str:
1819
return "medication.reminder"
1920

2021
async def execute(self, data: dict[str, Any]) -> dict[str, Any]:
21-
host_id: str = data.get("host_id", "")
22-
pill_name: str = data.get("pill_name", "Unknown")
23-
tokens: list[str] = data.get("tokens", [])
22+
payload = MedicationReminderPayload.model_validate(data)
2423

25-
if not tokens:
26-
logger.warning("medication_reminder_no_tokens", host_id=host_id)
24+
if not payload.tokens:
25+
logger.warning("medication_reminder_no_tokens", host_id=payload.host_id)
2726
return {"sent_count": 0, "skipped": True}
2827

2928
await dispatch_local(
3029
"notification.send",
3130
{
32-
"tokens": tokens,
31+
"tokens": payload.tokens,
3332
"title": "Medication Reminder",
34-
"body": f"Time to take {pill_name}",
35-
"data": {"host_id": host_id, "type": "medication_reminder"},
33+
"body": f"Time to take {payload.pill_name}",
34+
"data": {"host_id": payload.host_id, "type": "medication_reminder"},
3635
},
3736
)
3837

3938
logger.info(
4039
"medication_reminder_dispatched",
41-
host_id=host_id,
42-
pill_name=pill_name,
40+
host_id=payload.host_id,
41+
pill_name=payload.pill_name,
4342
)
44-
return {"dispatched": True, "pill_name": pill_name}
43+
return {"dispatched": True, "pill_name": payload.pill_name}
4544

4645

4746
register_job(MedicationReminderJob())

apps/worker/src/jobs/notification_send.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import structlog
88

99
from src.jobs.base import BaseJob, register_job
10+
from src.jobs.schemas import NotificationSendPayload
1011
from src.lib.config import settings
1112
from src.lib.retry import with_retry
1213

@@ -35,10 +36,11 @@ def job_type(self) -> str:
3536
return "notification.send"
3637

3738
async def execute(self, data: dict[str, Any]) -> dict[str, Any]:
38-
tokens: list[str] = data.get("tokens", [])
39-
title: str = data.get("title", "")
40-
body: str = data.get("body", "")
41-
extra: dict[str, str] = data.get("data", {})
39+
payload = NotificationSendPayload.model_validate(data)
40+
tokens = payload.tokens
41+
title = payload.title
42+
body = payload.body
43+
extra = payload.data
4244

4345
if not tokens:
4446
logger.warning("notification_send_no_tokens")
@@ -57,7 +59,7 @@ async def execute(self, data: dict[str, Any]) -> dict[str, Any]:
5759
message = messaging.MulticastMessage(
5860
tokens=tokens,
5961
notification=notification,
60-
data={k: str(v) for k, v in extra.items()},
62+
data=extra,
6163
)
6264
response: messaging.BatchResponse = await asyncio.wait_for(
6365
asyncio.to_thread(messaging.send_each_for_multicast, message),

apps/worker/src/jobs/relation_inactive.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import structlog
77

88
from src.jobs.base import BaseJob, register_job
9+
from src.jobs.schemas import RelationInactiveCheckPayload
910
from src.lib.config import settings
1011
from src.lib.retry import with_retry
1112

@@ -32,24 +33,26 @@ async def _call_api(self, headers: dict[str, str], params: dict[str, Any]) -> in
3233
return response.status_code
3334

3435
async def execute(self, data: dict[str, Any]) -> dict[str, Any]:
35-
threshold_days: int = data.get("threshold_days", 30)
36+
payload = RelationInactiveCheckPayload.model_validate(data)
3637

3738
logger.info(
3839
"relation_inactive_check_start",
39-
threshold_days=threshold_days,
40+
threshold_days=payload.threshold_days,
4041
)
4142

4243
headers: dict[str, str] = {}
4344
if settings.INTERNAL_API_KEY:
4445
headers["X-Internal-Key"] = settings.INTERNAL_API_KEY
4546

46-
status_code = await self._call_api(headers, {"threshold_days": threshold_days})
47+
status_code = await self._call_api(
48+
headers, {"threshold_days": payload.threshold_days}
49+
)
4750

4851
logger.info(
4952
"relation_inactive_check_complete",
5053
api_status=status_code,
5154
)
52-
return {"checked": True, "threshold_days": threshold_days}
55+
return {"checked": True, "threshold_days": payload.threshold_days}
5356

5457

5558
register_job(RelationInactiveCheckJob())

apps/worker/src/jobs/schemas.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""Pydantic payload schemas for worker jobs."""
2+
3+
from pydantic import BaseModel, Field
4+
5+
6+
class NotificationSendPayload(BaseModel):
7+
"""Payload for notification.send job."""
8+
9+
tokens: list[str]
10+
title: str = ""
11+
body: str = ""
12+
data: dict[str, str] = Field(default_factory=dict)
13+
14+
15+
class WellnessEscalationPayload(BaseModel):
16+
"""Payload for wellness.escalation job."""
17+
18+
log_id: str = ""
19+
host_id: str = ""
20+
status: str = ""
21+
summary: str = ""
22+
contact_tokens: list[str] = Field(default_factory=list)
23+
24+
25+
class MedicationReminderPayload(BaseModel):
26+
"""Payload for medication.reminder job."""
27+
28+
host_id: str = ""
29+
pill_name: str = "Unknown"
30+
tokens: list[str] = Field(default_factory=list)
31+
32+
33+
class DataCleanupPayload(BaseModel):
34+
"""Payload for data.cleanup job."""
35+
36+
retention_days: int = 90
37+
resource_type: str = "all"
38+
39+
40+
class RelationInactiveCheckPayload(BaseModel):
41+
"""Payload for relation.inactive_check job."""
42+
43+
threshold_days: int = 30
44+
45+
46+
class WellnessAggregatePayload(BaseModel):
47+
"""Payload for wellness.aggregate job."""
48+
49+
host_id: str = Field(min_length=1)
50+
date: str = Field(min_length=1)

apps/worker/src/jobs/wellness_aggregate.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import structlog
77

88
from src.jobs.base import BaseJob, register_job
9+
from src.jobs.schemas import WellnessAggregatePayload
910
from src.lib.config import settings
1011
from src.lib.retry import with_retry
1112

@@ -32,32 +33,33 @@ async def _call_api(self, headers: dict[str, str], params: dict[str, Any]) -> in
3233
return response.status_code
3334

3435
async def execute(self, data: dict[str, Any]) -> dict[str, Any]:
35-
host_id: str = data.get("host_id", "")
36-
date: str = data.get("date", "")
37-
38-
if not host_id or not date:
39-
msg = "host_id and date are required"
40-
raise ValueError(msg)
36+
payload = WellnessAggregatePayload.model_validate(data)
4137

4238
logger.info(
4339
"wellness_aggregate_start",
44-
host_id=host_id,
45-
date=date,
40+
host_id=payload.host_id,
41+
date=payload.date,
4642
)
4743

4844
headers: dict[str, str] = {}
4945
if settings.INTERNAL_API_KEY:
5046
headers["X-Internal-Key"] = settings.INTERNAL_API_KEY
5147

52-
status_code = await self._call_api(headers, {"host_id": host_id, "date": date})
48+
status_code = await self._call_api(
49+
headers, {"host_id": payload.host_id, "date": payload.date}
50+
)
5351

5452
logger.info(
5553
"wellness_aggregate_complete",
56-
host_id=host_id,
57-
date=date,
54+
host_id=payload.host_id,
55+
date=payload.date,
5856
api_status=status_code,
5957
)
60-
return {"host_id": host_id, "date": date, "aggregated": True}
58+
return {
59+
"host_id": payload.host_id,
60+
"date": payload.date,
61+
"aggregated": True,
62+
}
6163

6264

6365
register_job(WellnessAggregateJob())

apps/worker/src/jobs/wellness_escalation.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import structlog
66

77
from src.jobs.base import BaseJob, register_job
8+
from src.jobs.schemas import WellnessEscalationPayload
89
from src.lib.dispatch import dispatch_local
910

1011
logger = structlog.get_logger(__name__)
@@ -18,43 +19,39 @@ def job_type(self) -> str:
1819
return "wellness.escalation"
1920

2021
async def execute(self, data: dict[str, Any]) -> dict[str, Any]:
21-
log_id: str = data.get("log_id", "")
22-
host_id: str = data.get("host_id", "")
23-
status: str = data.get("status", "")
24-
summary: str = data.get("summary", "")
25-
contact_tokens: list[str] = data.get("contact_tokens", [])
22+
payload = WellnessEscalationPayload.model_validate(data)
2623

27-
if not contact_tokens:
24+
if not payload.contact_tokens:
2825
logger.warning(
2926
"wellness_escalation_no_contacts",
30-
log_id=log_id,
31-
host_id=host_id,
27+
log_id=payload.log_id,
28+
host_id=payload.host_id,
3229
)
3330
return {"escalated": False, "reason": "no_contacts"}
3431

3532
await dispatch_local(
3633
"notification.send",
3734
{
38-
"tokens": contact_tokens,
39-
"title": f"URGENT: Wellness {status.upper()}",
40-
"body": summary or "Immediate attention required",
35+
"tokens": payload.contact_tokens,
36+
"title": f"URGENT: Wellness {payload.status.upper()}",
37+
"body": payload.summary or "Immediate attention required",
4138
"data": {
42-
"log_id": log_id,
43-
"host_id": host_id,
39+
"log_id": payload.log_id,
40+
"host_id": payload.host_id,
4441
"type": "wellness_escalation",
4542
},
4643
},
4744
)
4845

4946
logger.info(
5047
"wellness_escalation_dispatched",
51-
log_id=log_id,
52-
host_id=host_id,
53-
contact_count=len(contact_tokens),
48+
log_id=payload.log_id,
49+
host_id=payload.host_id,
50+
contact_count=len(payload.contact_tokens),
5451
)
5552
return {
5653
"escalated": True,
57-
"contact_count": len(contact_tokens),
54+
"contact_count": len(payload.contact_tokens),
5855
}
5956

6057

apps/worker/src/routers/pubsub.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import structlog
88
from fastapi import APIRouter, HTTPException
9-
from pydantic import BaseModel
9+
from pydantic import BaseModel, ValidationError
1010

1111
from src.jobs.base import get_job
1212
from src.lib.idempotency import release_claim, try_claim
@@ -54,6 +54,14 @@ async def handle_pubsub_push(envelope: PubSubEnvelope) -> dict[str, Any]:
5454
logger.info("pubsub_job_execute_start", job_type=task_type)
5555
try:
5656
result = await job.execute(data)
57+
except (ValueError, ValidationError) as exc:
58+
release_claim(task_type, data, idempotency_key=idempotency_key)
59+
logger.warning(
60+
"pubsub_job_payload_invalid",
61+
job_type=task_type,
62+
error=str(exc),
63+
)
64+
raise HTTPException(status_code=422, detail=str(exc)) from exc
5765
except Exception:
5866
release_claim(task_type, data, idempotency_key=idempotency_key)
5967
raise

apps/worker/src/routers/tasks.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import structlog
66
from fastapi import APIRouter, HTTPException
77
from opentelemetry import trace
8-
from pydantic import BaseModel
8+
from pydantic import BaseModel, ValidationError
99

1010
from src.jobs.base import get_job, list_jobs
1111
from src.lib.idempotency import release_claim, try_claim
@@ -43,6 +43,14 @@ async def process_task(payload: TaskPayload) -> dict[str, Any]:
4343
attributes={"job.type": payload.task_type},
4444
):
4545
result = await job.execute(payload.data)
46+
except (ValueError, ValidationError) as exc:
47+
release_claim(payload.task_type, payload.data)
48+
logger.warning(
49+
"job_payload_invalid",
50+
job_type=payload.task_type,
51+
error=str(exc),
52+
)
53+
raise HTTPException(status_code=422, detail=str(exc)) from exc
4654
except Exception:
4755
release_claim(payload.task_type, payload.data)
4856
raise

0 commit comments

Comments
 (0)