forked from kfzteile24/postgresql-proxy
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathproxy.py
More file actions
406 lines (352 loc) · 16.9 KB
/
proxy.py
File metadata and controls
406 lines (352 loc) · 16.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
'''For every configured instance, a Proxy object is created, that starts a listener.
On connect, it initiates a parallel connection to postgresql and pairs them together.
Using selectors, packets are received, intercepted and relayed to the other party.
Protocol:
The challenge is in identifying 3 types of packets:
1. With type and data.
ex. 1 byte for type identifier, 4 bytes header for header and body length, body. Usually the body is ended with
0x00 byte as well, that is part of the length.
The queries are part of this type of packets. A query is b'Q####SELECT whatever\\x00'
2. Without type. They contain just a 4 byte header with packet length. It just so happens that the first byte is 0x00
just because nothing is that long. These contain information about connection.
Usually it's the client sending connection information. Ex.
b'x00x00x00O' - length
b'x00x03x00x00' - unexplained
then, separated by x00 is a list of key, value: user, database, application_name, client_encoding, etc
then, ended by b'x00'
3. Without data. Just the type. Since it's b'N', it might be "null"? The whole packet is this single byte.
Signals "ok" according to wireshark
Handling:
proxy.py - connections and sockets things
connection.py - parsing and composing packets, launching interceptors
interceptors.py - intercepting for modification
'''
from __future__ import annotations
import logging
import selectors
import socket
import ssl
from types import ModuleType
from postgresql_proxy import connection, config_schema as cfg
from postgresql_proxy.interceptors import ResponseInterceptor, CommandInterceptor
LOG = logging.getLogger("postgresql_proxy")
class SelectorKeyProxy(selectors.SelectorKey):
fileobj: socket.socket
data: connection.Connection
fd: int
events: int
class Proxy(object):
def __init__(
self,
instance_config: cfg.InstanceSettings,
plugins: dict[str, ModuleType],
debug: bool = False,
ssl_context: ssl.SSLContext | None = None,
) -> None:
self.plugins = plugins
self.num_clients = 0
self.instance_config = instance_config
self.connections = []
self.selector = selectors.DefaultSelector()
self.running = True
self.sock = None
self.ssl_context = ssl_context
# this is used to track leftover sockets
self._debug = debug
if self._debug:
self._registered_conn = set()
def _create_pg_connection(self, address, context):
redirect_config = self.instance_config.redirect
pg_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
pg_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
pg_sock.connect((redirect_config.host, redirect_config.port))
pg_sock.setblocking(False)
events = selectors.EVENT_READ
redirect_config_name = redirect_config.name + '_' + str(self.num_clients)
pg_conn = connection.Connection(
pg_sock,
name=redirect_config_name,
address=address,
events=events,
context=context
)
LOG.info("initiated client connection to %s:%s called %s",
redirect_config.host, redirect_config.port, redirect_config_name)
return pg_conn
def _register_conn(self, conn: connection.Connection):
try:
self.selector.register(conn.sock, conn.events, data=conn)
except Exception as e:
# potentially already registered - this can happen if file descriptors
# are reused for new sockets -> try to unregister/re-register
LOG.debug("exception while trying to register %s: %s", conn.name, e)
self.selector.modify(conn.sock, conn.events, data=conn)
if self._debug:
self._registered_conn.add(f"{conn.name}-{conn.sock.fileno()}")
def _unregister_conn(self, conn: connection.Connection):
LOG.debug("closing connection %s", conn.name)
self.selector.unregister(conn.sock)
if conn.name.startswith("proxy") and not conn.terminated:
# send Terminate to PG to not leave it hanging waiting for query
# the client did not disconnect properly
# this will cause postgres to close the socket on its side cleanly
try:
LOG.debug("try closing connection %s", conn.redirect_conn.name)
conn.redirect_conn.sock.send(b'X\x00\x00\x00\x04')
# remove reference to itself
conn.redirect_conn.redirect_conn = None
except OSError:
# OSError includes all socket exceptions + Connection* related exceptions
LOG.debug("tried closing connection %s: already closed", conn.redirect_conn.name)
if self._debug:
self._registered_conn.discard(f"{conn.name}-{conn.sock.fileno()}")
def accept_wrapper(self, sock: socket.socket):
"""
This method is called whenever a new client connects to the proxy. It will create a connection to postgres and
proxy all data between both sockets. It will add a `Connection` object to the SelectorKey, to be able to
store state and share data between the sockets.
:param sock: the client socket
:return:
"""
# Accept the raw connection
clientsocket, address = sock.accept()
clientsocket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
# On macOS, accepted sockets inherit O_NONBLOCK from the listening socket.
# SSL negotiation uses blocking recv, so we must set blocking explicitly here.
clientsocket.setblocking(True)
# Check if SSL is enabled for this proxy
if self.ssl_context:
# Handle SSL negotiation - must happen before setblocking(False)
clientsocket = self._handle_ssl_negotiation(clientsocket, self.ssl_context)
clientsocket.setblocking(False)
self.num_clients += 1
sock_name = f"{self.instance_config.listen.name}_{self.num_clients}"
LOG.info(
"Connection from %s, connection initiated %s (SSL: %s)",
address,
sock_name,
self.ssl_context is not None,
)
events = selectors.EVENT_READ
context = {"instance_config": self.instance_config}
conn = connection.Connection(
clientsocket,
name=sock_name,
address=address,
events=events,
context=context,
)
pg_conn = self._create_pg_connection(address, context)
if (
self.instance_config.intercept is not None
and self.instance_config.intercept.responses is not None
):
pg_conn.interceptor = ResponseInterceptor(
self.instance_config.intercept.responses, self.plugins, context
)
pg_conn.redirect_conn = conn
if (
self.instance_config.intercept is not None
and self.instance_config.intercept.commands is not None
):
conn.interceptor = CommandInterceptor(
self.instance_config.intercept.commands, self.plugins, context
)
conn.redirect_conn = pg_conn
self._register_conn(conn)
self._register_conn(pg_conn)
def _handle_ssl_negotiation(
self, client_socket: socket.socket, ssl_context: ssl.SSLContext
) -> socket.socket:
"""
Handle PostgreSQL SSL negotiation on an accepted socket.
PostgreSQL SSL flow:
1. Client sends SSLRequest (8 bytes): length (4) + code 80877103 (4)
2. Server responds 'S' (SSL supported) or 'N' (not supported)
3. If 'S', TLS handshake follows
4. After TLS, normal PostgreSQL protocol begins
Returns the SSL-wrapped socket if negotiation succeeds, or the original socket.
"""
# Peek at the first 8 bytes to check for SSLRequest
# Using MSG_PEEK so we don't consume the data if it's not SSLRequest
data = client_socket.recv(8, socket.MSG_PEEK)
if len(data) == 8:
length = int.from_bytes(data[:4], "big")
code = int.from_bytes(data[4:8], "big")
if length == 8 and code == 80877103: # SSLRequest code
# Consume the SSLRequest
client_socket.recv(8)
# Send 'S' to indicate SSL is supported
client_socket.send(b"S")
# Wrap socket with SSL
ssl_socket = ssl_context.wrap_socket(client_socket, server_side=True)
LOG.debug("SSL handshake completed for PostgreSQL connection")
return ssl_socket
# Not an SSLRequest, return original socket
return client_socket
def service_connection(self, key: SelectorKeyProxy, mask):
"""
This method proxies the messages between socket. It will use properties of the Connection object to
intercept and decode messages, modifies if needed, then send the message to the redirect_conn once it is
fully built.
:param key: SelectorKeyProxy, containing the socket and the Connection object
:param mask: mask of event, indicating what the socket is ready for
:return:
"""
sock = key.fileobj
conn = key.data
if mask & selectors.EVENT_READ:
LOG.debug('%s can receive', conn.name)
try:
recv_data = sock.recv(4096) # Should be ready to read
if recv_data:
LOG.debug('%s received data:\n%s', conn.name, recv_data)
conn.received(recv_data)
# excerpt from https://docs.python.org/3/library/ssl.html#ssl-nonblocking
# Conversely, since the SSL layer has its own framing, a SSL socket may still have data available
# for reading without select() being aware of it. Therefore, you should first call SSLSocket.recv()
# to drain any potentially available data, and then only block on a select() call if still necessary.
while isinstance(sock, ssl.SSLSocket) and sock.pending() > 0:
extra = sock.recv(4096)
if extra:
LOG.debug('%s received pending SSL data:\n%s', conn.name, extra)
conn.received(extra)
else:
self._unregister_conn(conn)
LOG.debug('%s connection closing %s', conn.name, conn.address)
# A file object shall be unregistered prior to being closed.
sock.close()
return
except OSError as e:
# it means the socket was closed by peer
LOG.debug('%s connection closed by peer %s: %s', conn.name, conn.address, e)
self._unregister_conn(conn)
return
if mask & selectors.EVENT_WRITE:
# Socket has buffer space — flush this connection's backlogged output.
try:
while conn.out_bytes:
sent = sock.send(conn.out_bytes)
conn.sent(sent)
# All data drained; stop watching for writability.
conn.events = selectors.EVENT_READ
self.selector.modify(sock, selectors.EVENT_READ, data=conn)
except (BlockingIOError, ssl.SSLWantWriteError):
pass # Still full; will retry on the next EVENT_WRITE notification.
except OSError as e:
LOG.debug('%s closed while flushing backlog: %s', conn.name, e)
self._unregister_conn(conn)
sock.close()
return
next_conn = conn.redirect_conn
if next_conn and next_conn.out_bytes:
try:
while next_conn.out_bytes:
LOG.debug('sending to %s:\n%s', next_conn.name, next_conn.out_bytes)
sent = next_conn.sock.send(next_conn.out_bytes)
next_conn.sent(sent)
# All sent; clear write interest if it was previously registered.
if next_conn.events & selectors.EVENT_WRITE:
next_conn.events = selectors.EVENT_READ
self.selector.modify(next_conn.sock, selectors.EVENT_READ, data=next_conn)
except (BlockingIOError, ssl.SSLWantWriteError):
# next_conn's send buffer is full — register for writability so we retry when there's space.
if not (next_conn.events & selectors.EVENT_WRITE):
next_conn.events = selectors.EVENT_READ | selectors.EVENT_WRITE
self.selector.modify(next_conn.sock, next_conn.events, data=next_conn)
except OSError:
# If one side is closed, close the other one
# this can happen in the case where the client disconnects, and postgres still return a response
# we then read the response then close the PG side of the socket.
LOG.debug('error sending to %s: connection closed', next_conn.name)
self._unregister_conn(conn)
sock.close()
def listen(self, max_connections: int = 8):
"""
Listen server socket. On connect, launch a selector polling for socket readiness to listen
:param max_connections:
:return:
"""
lconf = self.instance_config.listen
ip, port = (lconf.host, lconf.port)
try:
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.bind((ip, port))
self.sock.listen(max_connections)
self.sock.setblocking(False)
self.selector.register(self.sock, selectors.EVENT_READ, data=None)
while self.running:
events = self.selector.select(timeout=1)
if not events:
LOG.debug("polling selector...")
continue
for key, mask in events:
key: SelectorKeyProxy
if key.data is None:
# if the data object has not been set, it means the socket has not yet been accepted
self.accept_wrapper(key.fileobj)
else:
# manage the already proxied connections
self.service_connection(key, mask)
except OSError as ex:
LOG.error("Can't establish PostgreSQL proxy listener on port %s" % port, exc_info=ex)
except Exception:
LOG.exception("PostgreSQL proxy quit unexpectedly:")
finally:
LOG.info("Closing PostgreSQL proxy on port %s" % port)
self.selector.unregister(self.sock)
# this cleans up in case any connection was still opened
# it should not happen anymore
if self._debug:
LOG.debug("Registered connections dangling: %s", self._registered_conn)
registered_selector_sockets = [skey for i, skey in self.selector.get_map().items()]
for selector_key in registered_selector_sockets:
LOG.debug("Connection left: %s", selector_key)
selector_key: SelectorKeyProxy
try:
self.selector.unregister(selector_key.fileobj)
selector_key.fileobj.close()
except OSError:
continue
self.selector.close()
self.sock.close()
self.sock = None
def stop(self):
self.running = False
def main():
import importlib
import yaml
import os
path = os.path.dirname(os.path.realpath(__file__))
config = None
try:
with open(path + '/' + 'config.yml', 'r') as fp:
config = cfg.Config(yaml.load(fp))
except Exception:
logging.critical("Could not read config. Aborting.")
exit(1)
logging.basicConfig(
filename=config.settings.general_log,
level=getattr(logging, config.settings.log_level.upper()),
format='%(asctime)s : %(levelname)s : %(message)s'
)
qlog = logging.getLogger('intercept')
qformat = logging.Formatter('%(asctime)s : %(message)s')
qhandler = logging.FileHandler(config.settings.intercept_log, mode = 'w')
qhandler.setFormatter(qformat)
qlog.addHandler(qhandler)
qlog.setLevel(logging.DEBUG)
print('general log, level {}: {}'.format(config.settings.log_level, config.settings.general_log))
print('intercept log: {}'.format(config.settings.intercept_log))
print('further messages directed to log')
plugins = {}
for plugin in config.plugins:
logging.info("Loading module %s", plugin)
module = importlib.import_module('plugins.' + plugin)
plugins[plugin] = module
for instance in config.instances:
logging.info("Starting proxy instance")
proxy = Proxy(instance, plugins)
proxy.listen()
if __name__ == "__main__":
main()