Skip to content

Commit 53402b9

Browse files
feat: add Dependency injection feature
1 parent 8b42829 commit 53402b9

9 files changed

Lines changed: 569 additions & 2 deletions

File tree

aws_lambda_powertools/event_handler/api_gateway.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,9 @@ def __init__(
472472

473473
self.custom_response_validation_http_code = custom_response_validation_http_code
474474

475+
# Cache whether this route's handler declares Depends() parameters
476+
self._has_dependencies: bool | None = None
477+
475478
# Caches the name of any Request-typed parameter in the handler.
476479
# Avoids re-scanning the signature on every invocation.
477480
self.request_param_name: str | None = None
@@ -613,6 +616,21 @@ def dependant(self) -> Dependant:
613616

614617
return self._dependant
615618

619+
@property
620+
def has_dependencies(self) -> bool:
621+
"""Check if handler declares Depends() parameters without triggering full dependant computation."""
622+
if self._has_dependencies is None:
623+
from aws_lambda_powertools.event_handler.openapi.dependant import (
624+
_get_depends_from_annotation,
625+
get_typed_signature,
626+
)
627+
628+
sig = get_typed_signature(self.func)
629+
self._has_dependencies = any(
630+
_get_depends_from_annotation(p.annotation) is not None for p in sig.parameters.values()
631+
)
632+
return self._has_dependencies
633+
616634
@property
617635
def body_field(self) -> ModelField | None:
618636
if self._body_field is None:
@@ -1428,6 +1446,17 @@ def _registered_api_adapter(
14281446
if route.request_param_name:
14291447
route_args = {**route_args, route.request_param_name: app.request}
14301448

1449+
# Resolve Depends() parameters
1450+
if route.has_dependencies:
1451+
from aws_lambda_powertools.event_handler.openapi.dependant import solve_dependencies
1452+
1453+
dep_values = solve_dependencies(
1454+
dependant=route.dependant,
1455+
request=app.request,
1456+
dependency_overrides=app.dependency_overrides or None,
1457+
)
1458+
route_args.update(dep_values)
1459+
14311460
return app._to_response(next_middleware(**route_args))
14321461

14331462

@@ -1496,6 +1525,7 @@ def __init__(
14961525
by default json.loads when integrating with EventSource data class
14971526
"""
14981527
self._proxy_type = proxy_type
1528+
self.dependency_overrides: dict[Callable, Callable] = {}
14991529
self._dynamic_routes: list[Route] = []
15001530
self._static_routes: list[Route] = []
15011531
self._route_keys: list[str] = []

aws_lambda_powertools/event_handler/openapi/dependant.py

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
import inspect
44
import re
5-
from typing import TYPE_CHECKING, Any, ForwardRef, cast
5+
from typing import TYPE_CHECKING, Any, ForwardRef, cast, get_type_hints
6+
7+
from typing_extensions import Annotated, get_args, get_origin
68

79
from aws_lambda_powertools.event_handler.openapi.compat import (
810
ModelField,
@@ -13,6 +15,8 @@
1315
from aws_lambda_powertools.event_handler.openapi.params import (
1416
Body,
1517
Dependant,
18+
DependencyParam,
19+
Depends,
1620
File,
1721
Form,
1822
Param,
@@ -149,6 +153,15 @@ def get_path_param_names(path: str) -> set[str]:
149153
return set(re.findall("{(.*?)}", path))
150154

151155

156+
def _get_depends_from_annotation(annotation: Any) -> Depends | None:
157+
"""Extract a Depends instance from an Annotated[Type, Depends(...)] annotation."""
158+
if get_origin(annotation) is Annotated:
159+
for arg in get_args(annotation)[1:]:
160+
if isinstance(arg, Depends):
161+
return arg
162+
return None
163+
164+
152165
def get_dependant(
153166
*,
154167
path: str,
@@ -193,6 +206,22 @@ def get_dependant(
193206
if param.annotation is Request:
194207
continue
195208

209+
# Depends() parameters (via Annotated[Type, Depends(fn)]) are resolved at call time.
210+
depends_instance = _get_depends_from_annotation(param.annotation)
211+
if depends_instance is not None:
212+
sub_dependant = get_dependant(
213+
path=path,
214+
call=depends_instance.dependency,
215+
)
216+
dependant.dependencies.append(
217+
DependencyParam(
218+
param_name=param_name,
219+
depends=depends_instance,
220+
dependant=sub_dependant,
221+
),
222+
)
223+
continue
224+
196225
# If the parameter is a path parameter, we need to set the in_ field to "path".
197226
is_path_param = param_name in path_param_names
198227

@@ -386,3 +415,75 @@ def get_body_field_info(
386415
body_field_info_kwargs["media_type"] = body_param_media_types[0]
387416

388417
return body_field_info, body_field_info_kwargs
418+
419+
420+
def solve_dependencies(
421+
*,
422+
dependant: Dependant,
423+
request: Request | None = None,
424+
dependency_overrides: dict[Callable[..., Any], Callable[..., Any]] | None = None,
425+
dependency_cache: dict[Callable[..., Any], Any] | None = None,
426+
) -> dict[str, Any]:
427+
"""
428+
Recursively resolve all ``Depends()`` parameters for a given dependant.
429+
430+
Parameters
431+
----------
432+
dependant: Dependant
433+
The dependant model containing dependency declarations
434+
request: Request, optional
435+
The current request object, injected into dependencies that declare a Request parameter
436+
dependency_overrides: dict, optional
437+
Mapping of original dependency callable to override callable (for testing)
438+
dependency_cache: dict, optional
439+
Per-invocation cache of resolved dependency values
440+
441+
Returns
442+
-------
443+
dict[str, Any]
444+
Mapping of parameter name to resolved dependency value
445+
"""
446+
if dependency_cache is None:
447+
dependency_cache = {}
448+
449+
values: dict[str, Any] = {}
450+
451+
for dep in dependant.dependencies:
452+
use_fn = dep.depends.dependency
453+
454+
# Apply overrides (for testing)
455+
if dependency_overrides and use_fn in dependency_overrides:
456+
use_fn = dependency_overrides[use_fn]
457+
458+
# Check cache
459+
if dep.depends.use_cache and use_fn in dependency_cache:
460+
values[dep.param_name] = dependency_cache[use_fn]
461+
continue
462+
463+
# Recursively resolve sub-dependencies
464+
sub_values = solve_dependencies(
465+
dependant=dep.dependant,
466+
request=request,
467+
dependency_overrides=dependency_overrides,
468+
dependency_cache=dependency_cache,
469+
)
470+
471+
# Inject Request if the dependency declares it
472+
if request is not None:
473+
try:
474+
hints = get_type_hints(use_fn)
475+
except Exception:
476+
hints = {}
477+
for param_name, annotation in hints.items():
478+
if annotation is Request:
479+
sub_values[param_name] = request
480+
481+
solved = use_fn(**sub_values)
482+
483+
# Cache result
484+
if dep.depends.use_cache:
485+
dependency_cache[use_fn] = solved
486+
487+
values[dep.param_name] = solved
488+
489+
return values

aws_lambda_powertools/event_handler/openapi/params.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,47 @@ class ParamTypes(Enum):
4242
_Unset: Any = Undefined
4343

4444

45+
class Depends:
46+
"""
47+
Declares a dependency for a route handler parameter.
48+
49+
Dependencies are resolved automatically before the handler is called. The return value
50+
of the dependency callable is injected as the parameter value.
51+
52+
Parameters
53+
----------
54+
dependency: Callable[..., Any]
55+
A callable whose return value will be injected into the handler parameter.
56+
The callable can itself declare ``Depends()`` parameters to form a dependency tree.
57+
use_cache: bool
58+
If ``True`` (default), the dependency result is cached per invocation so that
59+
the same dependency used multiple times is only called once.
60+
61+
Examples
62+
--------
63+
64+
```python
65+
from typing_extensions import Annotated
66+
67+
from aws_lambda_powertools.event_handler import APIGatewayHttpResolver
68+
from aws_lambda_powertools.event_handler.openapi.params import Depends
69+
70+
app = APIGatewayHttpResolver()
71+
72+
def get_tenant() -> str:
73+
return "default-tenant"
74+
75+
@app.get("/orders")
76+
def list_orders(tenant_id: Annotated[str, Depends(get_tenant)]):
77+
return {"tenant": tenant_id}
78+
```
79+
"""
80+
81+
def __init__(self, dependency: Callable[..., Any], *, use_cache: bool = True) -> None:
82+
self.dependency = dependency
83+
self.use_cache = use_cache
84+
85+
4586
class Dependant:
4687
"""
4788
A class used internally to represent a dependency between path operation decorators and the path operation function.
@@ -64,6 +105,7 @@ def __init__(
64105
http_connection_param_name: str | None = None,
65106
response_param_name: str | None = None,
66107
background_tasks_param_name: str | None = None,
108+
dependencies: list[DependencyParam] | None = None,
67109
path: str | None = None,
68110
) -> None:
69111
self.path_params = path_params or []
@@ -78,6 +120,7 @@ def __init__(
78120
self.http_connection_param_name = http_connection_param_name
79121
self.response_param_name = response_param_name
80122
self.background_tasks_param_name = background_tasks_param_name
123+
self.dependencies = dependencies or []
81124
self.name = name
82125
self.call = call
83126
# Store the path to be able to re-generate a dependable from it in overrides
@@ -86,6 +129,15 @@ def __init__(
86129
self.cache_key: CacheKey = self.call
87130

88131

132+
class DependencyParam:
133+
"""Holds a dependency's parameter name and its resolved Dependant sub-tree."""
134+
135+
def __init__(self, *, param_name: str, depends: Depends, dependant: Dependant) -> None:
136+
self.param_name = param_name
137+
self.depends = depends
138+
self.dependant = dependant
139+
140+
89141
class Param(FieldInfo): # type: ignore[misc]
90142
"""
91143
A class used internally to represent a parameter in a path operation.
@@ -816,7 +868,7 @@ def get_flat_dependant(
816868
visited = []
817869
visited.append(dependant.cache_key)
818870

819-
return Dependant(
871+
flat = Dependant(
820872
path_params=dependant.path_params.copy(),
821873
query_params=dependant.query_params.copy(),
822874
header_params=dependant.header_params.copy(),
@@ -825,6 +877,18 @@ def get_flat_dependant(
825877
path=dependant.path,
826878
)
827879

880+
# Flatten sub-dependencies that declare HTTP params (query, header, etc.)
881+
for dep in dependant.dependencies:
882+
if dep.dependant.cache_key not in visited:
883+
sub_flat = get_flat_dependant(dep.dependant, visited=visited)
884+
flat.path_params.extend(sub_flat.path_params)
885+
flat.query_params.extend(sub_flat.query_params)
886+
flat.header_params.extend(sub_flat.header_params)
887+
flat.cookie_params.extend(sub_flat.cookie_params)
888+
flat.body_params.extend(sub_flat.body_params)
889+
890+
return flat
891+
828892

829893
def analyze_param(
830894
*,

docs/core/event_handler/api_gateway.md

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1365,6 +1365,48 @@ You can use `append_context` when you want to share data between your App and Ro
13651365
--8<-- "examples/event_handler_rest/src/split_route_append_context_module.py"
13661366
```
13671367

1368+
### Dependency injection
1369+
1370+
You can use `Depends()` to declare dependencies that are automatically resolved and injected into your route handlers. This provides type-safe, composable, and testable dependency injection.
1371+
1372+
#### Basic usage
1373+
1374+
Use `Annotated[Type, Depends(fn)]` to declare a dependency. The return value of `fn` is injected into the parameter automatically.
1375+
1376+
```python hl_lines="5 8 20 25"
1377+
--8<-- "examples/event_handler_rest/src/dependency_injection.py"
1378+
```
1379+
1380+
#### Nested dependencies
1381+
1382+
Dependencies can depend on other dependencies, forming a composable tree. Shared sub-dependencies are resolved once per invocation and cached automatically.
1383+
1384+
```python hl_lines="18 22 29-30"
1385+
--8<-- "examples/event_handler_rest/src/dependency_injection_nested.py"
1386+
```
1387+
1388+
#### Accessing the request
1389+
1390+
Dependencies that need access to the current request can declare a parameter typed as `Request`. It will be injected automatically.
1391+
1392+
```python hl_lines="5-6 12 20"
1393+
--8<-- "examples/event_handler_rest/src/dependency_injection_with_request.py"
1394+
```
1395+
1396+
#### Testing with dependency overrides
1397+
1398+
Use `dependency_overrides` to replace any dependency with a mock or stub during testing - no monkeypatching needed.
1399+
1400+
```python hl_lines="3 12 26"
1401+
--8<-- "examples/event_handler_rest/src/dependency_injection_testing.py"
1402+
```
1403+
1404+
???+ tip "Caching behavior"
1405+
By default, dependencies are cached within the same invocation (`use_cache=True`). If the same dependency is used by multiple handlers or sub-dependencies, it is resolved once and the result is reused. Use `Depends(fn, use_cache=False)` to resolve every time.
1406+
1407+
???+ info "`append_context` vs `Depends()`"
1408+
`append_context` remains available for backward compatibility. `Depends()` is recommended for new code because it provides type safety, IDE autocomplete, composable dependency trees, and `dependency_overrides` for testing.
1409+
13681410
#### Sample layout
13691411

13701412
This is a sample project layout for a monolithic function with routes split in different files (`/todos`, `/health`).
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import os
2+
from typing import Any
3+
4+
import boto3
5+
from typing_extensions import Annotated
6+
7+
from aws_lambda_powertools.event_handler import APIGatewayHttpResolver
8+
from aws_lambda_powertools.event_handler.openapi.params import Depends
9+
from aws_lambda_powertools.utilities.typing import LambdaContext
10+
11+
app = APIGatewayHttpResolver()
12+
13+
14+
def get_dynamodb_table():
15+
dynamodb = boto3.resource("dynamodb")
16+
return dynamodb.Table(os.environ["TABLE_NAME"])
17+
18+
19+
@app.get("/orders")
20+
def list_orders(table: Annotated[Any, Depends(get_dynamodb_table)]):
21+
return table.scan()["Items"]
22+
23+
24+
@app.post("/orders")
25+
def create_order(table: Annotated[Any, Depends(get_dynamodb_table)]):
26+
order = app.current_event.json_body
27+
table.put_item(Item=order)
28+
return {"message": "Order created"}
29+
30+
31+
def lambda_handler(event: dict, context: LambdaContext) -> dict:
32+
return app.resolve(event, context)

0 commit comments

Comments
 (0)