Skip to content

Commit e986ced

Browse files
committed
feat: in-memory buffering of downloaded figures
1 parent 86a5617 commit e986ced

1 file changed

Lines changed: 111 additions & 9 deletions

File tree

src/core.rs

Lines changed: 111 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -511,23 +511,45 @@ async fn download_figure(
511511
return None;
512512
}
513513

514+
let suffix = content_type_to_suffix(content_type.as_deref())
515+
.or_else(|| url_suffix(remote_url))
516+
.unwrap_or_else(|| ".img".to_string());
517+
518+
let filename = format!("{base_name}{suffix}");
519+
let output = figures_dir.join(&filename);
520+
let seed = SystemTime::now()
521+
.duration_since(UNIX_EPOCH)
522+
.unwrap_or_default()
523+
.as_nanos();
524+
let temp_path = output.with_extension(format!("tmp-{}-{seed}", std::process::id()));
525+
let mut temp_file = match fs::File::create(&temp_path).await {
526+
Ok(file) => file,
527+
Err(_) => return None,
528+
};
529+
514530
let mut stream = response.bytes_stream();
515-
let mut bytes: Vec<u8> = Vec::new();
531+
let mut written = 0u64;
516532
while let Some(chunk) = stream.next().await {
517533
let chunk = chunk.ok()?;
518-
bytes.extend_from_slice(&chunk);
519-
if bytes.len() as u64 > max_download_bytes {
534+
written += chunk.len() as u64;
535+
if written > max_download_bytes {
536+
let _ = fs::remove_file(&temp_path).await;
537+
return None;
538+
}
539+
if temp_file.write_all(&chunk).await.is_err() {
540+
let _ = fs::remove_file(&temp_path).await;
520541
return None;
521542
}
522543
}
523544

524-
let suffix = content_type_to_suffix(content_type.as_deref())
525-
.or_else(|| url_suffix(remote_url))
526-
.unwrap_or_else(|| ".img".to_string());
545+
if temp_file.flush().await.is_err() {
546+
let _ = fs::remove_file(&temp_path).await;
547+
return None;
548+
}
549+
drop(temp_file);
527550

528-
let filename = format!("{base_name}{suffix}");
529-
let output = figures_dir.join(&filename);
530-
if atomic_write_bytes(&output, &bytes).await.is_err() {
551+
if fs::rename(&temp_path, &output).await.is_err() {
552+
let _ = fs::remove_file(&temp_path).await;
531553
return None;
532554
}
533555

@@ -636,6 +658,10 @@ mod tests {
636658
use std::sync::atomic::{AtomicUsize, Ordering};
637659
use std::sync::{Mutex, OnceLock};
638660
use tempfile::TempDir;
661+
#[cfg(feature = "net-tests")]
662+
use tokio::io::AsyncReadExt;
663+
#[cfg(feature = "net-tests")]
664+
use tokio::net::TcpListener;
639665

640666
fn env_lock() -> &'static Mutex<()> {
641667
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
@@ -934,6 +960,35 @@ mod tests {
934960
assert_eq!(round3(Duration::from_millis(1234)), 1.234);
935961
}
936962

963+
#[cfg(feature = "net-tests")]
964+
async fn start_chunked_image_server(
965+
chunks: Vec<Vec<u8>>,
966+
) -> (String, tokio::task::JoinHandle<()>) {
967+
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
968+
let addr = listener.local_addr().unwrap();
969+
let handle = tokio::spawn(async move {
970+
let Ok((mut socket, _)) = listener.accept().await else {
971+
return;
972+
};
973+
let mut request_buf = [0u8; 1024];
974+
let _ = socket.read(&mut request_buf).await;
975+
let _ = socket
976+
.write_all(
977+
b"HTTP/1.1 200 OK\r\nContent-Type: image/png\r\nTransfer-Encoding: chunked\r\nConnection: close\r\n\r\n",
978+
)
979+
.await;
980+
for chunk in chunks {
981+
let _ = socket
982+
.write_all(format!("{:X}\r\n", chunk.len()).as_bytes())
983+
.await;
984+
let _ = socket.write_all(&chunk).await;
985+
let _ = socket.write_all(b"\r\n").await;
986+
}
987+
let _ = socket.write_all(b"0\r\n\r\n").await;
988+
});
989+
(format!("http://{addr}"), handle)
990+
}
991+
937992
#[cfg(feature = "net-tests")]
938993
#[test]
939994
fn download_figure_accepts_image_from_mock_server() {
@@ -995,6 +1050,53 @@ mod tests {
9951050
assert!(result.is_none());
9961051
}
9971052

1053+
#[cfg(feature = "net-tests")]
1054+
#[test]
1055+
fn download_figure_streams_chunked_response_and_cleans_up_on_limit() {
1056+
let rt = tokio::runtime::Runtime::new().unwrap();
1057+
let (url, server_task) = rt.block_on(start_chunked_image_server(vec![
1058+
vec![1, 2, 3, 4],
1059+
vec![5, 6, 7, 8],
1060+
]));
1061+
1062+
let tmp = TempDir::new().unwrap();
1063+
let client = reqwest::Client::builder()
1064+
.timeout(Duration::from_secs(2))
1065+
.build()
1066+
.unwrap();
1067+
let result = rt.block_on(download_figure(&client, &url, tmp.path(), "fig-001-001", 3));
1068+
1069+
assert!(result.is_none());
1070+
assert!(std::fs::read_dir(tmp.path()).unwrap().next().is_none());
1071+
rt.block_on(server_task).unwrap();
1072+
}
1073+
1074+
#[cfg(feature = "net-tests")]
1075+
#[test]
1076+
fn download_figure_streams_chunked_response_successfully() {
1077+
let rt = tokio::runtime::Runtime::new().unwrap();
1078+
let (url, server_task) =
1079+
rt.block_on(start_chunked_image_server(vec![vec![1, 2], vec![3, 4]]));
1080+
1081+
let tmp = TempDir::new().unwrap();
1082+
let client = reqwest::Client::builder()
1083+
.timeout(Duration::from_secs(2))
1084+
.build()
1085+
.unwrap();
1086+
let name = rt.block_on(download_figure(
1087+
&client,
1088+
&url,
1089+
tmp.path(),
1090+
"fig-001-001",
1091+
1024,
1092+
));
1093+
1094+
assert_eq!(name.as_deref(), Some("fig-001-001.png"));
1095+
let file = tmp.path().join(name.unwrap());
1096+
assert_eq!(std::fs::read(&file).unwrap(), vec![1, 2, 3, 4]);
1097+
rt.block_on(server_task).unwrap();
1098+
}
1099+
9981100
#[cfg(feature = "net-tests")]
9991101
#[test]
10001102
fn localize_figures_rewrites_markdown_and_tracks_progress() {

0 commit comments

Comments
 (0)