Skip to content

Commit 7250337

Browse files
authored
kafka.net: support socket_connection_setup_timeout_ms w/ exp backoff (KIP-601) (#3027)
1 parent 6e48314 commit 7250337

8 files changed

Lines changed: 192 additions & 129 deletions

File tree

kafka/net/connection.py

Lines changed: 42 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,6 @@ def connection_made(self, transport):
269269
client_id=self.config['client_id'],
270270
receive_message_max_bytes=self.config['receive_message_max_bytes'],
271271
ident=log_prefix)
272-
self.net.call_soon(self._check_version)
273272

274273
def pause(self, v):
275274
self.paused.add(v)
@@ -345,31 +344,36 @@ def _maybe_unthrottle(self):
345344
self._throttle_time = 0
346345
self.unpause('throttle')
347346

348-
async def _check_version(self, timeout_ms=None):
347+
async def initialize(self, timeout_at=None):
348+
if timeout_at is None:
349+
timeout_at = self._timeout_at()
350+
try:
351+
await self._get_api_versions(timeout_at)
352+
if self.sasl_enabled:
353+
await self._sasl_authenticate(timeout_at)
354+
except Exception as error:
355+
self.close(error)
356+
else:
357+
self._init_complete()
358+
359+
async def _get_api_versions(self, timeout_at=None):
360+
if timeout_at is None:
361+
timeout_at = self._timeout_at()
349362
if self.broker_version_data is not None:
350363
try:
351364
self._api_versions_idx = self.broker_version_data.api_version(ApiVersionsRequest)
352365
except Errors.IncompatibleBrokerVersion:
353366
log.debug('%s: Using pre-configured api_version %s for ApiVersions', self, self.broker_version)
354-
self._init_complete()
355367
return
356368

357-
if timeout_ms is not None:
358-
timeout_ms = self.config['api_version_auto_timeout_ms']
359-
timeout_at = self._timeout_at(timeout_ms=timeout_ms)
360369
while timeout_at > time.monotonic():
361370
version = self._api_versions_idx
362371
request = ApiVersionsRequest(
363372
version=version,
364373
client_software_name=self.config['client_software_name'],
365374
client_software_version=self.config['client_software_version'],
366375
)
367-
try:
368-
response = await self._send_request(request, timeout_at=timeout_at)
369-
except Exception as exc:
370-
self.close(exc)
371-
return
372-
376+
response = await self._send_request(request, timeout_at=timeout_at)
373377
error_type = Errors.for_code(response.error_code)
374378
if error_type is Errors.NoError:
375379
break
@@ -382,92 +386,74 @@ async def _check_version(self, timeout_ms=None):
382386
self._api_versions_idx = 0
383387
continue
384388
else:
385-
self.close(error_type())
386-
return
389+
raise error_type()
390+
else:
391+
raise Errors.KafkaTimeoutError('Timeout during ApiVersions check')
387392

388393
api_versions = {api_version.api_key: (api_version.min_version, api_version.max_version)
389394
for api_version in response.api_keys}
390395
self.broker_version_data = BrokerVersionData(api_versions=api_versions)
391396
log.info('%s: Broker version identified as %s', self, '.'.join(map(str, self.broker_version)))
392-
if self.sasl_enabled:
393-
await self._sasl_authenticate()
394-
if self.initializing:
395-
self._init_complete()
396397

397398
@property
398399
def sasl_enabled(self):
399400
return self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL')
400401

401-
async def _sasl_authenticate(self):
402+
async def _sasl_authenticate(self, timeout_at=None):
403+
if timeout_at is None:
404+
timeout_at = self._timeout_at()
402405
# Step 1: SaslHandshake to negotiate mechanism
403406
request = SaslHandshakeRequest(
404407
mechanism=self.config['sasl_mechanism'],
405408
max_version=1)
406-
try:
407-
response = await self._send_request(request)
408-
except Exception as exc:
409-
self.close(Errors.KafkaConnectionError('SaslHandshake failed: %s' % exc))
410-
return
411-
409+
response = await self._send_request(request, timeout_at=timeout_at)
412410
error_type = Errors.for_code(response.error_code)
413411
if error_type is not Errors.NoError:
414412
log.error('%s: SaslHandshake failed: %s', self, error_type.__name__)
415-
self.close(error_type())
416-
return
413+
raise error_type()
417414

