Skip to content

Commit c8b3067

Browse files
committed
feat(mysql): support load data infile
1 parent 6956cef commit c8b3067

7 files changed

Lines changed: 158 additions & 5 deletions

File tree

sqlx-core/src/fs.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,22 @@ impl ReadDir {
9494
}
9595
}
9696
}
97+
98+
#[cfg(feature = "_rt-tokio")]
99+
pub async fn open_file<P: AsRef<Path>>(path: P) -> Result<tokio::fs::File, io::Error> {
100+
if rt::rt_tokio::available() {
101+
return tokio::fs::File::open(path).await;
102+
}
103+
104+
rt::missing_rt(path);
105+
}
106+
107+
#[cfg(all(feature = "_rt-async-io", not(feature = "_rt-tokio")))]
108+
pub async fn open_file<P: AsRef<Path>>(path: P) -> Result<async_fs::File, io::Error> {
109+
async_fs::File::open(path).await
110+
}
111+
112+
#[cfg(all(not(feature = "_rt-async-io"), not(feature = "_rt-tokio")))]
113+
pub async fn open_file<P: AsRef<Path>>(path: P) -> Result<futures_util::io::Empty, io::Error> {
114+
rt::missing_rt(path)
115+
}

sqlx-mysql/src/connection/executor.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use crate::executor::{Execute, Executor};
55
use crate::ext::ustr::UStr;
66
use crate::io::MySqlBufExt;
77
use crate::logger::QueryLogger;
8-
use crate::protocol::response::Status;
8+
use crate::protocol::response::{LocalInfilePacket, Status};
99
use crate::protocol::statement::{
1010
BinaryRow, Execute as StatementExecute, Prepare, PrepareOk, StmtClose,
1111
};
@@ -22,7 +22,9 @@ use futures_core::stream::BoxStream;
2222
use futures_core::Stream;
2323
use futures_util::TryStreamExt;
2424
use sqlx_core::column::{ColumnOrigin, TableColumn};
25+
use sqlx_core::fs::open_file;
2526
use sqlx_core::sql_str::SqlStr;
27+
use std::path::PathBuf;
2628
use std::{pin::pin, sync::Arc};
2729

