Skip to content

Commit 4ffdc1e

Browse files
committed
feat: add custom socket transport support for Postgres and MySQL
Add methods to PgConnection and MySqlConnection that accept pre-connected sockets implementing AsyncRead + AsyncWrite, enabling custom transport layers (vsock, QUIC, turmoil, SSH tunnels, etc.) without forking sqlx. Per maintainer feedback on transact-rs#4187, this uses AsyncRead + AsyncWrite traits instead of exposing the internal Socket trait. Two separate methods are provided for each runtime's trait set: - connect_with_custom_tokio(): accepts tokio::io::{AsyncRead, AsyncWrite} - connect_with_custom_futures(): accepts futures_io::{AsyncRead, AsyncWrite} Also adds PoolOptions::connector() so pools can use custom transports: PgPoolOptions::new() .connector(|options| async move { let socket = VsockStream::connect(addr).await?; PgConnection::connect_with_custom_tokio(socket, &options).await }) .connect_with(options) .await? TLS upgrade is negotiated automatically based on the connection options. No new public trait exposure. No behavioral changes to existing code.
1 parent 75bc048 commit 4ffdc1e

11 files changed

Lines changed: 589 additions & 7 deletions

File tree

Cargo.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,10 @@ _unstable-docs = [
9494
]
9595

9696
# Base runtime features without TLS
97-
runtime-async-global-executor = ["_rt-async-global-executor", "sqlx-core/_rt-async-global-executor", "sqlx-macros?/_rt-async-global-executor"]
98-
runtime-async-std = ["_rt-async-std", "sqlx-core/_rt-async-std", "sqlx-macros?/_rt-async-std"]
99-
runtime-smol = ["_rt-smol", "sqlx-core/_rt-smol", "sqlx-macros?/_rt-smol"]
100-
runtime-tokio = ["_rt-tokio", "sqlx-core/_rt-tokio", "sqlx-macros?/_rt-tokio"]
97+
runtime-async-global-executor = ["_rt-async-global-executor", "sqlx-core/_rt-async-global-executor", "sqlx-macros?/_rt-async-global-executor", "sqlx-postgres?/_rt-async-io", "sqlx-mysql?/_rt-async-io"]
98+
runtime-async-std = ["_rt-async-std", "sqlx-core/_rt-async-std", "sqlx-macros?/_rt-async-std", "sqlx-postgres?/_rt-async-io", "sqlx-mysql?/_rt-async-io"]
99+
runtime-smol = ["_rt-smol", "sqlx-core/_rt-smol", "sqlx-macros?/_rt-smol", "sqlx-postgres?/_rt-async-io", "sqlx-mysql?/_rt-async-io"]
100+
runtime-tokio = ["_rt-tokio", "sqlx-core/_rt-tokio", "sqlx-macros?/_rt-tokio", "sqlx-postgres?/_rt-tokio", "sqlx-mysql?/_rt-tokio"]
101101

102102
# TLS features
103103
tls-native-tls = ["sqlx-core/_tls-native-tls", "sqlx-macros?/_tls-native-tls"]

