Skip to content

Commit 324084e

Browse files
committed
fix matcher logic; clean up logging; add Postgres SSL matching support
1 parent f53c70e commit 324084e

File tree

4 files changed

+46
-34
lines changed

4 files changed

+46
-34
lines changed

paradedb/localstack_paradedb/extension.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
import os
22
import socket
3-
import logging
43

54
from localstack_extensions.utils.docker import ProxiedDockerContainerExtension
65
from werkzeug.datastructures import Headers
76
from localstack import config
87

9-
LOG = logging.getLogger(__name__)
10-
118
# Environment variables for configuration
129
ENV_POSTGRES_USER = "PARADEDB_POSTGRES_USER"
1310
ENV_POSTGRES_PASSWORD = "PARADEDB_POSTGRES_PASSWORD"
@@ -65,11 +62,26 @@ def tcp_connection_matcher(self, data: bytes) -> bool:
6562
"""
6663
Identify PostgreSQL/ParadeDB connections by protocol handshake.
6764
68-
PostgreSQL startup message format:
65+
PostgreSQL can start with either:
66+
1. SSL request: protocol code 80877103 (0x04D2162F)
67+
2. Startup message: protocol version 3.0 (0x00030000)
68+
69+
Both use the same format:
6970
- 4 bytes: message length
70-
- 4 bytes: protocol version (3.0 = 0x00030000)
71+
- 4 bytes: protocol version/code
7172
"""
72-
return len(data) >= 8 and data[4:8] == b"\x00\x03\x00\x00"
73+
if len(data) < 8:
74+
return False
75+
76+
# Check for SSL request (80877103 = 0x04D2162F)
77+
if data[4:8] == b"\x04\xd2\x16\x2f":
78+
return True
79+
80+
# Check for protocol version 3.0 (0x00030000)
81+
if data[4:8] == b"\x00\x03\x00\x00":
82+
return True
83+
84+
return False
7385

