1818from collections import defaultdict
1919from functools import partial
2020from ipaddress import IPv4Address , IPv6Address , IPv4Network , IPv6Network
21- from typing import Optional , TYPE_CHECKING , Sequence
21+ from typing import Iterable , Optional , TYPE_CHECKING , Sequence
2222
2323import attr
2424from aiorpcx import (Event , JSONRPCAutoDetect , JSONRPCConnection ,
25- ReplyAndDisconnect , Request , RPCError , RPCSession ,
25+ ReplyAndDisconnect , Request , RPCError , RPCSession , Service ,
2626 handler_invocation , serve_rs , serve_ws , sleep ,
2727 NewlineFramer , TaskTimeout , timeout_after , run_in_thread )
2828
@@ -221,15 +221,14 @@ async def _start_external_servers(self):
221221
222222 async def _stop_servers (self , services ):
223223 '''Stop the servers of the given protocols.'''
224- server_map = {service : self .servers .pop (service )
225- for service in set (services ).intersection (self .servers )}
226- # Close all before waiting
227- for service , server in server_map .items ():
224+ for service in services :
228225 self .logger .info (f'closing down server for { service } ' )
229- server .close ()
230- # No value in doing these concurrently
231- for server in server_map .values ():
232- await server .wait_closed ()
226+ self .servers [service ].close ()
227+
228+ def _remove_servers (self , services : Iterable [Service ]):
229+ '''Remove the servers of the given protocols.'''
230+ for service in services :
231+ del self .servers [service ]
233232
234233 async def _manage_servers (self ):
235234 paused = False
@@ -242,8 +241,10 @@ async def _manage_servers(self):
242241 self .logger .info (f'maximum sessions { max_sessions :,d} '
243242 f'reached, stopping new connections until '
244243 f'count drops to { low_watermark :,d} ' )
245- await self ._stop_servers (service for service in self .servers
246- if service .protocol != 'rpc' )
244+ services_to_remove = [service for service in self .servers
245+ if service .protocol != 'rpc' ]
246+ await self ._stop_servers (services_to_remove )
247+ self ._remove_servers (services_to_remove )
247248 paused = True
248249 # Start listening for incoming connections if paused and
249250 # session count has fallen
@@ -680,11 +681,19 @@ async def serve(self, notifications, event):
680681 await group .spawn (self ._log_sessions ())
681682 await group .spawn (self ._manage_servers ())
682683 finally :
683- # Close servers then sessions
684+ # Stop listening on servers, so no new sessions can be created
684685 await self ._stop_servers (self .servers .keys ())
686+ # Then close sessions
687+ self .logger .info (f'closing { len (self .sessions ):,d} active sessions' )
685688 async with OldTaskGroup () as group :
686689 for session in list (self .sessions ):
687690 await group .spawn (session .close (force_after = 1 ))
691+ # Finally, wait for servers to be cleaned up and remove servers
692+ self .logger .info (f"waiting for all server's resources to close" )
693+ for server in self .servers .values ():
694+ await server .wait_closed ()
695+ servers_to_remove = list (self .servers .keys ())
696+ self ._remove_servers (servers_to_remove )
688697
689698 def extra_cost (self , session ):
690699 # Note there is no guarantee that session is still in self.sessions. Example traceback:
0 commit comments