|
4 | 4 |
|
5 | 5 | import asyncio |
6 | 6 | import json |
| 7 | +import time |
7 | 8 | from collections.abc import AsyncIterator |
8 | 9 | from datetime import UTC, datetime |
9 | 10 | from typing import Annotated, Any, Final, Optional, cast |
|
39 | 40 | from configuration import configuration |
40 | 41 | from constants import SUBSTITUTED_INSTRUCTIONS_PLACEHOLDER |
41 | 42 | from log import get_logger |
| 43 | +from metrics import recording |
42 | 44 | from models.config import Action |
43 | 45 | from models.requests import ResponsesRequest |
44 | 46 | from models.responses import ( |
@@ -136,6 +138,37 @@ def _get_user_agent(request: Request) -> Optional[str]: |
136 | 138 | return sanitized or None |
137 | 139 |
|
138 | 140 |
|
| 141 | +def _check_response_quota(user_id: str, endpoint_path: str) -> None: |
| 142 | + """Check response quota availability and record bounded quota metrics.""" |
| 143 | + quota_start_time = time.monotonic() |
| 144 | + try: |
| 145 | + check_tokens_available(configuration.quota_limiters, user_id) |
| 146 | + except HTTPException: |
| 147 | + recording.record_quota_check( |
| 148 | + endpoint_path, |
| 149 | + recording.QUOTA_TYPE_USER_ID, |
| 150 | + recording.QUOTA_RESULT_FAILURE, |
| 151 | + time.monotonic() - quota_start_time, |
| 152 | + ) |
| 153 | + raise |
| 154 | + except Exception: # pylint: disable=broad-exception-caught |
| 155 | + # Unexpected quota backend failures still need bounded metrics before |
| 156 | + # propagating to the endpoint error handling layer. |
| 157 | + recording.record_quota_check( |
| 158 | + endpoint_path, |
| 159 | + recording.QUOTA_TYPE_USER_ID, |
| 160 | + recording.QUOTA_RESULT_ERROR, |
| 161 | + time.monotonic() - quota_start_time, |
| 162 | + ) |
| 163 | + raise |
| 164 | + recording.record_quota_check( |
| 165 | + endpoint_path, |
| 166 | + recording.QUOTA_TYPE_USER_ID, |
| 167 | + recording.QUOTA_RESULT_SUCCESS, |
| 168 | + time.monotonic() - quota_start_time, |
| 169 | + ) |
| 170 | + |
| 171 | + |
139 | 172 | responses_response: dict[int | str, dict[str, Any]] = { |
140 | 173 | 200: ResponsesResponse.openapi_response(), |
141 | 174 | 401: UnauthorizedResponse.openapi_response( |
@@ -275,11 +308,12 @@ async def responses_endpoint_handler( |
275 | 308 | started_at = datetime.now(UTC) |
276 | 309 | rh_identity_context = get_rh_identity_context(request) |
277 | 310 | user_id, _, _, token = auth |
| 311 | + endpoint_path = "/v1/responses" |
278 | 312 |
|
279 | 313 | await check_mcp_auth(configuration, mcp_headers, token, request.headers) |
280 | 314 |
|
281 | 315 | # Check token availability |
282 | | - check_tokens_available(configuration.quota_limiters, user_id) |
| 316 | + _check_response_quota(user_id, endpoint_path) |
283 | 317 |
|
284 | 318 | # Enforce RBAC: optionally disallow overriding model in requests |
285 | 319 | validate_model_provider_override( |
@@ -331,7 +365,6 @@ async def responses_endpoint_handler( |
331 | 365 | ) |
332 | 366 | attachments_text = extract_attachments_text(original_request.input) |
333 | 367 |
|
334 | | - endpoint_path = "/v1/responses" |
335 | 368 | moderation_result = await run_shield_moderation( |
336 | 369 | client, |
337 | 370 | input_text + "\n\n" + attachments_text, |
|
0 commit comments