7486
def should_proxy_request(self, headers: Headers) -> bool:
7587
"""

utils/localstack_extensions/utils/docker.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,7 @@ def _setup_tcp_protocol_routing(self):
170170
)
171171

172172
LOG.info(
173-
f"Registered TCP extension {self.name} -> "
174-
f"{self.container_host}:{target_port} on gateway"
173+
f"Registered TCP extension {self.name} -> {self.container_host}:{target_port} on gateway"
175174
)
176175

177176
@abstractmethod
@@ -220,8 +219,6 @@ def start_container(self) -> None:
220219
self._remove_container()
221220
raise
222221

223-
LOG.debug("Successfully started extension container %s", self.container_name)
224-
225222
def _default_health_check(self) -> None:
226223
"""Default health check: HTTP GET request to the main port."""
227224
response = requests.get(f"http://{self.container_host}:{self.main_port}/")

utils/localstack_extensions/utils/h2_proxy.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,29 @@ def __init__(self, port: int, host: str = "localhost"):
2929
self.host = host
3030
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
3131
self._socket.connect((self.host, self.port))
32+
self._closed = False
3233

3334
def receive_loop(self, callback):
34-
while data := self._socket.recv(self.buffer_size):
35+
while data := self.recv(self.buffer_size):
3536
callback(data)
3637

38+
def recv(self, length):
39+
try:
40+
return self._socket.recv(length)
41+
except OSError as e:
42+
if self._closed:
43+
return None
44+
else:
45+
raise e
46+
3747
def send(self, data):
3848
self._socket.sendall(data)
3949

4050
def close(self):
51+
if self._closed:
52+
return
4153
LOG.debug(f"Closing connection to upstream HTTP2 server on port {self.port}")
54+
self._closed = True
4255
try:
4356
self._socket.shutdown(socket.SHUT_RDWR)
4457
self._socket.close()
@@ -93,7 +106,6 @@ def __init__(self, http_response_stream):
93106
)
94107

95108
def received_from_backend(self, data):
96-
LOG.debug(f"Received {len(data)} bytes from backend")
97109
self.http_response_stream.write(data)
98110

99111
def received_from_http2_client(self, data, default_handler: Callable):
@@ -113,9 +125,6 @@ def received_from_http2_client(self, data, default_handler: Callable):
113125

114126
if should_proxy_request(headers):
115127
self.state = ForwardingState.FORWARDING
116-
LOG.debug(
117-
f"Forwarding {len(buffered_data)} bytes to backend"
118-
)
119128
self.backend.send(buffered_data)
120129
else:
121130
self.state = ForwardingState.PASSTHROUGH

utils/localstack_extensions/utils/tcp_protocol_router.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
from twisted.web.http import HTTPChannel
1313

1414
from localstack.utils.patch import patch
15+
from localstack import config
1516

1617
LOG = logging.getLogger(__name__)
18+
LOG.setLevel(logging.DEBUG if config.DEBUG else logging.INFO)
1719

1820
# Global registry of extensions with TCP matchers
1921
# List of tuples: (extension_name, matcher_func, backend_host, backend_port)
@@ -31,14 +33,19 @@ def connectionMade(self):
3133
# Set up peer relationship
3234
server.set_tcp_peer(self)
3335

36+
# Unregister any existing producer on server transport (HTTPChannel may have one)
37+
try:
38+
server.transport.unregisterProducer()
39+
except Exception:
40+
pass # No producer was registered, which is fine
41+
3442
# Enable flow control
3543
self.transport.registerProducer(server.transport, True)
3644
server.transport.registerProducer(self.transport, True)
3745

3846
# Send buffered data from detection phase
3947
if hasattr(self.factory, "initial_data"):
4048
initial_data = self.factory.initial_data
41-
LOG.debug(f"Sending {len(initial_data)} buffered bytes to backend")
4249
self.transport.write(initial_data)
4350
del self.factory.initial_data
4451

@@ -61,12 +68,8 @@ def patch_gateway_for_tcp_routing():
6168
global _gateway_patched
6269

6370
if _gateway_patched:
64-
LOG.debug("Gateway already patched for TCP routing")
6571
return
6672

67-
LOG.debug("Patching LocalStack gateway for TCP protocol detection")
68-
peek_bytes_length = 32
69-
7073
# Patch HTTPChannel to use our protocol-detecting version
7174
@patch(HTTPChannel.__init__)
7275
def _patched_init(fn, self, *args, **kwargs):
@@ -76,7 +79,6 @@ def _patched_init(fn, self, *args, **kwargs):
7679
self._detection_buffer = []
7780
self._detecting = True
7881
self._tcp_peer = None
79-
self._detection_buffer_size = peek_bytes_length
8082

8183
@patch(HTTPChannel.dataReceived)
8284
def _patched_dataReceived(fn, self, data):
@@ -102,10 +104,6 @@ def _patched_dataReceived(fn, self, data):
102104
for ext_name, matcher, backend_host, backend_port in _tcp_extensions:
103105
try:
104106
if matcher(buffered_data):
105-
LOG.info(
106-
f"Extension {ext_name} claimed connection, routing to "
107-
f"{backend_host}:{backend_port}"
108-
)
109107
# Switch to TCP proxy mode
110108
self._detecting = False
111109
self.transport.pauseProducing()
@@ -123,14 +121,11 @@ def _patched_dataReceived(fn, self, data):
123121
continue
124122

125123
# No extension claimed the connection
126-
buffer_size = getattr(self, "_detection_buffer_size", peek_bytes_length)
127-
if len(buffered_data) >= buffer_size:
128-
LOG.debug("No TCP extension matched, using HTTP handler")
129-
self._detecting = False
130-
# Feed buffered data to HTTP handler
131-
for chunk in self._detection_buffer:
132-
fn(self, chunk)
133-
self._detection_buffer = []
124+
self._detecting = False
125+
# Feed buffered data to HTTP handler
126+
for chunk in self._detection_buffer:
127+
fn(self, chunk)
128+
self._detection_buffer = []
134129

135130
@patch(HTTPChannel.connectionLost)
136131
def _patched_connectionLost(fn, self, reason):
@@ -150,7 +145,6 @@ def set_tcp_peer(self, peer):
150145
HTTPChannel.set_tcp_peer = set_tcp_peer
151146

152147
_gateway_patched = True
153-
LOG.info("Gateway patched successfully for TCP protocol routing")
154148

155149

156150
def register_tcp_extension(

0 commit comments

Comments
 (0)