From 183ddfa6145c94b7f7920aa94f1d86cf14c3048a Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Wed, 3 Jun 2026 23:27:21 -0700 Subject: [PATCH 1/2] KIP-368: Periodic SASL reauthentication via session_lifetime_ms --- kafka/net/connection.py | 126 ++++++++++++- test/mock_broker.py | 3 + test/net/test_connection.py | 4 + test/net/test_sasl_reauthentication.py | 252 +++++++++++++++++++++++++ 4 files changed, 383 insertions(+), 2 deletions(-) create mode 100644 test/net/test_sasl_reauthentication.py diff --git a/kafka/net/connection.py b/kafka/net/connection.py index 65ce98fea..8edf99309 100644 --- a/kafka/net/connection.py +++ b/kafka/net/connection.py @@ -1,6 +1,7 @@ import collections import copy import logging +import random import struct import time @@ -59,6 +60,7 @@ def __init__(self, net, node_id=None, broker_version_data=None, **configs): self.broker_version_data = broker_version_data self._api_versions_idx = ApiVersionsRequest.max_version # version of ApiVersionsRequest to try on first connect self._throttle_time = 0 + self._reauth = SaslReauthenticator(self) if self.config['metrics']: self._sensors = KafkaConnectionMetrics( self.config['metrics'], self.config['metric_group_prefix'], node_id) @@ -117,7 +119,7 @@ def _timeout_at(self, now=None, timeout_ms=None): def send_request(self, request, request_timeout_ms=None): future = Future() timeout_at = self._timeout_at(timeout_ms=request_timeout_ms) - if self.initializing: + if self.initializing or self._reauth.is_reauthenticating: self._request_buffer.append((request, future, timeout_at)) return future elif self.paused: @@ -212,6 +214,7 @@ def data_received(self, data): future.success(response) if 'max_in_flight' in self.paused and len(self.in_flight_requests) < self.config['max_in_flight_requests_per_connection']: self.unpause('max_in_flight') + self._reauth.on_response_processed() def eof_received(self): """ Called when the other end calls write_eof() or equivalent. @@ -231,6 +234,7 @@ def connection_lost(self, exc): """ self.connected = self.initializing = False self.transport = None + self._reauth.cancel() error = exc or Errors.KafkaConnectionError() if not self._init_future.is_done: self._init_future.failure(error) @@ -426,10 +430,11 @@ async def _sasl_authenticate(self, timeout_at=None): mechanism = get_sasl_mechanism(self.config['sasl_mechanism'])( host=sasl_host, **self.config) + auth_response = None while not mechanism.is_done() and timeout_at > time.monotonic(): token = mechanism.auth_bytes() if version == 1: - auth_request = SaslAuthenticateRequest(token, version=0) + auth_request = SaslAuthenticateRequest(token) else: auth_request = SaslBytesRequest(token) auth_response = await self._send_request(auth_request, timeout_at=timeout_at) @@ -449,6 +454,9 @@ async def _sasl_authenticate(self, timeout_at=None): raise Errors.SaslAuthenticationFailedError( 'Failed to authenticate via SASL %s' % self.config['sasl_mechanism']) + # KIP-368: SessionLifetimeMs is only present on SaslAuthenticateResponse v1+. + if version == 1: + self._reauth.session_updated(auth_response.session_lifetime_ms) log.info('%s: %s', self, mechanism.auth_details()) def _init_complete(self): @@ -457,3 +465,117 @@ def _init_complete(self): self.connected = True self.send_buffered() self._init_future.success(True) + self._reauth.schedule() + + +class SaslReauthenticator: + """KIP-368 SASL re-authentication state and scheduling for a single + KafkaConnection. Owns the per-connection re-auth lifecycle so the + connection doesn't have to carry the related attributes and coroutines + inline. The connection plugs this in at five points: + + - after each successful SASL auth -> session_updated() + - after init completes -> schedule() + - when send_request needs to gate the public API -> is_reauthenticating + - on every response popped from in_flight_requests -> on_response_processed() + - on connection_lost -> cancel() + """ + + def __init__(self, conn): + self._conn = conn + self.session_lifetime_ms = 0 + self.authenticated_at = None + self._task = None + self._reauthenticating = False + self._drain_future = None + + @property + def is_reauthenticating(self): + return self._reauthenticating + + @property + def task(self): + """The scheduled re-auth task, or None. Exposed for tests/observability.""" + return self._task + + def session_updated(self, session_lifetime_ms): + """Capture broker-advertised session lifetime after each successful + auth round (initial and subsequent re-auths).""" + self.session_lifetime_ms = session_lifetime_ms or 0 + self.authenticated_at = time.monotonic() + + def schedule(self): + """Schedule the next re-auth before the lifetime elapses. Jittered to + 85-95% of the lifetime to avoid synchronised re-auth storms across + many connections (Apache Java semantics). No-op when SASL is disabled + or the broker advertised lifetime=0. + """ + if not self._conn.sasl_enabled or not self.session_lifetime_ms: + return + pct = random.uniform(0.85, 0.95) + delay = max(0.1, (self.session_lifetime_ms * pct) / 1000) + log.debug('%s: Scheduling SASL re-authentication in %.3fs (session_lifetime_ms=%d)', + self._conn, delay, self.session_lifetime_ms) + self._task = self._conn.net.call_later(delay, self._run) + + def cancel(self): + """Cancel any pending re-auth and fail the drain awaiter if present. + Called from KafkaConnection.connection_lost.""" + if self._task is not None: + try: + self._conn.net.unschedule(self._task) + except (ValueError, KeyError): + pass + self._task = None + if self._drain_future is not None and not self._drain_future.is_done: + self._drain_future.failure(Errors.KafkaConnectionError()) + self._drain_future = None + self._reauthenticating = False + + def on_response_processed(self): + """Wake the drain awaiter once in_flight_requests clears during reauth. + Called from KafkaConnection.data_received after each pop.""" + if (self._reauthenticating + and self._drain_future is not None + and not self._conn.in_flight_requests + and not self._drain_future.is_done): + self._drain_future.success(None) + + async def _run(self): + self._task = None + if self._conn.closed: + return + try: + await self._do_reauth() + except BaseException as exc: # pylint: disable=W0718 + # Re-auth failure is transient (KIP-368: not cached like initial + # auth failure); close the connection so the manager reconnects on + # next demand. + log.warning('%s: SASL re-authentication failed: %s', self._conn, exc) + err = exc if isinstance(exc, Exception) else Errors.SaslAuthenticationFailedError(str(exc)) + self._conn.close(err) + + async def _do_reauth(self): + self._reauthenticating = True + try: + # Drain in-flight so the SaslHandshake/Authenticate frames are the + # next bytes on the wire (Apache Java does the same; avoids + # reasoning about FIFO interleaving with the broker's reauth + # validation). + while self._conn.in_flight_requests and not self._conn.closed: + self._drain_future = Future() + if not self._conn.in_flight_requests: + break + await self._drain_future + self._drain_future = None + if self._conn.closed: + return + log.debug('%s: Beginning SASL re-authentication', self._conn) + await self._conn._sasl_authenticate() # pylint: disable=W0212 + finally: + self._reauthenticating = False + self._drain_future = None + if self._conn.closed: + return + self._conn.send_buffered() + self.schedule() diff --git a/test/mock_broker.py b/test/mock_broker.py index 42145268a..bf58356d9 100644 --- a/test/mock_broker.py +++ b/test/mock_broker.py @@ -63,6 +63,9 @@ def __init__(self, net, broker, node_id=0, host='localhost', port=9092): self._broker = broker self._node_id = node_id self._host = host + # Public mirror -- KafkaTCPTransport / KafkaSSLTransport both expose + # `host` for the SASL hostname fallback. + self.host = host self._port = port self._protocol = None self._closed = False diff --git a/test/net/test_connection.py b/test/net/test_connection.py index e546142a0..b43f01b5e 100644 --- a/test/net/test_connection.py +++ b/test/net/test_connection.py @@ -487,6 +487,8 @@ def test_sasl_authenticate_success(self, net): auth_response = MagicMock() auth_response.error_code = 0 auth_response.auth_bytes = b'' + # KIP-368: explicit 0 so _schedule_reauthenticate() short-circuits + auth_response.session_lifetime_ms = 0 responses = iter([handshake_response, auth_response]) def mock_send_request(request, **kwargs): @@ -546,6 +548,8 @@ def _drive_handshake_with_recording_mechanism(self, net, conn): auth_response = MagicMock() auth_response.error_code = 0 auth_response.auth_bytes = b'' + # KIP-368: explicit 0 so _schedule_reauthenticate() short-circuits + auth_response.session_lifetime_ms = 0 responses = iter([handshake_response, auth_response]) def mock_send_request(request, **kwargs): f = Future() diff --git a/test/net/test_sasl_reauthentication.py b/test/net/test_sasl_reauthentication.py new file mode 100644 index 000000000..2b0353b47 --- /dev/null +++ b/test/net/test_sasl_reauthentication.py @@ -0,0 +1,252 @@ +"""KIP-368: SASL connection re-authentication.""" +import time + +import pytest + +import kafka.errors as Errors +from kafka.net.connection import KafkaConnection +from kafka.net.manager import KafkaConnectionManager +from kafka.net.selector import NetworkSelector +from kafka.protocol.sasl import ( + SaslAuthenticateRequest, + SaslAuthenticateResponse, + SaslHandshakeRequest, + SaslHandshakeResponse, +) + +from test.mock_broker import MockBroker + + +SASL_CONFIG = { + 'security_protocol': 'SASL_PLAINTEXT', + 'sasl_mechanism': 'PLAIN', + 'sasl_plain_username': 'user', + 'sasl_plain_password': 'pass', +} + + +@pytest.fixture +def net(): + sel = NetworkSelector() + try: + yield sel + finally: + sel.close() + + +@pytest.fixture +def sasl_broker(): + return MockBroker(broker_version=(2, 5)) # supports SaslAuthenticate v0-2 + + +@pytest.fixture +def sasl_manager(net, sasl_broker): + manager = KafkaConnectionManager( + net, + bootstrap_servers='%s:%d' % (sasl_broker.host, sasl_broker.port), + api_version=sasl_broker.broker_version, + request_timeout_ms=5000, + **SASL_CONFIG, + ) + sasl_broker.attach(manager) + try: + yield manager + finally: + manager.close() + + +def _script_sasl(broker, session_lifetime_ms=0, auth_error_code=0): + """Queue the SaslHandshake + SaslAuthenticate response pair on the mock broker.""" + handshake = SaslHandshakeResponse(error_code=0, mechanisms=['PLAIN']) + broker.respond(SaslHandshakeRequest, handshake) + + auth = SaslAuthenticateResponse( + version=1, + error_code=auth_error_code, + error_message='' if auth_error_code == 0 else 'failed', + auth_bytes=b'', + session_lifetime_ms=session_lifetime_ms, + ) + broker.respond(SaslAuthenticateRequest, auth) + + +def _bootstrap_and_open(manager, broker, lifetime_ms=0, auth_error_code=0): + """Bootstrap the manager (consumes one SASL pair), then open a real + connection (consumes a second SASL pair scripted with `lifetime_ms`).""" + _script_sasl(broker, session_lifetime_ms=0) # bootstrap connection + _script_sasl(broker, session_lifetime_ms=lifetime_ms, auth_error_code=auth_error_code) + manager.bootstrap() + conn = manager.get_connection(broker.node_id) + # get_connection is non-blocking; spin the loop until init completes. + deadline = time.monotonic() + 2.0 + while not conn.connected and not conn.closed and time.monotonic() < deadline: + manager._net.poll(timeout_ms=10) + return conn + + +class TestSaslReauthentication: + def test_no_lifetime_does_not_schedule_reauth(self, sasl_manager, sasl_broker): + conn = _bootstrap_and_open(sasl_manager, sasl_broker, lifetime_ms=0) + + assert conn.connected + assert conn._reauth.session_lifetime_ms == 0 + assert conn._reauth.task is None + assert conn._reauth.authenticated_at is not None + + def test_lifetime_schedules_reauth(self, sasl_manager, sasl_broker, monkeypatch): + # Pin jitter to lower bound for deterministic delay. + monkeypatch.setattr('kafka.net.connection.random.uniform', lambda a, b: a) + + conn = _bootstrap_and_open(sasl_manager, sasl_broker, lifetime_ms=10_000) + + assert conn.connected + assert conn._reauth.session_lifetime_ms == 10_000 + assert conn._reauth.task is not None + assert conn._reauth.task.scheduled_at is not None + delay = conn._reauth.task.scheduled_at - time.monotonic() + assert 8.0 < delay <= 8.6 # 10s * 0.85 jitter + + def test_reauth_fires_and_runs_second_handshake(self, sasl_manager, sasl_broker, monkeypatch): + monkeypatch.setattr('kafka.net.connection.random.uniform', lambda a, b: a) + + conn = _bootstrap_and_open(sasl_manager, sasl_broker, lifetime_ms=200) + # Queue responses for the re-auth that's about to fire. + _script_sasl(sasl_broker, session_lifetime_ms=0) + + assert conn._reauth.session_lifetime_ms == 200 + initial_auth_at = conn._reauth.authenticated_at + + # Drive the loop past the jittered delay (200ms * 0.85 = 170ms) + deadline = time.monotonic() + 1.0 + while conn._reauth.authenticated_at == initial_auth_at and time.monotonic() < deadline: + sasl_manager._net.poll(timeout_ms=10) + + assert conn.connected + assert conn._reauth.authenticated_at > initial_auth_at + assert conn._reauth.session_lifetime_ms == 0 + assert conn._reauth.task is None + assert not conn._reauth.is_reauthenticating + + def test_reauth_failure_closes_connection(self, sasl_manager, sasl_broker, monkeypatch): + monkeypatch.setattr('kafka.net.connection.random.uniform', lambda a, b: a) + + conn = _bootstrap_and_open(sasl_manager, sasl_broker, lifetime_ms=200) + # Re-auth response with SaslAuthenticationFailedError (error code 58) + _script_sasl(sasl_broker, session_lifetime_ms=0, auth_error_code=58) + + deadline = time.monotonic() + 1.0 + while not conn.closed and time.monotonic() < deadline: + sasl_manager._net.poll(timeout_ms=10) + + assert conn.closed + # KIP-368: reauth failure is transient -- must NOT be sticky in the + # manager's auth-failure cache (initial-auth failures are sticky). + assert sasl_broker.node_id not in sasl_manager._auth_failures + + def test_close_cancels_scheduled_reauth(self, sasl_manager, sasl_broker, monkeypatch): + monkeypatch.setattr('kafka.net.connection.random.uniform', lambda a, b: a) + + conn = _bootstrap_and_open(sasl_manager, sasl_broker, lifetime_ms=60_000) + assert conn._reauth.task is not None + + conn.close() + # Let close propagate through the loop. + sasl_manager._net.drain() + + assert conn._reauth.task is None + + def test_negotiates_v2_when_broker_supports(self, sasl_manager, sasl_broker): + # Bootstrap uses an unobserved SASL pair; on the real connection we + # capture the SaslAuthenticateRequest's wire version via respond_fn. + _script_sasl(sasl_broker, session_lifetime_ms=0) # bootstrap + sasl_broker.respond(SaslHandshakeRequest, + SaslHandshakeResponse(error_code=0, mechanisms=['PLAIN'])) + captured = {} + + def auth_handler(api_key, api_version, correlation_id, request_bytes): + captured['api_version'] = api_version + return SaslAuthenticateResponse( + version=1, error_code=0, error_message='', + auth_bytes=b'', session_lifetime_ms=0) + sasl_broker.respond_fn(SaslAuthenticateRequest, auth_handler) + + sasl_manager.bootstrap() + conn = sasl_manager.get_connection(sasl_broker.node_id) + deadline = time.monotonic() + 2.0 + while not conn.connected and not conn.closed and time.monotonic() < deadline: + sasl_manager._net.poll(timeout_ms=10) + + # MockBroker is broker_version=(2,5) -> SaslAuthenticate v0-2 + assert captured.get('api_version') == 2 + + +class TestSaslReauthenticationUnit: + """Lower-level tests that don't need a MockBroker.""" + + def test_session_lifetime_captured_from_v1_response(self, net): + """getattr falls back to 0 when the response object lacks the field.""" + from unittest.mock import MagicMock + from kafka.future import Future + from kafka.protocol.broker_version_data import BrokerVersionData + + conn = KafkaConnection(net, node_id='test', **SASL_CONFIG) + transport = MagicMock() + transport.getPeer.return_value = ('127.0.0.1', 9092) + transport.host = 'broker' + conn.transport = transport + conn.initializing = True + + # Broker supports SaslHandshake v0-1 AND SaslAuthenticate v0-1 + conn.broker_version_data = BrokerVersionData(api_versions={ + SaslHandshakeRequest.API_KEY: (0, 1), + SaslAuthenticateRequest.API_KEY: (0, 1), + }) + + handshake_response = MagicMock() + handshake_response.error_code = 0 + handshake_response.mechanisms = ['PLAIN'] + handshake_response.API_VERSION = 1 + + auth_response = MagicMock() + auth_response.error_code = 0 + auth_response.auth_bytes = b'' + auth_response.session_lifetime_ms = 5_000 + + responses = iter([handshake_response, auth_response]) + def mock_send_request(request, **kwargs): + f = Future() + f.success(next(responses)) + return f + conn._send_request = mock_send_request + + net.run(conn.initialize()) + + assert conn.connected + assert conn._reauth.session_lifetime_ms == 5_000 + assert conn._reauth.task is not None + + def test_schedule_skipped_when_sasl_disabled(self, net): + conn = KafkaConnection(net, node_id='test', security_protocol='PLAINTEXT') + conn._reauth.session_lifetime_ms = 30_000 # would normally schedule + conn._reauth.schedule() + assert conn._reauth.task is None + + def test_schedule_skipped_when_lifetime_zero(self, net): + conn = KafkaConnection(net, node_id='test', **SASL_CONFIG) + conn._reauth.session_lifetime_ms = 0 + conn._reauth.schedule() + assert conn._reauth.task is None + + def test_send_request_buffers_during_reauth(self, net): + from kafka.protocol.metadata import MetadataRequest + + conn = KafkaConnection(net, node_id='test', **SASL_CONFIG) + conn.connected = True + conn.initializing = False + conn._reauth._reauthenticating = True + + req = MetadataRequest() + future = conn.send_request(req) + + assert not future.is_done + assert len(conn._request_buffer) == 1 From e961b4d80e335cc1177a8ae4b030906346ec0b5a Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Thu, 4 Jun 2026 08:12:16 -0700 Subject: [PATCH 2/2] Validate session_lifetime_ms --- kafka/net/connection.py | 9 +++++++-- test/net/test_sasl_reauthentication.py | 4 ++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/kafka/net/connection.py b/kafka/net/connection.py index 8edf99309..737822428 100644 --- a/kafka/net/connection.py +++ b/kafka/net/connection.py @@ -500,8 +500,13 @@ def task(self): def session_updated(self, session_lifetime_ms): """Capture broker-advertised session lifetime after each successful - auth round (initial and subsequent re-auths).""" + auth round (initial and subsequent re-auths). Clamp negative values to 0, + and require minimum non-zero lifetime of 1sec (1000).""" self.session_lifetime_ms = session_lifetime_ms or 0 + if self.session_lifetime_ms < 0: + self.session_lifetime_ms = 0 + elif 0 < self.session_lifetime_ms <= 1000: + self.session_lifetime_ms = 1000 self.authenticated_at = time.monotonic() def schedule(self): @@ -513,7 +518,7 @@ def schedule(self): if not self._conn.sasl_enabled or not self.session_lifetime_ms: return pct = random.uniform(0.85, 0.95) - delay = max(0.1, (self.session_lifetime_ms * pct) / 1000) + delay = (self.session_lifetime_ms * pct) / 1000 log.debug('%s: Scheduling SASL re-authentication in %.3fs (session_lifetime_ms=%d)', self._conn, delay, self.session_lifetime_ms) self._task = self._conn.net.call_later(delay, self._run) diff --git a/test/net/test_sasl_reauthentication.py b/test/net/test_sasl_reauthentication.py index 2b0353b47..c238063fa 100644 --- a/test/net/test_sasl_reauthentication.py +++ b/test/net/test_sasl_reauthentication.py @@ -109,11 +109,11 @@ def test_lifetime_schedules_reauth(self, sasl_manager, sasl_broker, monkeypatch) def test_reauth_fires_and_runs_second_handshake(self, sasl_manager, sasl_broker, monkeypatch): monkeypatch.setattr('kafka.net.connection.random.uniform', lambda a, b: a) - conn = _bootstrap_and_open(sasl_manager, sasl_broker, lifetime_ms=200) + conn = _bootstrap_and_open(sasl_manager, sasl_broker, lifetime_ms=1000) # Queue responses for the re-auth that's about to fire. _script_sasl(sasl_broker, session_lifetime_ms=0) - assert conn._reauth.session_lifetime_ms == 200 + assert conn._reauth.session_lifetime_ms == 1000 initial_auth_at = conn._reauth.authenticated_at # Drive the loop past the jittered delay (200ms * 0.85 = 170ms)