11import asyncio
22import contextlib
3+ import errno
34import json
45import socket
56import ssl
89from urllib .parse import urlparse
910
1011import lz4 .block
12+ import lz4 .frame
1113import msgpack
1214from typing_extensions import override
1315
2628
2729
2830class SocketMixin (BaseTransport ):
31+ MAX_UNCOMPRESSED_SIZE = 10 * 1024 * 1024
32+ MAX_PAYLOAD_LENGTH = 50 * 1024 * 1024
33+
34+ async def _close_socket (self ):
35+ async with self ._sock_lock :
36+ sock = self ._socket
37+ self ._socket = None
38+ if sock :
39+ try :
40+ try :
41+ sock .shutdown (socket .SHUT_RDWR )
42+ except Exception :
43+ pass
44+ sock .close ()
45+ except Exception as e :
46+ self .logger .debug ("Error closing socket: %s" , e , exc_info = True )
47+
2948 @property
3049 def sock (self ) -> socket .socket :
3150 if self ._socket is None or not self .is_connected :
3251 self .logger .critical ("Socket not connected when access attempted" )
3352 raise SocketNotConnectedError ()
3453 return self ._socket
3554
55+ def _looks_like_lz4_frame (self , payload : bytes ) -> bool :
56+ return len (payload ) >= 4 and payload [0 :4 ] == b"\x04 \x22 \x4d \x18 "
57+
3658 def _unpack_packet (self , data : bytes ) -> dict [str , Any ] | None :
37- ver = int .from_bytes (data [0 :1 ], "big" )
38- cmd = int .from_bytes (data [1 :2 ], "big" )
59+ if len (data ) < 10 :
60+ self .logger .warning ("Packet too short: %d bytes" , len (data ))
61+ return None
62+
63+ ver = data [0 ]
64+ cmd = data [1 ]
3965 seq = int .from_bytes (data [2 :4 ], "big" )
4066 opcode = int .from_bytes (data [4 :6 ], "big" )
4167 packed_len = int .from_bytes (data [6 :10 ], "big" , signed = False )
4268 comp_flag = packed_len >> 24
4369 payload_length = packed_len & 0xFFFFFF
70+
71+ if payload_length > self .MAX_PAYLOAD_LENGTH :
72+ self .logger .warning ("payload_length too large: %d" , payload_length )
73+ return None
74+
75+ if len (data ) < 10 + payload_length :
76+ self .logger .warning (
77+ "Not enough bytes for declared payload_length: have=%d need=%d" ,
78+ len (data ) - 10 ,
79+ payload_length ,
80+ )
81+ return None
82+
4483 payload_bytes = data [10 : 10 + payload_length ]
4584
46- payload = None
47- if payload_bytes :
48- if comp_flag != 0 :
49- # TODO: надо выяснить правильный размер распаковки
50- # uncompressed_size = int.from_bytes(payload_bytes[0:4], "big")
51- compressed_data = payload_bytes
52- try :
53- payload_bytes = lz4 .block .decompress (
54- compressed_data ,
55- uncompressed_size = 99999 ,
56- )
57- except lz4 .block .LZ4BlockError :
58- return None
85+ if not payload_bytes :
86+ return {"ver" : ver , "cmd" : cmd , "seq" : seq , "opcode" : opcode , "payload" : None }
87+
88+ try :
5989 payload = msgpack .unpackb (payload_bytes , raw = False , strict_map_key = False )
6090
61- return {
62- "ver" : ver ,
63- "cmd" : cmd ,
64- "seq" : seq ,
65- "opcode" : opcode ,
66- "payload" : payload , #
67- }
91+ return {"ver" : ver , "cmd" : cmd , "seq" : seq , "opcode" : opcode , "payload" : payload }
92+ except Exception as ex_msgpack :
93+ self .logger .debug (
94+ "msgpack direct unpack failed: %s — will try compressed paths" , ex_msgpack
95+ )
96+
97+ try :
98+ if self ._looks_like_lz4_frame (payload_bytes ):
99+ try :
100+ decompressed = lz4 .frame .decompress (payload_bytes )
101+ payload = msgpack .unpackb (decompressed , raw = False , strict_map_key = False )
102+ return {
103+ "ver" : ver ,
104+ "cmd" : cmd ,
105+ "seq" : seq ,
106+ "opcode" : opcode ,
107+ "payload" : payload ,
108+ }
109+ except Exception as ex_frame :
110+ self .logger .warning ("lz4.frame.decompress failed: %s" , ex_frame )
111+ except Exception :
112+ self .logger .exception ("Unexpected error testing lz4.frame" )
113+
114+ try :
115+ if len (payload_bytes ) >= 4 :
116+ maybe_size = int .from_bytes (payload_bytes [0 :4 ], "big" )
117+ if 0 < maybe_size <= self .MAX_UNCOMPRESSED_SIZE :
118+ compressed_data = payload_bytes [4 :]
119+ try :
120+ decompressed = lz4 .block .decompress (
121+ compressed_data , uncompressed_size = maybe_size
122+ )
123+ payload = msgpack .unpackb (decompressed , raw = False , strict_map_key = False )
124+ return {
125+ "ver" : ver ,
126+ "cmd" : cmd ,
127+ "seq" : seq ,
128+ "opcode" : opcode ,
129+ "payload" : payload ,
130+ }
131+ except (lz4 .block .LZ4BlockError , MemoryError ) as e :
132+ self .logger .warning ("lz4.block with prefixed size failed: %s" , e )
133+ else :
134+ self .logger .debug ("prefixed size %r not plausible, skipping" , maybe_size )
135+ except Exception :
136+ self .logger .exception ("Error during prefixed-size lz4 handling" )
137+
138+ try :
139+ try :
140+ decompressed = lz4 .block .decompress (
141+ payload_bytes , uncompressed_size = self .MAX_UNCOMPRESSED_SIZE
142+ )
143+ payload = msgpack .unpackb (decompressed , raw = False , strict_map_key = False )
144+ return {"ver" : ver , "cmd" : cmd , "seq" : seq , "opcode" : opcode , "payload" : payload }
145+ except lz4 .block .LZ4BlockError as e :
146+ self .logger .warning ("lz4.block.decompress (no-pref) failed: %s" , e )
147+ except MemoryError as me :
148+ self .logger .error ("MemoryError while decompressing LZ4: %s" , me )
149+ except Exception :
150+ self .logger .exception ("Unexpected error when attempting lz4.block.decompress" )
151+
152+ self .logger .warning (
153+ "Failed to unpack payload: seq=%s opcode=%s comp_flag=%s payload_len=%d" ,
154+ seq ,
155+ opcode ,
156+ comp_flag ,
157+ payload_length ,
158+ )
159+ return None
68160
69161 def _pack_packet (
70162 self ,
@@ -325,7 +417,9 @@ async def _recv_data(self, loop, header, sock):
325417
326418 data = self ._unpack_packet (raw )
327419 if not data :
328- self .logger .warning ("Failed to unpack packet" )
420+ self .logger .warning (
421+ "Failed to unpack packet (possibly corrupted or unsupported compression)"
422+ )
329423 return None
330424
331425 payload_objs = data .get ("payload" )
@@ -363,6 +457,7 @@ async def _recv_loop(self) -> None:
363457
364458 if not datas :
365459 self .logger .warning ("No data received, continuing recv loop" )
460+ await asyncio .sleep (RECV_LOOP_BACKOFF_DELAY )
366461 continue
367462
368463 consecutive_errors = 0
@@ -402,7 +497,7 @@ async def _recv_loop(self) -> None:
402497
403498 self ._pending .clear ()
404499
405- self ._socket = None
500+ await self ._close_socket ()
406501
407502 if self .reconnect and consecutive_errors < max_consecutive_errors :
408503 self .logger .info (
@@ -420,7 +515,7 @@ async def _recv_loop(self) -> None:
420515 self .logger .exception ("Error in recv_loop: %s" , e )
421516 self .is_connected = False
422517
423- self ._socket = None
518+ await self ._close_socket ()
424519
425520 if self .reconnect and consecutive_errors < max_consecutive_errors :
426521 self .logger .info (
@@ -469,7 +564,15 @@ async def _send_and_wait(
469564 raise SocketNotConnectedError
470565
471566 sock = self ._socket
472- await loop .run_in_executor (None , lambda : sock .sendall (packet ))
567+ try :
568+ await loop .run_in_executor (None , lambda : sock .sendall (packet ))
569+ except OSError as e :
570+ if e .errno in (errno .EBADF , errno .EPIPE , errno .ENOTCONN ):
571+ self .logger .debug ("Socket closed during send (errno=%s)" , e .errno )
572+ self .is_connected = False
573+ await self ._close_socket ()
574+ raise SocketNotConnectedError from e
575+ raise
473576
474577 data = await asyncio .wait_for (fut , timeout = timeout )
475578 self .logger .debug (
@@ -482,10 +585,15 @@ async def _send_and_wait(
482585 except (ssl .SSLEOFError , ssl .SSLError , ConnectionError , BrokenPipeError ) as conn_err :
483586 self .logger .warning ("Connection lost while sending: %s" , conn_err )
484587 self .is_connected = False
588+ for pending_fut in list (self ._pending .values ()):
589+ if not pending_fut .done ():
590+ pending_fut .set_exception (SocketNotConnectedError ())
591+ self ._pending .clear ()
592+
485593 if not fut .done ():
486594 fut .set_exception (SocketSendError ("connection lost during send" ))
487595
488- self ._socket = None
596+ await self ._close_socket ()
489597 raise SocketSendError ("Connection lost during send" ) from conn_err
490598
491599 except asyncio .TimeoutError :
@@ -507,18 +615,27 @@ async def _get_chat(self, chat_id: int) -> Chat | None:
507615
508616 async def _send_only (self , opcode : Opcode , payload : dict [str , Any ], cmd : int = 0 ) -> None :
509617 async def send_task ():
510- async with self ._sock_lock :
511- if not self .is_connected or self ._socket is None :
512- return
513- msg = self ._make_message (opcode , payload , cmd )
514- packet = self ._pack_packet (
515- msg ["ver" ],
516- msg ["cmd" ],
517- msg ["seq" ],
518- msg ["opcode" ],
519- msg ["payload" ],
520- )
521- loop = asyncio .get_running_loop ()
522- await loop .run_in_executor (None , lambda : self ._socket .sendall (packet ))
618+ try :
619+ async with self ._sock_lock :
620+ if not self .is_connected or self ._socket is None :
621+ self .logger .debug ("Socket not connected in _send_only, skipping" )
622+ return
623+ msg = self ._make_message (opcode , payload , cmd )
624+ packet = self ._pack_packet (
625+ msg ["ver" ],
626+ msg ["cmd" ],
627+ msg ["seq" ],
628+ msg ["opcode" ],
629+ msg ["payload" ],
630+ )
631+ loop = asyncio .get_running_loop ()
632+ await loop .run_in_executor (None , lambda : self ._socket .sendall (packet ))
633+ except (ssl .SSLEOFError , ssl .SSLError , ConnectionError , BrokenPipeError ) as e :
634+ self .logger .debug ("Connection error in _send_only (fire-and-forget): %s" , e )
635+ self .is_connected = False
636+ await self ._close_socket ()
637+ except Exception as e :
638+ self .logger .warning ("Unexpected error in _send_only: %s" , e , exc_info = True )
523639
524- self ._create_safe_task (send_task ())
640+ task = self ._create_safe_task (send_task (), name = "_send_only_task" )
641+ task .add_done_callback (lambda t : self ._log_task_exception (t ))
0 commit comments