Skip to content

Commit ee0f09d

Browse files
committed
chore: add tests for the fixed count retry strategy
Add sync and async tests for the fixed count retry strategy. Add a make target for local tests. Add a github action to run the local tests. Fix a linter error.
1 parent 6b810cc commit ee0f09d

15 files changed

Lines changed: 652 additions & 5 deletions
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
name: Momento Local tests
2+
3+
on:
4+
pull_request:
5+
branches: [main]
6+
7+
jobs:
8+
local-tests:
9+
strategy:
10+
matrix:
11+
os: [ubuntu-24.04]
12+
python-version: ["3.13"]
13+
runs-on: ${{ matrix.os }}
14+
15+
env:
16+
TEST_API_KEY: ${{ secrets.ALPHA_TEST_AUTH_TOKEN }}
17+
TEST_CACHE_NAME: python-integration-test-${{ matrix.python-version }}-${{ matrix.new-python-protobuf }}-${{ github.sha }}
18+
19+
steps:
20+
- uses: actions/checkout@v4
21+
22+
- name: Setup Python ${{ matrix.python-version }}
23+
uses: actions/setup-python@v4
24+
with:
25+
python-version: ${{ matrix.python-version }}
26+
27+
- name: Install and configure Poetry
28+
uses: snok/install-poetry@v1
29+
with:
30+
version: 1.3.1
31+
virtualenvs-in-project: true
32+
33+
- name: Install dependencies
34+
run: poetry install
35+
36+
- name: Start Momento Local
37+
run: |
38+
docker run --cap-add=NET_ADMIN --rm -d -p 8080:8080 -p 9090:9090 gomomento/momento-local --enable-test-admin
39+
40+
- name: Run tests
41+
run: poetry run pytest -p no:sugar -q -m local

.github/workflows/on-pull-request.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ jobs:
5757
run: poetry run ruff format --check --diff src tests
5858

5959
- name: Run tests
60-
run: poetry run pytest -p no:sugar -q
60+
run: poetry run pytest -p no:sugar -q -m "not local"
6161

6262
test-examples:
6363
runs-on: ubuntu-24.04

Makefile

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,12 @@ gen-sync: do-gen-sync format lint
5050
.PHONY: test
5151
## Run unit and integration tests with pytest
5252
test:
53-
@poetry run pytest
53+
@poetry run pytest -m "not local"
54+
55+
.PHONY: test-local
56+
## Run the integration tests that require Momento Local
57+
test-local:
58+
@poetry run pytest -m local
5459

5560
.PHONY: precommit
5661
## Run format, lint, and test as a step before committing.

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ log_level = "ERROR"
5959
log_cli = true
6060
log_cli_format = "%(asctime)s [%(levelname)s] %(message)s"
6161
log_cli_date_format = "%Y-%m-%d %H:%M:%S.%f"
62+
markers = [
63+
"local: tests that require Momento Local",
64+
]
6265

6366
[tool.mypy]
6467
python_version = "3.7"

src/momento/retry/fixed_timeout_retry_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def determine_when_to_retry(self, props: RetryableProps) -> Optional[float]:
4949
# If a retry attempt's timeout has passed but the client's overall timeout has not yet passed,
5050
# we should reset the deadline and retry.
5151
if (
52-
props.attempt_number > 0
52+
props.attempt_number > 0 # type: ignore[misc]
5353
and props.grpc_status == grpc.StatusCode.DEADLINE_EXCEEDED # type: ignore[misc]
5454
and props.overall_deadline > datetime.now()
5555
):

tests/conftest.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import asyncio
44
import os
55
import random
6+
from contextlib import asynccontextmanager, contextmanager
67
from datetime import timedelta
7-
from typing import AsyncIterator, Callable, Iterator, List, Optional, Union, cast
8+
from typing import AsyncGenerator, AsyncIterator, Callable, Iterator, List, Optional, Union, cast
89

910
import pytest
1011
import pytest_asyncio
@@ -41,6 +42,8 @@
4142
TTopicName,
4243
)
4344