418415
if self.config['sasl_mechanism'] not in response.mechanisms:
419-
self.close(Errors.UnsupportedSaslMechanismError(
416+
raise Errors.UnsupportedSaslMechanismError(
420417
'Kafka broker does not support %s sasl mechanism. Enabled mechanisms: %s'
421-
% (self.config['sasl_mechanism'], response.mechanisms)))
422-
return
418+
% (self.config['sasl_mechanism'], response.mechanisms))
423419

424420
# Step 2: SASL authentication exchange
425421
version = response.API_VERSION
426422
# Prefer the configured hostname (stored on the transport) so that
427423
# mechanisms like GSSAPI construct service principals against the
428424
# user-supplied name, not whichever IP getaddrinfo handed us.
429425
sasl_host = self.transport.host if self.transport.host else self.transport.getPeer()[0]
430-
try:
431-
mechanism = get_sasl_mechanism(self.config['sasl_mechanism'])(
432-
host=sasl_host, **self.config)
433-
except Exception as exc:
434-
self.close(exc)
435-
return
426+
mechanism = get_sasl_mechanism(self.config['sasl_mechanism'])(
427+
host=sasl_host, **self.config)
436428

437-
while not mechanism.is_done():
429+
while not mechanism.is_done() and timeout_at > time.monotonic():
438430
token = mechanism.auth_bytes()
439431
if version == 1:
440432
auth_request = SaslAuthenticateRequest(token, version=0)
441433
else:
442434
auth_request = SaslBytesRequest(token)
443-
444-
try:
445-
auth_response = await self._send_request(auth_request)
446-
except Exception as exc:
447-
self.close(Errors.KafkaConnectionError('SaslAuthenticate failed: %s' % exc))
448-
return
449-
435+
auth_response = await self._send_request(auth_request, timeout_at=timeout_at)
450436
error_type = Errors.for_code(auth_response.error_code)
451437
if error_type is not Errors.NoError:
452-
self.close(Errors.SaslAuthenticationFailedError(
453-
'%s: %s' % (error_type.__name__, auth_response.error_message)))
454-
return
438+
raise Errors.SaslAuthenticationFailedError(
439+
'%s: %s' % (error_type.__name__, auth_response.error_message))
455440

456441
# GSSAPI does not get a final recv in v0 unframed mode
457442
if version == 0 and mechanism.is_done():
458443
break
459-
460444
mechanism.receive(auth_response.auth_bytes)
461445

462-
if not mechanism.is_authenticated():
463-
self.close(Errors.SaslAuthenticationFailedError(
464-
'Failed to authenticate via SASL %s' % self.config['sasl_mechanism']))
465-
return
446+
if time.monotonic() > timeout_at:
447+
raise Errors.KafkaTimeoutError('SASL Authentication timed out')
448+
elif not mechanism.is_authenticated():
449+
raise Errors.SaslAuthenticationFailedError(
450+
'Failed to authenticate via SASL %s' % self.config['sasl_mechanism'])
466451

467452
log.info('%s: %s', self, mechanism.auth_details())
468453

469454
def _init_complete(self):
470-
self.initializing = False
471-
self.connected = True
472-
self.send_buffered()
473-
self._init_future.success(True)
455+
if self.initializing:
456+
self.initializing = False
457+
self.connected = True
458+
self.send_buffered()
459+
self._init_future.success(True)

kafka/net/inet.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import errno
22
import logging
33
import socket
4+
import time
45
from urllib.parse import urlparse
56

67
import kafka.errors as Errors
@@ -9,18 +10,20 @@
910
log = logging.getLogger(__name__)
1011

1112

12-
async def create_connection(net, host, port, socket_options=(), proxy_url=None):
13+
async def create_connection(net, host, port, socket_options=(), proxy_url=None, timeout_at=None):
1314
"""Connect to host:port; raises KafkaConnectionError on failure"""
1415
socket_factory = KafkaNetSocket(proxy_url)
1516
addrs = socket_factory.dns_lookup(host, port)
1617
exceptions = [Errors.KafkaConnectionError('DNS Resolution failure')]
1718
for res in addrs:
1819
try:
1920
log.debug('%s: Attempting to connect to %s (options: %s)', socket_factory, res, socket_options)
20-
sock = await socket_factory.connect(net, res, socket_options)
21+
sock = await socket_factory.connect(net, res, socket_options, timeout_at=timeout_at)
2122
except (socket.error, OSError) as e:
2223
exceptions.append(Errors.KafkaConnectionError('unable to connect: %s' % (e,)))
2324
continue
25+
except Errors.KafkaTimeoutError:
26+
raise Errors.KafkaConnectionError('Connection timed out')
2427
except Errors.KafkaConnectionError as e:
2528
exceptions.append(e)
2629
continue
@@ -79,17 +82,17 @@ def dns_lookup(self, host, port, raise_error=False):
7982
def socket(self, family=socket.AF_UNSPEC, sock_type=socket.SOCK_STREAM, proto=socket.IPPROTO_TCP):
8083
return socket.socket(family, sock_type, proto)
8184

