Skip to content

Commit 87fa1ae

Browse files
committed
add TCP proxy logic for ParadeDB
1 parent 0f3e62b commit 87fa1ae

File tree

7 files changed

+48
-141
lines changed

7 files changed

+48
-141
lines changed

paradedb/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ entrypoints: venv ## Generate plugin entrypoints for Python package
3434
$(VENV_RUN); python -m plux entrypoints
3535

3636
format: ## Run ruff to format the codebase
37-
$(VENV_RUN); python -m ruff format .; make lint
37+
$(VENV_RUN); python -m ruff format .; python -m ruff check --fix .
3838

3939
lint: ## Run ruff to lint the codebase
4040
$(VENV_RUN); python -m ruff check --output-format=full .

paradedb/localstack_paradedb/extension.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import logging
44

55
from localstack_extensions.utils.docker import ProxiedDockerContainerExtension
6-
from localstack.extensions.api import http
76
from werkzeug.datastructures import Headers
7+
from localstack import config
88

99
LOG = logging.getLogger(__name__)
1010

@@ -50,29 +50,33 @@ def __init__(self):
5050
}
5151

5252
def _tcp_health_check():
53-
"""Check if PostgreSQL port is accepting connections."""
53+
"""Check if ParadeDB port is accepting connections."""
5454
self._check_tcp_port(self.container_host, self.postgres_port)
5555

5656
super().__init__(
5757
image_name=self.DOCKER_IMAGE,
5858
container_ports=[postgres_port],
5959
env_vars=env_vars,
6060
health_check_fn=_tcp_health_check,
61+
tcp_ports=[postgres_port], # Enable TCP proxying through gateway
6162
)
6263

63-
def should_proxy_request(self, headers: Headers) -> bool:
64+
def tcp_connection_matcher(self, data: bytes) -> bool:
6465
"""
65-
Define whether a request should be proxied based on request headers.
66-
For database extensions, this is not used as connections are direct TCP.
66+
Identify PostgreSQL/ParadeDB connections by protocol handshake.
67+
68+
PostgreSQL startup message format:
69+
- 4 bytes: message length
70+
- 4 bytes: protocol version (3.0 = 0x00030000)
6771
"""
68-
return False
72+
return len(data) >= 8 and data[4:8] == b"\x00\x03\x00\x00"
6973

70-
def update_gateway_routes(self, router: http.Router[http.RouteHandler]):
74+
def should_proxy_request(self, headers: Headers) -> bool:
7175
"""
72-
Override to start container without setting up HTTP gateway routes.
73-
Database extensions don't need HTTP routing - clients connect directly via TCP.
76+
Define whether a request should be proxied based on request headers.
77+
Not used for TCP connections - see tcp_connection_matcher instead.
7478
"""
75-
self.start_container()
79+
return False
7680

7781
def _check_tcp_port(self, host: str, port: int, timeout: float = 2.0) -> None:
7882
"""Check if a TCP port is accepting connections."""
@@ -86,15 +90,21 @@ def _check_tcp_port(self, host: str, port: int, timeout: float = 2.0) -> None:
8690

8791
def get_connection_info(self) -> dict:
8892
"""Return connection information for ParadeDB."""
93+
# Clients should connect through the LocalStack gateway
94+
gateway_host = "paradedb.localhost.localstack.cloud"
95+
gateway_port = config.LOCALSTACK_HOST.port
96+
8997
return {
90-
"host": self.container_host,
98+
"host": gateway_host,
9199
"database": self.postgres_db,
92100
"user": self.postgres_user,
93101
"password": self.postgres_password,
94-
"port": self.postgres_port,
95-
"ports": {self.postgres_port: self.postgres_port},
102+
"port": gateway_port,
96103
"connection_string": (
97104
f"postgresql://{self.postgres_user}:{self.postgres_password}"
98-
f"@{self.container_host}:{self.postgres_port}/{self.postgres_db}"
105+
f"@{gateway_host}:{gateway_port}/{self.postgres_db}"
99106
),
107+
# Also include container connection details for debugging
108+
"container_host": self.container_host,
109+
"container_port": self.postgres_port,
100110
}

