11import collections
22import copy
33import logging
4+ import random
45import struct
56import time
67
@@ -59,6 +60,7 @@ def __init__(self, net, node_id=None, broker_version_data=None, **configs):
5960 self .broker_version_data = broker_version_data
6061 self ._api_versions_idx = ApiVersionsRequest .max_version # version of ApiVersionsRequest to try on first connect
6162 self ._throttle_time = 0
63+ self ._reauth = SaslReauthenticator (self )
6264 if self .config ['metrics' ]:
6365 self ._sensors = KafkaConnectionMetrics (
6466 self .config ['metrics' ], self .config ['metric_group_prefix' ], node_id )
@@ -117,7 +119,7 @@ def _timeout_at(self, now=None, timeout_ms=None):
117119 def send_request (self , request , request_timeout_ms = None ):
118120 future = Future ()
119121 timeout_at = self ._timeout_at (timeout_ms = request_timeout_ms )
120- if self .initializing :
122+ if self .initializing or self . _reauth . is_reauthenticating :
121123 self ._request_buffer .append ((request , future , timeout_at ))
122124 return future
123125 elif self .paused :
@@ -212,6 +214,7 @@ def data_received(self, data):
212214 future .success (response )
213215 if 'max_in_flight' in self .paused and len (self .in_flight_requests ) < self .config ['max_in_flight_requests_per_connection' ]:
214216 self .unpause ('max_in_flight' )
217+ self ._reauth .on_response_processed ()
215218
216219 def eof_received (self ):
217220 """ Called when the other end calls write_eof() or equivalent.
@@ -231,6 +234,7 @@ def connection_lost(self, exc):
231234 """
232235 self .connected = self .initializing = False
233236 self .transport = None
237+ self ._reauth .cancel ()
234238 error = exc or Errors .KafkaConnectionError ()
235239 if not self ._init_future .is_done :
236240 self ._init_future .failure (error )
@@ -426,10 +430,11 @@ async def _sasl_authenticate(self, timeout_at=None):
426430 mechanism = get_sasl_mechanism (self .config ['sasl_mechanism' ])(
427431 host = sasl_host , ** self .config )
428432
433+ auth_response = None
429434 while not mechanism .is_done () and timeout_at > time .monotonic ():
430435 token = mechanism .auth_bytes ()
431436 if version == 1 :
432- auth_request = SaslAuthenticateRequest (token , version = 0 )
437+ auth_request = SaslAuthenticateRequest (token )
433438 else :
434439 auth_request = SaslBytesRequest (token )
435440 auth_response = await self ._send_request (auth_request , timeout_at = timeout_at )
@@ -449,6 +454,9 @@ async def _sasl_authenticate(self, timeout_at=None):
449454 raise Errors .SaslAuthenticationFailedError (
450455 'Failed to authenticate via SASL %s' % self .config ['sasl_mechanism' ])
451456
457+ # KIP-368: SessionLifetimeMs is only present on SaslAuthenticateResponse v1+.
458+ if version == 1 :
459+ self ._reauth .session_updated (auth_response .session_lifetime_ms )
452460 log .info ('%s: %s' , self , mechanism .auth_details ())
453461
454462 def _init_complete (self ):
@@ -457,3 +465,117 @@ def _init_complete(self):
457465 self .connected = True
458466 self .send_buffered ()
459467 self ._init_future .success (True )
468+ self ._reauth .schedule ()
469+
470+
471+ class SaslReauthenticator :
472+ """KIP-368 SASL re-authentication state and scheduling for a single
473+ KafkaConnection. Owns the per-connection re-auth lifecycle so the
474+ connection doesn't have to carry the related attributes and coroutines
475+ inline. The connection plugs this in at five points:
476+
477+ - after each successful SASL auth -> session_updated()
478+ - after init completes -> schedule()
479+ - when send_request needs to gate the public API -> is_reauthenticating
480+ - on every response popped from in_flight_requests -> on_response_processed()
481+ - on connection_lost -> cancel()
482+ """
483+
484+ def __init__ (self , conn ):
485+ self ._conn = conn
486+ self .session_lifetime_ms = 0
487+ self .authenticated_at = None
488+ self ._task = None
489+ self ._reauthenticating = False
490+ self ._drain_future = None
491+
492+ @property
493+ def is_reauthenticating (self ):
494+ return self ._reauthenticating
495+
496+ @property
497+ def task (self ):
498+ """The scheduled re-auth task, or None. Exposed for tests/observability."""
499+ return self ._task
500+
501+ def session_updated (self , session_lifetime_ms ):
502+ """Capture broker-advertised session lifetime after each successful
503+ auth round (initial and subsequent re-auths)."""
504+ self .session_lifetime_ms = session_lifetime_ms or 0
505+ self .authenticated_at = time .monotonic ()
506+
507+ def schedule (self ):
508+ """Schedule the next re-auth before the lifetime elapses. Jittered to
509+ 85-95% of the lifetime to avoid synchronised re-auth storms across
510+ many connections (Apache Java semantics). No-op when SASL is disabled
511+ or the broker advertised lifetime=0.
512+ """
513+ if not self ._conn .sasl_enabled or not self .session_lifetime_ms :
514+ return
515+ pct = random .uniform (0.85 , 0.95 )
516+ delay = max (0.1 , (self .session_lifetime_ms * pct ) / 1000 )
517+ log .debug ('%s: Scheduling SASL re-authentication in %.3fs (session_lifetime_ms=%d)' ,
518+ self ._conn , delay , self .session_lifetime_ms )
519+ self ._task = self ._conn .net .call_later (delay , self ._run )
520+
521+ def cancel (self ):
522+ """Cancel any pending re-auth and fail the drain awaiter if present.
523+ Called from KafkaConnection.connection_lost."""
524+ if self ._task is not None :
525+ try :
526+ self ._conn .net .unschedule (self ._task )
527+ except (ValueError , KeyError ):
528+ pass
529+ self ._task = None
530+ if self ._drain_future is not None and not self ._drain_future .is_done :
531+ self ._drain_future .failure (Errors .KafkaConnectionError ())
532+ self ._drain_future = None
533+ self ._reauthenticating = False
534+
535+ def on_response_processed (self ):
536+ """Wake the drain awaiter once in_flight_requests clears during reauth.
537+ Called from KafkaConnection.data_received after each pop."""
538+ if (self ._reauthenticating
539+ and self ._drain_future is not None
540+ and not self ._conn .in_flight_requests
541+ and not self ._drain_future .is_done ):
542+ self ._drain_future .success (None )
543+
544+ async def _run (self ):
545+ self ._task = None
546+ if self ._conn .closed :
547+ return
548+ try :
549+ await self ._do_reauth ()
550+ except BaseException as exc : # pylint: disable=W0718
551+ # Re-auth failure is transient (KIP-368: not cached like initial
552+ # auth failure); close the connection so the manager reconnects on
553+ # next demand.
554+ log .warning ('%s: SASL re-authentication failed: %s' , self ._conn , exc )
555+ err = exc if isinstance (exc , Exception ) else Errors .SaslAuthenticationFailedError (str (exc ))
556+ self ._conn .close (err )
557+
558+ async def _do_reauth (self ):
559+ self ._reauthenticating = True
560+ try :
561+ # Drain in-flight so the SaslHandshake/Authenticate frames are the
562+ # next bytes on the wire (Apache Java does the same; avoids
563+ # reasoning about FIFO interleaving with the broker's reauth
564+ # validation).
565+ while self ._conn .in_flight_requests and not self ._conn .closed :
566+ self ._drain_future = Future ()
567+ if not self ._conn .in_flight_requests :
568+ break
569+ await self ._drain_future
570+ self ._drain_future = None
571+ if self ._conn .closed :
572+ return
573+ log .debug ('%s: Beginning SASL re-authentication' , self ._conn )
574+ await self ._conn ._sasl_authenticate () # pylint: disable=W0212
575+ finally :
576+ self ._reauthenticating = False
577+ self ._drain_future = None
578+ if self ._conn .closed :
579+ return
580+ self ._conn .send_buffered ()
581+ self .schedule ()
0 commit comments