22
33import math
44from datetime import timedelta
5+ from typing import Callable
56
67from momento_wire_types import cachepubsub_pb2 as pubsub_pb
78from momento_wire_types import cachepubsub_pb2_grpc as pubsub_grpc
1011from momento .auth import CredentialProvider
1112from momento .config import TopicConfiguration
1213from momento .errors import convert_error
14+ from momento .errors .exceptions import ClientResourceExhaustedException
1315from momento .internal ._utilities import _validate_cache_name , _validate_topic_name
1416from momento .internal .aio ._scs_grpc_manager import (
1517 _PubsubGrpcManager ,
2729class _ScsPubsubClient :
2830 """Internal pubsub client."""
2931
30- stream_topic_manager_count = 0
31-
3232 def __init__ (self , configuration : TopicConfiguration , credential_provider : CredentialProvider ):
3333 endpoint = credential_provider .cache_endpoint
3434 self ._logger = logs .logger
@@ -38,20 +38,23 @@ def __init__(self, configuration: TopicConfiguration, credential_provider: Crede
3838 default_deadline : timedelta = configuration .get_transport_strategy ().get_grpc_configuration ().get_deadline ()
3939 self ._default_deadline_seconds = default_deadline .total_seconds ()
4040
41- num_subscriptions = configuration .get_max_subscriptions ()
4241 # Default to a single channel and scale up if necessary. Each channel can support
4342 # 100 subscriptions. Issuing more subscribe requests than you have channels to handle
44- # will cause the last request to hang indefinitely, so it's important to get this right .
43+ # will cause a ClientResourceExhaustedException .
4544 num_channels = 1
45+ num_subscriptions = configuration .get_max_subscriptions ()
4646 if num_subscriptions > 0 :
4747 num_channels = math .ceil (num_subscriptions / 100.0 )
4848 self ._logger .debug (f"creating { num_channels } subscription channels" )
49-
50- self ._grpc_manager = _PubsubGrpcManager (configuration , credential_provider )
5149 self ._stream_managers = [
5250 _PubsubGrpcStreamManager (configuration , credential_provider ) for i in range (0 , num_channels )
5351 ]
5452
53+ # Default to 4 unary pubsub channels. TODO: Make this configurable.
54+ self ._unary_managers = [_PubsubGrpcManager (configuration , credential_provider ) for i in range (0 , 4 )]
55+ self ._stream_manager_count = 0
56+ self ._unary_manager_count = 0
57+
5558 @property
5659 def endpoint (self ) -> str :
5760 return self ._endpoint
@@ -72,7 +75,7 @@ async def publish(self, cache_name: str, topic_name: str, value: str | bytes) ->
7275 value = topic_value ,
7376 )
7477
75- await self ._get_stub ().Publish ( # type: ignore[misc]
78+ await self ._get_unary_stub ().Publish ( # type: ignore[misc]
7679 request ,
7780 timeout = self ._default_deadline_seconds ,
7881 )
@@ -98,7 +101,8 @@ async def subscribe(
98101 resume_at_topic_sequence_number = resume_at_topic_sequence_number ,
99102 sequence_page = resume_at_topic_sequence_page ,
100103 )
101- stream = self ._get_stream_stub ().Subscribe ( # type: ignore[misc]
104+ stub , decrement_stream_count = self ._get_stream_stub ()
105+ stream = stub .Subscribe ( # type: ignore[misc]
102106 request ,
103107 )
104108
@@ -112,23 +116,48 @@ async def subscribe(
112116 err = Exception (f"expected a heartbeat message but got '{ msg_type } '" )
113117 self ._log_request_error ("subscribe" , err )
114118 return TopicSubscribe .Error (convert_error (err , Service .TOPICS ))
115- return TopicSubscribe .SubscriptionAsync (cache_name , topic_name , client_stream = stream ) # type: ignore[misc]
119+ return TopicSubscribe .SubscriptionAsync (
120+ cache_name ,
121+ topic_name ,
122+ client_stream = stream , # type: ignore[misc]
123+ decrement_stream_count_method = decrement_stream_count ,
124+ )
116125 except Exception as e :
117126 self ._log_request_error ("subscribe" , e )
118127 return TopicSubscribe .Error (convert_error (e , Service .TOPICS ))
119128
120129 def _log_request_error (self , request_type : str , e : Exception ) -> None :
121130 self ._logger .warning (f"{ request_type } failed with exception: { e } " )
122131
123- def _get_stub (self ) -> pubsub_grpc .PubsubStub :
124- return self ._grpc_manager .async_stub ()
125-
126- def _get_stream_stub (self ) -> pubsub_grpc .PubsubStub :
127- stub = self ._stream_managers [self .stream_topic_manager_count % len (self ._stream_managers )].async_stub ()
128- self .stream_topic_manager_count += 1
129- return stub
132+ def _get_unary_stub (self ) -> pubsub_grpc .PubsubStub :
133+ # Simply round-robin through the unary managers.
134+ # Unary requests will eventually complete (unlike long-lived subscriptions),
135+ # so we do not need the same bookkeeping logic here.
136+ manager = self ._unary_managers [self ._unary_manager_count % len (self ._unary_managers )]
137+ self ._unary_manager_count += 1
138+ return manager .async_stub ()
139+
140+ def _get_stream_stub (self ) -> tuple [pubsub_grpc .PubsubStub , Callable [[], None ]]:
141+ # Try to get a client with capacity for another subscription by round-robining through the stubs.
142+ # Allow up to max_stream_capacity attempts to account for large bursts of requests.
143+ max_stream_capacity = len (self ._stream_managers ) * 100
144+ for _ in range (0 , max_stream_capacity ):
145+ try :
146+ manager = self ._stream_managers [self ._stream_manager_count % len (self ._stream_managers )]
147+ self ._stream_manager_count += 1
148+ return manager .async_stub (), manager .decrement_stream_count
149+ except ClientResourceExhaustedException :
150+ # If the stub is at capacity, continue to the next one.
151+ continue
152+
153+ # Otherwise return an error if no stubs have capacity.
154+ raise ClientResourceExhaustedException (
155+ message = "Maximum number of active subscriptions reached" ,
156+ service = Service .TOPICS ,
157+ )
130158
131159 async def close (self ) -> None :
132- await self ._grpc_manager .close ()
160+ for unary_client in self ._unary_managers :
161+ await unary_client .close ()
133162 for stream_client in self ._stream_managers :
134163 await stream_client .close ()
0 commit comments