Skip to content

Commit e8c8d21

Browse files
authored
closes #7
feat: add OCR-specific concurrency control for batch processing
2 parents b9576c7 + e100965 commit e8c8d21

7 files changed

Lines changed: 199 additions & 14 deletions

File tree

README.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ Therefore, this project was created because, while [`docling`](https://github.co
2929

3030
- Async OCR requests and batch PDF processing using the Z.AI API.
3131
- Concurrent figure downloads for each PDF.
32-
- Fast processing: approximately 25 seconds per batch of 32 PDFs. Speed depends on the z.ai API availability. See the cost section for more details on spending.
32+
- Fast processing with separate controls for total pipeline concurrency and OCR API concurrency.
3333

3434
> [!note]
3535
> This tool was designed to be used with academic papers written in English. Parsing other PDFs, heavy in tables or figures, or in other languages rather than English has not been tested.
@@ -45,9 +45,11 @@ paperdown --input path/to/paper.pdf
4545
My preferred method is batch directory processing:
4646

4747
```bash
48-
paperdown --input pdf/ --output md/ --workers 4 --overwrite
48+
paperdown --input pdf/ --output md/ --workers 32 --ocr-workers 2 --overwrite
4949
```
5050

51+
`--workers` controls how many PDFs are processed concurrently in batch mode. `--ocr-workers` controls concurrent OCR API calls. Effective OCR concurrency is `min(--workers, --ocr-workers)`.
52+
5153
## Installation
5254

5355
Install from crates.io:
@@ -87,6 +89,7 @@ Options:
8789
--timeout <TIMEOUT> HTTP timeout in seconds for OCR requests and figure downloads. [default: 180]
8890
--max-download-bytes <MAX_DOWNLOAD_BYTES> Maximum allowed size (bytes) for each downloaded figure file. [default: 20971520]
8991
--workers <WORKERS> Maximum number of PDFs processed concurrently in batch mode. [default: 32]
92+
--ocr-workers <OCR_WORKERS> Maximum number of concurrent OCR API calls in batch mode; effective OCR concurrency is min(--workers, --ocr-workers). [default: 2]
9093
-v, --verbose Enable verbose progress messages on stderr.
9194
--overwrite Replace existing managed output artifacts (index.md, figures/, and tables/ when enabled).
9295
--normalize-tables Normalize OCR HTML tables into Markdown and store raw HTML under tables/.

src/cli.rs

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ pub struct Cli {
7272
)]
7373
pub workers: usize,
7474

75+
#[arg(
76+
long = "ocr-workers",
77+
default_value_t = 2usize,
78+
value_parser = parse_positive_usize,
79+
help = "Maximum number of concurrent OCR API calls in batch mode; effective OCR concurrency is min(--workers, --ocr-workers)."
80+
)]
81+
pub ocr_workers: usize,
82+
7583
#[arg(
7684
short = 'v',
7785
long,
@@ -96,10 +104,7 @@ pub struct Cli {
96104
}
97105

98106
pub fn default_workers() -> usize {
99-
let cpu = std::thread::available_parallelism()
100-
.map(|n| n.get())
101-
.unwrap_or(4);
102-
(cpu * 4).clamp(4, 32)
107+
32
103108
}
104109

