Skip to content

Commit 1230510

Browse files
authored
feat: add middleware (#502)
* feat: add middleware Add middleware implementations for the async and synchronous Momento clients. Add aio and sync middleware to the configuration.
1 parent 9f809c8 commit 1230510

18 files changed

Lines changed: 761 additions & 126 deletions

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,13 @@ module = [
109109
"momento.internal.aio._scs_data_client",
110110
"momento.internal.aio._scs_grpc_manager",
111111
"momento.internal.aio._utilities",
112+
"momento.internal.synchronous._utilities",
112113
"momento.responses.control.signing_key.*",
114+
"momento.internal.aio._middleware_interceptor",
115+
"momento.internal.synchronous._middleware_interceptor",
116+
"momento.config.middleware.models",
117+
"momento.config.middleware.aio.middleware_metadata",
118+
"momento.config.middleware.synchronous.middleware_metadata",
113119
]
114120
disallow_any_expr = false
115121

src/momento/config/configuration.py

Lines changed: 81 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,16 @@
33
from abc import ABC, abstractmethod
44
from datetime import timedelta
55
from pathlib import Path
6+
from typing import List, Optional
67

8+
import momento.config.middleware.aio
79
from momento.retry import RetryStrategy
810

11+
from .middleware import Middleware
912
from .transport.transport_strategy import TransportStrategy
1013

1114

1215
class ConfigurationBase(ABC):
13-
# TODO: Middlewares
1416
@abstractmethod
1517
def get_retry_strategy(self) -> RetryStrategy:
1618
pass
@@ -35,20 +37,39 @@ def with_client_timeout(self, client_timeout: timedelta) -> Configuration:
3537
def with_root_certificates_pem(self, root_certificate_path: Path) -> Configuration:
3638
pass
3739

40+
@abstractmethod
41+
def with_middlewares(self, middlewares: List[Middleware]) -> Configuration:
42+
pass
43+
44+
@abstractmethod
45+
def add_middleware(self, middleware: Middleware) -> Configuration:
46+
pass
47+
48+
@abstractmethod
49+
def get_middlewares(self) -> List[Middleware]:
50+
pass
51+
3852

3953
class Configuration(ConfigurationBase):
4054
"""Configuration options for Momento Simple Cache Client."""
4155

42-
def __init__(self, transport_strategy: TransportStrategy, retry_strategy: RetryStrategy):
56+
def __init__(
57+
self,
58+
transport_strategy: TransportStrategy,
59+
retry_strategy: RetryStrategy,
60+
middlewares: Optional[List[Middleware]] = None,
61+
):
4362
"""Instantiate a Configuration.
4463
4564
Args:
4665
transport_strategy (TransportStrategy): Configuration options for networking with
4766
the Momento service.
4867
retry_strategy (RetryStrategy): the strategy to use when determining whether to retry a grpc call.
68+
middlewares: Middleware that can intercept Momento calls. May be aio or synchronous.
4969
"""
5070
self._transport_strategy = transport_strategy
5171
self._retry_strategy = retry_strategy
72+
self._middlewares: List[Middleware] = list(middlewares or [])
5273

5374
def get_retry_strategy(self) -> RetryStrategy:
5475
"""Access the retry strategy.
@@ -67,7 +88,7 @@ def with_retry_strategy(self, retry_strategy: RetryStrategy) -> Configuration:
6788
Returns:
6889
Configuration: the new Configuration with the specified RetryStrategy.
6990
"""
70-
return Configuration(self._transport_strategy, retry_strategy)
91+
return Configuration(self._transport_strategy, retry_strategy, self._middlewares)
7192

7293
def get_transport_strategy(self) -> TransportStrategy:
7394
"""Access the transport strategy.
@@ -86,7 +107,7 @@ def with_transport_strategy(self, transport_strategy: TransportStrategy) -> Conf
86107
Returns:
87108
Configuration: the new Configuration with the specified TransportStrategy.
88109
"""
89-
return Configuration(transport_strategy, self._retry_strategy)
110+
return Configuration(transport_strategy, self._retry_strategy, self._middlewares)
90111

91112
def with_client_timeout(self, client_timeout: timedelta) -> Configuration:
92113
"""Copies the Configuration and sets the new client-side timeout in the copy's TransportStrategy.
@@ -97,7 +118,11 @@ def with_client_timeout(self, client_timeout: timedelta) -> Configuration:
97118
Return:
98119
Configuration: the new Configuration.
99120
"""
100-
return Configuration(self._transport_strategy.with_client_timeout(client_timeout), self._retry_strategy)
121+
return Configuration(
122+
self._transport_strategy.with_client_timeout(client_timeout),
123+
self._retry_strategy,
124+
self._middlewares,
125+
)
101126

102127
def with_root_certificates_pem(self, root_certificates_pem_path: Path) -> Configuration:
103128
"""Copies the Configuration and sets the new root certificates in the copy's TransportStrategy.
@@ -106,10 +131,60 @@ def with_root_certificates_pem(self, root_certificates_pem_path: Path) -> Config
106131
root_certificates_pem_path (Path): the new root certificates.
107132
108133
Returns:
109-
ConfigurationBase: the new Configuration.
134+
Configuration: the new Configuration.
110135
"""
111136
grpc_configuration = self._transport_strategy.get_grpc_configuration().with_root_certificates_pem(
112137
root_certificates_pem_path
113138
)
114139
transport_strategy = self._transport_strategy.with_grpc_configuration(grpc_configuration)
115140
return self.with_transport_strategy(transport_strategy)
141+
142+
def with_middlewares(self, middlewares: List[Middleware]) -> Configuration:
143+
"""Copies the Configuration and adds the new middlewares to the end of the list.
144+
145+
Args:
146+
middlewares: the middleware list to be appended to the Configuration's existing middleware. These can be
147+
aio or synchronous middleware.
148+
149+
Returns:
150+
Configuration: the new Configuration.
151+
"""
152+
new_middlewares = self._middlewares.copy() + middlewares
153+
return Configuration(self._transport_strategy, self._retry_strategy, new_middlewares)
154+
155+
def add_middleware(self, middleware: Middleware) -> Configuration:
156+
"""Copies the Configuration and adds the new middleware to the end of the list.
157+
158+
Args:
159+
middleware: the middleware to be appended to the Configuration's existing middleware. This can be aio or
160+
synchronous middleware.
161+
162+
Returns:
163+
Configuration: the new Configuration.
164+
"""
165+
new_middlewares = self._middlewares.copy() + [middleware]
166+
return Configuration(self._transport_strategy, self._retry_strategy, new_middlewares)
167+
168+
def get_middlewares(self) -> List[Middleware]:
169+
"""Access the middleware list.
170+
171+
Returns:
172+
the configuration's list of middleware.
173+
"""
174+
return self._middlewares.copy()
175+
176+
def get_aio_middlewares(self) -> List[momento.config.middleware.aio.Middleware]:
177+
"""Access the aio middleware from the middleware list.
178+
179+
Returns:
180+
the configuration's list of aio middleware.
181+
"""
182+
return [m for m in self._middlewares if isinstance(m, momento.config.middleware.aio.Middleware)]
183+
184+
def get_sync_middlewares(self) -> List[momento.config.middleware.synchronous.Middleware]:
185+
"""Access the synchronous middleware from the middleware list.
186+
187+
Returns:
188+
the configuration's list of synchronous middleware.
189+
"""
190+
return [m for m in self._middlewares if isinstance(m, momento.config.middleware.synchronous.Middleware)]
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from typing import Union
2+
3+
from momento.config.middleware.aio import Middleware as AsyncMiddleware
4+
from momento.config.middleware.models import (
5+
MiddlewareMessage,
6+
MiddlewareRequestHandlerContext,
7+
MiddlewareStatus,
8+
)
9+
from momento.config.middleware.synchronous import Middleware as SyncMiddleware
10+
11+
Middleware = Union[SyncMiddleware, AsyncMiddleware]
12+
13+
__all__ = [
14+
"Middleware",
15+
"MiddlewareMessage",
16+
"MiddlewareStatus",
17+
"MiddlewareRequestHandlerContext",
18+
]
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from momento.config.middleware.aio.middleware import Middleware, MiddlewareRequestHandler
2+
from momento.config.middleware.aio.middleware_metadata import MiddlewareMetadata
3+
4+
__all__ = ["Middleware", "MiddlewareMetadata", "MiddlewareRequestHandler"]
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import abc
2+
3+
from momento.config.middleware.aio.middleware_metadata import MiddlewareMetadata
4+
from momento.config.middleware.models import MiddlewareMessage, MiddlewareRequestHandlerContext, MiddlewareStatus
5+
6+
7+
class MiddlewareRequestHandler(abc.ABC):
8+
@abc.abstractmethod
9+
async def on_request_metadata(self, metadata: MiddlewareMetadata) -> MiddlewareMetadata:
10+
pass
11+
12+
@abc.abstractmethod
13+
async def on_request_body(self, request: MiddlewareMessage) -> MiddlewareMessage:
14+
pass
15+
16+
@abc.abstractmethod
17+
async def on_response_metadata(self, metadata: MiddlewareMetadata) -> MiddlewareMetadata:
18+
pass
19+
20+
@abc.abstractmethod
21+
async def on_response_body(self, response: MiddlewareMessage) -> MiddlewareMessage:
22+
pass
23+
24+
@abc.abstractmethod
25+
async def on_response_status(self, status: MiddlewareStatus) -> MiddlewareStatus:
26+
pass
27+
28+
29+
class Middleware(abc.ABC):
30+
@abc.abstractmethod
31+
async def on_new_request(self, context: MiddlewareRequestHandlerContext) -> MiddlewareRequestHandler:
32+
pass
33+
34+
# noinspection PyMethodMayBeStatic
35+
def close(self) -> None:
36+
return None
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from typing import Optional
2+
3+
from grpc.aio import Metadata
4+
5+
6+
class MiddlewareMetadata:
7+
"""Wrapper for gRPC metadata."""
8+
9+
def __init__(self, metadata: Optional[Metadata]):
10+
self.grpc_metadata = metadata
11+
12+
def get_grpc_metadata(self) -> Optional[Metadata]:
13+
"""Get the underlying gRPC metadata."""
14+
return self.grpc_metadata
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from typing import Dict
2+
3+
import grpc
4+
from google.protobuf.message import Message
5+
6+
CONNECTION_ID_KEY = "connectionID"
7+
8+
9+
class MiddlewareMessage:
10+
"""Wrapper for a gRPC protobuf message."""
11+
12+
def __init__(self, message: Message):
13+
self.grpc_message = message
14+
15+
def get_message_length(self) -> int:
16+
"""Get the length of the message in bytes."""
17+
return len(self.grpc_message.SerializeToString())
18+
19+
def get_constructor_name(self) -> str:
20+
"""Get the class name of the message."""
21+
return str(self.grpc_message.__class__.__name__)
22+
23+
def get_message(self) -> Message:
24+
"""Get the underlying gRPC message."""
25+
return self.grpc_message
26+
27+
28+
class MiddlewareStatus:
29+
"""Wrapper for gRPC status."""
30+
31+
def __init__(self, status: grpc.StatusCode):
32+
self.grpc_status = status
33+
34+
def get_code(self) -> grpc.StatusCode:
35+
"""Get the status code."""
36+
return self.grpc_status
37+
38+
39+
class MiddlewareRequestHandlerContext:
40+
"""Context for middleware request handlers."""
41+
42+
def __init__(self, context: Dict[str, str]):
43+
self.context = context
44+
45+
def get_context(self) -> Dict[str, str]:
46+
return self.context
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from momento.config.middleware.synchronous.middleware import Middleware, MiddlewareRequestHandler
2+
from momento.config.middleware.synchronous.middleware_metadata import MiddlewareMetadata
3+
4+
__all__ = ["Middleware", "MiddlewareMetadata", "MiddlewareRequestHandler"]
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import abc
2+
3+
from momento.config.middleware.models import MiddlewareMessage, MiddlewareRequestHandlerContext, MiddlewareStatus
4+
from momento.config.middleware.synchronous.middleware_metadata import MiddlewareMetadata
5+
6+
7+
class MiddlewareRequestHandler(abc.ABC):
8+
@abc.abstractmethod
9+
def on_request_metadata(self, metadata: MiddlewareMetadata) -> MiddlewareMetadata:
10+
pass
11+
12+
@abc.abstractmethod
13+
def on_request_body(self, request: MiddlewareMessage) -> MiddlewareMessage:
14+
pass
15+
16+
@abc.abstractmethod
17+
def on_response_metadata(self, metadata: MiddlewareMetadata) -> MiddlewareMetadata:
18+
pass
19+
20+
@abc.abstractmethod
21+
def on_response_body(self, response: MiddlewareMessage) -> MiddlewareMessage:
22+
pass
23+
24+
@abc.abstractmethod
25+
def on_response_status(self, status: MiddlewareStatus) -> MiddlewareStatus:
26+
pass
27+
28+
29+
class Middleware(abc.ABC):
30+
@abc.abstractmethod
31+
def on_new_request(self, context: MiddlewareRequestHandlerContext) -> MiddlewareRequestHandler:
32+
pass
33+
34+
# noinspection PyMethodMayBeStatic
35+
def close(self) -> None:
36+
return None
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from typing import Optional
2+
3+
from grpc._typing import MetadataType
4+
5+
6+
class MiddlewareMetadata:
7+
"""Wrapper for gRPC metadata."""
8+
9+
def __init__(self, metadata: Optional[MetadataType]):
10+
self.grpc_metadata = metadata
11+
12+
def get_grpc_metadata(self) -> Optional[MetadataType]:
13+
"""Get the underlying gRPC metadata."""
14+
return self.grpc_metadata

0 commit comments

Comments
 (0)