45+
from tests.momento.local.momento_local_async_middleware import MomentoLocalAsyncMiddleware, MomentoLocalMiddlewareArgs
46+
from tests.momento.local.momento_local_middleware import MomentoLocalMiddleware
4447
from tests.utils import (
4548
unique_test_cache_name,
4649
uuid_bytes,
@@ -51,13 +54,17 @@
5154
# Integration test data
5255
#######################
5356

54-
TEST_CONFIGURATION = Configurations.Laptop.latest()
57+
TEST_CONFIGURATION: Configuration = Configurations.Laptop.latest()
5558
TEST_TOPIC_CONFIGURATION = TopicConfigurations.Default.latest().with_client_timeout(timedelta(seconds=10))
5659
TEST_AUTH_CONFIGURATION = AuthConfigurations.Laptop.latest()
5760

5861

5962
TEST_AUTH_PROVIDER = CredentialProvider.from_environment_variable("TEST_API_KEY")
6063

64+
MOMENTO_LOCAL_HOSTNAME = os.environ.get("MOMENTO_HOSTNAME", "127.0.0.1")
65+
MOMENTO_LOCAL_PORT = int(os.environ.get("MOMENTO_PORT", "8080"))
66+
TEST_LOCAL_AUTH_PROVIDER = CredentialProvider.for_momento_local(MOMENTO_LOCAL_HOSTNAME, MOMENTO_LOCAL_PORT)
67+
6168

6269
TEST_CACHE_NAME: Optional[str] = os.getenv("TEST_CACHE_NAME")
6370
if not TEST_CACHE_NAME:
@@ -354,6 +361,48 @@ async def auth_client_async() -> AsyncIterator[AuthClientAsync]:
354361
yield _auth_client
355362

356363

364+
@asynccontextmanager
365+
async def client_async_local(
366+
cache_name: str,
367+
middleware_args: Optional[MomentoLocalMiddlewareArgs] = None,
368+
config_fn: Optional[Callable[[Configuration], Configuration]] = None,
369+
) -> AsyncGenerator[CacheClientAsync, None]:
370+
config = TEST_CONFIGURATION
371+
372+
if config_fn:
373+
config = config_fn(config)
374+
375+
if middleware_args:
376+
config = config.add_middleware(MomentoLocalAsyncMiddleware(middleware_args))
377+
378+
client = await CacheClientAsync.create(config, TEST_LOCAL_AUTH_PROVIDER, DEFAULT_TTL_SECONDS)
379+
380+
await client.create_cache(cache_name)
381+
382+
yield client
383+
384+
385+
@contextmanager
386+
def client_local(
387+
cache_name: str,
388+
middleware_args: Optional[MomentoLocalMiddlewareArgs] = None,
389+
config_fn: Optional[Callable[[Configuration], Configuration]] = None,
390+
) -> Iterator[CacheClient]:
391+
config = TEST_CONFIGURATION
392+
393+
if config_fn:
394+
config = config_fn(config)
395+
396+
if middleware_args:
397+
config = config.add_middleware(MomentoLocalMiddleware(middleware_args))
398+
399+
client = CacheClient.create(config, TEST_LOCAL_AUTH_PROVIDER, DEFAULT_TTL_SECONDS)
400+
401+
client.create_cache(cache_name)
402+
403+
yield client
404+
405+
357406
TUniqueCacheName = Callable[[CacheClient], str]
358407

359408

tests/momento/local/__init__.py

Whitespace-only changes.
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from momento.errors import MomentoErrorCode
2+
3+
MOMENTO_ERROR_CODE_TO_METADATA = {
4+
MomentoErrorCode.INVALID_ARGUMENT_ERROR: "invalid-argument",
5+
MomentoErrorCode.UNKNOWN_SERVICE_ERROR: "unknown",
6+
MomentoErrorCode.ALREADY_EXISTS_ERROR: "already-exists",
7+
MomentoErrorCode.NOT_FOUND_ERROR: "not-found",
8+
MomentoErrorCode.INTERNAL_SERVER_ERROR: "internal",
9+
MomentoErrorCode.PERMISSION_ERROR: "permission-denied",
10+
MomentoErrorCode.AUTHENTICATION_ERROR: "unauthenticated",
11+
MomentoErrorCode.CANCELLED_ERROR: "cancelled",
12+
MomentoErrorCode.LIMIT_EXCEEDED_ERROR: "resource-exhausted",
13+
MomentoErrorCode.BAD_REQUEST_ERROR: "invalid-argument",
14+
MomentoErrorCode.TIMEOUT_ERROR: "deadline-exceeded",
15+
MomentoErrorCode.SERVER_UNAVAILABLE: "unavailable",
16+
MomentoErrorCode.CLIENT_RESOURCE_EXHAUSTED: "resource-exhausted",
17+
MomentoErrorCode.FAILED_PRECONDITION_ERROR: "failed-precondition",
18+
MomentoErrorCode.UNKNOWN_ERROR: "unknown",
19+
MomentoErrorCode.CONNECTION_ERROR: "unavailable",
20+
}
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import asyncio
2+
from typing import List
3+
4+
from grpc.aio import Metadata
5+
from momento import logs
6+
from momento.config.middleware import MiddlewareMessage, MiddlewareRequestHandlerContext, MiddlewareStatus
7+
from momento.config.middleware.aio import Middleware, MiddlewareMetadata, MiddlewareRequestHandler
8+
9+
from tests.momento.local.momento_error_code_metadata import MOMENTO_ERROR_CODE_TO_METADATA
10+
from tests.momento.local.momento_local_middleware_args import MomentoLocalMiddlewareArgs
11+
from tests.momento.local.momento_rpc_method import MomentoRpcMethod
12+
13+
14+
class MomentoLocalAsyncMiddlewareRequestHandler(MiddlewareRequestHandler):
15+
def __init__(self, args: MomentoLocalMiddlewareArgs):
16+
self._args = args
17+
self._cache_name = None
18+
self._logger = logs.logger
19+
20+
async def on_request_metadata(self, metadata: MiddlewareMetadata) -> MiddlewareMetadata:
21+
grpc_metadata = metadata.grpc_metadata
22+
23+
if grpc_metadata is not None:
24+
self._set_grpc_metadata(grpc_metadata, "request-id", self._args.request_id)
25+
26+
if self._args.return_error is not None:
27+
error = MOMENTO_ERROR_CODE_TO_METADATA[self._args.return_error]
28+
if error is not None:
29+
self._set_grpc_metadata(grpc_metadata, "return-error", error)
30+
31+
if self._args.error_rpc_list is not None:
32+
rpcs = self._concatenate_rpcs(self._args.error_rpc_list)
33+
self._set_grpc_metadata(grpc_metadata, "error-rpcs", rpcs)
34+
35+
if self._args.delay_rpc_list is not None:
36+
rpcs = self._concatenate_rpcs(self._args.delay_rpc_list)
37+
self._set_grpc_metadata(grpc_metadata, "delay-rpcs", rpcs)
38+
39+
if self._args.error_count is not None:
40+
self._set_grpc_metadata(grpc_metadata, "error-count", str(self._args.error_count))
41+
42+
if self._args.delay_millis is not None:
43+
self._set_grpc_metadata(grpc_metadata, "delay-ms", str(self._args.delay_millis))
44+
45+
if self._args.delay_count is not None:
46+
self._set_grpc_metadata(grpc_metadata, "delay-count", str(self._args.delay_count))
47+
48+
if self._args.stream_error_rpc_list is not None:
49+
rpcs = self._concatenate_rpcs(self._args.stream_error_rpc_list)
50+
self._set_grpc_metadata(grpc_metadata, "stream-error-rpcs", rpcs)
51+
52+
if self._args.stream_error is not None:
53+
error = MOMENTO_ERROR_CODE_TO_METADATA[self._args.stream_error]
54+
if error is not None:
55+
self._set_grpc_metadata(grpc_metadata, "stream-error", error)
56+
57+
if self._args.stream_error_message_limit is not None:
58+
limit_str = str(self._args.stream_error_message_limit)
59+
self._set_grpc_metadata(grpc_metadata, "stream-error-message-limit", limit_str)
60+
61+
cache_name = grpc_metadata.get("cache")
62+
if cache_name is not None:
63+
self._cache_name = cache_name
64+
else:
65+
self._logger.debug("No cache name found in metadata.")
66+
67+
return metadata
68+
69+
async def on_request_body(self, request: MiddlewareMessage) -> MiddlewareMessage:
70+
request_type = request.constructor_name
71+
72+
if self._cache_name is not None:
73+
if self._args.test_metrics_collector is not None: # type: ignore[unreachable]
74+
rpc_method = MomentoRpcMethod.from_request_name(request_type)
75+
if rpc_method:
76+
self._args.test_metrics_collector.add_timestamp(
77+
self._cache_name,
78+
rpc_method,
79+
int(asyncio.get_event_loop().time() * 1000), # Current time in milliseconds
80+
)
81+
else:
82+
self._logger.debug("No cache name available. Timestamp will not be collected.")
83+
84+
return request
85+
86+
async def on_response_metadata(self, metadata: MiddlewareMetadata) -> MiddlewareMetadata:
87+
return metadata
88+
89+
async def on_response_body(self, response: MiddlewareMessage) -> MiddlewareMessage:
90+
return response
91+
92+
async def on_response_status(self, status: MiddlewareStatus) -> MiddlewareStatus:
93+
return status
94+
95+
@staticmethod
96+
def _set_grpc_metadata(metadata: Metadata, key: str, value: str) -> None:
97+
if value is not None:
98+
metadata[key] = value
99+
100+
@staticmethod
101+
def _concatenate_rpcs(rpcs: List[MomentoRpcMethod]) -> str:
102+
return " ".join(rpc.metadata for rpc in rpcs)
103+
104+
105+
class MomentoLocalAsyncMiddleware(Middleware):
106+
def __init__(self, args: MomentoLocalMiddlewareArgs):
107+
self._args = args
108+
109+
async def on_new_request(self, context: MiddlewareRequestHandlerContext) -> MiddlewareRequestHandler:
110+
return MomentoLocalAsyncMiddlewareRequestHandler(self._args)
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from collections import defaultdict
2+
from typing import Dict, List
3+
4+
from tests.momento.local.momento_rpc_method import MomentoRpcMethod
5+
6+
7+
class MomentoLocalMetricsCollector:
8+
def __init__(self) -> None:
9+
# Data structure to store timestamps: cacheName -> requestName -> [timestamps]
10+
self.data: Dict[str, Dict[MomentoRpcMethod, List[int]]] = defaultdict(lambda: defaultdict(list))
11+
12+
def add_timestamp(self, cache_name: str, request_name: MomentoRpcMethod, timestamp: int) -> None:
13+
"""Add a timestamp for a specific request and cache.
14+
15+
Args:
16+
cache_name: The name of the cache
17+
request_name: The name of the request (using MomentoRpcMethod enum)
18+
timestamp: The timestamp to record in seconds since epoch
19+
"""
20+
self.data[cache_name][request_name].append(timestamp)
21+
22+
def get_total_retry_count(self, cache_name: str, request_name: MomentoRpcMethod) -> int:
23+
"""Calculate the total retry count for a specific cache and request.
24+
25+
Args:
26+
cache_name: The name of the cache
27+
request_name: The name of the request (using MomentoRpcMethod enum)
28+
29+
Returns:
30+
The total number of retries
31+
"""
32+
timestamps = self.data.get(cache_name, {}).get(request_name, [])
33+
# Number of retries is one less than the number of timestamps
34+
return max(0, len(timestamps) - 1)
35+
36+
def get_average_time_between_retries(self, cache_name: str, request_name: MomentoRpcMethod) -> float:
37+
"""Calculate the average time between retries for a specific cache and request.
38+
39+
Args:
40+
cache_name: The name of the cache
41+
request_name: The name of the request (using MomentoRpcMethod enum)
42+
43+
Returns:
44+
The average time in seconds, or 0.0 if there are no retries
45+
"""
46+
timestamps = self.data.get(cache_name, {}).get(request_name, [])
47+
if len(timestamps) < 2:
48+
return 0.0 # No retries occurred
49+
50+
total_interval = sum(timestamps[i] - timestamps[i - 1] for i in range(1, len(timestamps)))
51+
return total_interval / (len(timestamps) - 1)
52+
53+
def get_all_metrics(self) -> Dict[str, Dict[MomentoRpcMethod, List[int]]]:
54+
"""Retrieve all collected metrics for debugging or analysis.
55+
56+
Returns:
57+
The complete data structure with all recorded metrics
58+
"""
59+
return self.data

0 commit comments

Comments
 (0)