82-
async def connect(self, net, addrinfo, socket_options=()):
85+
async def connect(self, net, addrinfo, socket_options=(), timeout_at=None):
8386
"""Create non-blocking socket (with options) and connect to addrinfo tuple"""
8487
family, sock_type, proto, _canonname, sockaddr = addrinfo
8588
sock = self.socket(family, sock_type, proto)
8689
sock.setblocking(False)
8790
for option in socket_options:
8891
sock.setsockopt(*option)
89-
return await self.sock_connect(net, sock, sockaddr)
92+
return await self.sock_connect(net, sock, sockaddr, timeout_at=timeout_at)
9093

91-
async def sock_connect(self, net, sock, sockaddr):
92-
while True:
94+
async def sock_connect(self, net, sock, sockaddr, timeout_at=None):
95+
while timeout_at is None or time.monotonic() < timeout_at:
9396
ret = None
9497
try:
9598
ret = self.connect_ex(sock, sockaddr)
@@ -106,12 +109,14 @@ async def sock_connect(self, net, sock, sockaddr):
106109
# Needs retry
107110
# WSAEINVAL == 10022, but errno.WSAEINVAL is not available on non-win systems
108111
elif ret in (errno.EINPROGRESS, errno.EALREADY, errno.EWOULDBLOCK, 10022):
109-
await net.wait_write(sock)
112+
await net.wait_write(sock, timeout_at=timeout_at)
110113

111114
# Connection failed
112115
else:
113116
errstr = errno.errorcode.get(ret, 'UNKNOWN')
114117
raise Errors.KafkaConnectionError('{} {}'.format(ret, errstr))
118+
else:
119+
raise Errors.KafkaTimeoutError('Connection timed out')
115120

116121
def connect_ex(self, sock, sockaddr):
117122
return sock.connect_ex(sockaddr)

kafka/net/manager.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ class KafkaConnectionManager:
3030
'reconnect_backoff_ms': 50,
3131
'reconnect_backoff_max_ms': 30000,
3232
'request_timeout_ms': 30000,
33-
'socket_connection_timeout_ms': 5000,
33+
'socket_connection_setup_timeout_ms': 10000,
34+
'socket_connection_setup_timeout_max_ms': 30000,
3435
'socket_options': [
3536
(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1),
3637
(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1),
@@ -86,7 +87,7 @@ def __init__(self, net, **configs):
8687
)
8788
self.cluster.attach(self)
8889
self._conns = {}
89-
self._backoff = dict() # node_id => (failures, backoff_until)
90+
self._backoff = dict() # node_id => (failures, backoff_until, socket_connect_setup_timeout_ms)
9091
# Cache the most recent SASL / SSL / auth failure per node so we can
9192
# surface it to the user instead of silently retrying forever.
9293
# Cleared on successful connect.
@@ -119,7 +120,9 @@ async def _do_bootstrap(self, deadline):
119120
bootstrap_broker = random.choice(self.cluster.bootstrap_brokers())
120121
log.debug('Attempting bootstrap with %s', bootstrap_broker)
121122
try:
123+
timeout_ms = (deadline - time.monotonic()) * 1000 if deadline is not None else None
122124
conn = self.get_connection(bootstrap_broker.node_id,
125+
timeout_ms=timeout_ms,
123126
pop_on_close=False,
124127
refresh_metadata_on_err=False,
125128
reset_backoff_on_connect=False)
@@ -218,10 +221,11 @@ def _build_ssl_context(self):
218221
ctx.verify_flags |= ssl.VERIFY_CRL_CHECK_LEAF
219222
return ctx
220223

221-
async def _build_transport(self, node):
224+
async def _build_transport(self, node, timeout_at=None):
222225
sock = await create_connection(self._net, node.host, node.port,
223226
self.config['socket_options'],
224-
proxy_url=self.config['proxy_url'])
227+
proxy_url=self.config['proxy_url'],
228+
timeout_at=timeout_at)
225229
if self.ssl_enabled:
226230
transport = KafkaSSLTransport(self._net, sock, self._build_ssl_context(),
227231
host=node.host, ssl_check_hostname=self.config['ssl_check_hostname'])
@@ -235,11 +239,11 @@ async def _build_transport(self, node):
235239
else:
236240
return transport
237241

