Skip to content

Commit 5e2ae4f

Browse files
committed
feat: add client-side round-robin load balancing for Core clusters
1 parent 036fd29 commit 5e2ae4f

10 files changed

Lines changed: 1652 additions & 17 deletions

File tree

src/firebolt/async_db/connection.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
fix_url_schema,
4646
parse_url_and_params,
4747
validate_engine_name_and_url_v1,
48+
validate_firebolt_parameters_v1,
49+
validate_firebolt_parameters_v2,
4850
)
4951

5052
logger = logging.getLogger(__name__)
@@ -292,6 +294,7 @@ async def connect(
292294
url: Optional[str] = None,
293295
autocommit: bool = True,
294296
additional_parameters: Dict[str, Any] = {},
297+
client_side_lb: Optional[bool] = None,
295298
) -> Connection:
296299
# auth parameter is optional in function signature
297300
# but is required to connect.
@@ -313,14 +316,22 @@ async def connect(
313316
if auth_version == FireboltAuthVersion.CORE:
314317
# Verify that Core-incompatible parameters are not provided
315318
validate_firebolt_core_parameters(account_name, engine_name, engine_url)
319+
if client_side_lb == None:
320+
# When using Core, client_side_lb is True by default
321+
client_side_lb = True
322+
316323
return connect_core(
317324
auth=auth,
318325
user_agent_header=user_agent_header,
319326
database=database,
320327
connection_url=url,
321328
autocommit=autocommit,
329+
client_side_lb=client_side_lb,
322330
)
323331
elif auth_version == FireboltAuthVersion.V2:
332+
# Verify that v2-incompatible parameters are not provided
333+
validate_firebolt_parameters_v2(client_side_lb)
334+
324335
assert account_name is not None
325336
return await connect_v2(
326337
auth=auth,
@@ -334,6 +345,9 @@ async def connect(
334345
autocommit=autocommit,
335346
)
336347
elif auth_version == FireboltAuthVersion.V1:
348+
# Verify that v1-incompatible parameters are not provided
349+
validate_firebolt_parameters_v1(client_side_lb)
350+
337351
return await connect_v1(
338352
auth=auth,
339353
user_agent_header=user_agent_header,
@@ -490,6 +504,7 @@ def connect_core(
490504
database: Optional[str] = None,
491505
connection_url: Optional[str] = None,
492506
autocommit: bool = True,
507+
client_side_lb: bool = False,
493508
) -> Connection:
494509
"""Connect to Firebolt Core.
495510
@@ -519,6 +534,7 @@ def connect_core(
519534
timeout=Timeout(DEFAULT_TIMEOUT_SECONDS, read=None),
520535
headers={"User-Agent": user_agent_header},
521536
verify=ctx,
537+
client_side_lb=client_side_lb,
522538
)
523539

524540
return Connection(

src/firebolt/client/client.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,16 @@ def clone(self) -> "Client":
106106

107107
class Client(FireboltClientMixin, HttpxClient, metaclass=ABCMeta):
108108
def __init__(self, *args: Any, **kwargs: Any):
109+
# We pop it from kwargs because it's unknown to HttpxClient which won't accept it
110+
client_side_lb = kwargs.pop("client_side_lb", False)
111+
109112
super().__init__(
110113
*args,
111114
**kwargs,
112-
transport=KeepaliveTransport(verify=kwargs.get("verify", True)),
115+
transport=KeepaliveTransport(
116+
verify=kwargs.get("verify", True),
117+
client_side_lb=client_side_lb,
118+
),
113119
)
114120

115121
@property
@@ -139,13 +145,15 @@ def __init__(
139145
auth: Auth,
140146
account_name: str,
141147
api_endpoint: str = DEFAULT_API_URL,
148+
client_side_lb: bool = False,
142149
**kwargs: Any,
143150
):
144151
super().__init__(
145152
*args,
146153
auth=auth,
147154
account_name=account_name,
148155
api_endpoint=api_endpoint,
156+
client_side_lb=client_side_lb,
149157
**kwargs,
150158
)
151159

@@ -273,10 +281,15 @@ def _resolve_engine_url(self, engine_name: str) -> str:
273281

274282
class AsyncClient(FireboltClientMixin, HttpxAsyncClient, metaclass=ABCMeta):
275283
def __init__(self, *args: Any, **kwargs: Any):
284+
# We pop it from kwargs because it's unknown to HttpxClient which won't accept it
285+
client_side_lb = kwargs.pop("client_side_lb", False)
276286
super().__init__(
277287
*args,
278288
**kwargs,
279-
transport=AsyncKeepaliveTransport(verify=kwargs.get("verify", True)),
289+
transport=AsyncKeepaliveTransport(
290+
verify=kwargs.get("verify", True),
291+
client_side_lb=client_side_lb,
292+
),
280293
)
281294

282295
@property
@@ -306,13 +319,15 @@ def __init__(
306319
auth: Auth,
307320
account_name: str,
308321
api_endpoint: str = DEFAULT_API_URL,
322+
client_side_lb: bool = False,
309323
**kwargs: Any,
310324
):
311325
super().__init__(
312326
*args,
313327
auth=auth,
314328
account_name=account_name,
315329
api_endpoint=api_endpoint,
330+
client_side_lb=client_side_lb,
316331
**kwargs,
317332
)
318333

src/firebolt/client/http_backend.py

Lines changed: 155 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import anyio
12
import socket
2-
from typing import Any
3+
import threading
4+
import time
5+
from typing import Any, Dict, List
36

47
try:
58
from httpcore.backends.auto import AutoBackend # type: ignore
@@ -8,7 +11,7 @@
811
from httpcore._backends.auto import AutoBackend # type: ignore
912
from httpcore._backends.sync import SyncBackend # type: ignore
1013

11-
from httpx import AsyncHTTPTransport, HTTPTransport
14+
from httpx import AsyncHTTPTransport, HTTPTransport, Request, Response
1215

1316
from firebolt.common.constants import KEEPALIVE_FLAG, KEEPIDLE_RATE
1417

@@ -29,6 +32,41 @@ def override_stream(stream): # type: ignore [no-untyped-def]
2932
return stream
3033

3134

35+
class DNSCache:
36+
def __init__(self, ttl: float = 30.0):
37+
self.ttl = ttl
38+
self.cache: Dict[str, List[str]] = {}
39+
self.expiry: Dict[str, float] = {}
40+
self.indices: Dict[str, int] = {}
41+
self._lock = threading.Lock()
42+
43+
def get_ip_round_robin(self, hostname: str) -> str:
44+
now = time.monotonic()
45+
46+
with self._lock:
47+
cached_ips = self.cache.get(hostname)
48+
expires_at = self.expiry.get(hostname, 0)
49+
50+
if not cached_ips or now >= expires_at:
51+
try:
52+
_, _, new_ips = socket.gethostbyname_ex(hostname)
53+
if new_ips:
54+
self.cache[hostname] = sorted(new_ips)
55+
self.expiry[hostname] = now + self.ttl
56+
cached_ips = self.cache[hostname]
57+
except Exception:
58+
if not cached_ips:
59+
raise
60+
61+
# calculate round robin index
62+
current_index = self.indices.get(hostname, 0)
63+
target_ip = cached_ips[current_index % len(cached_ips)]
64+
65+
self.indices[hostname] = (current_index + 1) % len(cached_ips)
66+
67+
return target_ip
68+
69+
3270
class AsyncOverriddenHttpBackend(AutoBackend):
3371
"""
3472
`OverriddenHttpBackend` is a short-term solution for the TCP
@@ -68,18 +106,125 @@ def open_tcp_stream(self, *args, **kwargs): # type: ignore
68106

69107

70108
class AsyncKeepaliveTransport(AsyncHTTPTransport):
109+
_dns_cache = DNSCache(ttl=30.0)
110+
71111
def __init__(self, *args: Any, **kwargs: Any) -> None:
112+
self._client_side_lb = kwargs.pop("client_side_lb", False)
72113
super().__init__(*args, **kwargs)
73-
if hasattr(self._pool, "_network_backend"):
74-
self._pool._network_backend = AsyncOverriddenHttpBackend() # type: ignore
75-
if hasattr(self._pool, "_backend"):
76-
self._pool._backend = AsyncOverriddenHttpBackend() # type: ignore
114+
self._apply_custom_backend(self)
115+
self._transport_kwargs = kwargs
116+
self._ip_transports: Dict[str, AsyncHTTPTransport] = {}
117+
self._lock = anyio.Lock()
118+
119+
def _apply_custom_backend(self, transport: AsyncHTTPTransport) -> None:
120+
pool = getattr(transport, "_pool", None)
121+
if pool:
122+
for attr in ["_network_backend", "_backend"]:
123+
if hasattr(pool, attr):
124+
setattr(pool, attr, AsyncOverriddenHttpBackend())
125+
126+
async def handle_async_request(self, request: Request) -> Response:
127+
if not self._client_side_lb:
128+
return await super().handle_async_request(request)
129+
130+
hostname = request.url.host
131+
132+
try:
133+
target_ip = self._dns_cache.get_ip_round_robin(hostname)
134+
except Exception:
135+
return await super().handle_async_request(request)
136+
137+
# Lazy-load the lock to ensure it's bound to the correct event loop
138+
if self._lock is None:
139+
self._lock = anyio.Lock()
140+
141+
async with self._lock:
142+
if target_ip not in self._ip_transports:
143+
new_transport = AsyncHTTPTransport(**self._transport_kwargs)
144+
self._apply_custom_backend(new_transport)
145+
self._ip_transports[target_ip] = new_transport
146+
sub_transport = self._ip_transports[target_ip]
147+
148+
original_url = request.url
149+
request.url = request.url.copy_with(host=target_ip)
150+
try:
151+
return await sub_transport.handle_async_request(request)
152+
finally:
153+
request.url = original_url
154+
155+
async def aclose(self) -> None:
156+
"""
157+
Close the primary transport and all sub-transports created for load balancing.
158+
"""
159+
# Close the base transport first
160+
await super().aclose()
161+
162+
# Close all child transports created for specific IPs
163+
if self._ip_transports:
164+
async with anyio.create_task_group() as tg:
165+
# Gather all transports in task group and close them
166+
for transport in self._ip_transports.values():
167+
tg.start_soon(transport.aclose)
168+
169+
self._ip_transports.clear()
77170

78171

79172
class KeepaliveTransport(HTTPTransport):
173+
_dns_cache = DNSCache(ttl=30.0)
174+
80175
def __init__(self, *args: Any, **kwargs: Any) -> None:
176+
self._client_side_lb = kwargs.pop("client_side_lb", False)
81177
super().__init__(*args, **kwargs)
82-
if hasattr(self._pool, "_network_backend"):
83-
self._pool._network_backend = OverriddenHttpBackend() # type: ignore
84-
if hasattr(self._pool, "_backend"):
85-
self._pool._backend = OverriddenHttpBackend() # type: ignore
178+
self._apply_custom_backend(self)
179+
self._transport_kwargs = kwargs
180+
self._ip_transports: Dict[str, HTTPTransport] = {}
181+
self._lock = threading.Lock()
182+
183+
def _apply_custom_backend(self, transport: HTTPTransport) -> None:
184+
pool = getattr(transport, "_pool", None)
185+
if pool:
186+
for attr in ["_network_backend", "_backend"]:
187+
if hasattr(pool, attr):
188+
setattr(pool, attr, OverriddenHttpBackend())
189+
190+
def handle_request(self, request: Request) -> Response:
191+
if not self._client_side_lb:
192+
return super().handle_request(request)
193+
194+
hostname = request.url.host
195+
196+
try:
197+
target_ip = self._dns_cache.get_ip_round_robin(hostname)
198+
except Exception:
199+
return super().handle_request(request)
200+
201+
with self._lock:
202+
if target_ip not in self._ip_transports:
203+
new_transport = HTTPTransport(**self._transport_kwargs)
204+
self._apply_custom_backend(new_transport)
205+
self._ip_transports[target_ip] = new_transport
206+
sub_transport = self._ip_transports[target_ip]
207+
208+
original_url = request.url
209+
request.url = request.url.copy_with(host=target_ip)
210+
try:
211+
return sub_transport.handle_request(request)
212+
finally:
213+
request.url = original_url
214+
215+
def close(self) -> None:
216+
"""
217+
Close the primary transport and all sub-transports.
218+
"""
219+
# Close the base transport first
220+
super().close()
221+
222+
# Close all child transports created for specific IPs
223+
with self._lock:
224+
for transport in self._ip_transports.values():
225+
try:
226+
transport.close()
227+
except Exception:
228+
# Best effort to close others if one fails
229+
pass
230+
self._ip_transports.clear()

src/firebolt/db/connection.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
fix_url_schema,
4646
parse_url_and_params,
4747
validate_engine_name_and_url_v1,
48+
validate_firebolt_parameters_v1,
49+
validate_firebolt_parameters_v2,
4850
)
4951

5052
logger = logging.getLogger(__name__)
@@ -61,6 +63,7 @@ def connect(
6163
url: Optional[str] = None,
6264
autocommit: bool = True,
6365
additional_parameters: Dict[str, Any] = {},
66+
client_side_lb: Optional[bool] = None,
6467
) -> Connection:
6568
# auth parameter is optional in function signature
6669
# but is required to connect.
@@ -82,15 +85,22 @@ def connect(
8285
if auth_version == FireboltAuthVersion.CORE:
8386
# Verify that Core-incompatible parameters are not provided
8487
validate_firebolt_core_parameters(account_name, engine_name, engine_url)
88+
if client_side_lb == None:
89+
# When using Core, client_side_lb is True by default
90+
client_side_lb = True
8591

8692
return connect_core(
8793
auth=auth,
8894
user_agent_header=user_agent_header,
8995
database=database,
9096
connection_url=url,
9197
autocommit=autocommit,
98+
client_side_lb=client_side_lb,
9299
)
93100
elif auth_version == FireboltAuthVersion.V2:
101+
# Verify that v2-incompatible parameters are not provided
102+
validate_firebolt_parameters_v2(client_side_lb)
103+
94104
assert account_name is not None
95105
return connect_v2(
96106
auth=auth,
@@ -104,6 +114,9 @@ def connect(
104114
autocommit=autocommit,
105115
)
106116
elif auth_version == FireboltAuthVersion.V1:
117+
# Verify that v1-incompatible parameters are not provided
118+
validate_firebolt_parameters_v1(client_side_lb)
119+
107120
return connect_v1(
108121
auth=auth,
109122
user_agent_header=user_agent_header,
@@ -490,6 +503,7 @@ def connect_core(
490503
database: Optional[str] = None,
491504
connection_url: Optional[str] = None,
492505
autocommit: bool = True,
506+
client_side_lb: bool = True,
493507
) -> Connection:
494508
"""Connect to Firebolt Core.
495509
@@ -520,6 +534,7 @@ def connect_core(
520534
timeout=Timeout(DEFAULT_TIMEOUT_SECONDS, read=None),
521535
headers={"User-Agent": user_agent_header},
522536
verify=ctx,
537+
client_side_lb=client_side_lb,
523538
)
524539

525540
return Connection(

0 commit comments

Comments
 (0)