105110
fn parse_positive_usize(value: &str) -> Result<usize, String> {
@@ -118,9 +123,8 @@ mod tests {
118123
use clap::{CommandFactory, Parser};
119124

120125
#[test]
121-
fn default_workers_formula_bounds() {
122-
let workers = default_workers();
123-
assert!((4..=32).contains(&workers));
126+
fn default_workers_is_32() {
127+
assert_eq!(default_workers(), 32);
124128
}
125129

126130
#[test]
@@ -132,6 +136,7 @@ mod tests {
132136
assert_eq!(cli.timeout, 180);
133137
assert_eq!(cli.max_download_bytes, 20_971_520);
134138
assert_eq!(cli.workers, default_workers());
139+
assert_eq!(cli.ocr_workers, 2);
135140
assert!(!cli.verbose);
136141
assert!(!cli.overwrite);
137142
assert!(!cli.normalize_tables);
@@ -151,6 +156,9 @@ mod tests {
151156
.is_err()
152157
);
153158
assert!(Cli::try_parse_from(["paperdown", "--input", "in.pdf", "--workers", "0"]).is_err());
159+
assert!(
160+
Cli::try_parse_from(["paperdown", "--input", "in.pdf", "--ocr-workers", "0"]).is_err()
161+
);
154162
}
155163

156164
#[test]
@@ -166,5 +174,6 @@ mod tests {
166174
assert!(env_second.is_some());
167175
assert!(file_first.unwrap() < env_second.unwrap());
168176
assert!(help.contains("single .pdf file or a directory"));
177+
assert!(help.contains("min(--workers, --ocr-workers)"));
169178
}
170179
}

src/core.rs

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
use anyhow::{Context, Result, anyhow};
22
use serde::Serialize;
33
use serde_json::{Value, json};
4+
use std::future::Future;
45
use std::path::Path;
56
use std::sync::Arc;
67
use std::time::{Duration, Instant};
78
use time::OffsetDateTime;
89
use time::format_description::well_known::Rfc3339;
10+
use tokio::sync::Semaphore;
911

1012
mod assets;
1113
mod input;
@@ -56,6 +58,17 @@ pub async fn process_pdf(
5658
output_root: &Path,
5759
env_file: &Path,
5860
options: ProcessPdfOptions,
61+
) -> Result<PdfSummary> {
62+
process_pdf_with_ocr_limiter(pdf_path, output_root, env_file, options, None).await
63+
}
64+
65+
#[doc(hidden)]
66+
pub async fn process_pdf_with_ocr_limiter(
67+
pdf_path: &Path,
68+
output_root: &Path,
69+
env_file: &Path,
70+
options: ProcessPdfOptions,
71+
ocr_limiter: Option<Arc<Semaphore>>,
5972
) -> Result<PdfSummary> {
6073
let run_started = Instant::now();
6174
let pdf_path = pdf_path
@@ -78,7 +91,10 @@ pub async fn process_pdf(
7891
let payload = ocr::build_payload(&pdf_path).await?;
7992
fire(&options.progress, ProgressEvent::OcrStarted);
8093
let ocr_started = Instant::now();
81-
let response = ocr::call_layout_parsing(&client, &api_key, payload).await?;
94+
let response = run_with_ocr_limiter(ocr_limiter, async {
95+
ocr::call_layout_parsing(&client, &api_key, payload).await
96+
})
97+
.await?;
8298
let ocr_seconds = ocr_started.elapsed();
8399
fire(&options.progress, ProgressEvent::OcrFinished);
84100

@@ -169,6 +185,71 @@ fn round3(duration: Duration) -> f64 {
169185
((duration.as_secs_f64() * 1000.0).round()) / 1000.0
170186
}
171187

188+
async fn run_with_ocr_limiter<T, F>(limiter: Option<Arc<Semaphore>>, future: F) -> Result<T>
189+
where
190+
F: Future<Output = Result<T>>,
191+
{
192+
if let Some(limiter) = limiter {
193+
let _permit = limiter
194+
.acquire_owned()
195+
.await
196+
.context("OCR limiter closed unexpectedly")?;
197+
future.await
198+
} else {
199+
future.await
200+
}
201+
}
202+
203+
#[cfg(test)]
204+
mod tests {
205+
use super::*;
206+
use std::sync::atomic::{AtomicUsize, Ordering};
207+
use tokio::time::{Duration, sleep};
208+
209+
#[test]
210+
fn run_with_ocr_limiter_caps_parallelism() {
211+
let runtime = tokio::runtime::Runtime::new().expect("runtime");
212+
runtime.block_on(async {
213+
let limiter = Arc::new(Semaphore::new(2));
214+
let active = Arc::new(AtomicUsize::new(0));
215+
let peak = Arc::new(AtomicUsize::new(0));
216+
217+
let mut tasks = Vec::new();
218+
for _ in 0..8 {
219+
let limiter = Some(limiter.clone());
220+
let active = active.clone();
221+
let peak = peak.clone();
222+
tasks.push(tokio::spawn(async move {
223+
run_with_ocr_limiter(limiter, async {
224+
let current = active.fetch_add(1, Ordering::SeqCst) + 1;
225+
loop {
226+
let seen = peak.load(Ordering::SeqCst);
227+
if current <= seen {
228+
break;
229+
}
230+
if peak
231+
.compare_exchange(seen, current, Ordering::SeqCst, Ordering::SeqCst)
232+
.is_ok()
233+
{
234+
break;
235+
}
236+
}
237+
sleep(Duration::from_millis(30)).await;
238+
active.fetch_sub(1, Ordering::SeqCst);
239+
Ok::<(), anyhow::Error>(())
240+
})
241+
.await
242+
}));
243+
}
244+
245+
for task in tasks {
246+
task.await.expect("join").expect("task result");
247+
}
248+
assert!(peak.load(Ordering::SeqCst) <= 2);
249+
});
250+
}
251+
}
252+
172253
#[cfg(feature = "internal-testing")]
173254
#[doc(hidden)]
174255
pub mod testing {

src/core/ocr.rs

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,18 @@ pub(crate) async fn call_layout_parsing(
2121
client: &reqwest::Client,
2222
api_key: &str,
2323
payload: Value,
24+
) -> Result<Value> {
25+
call_layout_parsing_at_url(client, api_key, payload, API_URL).await
26+
}
27+
28+
pub(crate) async fn call_layout_parsing_at_url(
29+
client: &reqwest::Client,
30+
api_key: &str,
31+
payload: Value,
32+
api_url: &str,
2433
) -> Result<Value> {
2534
let response = client
26-
.post(API_URL)
35+
.post(api_url)
2736
.header("Authorization", format!("Bearer {api_key}"))
2837
.json(&payload)
2938
.send()
@@ -33,6 +42,11 @@ pub(crate) async fn call_layout_parsing(
3342
let status = response.status();
3443
let text = response.text().await?;
3544
if !status.is_success() {
45+
if status.as_u16() == 429 {
46+
return Err(anyhow!(
47+
"Z.AI OCR rate limit (HTTP 429). Lower --ocr-workers (e.g. 1) or reduce concurrent jobs sharing this API key."
48+
));
49+
}
3650
return Err(anyhow!(
3751
"Z.AI OCR request failed with HTTP {}: {}",
3852
status.as_u16(),
@@ -64,3 +78,51 @@ pub(crate) fn validate_layout_response(data: Value) -> Result<(String, Vec<Value
6478
let usage = data.get("usage").filter(|v| v.is_object()).cloned();
6579
Ok((markdown, layout_details, usage))
6680
}
81+
82+
#[cfg(test)]
83+
mod tests {
84+
use super::*;
85+
use serde_json::json;
86+
use tokio::io::{AsyncReadExt, AsyncWriteExt};
87+
use tokio::net::TcpListener;
88+
89+
#[test]
90+
fn call_layout_parsing_429_returns_actionable_error() {
91+
let runtime = tokio::runtime::Runtime::new().expect("runtime");
92+
runtime.block_on(async {
93+
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
94+
let addr = listener.local_addr().expect("local addr");
95+
96+
let server = tokio::spawn(async move {
97+
let (mut stream, _) = listener.accept().await.expect("accept");
98+
let mut read_buf = [0u8; 4096];
99+
let _ = stream.read(&mut read_buf).await.expect("read request");
100+
let body = r#"{"error":{"code":"1302","message":"Rate limit reached for requests"}}"#;
101+
let response = format!(
102+
"HTTP/1.1 429 Too Many Requests\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
103+
body.len(),
104+
body
105+
);
106+
stream
107+
.write_all(response.as_bytes())
108+
.await
109+
.expect("write response");
110+
});
111+
112+
let client = reqwest::Client::new();
113+
let err = call_layout_parsing_at_url(
114+
&client,
115+
"test-key",
116+
json!({"model": "glm-ocr", "file": "data:application/pdf;base64,AA=="}),
117+
&format!("http://{addr}"),
118+
)
119+
.await
120+
.expect_err("expected 429 error")
121+
.to_string();
122+
123+
server.await.expect("server done");
124+
assert!(err.contains("Z.AI OCR rate limit (HTTP 429)"));
125+
assert!(err.contains("--ocr-workers"));
126+
});
127+
}
128+
}

src/main.rs

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,19 @@ async fn run() -> Result<i32> {
5858
}
5959

6060
let workers = args.workers.min(pdfs.len()).max(1);
61-
eprintln!("Processing {} PDFs with {} workers...", pdfs.len(), workers);
61+
let ocr_workers = effective_ocr_workers(workers, args.ocr_workers);
62+
eprintln!(
63+
"Processing {} PDFs with {} workers (OCR concurrency: {})...",
64+
pdfs.len(),
65+
workers,
66+
ocr_workers
67+
);
6268

6369
let semaphore = Arc::new(Semaphore::new(workers));
70+
let ocr_semaphore = Arc::new(Semaphore::new(ocr_workers));
6471
let results = stream::iter(pdfs.into_iter().map(|pdf| {
6572
let permit_pool = semaphore.clone();
73+
let ocr_limiter = ocr_semaphore.clone();
6674
let output = args.output.clone();
6775
let env_file = args.env_file.clone();
6876
let progress = progress.clone();
@@ -75,7 +83,14 @@ async fn run() -> Result<i32> {
7583
};
7684
async move {
7785
let _permit = permit_pool.acquire_owned().await.expect("semaphore");
78-
let res = core::process_pdf(&pdf, &output, &env_file, options).await;
86+
let res = core::process_pdf_with_ocr_limiter(
87+
&pdf,
88+
&output,
89+
&env_file,
90+
options,
91+
Some(ocr_limiter),
92+
)
93+
.await;
7994
(pdf, res)
8095
}
8196
}))
@@ -111,6 +126,10 @@ fn stderr_is_tty() -> bool {
111126
std::io::stderr().is_terminal()
112127
}
113128

129+
fn effective_ocr_workers(workers: usize, ocr_workers: usize) -> usize {
130+
workers.min(ocr_workers).max(1)
131+
}
132+
114133
fn format_error_for_stderr(message: &str) -> String {
115134
if stderr_is_tty() {
116135
return message.replace("--overwrite", "\x1b[1;33m--overwrite\x1b[0m");
@@ -299,4 +318,11 @@ mod tests {
299318
callback(ProgressEvent::FigureDownloadFinished);
300319
callback(ProgressEvent::FigureDownloadFinished);
301320
}
321+
322+
#[test]
323+
fn effective_ocr_workers_caps_to_total_workers() {
324+
assert_eq!(effective_ocr_workers(32, 2), 2);
325+
assert_eq!(effective_ocr_workers(8, 32), 8);
326+
assert_eq!(effective_ocr_workers(1, 2), 1);
327+
}
302328
}

tests/cli_coverage.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@ fn cli_batch_reports_failed_count() {
5656
"--env-file",
5757
env_file.to_str().unwrap(),
5858
"--workers",
59-
"2",
59+
"1",
60+
"--ocr-workers",
61+
"5",
6062
])
6163
.output()
6264
.unwrap();
@@ -66,4 +68,5 @@ fn cli_batch_reports_failed_count() {
6668
let stderr = String::from_utf8_lossy(&output.stderr);
6769
assert!(stdout.contains("Batch Complete processed: 0 failed: 2 figures: 0"));
6870
assert!(stderr.contains("failed:"));
71+
assert!(stderr.contains("OCR concurrency: 1"));
6972
}

tests/cli_existing_output.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ fn batch_existing_outputs_fail_before_env_or_ocr() {
4747
assert!(stderr.contains("a.pdf"));
4848
assert!(stderr.contains("b.pdf"));
4949
assert!(stderr.contains("Re-run with --overwrite"));
50+
assert!(stderr.contains("OCR concurrency:"));
5051

5152
assert!(!stderr.contains("ZAI_API_KEY"));
5253
assert!(!stdout.contains("\u{1b}["));

0 commit comments

Comments
 (0)