66validation, and lifecycle management.
77"""
88
9- import asyncio
109from datetime import datetime , timedelta
1110import logging
1211from pathlib import Path
1716from cryptography .hazmat .primitives import hashes , serialization
1817from cryptography .hazmat .primitives .asymmetric import rsa
1918from cryptography .x509 .oid import NameOID
19+ import trio
2020
2121from libp2p .peer .id import ID
2222
@@ -51,7 +51,18 @@ def __init__(
5151 self .renewal_threshold_hours = renewal_threshold_hours
5252
5353 self ._certificates : dict [tuple [ID , str ], dict [Any , Any ]] = {}
54- self ._renewal_tasks : dict [tuple [ID , str ], asyncio .Task [Any ]] = {}
54+ self ._renewal_scopes : dict [tuple [ID , str ], trio .CancelScope ] = {}
55+ self ._nursery : trio .Nursery | None = None
56+
57+ async def start (self , nursery : trio .Nursery ) -> None :
58+ """Attach manager to a long-lived nursery."""
59+ self ._nursery = nursery
60+
61+ async def shutdown (self ) -> None :
62+ """Cancel all renewal jobs."""
63+ for scope in self ._renewal_scopes .values ():
64+ scope .cancel ()
65+ logger .info ("Certificate manager shutdown complete" )
5566
5667 async def get_certificate (
5768 self ,
@@ -183,13 +194,15 @@ async def _schedule_renewal(
183194 peer_id : ID ,
184195 domain : str ,
185196 cert_data : dict [Any , Any ],
197+ _current_scope : trio .CancelScope | None = None ,
186198 ) -> None :
187199 """Schedule certificate renewal."""
188200 key = (peer_id , domain )
189201
190202 # Cancel existing renewal task
191- if key in self ._renewal_tasks :
192- self ._renewal_tasks [key ].cancel ()
203+ existing_scope = self ._renewal_scopes .get (key )
204+ if existing_scope is not None and existing_scope is not _current_scope :
205+ existing_scope .cancel ()
193206
194207 # Calculate renewal time
195208 expires_at = datetime .fromisoformat (cert_data ["expires_at" ])
@@ -204,24 +217,41 @@ async def _schedule_renewal(
204217 f"Scheduling certificate renewal for { peer_id } in { delay :.0f} seconds"
205218 )
206219
207- async def renew_certificate () -> None :
208- try :
209- await asyncio .sleep (delay )
210-
211- logger .info (f"Renewing certificate for { peer_id } on { domain } " )
212- new_cert_data = await self .get_certificate (
213- peer_id , domain , force_renew = True
214- )
215-
216- # Update cached certificate
217- self ._certificates [key ] = new_cert_data # type: ignore
220+ scope = trio .CancelScope ()
221+ self ._renewal_scopes [key ] = scope
218222
219- except asyncio .CancelledError :
220- logger .debug (f"Certificate renewal cancelled for { peer_id } " )
221- except Exception as e :
222- logger .error (f"Certificate renewal failed for { peer_id } : { e } " )
223-
224- self ._renewal_tasks [key ] = asyncio .create_task (renew_certificate ())
223+ async def renew_certificate () -> None :
224+ with scope :
225+ try :
226+ await trio .sleep (delay )
227+
228+ logger .info (f"Renewing certificate for { peer_id } on { domain } " )
229+ new_cert_data = await self ._generate_certificate (peer_id , domain )
230+ await self ._store_certificate_to_storage (
231+ peer_id , domain , new_cert_data
232+ )
233+ self ._certificates [key ] = new_cert_data
234+
235+ await self ._schedule_renewal (
236+ peer_id ,
237+ domain ,
238+ new_cert_data ,
239+ _current_scope = scope ,
240+ )
241+
242+ except trio .Cancelled :
243+ logger .debug (f"Certificate renewal cancelled for { peer_id } " )
244+ raise
245+ except Exception as e :
246+ logger .error (f"Certificate renewal failed for { peer_id } : { e } " )
247+ finally :
248+ if self ._renewal_scopes .get (key ) is scope :
249+ self ._renewal_scopes .pop (key , None )
250+
251+ if self ._nursery is not None :
252+ self ._nursery .start_soon (renew_certificate )
253+ else :
254+ trio .lowlevel .spawn_system_task (renew_certificate )
225255
226256 def _get_cert_path (self , peer_id : ID , domain : str ) -> Path :
227257 """Get certificate file path."""
@@ -323,9 +353,9 @@ async def cleanup_expired_certificates(self) -> None:
323353 del self ._certificates [key ]
324354
325355 # Cancel renewal task
326- if key in self ._renewal_tasks :
327- self . _renewal_tasks [ key ]. cancel ()
328- del self . _renewal_tasks [ key ]
356+ scope = self ._renewal_scopes . pop ( key , None )
357+ if scope is not None :
358+ scope . cancel ()
329359
330360 if expired_keys :
331361 logger .info (f"Cleaned up { len (expired_keys )} expired certificates" )
@@ -375,26 +405,13 @@ async def revoke_certificate(
375405 del self ._certificates [key ]
376406
377407 # Cancel renewal task
378- if key in self ._renewal_tasks :
379- self . _renewal_tasks [ key ]. cancel ()
380- del self . _renewal_tasks [ key ]
408+ scope = self ._renewal_scopes . pop ( key , None )
409+ if scope is not None :
410+ scope . cancel ()
381411
382412 # Remove from storage
383413 cert_path = self ._get_cert_path (peer_id , domain )
384414 if cert_path .exists ():
385415 cert_path .unlink ()
386416
387417 logger .info (f"Revoked certificate for { peer_id } on { domain } " )
388-
389- async def shutdown (self ) -> None :
390- """Shutdown certificate manager."""
391- # Cancel all renewal tasks
392- for task in self ._renewal_tasks .values ():
393- if not task .done ():
394- task .cancel ()
395-
396- # Wait for tasks to complete
397- if self ._renewal_tasks :
398- await asyncio .gather (* self ._renewal_tasks .values (), return_exceptions = True )
399-
400- logger .info ("Certificate manager shutdown complete" )
0 commit comments