paradedb/tests/test_extension.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33

44

55
# Connection details for ParadeDB
6-
HOST = "localhost"
7-
PORT = 5432
6+
# Connect through LocalStack gateway with TCP proxying
7+
HOST = "paradedb.localhost.localstack.cloud"
8+
PORT = 4566
89
USER = "myuser"
910
PASSWORD = "mypassword"
1011
DATABASE = "mydatabase"

utils/localstack_extensions/utils/docker.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
from rolo.proxy import Proxy
2020
from rolo.routing import RuleAdapter, WithHost
2121
from werkzeug.datastructures import Headers
22-
from twisted.internet import reactor
23-
from twisted.protocols.portforward import ProxyFactory
2422

2523
LOG = logging.getLogger(__name__)
2624

utils/localstack_extensions/utils/tcp_protocol_detector.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,7 @@ def matcher(data: bytes) -> bool:
3535
return matcher
3636

3737

38-
def create_signature_matcher(
39-
signature: bytes, offset: int = 0
40-
) -> ConnectionMatcher:
38+
def create_signature_matcher(signature: bytes, offset: int = 0) -> ConnectionMatcher:
4139
"""
4240
Create a matcher that matches bytes at a specific offset.
4341

utils/localstack_extensions/utils/tcp_protocol_router.py

Lines changed: 4 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"""
88

99
import logging
10-
from twisted.internet import protocol, reactor
10+
from twisted.internet import reactor
1111
from twisted.protocols.portforward import ProxyClient, ProxyClientFactory
1212
from twisted.web.http import HTTPChannel
1313

@@ -21,99 +21,6 @@
2121
_gateway_patched = False
2222

2323

24-
class ProtocolDetectingChannel(HTTPChannel):
25-
"""
26-
HTTP channel wrapper that detects TCP protocols before HTTP processing.
27-
28-
This wraps the standard HTTPChannel and intercepts the first dataReceived call
29-
to check if it's a known TCP protocol. If so, it morphs into a TCP proxy.
30-
Otherwise, it passes through to normal HTTP handling.
31-
"""
32-
33-
def __init__(self):
34-
super().__init__()
35-
self._detection_buffer = []
36-
self._detecting = True
37-
self._tcp_peer = None
38-
self._detection_buffer_size = 512
39-
40-
def dataReceived(self, data):
41-
"""Intercept data to detect protocol before HTTP processing."""
42-
if not self._detecting:
43-
# Already decided - either proxying TCP or processing HTTP
44-
if self._tcp_peer:
45-
# TCP proxying mode
46-
self._tcp_peer.transport.write(data)
47-
else:
48-
# HTTP mode - pass to parent
49-
super().dataReceived(data)
50-
return
51-
52-
# Still detecting - buffer data
53-
self._detection_buffer.append(data)
54-
buffered_data = b"".join(self._detection_buffer)
55-
56-
# Try detection once we have enough bytes
57-
if len(buffered_data) >= 8:
58-
protocol_name = detect_protocol(buffered_data)
59-
60-
if protocol_name and protocol_name not in ("http", "http2"):
61-
# Known TCP protocol (not HTTP) - check if we have a backend
62-
backend_info = _protocol_backends.get(protocol_name)
63-
64-
if backend_info:
65-
LOG.info(
66-
f"Detected {protocol_name} on gateway port, routing to "
67-
f"{backend_info['host']}:{backend_info['port']}"
68-
)
69-
self._switch_to_tcp_proxy(
70-
backend_info["host"], backend_info["port"], buffered_data
71-
)
72-
self._detecting = False
73-
return
74-
75-
# Not a known TCP protocol, or no backend configured
76-
# Check if we've buffered enough
77-
if (
78-
len(buffered_data) >= self._detection_buffer_size
79-
or protocol_name in ("http", "http2")
80-
):
81-
LOG.debug(
82-
f"Protocol detected as {protocol_name or 'unknown'}, using HTTP handler"
83-
)
84-
self._detecting = False
85-
# Feed buffered data to HTTP handler
86-
for chunk in self._detection_buffer:
87-
super().dataReceived(chunk)
88-
self._detection_buffer = []
89-
90-
def _switch_to_tcp_proxy(self, host, port, initial_data):
91-
"""Switch this connection to TCP proxy mode."""
92-
# Pause reading while we establish backend connection
93-
self.transport.pauseProducing()
94-
95-
# Create backend connection
96-
client_factory = ProxyClientFactory()
97-
client_factory.server = self
98-
client_factory.initial_data = initial_data
99-
100-
# Connect to backend
101-
reactor.connectTCP(host, port, client_factory)
102-
103-
def set_tcp_peer(self, peer):
104-
"""Called when backend connection is established."""
105-
self._tcp_peer = peer
106-
# Resume reading from client
107-
self.transport.resumeProducing()
108-
109-
def connectionLost(self, reason):
110-
"""Handle connection close."""
111-
if self._tcp_peer:
112-
self._tcp_peer.transport.loseConnection()
113-
self._tcp_peer = None
114-
super().connectionLost(reason)
115-
116-
11724
class TcpProxyClient(ProxyClient):
11825
"""Backend TCP connection for protocol-detected connections."""
11926

