|
18 | 18 | ) |
19 | 19 | from fastapi.datastructures import Address |
20 | 20 | from fastapi.middleware.cors import CORSMiddleware |
| 21 | +from fastapi.requests import HTTPConnection |
21 | 22 | from fastapi.responses import RedirectResponse, StreamingResponse |
22 | 23 | from fastapi.security import OAuth2AuthorizationCodeBearer |
23 | 24 | from observability_utils.tracing import ( |
@@ -161,6 +162,11 @@ def inner(request: Request, access_token: str = Depends(oauth_scheme)): |
161 | 162 | return inner |
162 | 163 |
|
163 | 164 |
|
| 165 | +def _user(request: HTTPConnection) -> str | None: |
| 166 | + user = getattr(request.state, "decoded_access_token", {}) |
| 167 | + return user.get("fedid", None) |
| 168 | + |
| 169 | + |
164 | 170 | TRACER = get_tracer("interface") |
165 | 171 |
|
166 | 172 |
|
@@ -283,18 +289,11 @@ def submit_task( |
283 | 289 | response: Response, |
284 | 290 | task_request: Annotated[TaskRequest, Body(..., examples=[example_task_request])], |
285 | 291 | runner: Annotated[WorkerDispatcher, Depends(_runner)], |
| 292 | + user: Annotated[str, Depends(_user)], |
286 | 293 | ) -> TaskResponse: |
287 | 294 | """Submit a task to the worker.""" |
288 | 295 | try: |
289 | | - # Extract user from jwt if using OIDC (if jwt exists) |
290 | | - access_token: dict[str, Any] | None = getattr( |
291 | | - request.state, "decoded_access_token", None |
292 | | - ) |
293 | | - if access_token: |
294 | | - user: str = access_token.get("fedid", "Unknown") |
295 | | - else: |
296 | | - user = "Unknown" |
297 | | - |
| 296 | + user = user or "UNKNOWN" |
298 | 297 | task_id: str = runner.run(interface.submit_task, task_request, {"user": user}) |
299 | 298 | response.headers["Location"] = f"{request.url}/{task_id}" |
300 | 299 | return TaskResponse(task_id=task_id) |
|
0 commit comments