|
| 1 | +import base64 |
| 2 | +import hashlib |
| 3 | +import hmac |
| 4 | +import json |
1 | 5 | import logging |
| 6 | +import re |
2 | 7 | from contextvars import ContextVar |
| 8 | +from typing import Optional |
3 | 9 |
|
4 | 10 | import requests |
5 | 11 | from google.auth import jwt |
@@ -55,6 +61,68 @@ def decode_jwt(self, token: str): |
55 | 61 | logging.error("Error decoding JWT: %s", e) |
56 | 62 | return None |
57 | 63 |
|
| 64 | + def decode_user_context_jwt(self, token: str): |
| 65 | + """Decode and verify the custom user-context JWT sent by the web app. |
| 66 | +
|
| 67 | + This token is signed with HS256 using a shared secret (S2S_JWT_SECRET). |
| 68 | + If verification fails for any reason, None is returned and the request |
| 69 | + falls back to the existing IAP / Authorization-based identity handling. |
| 70 | + """ |
| 71 | + try: |
| 72 | + secret = get_config("S2S_JWT_SECRET") |
| 73 | + if not secret or len(secret) < 32: |
| 74 | + # Misconfiguration: do not fail the request, just skip user-context. |
| 75 | + logging.error( |
| 76 | + "S2S_JWT_SECRET is missing or too short; " "cannot verify x-mdb-user-context token.", |
| 77 | + ) |
| 78 | + return None |
| 79 | + |
| 80 | + token = token.replace("Bearer ", "") |
| 81 | + parts = token.split(".") |
| 82 | + if len(parts) != 3: |
| 83 | + return None |
| 84 | + |
| 85 | + header_b64, payload_b64, signature_b64 = parts |
| 86 | + signing_input = f"{header_b64}.{payload_b64}".encode("ascii") |
| 87 | + |
| 88 | + expected_sig = hmac.new(secret.encode("utf-8"), signing_input, hashlib.sha256).digest() |
| 89 | + |
| 90 | + # JWT uses URL-safe base64 without padding |
| 91 | + def b64url_decode(value: str) -> bytes: |
| 92 | + padding = "=" * (-len(value) % 4) |
| 93 | + return base64.urlsafe_b64decode(value + padding) |
| 94 | + |
| 95 | + actual_sig = b64url_decode(signature_b64) |
| 96 | + if not hmac.compare_digest(expected_sig, actual_sig): |
| 97 | + logging.warning("Invalid signature for x-mdb-user-context token") |
| 98 | + return None |
| 99 | + |
| 100 | + payload_json = b64url_decode(payload_b64).decode("utf-8") |
| 101 | + payload = json.loads(payload_json) |
| 102 | + # Minimal shape we care about: { uid, email?, isGuest? } |
| 103 | + if not isinstance(payload, dict) or "uid" not in payload: |
| 104 | + return None |
| 105 | + return payload |
| 106 | + except Exception as e: # pragma: no cover - defensive |
| 107 | + logging.error("Error decoding user-context JWT: %s", e) |
| 108 | + return None |
| 109 | + |
| 110 | + @staticmethod |
| 111 | + def extract_user_id(raw_user_id: Optional[str]) -> Optional[str]: |
| 112 | + """ |
| 113 | + Extracts the user ID from the raw user ID string. |
| 114 | + - If there is a colon, return the substring after the last colon. |
| 115 | + - If there is no colon, return the original raw_user_id. |
| 116 | + - If raw_user_id is None, return None. |
| 117 | + """ |
| 118 | + if raw_user_id is None: |
| 119 | + return None |
| 120 | + |
| 121 | + match = re.search(r":([^:]+)$", raw_user_id) |
| 122 | + if match: |
| 123 | + return match.group(1) |
| 124 | + return raw_user_id |
| 125 | + |
58 | 126 | def _extract_from_headers(self, headers: dict, scope: Scope) -> None: |
59 | 127 | self.host = headers.get("host") |
60 | 128 | self.protocol = headers.get("x-forwarded-proto") if headers.get("x-forwarded-proto") else scope.get("scheme") |
@@ -87,13 +155,29 @@ def _extract_from_headers(self, headers: dict, scope: Scope) -> None: |
87 | 155 | # auth header is used for local development |
88 | 156 | self.user_id = headers.get("x-goog-authenticated-user-id") |
89 | 157 | self.user_email = headers.get("x-goog-authenticated-user-email") |
| 158 | + self.is_guest = False |
90 | 159 | self.google_public_keys = None |
91 | 160 | if not self.iap_jwt_assertion and headers.get("authorization"): |
92 | 161 | self.iap_jwt_assertion = self.decode_jwt(headers.get("authorization")) |
93 | 162 | if self.iap_jwt_assertion: |
94 | 163 | self.user_id = self.iap_jwt_assertion.get("user_id") |
95 | 164 | self.user_email = self.iap_jwt_assertion.get("email") |
96 | 165 |
|
| 166 | + # Optional user-context header set by the web app for server-to-server calls. |
| 167 | + # Name is aligned with the frontend's USER_CONTEXT_HEADER. |
| 168 | + user_context_header = headers.get("x-mdb-user-context") or headers.get("md-user-context") |
| 169 | + if user_context_header: |
| 170 | + user_context = self.decode_user_context_jwt(user_context_header) |
| 171 | + if user_context: |
| 172 | + # Prefer values from the verified user-context token when present. |
| 173 | + self.user_id = user_context.get("uid", self.user_id) |
| 174 | + self.user_email = user_context.get("email", self.user_email) |
| 175 | + self.is_guest = bool(user_context.get("isGuest")) |
| 176 | + # if the user_id is in the format "accounts.google.com:1234567890", |
| 177 | + # extract just the numeric ID part for consistency with legacy IAP user_id format |
| 178 | + if self.user_id: |
| 179 | + self.user_id = RequestContext.extract_user_id(self.user_id) |
| 180 | + |
97 | 181 | def __repr__(self) -> str: |
98 | 182 | # Omitting sensitive data like email and jwt assertion |
99 | 183 | safe_properties = dict( |
|
0 commit comments