1616
1717ProxyRequestMatcher = Callable [[Headers ], bool ]
1818
19+
1920class TcpForwarder :
2021 """Simple helper class for bidirectional forwarding of TCP traffic."""
2122
@@ -47,6 +48,7 @@ def close(self):
4748
4849patched_connection = False
4950
51+
5052def apply_http2_patches_for_grpc_support (
5153 target_host : str , target_port : int , should_proxy_request : ProxyRequestMatcher
5254):
@@ -56,7 +58,9 @@ def apply_http2_patches_for_grpc_support(
5658 """
5759 LOG .debug (f"Enabling proxying to backend { target_host } :{ target_port } " )
5860 global patched_connection
59- assert not patched_connection , "It is not safe to patch H2Connection twice with this function"
61+ assert not patched_connection , (
62+ "It is not safe to patch H2Connection twice with this function"
63+ )
6064 patched_connection = True
6165
6266 class ForwardingBuffer :
@@ -65,35 +69,57 @@ class ForwardingBuffer:
6569 data until the ProxyRequestMatcher tells us whether to send it
6670 to the backend, or leave it to the default handler.
6771 """
72+
73+ backend : TcpForwarder
74+ buffer : list
75+ proxying : bool | None
76+
6877 def __init__ (self , http_response_stream ):
6978 self .http_response_stream = http_response_stream
70- LOG .debug (f"Starting TCP forwarder to port { target_port } for new HTTP2 connection" )
79+ LOG .debug (
80+ f"Starting TCP forwarder to port { target_port } for new HTTP2 connection"
81+ )
7182 self .backend = TcpForwarder (target_port , host = target_host )
7283 self .buffer = []
73- self .proxying = False
74- reactor .getThreadPool ().callInThread (self .backend .receive_loop , self .received_from_backend )
84+ self .proxying = None
85+ reactor .getThreadPool ().callInThread (
86+ self .backend .receive_loop , self .received_from_backend
87+ )
7588
7689 def received_from_backend (self , data ):
7790 LOG .debug (f"Received { len (data )} bytes from backend" )
7891 self .http_response_stream .write (data )
7992
80- def received_from_http2_client (self , data , default_handler ):
93+ def received_from_http2_client (self , data , default_handler : Callable ):
94+ if self .proxying is False :
95+ # Note: Return here only if `proxying` is `False` (a value of `None` indicates
96+ # that the headers have not fully been received yet)
97+ return default_handler (data )
98+
8199 if self .proxying :
82100 assert not self .buffer
83101 # Keep sending data to the backend for the lifetime of this connection
84102 self .backend .send (data )
85- else :
86- self .buffer .append (data )
87- if headers := get_headers_from_data_stream (self .buffer ):
88- self .proxying = should_proxy_request (headers )
89- # Now we know what to do with the buffer
90- buffered_data = b"" .join (self .buffer )
91- self .buffer = []
92- if self .proxying :
93- LOG .debug (f"Forwarding { len (buffered_data )} bytes to backend" )
94- self .backend .send (buffered_data )
95- else :
96- return default_handler (buffered_data )
103+ return
104+
105+ self .buffer .append (data )
106+
107+ if not (headers := get_headers_from_data_stream (self .buffer )):
108+ # If no headers received yet, then return (method will be called again for next chunk of data)
109+ return
110+
111+ self .proxying = should_proxy_request (headers )
112+
113+ buffered_data = b"" .join (self .buffer )
114+ self .buffer = []
115+
116+ if not self .proxying :
117+ # if this is not a target request, then call the default handler
118+ default_handler (buffered_data )
119+ return
120+
121+ LOG .debug (f"Forwarding { len (buffered_data )} bytes to backend" )
122+ self .backend .send (buffered_data )
97123
98124 def close (self ):
99125 self .backend .close ()
@@ -104,7 +130,9 @@ def _connectionMade(fn, self, *args, **kwargs):
104130
105131 @patch (H2Connection .dataReceived )
106132 def _dataReceived (fn , self , data , * args , ** kwargs ):
107- self ._ls_forwarding_buffer .received_from_http2_client (data , lambda d : fn (d , * args , ** kwargs ))
133+ self ._ls_forwarding_buffer .received_from_http2_client (
134+ data , lambda d : fn (self , d , * args , ** kwargs )
135+ )
108136
109137 @patch (H2Connection .connectionLost )
110138 def connectionLost (fn , self , * args , ** kwargs ):
@@ -132,12 +160,11 @@ def get_headers_from_frames(frames: Iterable[Frame]) -> Headers:
132160
133161
134162def get_frames_from_http2_stream (data : bytes ) -> Iterable [Frame ]:
135- """Parse the data from an HTTP2 stream into a list of frames"""
136- frames = []
163+ """Parse the data from an HTTP2 stream into an iterable of frames"""
137164 buffer = FrameBuffer (server = True )
138165 buffer .max_frame_size = 16384
139- buffer .add_data (data )
140166 try :
167+ buffer .add_data (data )
141168 for frame in buffer :
142169 yield frame
143170 except Exception :
0 commit comments