Skip to content

Commit 0ac9c92

Browse files
authored
feat: supporting s2s user context (#1592)
1 parent 8dd6737 commit 0ac9c92

File tree

3 files changed

+108
-0
lines changed

3 files changed

+108
-0
lines changed

api/src/middleware/request_context.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
1+
import base64
2+
import hashlib
3+
import hmac
4+
import json
15
import logging
6+
import re
27
from contextvars import ContextVar
8+
from typing import Optional
39

410
import requests
511
from google.auth import jwt
@@ -55,6 +61,68 @@ def decode_jwt(self, token: str):
5561
logging.error("Error decoding JWT: %s", e)
5662
return None
5763

64+
def decode_user_context_jwt(self, token: str):
65+
"""Decode and verify the custom user-context JWT sent by the web app.
66+
67+
This token is signed with HS256 using a shared secret (S2S_JWT_SECRET).
68+
If verification fails for any reason, None is returned and the request
69+
falls back to the existing IAP / Authorization-based identity handling.
70+
"""
71+
try:
72+
secret = get_config("S2S_JWT_SECRET")
73+
if not secret or len(secret) < 32:
74+
# Misconfiguration: do not fail the request, just skip user-context.
75+
logging.error(
76+
"S2S_JWT_SECRET is missing or too short; " "cannot verify x-mdb-user-context token.",
77+
)
78+
return None
79+
80+
token = token.replace("Bearer ", "")
81+
parts = token.split(".")
82+
if len(parts) != 3:
83+
return None
84+
85+
header_b64, payload_b64, signature_b64 = parts
86+
signing_input = f"{header_b64}.{payload_b64}".encode("ascii")
87+
88+
expected_sig = hmac.new(secret.encode("utf-8"), signing_input, hashlib.sha256).digest()
89+
90+
# JWT uses URL-safe base64 without padding
91+
def b64url_decode(value: str) -> bytes:
92+
padding = "=" * (-len(value) % 4)
93+
return base64.urlsafe_b64decode(value + padding)
94+
95+
actual_sig = b64url_decode(signature_b64)
96+
if not hmac.compare_digest(expected_sig, actual_sig):
97+
logging.warning("Invalid signature for x-mdb-user-context token")
98+
return None
99+
100+
payload_json = b64url_decode(payload_b64).decode("utf-8")
101+
payload = json.loads(payload_json)
102+
# Minimal shape we care about: { uid, email?, isGuest? }
103+
if not isinstance(payload, dict) or "uid" not in payload:
104+
return None
105+
return payload
106+
except Exception as e: # pragma: no cover - defensive
107+
logging.error("Error decoding user-context JWT: %s", e)
108+
return None
109+
110+
@staticmethod
111+
def extract_user_id(raw_user_id: Optional[str]) -> Optional[str]:
112+
"""
113+
Extracts the user ID from the raw user ID string.
114+
- If there is a colon, return the substring after the last colon.
115+
- If there is no colon, return the original raw_user_id.
116+
- If raw_user_id is None, return None.
117+
"""
118+
if raw_user_id is None:
119+
return None
120+
121+
match = re.search(r":([^:]+)$", raw_user_id)
122+
if match:
123+
return match.group(1)
124+
return raw_user_id
125+
58126
def _extract_from_headers(self, headers: dict, scope: Scope) -> None:
59127
self.host = headers.get("host")
60128
self.protocol = headers.get("x-forwarded-proto") if headers.get("x-forwarded-proto") else scope.get("scheme")
@@ -87,13 +155,29 @@ def _extract_from_headers(self, headers: dict, scope: Scope) -> None:
87155
# auth header is used for local development
88156
self.user_id = headers.get("x-goog-authenticated-user-id")
89157
self.user_email = headers.get("x-goog-authenticated-user-email")
158+
self.is_guest = False
90159
self.google_public_keys = None
91160
if not self.iap_jwt_assertion and headers.get("authorization"):
92161
self.iap_jwt_assertion = self.decode_jwt(headers.get("authorization"))
93162
if self.iap_jwt_assertion:
94163
self.user_id = self.iap_jwt_assertion.get("user_id")
95164
self.user_email = self.iap_jwt_assertion.get("email")
96165

166+
# Optional user-context header set by the web app for server-to-server calls.
167+
# Name is aligned with the frontend's USER_CONTEXT_HEADER.
168+
user_context_header = headers.get("x-mdb-user-context") or headers.get("md-user-context")
169+
if user_context_header:
170+
user_context = self.decode_user_context_jwt(user_context_header)
171+
if user_context:
172+
# Prefer values from the verified user-context token when present.
173+
self.user_id = user_context.get("uid", self.user_id)
174+
self.user_email = user_context.get("email", self.user_email)
175+
self.is_guest = bool(user_context.get("isGuest"))
176+
# if the user_id is in the format "accounts.google.com:1234567890",
177+
# extract just the numeric ID part for consistency with legacy IAP user_id format
178+
if self.user_id:
179+
self.user_id = RequestContext.extract_user_id(self.user_id)
180+
97181
def __repr__(self) -> str:
98182
# Omitting sensitive data like email and jwt assertion
99183
safe_properties = dict(

api/tests/unittest/middleware/test_request_context.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import unittest
22
from unittest.mock import MagicMock
33

4+
import pytest
45
from starlette.datastructures import Headers
56

67
from middleware.request_context import RequestContext, get_request_context, _request_context
@@ -35,6 +36,7 @@ def test_init_extract_headers(self):
3536
"client_host": "client",
3637
"client_user_agent": "user-agent",
3738
"google_public_keys": None,
39+
"is_guest": False,
3840
"headers": Headers(scope=scope_instance),
3941
"host": "localhost",
4042
"iap_jwt_assertion": "jwt",
@@ -54,3 +56,16 @@ def test_get_request_context(self):
5456
request_context = RequestContext(MagicMock())
5557
_request_context.set(request_context)
5658
self.assertEqual(request_context, get_request_context())
59+
60+
61+
@pytest.mark.parametrize(
62+
"raw_user_id, expected",
63+
[
64+
(None, None),
65+
("plainuserid", "plainuserid"),
66+
("accounts.google.com:1234567890", "1234567890"),
67+
("prefix:middle:finalpart", "finalpart"),
68+
],
69+
)
70+
def test_extract_user_id_parametrized(raw_user_id, expected):
71+
assert RequestContext.extract_user_id(raw_user_id) == expected

infra/feed-api/main.tf

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,15 @@ resource "google_cloud_run_v2_service" "mobility-feed-api" {
9595
name = "PROJECT_ID"
9696
value = data.google_project.project.project_id
9797
}
98+
env {
99+
name = "S2S_JWT_SECRET"
100+
value_source {
101+
secret_key_ref {
102+
secret = "${upper(var.environment)}_S2S_JWT_SECRET"
103+
version = "latest"
104+
}
105+
}
106+
}
98107
resources {
99108
limits = {
100109
cpu = "1"

0 commit comments

Comments
 (0)