From 5593a08e2dd78edab80f37a4dc97dc09b6513535 Mon Sep 17 00:00:00 2001 From: Benalleng Date: Mon, 2 Mar 2026 11:32:39 -0500 Subject: [PATCH 1/3] Add internal Url type to replace url crate This commit implements the minimal native Url Struct and validation logic to be able to replace the Url dep from within payjoin excluding when utilizing any external Url interfaces like when using reqwest in the io feature. Co-authored-by: xstoicunicornx --- payjoin/src/core/url.rs | 746 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 746 insertions(+) create mode 100644 payjoin/src/core/url.rs diff --git a/payjoin/src/core/url.rs b/payjoin/src/core/url.rs new file mode 100644 index 000000000..7628bf04a --- /dev/null +++ b/payjoin/src/core/url.rs @@ -0,0 +1,746 @@ +//! Minimal URL type used internally by `payjoin`. +//! +//! This module provides a small, dependency-free URL parser that covers the +//! subset of RFC 3986 needed by the payjoin protocol (`http`, `https`, and +//! `bitcoin:` style URIs). It is not a full replacement for the `url` crate — +//! only the surface used by this library is implemented. +//! +//! The primary entry point is [`Url`], with parse errors surfaced through +//! [`ParseError`] (re-exported at the crate root as `UrlParseError`). + +use core::fmt; +use core::str::FromStr; + +/// A parsed URL. +/// +/// Construct one with [`Url::parse`] or via the [`FromStr`] impl. The parser +/// accepts an absolute URL of the form `scheme://host[:port][/path][?query][#fragment]`. +/// When no path is supplied, `/` is stored so that round-tripping through +/// [`Url::as_str`] always yields a normalised form. +/// +/// # Example +/// +/// ```ignore +/// use payjoin::UrlParseError; +/// # fn demo() -> Result<(), UrlParseError> { +/// let url: payjoin::Url = "https://example.com/pj?v=2".parse()?; +/// assert_eq!(url.scheme(), "https"); +/// assert_eq!(url.host_str(), "example.com"); +/// assert_eq!(url.path(), "/pj"); +/// assert_eq!(url.query(), Some("v=2")); +/// # Ok(()) +/// # } +/// ``` +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Url { + raw: String, + scheme: String, + host: Host, + port: Option, + path: String, + query: Option, + fragment: Option, +} + +/// Iterator over the `/`-separated segments of a URL path. +/// +/// Returned by [`Url::path_segments`]. The leading `/` is stripped before +/// splitting, so `"/a/b"` yields `["a", "b"]` and `"/"` yields no segments. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct PathSegments<'a> { + segments: Vec<&'a str>, + index: usize, +} + +/// The host component of a [`Url`]. +/// +/// Parsed into one of three shapes depending on the input: a registered +/// domain name, a dotted-quad IPv4 address, or a bracketed IPv6 literal. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Host { + /// A registered domain name, e.g. `example.com`. + Domain(String), + /// An IPv4 address in network byte order. + Ipv4([u8; 4]), + /// An IPv6 address as eight 16-bit groups in network byte order. + Ipv6([u16; 8]), +} + +impl<'a> Iterator for PathSegments<'a> { + type Item = &'a str; + + fn next(&mut self) -> Option { + if self.index >= self.segments.len() { + None + } else { + let item = self.segments[self.index]; + self.index += 1; + Some(item) + } + } +} + +/// Mutable handle for appending path segments to a [`Url`]. +/// +/// Obtained via [`Url::path_segments_mut`]. The underlying URL's serialised +/// form is rebuilt when this handle is dropped. +pub struct PathSegmentsMut<'a> { + url: &'a mut Url, +} + +impl<'a> PathSegmentsMut<'a> { + /// Append a single path segment, inserting a `/` separator if needed. + /// + /// The segment is pushed verbatim — no percent-encoding is applied. + pub fn push(&mut self, segment: &str) { + if !self.url.path.ends_with('/') && !self.url.path.is_empty() { + self.url.path.push('/'); + } + self.url.path.push_str(segment); + } +} + +impl<'a> Drop for PathSegmentsMut<'a> { + fn drop(&mut self) { self.url.rebuild_raw(); } +} + +impl Url { + /// Return a mutable handle for appending path segments. + /// + /// Always returns `Some`; the `Option` mirrors the `url` crate's API so + /// callers can migrate without changes. + pub fn path_segments_mut(&mut self) -> Option> { + Some(PathSegmentsMut { url: self }) + } +} + +/// Mutable handle for appending `key=value` pairs to a [`Url`]'s query string. +/// +/// Obtained via [`Url::query_pairs_mut`]. Each [`append_pair`](Self::append_pair) +/// call rewrites the underlying URL's serialised form. +pub struct UrlQueryPairs<'a> { + url: &'a mut Url, +} + +impl<'a> UrlQueryPairs<'a> { + /// Append a single `key=value` pair to the query string. + /// + /// Key and value are written verbatim — the caller is responsible for any + /// percent-encoding. Returns `&mut self` to support fluent chaining. + pub fn append_pair(&mut self, key: &str, value: &str) -> &mut UrlQueryPairs<'a> { + let new_pair = format!("{}={}", key, value); + if let Some(ref mut query) = self.url.query { + query.push_str(&format!("&{}", new_pair)); + } else { + self.url.query = Some(new_pair); + } + self.url.rebuild_raw(); + self + } +} + +/// Errors produced by [`Url::parse`]. +/// +/// Re-exported at the crate root as `UrlParseError`. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ParseError { + /// The authority section had no host between `://` and the path. + EmptyHost, + /// The scheme was empty. + InvalidScheme, + /// The overall structure did not match `scheme://host...`. + InvalidFormat, + /// The port was present but did not parse as a `u16`. + InvalidPort, + /// The host was not a valid domain, IPv4 literal, or IPv6 literal. + InvalidHost, +} + +impl fmt::Display for ParseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ParseError::EmptyHost => write!(f, "empty host"), + ParseError::InvalidScheme => write!(f, "invalid scheme"), + ParseError::InvalidFormat => write!(f, "invalid format"), + ParseError::InvalidPort => write!(f, "invalid port"), + ParseError::InvalidHost => write!(f, "invalid host"), + } + } +} + +impl std::error::Error for ParseError {} + +impl FromStr for Url { + type Err = ParseError; + + fn from_str(s: &str) -> Result { Self::parse(s) } +} + +impl Url { + /// Parse an absolute URL. + /// + /// The input must be of the form `scheme://host[:port][/path][?query][#fragment]`. + /// An empty path is normalised to `/`. The scheme is lower-cased. + pub fn parse(input: &str) -> Result { + let (rest, scheme) = parse_scheme(input)?; + let (_rest, host, port, path, query, fragment) = + if let Some(rest) = rest.strip_prefix("://") { + let (rest, host) = parse_host(rest)?; + let (rest, port) = parse_port(rest).unwrap_or((rest, None)); + let (path, query, fragment) = parse_path_query_fragment(rest); + (rest, host, port, path, query, fragment) + } else { + return Err(ParseError::InvalidFormat); + }; + + let path = if path.is_empty() { "/".to_string() } else { path }; + + let mut url = Url { raw: String::new(), scheme, host, port, path, query, fragment }; + url.rebuild_raw(); + Ok(url) + } + + /// The URL's scheme, lower-cased (e.g. `"https"`). + pub fn scheme(&self) -> &str { &self.scheme } + + /// The host as a domain name, or `None` for IP literals. + pub fn domain(&self) -> Option<&str> { + match &self.host { + Host::Domain(s) => Some(s.as_str()), + _ => None, + } + } + + /// The host rendered as a string. + /// + /// Domains are returned as-is; IPv4 addresses as dotted-quad; IPv6 + /// addresses in `[…]` bracket form. + pub fn host_str(&self) -> String { + match &self.host { + Host::Domain(d) => d.clone(), + Host::Ipv4(octets) => + format!("{}.{}.{}.{}", octets[0], octets[1], octets[2], octets[3]), + Host::Ipv6(segs) => { + let s = segs.iter().map(|s| format!("{:x}", s)).collect::>().join(":"); + format!("[{}]", s) + } + } + } + + /// The explicit port, if one was given. + pub fn port(&self) -> Option { self.port } + + /// Replace the port. Pass `None` to remove it. + pub fn set_port(&mut self, port: Option) { + self.port = port; + self.rebuild_raw(); + } + + /// The path component, always starting with `/`. + pub fn path(&self) -> &str { &self.path } + + /// The fragment (without the leading `#`), if any. + pub fn fragment(&self) -> Option<&str> { self.fragment.as_deref() } + + /// Replace the fragment. Pass `None` to remove it. + pub fn set_fragment(&mut self, fragment: Option<&str>) { + self.fragment = fragment.map(|s| s.to_string()); + self.rebuild_raw(); + } + + /// Iterate over the `/`-separated path segments. + /// + /// Always returns `Some`; the `Option` mirrors the `url` crate's API. + pub fn path_segments(&self) -> Option> { + if self.path.is_empty() || self.path == "/" { + return Some(PathSegments { segments: vec![], index: 0 }); + } + let segments: Vec<&str> = self.path.trim_start_matches('/').split('/').collect(); + Some(PathSegments { segments, index: 0 }) + } + + /// The query string (without the leading `?`), if any. + pub fn query(&self) -> Option<&str> { self.query.as_deref() } + + /// Clear the query string. + pub fn clear_query(&mut self) { + self.query = None; + self.rebuild_raw(); + } + + /// Return a handle for appending `key=value` pairs to the query string. + pub fn query_pairs_mut(&mut self) -> UrlQueryPairs<'_> { UrlQueryPairs { url: self } } + + /// Return parsed query pairs as a Vec of Strings + pub fn query_pairs(&self) -> Vec<(String, String)> { + let Some(query) = &self.query else { return vec![] }; + query + .split('&') + .filter(|s| !s.is_empty()) + .filter_map(|pair| { + let (k, v) = pair.split_once('=')?; + let key = + percent_encoding_rfc3986::percent_decode_str(k).ok()?.decode_utf8().ok()?; + let val = + percent_encoding_rfc3986::percent_decode_str(v).ok()?.decode_utf8().ok()?; + Some((key.into_owned(), val.into_owned())) + }) + .collect() + } + + /// Resolve a reference against this URL per RFC 3986. + /// + /// - A `segment` with a scheme (`scheme://…`) is parsed as a new absolute URL. + /// - A segment starting with `/` replaces the path and clears the query + /// and fragment. + /// - Otherwise the segment is merged relative to the base, resolving + /// `.` and `..` dot-segments, and the query and fragment are cleared. + pub fn join(&self, segment: &str) -> Result { + // If the segment is a full URL (scheme://...), parse it independently. + // Only treat it as a full URL if no / appears before :// (i.e. in scheme position). + if let Some(pos) = segment.find("://") { + if !segment[..pos].contains('/') { + return Url::parse(segment); + } + } + + let mut new_url = self.clone(); + + if segment.starts_with('/') { + // Absolute path reference: replace entire path, clear query/fragment + new_url.path = segment.to_string(); + new_url.query = None; + new_url.fragment = None; + } else { + // Relative reference: merge per RFC 3986 + // Remove everything after the last '/' in the base path, then append segment + let base_path = + if let Some(pos) = new_url.path.rfind('/') { &new_url.path[..=pos] } else { "/" }; + let merged = format!("{}{}", base_path, segment); + + // Resolve dot segments + let mut output_segments: Vec<&str> = Vec::new(); + for part in merged.split('/') { + match part { + "." => {} + ".." => { + output_segments.pop(); + } + _ => output_segments.push(part), + } + } + new_url.path = output_segments.join("/"); + if !new_url.path.starts_with('/') { + new_url.path.insert(0, '/'); + } + new_url.query = None; + new_url.fragment = None; + } + + new_url.rebuild_raw(); + Ok(new_url) + } + + fn rebuild_raw(&mut self) { + let mut raw = String::new(); + raw.push_str(&self.scheme); + raw.push_str("://"); + raw.push_str(&self.host_str()); + + if let Some(port) = self.port { + raw.push(':'); + raw.push_str(&port.to_string()); + } + + raw.push_str(&self.path); + if let Some(ref query) = self.query { + raw.push('?'); + raw.push_str(query); + } + if let Some(ref fragment) = self.fragment { + raw.push('#'); + raw.push_str(fragment); + } + self.raw = raw; + } +} + +impl AsRef for Url { + fn as_ref(&self) -> &str { &self.raw } +} + +impl Url { + /// The URL in its serialised form. + /// + /// Equivalent to the [`Display`](fmt::Display) output and to the + /// [`AsRef`] impl. + pub fn as_str(&self) -> &str { &self.raw } +} + +impl fmt::Display for Url { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self.raw) } +} + +fn parse_scheme(input: &str) -> Result<(&str, String), ParseError> { + let chars = input.chars(); + let mut scheme = String::new(); + + for c in chars { + match c { + 'a'..='z' | 'A'..='Z' | '0'..='9' | '+' | '-' | '.' => { + scheme.push(c); + } + ':' => break, + _ => return Err(ParseError::InvalidScheme), + } + } + + if scheme.is_empty() { + return Err(ParseError::InvalidScheme); + } + + let scheme = scheme.to_lowercase(); + Ok((&input[scheme.len()..], scheme)) +} + +fn parse_host(input: &str) -> Result<(&str, Host), ParseError> { + // IPv6 literal: [xxxx:...] + if input.starts_with('[') { + let end = input.find(']').ok_or(ParseError::InvalidHost)?; + let ipv6_str = &input[1..end]; + let rest = &input[end + 1..]; + return Ok((rest, parse_ipv6(ipv6_str)?)); + } + + // Split at the first ':', '/', '?', or '#' to separate host from port/path/query/fragment + let mut end = input.len(); + for (i, c) in input.char_indices() { + if c == ':' || c == '/' || c == '?' || c == '#' { + end = i; + break; + } + } + let host_str = &input[..end]; + let rest = &input[end..]; + + if let Some(host) = try_parse_ipv4(host_str) { + return Ok((rest, host)); + } + + if host_str.is_empty() { + return Err(ParseError::EmptyHost); + } + if !host_str.chars().all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '.') { + return Err(ParseError::InvalidHost); + } + + Ok((rest, Host::Domain(host_str.to_string()))) +} + +fn try_parse_ipv4(s: &str) -> Option { + let parts: Vec<&str> = s.split('.').collect(); + if parts.len() != 4 { + return None; + } + let octets: [u8; 4] = [ + parts[0].parse().ok()?, + parts[1].parse().ok()?, + parts[2].parse().ok()?, + parts[3].parse().ok()?, + ]; + Some(Host::Ipv4(octets)) +} + +fn parse_ipv6(s: &str) -> Result { + let mut groups = [0u16; 8]; + if let Some((left, right)) = s.split_once("::") { + let left_parts = parse_ipv6_groups(left)?; + let right_parts = parse_ipv6_groups(right)?; + if left_parts.len() + right_parts.len() > 7 { + return Err(ParseError::InvalidHost); + } + for (i, &v) in left_parts.iter().enumerate() { + groups[i] = v; + } + let offset = 8 - right_parts.len(); + for (i, &v) in right_parts.iter().enumerate() { + groups[offset + i] = v; + } + } else { + let parts = parse_ipv6_groups(s)?; + if parts.len() != 8 { + return Err(ParseError::InvalidHost); + } + for (i, &v) in parts.iter().enumerate() { + groups[i] = v; + } + } + Ok(Host::Ipv6(groups)) +} + +fn parse_ipv6_groups(s: &str) -> Result, ParseError> { + if s.is_empty() { + return Ok(vec![]); + } + s.split(':').map(|p| u16::from_str_radix(p, 16).map_err(|_| ParseError::InvalidHost)).collect() +} + +fn parse_port(input: &str) -> Result<(&str, Option), ParseError> { + if !input.starts_with(':') { + return Ok((input, None)); + } + + let rest = &input[1..]; + let mut port_str = String::new(); + + for c in rest.chars() { + match c { + '0'..='9' => port_str.push(c), + '/' | '?' | '#' => break, + _ => return Err(ParseError::InvalidPort), + } + } + + if port_str.is_empty() { + return Ok((rest, None)); + } + + let port: u16 = port_str.parse().map_err(|_| ParseError::InvalidPort)?; + let remaining = &rest[port_str.len()..]; + Ok((remaining, Some(port))) +} + +fn parse_path_query_fragment(input: &str) -> (String, Option, Option) { + let (before_fragment, fragment) = match input.find('#') { + Some(pos) => (&input[..pos], Some(input[pos + 1..].to_string())), + None => (input, None), + }; + let (path, query) = match before_fragment.find('?') { + Some(pos) => + (before_fragment[..pos].to_string(), Some(before_fragment[pos + 1..].to_string())), + None => (before_fragment.to_string(), None), + }; + + (path, query, fragment) +} + +impl serde::Serialize for Url { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(&self.raw) + } +} + +impl<'de> serde::Deserialize<'de> for Url { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + Url::from_str(&s).map_err(serde::de::Error::custom) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_url_with_path() { + let url = Url::parse("https://example.com/path/to/resource").unwrap(); + assert_eq!(url.scheme(), "https"); + assert_eq!(url.domain(), Some("example.com")); + assert_eq!( + url.path_segments().map(|s| s.collect::>()), + Some(vec!["path", "to", "resource"]) + ); + } + + #[test] + fn test_set_fragment() { + let mut url = Url::parse("https://example.com/path").unwrap(); + url.set_fragment(Some("newfragment")); + assert_eq!(url.fragment(), Some("newfragment")); + assert!(url.as_ref().contains("#newfragment")); + } + + #[test] + fn test_join() { + let base = Url::parse("http://example.com/base/").unwrap(); + let joined = base.join("next").unwrap(); + assert_eq!(joined.path, "/base/next"); + } + + #[test] + fn test_parse_url_with_port_and_fragment() { + let input = "http://localhost:1234/PATH#FRAGMENT"; + let url = Url::parse(input).unwrap(); + assert_eq!(url.scheme(), "http"); + assert_eq!(url.domain(), Some("localhost")); + assert_eq!(url.port(), Some(1234)); + assert_eq!(url.path(), "/PATH"); + assert_eq!(url.fragment(), Some("FRAGMENT")); + assert_eq!(url.as_str(), "http://localhost:1234/PATH#FRAGMENT"); + } + + #[test] + fn test_empty_host_rejected() { + assert!(matches!(Url::parse("http:///path"), Err(ParseError::EmptyHost))); + } + + #[test] + fn test_path_segments_mut_push_adds_separator() { + let mut url = Url::parse("http://example.com/base").unwrap(); + { + let mut segs = url.path_segments_mut().unwrap(); + segs.push("child"); + } + assert_eq!(url.path(), "/base/child"); + assert_eq!(url.as_str(), "http://example.com/base/child"); + } + + #[test] + fn test_host_str() { + let url = Url::parse("http://example.com/").unwrap(); + assert_eq!(url.host_str(), "example.com".to_string()); + } + + #[test] + fn test_set_port() { + let mut url = Url::parse("http://example.com/path").unwrap(); + url.set_port(Some(9090)); + assert_eq!(url.port(), Some(9090)); + assert_eq!(url.as_str(), "http://example.com:9090/path"); + url.set_port(None); + assert_eq!(url.port(), None); + assert_eq!(url.as_str(), "http://example.com/path"); + } + + #[test] + fn test_path_segments_root() { + let url = Url::parse("http://example.com/").unwrap(); + let segs: Vec<_> = url.path_segments().unwrap().collect(); + assert!(segs.is_empty()); + + let url = Url::parse("http://example.com").unwrap(); + let segs: Vec<_> = url.path_segments().unwrap().collect(); + assert!(segs.is_empty()); + assert_eq!(url.as_str(), "http://example.com/"); + } + + #[test] + fn test_set_query() { + let mut url = Url::parse("http://example.com/path").unwrap(); + url.query_pairs_mut().append_pair("key", "value"); + assert_eq!(url.query(), Some("key=value")); + assert_eq!(url.as_str(), "http://example.com/path?key=value"); + url.clear_query(); + assert_eq!(url.query(), None); + assert_eq!(url.as_str(), "http://example.com/path"); + } + + #[test] + fn test_join_dot_segments() { + let base = Url::parse("http://example.com/a/b/c").unwrap(); + + let joined = base.join("./d").unwrap(); + assert_eq!(joined.path(), "/a/b/d"); + + let joined = base.join("../d").unwrap(); + assert_eq!(joined.path(), "/a/d"); + } + + #[test] + fn test_parse_query_and_fragment() { + let url = Url::parse("http://example.com/path?q=1#frag").unwrap(); + assert_eq!(url.path(), "/path"); + assert_eq!(url.query(), Some("q=1")); + assert_eq!(url.fragment(), Some("frag")); + } + + #[test] + fn test_parse_ipv4_with_port() { + let url = Url::parse("http://127.0.0.1:8080/path").unwrap(); + assert_eq!(url.host, Host::Ipv4([127, 0, 0, 1])); + assert_eq!(url.port(), Some(8080)); + assert_eq!(url.as_str(), "http://127.0.0.1:8080/path"); + } + + #[test] + fn test_parse_ipv6_full() { + let url = Url::parse("http://[2001:db8:85a3:0:0:8a2e:370:7334]/").unwrap(); + assert_eq!( + url.host, + Host::Ipv6([0x2001, 0x0db8, 0x85a3, 0x0000, 0x0000, 0x8a2e, 0x0370, 0x7334]) + ); + assert_eq!(url.as_str(), "http://[2001:db8:85a3:0:0:8a2e:370:7334]/"); + } + + #[test] + fn test_parse_ipv6_with_port() { + let url = Url::parse("http://[::1]:8080/path").unwrap(); + assert_eq!(url.host, Host::Ipv6([0, 0, 0, 0, 0, 0, 0, 1])); + assert_eq!(url.port(), Some(8080)); + assert_eq!(url.as_str(), "http://[0:0:0:0:0:0:0:1]:8080/path"); + } + + #[test] + fn test_parse_ipv6_unclosed_bracket() { + assert!(matches!(Url::parse("http://[::1/"), Err(ParseError::InvalidHost))); + } + + #[test] + fn test_ipv6_matches_std_parser() { + let url = Url::parse("http://[::1]/").unwrap(); + let std_addr: std::net::Ipv6Addr = "::1".parse().unwrap(); + assert_eq!(url.host, Host::Ipv6(std_addr.segments())); + assert_eq!(url.domain(), None); + assert_eq!(url.host_str(), "[0:0:0:0:0:0:0:1]".to_string()); + assert_eq!(url.as_str(), "http://[0:0:0:0:0:0:0:1]/"); + + let url = Url::parse("http://[1::1]/").unwrap(); + let std_addr: std::net::Ipv6Addr = "1::1".parse().unwrap(); + assert_eq!(url.host, Host::Ipv6(std_addr.segments())); + + let url = Url::parse("http://[1:2:3::4:5:6]/").unwrap(); + let std_addr: std::net::Ipv6Addr = "1:2:3::4:5:6".parse().unwrap(); + assert_eq!(url.host, Host::Ipv6(std_addr.segments())); + + let url = Url::parse("http://[1:2:3:4:5:6:7::]/").unwrap(); + let std_addr: std::net::Ipv6Addr = "1:2:3:4:5:6:7::".parse().unwrap(); + assert_eq!(url.host, Host::Ipv6(std_addr.segments())); + + let url = Url::parse("http://[2001:db8:85a3::8a2e:370:7334]/").unwrap(); + let std_addr: std::net::Ipv6Addr = "2001:db8:85a3::8a2e:370:7334".parse().unwrap(); + assert_eq!(url.host, Host::Ipv6(std_addr.segments())); + } + + #[test] + fn test_ipv4_matches_std_parser() { + let url = Url::parse("http://127.0.0.1/").unwrap(); + let std_addr: std::net::Ipv4Addr = "127.0.0.1".parse().unwrap(); + assert_eq!(url.host, Host::Ipv4(std_addr.octets())); + assert_eq!(url.domain(), None); + assert_eq!(url.host_str(), "127.0.0.1".to_string()); + assert_eq!(url.as_str(), "http://127.0.0.1/"); + + let url = Url::parse("http://192.168.1.1/").unwrap(); + let std_addr: std::net::Ipv4Addr = "192.168.1.1".parse().unwrap(); + assert_eq!(url.host, Host::Ipv4(std_addr.octets())); + + let url = Url::parse("http://0.0.0.0/").unwrap(); + let std_addr: std::net::Ipv4Addr = "0.0.0.0".parse().unwrap(); + assert_eq!(url.host, Host::Ipv4(std_addr.octets())); + + let url = Url::parse("http://255.255.255.255/").unwrap(); + let std_addr: std::net::Ipv4Addr = "255.255.255.255".parse().unwrap(); + assert_eq!(url.host, Host::Ipv4(std_addr.octets())); + } + + #[test] + fn test_parse_ipv6_too_many_groups_rejected() { + assert!(matches!(Url::parse("http://[1:2:3:4::5:6:7:8]/"), Err(ParseError::InvalidHost))); + } +} From 08ab90611c222494ce5003ba12a509edbecdbb4d Mon Sep 17 00:00:00 2001 From: Benalleng Date: Mon, 2 Mar 2026 11:36:54 -0500 Subject: [PATCH 2/3] Transition workspace to use native url struct This commit migrates the monorepo away from the external Url dep to use the new internal Url. Additionally due to the transition we need to add a dep for url encoding with `percent-encoding-rfc3986` which coincidentally get us inline with the bitcoin_uri crate. --- Cargo-minimal.lock | 4 +-- Cargo-recent.lock | 3 +-- payjoin-cli/Cargo.toml | 1 - payjoin-cli/src/app/config.rs | 13 +++++++--- payjoin-cli/src/app/v1.rs | 10 +++---- payjoin-cli/src/app/v2/mod.rs | 2 +- payjoin-cli/src/app/v2/ohttp.rs | 19 +++++++------- payjoin-cli/src/cli/mod.rs | 14 ++++++---- payjoin/Cargo.toml | 6 ++--- payjoin/src/core/into_url.rs | 26 +++++++++---------- payjoin/src/core/io.rs | 8 +++--- payjoin/src/core/mod.rs | 2 ++ payjoin/src/core/ohttp.rs | 14 +++++----- payjoin/src/core/receive/error.rs | 2 ++ payjoin/src/core/receive/mod.rs | 8 +++--- .../src/core/receive/optional_parameters.rs | 16 +++++++++--- payjoin/src/core/receive/v1/mod.rs | 10 +++---- payjoin/src/core/receive/v2/mod.rs | 10 +++---- payjoin/src/core/request.rs | 2 +- payjoin/src/core/send/mod.rs | 4 +-- payjoin/src/core/send/v2/mod.rs | 2 +- payjoin/src/core/send/v2/session.rs | 4 +-- payjoin/src/core/uri/mod.rs | 6 ++--- payjoin/src/core/uri/v1.rs | 3 +-- payjoin/src/core/uri/v2.rs | 2 +- payjoin/tests/integration.rs | 4 +-- 26 files changed, 103 insertions(+), 92 deletions(-) diff --git a/Cargo-minimal.lock b/Cargo-minimal.lock index e0b847ddc..d5fa27ee3 100644 --- a/Cargo-minimal.lock +++ b/Cargo-minimal.lock @@ -2768,13 +2768,13 @@ dependencies = [ "http", "once_cell", "payjoin-test-utils", + "percent-encoding-rfc3986", "reqwest", "rustls 0.23.37", "serde", "serde_json", "tokio", "tracing", - "url", "web-time", ] @@ -2805,7 +2805,6 @@ dependencies = [ "tokio-rustls 0.26.4", "tracing", "tracing-subscriber", - "url", ] [[package]] @@ -4980,7 +4979,6 @@ dependencies = [ "idna", "percent-encoding", "serde", - "serde_derive", ] [[package]] diff --git a/Cargo-recent.lock b/Cargo-recent.lock index 9c7295510..4fd70512f 100644 --- a/Cargo-recent.lock +++ b/Cargo-recent.lock @@ -2736,13 +2736,13 @@ dependencies = [ "http", "once_cell", "payjoin-test-utils", + "percent-encoding-rfc3986", "reqwest", "rustls 0.23.31", "serde", "serde_json", "tokio", "tracing", - "url", "web-time", ] @@ -2773,7 +2773,6 @@ dependencies = [ "tokio-rustls 0.26.2", "tracing", "tracing-subscriber", - "url", ] [[package]] diff --git a/payjoin-cli/Cargo.toml b/payjoin-cli/Cargo.toml index 7722f892f..35f8391ea 100644 --- a/payjoin-cli/Cargo.toml +++ b/payjoin-cli/Cargo.toml @@ -51,7 +51,6 @@ tokio-rustls = { version = "0.26.2", features = [ ], default-features = false, optional = true } tracing = "0.1.41" tracing-subscriber = { version = "0.3.19", features = ["env-filter"] } -url = { version = "2.5.4", features = ["serde"] } [dev-dependencies] nix = { version = "0.30.1", features = ["aio", "process", "signal"] } diff --git a/payjoin-cli/src/app/config.rs b/payjoin-cli/src/app/config.rs index 6b7f8acb8..cdef7ca6c 100644 --- a/payjoin-cli/src/app/config.rs +++ b/payjoin-cli/src/app/config.rs @@ -4,9 +4,8 @@ use anyhow::Result; use config::builder::DefaultState; use config::{ConfigError, File, FileFormat}; use payjoin::bitcoin::FeeRate; -use payjoin::Version; +use payjoin::{Url, Version}; use serde::Deserialize; -use url::Url; use crate::cli::{Cli, Commands}; use crate::db; @@ -317,10 +316,16 @@ fn handle_subcommands(config: Builder, cli: &Cli) -> Result { let query_string = req.uri().query().unwrap_or(""); tracing::trace!("{:?}, {query_string:?}", req.method()); - let query_params: HashMap<_, _> = - url::form_urlencoded::parse(query_string.as_bytes()).into_owned().collect(); + let query_url = payjoin::Url::parse(&format!("http://localhost/?{query_string}")) + .expect("valid query URL"); + let query_params: HashMap = + query_url.query_pairs().into_iter().collect(); let amount = query_params.get("amount").map(|amt| { Amount::from_btc(amt.parse().expect("Failed to parse amount")).unwrap() }); diff --git a/payjoin-cli/src/app/v2/mod.rs b/payjoin-cli/src/app/v2/mod.rs index 73f4dcf85..324121267 100644 --- a/payjoin-cli/src/app/v2/mod.rs +++ b/payjoin-cli/src/app/v2/mod.rs @@ -813,7 +813,7 @@ impl App { async fn unwrap_relay_or_else_fetch( &self, directory: Option, - ) -> Result { + ) -> Result { let directory = directory.map(|url| url.into_url()).transpose()?; let selected_relay = self.relay_manager.lock().expect("Lock should not be poisoned").get_selected_relay(); diff --git a/payjoin-cli/src/app/v2/ohttp.rs b/payjoin-cli/src/app/v2/ohttp.rs index 4ee637ecb..d6cb490f6 100644 --- a/payjoin-cli/src/app/v2/ohttp.rs +++ b/payjoin-cli/src/app/v2/ohttp.rs @@ -1,35 +1,36 @@ use std::sync::{Arc, Mutex}; use anyhow::{anyhow, Result}; +use payjoin::Url; use super::Config; #[derive(Debug, Clone)] pub struct RelayManager { - selected_relay: Option, - failed_relays: Vec, + selected_relay: Option, + failed_relays: Vec, } impl RelayManager { pub fn new() -> Self { RelayManager { selected_relay: None, failed_relays: Vec::new() } } - pub fn set_selected_relay(&mut self, relay: url::Url) { self.selected_relay = Some(relay); } + pub fn set_selected_relay(&mut self, relay: Url) { self.selected_relay = Some(relay); } - pub fn get_selected_relay(&self) -> Option { self.selected_relay.clone() } + pub fn get_selected_relay(&self) -> Option { self.selected_relay.clone() } - pub fn add_failed_relay(&mut self, relay: url::Url) { self.failed_relays.push(relay); } + pub fn add_failed_relay(&mut self, relay: Url) { self.failed_relays.push(relay); } - pub fn get_failed_relays(&self) -> Vec { self.failed_relays.clone() } + pub fn get_failed_relays(&self) -> Vec { self.failed_relays.clone() } } pub(crate) struct ValidatedOhttpKeys { pub(crate) ohttp_keys: payjoin::OhttpKeys, - pub(crate) relay_url: url::Url, + pub(crate) relay_url: Url, } pub(crate) async fn unwrap_ohttp_keys_or_else_fetch( config: &Config, - directory: Option, + directory: Option, relay_manager: Arc>, ) -> Result { if let Some(ohttp_keys) = config.v2()?.ohttp_keys.clone() { @@ -46,7 +47,7 @@ pub(crate) async fn unwrap_ohttp_keys_or_else_fetch( async fn fetch_ohttp_keys( config: &Config, - directory: Option, + directory: Option, relay_manager: Arc>, ) -> Result { use payjoin::bitcoin::secp256k1::rand::prelude::SliceRandom; diff --git a/payjoin-cli/src/cli/mod.rs b/payjoin-cli/src/cli/mod.rs index 1f35979b5..8b5e42fa7 100644 --- a/payjoin-cli/src/cli/mod.rs +++ b/payjoin-cli/src/cli/mod.rs @@ -3,8 +3,8 @@ use std::path::PathBuf; use clap::{value_parser, Parser, Subcommand}; use payjoin::bitcoin::amount::ParseAmountError; use payjoin::bitcoin::{Amount, FeeRate}; +use payjoin::Url; use serde::Deserialize; -use url::Url; #[derive(Debug, Clone, Deserialize, Parser)] pub struct Flags { @@ -114,13 +114,13 @@ pub enum Commands { #[cfg(feature = "v1")] /// The `pj=` endpoint to receive the payjoin request - #[arg(long = "pj-endpoint", value_parser = value_parser!(Url))] - pj_endpoint: Option, + #[arg(long = "pj-endpoint", value_parser = parse_boxed_url)] + pj_endpoint: Option>, #[cfg(feature = "v2")] /// The directory to store payjoin requests - #[arg(long = "pj-directory", value_parser = value_parser!(Url))] - pj_directory: Option, + #[arg(long = "pj-directory", value_parser = parse_boxed_url)] + pj_directory: Option>, #[cfg(feature = "v2")] /// The path to the ohttp keys file @@ -144,3 +144,7 @@ pub fn parse_fee_rate_in_sat_per_vb(s: &str) -> Result Result, String> { + s.parse::().map(Box::new).map_err(|e| e.to_string()) +} diff --git a/payjoin/Cargo.toml b/payjoin/Cargo.toml index b1d7c17b0..a2e0c87f2 100644 --- a/payjoin/Cargo.toml +++ b/payjoin/Cargo.toml @@ -26,7 +26,7 @@ _core = [ "bitcoin/rand-std", "dep:http", "serde_json", - "url/serde", + "dep:percent-encoding-rfc3986", "bitcoin_uri", "serde", "bitcoin/serde", @@ -46,6 +46,7 @@ bitcoin_uri = { version = "0.1.0", optional = true } hpke = { package = "bitcoin-hpke", version = "0.13.0", optional = true } http = { version = "1.3.1", optional = true } ohttp = { package = "bitcoin-ohttp", version = "0.6.0", optional = true } +percent-encoding-rfc3986 = { version = "0.1.3", optional = true } reqwest = { version = "0.12.23", default-features = false, optional = true } rustls = { version = "0.23.31", optional = true, default-features = false, features = [ "ring", @@ -53,9 +54,6 @@ rustls = { version = "0.23.31", optional = true, default-features = false, featu serde = { version = "1.0.219", default-features = false, optional = true } serde_json = { version = "1.0.142", optional = true } tracing = "0.1.41" -url = { version = "2.5.4", optional = true, default-features = false, features = [ - "serde", -] } [target.'cfg(target_arch = "wasm32")'.dependencies] web-time = "1.1.0" diff --git a/payjoin/src/core/into_url.rs b/payjoin/src/core/into_url.rs index fc4537bf5..9d45755a5 100644 --- a/payjoin/src/core/into_url.rs +++ b/payjoin/src/core/into_url.rs @@ -1,9 +1,9 @@ -use url::{ParseError, Url}; +use crate::core::{Url, UrlParseError}; #[derive(Debug, PartialEq, Eq)] pub enum Error { BadScheme, - ParseError(ParseError), + ParseError(UrlParseError), } impl std::fmt::Display for Error { @@ -19,8 +19,8 @@ impl std::fmt::Display for Error { impl std::error::Error for Error {} -impl From for Error { - fn from(err: ParseError) -> Error { Error::ParseError(err) } +impl From for Error { + fn from(err: UrlParseError) -> Error { Error::ParseError(err) } } type Result = core::result::Result; @@ -53,13 +53,7 @@ impl IntoUrlSealed for &Url { } impl IntoUrlSealed for Url { - fn into_url(self) -> Result { - if self.has_host() { - Ok(self) - } else { - Err(Error::BadScheme) - } - } + fn into_url(self) -> Result { Ok(self) } fn as_str(&self) -> &str { self.as_ref() } } @@ -101,13 +95,19 @@ mod tests { #[test] fn into_url_file_scheme() { let err = "file:///etc/hosts".into_url().unwrap_err(); - assert_eq!(err.to_string(), "URL scheme is not allowed"); + assert_eq!(err.to_string(), "empty host"); } #[test] fn into_url_blob_scheme() { let err = "blob:https://example.com".into_url().unwrap_err(); - assert_eq!(err.to_string(), "URL scheme is not allowed"); + assert_eq!(err.to_string(), "invalid format"); + } + + #[test] + fn into_url_rejects_userinfo() { + let err = "http://user@example.com/".into_url().unwrap_err(); + assert_eq!(err.to_string(), "invalid host"); } #[test] diff --git a/payjoin/src/core/io.rs b/payjoin/src/core/io.rs index d335a3fce..2cbc0010b 100644 --- a/payjoin/src/core/io.rs +++ b/payjoin/src/core/io.rs @@ -23,7 +23,7 @@ pub async fn fetch_ohttp_keys( let proxy = Proxy::all(ohttp_relay.into_url()?.as_str())?; let client = Client::builder().proxy(proxy).http1_only().build()?; let res = client - .get(ohttp_keys_url) + .get(ohttp_keys_url.as_str()) .timeout(Duration::from_secs(10)) .header(ACCEPT, "application/ohttp-keys") .send() @@ -56,7 +56,7 @@ pub async fn fetch_ohttp_keys_with_cert( .http1_only() .build()?; let res = client - .get(ohttp_keys_url) + .get(ohttp_keys_url.as_str()) .timeout(Duration::from_secs(10)) .header(ACCEPT, "application/ohttp-keys") .send() @@ -98,8 +98,8 @@ enum InternalErrorInner { InvalidOhttpKeys(String), } -impl From for Error { - fn from(value: url::ParseError) -> Self { +impl From for Error { + fn from(value: crate::core::UrlParseError) -> Self { Self::Internal(InternalError(InternalErrorInner::ParseUrl(value.into()))) } } diff --git a/payjoin/src/core/mod.rs b/payjoin/src/core/mod.rs index 425bbae91..ec64e9963 100644 --- a/payjoin/src/core/mod.rs +++ b/payjoin/src/core/mod.rs @@ -16,6 +16,8 @@ pub mod send; pub use request::*; pub(crate) mod into_url; pub use into_url::{Error as IntoUrlError, IntoUrl}; +pub(crate) mod url; +pub use url::{ParseError as UrlParseError, Url}; #[cfg(feature = "v2")] pub mod time; pub mod uri; diff --git a/payjoin/src/core/ohttp.rs b/payjoin/src/core/ohttp.rs index 2036a9f5a..2fbd0d7d4 100644 --- a/payjoin/src/core/ohttp.rs +++ b/payjoin/src/core/ohttp.rs @@ -23,14 +23,14 @@ pub(crate) fn ohttp_encapsulate( let mut ohttp_keys = ohttp_keys.clone(); let ctx = ohttp::ClientRequest::from_config(&mut ohttp_keys)?; - let url = url::Url::parse(target_resource)?; - let authority_bytes = url.host().map_or_else(Vec::new, |host| { - let mut authority = host.to_string(); + let url = crate::core::Url::parse(target_resource)?; + let authority_bytes = { + let mut authority = url.host_str(); if let Some(port) = url.port() { write!(authority, ":{port}").unwrap(); } authority.into_bytes() - }); + }; let mut bhttp_message = bhttp::Message::request( method.as_bytes().to_vec(), url.scheme().as_bytes().to_vec(), @@ -164,7 +164,7 @@ pub enum OhttpEncapsulationError { Http(http::Error), Ohttp(ohttp::Error), Bhttp(bhttp::Error), - ParseUrl(url::ParseError), + ParseUrl(crate::core::UrlParseError), } impl From for OhttpEncapsulationError { @@ -179,8 +179,8 @@ impl From for OhttpEncapsulationError { fn from(value: bhttp::Error) -> Self { Self::Bhttp(value) } } -impl From for OhttpEncapsulationError { - fn from(value: url::ParseError) -> Self { Self::ParseUrl(value) } +impl From for OhttpEncapsulationError { + fn from(value: crate::core::UrlParseError) -> Self { Self::ParseUrl(value) } } impl fmt::Display for OhttpEncapsulationError { diff --git a/payjoin/src/core/receive/error.rs b/payjoin/src/core/receive/error.rs index 5d9b02083..d7bd23e39 100644 --- a/payjoin/src/core/receive/error.rs +++ b/payjoin/src/core/receive/error.rs @@ -243,6 +243,8 @@ impl From<&PayloadError> for JsonReply { } super::optional_parameters::Error::FeeRate => JsonReply::new(OriginalPsbtRejected, e), + super::optional_parameters::Error::MalformedQuery => + JsonReply::new(OriginalPsbtRejected, e), }, } } diff --git a/payjoin/src/core/receive/mod.rs b/payjoin/src/core/receive/mod.rs index a2be3bc0e..2790a6325 100644 --- a/payjoin/src/core/receive/mod.rs +++ b/payjoin/src/core/receive/mod.rs @@ -239,8 +239,7 @@ pub(crate) fn parse_payload( let psbt = unchecked_psbt.validate().map_err(InternalPayloadError::InconsistentPsbt)?; tracing::trace!("Received original psbt: {psbt:?}"); - let pairs = url::form_urlencoded::parse(query.as_bytes()); - let params = Params::from_query_pairs(pairs, supported_versions) + let params = Params::from_query_str(query, supported_versions) .map_err(InternalPayloadError::SenderParams)?; tracing::trace!("Received request with params: {params:?}"); @@ -484,9 +483,8 @@ pub(crate) mod tests { use crate::psbt::NON_WITNESS_INPUT_WEIGHT; pub(crate) fn original_from_test_vector() -> OriginalPayload { - let pairs = url::form_urlencoded::parse(QUERY_PARAMS.as_bytes()); - let params = Params::from_query_pairs(pairs, &[Version::One]) - .expect("Could not parse params from query pairs"); + let params = Params::from_query_str(QUERY_PARAMS, &[Version::One]) + .expect("Could not parse params from query str"); OriginalPayload { psbt: PARSED_ORIGINAL_PSBT.clone(), params } } diff --git a/payjoin/src/core/receive/optional_parameters.rs b/payjoin/src/core/receive/optional_parameters.rs index 0da46e00a..bb2df9bb0 100644 --- a/payjoin/src/core/receive/optional_parameters.rs +++ b/payjoin/src/core/receive/optional_parameters.rs @@ -121,12 +121,22 @@ impl Params { tracing::trace!("parsed optional parameters: {params:?}"); Ok(params) } + + pub fn from_query_str( + query: &str, + supported_versions: &'static [Version], + ) -> Result { + let url = crate::Url::parse(&format!("http://localhost/?{query}")) + .map_err(|_| Error::MalformedQuery)?; + Self::from_query_pairs(url.query_pairs().into_iter(), supported_versions) + } } #[derive(Debug, PartialEq, Eq)] pub(crate) enum Error { UnknownVersion { supported_versions: &'static [Version] }, FeeRate, + MalformedQuery, } impl fmt::Display for Error { @@ -134,6 +144,7 @@ impl fmt::Display for Error { match self { Error::UnknownVersion { .. } => write!(f, "unknown version"), Error::FeeRate => write!(f, "could not parse feerate"), + Error::MalformedQuery => write!(f, "malformed query parameter encoding"), } } } @@ -152,9 +163,8 @@ pub(crate) mod test { #[test] fn test_parse_params() { - let pairs = url::form_urlencoded::parse(b"&maxadditionalfeecontribution=182&additionalfeeoutputindex=0&minfeerate=2&disableoutputsubstitution=true&optimisticmerge=true"); - let params = Params::from_query_pairs(pairs, &[Version::One]) - .expect("Could not parse params from query pairs"); + let params = Params::from_query_str("&maxadditionalfeecontribution=182&additionalfeeoutputindex=0&minfeerate=2&disableoutputsubstitution=true&optimisticmerge=true", &[Version::One]) + .expect("Could not parse params from query str"); assert_eq!(params.v, Version::One); assert_eq!(params.output_substitution, OutputSubstitution::Disabled); assert_eq!(params.additional_fee_contribution, Some((Amount::from_sat(182), 0))); diff --git a/payjoin/src/core/receive/v1/mod.rs b/payjoin/src/core/receive/v1/mod.rs index 7c2e937c3..99e8aa3a9 100644 --- a/payjoin/src/core/receive/v1/mod.rs +++ b/payjoin/src/core/receive/v1/mod.rs @@ -405,18 +405,16 @@ mod tests { } fn unchecked_proposal_from_test_vector() -> UncheckedOriginalPayload { - let pairs = url::form_urlencoded::parse(QUERY_PARAMS.as_bytes()); - let params = Params::from_query_pairs(pairs, &[Version::One]) - .expect("Could not parse params from query pairs"); + let params = Params::from_query_str(QUERY_PARAMS, &[Version::One]) + .expect("Could not parse params from query str"); UncheckedOriginalPayload { original: OriginalPayload { psbt: PARSED_ORIGINAL_PSBT.clone(), params }, } } fn maybe_inputs_owned_from_test_vector() -> MaybeInputsOwned { - let pairs = url::form_urlencoded::parse(QUERY_PARAMS.as_bytes()); - let params = Params::from_query_pairs(pairs, &[Version::One]) - .expect("Could not parse params from query pairs"); + let params = Params::from_query_str(QUERY_PARAMS, &[Version::One]) + .expect("Could not parse params from query str"); MaybeInputsOwned { original: OriginalPayload { psbt: PARSED_ORIGINAL_PSBT.clone(), params }, } diff --git a/payjoin/src/core/receive/v2/mod.rs b/payjoin/src/core/receive/v2/mod.rs index 80078ca6c..90fabd4ba 100644 --- a/payjoin/src/core/receive/v2/mod.rs +++ b/payjoin/src/core/receive/v2/mod.rs @@ -39,7 +39,6 @@ pub use session::{ replay_event_log, replay_event_log_async, SessionEvent, SessionHistory, SessionOutcome, SessionStatus, }; -use url::Url; #[cfg(target_arch = "wasm32")] use web_time::Duration; @@ -47,6 +46,7 @@ use super::error::{Error, InputContributionError}; use super::{ common, InternalPayloadError, JsonReply, OutputSubstitutionError, ProtocolError, SelectionError, }; +use crate::core::Url; use crate::error::{InternalReplayError, ReplayError}; use crate::hpke::{decrypt_message_a, encrypt_message_b, HpkeKeyPair, HpkePublicKey}; use crate::ohttp::{ @@ -73,7 +73,7 @@ static TWENTY_FOUR_HOURS_DEFAULT_EXPIRATION: Duration = Duration::from_secs(60 * pub struct SessionContext { #[serde(deserialize_with = "deserialize_address_assume_checked")] address: Address, - directory: url::Url, + directory: Url, ohttp_keys: OhttpKeys, expiration: Time, amount: Option, @@ -1463,8 +1463,7 @@ pub mod test { }); pub(crate) fn unchecked_proposal_v2_from_test_vector() -> UncheckedOriginalPayload { - let pairs = url::form_urlencoded::parse(QUERY_PARAMS.as_bytes()); - let params = Params::from_query_pairs(pairs, &[Version::Two]) + let params = Params::from_query_str(QUERY_PARAMS, &[Version::Two]) .expect("Test utils query params should not fail"); UncheckedOriginalPayload { original: OriginalPayload { psbt: PARSED_ORIGINAL_PSBT.clone(), params }, @@ -1472,8 +1471,7 @@ pub mod test { } pub(crate) fn maybe_inputs_owned_v2_from_test_vector() -> MaybeInputsOwned { - let pairs = url::form_urlencoded::parse(QUERY_PARAMS.as_bytes()); - let params = Params::from_query_pairs(pairs, &[Version::Two]) + let params = Params::from_query_str(QUERY_PARAMS, &[Version::Two]) .expect("Test utils query params should not fail"); MaybeInputsOwned { original: OriginalPayload { psbt: PARSED_ORIGINAL_PSBT.clone(), params }, diff --git a/payjoin/src/core/request.rs b/payjoin/src/core/request.rs index b4db22acb..abe51b611 100644 --- a/payjoin/src/core/request.rs +++ b/payjoin/src/core/request.rs @@ -1,4 +1,4 @@ -use url::Url; +use crate::core::Url; #[cfg(feature = "v1")] const V1_REQ_CONTENT_TYPE: &str = "text/plain"; diff --git a/payjoin/src/core/send/mod.rs b/payjoin/src/core/send/mod.rs index 276d928b4..5aa3521d8 100644 --- a/payjoin/src/core/send/mod.rs +++ b/payjoin/src/core/send/mod.rs @@ -20,8 +20,8 @@ use bitcoin::psbt::Psbt; use bitcoin::{Amount, FeeRate, Script, ScriptBuf, TxOut, Weight}; pub use error::{BuildSenderError, ResponseError, ValidationError, WellKnownError}; pub(crate) use error::{InternalBuildSenderError, InternalProposalError, InternalValidationError}; -use url::Url; +use crate::core::Url; use crate::output_substitution::OutputSubstitution; use crate::psbt::{AddressTypeError, PsbtExt, NON_WITNESS_INPUT_WEIGHT}; use crate::Version; @@ -684,9 +684,9 @@ mod test { BoxError, PARSED_ORIGINAL_PSBT, PARSED_PAYJOIN_PROPOSAL, PARSED_PAYJOIN_PROPOSAL_WITH_SENDER_INFO, }; - use url::Url; use super::*; + use crate::core::Url; use crate::output_substitution::OutputSubstitution; use crate::psbt::PsbtExt; use crate::send::{AdditionalFeeContribution, InternalBuildSenderError, InternalProposalError}; diff --git a/payjoin/src/core/send/v2/mod.rs b/payjoin/src/core/send/v2/mod.rs index 7a36553e3..a2daef428 100644 --- a/payjoin/src/core/send/v2/mod.rs +++ b/payjoin/src/core/send/v2/mod.rs @@ -38,10 +38,10 @@ pub use session::{ replay_event_log, replay_event_log_async, SessionEvent, SessionHistory, SessionOutcome, SessionStatus, }; -use url::Url; use super::error::BuildSenderError; use super::*; +use crate::core::Url; use crate::error::{InternalReplayError, ReplayError}; use crate::hpke::{decrypt_message_b, encrypt_message_a, HpkeSecretKey}; use crate::ohttp::{ohttp_encapsulate, process_get_res, process_post_res}; diff --git a/payjoin/src/core/send/v2/session.rs b/payjoin/src/core/send/v2/session.rs index 1c78c7827..8888243bb 100644 --- a/payjoin/src/core/send/v2/session.rs +++ b/payjoin/src/core/send/v2/session.rs @@ -172,9 +172,9 @@ pub enum SessionOutcome { mod tests { use bitcoin::{FeeRate, ScriptBuf}; use payjoin_test_utils::{KEM, KEY_ID, PARSED_ORIGINAL_PSBT, SYMMETRIC}; - use url::Url; use super::*; + use crate::core::Url; use crate::output_substitution::OutputSubstitution; use crate::persist::test_utils::{InMemoryAsyncTestPersister, InMemoryTestPersister}; use crate::persist::NoopSessionPersister; @@ -190,7 +190,7 @@ mod tests { fn test_sender_session_event_serialization_roundtrip() { let keypair = HpkeKeyPair::gen_keypair(); let id = crate::uri::ShortId::try_from(&b"12345670"[..]).expect("valid short id"); - let endpoint = url::Url::parse("http://localhost:1234").expect("valid url"); + let endpoint = Url::parse("http://localhost:1234").expect("valid url"); let expiration = Time::from_now(std::time::Duration::from_secs(60)).expect("expiration should be valid"); let pj_param = crate::uri::v2::PjParam::new( diff --git a/payjoin/src/core/uri/mod.rs b/payjoin/src/core/uri/mod.rs index 5c151dbdd..e70e9e7a8 100644 --- a/payjoin/src/core/uri/mod.rs +++ b/payjoin/src/core/uri/mod.rs @@ -49,7 +49,7 @@ impl PjParam { pub fn endpoint(&self) -> String { self.endpoint_url().to_string() } - pub(crate) fn endpoint_url(&self) -> url::Url { + pub(crate) fn endpoint_url(&self) -> crate::core::Url { match self { #[cfg(feature = "v1")] PjParam::V1(url) => url.endpoint(), @@ -65,12 +65,12 @@ impl std::fmt::Display for PjParam { // unfortunately Url normalizes these to be lowercase let endpoint = &self.endpoint_url(); let scheme = endpoint.scheme(); - let host = endpoint.host_str().expect("host must be set"); + let host = endpoint.host_str(); let endpoint_str = self .endpoint() .as_str() .replacen(scheme, &scheme.to_uppercase(), 1) - .replacen(host, &host.to_uppercase(), 1); + .replacen(&host, &host.to_uppercase(), 1); write!(f, "{endpoint_str}") } } diff --git a/payjoin/src/core/uri/v1.rs b/payjoin/src/core/uri/v1.rs index 702a671fb..b373d5b22 100644 --- a/payjoin/src/core/uri/v1.rs +++ b/payjoin/src/core/uri/v1.rs @@ -1,8 +1,7 @@ //! Payjoin v1 URI functionality -use url::Url; - use super::PjParseError; +use crate::core::Url; use crate::uri::error::InternalPjParseError; /// Payjoin v1 parameter containing the endpoint URL diff --git a/payjoin/src/core/uri/v2.rs b/payjoin/src/core/uri/v2.rs index 2346bd816..da97fa8b6 100644 --- a/payjoin/src/core/uri/v2.rs +++ b/payjoin/src/core/uri/v2.rs @@ -4,8 +4,8 @@ use std::collections::BTreeMap; use std::str::FromStr; use bitcoin::bech32::Hrp; -use url::Url; +use crate::core::Url; use crate::hpke::HpkePublicKey; use crate::ohttp::OhttpKeys; use crate::time::{ParseTimeError, Time}; diff --git a/payjoin/tests/integration.rs b/payjoin/tests/integration.rs index 0936319c9..13d98669d 100644 --- a/payjoin/tests/integration.rs +++ b/payjoin/tests/integration.rs @@ -9,7 +9,7 @@ mod integration { use bitcoin::{Amount, FeeRate, OutPoint, TxIn, TxOut, Weight}; use payjoin::receive::v1::build_v1_pj_uri; use payjoin::receive::InputPair; - use payjoin::{ImplementationError, OutputSubstitution, PjUri, Request, Uri}; + use payjoin::{ImplementationError, OutputSubstitution, PjUri, Request, Uri, Url}; use payjoin_test_utils::corepc_node::vtype::ListUnspentItem; use payjoin_test_utils::corepc_node::AddressType; use payjoin_test_utils::{corepc_node, init_bitcoind_sender_receiver, init_tracing, BoxError}; @@ -1487,7 +1487,7 @@ mod integration { // Receiver receive payjoin proposal, IRL it will be an HTTP request (over ssl or onion) let proposal = payjoin::receive::v1::UncheckedOriginalPayload::from_request( req.body.as_slice(), - url::Url::from_str(&req.url).expect("Could not parse url").query().unwrap_or(""), + Url::from_str(&req.url).expect("Could not parse url").query().unwrap_or(""), headers, )?; let proposal = From eeb65e05a900b8092136f259c072cbff06f4c99e Mon Sep 17 00:00:00 2001 From: Benalleng Date: Mon, 2 Mar 2026 14:44:02 -0500 Subject: [PATCH 3/3] Add url fuzz target --- fuzz/Cargo.toml | 6 +++ fuzz/cycle.sh | 4 +- fuzz/fuzz.sh | 9 +++- fuzz/fuzz_targets/url/decode_url.rs | 68 +++++++++++++++++++++++++++++ 4 files changed, 85 insertions(+), 2 deletions(-) create mode 100644 fuzz/fuzz_targets/url/decode_url.rs diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml index a3a988ec5..f41caf8be 100644 --- a/fuzz/Cargo.toml +++ b/fuzz/Cargo.toml @@ -28,3 +28,9 @@ name = "uri_deserialize_pjuri" path = "fuzz_targets/uri/deserialize_pjuri.rs" doc = false bench = false + +[[bin]] +name = "url_decode_url" +path = "fuzz_targets/url/decode_url.rs" +doc = false +bench = false diff --git a/fuzz/cycle.sh b/fuzz/cycle.sh index 23743ebd2..a318370f6 100755 --- a/fuzz/cycle.sh +++ b/fuzz/cycle.sh @@ -2,11 +2,13 @@ # Continuously cycle over fuzz targets running each for 1 hour. # It uses chrt SCHED_IDLE so that other process takes priority. +# A number of concurrent forks can be applied for parallelization. Be sure to leave one or two available CPUs open for the OS. # # For cargo-fuzz usage see https://github.com/rust-fuzz/cargo-fuzz?tab=readme-ov-file#usage set -euo pipefail +FORKS=${1:-1} REPO_DIR=$(git rev-parse --show-toplevel) # can't find the file because of the ENV var # shellcheck source=/dev/null @@ -17,7 +19,7 @@ while :; do targetName=$(targetFileToName "$targetFile") echo "Fuzzing target $targetName ($targetFile)" # fuzz for one hour - cargo +nightly fuzz run "$targetName" -- -max_total_time=3600 + cargo +nightly fuzz run "$targetName" -- -max_total_time=3600 -fork="$FORKS" # minimize the corpus cargo +nightly fuzz cmin "$targetName" done diff --git a/fuzz/fuzz.sh b/fuzz/fuzz.sh index 839fba6d4..45f5fb7ff 100755 --- a/fuzz/fuzz.sh +++ b/fuzz/fuzz.sh @@ -1,16 +1,23 @@ #!/usr/bin/env bash # This script is used to briefly fuzz every target when no target is provided. Otherwise, it will briefly fuzz the provided target +# When fuzzing with a specific target a number of concurrent forks can be applied. Be sure to leave one or two available CPUs open for the OS. set -euo pipefail TARGET="" +FORKS=1 if [[ $# -gt 0 ]]; then TARGET="$1" shift fi +if [[ $# -gt 0 ]]; then + FORKS="$1" + shift +fi + REPO_DIR=$(git rev-parse --show-toplevel) # can't find the file because of the ENV var @@ -26,5 +33,5 @@ fi for targetFile in $targetFiles; do targetName=$(targetFileToName "$targetFile") echo "Fuzzing target $targetName ($targetFile)" - cargo fuzz run "$targetName" -- -max_total_time=30 + cargo fuzz run "$targetName" -- -max_total_time=30 -fork="$FORKS" done diff --git a/fuzz/fuzz_targets/url/decode_url.rs b/fuzz/fuzz_targets/url/decode_url.rs new file mode 100644 index 000000000..15b16a030 --- /dev/null +++ b/fuzz/fuzz_targets/url/decode_url.rs @@ -0,0 +1,68 @@ +#![no_main] + +use std::str; + +use libfuzzer_sys::fuzz_target; +// Adjust this path to wherever your Url module lives in your crate. +use payjoin::Url; + +fn do_test(data: &[u8]) { + let Ok(s) = str::from_utf8(data) else { return }; + + let Ok(mut url) = Url::parse(s) else { return }; + + let _ = url.scheme(); + let _ = url.domain(); + let _ = url.port(); + let _ = url.path(); + let _ = url.query(); + let _ = url.fragment(); + let _ = url.as_str(); + let _ = url.to_string(); + if let Some(segs) = url.path_segments() { + let _ = segs.collect::>(); + } + + // Cross-check IPv4/IPv6 parsing against std::net + let host_str = url.host_str(); + if let Ok(std_addr) = host_str.parse::() { + assert!(url.domain().is_none(), "domain() must be None for IPv4 host"); + let _ = std_addr.octets(); + } + let bracketed = host_str.trim_start_matches('[').trim_end_matches(']'); + if let Ok(std_addr) = bracketed.parse::() { + assert!(url.domain().is_none(), "domain() must be None for IPv6 host"); + let _ = std_addr.segments(); + } + + let raw = url.as_str().to_owned(); + if let Ok(reparsed) = Url::parse(&raw) { + assert_eq!( + reparsed.as_str(), + raw, + "round-trip mismatch: first={raw:?} second={:?}", + reparsed.as_str() + ); + } + + url.set_port(Some(8080)); + url.set_port(None); + url.set_fragment(Some("fuzz")); + url.set_fragment(None); + url.query_pairs_mut().append_pair("k", "v"); + url.clear_query(); + url.query_pairs_mut().append_pair("fuzz_key", "fuzz_val"); + + if let Some(mut segs) = url.path_segments_mut() { + segs.push("fuzz_segment"); + } + + let _ = url.join("relative/path"); + let _ = url.join("/absolute/path"); + let _ = url.join("../dotdot"); + let _ = url.join("https://other.example.com/new"); +} + +fuzz_target!(|data| { + do_test(data); +});