1+ import anyio
12import socket
2- from typing import Any
3+ import threading
4+ import time
5+ from typing import Any , Dict , List
36
47try :
58 from httpcore .backends .auto import AutoBackend # type: ignore
811 from httpcore ._backends .auto import AutoBackend # type: ignore
912 from httpcore ._backends .sync import SyncBackend # type: ignore
1013
11- from httpx import AsyncHTTPTransport , HTTPTransport
14+ from httpx import AsyncHTTPTransport , HTTPTransport , Request , Response
1215
1316from firebolt .common .constants import KEEPALIVE_FLAG , KEEPIDLE_RATE
1417
@@ -29,6 +32,41 @@ def override_stream(stream): # type: ignore [no-untyped-def]
2932 return stream
3033
3134
35+ class DNSCache :
36+ def __init__ (self , ttl : float = 30.0 ):
37+ self .ttl = ttl
38+ self .cache : Dict [str , List [str ]] = {}
39+ self .expiry : Dict [str , float ] = {}
40+ self .indices : Dict [str , int ] = {}
41+ self ._lock = threading .Lock ()
42+
43+ def get_ip_round_robin (self , hostname : str ) -> str :
44+ now = time .monotonic ()
45+
46+ with self ._lock :
47+ cached_ips = self .cache .get (hostname )
48+ expires_at = self .expiry .get (hostname , 0 )
49+
50+ if not cached_ips or now >= expires_at :
51+ try :
52+ _ , _ , new_ips = socket .gethostbyname_ex (hostname )
53+ if new_ips :
54+ self .cache [hostname ] = sorted (new_ips )
55+ self .expiry [hostname ] = now + self .ttl
56+ cached_ips = self .cache [hostname ]
57+ except Exception :
58+ if not cached_ips :
59+ raise
60+
61+ # calculate round robin index
62+ current_index = self .indices .get (hostname , 0 )
63+ target_ip = cached_ips [current_index % len (cached_ips )]
64+
65+ self .indices [hostname ] = (current_index + 1 ) % len (cached_ips )
66+
67+ return target_ip
68+
69+
3270class AsyncOverriddenHttpBackend (AutoBackend ):
3371 """
3472 `OverriddenHttpBackend` is a short-term solution for the TCP
@@ -68,18 +106,125 @@ def open_tcp_stream(self, *args, **kwargs): # type: ignore
68106
69107
70108class AsyncKeepaliveTransport (AsyncHTTPTransport ):
109+ _dns_cache = DNSCache (ttl = 30.0 )
110+
71111 def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
112+ self ._client_side_lb = kwargs .pop ("client_side_lb" , False )
72113 super ().__init__ (* args , ** kwargs )
73- if hasattr (self ._pool , "_network_backend" ):
74- self ._pool ._network_backend = AsyncOverriddenHttpBackend () # type: ignore
75- if hasattr (self ._pool , "_backend" ):
76- self ._pool ._backend = AsyncOverriddenHttpBackend () # type: ignore
114+ self ._apply_custom_backend (self )
115+ self ._transport_kwargs = kwargs
116+ self ._ip_transports : Dict [str , AsyncHTTPTransport ] = {}
117+ self ._lock = anyio .Lock ()
118+
119+ def _apply_custom_backend (self , transport : AsyncHTTPTransport ) -> None :
120+ pool = getattr (transport , "_pool" , None )
121+ if pool :
122+ for attr in ["_network_backend" , "_backend" ]:
123+ if hasattr (pool , attr ):
124+ setattr (pool , attr , AsyncOverriddenHttpBackend ())
125+
126+ async def handle_async_request (self , request : Request ) -> Response :
127+ if not self ._client_side_lb :
128+ return await super ().handle_async_request (request )
129+
130+ hostname = request .url .host
131+
132+ try :
133+ target_ip = self ._dns_cache .get_ip_round_robin (hostname )
134+ except Exception :
135+ return await super ().handle_async_request (request )
136+
137+ # Lazy-load the lock to ensure it's bound to the correct event loop
138+ if self ._lock is None :
139+ self ._lock = anyio .Lock ()
140+
141+ async with self ._lock :
142+ if target_ip not in self ._ip_transports :
143+ new_transport = AsyncHTTPTransport (** self ._transport_kwargs )
144+ self ._apply_custom_backend (new_transport )
145+ self ._ip_transports [target_ip ] = new_transport
146+ sub_transport = self ._ip_transports [target_ip ]
147+
148+ original_url = request .url
149+ request .url = request .url .copy_with (host = target_ip )
150+ try :
151+ return await sub_transport .handle_async_request (request )
152+ finally :
153+ request .url = original_url
154+
155+ async def aclose (self ) -> None :
156+ """
157+ Close the primary transport and all sub-transports created for load balancing.
158+ """
159+ # Close the base transport first
160+ await super ().aclose ()
161+
162+ # Close all child transports created for specific IPs
163+ if self ._ip_transports :
164+ async with anyio .create_task_group () as tg :
165+ # Gather all transports in task group and close them
166+ for transport in self ._ip_transports .values ():
167+ tg .start_soon (transport .aclose )
168+
169+ self ._ip_transports .clear ()
77170
78171
79172class KeepaliveTransport (HTTPTransport ):
173+ _dns_cache = DNSCache (ttl = 30.0 )
174+
80175 def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
176+ self ._client_side_lb = kwargs .pop ("client_side_lb" , False )
81177 super ().__init__ (* args , ** kwargs )
82- if hasattr (self ._pool , "_network_backend" ):
83- self ._pool ._network_backend = OverriddenHttpBackend () # type: ignore
84- if hasattr (self ._pool , "_backend" ):
85- self ._pool ._backend = OverriddenHttpBackend () # type: ignore
178+ self ._apply_custom_backend (self )
179+ self ._transport_kwargs = kwargs
180+ self ._ip_transports : Dict [str , HTTPTransport ] = {}
181+ self ._lock = threading .Lock ()
182+
183+ def _apply_custom_backend (self , transport : HTTPTransport ) -> None :
184+ pool = getattr (transport , "_pool" , None )
185+ if pool :
186+ for attr in ["_network_backend" , "_backend" ]:
187+ if hasattr (pool , attr ):
188+ setattr (pool , attr , OverriddenHttpBackend ())
189+
190+ def handle_request (self , request : Request ) -> Response :
191+ if not self ._client_side_lb :
192+ return super ().handle_request (request )
193+
194+ hostname = request .url .host
195+
196+ try :
197+ target_ip = self ._dns_cache .get_ip_round_robin (hostname )
198+ except Exception :
199+ return super ().handle_request (request )
200+
201+ with self ._lock :
202+ if target_ip not in self ._ip_transports :
203+ new_transport = HTTPTransport (** self ._transport_kwargs )
204+ self ._apply_custom_backend (new_transport )
205+ self ._ip_transports [target_ip ] = new_transport
206+ sub_transport = self ._ip_transports [target_ip ]
207+
208+ original_url = request .url
209+ request .url = request .url .copy_with (host = target_ip )
210+ try :
211+ return sub_transport .handle_request (request )
212+ finally :
213+ request .url = original_url
214+
215+ def close (self ) -> None :
216+ """
217+ Close the primary transport and all sub-transports.
218+ """
219+ # Close the base transport first
220+ super ().close ()
221+
222+ # Close all child transports created for specific IPs
223+ with self ._lock :
224+ for transport in self ._ip_transports .values ():
225+ try :
226+ transport .close ()
227+ except Exception :
228+ # Best effort to close others if one fails
229+ pass
230+ self ._ip_transports .clear ()
0 commit comments