2830
impl MySqlConnection {
@@ -209,6 +211,19 @@ impl MySqlConnection {
209211
return Ok(());
210212
}
211213

214+
if packet[0] == 0xfb {
215+
// LocalInfileRequest
216+
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_local_infile_request.html
217+
let packet = packet.decode::<LocalInfilePacket>()?;
218+
let path = PathBuf::from(String::from_utf8_lossy(&packet.filename).into_owned());
219+
let file = open_file(&path).await.map_err(|_| err_protocol!("cannot open file {} for local infile request", path.display()))?;
220+
221+
self.inner.stream.send_stream(file).await?;
222+
223+
continue;
224+
}
225+
226+
212227
// otherwise, this first packet is the start of the result-set metadata,
213228
*self.inner.stream.waiting.front_mut().unwrap() = Waiting::Row;
214229

sqlx-mysql/src/connection/stream.rs

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use bytes::{Buf, Bytes, BytesMut};
55

66
use crate::error::Error;
77
use crate::io::MySqlBufExt;
8-
use crate::io::{ProtocolDecode, ProtocolEncode};
8+
use crate::io::{AsyncRead, ProtocolDecode, ProtocolEncode};
99
use crate::net::{BufferedSocket, Socket};
1010
use crate::protocol::response::{EofPacket, ErrPacket, OkPacket, Status};
1111
use crate::protocol::{Capabilities, Packet};
@@ -43,7 +43,8 @@ impl<S: Socket> MySqlStream<S> {
4343
| Capabilities::MULTI_RESULTS
4444
| Capabilities::PLUGIN_AUTH
4545
| Capabilities::PS_MULTI_RESULTS
46-
| Capabilities::SSL;
46+
| Capabilities::SSL
47+
| Capabilities::LOCAL_FILES;
4748

4849
if options.database.is_some() {
4950
capabilities |= Capabilities::CONNECT_WITH_DB;
@@ -108,6 +109,43 @@ impl<S: Socket> MySqlStream<S> {
108109
Ok(())
109110
}
110111

112+
/// Send data from a stream to the database server as MySQL packets
113+
///
114+
/// This is used to send data for a LOCAL INFILE query
115+
pub(crate) async fn send_stream(
116+
&mut self,
117+
mut source: impl AsyncRead + Unpin,
118+
) -> Result<(), Error> {
119+
self.socket.flush().await?;
120+
121+
loop {
122+
let buf = self.socket.write_buffer_mut();
123+
124+
// Write the CopyData format code and reserve space for the length + sequence_id
125+
// This is safe even if empty, since we always need to send an empty packet at the end
126+
buf.put_slice(b"\0\0\0\0");
127+
128+
let read = buf.read_from(&mut source).await?;
129+
let read32 = i32::try_from(read)
130+
.map_err(|_| err_protocol!("number of bytes read exceeds 2^31 - 1: {}", read))?;
131+
132+
// rewrite header (len + sequenceid)
133+
let mut header = read32.to_le_bytes();
134+
header[3] = self.sequence_id;
135+
self.sequence_id = self.sequence_id.wrapping_add(1);
136+
137+
buf.get_mut()[..4].copy_from_slice(&header);
138+
139+
self.socket.flush().await?;
140+
141+
if read32 == 0 {
142+
break;
143+
}
144+
}
145+
146+
Ok(())
147+
}
148+
111149
pub(crate) fn write_packet<'en, T>(&mut self, payload: T) -> Result<(), Error>
112150
where
113151
T: ProtocolEncode<'en, Capabilities>,
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
use bytes::{Buf, Bytes};
2+
use sqlx_core::io::{BufExt, ProtocolDecode};
3+
4+
use crate::error::Error;
5+
6+
/// Requests the client to send a file to the server, following a LOCAL INFILE statement
7+
///
8+
/// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_local_infile_request.html
9+
#[derive(Debug)]
10+
pub struct LocalInfilePacket {
11+
pub filename: Vec<u8>,
12+
}
13+
14+
impl ProtocolDecode<'_> for LocalInfilePacket {
15+
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> {
16+
let header = buf.get_u8();
17+
if header != 0xfb {
18+
return Err(err_protocol!(
19+
"expected 0xfb (LocalInfileRequest) but found 0x{:02x}",
20+
header
21+
));
22+
}
23+
24+
let filename = buf.get_bytes(buf.len()).to_vec();
25+
26+
Ok(Self { filename })
27+
}
28+
}
29+
30+
#[test]
31+
fn test_decode_localinfile_packet() {
32+
const DATA: &[u8] = b"\xfb\x64\x75\x6d\x6d\x79";
33+
34+
let p = LocalInfilePacket::decode(DATA.into()).unwrap();
35+
36+
assert_eq!(p.filename, b"dummy");
37+
}

sqlx-mysql/src/protocol/response/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
66
mod eof;
77
mod err;
8+
mod local_infile;
89
mod ok;
910
mod status;
1011

1112
pub use eof::EofPacket;
1213
pub use err::ErrPacket;
14+
pub use local_infile::LocalInfilePacket;
1315
pub use ok::OkPacket;
1416
pub use status::Status;
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
1,a
2+
2,b

tests/mysql/mysql.rs

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use anyhow::Context;
22
use futures_util::TryStreamExt;
33
use sqlx::mysql::{MySql, MySqlConnection, MySqlPool, MySqlPoolOptions, MySqlRow};
4-
use sqlx::{Column, Connection, Executor, Row, SqlSafeStr, Statement, TypeInfo};
4+
use sqlx::{AssertSqlSafe, Column, Connection, Executor, Row, SqlSafeStr, Statement, TypeInfo};
55
use sqlx_core::connection::ConnectOptions;
66
use sqlx_core::types::Type;
77
use sqlx_mysql::MySqlConnectOptions;
@@ -599,7 +599,7 @@ async fn select_statement_count(conn: &mut MySqlConnection) -> Result<i64, sqlx:
599599
SELECT COUNT(*)
600600
FROM performance_schema.threads AS t
601601
INNER JOIN performance_schema.prepared_statements_instances AS psi
602-
ON psi.OWNER_THREAD_ID = t.THREAD_ID
602+
ON psi.OWNER_THREAD_ID = t.THREAD_ID
603603
WHERE t.processlist_id = CONNECTION_ID()
604604
"#,
605605
)
@@ -727,3 +727,43 @@ async fn any_blob_conversions() -> anyhow::Result<()> {
727727

728728
Ok(())
729729
}
730+
731+
#[sqlx_macros::test]
732+
async fn it_can_load_a_file() -> anyhow::Result<()> {
733+
let mut conn = new::<MySql>().await?;
734+
735+
let _ = conn
736+
.execute(
737+
r#"
738+
CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY, name TEXT);
739+
"#,
740+
)
741+
.await?;
742+
743+
let _ = conn.execute("SET GLOBAL local_infile = 1;").await?;
744+
745+
let file_path = env::current_dir()
746+
.unwrap()
747+
.join("tests/mysql/fixtures/load_data_infile.txt");
748+
749+
// Execute LOAD DATA LOCAL INFILE
750+
let load_query = format!(
751+
"LOAD DATA LOCAL INFILE '{}' INTO TABLE users FIELDS TERMINATED BY ',' LINES TERMINATED BY '\\n'",
752+
file_path.display()
753+
);
754+
755+
let result = conn.execute(AssertSqlSafe(load_query)).await;
756+
757+
if let Err(e) = result {
758+
assert!(false, "{:?}", e)
759+
}
760+
761+
let name = sqlx::query("SELECT name FROM users WHERE id = 1")
762+
.try_map(|row: MySqlRow| row.try_get::<String, _>(0))
763+
.fetch_one(&mut conn)
764+
.await?;
765+
766+
assert_eq!("a", name);
767+
768+
Ok(())
769+
}

0 commit comments

Comments
 (0)