|
2 | 2 | import json |
3 | 3 | import logging |
4 | 4 | from copy import deepcopy |
5 | | -from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple |
| 5 | +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple |
6 | 6 |
|
7 | 7 | from pydantic import BaseModel |
8 | 8 |
|
@@ -55,6 +55,18 @@ def get_todos(): List[Todo]: |
55 | 55 | ``` |
56 | 56 | """ |
57 | 57 |
|
| 58 | + def __init__(self, validation_serializer: Optional[Callable[[Any], str]] = None): |
| 59 | + """ |
| 60 | + Initialize the OpenAPIValidationMiddleware. |
| 61 | +
|
| 62 | + Parameters |
| 63 | + ---------- |
| 64 | + validation_serializer : Callable, optional |
| 65 | + Optional serializer to use when serializing the response for validation. |
| 66 | + Use it when you have a custom type that cannot be serialized by the default jsonable_encoder. |
| 67 | + """ |
| 68 | + self._validation_serializer = validation_serializer |
| 69 | + |
58 | 70 | def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response: |
59 | 71 | logger.debug("OpenAPIValidationMiddleware handler") |
60 | 72 |
|
@@ -181,10 +193,11 @@ def _serialize_response( |
181 | 193 | exclude_unset=exclude_unset, |
182 | 194 | exclude_defaults=exclude_defaults, |
183 | 195 | exclude_none=exclude_none, |
| 196 | + custom_serializer=self._validation_serializer, |
184 | 197 | ) |
185 | 198 | else: |
186 | 199 | # Just serialize the response content returned from the handler |
187 | | - return jsonable_encoder(response_content) |
| 200 | + return jsonable_encoder(response_content, custom_serializer=self._validation_serializer) |
188 | 201 |
|
189 | 202 | def _prepare_response_content( |
190 | 203 | self, |
|
0 commit comments