@@ -261,7 +168,9 @@ def register_tcp_extension(
261168
backend_port: Backend port to route to
262169
"""
263170
_tcp_extensions.append((extension_name, matcher, backend_host, backend_port))
264-
LOG.info(f"Registered TCP extension {extension_name} -> {backend_host}:{backend_port}")
171+
LOG.info(
172+
f"Registered TCP extension {extension_name} -> {backend_host}:{backend_port}"
173+
)
265174

266175

267176
def unregister_tcp_extension(extension_name: str):

utils/tests/unit/test_tcp_protocol_detector.py

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
Unit tests for TCP connection matcher helpers.
33
"""
44

5-
import pytest
65
from localstack_extensions.utils.tcp_protocol_detector import (
76
create_prefix_matcher,
87
create_signature_matcher,
98
create_custom_matcher,
109
combine_matchers,
1110
)
11+
from localstack_extensions.utils.docker import ProxiedDockerContainerExtension
12+
from werkzeug.datastructures import Headers
1213

1314

1415
class TestMatcherFactories:
@@ -26,13 +27,13 @@ def test_create_prefix_matcher(self):
2627
def test_create_signature_matcher(self):
2728
"""Test creating a signature matcher with offset."""
2829
# Match signature at offset 4
29-
matcher = create_signature_matcher(b"\xAA\xBB", offset=4)
30+
matcher = create_signature_matcher(b"\xaa\xbb", offset=4)
3031

31-
assert matcher(b"\x00\x00\x00\x00\xAA\xBB\xCC")
32-
assert matcher(b"\x00\x00\x00\x00\xAA\xBB")
33-
assert not matcher(b"\xAA\xBB\xCC") # Wrong offset
34-
assert not matcher(b"\x00\x00\x00\x00\xCC\xDD") # Wrong signature
35-
assert not matcher(b"\x00\x00\x00\x00\xAA") # Incomplete
32+
assert matcher(b"\x00\x00\x00\x00\xaa\xbb\xcc")
33+
assert matcher(b"\x00\x00\x00\x00\xaa\xbb")
34+
assert not matcher(b"\xaa\xbb\xcc") # Wrong offset
35+
assert not matcher(b"\x00\x00\x00\x00\xcc\xdd") # Wrong signature
36+
assert not matcher(b"\x00\x00\x00\x00\xaa") # Incomplete
3637

3738
def test_create_custom_matcher(self):
3839
"""Test creating a custom matcher."""
@@ -42,8 +43,8 @@ def my_check(data):
4243

4344
matcher = create_custom_matcher(my_check)
4445

45-
assert matcher(b"\x00\x00\x00\x00\x00\xFF")
46-
assert matcher(b"\x00\x00\x00\x00\x00\xFF\xFF")
46+
assert matcher(b"\x00\x00\x00\x00\x00\xff")
47+
assert matcher(b"\x00\x00\x00\x00\x00\xff\xff")
4748
assert not matcher(b"\x00\x00\x00\x00\x00\x00")
4849
assert not matcher(b"\x00\x00\x00\x00\x00") # Too short
4950

@@ -88,16 +89,14 @@ def test_matcher_with_extra_data(self):
8889
matcher = create_prefix_matcher(b"PREFIX")
8990

9091
# Should match even with lots of extra data
91-
assert matcher(b"PREFIX" + b"\xFF" * 1000)
92+
assert matcher(b"PREFIX" + b"\xff" * 1000)
9293

9394

9495
class TestRealWorldUsage:
9596
"""Tests for real-world usage patterns."""
9697

9798
def test_extension_with_custom_protocol_matcher(self):
9899
"""Test using custom matchers in an extension context."""
99-
from localstack_extensions.utils.docker import ProxiedDockerContainerExtension
100-
from werkzeug.datastructures import Headers
101100

102101
class CustomProtocolExtension(ProxiedDockerContainerExtension):
103102
name = "custom"
@@ -111,7 +110,7 @@ def __init__(self):
111110

112111
def tcp_connection_matcher(self, data: bytes) -> bool:
113112
# Match custom protocol with magic bytes at offset 4
114-
matcher = create_signature_matcher(b"\xDE\xAD\xBE\xEF", offset=4)
113+
matcher = create_signature_matcher(b"\xde\xad\xbe\xef", offset=4)
115114
return matcher(data)
116115

117116
def should_proxy_request(self, headers: Headers) -> bool:
@@ -121,16 +120,14 @@ def should_proxy_request(self, headers: Headers) -> bool:
121120
assert hasattr(extension, "tcp_connection_matcher")
122121

123122
# Test the matcher
124-
valid_data = b"\x00\x00\x00\x00\xDE\xAD\xBE\xEF\xFF"
123+
valid_data = b"\x00\x00\x00\x00\xde\xad\xbe\xef\xff"
125124
assert extension.tcp_connection_matcher(valid_data)
126125

127-
invalid_data = b"\x00\x00\x00\x00\xFF\xFF\xFF\xFF"
126+
invalid_data = b"\x00\x00\x00\x00\xff\xff\xff\xff"
128127
assert not extension.tcp_connection_matcher(invalid_data)
129128

130129
def test_extension_with_combined_matchers(self):
131130
"""Test using combined matchers in an extension."""
132-
from localstack_extensions.utils.docker import ProxiedDockerContainerExtension
133-
from werkzeug.datastructures import Headers
134131

135132
class MultiProtocolExtension(ProxiedDockerContainerExtension):
136133
name = "multi-protocol"
@@ -160,8 +157,6 @@ def should_proxy_request(self, headers: Headers) -> bool:
160157

161158
def test_extension_with_inline_matcher(self):
162159
"""Test using an inline matcher function."""
163-
from localstack_extensions.utils.docker import ProxiedDockerContainerExtension
164-
from werkzeug.datastructures import Headers
165160

166161
class InlineMatcherExtension(ProxiedDockerContainerExtension):
167162
name = "inline"
@@ -175,11 +170,7 @@ def __init__(self):
175170

176171
def tcp_connection_matcher(self, data: bytes) -> bool:
177172
# Inline custom logic without helper functions
178-
return (
179-
len(data) >= 8
180-
and data.startswith(b"MAGIC")
181-
and data[7] == 0x42
182-
)
173+
return len(data) >= 8 and data.startswith(b"MAGIC") and data[7] == 0x42
183174

184175
def should_proxy_request(self, headers: Headers) -> bool:
185176
return False

0 commit comments

Comments
 (0)