From 0f25a9c08d68f00e67c2bcfe512885d3569a9df6 Mon Sep 17 00:00:00 2001 From: Eduardo Flores Date: Tue, 7 Apr 2026 15:39:19 -0700 Subject: [PATCH] Introduce backend abstraction and Reqwest backend --- Cargo.toml | 12 +- examples/usage.rs | 2 +- src/backend/error.rs | 57 +++++++++ src/backend/mod.rs | 127 +++++++++++++++++++ src/backend/reqwest_backend.rs | 220 +++++++++++++++++++++++++++++++++ src/context.rs | 91 ++++++++++---- src/download_manager.rs | 147 +++++++++++++++++----- src/error.rs | 36 +++--- src/events.rs | 2 +- src/request.rs | 8 +- src/worker.rs | 145 +++++++++++----------- 11 files changed, 694 insertions(+), 153 deletions(-) create mode 100644 src/backend/error.rs create mode 100644 src/backend/mod.rs create mode 100644 src/backend/reqwest_backend.rs diff --git a/Cargo.toml b/Cargo.toml index 440f116..2b391f7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,13 +6,23 @@ edition.workspace = true [lib] path = "src/download_manager.rs" +[features] +default = ["reqwest"] +reqwest = ["dep:reqwest"] + [dependencies] -reqwest.workspace = true +reqwest = { workspace = true, optional = true } + tokio.workspace = true tokio-stream.workspace = true tokio-util.workspace = true futures-core.workspace = true futures-util.workspace = true + +url.workspace = true +http.workspace = true +bytes.workspace = true + thiserror.workspace = true anyhow.workspace = true derive_builder.workspace = true diff --git a/examples/usage.rs b/examples/usage.rs index 5e3bb5f..7af734e 100644 --- a/examples/usage.rs +++ b/examples/usage.rs @@ -2,7 +2,7 @@ use std::path::PathBuf; use bottles_download_manager::{DownloadManager, prelude::*}; use futures_util::StreamExt; -use reqwest::Url; +use url::Url; // use std::fmt::Debug; use tracing::{error, info}; use tracing_subscriber::EnvFilter; diff --git a/src/backend/error.rs b/src/backend/error.rs new file mode 100644 index 0000000..f928f1f --- /dev/null +++ b/src/backend/error.rs @@ -0,0 +1,57 @@ +use thiserror::Error; + +/// Transport-agnostic error type returned by [`super::DownloadBackend`] implementations. +/// +/// Backends are responsible for converting their native error types into this +/// enum so that the rest of the download pipeline can reason about +/// retryability without coupling itself to any particular HTTP library. +#[derive(Debug, Error)] +pub enum BackendError { + /// TCP / socket-level connection to the remote host failed. + #[error("Connection failed: {0}")] + Connect(String), + + /// The request or response exceeded a time limit. + #[error("Request timed out: {0}")] + Timeout(String), + + /// A well-formed request was sent but could not be completed. + #[error("Request error: {0}")] + Request(String), + + /// The server returned an HTTP 5xx status code. + #[error("HTTP server error {status}: {message}")] + ServerError { status: u16, message: String }, + + /// The operation was aborted via the download's [`tokio_util::sync::CancellationToken`]. + #[error("Cancelled")] + Cancelled, + + /// Any other transport-layer error that does not fit a specific category. + #[error("Network error: {0}")] + Other(String), +} + +impl BackendError { + /// Returns `true` if the scheduler should schedule a retry after this error. + /// + /// The following variants are considered **retryable** (transient failures + /// that have a reasonable chance of succeeding on a subsequent attempt): + /// + /// - [`BackendError::Connect`] + /// - [`BackendError::Timeout`] + /// - [`BackendError::Request`] + /// - [`BackendError::ServerError`] (HTTP 5xx) + /// + /// [`BackendError::Cancelled`] and [`BackendError::Other`] are **not** + /// retryable. + pub fn is_retryable(&self) -> bool { + matches!( + self, + BackendError::Connect(_) + | BackendError::Timeout(_) + | BackendError::Request(_) + | BackendError::ServerError { .. } + ) + } +} diff --git a/src/backend/mod.rs b/src/backend/mod.rs new file mode 100644 index 0000000..b89ba37 --- /dev/null +++ b/src/backend/mod.rs @@ -0,0 +1,127 @@ +mod error; +#[cfg(feature = "reqwest")] +mod reqwest_backend; + +pub use error::BackendError; +#[cfg(feature = "reqwest")] +pub use reqwest_backend::ReqwestBackend; + +use crate::download::RemoteInfo; +use bytes::Bytes; +use futures_core::Stream; +use http::HeaderMap; +use std::{future::Future, pin::Pin}; +use tokio_util::sync::CancellationToken; +use url::Url; + +/// The response yielded by a successful [`DownloadBackend::fetch`] call. +/// +/// Carries the optional total size advertised by the server and a pinned +/// stream of raw byte chunks. +pub struct BackendResponse { + /// Total byte count as reported by the server's `Content-Length` header, + /// or `None` when the server did not include one. + pub content_length: Option, + + /// Stream of raw byte chunks arriving from the server. + /// + /// Each item is `Ok(Bytes)` on success or `Err(BackendError)` when the + /// underlying transport encounters an error mid-stream. + pub stream: Pin> + Send + 'static>>, +} + +/// Abstraction over the transport layer used to perform downloads. +/// +/// Implement this trait to provide a custom HTTP (or non-HTTP) backend. +/// The [`ReqwestBackend`] provided behind the `reqwest` Cargo feature is the +/// built-in default. +/// +/// # Object safety +/// +/// The trait is object-safe when called through `Arc`. +/// Both methods return boxed futures so that the trait can be used as a +/// dynamic dispatch target without requiring `async_trait`. +/// +/// # Cancellation +/// +/// Both methods receive a [`CancellationToken`] for the in-flight download. +/// Implementations should honour cancellation as early as possible (e.g. +/// aborting the TCP connection rather than waiting for a response). When a +/// probe is cancelled `probe_head` should return `None`; when a fetch is +/// cancelled before the response headers arrive `fetch` should return +/// `Err(BackendError::Cancelled)`. Chunk-level cancellation during streaming +/// is handled by the worker via `tokio::select!` so backends are not required +/// to poll the token inside the returned stream. +/// +/// # Example — minimal backend stub +/// +/// ```rust,ignore +/// use std::{future::Future, pin::Pin}; +/// use bytes::Bytes; +/// use http::HeaderMap; +/// use tokio_util::sync::CancellationToken; +/// use url::Url; +/// use bottles_download_manager::backend::{ +/// BackendError, BackendResponse, DownloadBackend, +/// }; +/// use bottles_download_manager::download::RemoteInfo; +/// +/// struct MyBackend; +/// +/// impl DownloadBackend for MyBackend { +/// fn probe_head<'a>( +/// &'a self, +/// _url: &'a Url, +/// _headers: &'a HeaderMap, +/// _cancel: &'a CancellationToken, +/// ) -> Pin> + Send + 'a>> { +/// Box::pin(async { None }) +/// } +/// +/// fn fetch<'a>( +/// &'a self, +/// _url: &'a Url, +/// _headers: &'a HeaderMap, +/// _cancel: &'a CancellationToken, +/// ) -> Pin> + Send + 'a>> { +/// Box::pin(async { +/// Err(BackendError::Other("not implemented".into())) +/// }) +/// } +/// } +/// ``` +pub trait DownloadBackend: Send + Sync + 'static { + /// Perform a best-effort `HEAD` probe and return metadata about the remote + /// resource. + /// + /// Implementations should return `None` when: + /// - the server does not support `HEAD`, + /// - the response indicates an error, + /// - or the cancellation token is triggered before a response arrives. + /// + /// Returning `None` is safe; the download will proceed without pre-flight + /// metadata and no [`crate::events::Event::Probed`] event will be emitted. + fn probe_head<'a>( + &'a self, + url: &'a Url, + headers: &'a HeaderMap, + cancel: &'a CancellationToken, + ) -> Pin> + Send + 'a>>; + + /// Issue a `GET` request and return a streaming response. + /// + /// The implementation is responsible for: + /// - establishing the connection, + /// - reading and surfacing the `Content-Length` header (if present), + /// - returning a `Stream` of `Bytes` chunks. + /// + /// Returns `Err(BackendError::Cancelled)` if the cancellation token fires + /// before the response headers are received. Mid-stream cancellation is + /// managed by the caller. + fn fetch<'a>( + &'a self, + url: &'a Url, + headers: &'a HeaderMap, + cancel: &'a CancellationToken, + ) -> Pin> + Send + 'a>>; +} diff --git a/src/backend/reqwest_backend.rs b/src/backend/reqwest_backend.rs new file mode 100644 index 0000000..5ba3878 --- /dev/null +++ b/src/backend/reqwest_backend.rs @@ -0,0 +1,220 @@ +use std::{future::Future, pin::Pin}; + +use bytes::Bytes; +use futures_util::StreamExt as _; +use http::HeaderMap; +use reqwest::{Client, Method}; +use tokio_util::sync::CancellationToken; +use tracing::{debug, trace}; +use url::Url; + +use super::{BackendError, BackendResponse, DownloadBackend}; +use crate::download::RemoteInfo; + +impl From for BackendError { + fn from(e: reqwest::Error) -> Self { + if e.is_timeout() { + BackendError::Timeout(e.to_string()) + } else if e.is_connect() { + BackendError::Connect(e.to_string()) + } else if e.is_request() { + BackendError::Request(e.to_string()) + } else if let Some(status) = e.status() { + if status.is_server_error() { + BackendError::ServerError { + status: status.as_u16(), + message: e.to_string(), + } + } else { + BackendError::Other(e.to_string()) + } + } else { + BackendError::Other(e.to_string()) + } + } +} + +/// A [`DownloadBackend`] backed by [`reqwest`]. +/// +/// Uses a shared [`reqwest::Client`] internally for connection pooling and +/// keep-alive reuse across downloads. +/// +/// # Construction +/// +/// | Method | When to use | +/// |---|---| +/// | [`ReqwestBackend::default()`] | Quick start — plain client, no extra config. | +/// | [`ReqwestBackend::from_client`] | Reuse an existing client with custom TLS, proxy, or timeout settings. | +/// +/// # Example +/// +/// ```rust,ignore +/// use reqwest::ClientBuilder; +/// use std::time::Duration; +/// use bottles_download_manager::{DownloadManager, backend::ReqwestBackend}; +/// +/// let client = ClientBuilder::new() +/// .timeout(Duration::from_secs(30)) +/// .user_agent("my-app/1.0") +/// .build() +/// .unwrap(); +/// +/// let manager = DownloadManager::with_backend(ReqwestBackend::from_client(client)); +/// ``` +pub struct ReqwestBackend { + client: Client, +} + +impl Default for ReqwestBackend { + fn default() -> Self { + ReqwestBackend { + client: Client::new(), + } + } +} + +impl ReqwestBackend { + /// Wrap an existing [`reqwest::Client`]. + /// + /// Prefer this over [`Default`] when you need custom TLS roots, a proxy, + /// connection timeouts, or a specific `User-Agent`. + pub fn from_client(client: Client) -> Self { + ReqwestBackend { client } + } + + /// Return a reference to the underlying [`reqwest::Client`]. + pub fn client(&self) -> &Client { + &self.client + } +} + +impl std::fmt::Debug for ReqwestBackend { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ReqwestBackend").finish_non_exhaustive() + } +} + +impl DownloadBackend for ReqwestBackend { + /// Send an HTTP `HEAD` request and surface the response metadata. + /// + /// Cancellation is handled cooperatively: if the token fires before the + /// server responds, the future returns `None` immediately and the + /// in-flight request is dropped. + fn probe_head<'a>( + &'a self, + url: &'a Url, + headers: &'a HeaderMap, + cancel: &'a CancellationToken, + ) -> Pin> + Send + 'a>> { + Box::pin(async move { + use reqwest::header as rh; + + debug!(%url, "ReqwestBackend: HEAD probe"); + + let req = self + .client + .request(Method::HEAD, url.as_str()) + .headers(headers.clone()) + .send(); + + let resp = tokio::select! { + biased; + _ = cancel.cancelled() => { + debug!(%url, "HEAD probe cancelled before response"); + return None; + } + result = req => match result { + Ok(r) => match r.error_for_status() { + Ok(r) => r, + Err(e) => { + debug!(%url, error = %e, "HEAD probe returned error status"); + return None; + } + }, + Err(e) => { + debug!(%url, error = %e, "HEAD probe request failed"); + return None; + } + }, + }; + + let content_length = resp.content_length(); + trace!(%url, content_length = ?content_length, "HEAD response received"); + + let h = resp.headers(); + + Some(RemoteInfo { + content_length, + accept_ranges: h + .get(rh::ACCEPT_RANGES) + .and_then(|v| v.to_str().ok()) + .map(str::to_owned), + etag: h + .get(rh::ETAG) + .and_then(|v| v.to_str().ok()) + .map(str::to_owned), + last_modified: h + .get(rh::LAST_MODIFIED) + .and_then(|v| v.to_str().ok()) + .map(str::to_owned), + content_type: h + .get(rh::CONTENT_TYPE) + .and_then(|v| v.to_str().ok()) + .map(str::to_owned), + }) + }) + } + + /// Send an HTTP `GET` request and return a streaming [`BackendResponse`]. + /// + /// - If the cancellation token fires before the response headers arrive, + /// returns `Err(BackendError::Cancelled)`. + /// - The returned `stream` yields `Ok(Bytes)` chunks as they arrive from + /// the server, or `Err(BackendError)` if the connection drops mid-stream. + /// - Chunk-level cancellation during streaming is handled externally by + /// the worker via `tokio::select!`. + fn fetch<'a>( + &'a self, + url: &'a Url, + headers: &'a HeaderMap, + cancel: &'a CancellationToken, + ) -> Pin> + Send + 'a>> { + Box::pin(async move { + debug!(%url, "ReqwestBackend: GET fetch"); + + let req = self + .client + .request(Method::GET, url.as_str()) + .headers(headers.clone()) + .send(); + + let resp = tokio::select! { + biased; + _ = cancel.cancelled() => { + debug!(%url, "GET fetch cancelled before response headers"); + return Err(BackendError::Cancelled); + } + result = req => { + result + .map_err(BackendError::from)? + .error_for_status() + .map_err(BackendError::from)? + } + }; + + let content_length = resp.content_length(); + debug!(%url, content_length = ?content_length, "GET response headers received"); + + // Map each reqwest chunk error to a BackendError so the stream + // item type matches the trait's contract. + let stream = resp + .bytes_stream() + .map(|result: Result| result.map_err(BackendError::from)); + + Ok(BackendResponse { + content_length, + stream: Box::pin(stream), + }) + }) + } +} diff --git a/src/context.rs b/src/context.rs index 34a9bec..6a98677 100644 --- a/src/context.rs +++ b/src/context.rs @@ -1,4 +1,3 @@ -use reqwest::Client; use std::sync::{ Arc, atomic::{AtomicU64, AtomicUsize, Ordering}, @@ -7,73 +6,113 @@ use tokio::sync::Semaphore; use tokio_util::sync::CancellationToken; use tracing::info; -use crate::{DownloadManagerConfig, events::EventBus}; +use crate::{DownloadManagerConfig, backend::DownloadBackend, events::EventBus}; /// Unique identifier for a download; monotonically increasing u64. pub type DownloadID = u64; /// Shared runtime context for coordinating downloads. Internal to the crate. -/// Holds the concurrency semaphore, root cancellation token, HTTP client, -/// atomic counters, and the global [DownloadEvent] broadcast sender. -/// Cloned and shared across scheduler and workers. -#[derive(Debug)] +/// +/// Holds the concurrency semaphore, root cancellation token, pluggable download +/// backend, atomic counters, and the global [`Event`](crate::events::Event) +/// broadcast sender. Cloned (via [`Arc`]) and shared across the scheduler and +/// every spawned worker task. pub(crate) struct Context { - /// Semaphore limiting concurrent active downloads. + /// Semaphore limiting the number of concurrently active downloads. pub semaphore: Arc, - /// Root cancellation token; children inherit via [Context::child_token()]. + + /// Root cancellation token. Cancelling it cascades to all child tokens, + /// cooperatively stopping every in-flight download. pub cancel_root: CancellationToken, - /// Shared reqwest client reused across attempts. - pub client: Client, - // Counters - /// Monotonic counter for generating DownloadID values. + /// Pluggable transport backend (e.g. `ReqwestBackend`). + /// + /// Wrapped in an [`Arc`] so it can be shared cheaply across tasks without + /// requiring the backend to be `Clone`. + pub backend: Arc, + + /// Monotonic counter used to generate [`DownloadID`] values. Starts at 1. pub id_counter: AtomicU64, - /// Number of currently active (running) downloads. + + /// Number of currently active (running, not queued) downloads. pub active: AtomicUsize, - /// Configured maximum concurrency. Not automatically updated if semaphore changes. + + /// Configured maximum concurrency. Informational; the semaphore is the + /// true enforcement mechanism. pub max_concurrent: AtomicUsize, - /// Global [DownloadEvent] broadcaster (buffered). Slow subscribers may miss events. + /// Global [`Event`](crate::events::Event) broadcaster (bounded buffer, 1 024 + /// slots). Slow subscribers may miss events; use + /// [`EventBus::events`](crate::events::EventBus::events) for a + /// lagged-message-safe stream. pub events: EventBus, } +// Manual Debug impl because `Arc` is not Debug by default +// (the trait does not require it) and we don't want to force that bound on +// every implementor. +impl std::fmt::Debug for Context { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Context") + .field( + "max_concurrent", + &self.max_concurrent.load(Ordering::Relaxed), + ) + .field("active", &self.active.load(Ordering::Relaxed)) + .field("id_counter", &self.id_counter.load(Ordering::Relaxed)) + .finish_non_exhaustive() + } +} + impl Context { - /// Create a new shared Context. - /// - Initializes the semaphore with `max_concurrent` permits. - /// - Creates a root [CancellationToken] and a broadcast channel (capacity 1024). - /// - Constructs a shared [reqwest::Client]. - pub fn new(config: DownloadManagerConfig, cancel_root: CancellationToken) -> Arc { + /// Create a new shared [`Context`]. + /// + /// - Initialises the semaphore with `config.max_concurrent` permits. + /// - Stores the provided `cancel_root` token (the manager's shutdown token + /// child). + /// - Wraps `backend` in an [`Arc`] for cheap sharing across tasks. + /// - Creates a fresh broadcast channel for global events (capacity 1 024). + pub fn new( + config: DownloadManagerConfig, + cancel_root: CancellationToken, + backend: Arc, + ) -> Arc { let ctx = Arc::new(Self { semaphore: Arc::new(Semaphore::new(config.max_concurrent)), max_concurrent: AtomicUsize::new(config.max_concurrent), cancel_root, active: AtomicUsize::new(0), id_counter: AtomicU64::new(1), - client: Client::new(), + backend, events: EventBus::new(), }); + info!( max_concurrent = config.max_concurrent, "Context initialized" ); + ctx } - /// Atomically generate the next [DownloadID] (relaxed ordering). - /// Unique within the lifetime of this Context; starts at 1. + /// Atomically generate the next [`DownloadID`] (relaxed ordering). + /// + /// IDs are unique within the lifetime of this [`Context`] and start at 1. #[inline] pub fn next_id(&self) -> DownloadID { self.id_counter.fetch_add(1, Ordering::Relaxed) } - /// Create a child [CancellationToken] tied to the manager's root token. - /// Cancelling the root cascades to all children. + /// Create a child [`CancellationToken`] tied to the manager's root token. + /// + /// Cancelling the root (e.g. on [`DownloadManager::shutdown`]) cascades to + /// all children. #[inline] pub fn child_token(&self) -> CancellationToken { self.cancel_root.child_token() } - /// Cancel the root token, cooperatively cancelling all in-flight downloads. + /// Cancel the root token, cooperatively stopping all in-flight downloads. pub fn cancel_all(&self) { self.cancel_root.cancel(); } diff --git a/src/download_manager.rs b/src/download_manager.rs index 67b4ddc..79b41ff 100644 --- a/src/download_manager.rs +++ b/src/download_manager.rs @@ -1,3 +1,4 @@ +pub mod backend; mod context; mod download; mod error; @@ -8,15 +9,20 @@ mod worker; pub mod prelude { pub use crate::{ + backend::{BackendError, DownloadBackend}, context::DownloadID, download::{Download, DownloadResult}, error::DownloadError, events::{Event, Progress}, request::Request, }; + + #[cfg(feature = "reqwest")] + pub use crate::backend::ReqwestBackend; } use crate::{ + backend::DownloadBackend, context::Context, request::RequestBuilder, scheduler::{Scheduler, SchedulerCmd}, @@ -24,7 +30,6 @@ use crate::{ use derive_builder::Builder; use futures_core::Stream; use prelude::*; -use reqwest::Url; use std::{ path::Path, sync::{Arc, atomic::Ordering}, @@ -32,17 +37,40 @@ use std::{ use tokio::sync::{broadcast, mpsc}; use tokio_util::{sync::CancellationToken, task::TaskTracker}; use tracing::{debug, info, instrument, trace, warn}; +use url::Url; /// Entry point for scheduling, observing, and cancelling downloads. /// -/// Behavior +/// # Behavior /// - Enforces a global concurrency limit across all downloads. -/// - Publishes global DownloadEvent notifications and exposes per-download streams via Download. +/// - Publishes global [`Event`] notifications and exposes per-download streams +/// via [`Download`]. +/// +/// # Backend injection +/// +/// By default (when the `reqwest` Cargo feature is enabled) the manager uses +/// [`ReqwestBackend`]. To supply a custom backend implement [`DownloadBackend`] +/// and use [`DownloadManager::with_backend`] or [`DownloadManager::new`]. +/// +/// ```rust,ignore +/// use bottles_download_manager::{DownloadManager, DownloadManagerConfig}; +/// use my_crate::MyBackend; +/// +/// // Custom backend, default config +/// let manager = DownloadManager::with_backend(MyBackend::new()); /// -/// Notes -/// - Events are delivered over a broadcast channel with a bounded buffer; slow consumers can miss events. -/// - Use events() to get a fallible-safe stream that drops lagged messages. -/// - Use shutdown() for a graceful stop: it cancels all work and waits for workers to finish. +/// // Custom backend + custom config +/// let config = DownloadManagerConfig::default(); +/// let manager = DownloadManager::new(config, MyBackend::new()); +/// ``` +/// +/// # Notes +/// - Events are delivered over a broadcast channel with a bounded buffer; slow +/// consumers can miss events. +/// - Use [`DownloadManager::events`] to get a stream that drops lagged messages +/// gracefully. +/// - Use [`DownloadManager::shutdown`] for a graceful stop: it cancels all work +/// and waits for workers to finish. pub struct DownloadManager { scheduler_tx: mpsc::Sender, ctx: Arc, @@ -50,6 +78,9 @@ pub struct DownloadManager { shutdown_token: CancellationToken, } +/// `Default` is only available when the `reqwest` feature is enabled, because +/// it requires a concrete backend to be chosen automatically. +#[cfg(feature = "reqwest")] impl Default for DownloadManager { #[instrument(level = "debug")] fn default() -> Self { @@ -58,16 +89,48 @@ impl Default for DownloadManager { } impl DownloadManager { - /// Create a new builder for DownloadManager. + /// Create a manager with a custom [`DownloadManagerConfig`] and the default + /// [`ReqwestBackend`]. /// - /// You must set a positive max_concurrent on the builder before build(). - /// If you want a sensible default quickly, see [DownloadManager::default()]. + /// Only available when the `reqwest` Cargo feature is enabled (it is on by + /// default). For a fully custom backend use [`DownloadManager::new`]. + #[cfg(feature = "reqwest")] #[instrument(level = "info", skip(config))] pub fn with_config(config: DownloadManagerConfig) -> DownloadManager { + use crate::backend::ReqwestBackend; + DownloadManager::new(config, ReqwestBackend::default()) + } + + /// Create a manager with the default [`DownloadManagerConfig`] and a custom + /// backend. + /// + /// This is a convenience wrapper around [`DownloadManager::new`] for when + /// you only need to override the backend. + /// + /// ```rust,ignore + /// let manager = DownloadManager::with_backend(MyBackend::new()); + /// ``` + #[instrument(level = "info", skip(backend))] + pub fn with_backend(backend: B) -> DownloadManager { + DownloadManager::new(DownloadManagerConfig::default(), backend) + } + + /// Create a manager with both a custom [`DownloadManagerConfig`] and a + /// custom backend. + /// + /// This is the most general constructor; all other constructors delegate + /// here. + /// + /// # Panics + /// + /// Does not panic. The config builder can return an error only if + /// `max_concurrent` is 0; [`DownloadManagerConfig::default`] sets it to 3. + #[instrument(level = "info", skip(config, backend))] + pub fn new(config: DownloadManagerConfig, backend: B) -> DownloadManager { let (cmd_tx, cmd_rx) = mpsc::channel(1024); let tracker = TaskTracker::new(); let shutdown_token = CancellationToken::new(); - let ctx = Context::new(config, shutdown_token.child_token()); + let ctx = Context::new(config, shutdown_token.child_token(), Arc::new(backend)); let scheduler = Scheduler::new(shutdown_token.clone(), ctx.clone(), tracker.clone(), cmd_rx); @@ -79,6 +142,7 @@ impl DownloadManager { }; tracker.spawn(async move { scheduler.run().await }); + let max = ctx.max_concurrent.load(Ordering::Relaxed); info!( max_concurrent = max, @@ -90,9 +154,12 @@ impl DownloadManager { /// Start a download with default request settings. /// - /// - Returns a [Download] handle which is also a Future yielding [DownloadResult] or [DownloadError]. - /// - You can stream progress and per-download events from the returned handle. - /// - Cancellation: call [Download::cancel()] on the handle, or [DownloadManager::cancel(id)]. + /// - Returns a [`Download`] handle which is also a `Future` yielding + /// [`DownloadResult`] or [`DownloadError`]. + /// - You can stream progress and per-download events from the returned + /// handle. + /// - Cancellation: call [`Download::cancel`] on the handle or + /// [`DownloadManager::cancel`] with its ID. #[instrument(level = "info", skip(self, destination), fields(url = %url))] pub fn download(&self, url: Url, destination: impl AsRef) -> anyhow::Result { self.download_builder() @@ -101,9 +168,11 @@ impl DownloadManager { .start() } - /// Create a [RequestBuilder] to customize a download (headers, retries, overwrite, callbacks). + /// Create a [`RequestBuilder`] to customise a download (headers, retries, + /// overwrite, callbacks). /// - /// Use this if you need non-default behavior or want to hook into progress/event callbacks before start(). + /// Use this if you need non-default behaviour or want to hook into + /// progress / event callbacks before `start()`. #[instrument(level = "debug", skip(self))] pub fn download_builder(&self) -> RequestBuilder { Request::builder(self) @@ -112,7 +181,8 @@ impl DownloadManager { /// Best-effort attempt to request cancellation for a download by ID. /// /// - No-op if the job is already finished or missing. - /// - Returns an error if the internal command channel is unavailable or the buffer is full. + /// - Returns an error if the internal command channel is unavailable or the + /// buffer is full. #[instrument(level = "info", skip(self), fields(id = id))] pub fn try_cancel(&self, id: DownloadID) -> anyhow::Result<()> { match self.scheduler_tx.try_send(SchedulerCmd::Cancel { id }) { @@ -147,7 +217,8 @@ impl DownloadManager { /// Number of currently active (running) downloads. /// - /// Does not include queued or delayed retries. Reflects active semaphore permits. + /// Does not include queued or delayed retries. Reflects active semaphore + /// permits. #[instrument(level = "trace", skip(self))] pub fn active_downloads(&self) -> usize { let n = self.ctx.active.load(Ordering::Relaxed); @@ -157,31 +228,35 @@ impl DownloadManager { /// Cancel all queued and in-flight downloads managed by this instance. /// - /// This triggers cooperative cancellation for workers and removes partial files. + /// This triggers cooperative cancellation for workers and removes partial + /// files. #[instrument(level = "info", skip(self))] pub fn cancel_all(&self) { info!("Cancelling all downloads"); self.ctx.cancel_all(); } - /// Return a child [CancellationToken] tied to the manager's root token. + /// Return a child [`CancellationToken`] tied to the manager's root token. #[instrument(level = "trace", skip(self))] pub fn child_token(&self) -> CancellationToken { self.ctx.child_token() } - /// Subscribe to all [DownloadEvent] notifications across the manager. + /// Subscribe to all [`Event`] notifications across the manager. /// - /// The underlying broadcast channel has a bounded buffer (1024). Slow consumers may lag and - /// miss events. Consider using [DownloadManager::events()] for a stream that skips lagged messages gracefully. + /// The underlying broadcast channel has a bounded buffer (1 024). Slow + /// consumers may lag and miss events. Consider using + /// [`DownloadManager::events`] for a stream that skips lagged messages + /// gracefully. #[instrument(level = "debug", skip(self))] pub fn subscribe(&self) -> broadcast::Receiver { self.ctx.events.subscribe() } - /// A fallible-safe stream of global [DownloadEvent] values. + /// A fallible-safe stream of global [`Event`] values. /// - /// Internally wraps the broadcast receiver and filters out lagged/closed errors. + /// Internally wraps the broadcast receiver and filters out lagged / closed + /// errors. #[instrument(level = "debug", skip(self))] pub fn events(&self) -> impl Stream + 'static { self.ctx.events.events() @@ -189,9 +264,12 @@ impl DownloadManager { /// Gracefully stop the manager. /// - /// - Cancels all in-flight work ([DownloadManager::cancel_all()]). - /// - Prevents new tasks from being scheduled and waits for all worker tasks to finish. - /// Call this before dropping the manager if you need deterministic teardown. + /// - Cancels all in-flight work ([`DownloadManager::cancel_all`]). + /// - Prevents new tasks from being scheduled and waits for all worker tasks + /// to finish. + /// + /// Call this before dropping the manager if you need deterministic + /// teardown. #[instrument(level = "info", skip(self))] pub async fn shutdown(&self) { info!("Shutting down DownloadManager"); @@ -204,6 +282,9 @@ impl DownloadManager { #[derive(Builder)] pub struct DownloadManagerConfig { + /// Maximum number of downloads that may run concurrently. + /// + /// Must be greater than 0. Defaults to 3. #[builder(default = 3, setter(custom))] max_concurrent: usize, } @@ -216,14 +297,16 @@ impl Default for DownloadManagerConfig { } impl DownloadManagerConfigBuilder { + /// Set the maximum concurrent downloads. + /// + /// Returns an error if `value` is 0. #[instrument(level = "debug", skip(self))] fn max_concurrent(&mut self, value: usize) -> anyhow::Result<&mut Self> { - let value = (value != 0).then(|| value).ok_or(anyhow::anyhow!( - "Max concurrent downloads must be set and greater than 0" - ))?; + let value = (value != 0) + .then_some(value) + .ok_or_else(|| anyhow::anyhow!("Max concurrent downloads must be greater than 0"))?; self.max_concurrent = Some(value); - Ok(self) } } diff --git a/src/error.rs b/src/error.rs index 7cbf9e4..bea82ab 100644 --- a/src/error.rs +++ b/src/error.rs @@ -2,22 +2,35 @@ use std::path::PathBuf; use thiserror::Error; use tracing::instrument; -#[derive(Error, Debug)] +use crate::backend::BackendError; + +#[derive(Debug, Error)] pub enum DownloadError { + /// A transport-level error reported by the active [`crate::backend::DownloadBackend`]. + /// + /// Retryability is determined by [`BackendError::is_retryable`] so the + /// scheduler does not need to know which backend is in use. #[error("Network error: {0}")] - Network(#[from] reqwest::Error), + Network(BackendError), + #[error("I/O error: {0}")] Io(#[from] std::io::Error), + #[error("Download was cancelled")] Cancelled, + #[error("Retry limit exceeded: {last_error}")] RetriesExhausted { last_error: Box }, + #[error("Download manager has been shut down")] ManagerShutdown, + #[error("File already exists: {path}")] FileExists { path: PathBuf }, + #[error("Invalid URL: {0}")] InvalidUrl(String), + #[error("Unknown error: {0}")] Unknown(String), } @@ -25,21 +38,16 @@ pub enum DownloadError { impl DownloadError { /// Classify whether this error should be retried by the scheduler. /// - /// Returns true for transient reqwest errors (timeout, connect, request) and HTTP 5xx. - /// If the HTTP status is unavailable, the error is treated as retryable by default. - /// Returns false for Cancelled, Io, and other non-transient variants. + /// For [`DownloadError::Network`] variants the decision is delegated to + /// [`BackendError::is_retryable`], keeping retry logic co-located with the + /// error classification that backends are responsible for. + /// + /// [`DownloadError::Cancelled`] and [`DownloadError::Io`] are never + /// retried; all other variants are treated as terminal by default. #[instrument(level = "trace", skip(self))] pub fn is_retryable(&self) -> bool { match self { - Self::Network(network_err) => { - network_err.is_timeout() - || network_err.is_connect() - || network_err.is_request() - || network_err - .status() - .map(|status_code| status_code.is_server_error()) - .unwrap_or(true) - } + Self::Network(backend_err) => backend_err.is_retryable(), Self::Cancelled | Self::Io(_) => false, _ => false, } diff --git a/src/events.rs b/src/events.rs index c33fd91..258b921 100644 --- a/src/events.rs +++ b/src/events.rs @@ -1,10 +1,10 @@ use crate::DownloadID; use crate::download::RemoteInfo; -use reqwest::Url; use std::path::PathBuf; use std::time::{Duration, Instant}; use tokio::sync::broadcast; use tracing::{debug, warn}; +use url::Url; #[derive(Debug, Clone)] pub(crate) struct EventBus(broadcast::Sender); diff --git a/src/request.rs b/src/request.rs index d0b3a15..b361794 100644 --- a/src/request.rs +++ b/src/request.rs @@ -3,10 +3,7 @@ use crate::{ scheduler::SchedulerCmd, }; use derive_builder::Builder; -use reqwest::{ - Url, - header::{HeaderMap, IntoHeaderName}, -}; +use http::{HeaderMap, header::IntoHeaderName}; use std::{ path::{Path, PathBuf}, sync::Arc, @@ -14,6 +11,7 @@ use std::{ use tokio::sync::{mpsc, oneshot, watch}; use tokio_util::sync::CancellationToken; use tracing::{debug, instrument, trace}; +use url::Url; /// Immutable description of a single download request. /// @@ -157,7 +155,7 @@ impl RequestBuilder { /// Convenience for setting the User-Agent header. pub fn user_agent(self, user_agent: impl AsRef) -> Self { - self.header(reqwest::header::USER_AGENT, user_agent) + self.header(http::header::USER_AGENT, user_agent) } /// Control whether an existing destination file may be overwritten. diff --git a/src/worker.rs b/src/worker.rs index 3140645..27510a9 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -1,12 +1,12 @@ use std::sync::Arc; -use reqwest::{Client, Method}; +use futures_util::StreamExt as _; use tokio::{fs::File, io::AsyncWriteExt, sync::mpsc}; -use tracing::{debug, error, info, instrument, trace, warn}; +use tracing::{debug, error, info, instrument, warn}; use crate::{ + backend::BackendError, context::{Context, DownloadID}, - download::RemoteInfo, error::DownloadError, events::{Event, Progress}, prelude::DownloadResult, @@ -26,7 +26,8 @@ pub(crate) async fn run( ctx: Arc, worker_tx: mpsc::Sender, ) { - let result = attempt_download(request.as_ref(), ctx.client.clone()).await; + let result = attempt_download(request.as_ref(), &ctx).await; + if result.is_ok() { info!(id = %request.id(), "Download attempt finished successfully"); } else { @@ -41,55 +42,26 @@ pub(crate) async fn run( .await; } -#[instrument(level = "debug", skip(request, client), fields(id = %request.id(), url = %request.url()))] -pub(crate) async fn probe_head(request: &Request, client: &Client) -> Option { - use reqwest::header; - debug!("Probing remote with HTTP HEAD"); - let req = client - .request(Method::HEAD, request.url().as_ref()) - .headers(request.config().headers().clone()) - .send(); - - let resp = tokio::select! { - resp = req => resp.ok()?.error_for_status().ok()?, - _ = request.cancel_token.cancelled() => return None, - }; - - let headers = resp.headers(); - let content_length = resp.content_length(); - trace!(content_length = ?content_length, "Got HEAD response"); - let accept_ranges = headers - .get(header::ACCEPT_RANGES) - .and_then(|v| v.to_str().ok()) - .map(|s| s.to_string()); - let etag = headers - .get(header::ETAG) - .and_then(|v| v.to_str().ok()) - .map(|s| s.to_string()); - let last_modified = headers - .get(header::LAST_MODIFIED) - .and_then(|v| v.to_str().ok()) - .map(|s| s.to_string()); - let content_type = headers - .get(header::CONTENT_TYPE) - .and_then(|v| v.to_str().ok()) - .map(|s| s.to_string()); - - Some(RemoteInfo { - content_length, - accept_ranges, - etag, - last_modified, - content_type, - }) -} - -#[instrument(level = "info", skip(request, client), fields(id = %request.id(), url = %request.url(), destination = ?request.destination()))] +#[instrument( + level = "info", + skip(request, ctx), + fields(id = %request.id(), url = %request.url(), destination = ?request.destination()) +)] pub(crate) async fn attempt_download( request: &Request, - client: Client, + ctx: &Context, ) -> Result { - if let Some(info) = probe_head(request, &client).await { + // The backend handles cancellation internally: if the token fires before + // the server responds it returns None, and no Probed event is emitted. + if let Some(info) = ctx + .backend + .probe_head( + request.url(), + request.config().headers(), + &request.cancel_token, + ) + .await + { request.emit(Event::Probed { id: request.id(), info, @@ -100,25 +72,39 @@ pub(crate) async fn attempt_download( tokio::fs::create_dir_all(parent).await?; } if request.destination().exists() && !request.config().overwrite() { - warn!(destination = ?request.destination(), "Destination exists and overwrite=false; failing"); + warn!( + destination = ?request.destination(), + "Destination exists and overwrite=false; failing" + ); return Err(DownloadError::FileExists { path: request.destination().to_path_buf(), }); } - let req = client - .request(Method::GET, request.url().as_ref()) - .headers(request.config().headers().clone()) - .send(); - - let mut response = tokio::select! { - resp = req => Ok(resp?.error_for_status()?), - _ = request.cancel_token.cancelled() => Err(DownloadError::Cancelled), - }?; - let total_bytes = response.content_length(); - debug!(total_bytes = ?total_bytes, "Server accepted download"); - + // + // If the cancellation token fires before response headers are received the + // backend returns Err(BackendError::Cancelled), which we surface as + // DownloadError::Cancelled rather than DownloadError::Network so the + // scheduler does not retry it. + let backend_resp = ctx + .backend + .fetch( + request.url(), + request.config().headers(), + &request.cancel_token, + ) + .await + .map_err(|e| match e { + BackendError::Cancelled => DownloadError::Cancelled, + other => DownloadError::Network(other), + })?; + + let total_bytes = backend_resp.content_length; + debug!(total_bytes = ?total_bytes, "Server accepted GET request"); + + let mut stream = backend_resp.stream; let mut file = File::create(request.destination()).await?; + request.emit(Event::Started { id: request.id(), url: request.url().clone(), @@ -127,29 +113,37 @@ pub(crate) async fn attempt_download( }); let mut progress = Progress::new(total_bytes); + loop { tokio::select! { _ = request.cancel_token.cancelled() => { - warn!(destination = ?request.destination(), "Cancellation received; cleaning up partial file"); + warn!( + destination = ?request.destination(), + "Cancellation received during streaming; removing partial file" + ); drop(file); tokio::fs::remove_file(request.destination()).await?; return Err(DownloadError::Cancelled); } - chunk = response.chunk() => { + chunk = stream.next() => { match chunk { - Ok(Some(chunk)) => { - file.write_all(&chunk).await?; - if progress.update(chunk.len() as u64) { + Some(Ok(bytes)) => { + file.write_all(&bytes).await?; + if progress.update(bytes.len() as u64) { request.update_progress(progress); } } - Ok(None) => break, - Err(e) => { - error!(error = %e, destination = ?request.destination(), "Error while reading response chunk; removing partial file"); + Some(Err(e)) => { + error!( + error = %e, + destination = ?request.destination(), + "Error reading response chunk; removing partial file" + ); drop(file); tokio::fs::remove_file(request.destination()).await?; - return Err(e.into()); + return Err(DownloadError::Network(e)); } + None => break, } } } @@ -158,7 +152,12 @@ pub(crate) async fn attempt_download( progress.force_update(); let _ = request.update_progress(progress); file.sync_all().await?; - info!(destination = ?request.destination(), bytes = progress.bytes_downloaded(), "Download completed successfully"); + + info!( + destination = ?request.destination(), + bytes = progress.bytes_downloaded(), + "Download completed successfully" + ); Ok(DownloadResult { path: request.destination().to_path_buf(),