11import asyncio
2+ import contextlib
23import socket
34import ssl
45import sys
56import time
7+ import traceback
68from collections .abc import Callable
79from typing import Any , Literal
810from urllib .parse import urlparse
3537
3638
3739class SocketMixin (BaseTransport ):
40+ def _close_socket_safely (self ) -> None :
41+ if self ._socket is not None :
42+ with contextlib .suppress (ssl .SSLError , Exception ):
43+ self ._socket .close ()
44+
3845 @property
3946 def sock (self ) -> socket .socket :
4047 if self ._socket is None or not self .is_connected :
@@ -184,6 +191,31 @@ def _create_socket_with_proxy(self, proxy: str) -> socket.socket:
184191
185192 return sock
186193
194+ def _perform_ssl_handshake (self , raw_sock : socket .socket ) -> socket .socket :
195+ """
196+ Выполняет SSL handshake с сервером.
197+
198+ :param raw_sock: Обычный сокет
199+ :return: SSL сокет
200+ """
201+ raw_sock .setblocking (True )
202+
203+ try :
204+ raw_sock .settimeout (10.0 )
205+ wrapped = self ._ssl_context .wrap_socket (
206+ raw_sock ,
207+ server_hostname = self .host ,
208+ do_handshake_on_connect = True ,
209+ suppress_ragged_eofs = True ,
210+ )
211+ wrapped .setsockopt (socket .SOL_SOCKET , socket .SO_KEEPALIVE , 1 )
212+ wrapped .setblocking (False )
213+ return wrapped
214+ except ssl .SSLError as e :
215+ self .logger .error ("SSL handshake failed: %s" , e )
216+ raw_sock .close ()
217+ raise
218+
187219 async def connect (self , user_agent : UserAgentPayload | None = None ) -> dict [str , Any ]:
188220 """
189221 Устанавливает соединение с сервером и выполняет handshake.
@@ -213,27 +245,51 @@ async def connect(self, user_agent: UserAgentPayload | None = None) -> dict[str,
213245 )
214246 else :
215247 raw_sock = await loop .run_in_executor (
216- None , lambda : socket .create_connection ((self .host , self .port ))
248+ None , lambda : socket .create_connection ((self .host , self .port ), timeout = 10.0 )
249+ )
250+
251+ try :
252+ self ._socket = await asyncio .wait_for (
253+ loop .run_in_executor (None , lambda : self ._perform_ssl_handshake (raw_sock )),
254+ timeout = 15.0 ,
217255 )
218- self ._socket = self ._ssl_context .wrap_socket (raw_sock , server_hostname = self .host )
219- self ._socket .setsockopt (socket .SOL_SOCKET , socket .SO_KEEPALIVE , 1 )
256+ except asyncio .TimeoutError :
257+ raw_sock .close ()
258+ self .logger .error ("SSL handshake timeout" )
259+ raise
260+
220261 self .is_connected = True
221262 self ._incoming = asyncio .Queue ()
222263 self ._outgoing = asyncio .Queue ()
223264 self ._pending = {}
224- self ._recv_task = asyncio .create_task (self ._recv_loop ())
225- self ._outgoing_task = asyncio .create_task (self ._outgoing_loop ())
265+ self ._recv_task = self ._create_safe_task (self ._recv_loop (), name = "recv_loop socket task" )
266+ self ._outgoing_task = self ._create_safe_task (
267+ self ._outgoing_loop (), name = "outgoing_loop socket task"
268+ )
226269 self .logger .info ("Socket connected, starting handshake" )
227270 return await self ._handshake (user_agent )
228271
229272 def _recv_exactly (self , sock : socket .socket , n : int ) -> bytes :
273+ """
274+ Получает ровно n байт из сокета. Обрабатывает SSL ошибки корректно.
275+ """
230276 buf = bytearray ()
231- while len (buf ) < n :
232- chunk = sock .recv (n - len (buf ))
233- if not chunk :
234- return bytes (buf )
235- buf .extend (chunk )
236- return bytes (buf )
277+ try :
278+ while len (buf ) < n :
279+ try :
280+ chunk = sock .recv (n - len (buf ))
281+ except ssl .SSLWantReadError :
282+ continue
283+ except ssl .SSLWantWriteError :
284+ continue
285+
286+ if not chunk :
287+ break
288+ buf .extend (chunk )
289+ return bytes (buf )
290+ except (ssl .SSLError , ConnectionError , BrokenPipeError ) as e :
291+ self .logger .debug ("SSL/Connection error in _recv_exactly: %s" , e )
292+ raise
237293
238294 async def _parse_header (
239295 self , loop : asyncio .AbstractEventLoop , sock : socket .socket
@@ -299,6 +355,8 @@ async def _recv_loop(self) -> None:
299355
300356 sock = self ._socket
301357 loop = asyncio .get_running_loop ()
358+ consecutive_errors = 0
359+ max_consecutive_errors = 3
302360
303361 while True :
304362 try :
@@ -312,6 +370,8 @@ async def _recv_loop(self) -> None:
312370 if not datas :
313371 continue
314372
373+ consecutive_errors = 0
374+
315375 for data_item in datas :
316376 seq = data_item .get ("seq" )
317377
@@ -326,20 +386,58 @@ async def _recv_loop(self) -> None:
326386 except asyncio .CancelledError :
327387 self .logger .debug ("Recv loop cancelled" )
328388 raise
329- except Exception :
330- self .logger .exception ("Error in recv_loop" )
389+ except (
390+ ssl .SSLError ,
391+ ssl .SSLEOFError ,
392+ ConnectionResetError ,
393+ BrokenPipeError ,
394+ ) as ssl_err :
395+ consecutive_errors += 1
396+ self .logger .error (
397+ "SSL/Connection error in recv_loop (error %d/%d): %s" ,
398+ consecutive_errors ,
399+ max_consecutive_errors ,
400+ ssl_err ,
401+ )
331402 self .is_connected = False
332403
333- if self .reconnect :
404+ self ._close_socket_safely ()
405+
406+ if self .reconnect and consecutive_errors < max_consecutive_errors :
334407 self .logger .info ("Reconnect enabled, attempting to restore connection..." )
335408 try :
336- await asyncio .sleep (self . reconnect_delay )
409+ await asyncio .sleep (min ( 2 ** consecutive_errors , 10 ) )
337410 await self .connect (self .user_agent )
338411 sock = self ._socket
339412 self .logger .info ("Connection restored successfully" )
340413 except Exception :
341- self .logger .exception ("Failed to restore connection, exiting recv_loop" )
342- break
414+ self .logger .exception ("Failed to restore connection" )
415+ if consecutive_errors >= max_consecutive_errors :
416+ self .logger .error (
417+ "Max reconnection attempts reached, exiting recv_loop"
418+ )
419+ break
420+ else :
421+ self .logger .warning (
422+ "Reconnect disabled or max errors reached, exiting recv_loop"
423+ )
424+ break
425+ except Exception as e :
426+ consecutive_errors += 1
427+ self .logger .exception ("Error in recv_loop: %s" , e )
428+ self .is_connected = False
429+
430+ if self .reconnect and consecutive_errors < max_consecutive_errors :
431+ self .logger .info ("Reconnect enabled, attempting to restore connection..." )
432+ try :
433+ await asyncio .sleep (min (2 ** consecutive_errors , 10 ))
434+ await self .connect (self .user_agent )
435+ sock = self ._socket
436+ self .logger .info ("Connection restored successfully" )
437+ except Exception :
438+ self .logger .exception ("Failed to restore connection" )
439+ if consecutive_errors >= max_consecutive_errors :
440+ break
343441 else :
344442 self .logger .warning ("Reconnect disabled, exiting recv_loop" )
345443 break
@@ -389,15 +487,11 @@ async def _send_and_wait(
389487 )
390488 return data
391489
392- except (ssl .SSLEOFError , ssl .SSLError , ConnectionError ) as conn_err :
393- self .logger .warning ("Connection lost, reconnecting ..." )
490+ except (ssl .SSLEOFError , ssl .SSLError , ConnectionError , BrokenPipeError ) as conn_err :
491+ self .logger .warning ("Connection lost: %s, attempting reconnect ..." , conn_err )
394492 self .is_connected = False
395- try :
396- await self .connect (self .user_agent )
397- except Exception as exc :
398- self .logger .exception ("Reconnect failed" )
399- raise exc from conn_err
400- raise SocketNotConnectedError from conn_err
493+
494+ self ._close_socket_safely ()
401495 except asyncio .TimeoutError :
402496 self .logger .exception ("Send and wait failed (opcode=%s, seq=%s)" , opcode , msg ["seq" ])
403497 raise SocketSendError from None
0 commit comments