Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 129 additions & 2 deletions kafka/net/connection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import collections
import copy
import logging
import random
import struct
import time

Expand Down Expand Up @@ -59,6 +60,7 @@ def __init__(self, net, node_id=None, broker_version_data=None, **configs):
self.broker_version_data = broker_version_data
self._api_versions_idx = ApiVersionsRequest.max_version # version of ApiVersionsRequest to try on first connect
self._throttle_time = 0
self._reauth = SaslReauthenticator(self)
if self.config['metrics']:
self._sensors = KafkaConnectionMetrics(
self.config['metrics'], self.config['metric_group_prefix'], node_id)
Expand Down Expand Up @@ -117,7 +119,7 @@ def _timeout_at(self, now=None, timeout_ms=None):
def send_request(self, request, request_timeout_ms=None):
future = Future()
timeout_at = self._timeout_at(timeout_ms=request_timeout_ms)
if self.initializing:
if self.initializing or self._reauth.is_reauthenticating:
self._request_buffer.append((request, future, timeout_at))
return future
elif self.paused:
Expand Down Expand Up @@ -212,6 +214,7 @@ def data_received(self, data):
future.success(response)
if 'max_in_flight' in self.paused and len(self.in_flight_requests) < self.config['max_in_flight_requests_per_connection']:
self.unpause('max_in_flight')
self._reauth.on_response_processed()

def eof_received(self):
""" Called when the other end calls write_eof() or equivalent.
Expand All @@ -231,6 +234,7 @@ def connection_lost(self, exc):
"""
self.connected = self.initializing = False
self.transport = None
self._reauth.cancel()
error = exc or Errors.KafkaConnectionError()
if not self._init_future.is_done:
self._init_future.failure(error)
Expand Down Expand Up @@ -426,10 +430,11 @@ async def _sasl_authenticate(self, timeout_at=None):
mechanism = get_sasl_mechanism(self.config['sasl_mechanism'])(
host=sasl_host, **self.config)

auth_response = None
while not mechanism.is_done() and timeout_at > time.monotonic():
token = mechanism.auth_bytes()
if version == 1:
auth_request = SaslAuthenticateRequest(token, version=0)
auth_request = SaslAuthenticateRequest(token)
else:
auth_request = SaslBytesRequest(token)
auth_response = await self._send_request(auth_request, timeout_at=timeout_at)
Expand All @@ -449,6 +454,9 @@ async def _sasl_authenticate(self, timeout_at=None):
raise Errors.SaslAuthenticationFailedError(
'Failed to authenticate via SASL %s' % self.config['sasl_mechanism'])

# KIP-368: SessionLifetimeMs is only present on SaslAuthenticateResponse v1+.
if version == 1:
self._reauth.session_updated(auth_response.session_lifetime_ms)
log.info('%s: %s', self, mechanism.auth_details())

def _init_complete(self):
Expand All @@ -457,3 +465,122 @@ def _init_complete(self):
self.connected = True
self.send_buffered()
self._init_future.success(True)
self._reauth.schedule()


class SaslReauthenticator:
"""KIP-368 SASL re-authentication state and scheduling for a single
KafkaConnection. Owns the per-connection re-auth lifecycle so the
connection doesn't have to carry the related attributes and coroutines
inline. The connection plugs this in at five points:

- after each successful SASL auth -> session_updated()
- after init completes -> schedule()
- when send_request needs to gate the public API -> is_reauthenticating
- on every response popped from in_flight_requests -> on_response_processed()
- on connection_lost -> cancel()
"""

def __init__(self, conn):
self._conn = conn
self.session_lifetime_ms = 0
self.authenticated_at = None
self._task = None
self._reauthenticating = False
self._drain_future = None

@property
def is_reauthenticating(self):
return self._reauthenticating

@property
def task(self):
"""The scheduled re-auth task, or None. Exposed for tests/observability."""
return self._task

def session_updated(self, session_lifetime_ms):
"""Capture broker-advertised session lifetime after each successful
auth round (initial and subsequent re-auths). Clamp negative values to 0,
and require minimum non-zero lifetime of 1sec (1000)."""
self.session_lifetime_ms = session_lifetime_ms or 0
if self.session_lifetime_ms < 0:
self.session_lifetime_ms = 0
elif 0 < self.session_lifetime_ms <= 1000:
self.session_lifetime_ms = 1000
self.authenticated_at = time.monotonic()

