|
2 | 2 |
|
3 | 3 | import inspect |
4 | 4 | import re |
5 | | -from typing import TYPE_CHECKING, Any, ForwardRef, cast |
| 5 | +from typing import TYPE_CHECKING, Any, ForwardRef, cast, get_type_hints |
| 6 | + |
| 7 | +from typing_extensions import Annotated, get_args, get_origin |
6 | 8 |
|
7 | 9 | from aws_lambda_powertools.event_handler.openapi.compat import ( |
8 | 10 | ModelField, |
|
13 | 15 | from aws_lambda_powertools.event_handler.openapi.params import ( |
14 | 16 | Body, |
15 | 17 | Dependant, |
| 18 | + DependencyParam, |
| 19 | + Depends, |
16 | 20 | File, |
17 | 21 | Form, |
18 | 22 | Param, |
@@ -149,6 +153,15 @@ def get_path_param_names(path: str) -> set[str]: |
149 | 153 | return set(re.findall("{(.*?)}", path)) |
150 | 154 |
|
151 | 155 |
|
| 156 | +def _get_depends_from_annotation(annotation: Any) -> Depends | None: |
| 157 | + """Extract a Depends instance from an Annotated[Type, Depends(...)] annotation.""" |
| 158 | + if get_origin(annotation) is Annotated: |
| 159 | + for arg in get_args(annotation)[1:]: |
| 160 | + if isinstance(arg, Depends): |
| 161 | + return arg |
| 162 | + return None |
| 163 | + |
| 164 | + |
152 | 165 | def get_dependant( |
153 | 166 | *, |
154 | 167 | path: str, |
@@ -193,6 +206,22 @@ def get_dependant( |
193 | 206 | if param.annotation is Request: |
194 | 207 | continue |
195 | 208 |
|
| 209 | + # Depends() parameters (via Annotated[Type, Depends(fn)]) are resolved at call time. |
| 210 | + depends_instance = _get_depends_from_annotation(param.annotation) |
| 211 | + if depends_instance is not None: |
| 212 | + sub_dependant = get_dependant( |
| 213 | + path=path, |
| 214 | + call=depends_instance.dependency, |
| 215 | + ) |
| 216 | + dependant.dependencies.append( |
| 217 | + DependencyParam( |
| 218 | + param_name=param_name, |
| 219 | + depends=depends_instance, |
| 220 | + dependant=sub_dependant, |
| 221 | + ), |
| 222 | + ) |
| 223 | + continue |
| 224 | + |
196 | 225 | # If the parameter is a path parameter, we need to set the in_ field to "path". |
197 | 226 | is_path_param = param_name in path_param_names |
198 | 227 |
|
@@ -386,3 +415,75 @@ def get_body_field_info( |
386 | 415 | body_field_info_kwargs["media_type"] = body_param_media_types[0] |
387 | 416 |
|
388 | 417 | return body_field_info, body_field_info_kwargs |
| 418 | + |
| 419 | + |
| 420 | +def solve_dependencies( |
| 421 | + *, |
| 422 | + dependant: Dependant, |
| 423 | + request: Request | None = None, |
| 424 | + dependency_overrides: dict[Callable[..., Any], Callable[..., Any]] | None = None, |
| 425 | + dependency_cache: dict[Callable[..., Any], Any] | None = None, |
| 426 | +) -> dict[str, Any]: |
| 427 | + """ |
| 428 | + Recursively resolve all ``Depends()`` parameters for a given dependant. |
| 429 | +
|
| 430 | + Parameters |
| 431 | + ---------- |
| 432 | + dependant: Dependant |
| 433 | + The dependant model containing dependency declarations |
| 434 | + request: Request, optional |
| 435 | + The current request object, injected into dependencies that declare a Request parameter |
| 436 | + dependency_overrides: dict, optional |
| 437 | + Mapping of original dependency callable to override callable (for testing) |
| 438 | + dependency_cache: dict, optional |
| 439 | + Per-invocation cache of resolved dependency values |
| 440 | +
|
| 441 | + Returns |
| 442 | + ------- |
| 443 | + dict[str, Any] |
| 444 | + Mapping of parameter name to resolved dependency value |
| 445 | + """ |
| 446 | + if dependency_cache is None: |
| 447 | + dependency_cache = {} |
| 448 | + |
| 449 | + values: dict[str, Any] = {} |
| 450 | + |
| 451 | + for dep in dependant.dependencies: |
| 452 | + use_fn = dep.depends.dependency |
| 453 | + |
| 454 | + # Apply overrides (for testing) |
| 455 | + if dependency_overrides and use_fn in dependency_overrides: |
| 456 | + use_fn = dependency_overrides[use_fn] |
| 457 | + |
| 458 | + # Check cache |
| 459 | + if dep.depends.use_cache and use_fn in dependency_cache: |
| 460 | + values[dep.param_name] = dependency_cache[use_fn] |
| 461 | + continue |
| 462 | + |
| 463 | + # Recursively resolve sub-dependencies |
| 464 | + sub_values = solve_dependencies( |
| 465 | + dependant=dep.dependant, |
| 466 | + request=request, |
| 467 | + dependency_overrides=dependency_overrides, |
| 468 | + dependency_cache=dependency_cache, |
| 469 | + ) |
| 470 | + |
| 471 | + # Inject Request if the dependency declares it |
| 472 | + if request is not None: |
| 473 | + try: |
| 474 | + hints = get_type_hints(use_fn) |
| 475 | + except Exception: |
| 476 | + hints = {} |
| 477 | + for param_name, annotation in hints.items(): |
| 478 | + if annotation is Request: |
| 479 | + sub_values[param_name] = request |
| 480 | + |
| 481 | + solved = use_fn(**sub_values) |
| 482 | + |
| 483 | + # Cache result |
| 484 | + if dep.depends.use_cache: |
| 485 | + dependency_cache[use_fn] = solved |
| 486 | + |
| 487 | + values[dep.param_name] = solved |
| 488 | + |
| 489 | + return values |
0 commit comments