diff --git a/crates/barbacane-wasm/src/http_client.rs b/crates/barbacane-wasm/src/http_client.rs index 9416d99..135df0d 100644 --- a/crates/barbacane-wasm/src/http_client.rs +++ b/crates/barbacane-wasm/src/http_client.rs @@ -4,7 +4,7 @@ use std::collections::HashMap; use std::hash::{Hash, Hasher}; -use std::net::IpAddr; +use std::net::{IpAddr, SocketAddr}; use std::path::PathBuf; use std::sync::Arc; use std::time::Duration; @@ -116,6 +116,11 @@ impl HttpClient { // Disable redirect following: a permitted host could otherwise 3xx // to an internal/metadata target, bypassing the SSRF guard below. .redirect(reqwest::redirect::Policy::none()) + // Enforce the SSRF guard at connect-time resolution to close the + // DNS-rebinding TOCTOU window. + .dns_resolver(Arc::new(GuardedResolver { + allow_internal: config.allow_internal_egress, + })) .build() .map_err(HttpClientError::BuildError)?; @@ -165,7 +170,11 @@ impl HttpClient { .pool_idle_timeout(self.base_config.pool_idle_timeout) .connect_timeout(self.base_config.connect_timeout) .timeout(self.base_config.default_timeout) - .redirect(reqwest::redirect::Policy::none()); + .redirect(reqwest::redirect::Policy::none()) + // Enforce the SSRF guard at connect-time resolution (DNS-rebinding). + .dns_resolver(Arc::new(GuardedResolver { + allow_internal: self.base_config.allow_internal_egress, + })); // Add client certificate (mTLS) if let (Some(cert_path), Some(key_path)) = (&tls_config.client_cert, &tls_config.client_key) @@ -544,6 +553,54 @@ fn is_forbidden_outbound_header(name: &str) -> bool { FORBIDDEN.iter().any(|h| name.eq_ignore_ascii_case(h)) } +/// A reqwest DNS resolver that enforces the SSRF guard at the moment of +/// resolution. reqwest connects to exactly the addresses this returns and does +/// not resolve again, which closes the DNS-rebinding TOCTOU window: a hostile +/// resolver cannot answer with a public IP for the pre-flight [`ssrf_guard`] +/// check and then rebind to an internal IP when the connection is made, because +/// the connect-time resolution is this one and it filters internal addresses. +#[derive(Debug, Clone)] +struct GuardedResolver { + allow_internal: bool, +} + +/// Keep only the addresses a plugin is permitted to connect to. With egress +/// disallowed, internal/loopback/link-local/metadata addresses are dropped; an +/// empty result means every resolved address was internal (the caller treats +/// that as blocked). +fn permitted_addrs( + resolved: impl Iterator, + allow_internal: bool, +) -> Vec { + if allow_internal { + resolved.collect() + } else { + resolved.filter(|a| !ip_is_internal(&a.ip())).collect() + } +} + +impl reqwest::dns::Resolve for GuardedResolver { + fn resolve(&self, name: reqwest::dns::Name) -> reqwest::dns::Resolving { + let allow_internal = self.allow_internal; + Box::pin(async move { + let host = name.as_str().to_string(); + // The port is irrelevant to resolution; reqwest applies the real + // target port to the returned IPs. + let resolved = tokio::net::lookup_host((host.as_str(), 0u16)).await?; + let addrs = permitted_addrs(resolved, allow_internal); + if addrs.is_empty() { + let err: Box = format!( + "no permitted address for host '{host}': all resolved \ + addresses are internal, or none were returned" + ) + .into(); + return Err(err); + } + Ok(Box::new(addrs.into_iter()) as reqwest::dns::Addrs) + }) + } +} + /// SSRF guard for the HTTP client: reject requests whose target resolves to an /// internal address. async fn ssrf_guard(url: &reqwest::Url, allow_internal: bool) -> Result<(), HttpClientError> { @@ -1141,6 +1198,28 @@ mod tests { assert_ne!(tls1.cache_key(), tls3.cache_key()); } + #[test] + fn permitted_addrs_filters_internal_unless_egress_allowed() { + let internal: SocketAddr = "127.0.0.1:0".parse().unwrap(); + let link_local: SocketAddr = "169.254.169.254:0".parse().unwrap(); + let external: SocketAddr = "8.8.8.8:0".parse().unwrap(); + + // Egress disallowed: internal + link-local (cloud metadata) dropped. + assert_eq!( + permitted_addrs(vec![internal, link_local, external].into_iter(), false), + vec![external] + ); + + // Disallowed and every address internal -> empty (caller blocks). + assert!(permitted_addrs(vec![internal, link_local].into_iter(), false).is_empty()); + + // Egress allowed: everything passes through. + assert_eq!( + permitted_addrs(vec![internal, external].into_iter(), true), + vec![internal, external] + ); + } + #[test] fn forbidden_outbound_headers_are_denied_case_insensitively() { for h in [