sqlx-core/src/net/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,9 @@ pub mod tls;
44
pub use socket::{
55
connect_tcp, connect_uds, BufferedSocket, Socket, SocketIntoBox, WithSocket, WriteBuffer,
66
};
7+
8+
#[cfg(feature = "_rt-tokio")]
9+
pub use socket::async_rw_adapter::TokioStream;
10+
11+
#[cfg(feature = "_rt-async-io")]
12+
pub use socket::async_rw_adapter::FuturesStream;
Lines changed: 346 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,346 @@
1+
use std::io;
2+
use std::pin::Pin;
3+
use std::task::{Context, Poll};
4+
5+
use bytes::BufMut;
6+
7+
use crate::io::ReadBuf;
8+
use crate::net::Socket;
9+
10+
// Internal buffer size for the read-ahead used by `poll_read_ready`.
11+
const ADAPTER_BUF_SIZE: usize = 8192;
12+
13+
/// Adapter that wraps a tokio [`AsyncRead`][tokio::io::AsyncRead] + [`AsyncWrite`][tokio::io::AsyncWrite]
14+
/// into a [`Socket`] implementation.
15+
#[cfg(feature = "_rt-tokio")]
16+
pub struct TokioStream<S> {
17+
inner: S,
18+
read_buf: Vec<u8>,
19+
read_len: usize,
20+
read_pos: usize,
21+
}
22+
23+
#[cfg(feature = "_rt-tokio")]
24+
impl<S> TokioStream<S> {
25+
pub fn new(inner: S) -> Self {
26+
Self {
27+
inner,
28+
read_buf: vec![0u8; ADAPTER_BUF_SIZE],
29+
read_len: 0,
30+
read_pos: 0,
31+
}
32+
}
33+
34+
fn buffered(&self) -> &[u8] {
35+
&self.read_buf[self.read_pos..self.read_len]
36+
}
37+
}
38+
39+
#[cfg(feature = "_rt-tokio")]
40+
impl<S> Socket for TokioStream<S>
41+
where
42+
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Sync + Unpin + 'static,
43+
{
44+
fn try_read(&mut self, buf: &mut dyn ReadBuf) -> io::Result<usize> {
45+
let buffered = self.buffered();
46+
if !buffered.is_empty() {
47+
let to_copy = buffered.len().min(buf.remaining_mut());
48+
buf.put_slice(&buffered[..to_copy]);
49+
self.read_pos += to_copy;
50+
if self.read_pos == self.read_len {
51+
self.read_pos = 0;
52+
self.read_len = 0;
53+
}
54+
return Ok(to_copy);
55+
}
56+
Err(io::Error::from(io::ErrorKind::WouldBlock))
57+
}
58+
59+
fn try_write(&mut self, buf: &[u8]) -> io::Result<usize> {
60+
// Use a noop waker to attempt a non-blocking write.
61+
let waker = futures_util::task::noop_waker();
62+
let mut cx = Context::from_waker(&waker);
63+
let pin = Pin::new(&mut self.inner);
64+
match pin.poll_write(&mut cx, buf) {
65+
Poll::Ready(result) => result,
66+
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
67+
}
68+
}
69+
70+
fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
71+
// If we already have buffered data, we're ready.
72+
if !self.buffered().is_empty() {
73+
return Poll::Ready(Ok(()));
74+
}
75+
76+
// Reset buffer positions.
77+
self.read_pos = 0;
78+
self.read_len = 0;
79+
80+
let mut read_buf = tokio::io::ReadBuf::new(&mut self.read_buf);
81+
let pin = Pin::new(&mut self.inner);
82+
match pin.poll_read(cx, &mut read_buf) {
83+
Poll::Ready(Ok(())) => {
84+
let n = read_buf.filled().len();
85+
if n == 0 {
86+
return Poll::Ready(Err(io::Error::from(io::ErrorKind::UnexpectedEof)));
87+
}
88+
self.read_len = n;
89+
Poll::Ready(Ok(()))
90+
}
91+
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
92+
Poll::Pending => Poll::Pending,
93+
}
94+
}
95+
96+
fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
97+
// For tokio AsyncWrite, we just attempt a zero-length write to check readiness.
98+
// Most implementations will return Ready(Ok(0)) if writable.
99+
// A more robust approach: just return Ready since we rely on try_write to handle Pending.
100+
let pin = Pin::new(&mut self.inner);
101+
match pin.poll_write(cx, &[]) {
102+
Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
103+
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
104+
Poll::Pending => Poll::Pending,
105+
}
106+
}
107+
108+
fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
109+
Pin::new(&mut self.inner).poll_flush(cx)
110+
}
111+
112+
fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
113+
Pin::new(&mut self.inner).poll_shutdown(cx)
114+
}
115+
}
116+
117+
/// Adapter that wraps a futures-io [`AsyncRead`][futures_io::AsyncRead] + [`AsyncWrite`][futures_io::AsyncWrite]
118+
/// into a [`Socket`] implementation.
119+
#[cfg(feature = "_rt-async-io")]
120+
pub struct FuturesStream<S> {
121+
inner: S,
122+
read_buf: Vec<u8>,
123+
read_len: usize,
124+
read_pos: usize,
125+
}
126+
127+
#[cfg(feature = "_rt-async-io")]
128+
impl<S> FuturesStream<S> {
129+
pub fn new(inner: S) -> Self {
130+
Self {
131+
inner,
132+
read_buf: vec![0u8; ADAPTER_BUF_SIZE],
133+
read_len: 0,
134+
read_pos: 0,
135+
}
136+
}
137+
138+
fn buffered(&self) -> &[u8] {
139+
&self.read_buf[self.read_pos..self.read_len]
140+
}
141+
}
142+
143+
#[cfg(feature = "_rt-async-io")]
144+
impl<S> Socket for FuturesStream<S>
145+
where
146+
S: futures_io::AsyncRead + futures_io::AsyncWrite + Send + Sync + Unpin + 'static,
147+
{
148+
fn try_read(&mut self, buf: &mut dyn ReadBuf) -> io::Result<usize> {
149+
let buffered = self.buffered();
150+
if !buffered.is_empty() {
151+
let to_copy = buffered.len().min(buf.remaining_mut());
152+
buf.put_slice(&buffered[..to_copy]);
153+
self.read_pos += to_copy;
154+
if self.read_pos == self.read_len {
155+
self.read_pos = 0;
156+
self.read_len = 0;
157+
}
158+
return Ok(to_copy);
159+
}
160+
Err(io::Error::from(io::ErrorKind::WouldBlock))
161+
}
162+
163+
fn try_write(&mut self, buf: &[u8]) -> io::Result<usize> {
164+
let waker = futures_util::task::noop_waker();
165+
let mut cx = Context::from_waker(&waker);
166+
let pin = Pin::new(&mut self.inner);
167+
match pin.poll_write(&mut cx, buf) {
168+
Poll::Ready(result) => result,
169+
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
170+
}
171+
}
172+
173+
fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
174+
if !self.buffered().is_empty() {
175+
return Poll::Ready(Ok(()));
176+
}
177+
178+
self.read_pos = 0;
179+
self.read_len = 0;
180+
181+
let pin = Pin::new(&mut self.inner);
182+
match pin.poll_read(cx, &mut self.read_buf) {
183+
Poll::Ready(Ok(n)) => {
184+
if n == 0 {
185+
return Poll::Ready(Err(io::Error::from(io::ErrorKind::UnexpectedEof)));
186+
}
187+
self.read_len = n;
188+
Poll::Ready(Ok(()))
189+
}
190+
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
191+
Poll::Pending => Poll::Pending,
192+
}
193+
}
194+
195+
fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
196+
let pin = Pin::new(&mut self.inner);
197+
match pin.poll_write(cx, &[]) {
198+
Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
199+
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
200+
Poll::Pending => Poll::Pending,
201+
}
202+
}
203+
204+
fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
205+
Pin::new(&mut self.inner).poll_flush(cx)
206+
}
207+
208+
fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
209+
Pin::new(&mut self.inner).poll_close(cx)
210+
}
211+
}
212+
213+
#[cfg(test)]
214+
mod tests {
215+
use super::*;
216+
use std::io::Cursor;
217+
218+
// Cursor<Vec<u8>> implements both tokio and futures-io AsyncRead/AsyncWrite
219+
220+
#[cfg(feature = "_rt-tokio")]
221+
mod tokio_adapter {
222+
use super::*;
223+
use crate::net::Socket;
224+
use bytes::BytesMut;
225+
use std::task::{Context, Poll};
226+
227+
fn noop_cx() -> Context<'static> {
228+
Context::from_waker(futures_util::task::noop_waker_ref())
229+
}
230+
231+
#[test]
232+
fn try_read_returns_would_block_when_empty() {
233+
let stream = tokio::io::duplex(64).0;
234+
let mut adapter = TokioStream::new(stream);
235+
let mut buf = BytesMut::with_capacity(32);
236+
let err = adapter.try_read(&mut buf).unwrap_err();
237+
assert_eq!(err.kind(), io::ErrorKind::WouldBlock);
238+
}
239+
240+
#[test]
241+
fn poll_read_ready_fills_buffer_then_try_read_drains() {
242+
// Use a duplex stream with pre-written data
243+
let (client, mut server) = tokio::io::duplex(64);
244+
245+
// Write data into the server side so client can read it
246+
let rt = tokio::runtime::Builder::new_current_thread()
247+
.enable_all()
248+
.build()
249+
.unwrap();
250+
251+
rt.block_on(async {
252+
use tokio::io::AsyncWriteExt;
253+
server.write_all(b"hello world").await.unwrap();
254+
255+
let mut adapter = TokioStream::new(client);
256+
let mut buf = BytesMut::with_capacity(32);
257+
258+
// poll_read_ready should fill internal buffer
259+
let poll =
260+
std::future::poll_fn(|cx| adapter.poll_read_ready(cx)).await;
261+
assert!(poll.is_ok());
262+
263+
// try_read should drain from internal buffer
264+
let n = adapter.try_read(&mut buf).unwrap();
265+
assert_eq!(&buf[..n], b"hello world");
266+
267+
// After draining, try_read should return WouldBlock again
268+
let mut buf2 = BytesMut::with_capacity(32);
269+
let err = adapter.try_read(&mut buf2).unwrap_err();
270+
assert_eq!(err.kind(), io::ErrorKind::WouldBlock);
271+
});
272+
}
273+
274+
#[test]
275+
fn try_write_writes_data() {
276+
let (client, mut server) = tokio::io::duplex(64);
277+
278+
let rt = tokio::runtime::Builder::new_current_thread()
279+
.enable_all()
280+
.build()
281+
.unwrap();
282+
283+
rt.block_on(async {
284+
use tokio::io::AsyncReadExt;
285+
let mut adapter = TokioStream::new(client);
286+
287+
// Write via the adapter
288+
let n = std::future::poll_fn(|cx| {
289+
match adapter.try_write(b"test data") {
290+
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
291+
match adapter.poll_write_ready(cx) {
292+
Poll::Ready(Ok(())) => {
293+
Poll::Ready(adapter.try_write(b"test data"))
294+
}
295+
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
296+
Poll::Pending => Poll::Pending,
297+
}
298+
}
299+
other => Poll::Ready(other),
300+
}
301+
})
302+
.await
303+
.unwrap();
304+
305+
assert_eq!(n, 9);
306+
307+
// Read it back from the other end
308+
let mut read_buf = vec![0u8; 32];
309+
let n = server.read(&mut read_buf).await.unwrap();
310+
assert_eq!(&read_buf[..n], b"test data");
311+
});
312+
}
313+
314+
#[test]
315+
fn partial_drain_preserves_remaining() {
316+
let (client, mut server) = tokio::io::duplex(64);
317+
318+
let rt = tokio::runtime::Builder::new_current_thread()
319+
.enable_all()
320+
.build()
321+
.unwrap();
322+
323+
rt.block_on(async {
324+
use tokio::io::AsyncWriteExt;
325+
server.write_all(b"abcdefghij").await.unwrap();
326+
327+
let mut adapter = TokioStream::new(client);
328+
329+
// Fill internal buffer
330+
std::future::poll_fn(|cx| adapter.poll_read_ready(cx)).await.unwrap();
331+
332+
// Read only 4 bytes using a fixed-size slice
333+
let mut buf = [0u8; 4];
334+
let n = adapter.try_read(&mut buf.as_mut_slice()).unwrap();
335+
assert_eq!(n, 4);
336+
assert_eq!(&buf, b"abcd");
337+
338+
// Remaining 6 bytes should still be available
339+
let mut buf2 = [0u8; 32];
340+
let n = adapter.try_read(&mut buf2.as_mut_slice()).unwrap();
341+
assert_eq!(n, 6);
342+
assert_eq!(&buf2[..6], b"efghij");
343+
});
344+
}
345+
}
346+
}

sqlx-core/src/net/socket/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ use cfg_if::cfg_if;
1010

1111
use crate::io::ReadBuf;
1212

13+
#[cfg(any(feature = "_rt-tokio", feature = "_rt-async-io"))]
14+
pub(crate) mod async_rw_adapter;
1315
mod buffered;
1416

1517
pub trait Socket: Send + Sync + Unpin + 'static {

0 commit comments

Comments
 (0)