diff --git a/Cargo-minimal.lock b/Cargo-minimal.lock index e0b847ddc..d5fa27ee3 100644 --- a/Cargo-minimal.lock +++ b/Cargo-minimal.lock @@ -2768,13 +2768,13 @@ dependencies = [ "http", "once_cell", "payjoin-test-utils", + "percent-encoding-rfc3986", "reqwest", "rustls 0.23.37", "serde", "serde_json", "tokio", "tracing", - "url", "web-time", ] @@ -2805,7 +2805,6 @@ dependencies = [ "tokio-rustls 0.26.4", "tracing", "tracing-subscriber", - "url", ] [[package]] @@ -4980,7 +4979,6 @@ dependencies = [ "idna", "percent-encoding", "serde", - "serde_derive", ] [[package]] diff --git a/Cargo-recent.lock b/Cargo-recent.lock index 9c7295510..4fd70512f 100644 --- a/Cargo-recent.lock +++ b/Cargo-recent.lock @@ -2736,13 +2736,13 @@ dependencies = [ "http", "once_cell", "payjoin-test-utils", + "percent-encoding-rfc3986", "reqwest", "rustls 0.23.31", "serde", "serde_json", "tokio", "tracing", - "url", "web-time", ] @@ -2773,7 +2773,6 @@ dependencies = [ "tokio-rustls 0.26.2", "tracing", "tracing-subscriber", - "url", ] [[package]] diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml index a3a988ec5..f41caf8be 100644 --- a/fuzz/Cargo.toml +++ b/fuzz/Cargo.toml @@ -28,3 +28,9 @@ name = "uri_deserialize_pjuri" path = "fuzz_targets/uri/deserialize_pjuri.rs" doc = false bench = false + +[[bin]] +name = "url_decode_url" +path = "fuzz_targets/url/decode_url.rs" +doc = false +bench = false diff --git a/fuzz/cycle.sh b/fuzz/cycle.sh index 23743ebd2..a318370f6 100755 --- a/fuzz/cycle.sh +++ b/fuzz/cycle.sh @@ -2,11 +2,13 @@ # Continuously cycle over fuzz targets running each for 1 hour. # It uses chrt SCHED_IDLE so that other process takes priority. +# A number of concurrent forks can be applied for parallelization. Be sure to leave one or two available CPUs open for the OS. # # For cargo-fuzz usage see https://github.com/rust-fuzz/cargo-fuzz?tab=readme-ov-file#usage set -euo pipefail +FORKS=${1:-1} REPO_DIR=$(git rev-parse --show-toplevel) # can't find the file because of the ENV var # shellcheck source=/dev/null @@ -17,7 +19,7 @@ while :; do targetName=$(targetFileToName "$targetFile") echo "Fuzzing target $targetName ($targetFile)" # fuzz for one hour - cargo +nightly fuzz run "$targetName" -- -max_total_time=3600 + cargo +nightly fuzz run "$targetName" -- -max_total_time=3600 -fork="$FORKS" # minimize the corpus cargo +nightly fuzz cmin "$targetName" done diff --git a/fuzz/fuzz.sh b/fuzz/fuzz.sh index 839fba6d4..45f5fb7ff 100755 --- a/fuzz/fuzz.sh +++ b/fuzz/fuzz.sh @@ -1,16 +1,23 @@ #!/usr/bin/env bash # This script is used to briefly fuzz every target when no target is provided. Otherwise, it will briefly fuzz the provided target +# When fuzzing with a specific target a number of concurrent forks can be applied. Be sure to leave one or two available CPUs open for the OS. set -euo pipefail TARGET="" +FORKS=1 if [[ $# -gt 0 ]]; then TARGET="$1" shift fi +if [[ $# -gt 0 ]]; then + FORKS="$1" + shift +fi + REPO_DIR=$(git rev-parse --show-toplevel) # can't find the file because of the ENV var @@ -26,5 +33,5 @@ fi for targetFile in $targetFiles; do targetName=$(targetFileToName "$targetFile") echo "Fuzzing target $targetName ($targetFile)" - cargo fuzz run "$targetName" -- -max_total_time=30 + cargo fuzz run "$targetName" -- -max_total_time=30 -fork="$FORKS" done diff --git a/fuzz/fuzz_targets/url/decode_url.rs b/fuzz/fuzz_targets/url/decode_url.rs new file mode 100644 index 000000000..15b16a030 --- /dev/null +++ b/fuzz/fuzz_targets/url/decode_url.rs @@ -0,0 +1,68 @@ +#![no_main] + +use std::str; + +use libfuzzer_sys::fuzz_target; +// Adjust this path to wherever your Url module lives in your crate. +use payjoin::Url; + +fn do_test(data: &[u8]) { + let Ok(s) = str::from_utf8(data) else { return }; + + let Ok(mut url) = Url::parse(s) else { return }; + + let _ = url.scheme(); + let _ = url.domain(); + let _ = url.port(); + let _ = url.path(); + let _ = url.query(); + let _ = url.fragment(); + let _ = url.as_str(); + let _ = url.to_string(); + if let Some(segs) = url.path_segments() { + let _ = segs.collect::>(); + } + + // Cross-check IPv4/IPv6 parsing against std::net + let host_str = url.host_str(); + if let Ok(std_addr) = host_str.parse::() { + assert!(url.domain().is_none(), "domain() must be None for IPv4 host"); + let _ = std_addr.octets(); + } + let bracketed = host_str.trim_start_matches('[').trim_end_matches(']'); + if let Ok(std_addr) = bracketed.parse::() { + assert!(url.domain().is_none(), "domain() must be None for IPv6 host"); + let _ = std_addr.segments(); + } + + let raw = url.as_str().to_owned(); + if let Ok(reparsed) = Url::parse(&raw) { + assert_eq!( + reparsed.as_str(), + raw, + "round-trip mismatch: first={raw:?} second={:?}", + reparsed.as_str() + ); + } + + url.set_port(Some(8080)); + url.set_port(None); + url.set_fragment(Some("fuzz")); + url.set_fragment(None); + url.query_pairs_mut().append_pair("k", "v"); + url.clear_query(); + url.query_pairs_mut().append_pair("fuzz_key", "fuzz_val"); + + if let Some(mut segs) = url.path_segments_mut() { + segs.push("fuzz_segment"); + } + + let _ = url.join("relative/path"); + let _ = url.join("/absolute/path"); + let _ = url.join("../dotdot"); + let _ = url.join("https://other.example.com/new"); +} + +fuzz_target!(|data| { + do_test(data); +}); diff --git a/payjoin-cli/Cargo.toml b/payjoin-cli/Cargo.toml index 7722f892f..35f8391ea 100644 --- a/payjoin-cli/Cargo.toml +++ b/payjoin-cli/Cargo.toml @@ -51,7 +51,6 @@ tokio-rustls = { version = "0.26.2", features = [ ], default-features = false, optional = true } tracing = "0.1.41" tracing-subscriber = { version = "0.3.19", features = ["env-filter"] } -url = { version = "2.5.4", features = ["serde"] } [dev-dependencies] nix = { version = "0.30.1", features = ["aio", "process", "signal"] } diff --git a/payjoin-cli/src/app/config.rs b/payjoin-cli/src/app/config.rs index 6b7f8acb8..cdef7ca6c 100644 --- a/payjoin-cli/src/app/config.rs +++ b/payjoin-cli/src/app/config.rs @@ -4,9 +4,8 @@ use anyhow::Result; use config::builder::DefaultState; use config::{ConfigError, File, FileFormat}; use payjoin::bitcoin::FeeRate; -use payjoin::Version; +use payjoin::{Url, Version}; use serde::Deserialize; -use url::Url; use crate::cli::{Cli, Commands}; use crate::db; @@ -317,10 +316,16 @@ fn handle_subcommands(config: Builder, cli: &Cli) -> Result { let query_string = req.uri().query().unwrap_or(""); tracing::trace!("{:?}, {query_string:?}", req.method()); - let query_params: HashMap<_, _> = - url::form_urlencoded::parse(query_string.as_bytes()).into_owned().collect(); + let query_url = payjoin::Url::parse(&format!("http://localhost/?{query_string}")) + .expect("valid query URL"); + let query_params: HashMap = + query_url.query_pairs().into_iter().collect(); let amount = query_params.get("amount").map(|amt| { Amount::from_btc(amt.parse().expect("Failed to parse amount")).unwrap() }); diff --git a/payjoin-cli/src/app/v2/mod.rs b/payjoin-cli/src/app/v2/mod.rs index 73f4dcf85..324121267 100644 --- a/payjoin-cli/src/app/v2/mod.rs +++ b/payjoin-cli/src/app/v2/mod.rs @@ -813,7 +813,7 @@ impl App { async fn unwrap_relay_or_else_fetch( &self, directory: Option, - ) -> Result { + ) -> Result { let directory = directory.map(|url| url.into_url()).transpose()?; let selected_relay = self.relay_manager.lock().expect("Lock should not be poisoned").get_selected_relay(); diff --git a/payjoin-cli/src/app/v2/ohttp.rs b/payjoin-cli/src/app/v2/ohttp.rs index 4ee637ecb..d6cb490f6 100644 --- a/payjoin-cli/src/app/v2/ohttp.rs +++ b/payjoin-cli/src/app/v2/ohttp.rs @@ -1,35 +1,36 @@ use std::sync::{Arc, Mutex}; use anyhow::{anyhow, Result}; +use payjoin::Url; use super::Config; #[derive(Debug, Clone)] pub struct RelayManager { - selected_relay: Option, - failed_relays: Vec, + selected_relay: Option, + failed_relays: Vec, } impl RelayManager { pub fn new() -> Self { RelayManager { selected_relay: None, failed_relays: Vec::new() } } - pub fn set_selected_relay(&mut self, relay: url::Url) { self.selected_relay = Some(relay); } + pub fn set_selected_relay(&mut self, relay: Url) { self.selected_relay = Some(relay); } - pub fn get_selected_relay(&self) -> Option { self.selected_relay.clone() } + pub fn get_selected_relay(&self) -> Option { self.selected_relay.clone() } - pub fn add_failed_relay(&mut self, relay: url::Url) { self.failed_relays.push(relay); } + pub fn add_failed_relay(&mut self, relay: Url) { self.failed_relays.push(relay); } - pub fn get_failed_relays(&self) -> Vec { self.failed_relays.clone() } + pub fn get_failed_relays(&self) -> Vec { self.failed_relays.clone() } } pub(crate) struct ValidatedOhttpKeys { pub(crate) ohttp_keys: payjoin::OhttpKeys, - pub(crate) relay_url: url::Url, + pub(crate) relay_url: Url, } pub(crate) async fn unwrap_ohttp_keys_or_else_fetch( config: &Config, - directory: Option, + directory: Option, relay_manager: Arc>, ) -> Result { if let Some(ohttp_keys) = config.v2()?.ohttp_keys.clone() { @@ -46,7 +47,7 @@ pub(crate) async fn unwrap_ohttp_keys_or_else_fetch( async fn fetch_ohttp_keys( config: &Config, - directory: Option, + directory: Option, relay_manager: Arc>, ) -> Result { use payjoin::bitcoin::secp256k1::rand::prelude::SliceRandom; diff --git a/payjoin-cli/src/cli/mod.rs b/payjoin-cli/src/cli/mod.rs index 1f35979b5..8b5e42fa7 100644 --- a/payjoin-cli/src/cli/mod.rs +++ b/payjoin-cli/src/cli/mod.rs @@ -3,8 +3,8 @@ use std::path::PathBuf; use clap::{value_parser, Parser, Subcommand}; use payjoin::bitcoin::amount::ParseAmountError; use payjoin::bitcoin::{Amount, FeeRate}; +use payjoin::Url; use serde::Deserialize; -use url::Url; #[derive(Debug, Clone, Deserialize, Parser)] pub struct Flags { @@ -114,13 +114,13 @@ pub enum Commands { #[cfg(feature = "v1")] /// The `pj=` endpoint to receive the payjoin request - #[arg(long = "pj-endpoint", value_parser = value_parser!(Url))] - pj_endpoint: Option, + #[arg(long = "pj-endpoint", value_parser = parse_boxed_url)] + pj_endpoint: Option>, #[cfg(feature = "v2")] /// The directory to store payjoin requests - #[arg(long = "pj-directory", value_parser = value_parser!(Url))] - pj_directory: Option, + #[arg(long = "pj-directory", value_parser = parse_boxed_url)] + pj_directory: Option>, #[cfg(feature = "v2")] /// The path to the ohttp keys file @@ -144,3 +144,7 @@ pub fn parse_fee_rate_in_sat_per_vb(s: &str) -> Result Result, String> { + s.parse::().map(Box::new).map_err(|e| e.to_string()) +} diff --git a/payjoin/Cargo.toml b/payjoin/Cargo.toml index b1d7c17b0..a2e0c87f2 100644 --- a/payjoin/Cargo.toml +++ b/payjoin/Cargo.toml @@ -26,7 +26,7 @@ _core = [ "bitcoin/rand-std", "dep:http", "serde_json", - "url/serde", + "dep:percent-encoding-rfc3986", "bitcoin_uri", "serde", "bitcoin/serde", @@ -46,6 +46,7 @@ bitcoin_uri = { version = "0.1.0", optional = true } hpke = { package = "bitcoin-hpke", version = "0.13.0", optional = true } http = { version = "1.3.1", optional = true } ohttp = { package = "bitcoin-ohttp", version = "0.6.0", optional = true } +percent-encoding-rfc3986 = { version = "0.1.3", optional = true } reqwest = { version = "0.12.23", default-features = false, optional = true } rustls = { version = "0.23.31", optional = true, default-features = false, features = [ "ring", @@ -53,9 +54,6 @@ rustls = { version = "0.23.31", optional = true, default-features = false, featu serde = { version = "1.0.219", default-features = false, optional = true } serde_json = { version = "1.0.142", optional = true } tracing = "0.1.41" -url = { version = "2.5.4", optional = true, default-features = false, features = [ - "serde", -] } [target.'cfg(target_arch = "wasm32")'.dependencies] web-time = "1.1.0" diff --git a/payjoin/src/core/into_url.rs b/payjoin/src/core/into_url.rs index fc4537bf5..9d45755a5 100644 --- a/payjoin/src/core/into_url.rs +++ b/payjoin/src/core/into_url.rs @@ -1,9 +1,9 @@ -use url::{ParseError, Url}; +use crate::core::{Url, UrlParseError}; #[derive(Debug, PartialEq, Eq)] pub enum Error { BadScheme, - ParseError(ParseError), + ParseError(UrlParseError), } impl std::fmt::Display for Error { @@ -19,8 +19,8 @@ impl std::fmt::Display for Error { impl std::error::Error for Error {} -impl From for Error { - fn from(err: ParseError) -> Error { Error::ParseError(err) } +impl From for Error { + fn from(err: UrlParseError) -> Error { Error::ParseError(err) } } type Result = core::result::Result; @@ -53,13 +53,7 @@ impl IntoUrlSealed for &Url { } impl IntoUrlSealed for Url { - fn into_url(self) -> Result { - if self.has_host() { - Ok(self) - } else { - Err(Error::BadScheme) - } - } + fn into_url(self) -> Result { Ok(self) } fn as_str(&self) -> &str { self.as_ref() } } @@ -101,13 +95,19 @@ mod tests { #[test] fn into_url_file_scheme() { let err = "file:///etc/hosts".into_url().unwrap_err(); - assert_eq!(err.to_string(), "URL scheme is not allowed"); + assert_eq!(err.to_string(), "empty host"); } #[test] fn into_url_blob_scheme() { let err = "blob:https://example.com".into_url().unwrap_err(); - assert_eq!(err.to_string(), "URL scheme is not allowed"); + assert_eq!(err.to_string(), "invalid format"); + } + + #[test] + fn into_url_rejects_userinfo() { + let err = "http://user@example.com/".into_url().unwrap_err(); + assert_eq!(err.to_string(), "invalid host"); } #[test] diff --git a/payjoin/src/core/io.rs b/payjoin/src/core/io.rs index d335a3fce..2cbc0010b 100644 --- a/payjoin/src/core/io.rs +++ b/payjoin/src/core/io.rs @@ -23,7 +23,7 @@ pub async fn fetch_ohttp_keys( let proxy = Proxy::all(ohttp_relay.into_url()?.as_str())?; let client = Client::builder().proxy(proxy).http1_only().build()?; let res = client - .get(ohttp_keys_url) + .get(ohttp_keys_url.as_str()) .timeout(Duration::from_secs(10)) .header(ACCEPT, "application/ohttp-keys") .send() @@ -56,7 +56,7 @@ pub async fn fetch_ohttp_keys_with_cert( .http1_only() .build()?; let res = client - .get(ohttp_keys_url) + .get(ohttp_keys_url.as_str()) .timeout(Duration::from_secs(10)) .header(ACCEPT, "application/ohttp-keys") .send() @@ -98,8 +98,8 @@ enum InternalErrorInner { InvalidOhttpKeys(String), } -impl From for Error { - fn from(value: url::ParseError) -> Self { +impl From for Error { + fn from(value: crate::core::UrlParseError) -> Self { Self::Internal(InternalError(InternalErrorInner::ParseUrl(value.into()))) } } diff --git a/payjoin/src/core/mod.rs b/payjoin/src/core/mod.rs index 425bbae91..ec64e9963 100644 --- a/payjoin/src/core/mod.rs +++ b/payjoin/src/core/mod.rs @@ -16,6 +16,8 @@ pub mod send; pub use request::*; pub(crate) mod into_url; pub use into_url::{Error as IntoUrlError, IntoUrl}; +pub(crate) mod url; +pub use url::{ParseError as UrlParseError, Url}; #[cfg(feature = "v2")] pub mod time; pub mod uri; diff --git a/payjoin/src/core/ohttp.rs b/payjoin/src/core/ohttp.rs index 2036a9f5a..2fbd0d7d4 100644 --- a/payjoin/src/core/ohttp.rs +++ b/payjoin/src/core/ohttp.rs @@ -23,14 +23,14 @@ pub(crate) fn ohttp_encapsulate( let mut ohttp_keys = ohttp_keys.clone(); let ctx = ohttp::ClientRequest::from_config(&mut ohttp_keys)?; - let url = url::Url::parse(target_resource)?; - let authority_bytes = url.host().map_or_else(Vec::new, |host| { - let mut authority = host.to_string(); + let url = crate::core::Url::parse(target_resource)?; + let authority_bytes = { + let mut authority = url.host_str(); if let Some(port) = url.port() { write!(authority, ":{port}").unwrap(); } authority.into_bytes() - }); + }; let mut bhttp_message = bhttp::Message::request( method.as_bytes().to_vec(), url.scheme().as_bytes().to_vec(), @@ -164,7 +164,7 @@ pub enum OhttpEncapsulationError { Http(http::Error), Ohttp(ohttp::Error), Bhttp(bhttp::Error), - ParseUrl(url::ParseError), + ParseUrl(crate::core::UrlParseError), } impl From for OhttpEncapsulationError { @@ -179,8 +179,8 @@ impl From for OhttpEncapsulationError { fn from(value: bhttp::Error) -> Self { Self::Bhttp(value) } } -impl From for OhttpEncapsulationError { - fn from(value: url::ParseError) -> Self { Self::ParseUrl(value) } +impl From for OhttpEncapsulationError { + fn from(value: crate::core::UrlParseError) -> Self { Self::ParseUrl(value) } } impl fmt::Display for OhttpEncapsulationError { diff --git a/payjoin/src/core/receive/error.rs b/payjoin/src/core/receive/error.rs index 5d9b02083..d7bd23e39 100644 --- a/payjoin/src/core/receive/error.rs +++ b/payjoin/src/core/receive/error.rs @@ -243,6 +243,8 @@ impl From<&PayloadError> for JsonReply { } super::optional_parameters::Error::FeeRate => JsonReply::new(OriginalPsbtRejected, e), + super::optional_parameters::Error::MalformedQuery => + JsonReply::new(OriginalPsbtRejected, e), }, } } diff --git a/payjoin/src/core/receive/mod.rs b/payjoin/src/core/receive/mod.rs index a2be3bc0e..2790a6325 100644 --- a/payjoin/src/core/receive/mod.rs +++ b/payjoin/src/core/receive/mod.rs @@ -239,8 +239,7 @@ pub(crate) fn parse_payload( let psbt = unchecked_psbt.validate().map_err(InternalPayloadError::InconsistentPsbt)?; tracing::trace!("Received original psbt: {psbt:?}"); - let pairs = url::form_urlencoded::parse(query.as_bytes()); - let params = Params::from_query_pairs(pairs, supported_versions) + let params = Params::from_query_str(query, supported_versions) .map_err(InternalPayloadError::SenderParams)?; tracing::trace!("Received request with params: {params:?}"); @@ -484,9 +483,8 @@ pub(crate) mod tests { use crate::psbt::NON_WITNESS_INPUT_WEIGHT; pub(crate) fn original_from_test_vector() -> OriginalPayload { - let pairs = url::form_urlencoded::parse(QUERY_PARAMS.as_bytes()); - let params = Params::from_query_pairs(pairs, &[Version::One]) - .expect("Could not parse params from query pairs"); + let params = Params::from_query_str(QUERY_PARAMS, &[Version::One]) + .expect("Could not parse params from query str"); OriginalPayload { psbt: PARSED_ORIGINAL_PSBT.clone(), params } } diff --git a/payjoin/src/core/receive/optional_parameters.rs b/payjoin/src/core/receive/optional_parameters.rs index 0da46e00a..bb2df9bb0 100644 --- a/payjoin/src/core/receive/optional_parameters.rs +++ b/payjoin/src/core/receive/optional_parameters.rs @@ -121,12 +121,22 @@ impl Params { tracing::trace!("parsed optional parameters: {params:?}"); Ok(params) } + + pub fn from_query_str( + query: &str, + supported_versions: &'static [Version], + ) -> Result { + let url = crate::Url::parse(&format!("http://localhost/?{query}")) + .map_err(|_| Error::MalformedQuery)?; + Self::from_query_pairs(url.query_pairs().into_iter(), supported_versions) + } } #[derive(Debug, PartialEq, Eq)] pub(crate) enum Error { UnknownVersion { supported_versions: &'static [Version] }, FeeRate, + MalformedQuery, } impl fmt::Display for Error { @@ -134,6 +144,7 @@ impl fmt::Display for Error { match self { Error::UnknownVersion { .. } => write!(f, "unknown version"), Error::FeeRate => write!(f, "could not parse feerate"), + Error::MalformedQuery => write!(f, "malformed query parameter encoding"), } } } @@ -152,9 +163,8 @@ pub(crate) mod test { #[test] fn test_parse_params() { - let pairs = url::form_urlencoded::parse(b"&maxadditionalfeecontribution=182&additionalfeeoutputindex=0&minfeerate=2&disableoutputsubstitution=true&optimisticmerge=true"); - let params = Params::from_query_pairs(pairs, &[Version::One]) - .expect("Could not parse params from query pairs"); + let params = Params::from_query_str("&maxadditionalfeecontribution=182&additionalfeeoutputindex=0&minfeerate=2&disableoutputsubstitution=true&optimisticmerge=true", &[Version::One]) + .expect("Could not parse params from query str"); assert_eq!(params.v, Version::One); assert_eq!(params.output_substitution, OutputSubstitution::Disabled); assert_eq!(params.additional_fee_contribution, Some((Amount::from_sat(182), 0))); diff --git a/payjoin/src/core/receive/v1/mod.rs b/payjoin/src/core/receive/v1/mod.rs index 7c2e937c3..99e8aa3a9 100644 --- a/payjoin/src/core/receive/v1/mod.rs +++ b/payjoin/src/core/receive/v1/mod.rs @@ -405,18 +405,16 @@ mod tests { } fn unchecked_proposal_from_test_vector() -> UncheckedOriginalPayload { - let pairs = url::form_urlencoded::parse(QUERY_PARAMS.as_bytes()); - let params = Params::from_query_pairs(pairs, &[Version::One]) - .expect("Could not parse params from query pairs"); + let params = Params::from_query_str(QUERY_PARAMS, &[Version::One]) + .expect("Could not parse params from query str"); UncheckedOriginalPayload { original: OriginalPayload { psbt: PARSED_ORIGINAL_PSBT.clone(), params }, } } fn maybe_inputs_owned_from_test_vector() -> MaybeInputsOwned { - let pairs = url::form_urlencoded::parse(QUERY_PARAMS.as_bytes()); - let params = Params::from_query_pairs(pairs, &[Version::One]) - .expect("Could not parse params from query pairs"); + let params = Params::from_query_str(QUERY_PARAMS, &[Version::One]) + .expect("Could not parse params from query str"); MaybeInputsOwned { original: OriginalPayload { psbt: PARSED_ORIGINAL_PSBT.clone(), params }, } diff --git a/payjoin/src/core/receive/v2/mod.rs b/payjoin/src/core/receive/v2/mod.rs index 80078ca6c..90fabd4ba 100644 --- a/payjoin/src/core/receive/v2/mod.rs +++ b/payjoin/src/core/receive/v2/mod.rs @@ -39,7 +39,6 @@ pub use session::{ replay_event_log, replay_event_log_async, SessionEvent, SessionHistory, SessionOutcome, SessionStatus, }; -use url::Url; #[cfg(target_arch = "wasm32")] use web_time::Duration; @@ -47,6 +46,7 @@ use super::error::{Error, InputContributionError}; use super::{ common, InternalPayloadError, JsonReply, OutputSubstitutionError, ProtocolError, SelectionError, }; +use crate::core::Url; use crate::error::{InternalReplayError, ReplayError}; use crate::hpke::{decrypt_message_a, encrypt_message_b, HpkeKeyPair, HpkePublicKey}; use crate::ohttp::{ @@ -73,7 +73,7 @@ static TWENTY_FOUR_HOURS_DEFAULT_EXPIRATION: Duration = Duration::from_secs(60 * pub struct SessionContext { #[serde(deserialize_with = "deserialize_address_assume_checked")] address: Address, - directory: url::Url, + directory: Url, ohttp_keys: OhttpKeys, expiration: Time, amount: Option, @@ -1463,8 +1463,7 @@ pub mod test { }); pub(crate) fn unchecked_proposal_v2_from_test_vector() -> UncheckedOriginalPayload { - let pairs = url::form_urlencoded::parse(QUERY_PARAMS.as_bytes()); - let params = Params::from_query_pairs(pairs, &[Version::Two]) + let params = Params::from_query_str(QUERY_PARAMS, &[Version::Two]) .expect("Test utils query params should not fail"); UncheckedOriginalPayload { original: OriginalPayload { psbt: PARSED_ORIGINAL_PSBT.clone(), params }, @@ -1472,8 +1471,7 @@ pub mod test { } pub(crate) fn maybe_inputs_owned_v2_from_test_vector() -> MaybeInputsOwned { - let pairs = url::form_urlencoded::parse(QUERY_PARAMS.as_bytes()); - let params = Params::from_query_pairs(pairs, &[Version::Two]) + let params = Params::from_query_str(QUERY_PARAMS, &[Version::Two]) .expect("Test utils query params should not fail"); MaybeInputsOwned { original: OriginalPayload { psbt: PARSED_ORIGINAL_PSBT.clone(), params }, diff --git a/payjoin/src/core/request.rs b/payjoin/src/core/request.rs index b4db22acb..abe51b611 100644 --- a/payjoin/src/core/request.rs +++ b/payjoin/src/core/request.rs @@ -1,4 +1,4 @@ -use url::Url; +use crate::core::Url; #[cfg(feature = "v1")] const V1_REQ_CONTENT_TYPE: &str = "text/plain"; diff --git a/payjoin/src/core/send/mod.rs b/payjoin/src/core/send/mod.rs index 276d928b4..5aa3521d8 100644 --- a/payjoin/src/core/send/mod.rs +++ b/payjoin/src/core/send/mod.rs @@ -20,8 +20,8 @@ use bitcoin::psbt::Psbt; use bitcoin::{Amount, FeeRate, Script, ScriptBuf, TxOut, Weight}; pub use error::{BuildSenderError, ResponseError, ValidationError, WellKnownError}; pub(crate) use error::{InternalBuildSenderError, InternalProposalError, InternalValidationError}; -use url::Url; +use crate::core::Url; use crate::output_substitution::OutputSubstitution; use crate::psbt::{AddressTypeError, PsbtExt, NON_WITNESS_INPUT_WEIGHT}; use crate::Version; @@ -684,9 +684,9 @@ mod test { BoxError, PARSED_ORIGINAL_PSBT, PARSED_PAYJOIN_PROPOSAL, PARSED_PAYJOIN_PROPOSAL_WITH_SENDER_INFO, }; - use url::Url; use super::*; + use crate::core::Url; use crate::output_substitution::OutputSubstitution; use crate::psbt::PsbtExt; use crate::send::{AdditionalFeeContribution, InternalBuildSenderError, InternalProposalError}; diff --git a/payjoin/src/core/send/v2/mod.rs b/payjoin/src/core/send/v2/mod.rs index 7a36553e3..a2daef428 100644 --- a/payjoin/src/core/send/v2/mod.rs +++ b/payjoin/src/core/send/v2/mod.rs @@ -38,10 +38,10 @@ pub use session::{ replay_event_log, replay_event_log_async, SessionEvent, SessionHistory, SessionOutcome, SessionStatus, }; -use url::Url; use super::error::BuildSenderError; use super::*; +use crate::core::Url; use crate::error::{InternalReplayError, ReplayError}; use crate::hpke::{decrypt_message_b, encrypt_message_a, HpkeSecretKey}; use crate::ohttp::{ohttp_encapsulate, process_get_res, process_post_res}; diff --git a/payjoin/src/core/send/v2/session.rs b/payjoin/src/core/send/v2/session.rs index 1c78c7827..8888243bb 100644 --- a/payjoin/src/core/send/v2/session.rs +++ b/payjoin/src/core/send/v2/session.rs @@ -172,9 +172,9 @@ pub enum SessionOutcome { mod tests { use bitcoin::{FeeRate, ScriptBuf}; use payjoin_test_utils::{KEM, KEY_ID, PARSED_ORIGINAL_PSBT, SYMMETRIC}; - use url::Url; use super::*; + use crate::core::Url; use crate::output_substitution::OutputSubstitution; use crate::persist::test_utils::{InMemoryAsyncTestPersister, InMemoryTestPersister}; use crate::persist::NoopSessionPersister; @@ -190,7 +190,7 @@ mod tests { fn test_sender_session_event_serialization_roundtrip() { let keypair = HpkeKeyPair::gen_keypair(); let id = crate::uri::ShortId::try_from(&b"12345670"[..]).expect("valid short id"); - let endpoint = url::Url::parse("http://localhost:1234").expect("valid url"); + let endpoint = Url::parse("http://localhost:1234").expect("valid url"); let expiration = Time::from_now(std::time::Duration::from_secs(60)).expect("expiration should be valid"); let pj_param = crate::uri::v2::PjParam::new( diff --git a/payjoin/src/core/uri/mod.rs b/payjoin/src/core/uri/mod.rs index 5c151dbdd..e70e9e7a8 100644 --- a/payjoin/src/core/uri/mod.rs +++ b/payjoin/src/core/uri/mod.rs @@ -49,7 +49,7 @@ impl PjParam { pub fn endpoint(&self) -> String { self.endpoint_url().to_string() } - pub(crate) fn endpoint_url(&self) -> url::Url { + pub(crate) fn endpoint_url(&self) -> crate::core::Url { match self { #[cfg(feature = "v1")] PjParam::V1(url) => url.endpoint(), @@ -65,12 +65,12 @@ impl std::fmt::Display for PjParam { // unfortunately Url normalizes these to be lowercase let endpoint = &self.endpoint_url(); let scheme = endpoint.scheme(); - let host = endpoint.host_str().expect("host must be set"); + let host = endpoint.host_str(); let endpoint_str = self .endpoint() .as_str() .replacen(scheme, &scheme.to_uppercase(), 1) - .replacen(host, &host.to_uppercase(), 1); + .replacen(&host, &host.to_uppercase(), 1); write!(f, "{endpoint_str}") } } diff --git a/payjoin/src/core/uri/v1.rs b/payjoin/src/core/uri/v1.rs index 702a671fb..b373d5b22 100644 --- a/payjoin/src/core/uri/v1.rs +++ b/payjoin/src/core/uri/v1.rs @@ -1,8 +1,7 @@ //! Payjoin v1 URI functionality -use url::Url; - use super::PjParseError; +use crate::core::Url; use crate::uri::error::InternalPjParseError; /// Payjoin v1 parameter containing the endpoint URL diff --git a/payjoin/src/core/uri/v2.rs b/payjoin/src/core/uri/v2.rs index 2346bd816..da97fa8b6 100644 --- a/payjoin/src/core/uri/v2.rs +++ b/payjoin/src/core/uri/v2.rs @@ -4,8 +4,8 @@ use std::collections::BTreeMap; use std::str::FromStr; use bitcoin::bech32::Hrp; -use url::Url; +use crate::core::Url; use crate::hpke::HpkePublicKey; use crate::ohttp::OhttpKeys; use crate::time::{ParseTimeError, Time}; diff --git a/payjoin/src/core/url.rs b/payjoin/src/core/url.rs new file mode 100644 index 000000000..7628bf04a --- /dev/null +++ b/payjoin/src/core/url.rs @@ -0,0 +1,746 @@ +//! Minimal URL type used internally by `payjoin`. +//! +//! This module provides a small, dependency-free URL parser that covers the +//! subset of RFC 3986 needed by the payjoin protocol (`http`, `https`, and +//! `bitcoin:` style URIs). It is not a full replacement for the `url` crate — +//! only the surface used by this library is implemented. +//! +//! The primary entry point is [`Url`], with parse errors surfaced through +//! [`ParseError`] (re-exported at the crate root as `UrlParseError`). + +use core::fmt; +use core::str::FromStr; + +/// A parsed URL. +/// +/// Construct one with [`Url::parse`] or via the [`FromStr`] impl. The parser +/// accepts an absolute URL of the form `scheme://host[:port][/path][?query][#fragment]`. +/// When no path is supplied, `/` is stored so that round-tripping through +/// [`Url::as_str`] always yields a normalised form. +/// +/// # Example +/// +/// ```ignore +/// use payjoin::UrlParseError; +/// # fn demo() -> Result<(), UrlParseError> { +/// let url: payjoin::Url = "https://example.com/pj?v=2".parse()?; +/// assert_eq!(url.scheme(), "https"); +/// assert_eq!(url.host_str(), "example.com"); +/// assert_eq!(url.path(), "/pj"); +/// assert_eq!(url.query(), Some("v=2")); +/// # Ok(()) +/// # } +/// ``` +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Url { + raw: String, + scheme: String, + host: Host, + port: Option, + path: String, + query: Option, + fragment: Option, +} + +/// Iterator over the `/`-separated segments of a URL path. +/// +/// Returned by [`Url::path_segments`]. The leading `/` is stripped before +/// splitting, so `"/a/b"` yields `["a", "b"]` and `"/"` yields no segments. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct PathSegments<'a> { + segments: Vec<&'a str>, + index: usize, +} + +/// The host component of a [`Url`]. +/// +/// Parsed into one of three shapes depending on the input: a registered +/// domain name, a dotted-quad IPv4 address, or a bracketed IPv6 literal. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Host { + /// A registered domain name, e.g. `example.com`. + Domain(String), + /// An IPv4 address in network byte order. + Ipv4([u8; 4]), + /// An IPv6 address as eight 16-bit groups in network byte order. + Ipv6([u16; 8]), +} + +impl<'a> Iterator for PathSegments<'a> { + type Item = &'a str; + + fn next(&mut self) -> Option { + if self.index >= self.segments.len() { + None + } else { + let item = self.segments[self.index]; + self.index += 1; + Some(item) + } + } +} + +/// Mutable handle for appending path segments to a [`Url`]. +/// +/// Obtained via [`Url::path_segments_mut`]. The underlying URL's serialised +/// form is rebuilt when this handle is dropped. +pub struct PathSegmentsMut<'a> { + url: &'a mut Url, +} + +impl<'a> PathSegmentsMut<'a> { + /// Append a single path segment, inserting a `/` separator if needed. + /// + /// The segment is pushed verbatim — no percent-encoding is applied. + pub fn push(&mut self, segment: &str) { + if !self.url.path.ends_with('/') && !self.url.path.is_empty() { + self.url.path.push('/'); + } + self.url.path.push_str(segment); + } +} + +impl<'a> Drop for PathSegmentsMut<'a> { + fn drop(&mut self) { self.url.rebuild_raw(); } +} + +impl Url { + /// Return a mutable handle for appending path segments. + /// + /// Always returns `Some`; the `Option` mirrors the `url` crate's API so + /// callers can migrate without changes. + pub fn path_segments_mut(&mut self) -> Option> { + Some(PathSegmentsMut { url: self }) + } +} + +/// Mutable handle for appending `key=value` pairs to a [`Url`]'s query string. +/// +/// Obtained via [`Url::query_pairs_mut`]. Each [`append_pair`](Self::append_pair) +/// call rewrites the underlying URL's serialised form. +pub struct UrlQueryPairs<'a> { + url: &'a mut Url, +} + +impl<'a> UrlQueryPairs<'a> { + /// Append a single `key=value` pair to the query string. + /// + /// Key and value are written verbatim — the caller is responsible for any + /// percent-encoding. Returns `&mut self` to support fluent chaining. + pub fn append_pair(&mut self, key: &str, value: &str) -> &mut UrlQueryPairs<'a> { + let new_pair = format!("{}={}", key, value); + if let Some(ref mut query) = self.url.query { + query.push_str(&format!("&{}", new_pair)); + } else { + self.url.query = Some(new_pair); + } + self.url.rebuild_raw(); + self + } +} + +/// Errors produced by [`Url::parse`]. +/// +/// Re-exported at the crate root as `UrlParseError`. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ParseError { + /// The authority section had no host between `://` and the path. + EmptyHost, + /// The scheme was empty. + InvalidScheme, + /// The overall structure did not match `scheme://host...`. + InvalidFormat, + /// The port was present but did not parse as a `u16`. + InvalidPort, + /// The host was not a valid domain, IPv4 literal, or IPv6 literal. + InvalidHost, +} + +impl fmt::Display for ParseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ParseError::EmptyHost => write!(f, "empty host"), + ParseError::InvalidScheme => write!(f, "invalid scheme"), + ParseError::InvalidFormat => write!(f, "invalid format"), + ParseError::InvalidPort => write!(f, "invalid port"), + ParseError::InvalidHost => write!(f, "invalid host"), + } + } +} + +impl std::error::Error for ParseError {} + +impl FromStr for Url { + type Err = ParseError; + + fn from_str(s: &str) -> Result { Self::parse(s) } +} + +impl Url { + /// Parse an absolute URL. + /// + /// The input must be of the form `scheme://host[:port][/path][?query][#fragment]`. + /// An empty path is normalised to `/`. The scheme is lower-cased. + pub fn parse(input: &str) -> Result { + let (rest, scheme) = parse_scheme(input)?; + let (_rest, host, port, path, query, fragment) = + if let Some(rest) = rest.strip_prefix("://") { + let (rest, host) = parse_host(rest)?; + let (rest, port) = parse_port(rest).unwrap_or((rest, None)); + let (path, query, fragment) = parse_path_query_fragment(rest); + (rest, host, port, path, query, fragment) + } else { + return Err(ParseError::InvalidFormat); + }; + + let path = if path.is_empty() { "/".to_string() } else { path }; + + let mut url = Url { raw: String::new(), scheme, host, port, path, query, fragment }; + url.rebuild_raw(); + Ok(url) + } + + /// The URL's scheme, lower-cased (e.g. `"https"`). + pub fn scheme(&self) -> &str { &self.scheme } + + /// The host as a domain name, or `None` for IP literals. + pub fn domain(&self) -> Option<&str> { + match &self.host { + Host::Domain(s) => Some(s.as_str()), + _ => None, + } + } + + /// The host rendered as a string. + /// + /// Domains are returned as-is; IPv4 addresses as dotted-quad; IPv6 + /// addresses in `[…]` bracket form. + pub fn host_str(&self) -> String { + match &self.host { + Host::Domain(d) => d.clone(), + Host::Ipv4(octets) => + format!("{}.{}.{}.{}", octets[0], octets[1], octets[2], octets[3]), + Host::Ipv6(segs) => { + let s = segs.iter().map(|s| format!("{:x}", s)).collect::>().join(":"); + format!("[{}]", s) + } + } + } + + /// The explicit port, if one was given. + pub fn port(&self) -> Option { self.port } + + /// Replace the port. Pass `None` to remove it. + pub fn set_port(&mut self, port: Option) { + self.port = port; + self.rebuild_raw(); + } + + /// The path component, always starting with `/`. + pub fn path(&self) -> &str { &self.path } + + /// The fragment (without the leading `#`), if any. + pub fn fragment(&self) -> Option<&str> { self.fragment.as_deref() } + + /// Replace the fragment. Pass `None` to remove it. + pub fn set_fragment(&mut self, fragment: Option<&str>) { + self.fragment = fragment.map(|s| s.to_string()); + self.rebuild_raw(); + } + + /// Iterate over the `/`-separated path segments. + /// + /// Always returns `Some`; the `Option` mirrors the `url` crate's API. + pub fn path_segments(&self) -> Option> { + if self.path.is_empty() || self.path == "/" { + return Some(PathSegments { segments: vec![], index: 0 }); + } + let segments: Vec<&str> = self.path.trim_start_matches('/').split('/').collect(); + Some(PathSegments { segments, index: 0 }) + } + + /// The query string (without the leading `?`), if any. + pub fn query(&self) -> Option<&str> { self.query.as_deref() } + + /// Clear the query string. + pub fn clear_query(&mut self) { + self.query = None; + self.rebuild_raw(); + } + + /// Return a handle for appending `key=value` pairs to the query string. + pub fn query_pairs_mut(&mut self) -> UrlQueryPairs<'_> { UrlQueryPairs { url: self } } + + /// Return parsed query pairs as a Vec of Strings + pub fn query_pairs(&self) -> Vec<(String, String)> { + let Some(query) = &self.query else { return vec![] }; + query + .split('&') + .filter(|s| !s.is_empty()) + .filter_map(|pair| { + let (k, v) = pair.split_once('=')?; + let key = + percent_encoding_rfc3986::percent_decode_str(k).ok()?.decode_utf8().ok()?; + let val = + percent_encoding_rfc3986::percent_decode_str(v).ok()?.decode_utf8().ok()?; + Some((key.into_owned(), val.into_owned())) + }) + .collect() + } + + /// Resolve a reference against this URL per RFC 3986. + /// + /// - A `segment` with a scheme (`scheme://…`) is parsed as a new absolute URL. + /// - A segment starting with `/` replaces the path and clears the query + /// and fragment. + /// - Otherwise the segment is merged relative to the base, resolving + /// `.` and `..` dot-segments, and the query and fragment are cleared. + pub fn join(&self, segment: &str) -> Result { + // If the segment is a full URL (scheme://...), parse it independently. + // Only treat it as a full URL if no / appears before :// (i.e. in scheme position). + if let Some(pos) = segment.find("://") { + if !segment[..pos].contains('/') { + return Url::parse(segment); + } + } + + let mut new_url = self.clone(); + + if segment.starts_with('/') { + // Absolute path reference: replace entire path, clear query/fragment + new_url.path = segment.to_string(); + new_url.query = None; + new_url.fragment = None; + } else { + // Relative reference: merge per RFC 3986 + // Remove everything after the last '/' in the base path, then append segment + let base_path = + if let Some(pos) = new_url.path.rfind('/') { &new_url.path[..=pos] } else { "/" }; + let merged = format!("{}{}", base_path, segment); + + // Resolve dot segments + let mut output_segments: Vec<&str> = Vec::new(); + for part in merged.split('/') { + match part { + "." => {} + ".." => { + output_segments.pop(); + } + _ => output_segments.push(part), + } + } + new_url.path = output_segments.join("/"); + if !new_url.path.starts_with('/') { + new_url.path.insert(0, '/'); + } + new_url.query = None; + new_url.fragment = None; + } + + new_url.rebuild_raw(); + Ok(new_url) + } + + fn rebuild_raw(&mut self) { + let mut raw = String::new(); + raw.push_str(&self.scheme); + raw.push_str("://"); + raw.push_str(&self.host_str()); + + if let Some(port) = self.port { + raw.push(':'); + raw.push_str(&port.to_string()); + } + + raw.push_str(&self.path); + if let Some(ref query) = self.query { + raw.push('?'); + raw.push_str(query); + } + if let Some(ref fragment) = self.fragment { + raw.push('#'); + raw.push_str(fragment); + } + self.raw = raw; + } +} + +impl AsRef for Url { + fn as_ref(&self) -> &str { &self.raw } +} + +impl Url { + /// The URL in its serialised form. + /// + /// Equivalent to the [`Display`](fmt::Display) output and to the + /// [`AsRef`] impl. + pub fn as_str(&self) -> &str { &self.raw } +} + +impl fmt::Display for Url { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self.raw) } +} + +fn parse_scheme(input: &str) -> Result<(&str, String), ParseError> { + let chars = input.chars(); + let mut scheme = String::new(); + + for c in chars { + match c { + 'a'..='z' | 'A'..='Z' | '0'..='9' | '+' | '-' | '.' => { + scheme.push(c); + } + ':' => break, + _ => return Err(ParseError::InvalidScheme), + } + } + + if scheme.is_empty() { + return Err(ParseError::InvalidScheme); + } + + let scheme = scheme.to_lowercase(); + Ok((&input[scheme.len()..], scheme)) +} + +fn parse_host(input: &str) -> Result<(&str, Host), ParseError> { + // IPv6 literal: [xxxx:...] + if input.starts_with('[') { + let end = input.find(']').ok_or(ParseError::InvalidHost)?; + let ipv6_str = &input[1..end]; + let rest = &input[end + 1..]; + return Ok((rest, parse_ipv6(ipv6_str)?)); + } + + // Split at the first ':', '/', '?', or '#' to separate host from port/path/query/fragment + let mut end = input.len(); + for (i, c) in input.char_indices() { + if c == ':' || c == '/' || c == '?' || c == '#' { + end = i; + break; + } + } + let host_str = &input[..end]; + let rest = &input[end..]; + + if let Some(host) = try_parse_ipv4(host_str) { + return Ok((rest, host)); + } + + if host_str.is_empty() { + return Err(ParseError::EmptyHost); + } + if !host_str.chars().all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '.') { + return Err(ParseError::InvalidHost); + } + + Ok((rest, Host::Domain(host_str.to_string()))) +} + +fn try_parse_ipv4(s: &str) -> Option { + let parts: Vec<&str> = s.split('.').collect(); + if parts.len() != 4 { + return None; + } + let octets: [u8; 4] = [ + parts[0].parse().ok()?, + parts[1].parse().ok()?, + parts[2].parse().ok()?, + parts[3].parse().ok()?, + ]; + Some(Host::Ipv4(octets)) +} + +fn parse_ipv6(s: &str) -> Result { + let mut groups = [0u16; 8]; + if let Some((left, right)) = s.split_once("::") { + let left_parts = parse_ipv6_groups(left)?; + let right_parts = parse_ipv6_groups(right)?; + if left_parts.len() + right_parts.len() > 7 { + return Err(ParseError::InvalidHost); + } + for (i, &v) in left_parts.iter().enumerate() { + groups[i] = v; + } + let offset = 8 - right_parts.len(); + for (i, &v) in right_parts.iter().enumerate() { + groups[offset + i] = v; + } + } else { + let parts = parse_ipv6_groups(s)?; + if parts.len() != 8 { + return Err(ParseError::InvalidHost); + } + for (i, &v) in parts.iter().enumerate() { + groups[i] = v; + } + } + Ok(Host::Ipv6(groups)) +} + +fn parse_ipv6_groups(s: &str) -> Result, ParseError> { + if s.is_empty() { + return Ok(vec![]); + } + s.split(':').map(|p| u16::from_str_radix(p, 16).map_err(|_| ParseError::InvalidHost)).collect() +} + +fn parse_port(input: &str) -> Result<(&str, Option), ParseError> { + if !input.starts_with(':') { + return Ok((input, None)); + } + + let rest = &input[1..]; + let mut port_str = String::new(); + + for c in rest.chars() { + match c { + '0'..='9' => port_str.push(c), + '/' | '?' | '#' => break, + _ => return Err(ParseError::InvalidPort), + } + } + + if port_str.is_empty() { + return Ok((rest, None)); + } + + let port: u16 = port_str.parse().map_err(|_| ParseError::InvalidPort)?; + let remaining = &rest[port_str.len()..]; + Ok((remaining, Some(port))) +} + +fn parse_path_query_fragment(input: &str) -> (String, Option, Option) { + let (before_fragment, fragment) = match input.find('#') { + Some(pos) => (&input[..pos], Some(input[pos + 1..].to_string())), + None => (input, None), + }; + let (path, query) = match before_fragment.find('?') { + Some(pos) => + (before_fragment[..pos].to_string(), Some(before_fragment[pos + 1..].to_string())), + None => (before_fragment.to_string(), None), + }; + + (path, query, fragment) +} + +impl serde::Serialize for Url { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(&self.raw) + } +} + +impl<'de> serde::Deserialize<'de> for Url { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + Url::from_str(&s).map_err(serde::de::Error::custom) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_url_with_path() { + let url = Url::parse("https://example.com/path/to/resource").unwrap(); + assert_eq!(url.scheme(), "https"); + assert_eq!(url.domain(), Some("example.com")); + assert_eq!( + url.path_segments().map(|s| s.collect::>()), + Some(vec!["path", "to", "resource"]) + ); + } + + #[test] + fn test_set_fragment() { + let mut url = Url::parse("https://example.com/path").unwrap(); + url.set_fragment(Some("newfragment")); + assert_eq!(url.fragment(), Some("newfragment")); + assert!(url.as_ref().contains("#newfragment")); + } + + #[test] + fn test_join() { + let base = Url::parse("http://example.com/base/").unwrap(); + let joined = base.join("next").unwrap(); + assert_eq!(joined.path, "/base/next"); + } + + #[test] + fn test_parse_url_with_port_and_fragment() { + let input = "http://localhost:1234/PATH#FRAGMENT"; + let url = Url::parse(input).unwrap(); + assert_eq!(url.scheme(), "http"); + assert_eq!(url.domain(), Some("localhost")); + assert_eq!(url.port(), Some(1234)); + assert_eq!(url.path(), "/PATH"); + assert_eq!(url.fragment(), Some("FRAGMENT")); + assert_eq!(url.as_str(), "http://localhost:1234/PATH#FRAGMENT"); + } + + #[test] + fn test_empty_host_rejected() { + assert!(matches!(Url::parse("http:///path"), Err(ParseError::EmptyHost))); + } + + #[test] + fn test_path_segments_mut_push_adds_separator() { + let mut url = Url::parse("http://example.com/base").unwrap(); + { + let mut segs = url.path_segments_mut().unwrap(); + segs.push("child"); + } + assert_eq!(url.path(), "/base/child"); + assert_eq!(url.as_str(), "http://example.com/base/child"); + } + + #[test] + fn test_host_str() { + let url = Url::parse("http://example.com/").unwrap(); + assert_eq!(url.host_str(), "example.com".to_string()); + } + + #[test] + fn test_set_port() { + let mut url = Url::parse("http://example.com/path").unwrap(); + url.set_port(Some(9090)); + assert_eq!(url.port(), Some(9090)); + assert_eq!(url.as_str(), "http://example.com:9090/path"); + url.set_port(None); + assert_eq!(url.port(), None); + assert_eq!(url.as_str(), "http://example.com/path"); + } + + #[test] + fn test_path_segments_root() { + let url = Url::parse("http://example.com/").unwrap(); + let segs: Vec<_> = url.path_segments().unwrap().collect(); + assert!(segs.is_empty()); + + let url = Url::parse("http://example.com").unwrap(); + let segs: Vec<_> = url.path_segments().unwrap().collect(); + assert!(segs.is_empty()); + assert_eq!(url.as_str(), "http://example.com/"); + } + + #[test] + fn test_set_query() { + let mut url = Url::parse("http://example.com/path").unwrap(); + url.query_pairs_mut().append_pair("key", "value"); + assert_eq!(url.query(), Some("key=value")); + assert_eq!(url.as_str(), "http://example.com/path?key=value"); + url.clear_query(); + assert_eq!(url.query(), None); + assert_eq!(url.as_str(), "http://example.com/path"); + } + + #[test] + fn test_join_dot_segments() { + let base = Url::parse("http://example.com/a/b/c").unwrap(); + + let joined = base.join("./d").unwrap(); + assert_eq!(joined.path(), "/a/b/d"); + + let joined = base.join("../d").unwrap(); + assert_eq!(joined.path(), "/a/d"); + } + + #[test] + fn test_parse_query_and_fragment() { + let url = Url::parse("http://example.com/path?q=1#frag").unwrap(); + assert_eq!(url.path(), "/path"); + assert_eq!(url.query(), Some("q=1")); + assert_eq!(url.fragment(), Some("frag")); + } + + #[test] + fn test_parse_ipv4_with_port() { + let url = Url::parse("http://127.0.0.1:8080/path").unwrap(); + assert_eq!(url.host, Host::Ipv4([127, 0, 0, 1])); + assert_eq!(url.port(), Some(8080)); + assert_eq!(url.as_str(), "http://127.0.0.1:8080/path"); + } + + #[test] + fn test_parse_ipv6_full() { + let url = Url::parse("http://[2001:db8:85a3:0:0:8a2e:370:7334]/").unwrap(); + assert_eq!( + url.host, + Host::Ipv6([0x2001, 0x0db8, 0x85a3, 0x0000, 0x0000, 0x8a2e, 0x0370, 0x7334]) + ); + assert_eq!(url.as_str(), "http://[2001:db8:85a3:0:0:8a2e:370:7334]/"); + } + + #[test] + fn test_parse_ipv6_with_port() { + let url = Url::parse("http://[::1]:8080/path").unwrap(); + assert_eq!(url.host, Host::Ipv6([0, 0, 0, 0, 0, 0, 0, 1])); + assert_eq!(url.port(), Some(8080)); + assert_eq!(url.as_str(), "http://[0:0:0:0:0:0:0:1]:8080/path"); + } + + #[test] + fn test_parse_ipv6_unclosed_bracket() { + assert!(matches!(Url::parse("http://[::1/"), Err(ParseError::InvalidHost))); + } + + #[test] + fn test_ipv6_matches_std_parser() { + let url = Url::parse("http://[::1]/").unwrap(); + let std_addr: std::net::Ipv6Addr = "::1".parse().unwrap(); + assert_eq!(url.host, Host::Ipv6(std_addr.segments())); + assert_eq!(url.domain(), None); + assert_eq!(url.host_str(), "[0:0:0:0:0:0:0:1]".to_string()); + assert_eq!(url.as_str(), "http://[0:0:0:0:0:0:0:1]/"); + + let url = Url::parse("http://[1::1]/").unwrap(); + let std_addr: std::net::Ipv6Addr = "1::1".parse().unwrap(); + assert_eq!(url.host, Host::Ipv6(std_addr.segments())); + + let url = Url::parse("http://[1:2:3::4:5:6]/").unwrap(); + let std_addr: std::net::Ipv6Addr = "1:2:3::4:5:6".parse().unwrap(); + assert_eq!(url.host, Host::Ipv6(std_addr.segments())); + + let url = Url::parse("http://[1:2:3:4:5:6:7::]/").unwrap(); + let std_addr: std::net::Ipv6Addr = "1:2:3:4:5:6:7::".parse().unwrap(); + assert_eq!(url.host, Host::Ipv6(std_addr.segments())); + + let url = Url::parse("http://[2001:db8:85a3::8a2e:370:7334]/").unwrap(); + let std_addr: std::net::Ipv6Addr = "2001:db8:85a3::8a2e:370:7334".parse().unwrap(); + assert_eq!(url.host, Host::Ipv6(std_addr.segments())); + } + + #[test] + fn test_ipv4_matches_std_parser() { + let url = Url::parse("http://127.0.0.1/").unwrap(); + let std_addr: std::net::Ipv4Addr = "127.0.0.1".parse().unwrap(); + assert_eq!(url.host, Host::Ipv4(std_addr.octets())); + assert_eq!(url.domain(), None); + assert_eq!(url.host_str(), "127.0.0.1".to_string()); + assert_eq!(url.as_str(), "http://127.0.0.1/"); + + let url = Url::parse("http://192.168.1.1/").unwrap(); + let std_addr: std::net::Ipv4Addr = "192.168.1.1".parse().unwrap(); + assert_eq!(url.host, Host::Ipv4(std_addr.octets())); + + let url = Url::parse("http://0.0.0.0/").unwrap(); + let std_addr: std::net::Ipv4Addr = "0.0.0.0".parse().unwrap(); + assert_eq!(url.host, Host::Ipv4(std_addr.octets())); + + let url = Url::parse("http://255.255.255.255/").unwrap(); + let std_addr: std::net::Ipv4Addr = "255.255.255.255".parse().unwrap(); + assert_eq!(url.host, Host::Ipv4(std_addr.octets())); + } + + #[test] + fn test_parse_ipv6_too_many_groups_rejected() { + assert!(matches!(Url::parse("http://[1:2:3:4::5:6:7:8]/"), Err(ParseError::InvalidHost))); + } +} diff --git a/payjoin/tests/integration.rs b/payjoin/tests/integration.rs index 0936319c9..13d98669d 100644 --- a/payjoin/tests/integration.rs +++ b/payjoin/tests/integration.rs @@ -9,7 +9,7 @@ mod integration { use bitcoin::{Amount, FeeRate, OutPoint, TxIn, TxOut, Weight}; use payjoin::receive::v1::build_v1_pj_uri; use payjoin::receive::InputPair; - use payjoin::{ImplementationError, OutputSubstitution, PjUri, Request, Uri}; + use payjoin::{ImplementationError, OutputSubstitution, PjUri, Request, Uri, Url}; use payjoin_test_utils::corepc_node::vtype::ListUnspentItem; use payjoin_test_utils::corepc_node::AddressType; use payjoin_test_utils::{corepc_node, init_bitcoind_sender_receiver, init_tracing, BoxError}; @@ -1487,7 +1487,7 @@ mod integration { // Receiver receive payjoin proposal, IRL it will be an HTTP request (over ssl or onion) let proposal = payjoin::receive::v1::UncheckedOriginalPayload::from_request( req.body.as_slice(), - url::Url::from_str(&req.url).expect("Could not parse url").query().unwrap_or(""), + Url::from_str(&req.url).expect("Could not parse url").query().unwrap_or(""), headers, )?; let proposal =