238-
async def _connect(self, node, conn, reset_backoff_on_connect=True):
242+
async def _connect(self, node, conn, reset_backoff_on_connect=True, timeout_at=None):
239243
try:
240-
transport = await self._build_transport(node)
244+
transport = await self._build_transport(node, timeout_at=timeout_at)
241245
conn.connection_made(transport)
242-
await conn.init_future
246+
await conn.initialize(timeout_at=timeout_at)
243247
except Exception as exc:
244248
log.error('Connection failed: %s', exc)
245249
conn.connection_lost(exc)
@@ -280,12 +284,10 @@ def get_connection(self, node_id, timeout_ms=None,
280284
if refresh_metadata_on_err:
281285
conn.close_future.add_errback(lambda _: self.cluster.request_update())
282286
self._conns[node_id] = conn
283-
self._net.call_soon(lambda: self._connect(node, conn, reset_backoff_on_connect=reset_backoff_on_connect))
284287
if timeout_ms is None:
285-
timeout_ms = self.config['socket_connection_timeout_ms']
286-
self._net.call_later(timeout_ms / 1000,
287-
lambda: conn.close(Errors.KafkaConnectionError('Connection timed out'))
288-
if not conn.init_future.is_done else None)
288+
timeout_ms = self.socket_connection_setup_timeout_ms(node_id)
289+
timeout_at = time.monotonic() + timeout_ms / 1000
290+
self._net.call_soon(lambda: self._connect(node, conn, reset_backoff_on_connect=reset_backoff_on_connect, timeout_at=timeout_at))
289291
return conn
290292

291293
def send(self, request, node_id=None, request_timeout_ms=None):
@@ -335,18 +337,29 @@ def reset_backoff(self, node_id):
335337
except KeyError:
336338
pass
337339

338-
def reconnect_jitter_pct(self):
340+
def jitter_pct(self):
339341
return random.uniform(0.8, 1.2)
340342

343+
def _calculate_exp_timeout(self, key, failures):
344+
max_keys = {
345+
'reconnect_backoff_ms': 'reconnect_backoff_max_ms',
346+
'socket_connection_setup_timeout_ms': 'socket_connection_setup_timeout_max_ms',
347+
}
348+
timeout_ms = self.config[key] * 2 ** (failures - 1)
349+
if key in max_keys:
350+
max_ms = self.config[max_keys[key]]
351+
timeout_ms = min(max_ms, timeout_ms)
352+
return timeout_ms * self.jitter_pct()
353+
341354
def update_backoff(self, node_id):
342-
failures, _ = self._backoff.get(node_id, (0, 0))
355+
failures, _, _ = self._backoff.get(node_id, (0, 0, 0))
343356
failures += 1
344-
backoff_ms = self.config['reconnect_backoff_ms'] * 2 ** (failures - 1)
345-
backoff_ms = min(backoff_ms, self.config['reconnect_backoff_max_ms'])
346-
backoff_ms *= self.reconnect_jitter_pct()
347-
log.debug('%s reconnect backoff %d ms after %s failures', node_id, backoff_ms, failures)
357+
backoff_ms = self._calculate_exp_timeout('reconnect_backoff_ms', failures)
358+
connect_ms = self._calculate_exp_timeout('socket_connection_setup_timeout_ms', failures)
359+
log.debug('%s reconnect backoff %d ms / connect timeout %d ms after %s failures',
360+
node_id, backoff_ms, connect_ms, failures)
348361
backoff_until_time = time.monotonic() + (backoff_ms / 1000)
349-
self._backoff[node_id] = (failures, backoff_until_time)
362+
self._backoff[node_id] = (failures, backoff_until_time, connect_ms)
350363

351364
def connection_delay(self, node_id):
352365
"""Connection delay in seconds.
@@ -357,6 +370,11 @@ def connection_delay(self, node_id):
357370
return 0
358371
return max(0, self._backoff[node_id][1] - time.monotonic())
359372

373+
def socket_connection_setup_timeout_ms(self, node_id):
374+
if node_id not in self._backoff:
375+
return self.config['socket_connection_setup_timeout_ms']
376+
return self._backoff[node_id][2]
377+
360378
def auth_failure(self, node_id):
361379
"""Return the most recent auth-class failure for ``node_id``,
362380
or None if there is no sticky failure on record."""

0 commit comments

Comments
 (0)