def schedule(self):
"""Schedule the next re-auth before the lifetime elapses. Jittered to
85-95% of the lifetime to avoid synchronised re-auth storms across
many connections (Apache Java semantics). No-op when SASL is disabled
or the broker advertised lifetime=0.
"""
if not self._conn.sasl_enabled or not self.session_lifetime_ms:
return
pct = random.uniform(0.85, 0.95)
delay = (self.session_lifetime_ms * pct) / 1000
log.debug('%s: Scheduling SASL re-authentication in %.3fs (session_lifetime_ms=%d)',
self._conn, delay, self.session_lifetime_ms)
self._task = self._conn.net.call_later(delay, self._run)

def cancel(self):
"""Cancel any pending re-auth and fail the drain awaiter if present.
Called from KafkaConnection.connection_lost."""
if self._task is not None:
try:
self._conn.net.unschedule(self._task)
except (ValueError, KeyError):
pass
self._task = None
if self._drain_future is not None and not self._drain_future.is_done:
self._drain_future.failure(Errors.KafkaConnectionError())
self._drain_future = None
self._reauthenticating = False

def on_response_processed(self):
"""Wake the drain awaiter once in_flight_requests clears during reauth.
Called from KafkaConnection.data_received after each pop."""
if (self._reauthenticating
and self._drain_future is not None
and not self._conn.in_flight_requests
and not self._drain_future.is_done):
self._drain_future.success(None)

async def _run(self):
self._task = None
if self._conn.closed:
return
try:
await self._do_reauth()
except BaseException as exc: # pylint: disable=W0718
# Re-auth failure is transient (KIP-368: not cached like initial
# auth failure); close the connection so the manager reconnects on
# next demand.
log.warning('%s: SASL re-authentication failed: %s', self._conn, exc)
err = exc if isinstance(exc, Exception) else Errors.SaslAuthenticationFailedError(str(exc))
self._conn.close(err)

async def _do_reauth(self):
self._reauthenticating = True
try:
# Drain in-flight so the SaslHandshake/Authenticate frames are the
# next bytes on the wire (Apache Java does the same; avoids
# reasoning about FIFO interleaving with the broker's reauth
# validation).
while self._conn.in_flight_requests and not self._conn.closed:
self._drain_future = Future()
if not self._conn.in_flight_requests:
break
await self._drain_future
self._drain_future = None
if self._conn.closed:
return
log.debug('%s: Beginning SASL re-authentication', self._conn)
await self._conn._sasl_authenticate() # pylint: disable=W0212
finally:
self._reauthenticating = False
self._drain_future = None
if self._conn.closed:
return
self._conn.send_buffered()
self.schedule()
3 changes: 3 additions & 0 deletions test/mock_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def __init__(self, net, broker, node_id=0, host='localhost', port=9092):
self._broker = broker
self._node_id = node_id
self._host = host
# Public mirror -- KafkaTCPTransport / KafkaSSLTransport both expose
# `host` for the SASL hostname fallback.
self.host = host
self._port = port
self._protocol = None
self._closed = False
Expand Down
4 changes: 4 additions & 0 deletions test/net/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,8 @@ def test_sasl_authenticate_success(self, net):
auth_response = MagicMock()
auth_response.error_code = 0
auth_response.auth_bytes = b''
# KIP-368: explicit 0 so _schedule_reauthenticate() short-circuits
auth_response.session_lifetime_ms = 0

responses = iter([handshake_response, auth_response])
def mock_send_request(request, **kwargs):
Expand Down Expand Up @@ -546,6 +548,8 @@ def _drive_handshake_with_recording_mechanism(self, net, conn):
auth_response = MagicMock()
auth_response.error_code = 0
auth_response.auth_bytes = b''
# KIP-368: explicit 0 so _schedule_reauthenticate() short-circuits
auth_response.session_lifetime_ms = 0
responses = iter([handshake_response, auth_response])
def mock_send_request(request, **kwargs):
f = Future()
Expand Down
Loading
Loading