diff --git a/payjoin-ffi/src/ohttp.rs b/payjoin-ffi/src/ohttp.rs index a901750a8..c3d8edb79 100644 --- a/payjoin-ffi/src/ohttp.rs +++ b/payjoin-ffi/src/ohttp.rs @@ -4,7 +4,7 @@ pub mod error { #[derive(Debug, thiserror::Error, uniffi::Object)] #[uniffi::export(Debug, Display)] #[error(transparent)] - pub struct OhttpError(#[from] ohttp::Error); + pub struct OhttpError(#[from] payjoin::OhttpKeysError); } impl From for OhttpKeys { @@ -28,15 +28,15 @@ impl OhttpKeys { use std::sync::Mutex; #[derive(uniffi::Object)] -pub struct ClientResponse(Mutex>); +pub struct ClientResponse(Mutex>); -impl From<&ClientResponse> for ohttp::ClientResponse { +impl From<&ClientResponse> for payjoin::OhttpResponse { fn from(value: &ClientResponse) -> Self { let mut data_guard = value.0.lock().unwrap(); Option::take(&mut *data_guard).expect("ClientResponse moved out of memory") } } -impl From for ClientResponse { - fn from(value: ohttp::ClientResponse) -> Self { Self(Mutex::new(Some(value))) } +impl From for ClientResponse { + fn from(value: payjoin::OhttpResponse) -> Self { Self(Mutex::new(Some(value))) } } diff --git a/payjoin-ffi/src/receive/error.rs b/payjoin-ffi/src/receive/error.rs index 542d39aa4..c3fbf8419 100644 --- a/payjoin-ffi/src/receive/error.rs +++ b/payjoin-ffi/src/receive/error.rs @@ -300,16 +300,15 @@ mod tests { use payjoin::persist::InMemoryPersister; use payjoin::receive::v2::{ReceiverBuilder, SessionEvent}; use payjoin::OhttpKeys; - use payjoin_test_utils::{EXAMPLE_URL, KEM, KEY_ID, SYMMETRIC}; + use payjoin_test_utils::EXAMPLE_URL; // Build a receiver whose session is already expired, then surface the // expiry error through the dedicated create-request error. let address = Address::from_str("tb1q6d3a2w975yny0asuvd9a67ner4nks58ff0q8g4") .expect("valid address") .assume_checked(); - let ohttp_keys = OhttpKeys( - ohttp::KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC)).expect("valid keys"), - ); + let ohttp_keys = OhttpKeys::decode(&payjoin_test_utils::ohttp_key_config_bytes()) + .expect("valid ohttp keys"); let persister = InMemoryPersister::::default(); let receiver = ReceiverBuilder::new(address, EXAMPLE_URL, ohttp_keys) .expect("valid builder") diff --git a/payjoin-test-utils/src/v2.rs b/payjoin-test-utils/src/v2.rs index 61c6ec6a0..835460150 100644 --- a/payjoin-test-utils/src/v2.rs +++ b/payjoin-test-utils/src/v2.rs @@ -221,3 +221,26 @@ pub const KEY_ID: KeyId = 1; pub const KEM: Kem = Kem::K256Sha256; pub const SYMMETRIC: &[SymmetricSuite] = &[ohttp::SymmetricSuite::new(Kdf::HkdfSha256, Aead::ChaCha20Poly1305)]; + +/// Derive the test OHTTP key config deterministically so that +/// [`ohttp_key_config_bytes`] and [`ohttp_server`] agree on the same key pair. +fn test_key_config() -> ohttp::KeyConfig { + ohttp::KeyConfig::derive(KEY_ID, KEM, SYMMETRIC.to_vec(), &crate::DUMMY32) + .expect("valid test key config") +} + +/// The encoded OHTTP key config for tests, decodable via the public +/// `OhttpKeys::decode`. +/// +/// Returns raw bytes rather than `OhttpKeys` so it is usable from payjoin's own +/// unit tests, where a helper returning a `payjoin` type would resolve to a +/// separate crate instance through the dev-dependency cycle. +pub fn ohttp_key_config_bytes() -> Vec { + test_key_config().encode().expect("valid key config encoding") +} + +/// The OHTTP server matching [`ohttp_key_config_bytes`], for tests that emulate +/// the directory's OHTTP gateway and must decapsulate client requests. +pub fn ohttp_server() -> ohttp::Server { + ohttp::Server::new(test_key_config()).expect("valid ohttp server") +} diff --git a/payjoin/src/core/hpke.rs b/payjoin/src/core/hpke.rs index aa12756af..a66f73140 100644 --- a/payjoin/src/core/hpke.rs +++ b/payjoin/src/core/hpke.rs @@ -1,6 +1,5 @@ use core::fmt; use std::error; -use std::ops::Deref; use bitcoin::key::constants::{ELLSWIFT_ENCODING_SIZE, PUBLIC_KEY_SIZE}; use bitcoin::secp256k1; @@ -57,7 +56,7 @@ fn pubkey_from_compressed_bytes(pk_bytes: &[u8]) -> Result [u8; PUBLIC_KEY_SIZE] { - let reply_pk_uncompressed = pk.to_bytes(); + let reply_pk_uncompressed = pk.0.to_bytes(); secp256k1::PublicKey::from_slice(&reply_pk_uncompressed[..]) .expect("parsing a pubkey immediately after serializing it must not fail") .serialize() @@ -81,13 +80,7 @@ fn ellswift_bytes_from_encapped_key( } #[derive(Clone, PartialEq, Eq)] -pub struct HpkeSecretKey(pub SecretKey); - -impl Deref for HpkeSecretKey { - type Target = SecretKey; - - fn deref(&self) -> &Self::Target { &self.0 } -} +pub struct HpkeSecretKey(SecretKey); impl fmt::Debug for HpkeSecretKey { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -118,7 +111,7 @@ impl<'de> serde::Deserialize<'de> for HpkeSecretKey { } #[derive(Clone, PartialEq, Eq)] -pub struct HpkePublicKey(pub PublicKey); +pub struct HpkePublicKey(PublicKey); impl HpkePublicKey { pub fn to_compressed_bytes(&self) -> [u8; 33] { @@ -135,12 +128,6 @@ impl HpkePublicKey { } } -impl Deref for HpkePublicKey { - type Target = PublicKey; - - fn deref(&self) -> &Self::Target { &self.0 } -} - impl fmt::Debug for HpkePublicKey { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "SecpHpkePublicKey({:?})", self.0) diff --git a/payjoin/src/core/io.rs b/payjoin/src/core/io.rs index 2cbc0010b..05994c1a6 100644 --- a/payjoin/src/core/io.rs +++ b/payjoin/src/core/io.rs @@ -182,8 +182,6 @@ impl From for Error { #[cfg(test)] mod tests { - use std::str::FromStr; - use http::StatusCode; use reqwest::Response; @@ -195,11 +193,7 @@ mod tests { #[tokio::test] async fn test_parse_success_response() { - let valid_keys = - OhttpKeys::from_str("OH1QYPM5JXYNS754Y4R45QWE336QFX6ZR8DQGVQCULVZTV20TFVEYDMFQC") - .expect("valid keys") - .encode() - .expect("encodevalid keys"); + let valid_keys = payjoin_test_utils::ohttp_key_config_bytes(); let response = mock_response(StatusCode::OK, valid_keys); assert!(parse_ohttp_keys_response(response).await.is_ok(), "expected valid keys response"); diff --git a/payjoin/src/core/mod.rs b/payjoin/src/core/mod.rs index ec64e9963..6c449256c 100644 --- a/payjoin/src/core/mod.rs +++ b/payjoin/src/core/mod.rs @@ -37,7 +37,7 @@ pub use crate::hpke::{HpkeKeyPair, HpkePublicKey}; #[cfg(feature = "v2")] pub(crate) mod ohttp; #[cfg(feature = "v2")] -pub use crate::ohttp::OhttpKeys; +pub use crate::ohttp::{OhttpKeys, OhttpKeysError, OhttpResponse}; #[cfg(feature = "io")] #[cfg_attr(docsrs, doc(cfg(feature = "io")))] diff --git a/payjoin/src/core/ohttp.rs b/payjoin/src/core/ohttp.rs index 2fbd0d7d4..af61f8e06 100644 --- a/payjoin/src/core/ohttp.rs +++ b/payjoin/src/core/ohttp.rs @@ -1,4 +1,3 @@ -use std::ops::{Deref, DerefMut}; use std::{error, fmt}; use bitcoin::bech32::{self, EncodeError}; @@ -14,13 +13,13 @@ pub const PADDED_BHTTP_REQ_BYTES: usize = ENCAPSULATED_MESSAGE_BYTES - (N_ENC + N_T + OHTTP_REQ_HEADER_BYTES); pub(crate) fn ohttp_encapsulate( - ohttp_keys: &ohttp::KeyConfig, + ohttp_keys: &OhttpKeys, method: &str, target_resource: &str, body: Option<&[u8]>, ) -> Result<([u8; ENCAPSULATED_MESSAGE_BYTES], ohttp::ClientResponse), OhttpEncapsulationError> { use std::fmt::Write; - let mut ohttp_keys = ohttp_keys.clone(); + let mut ohttp_keys = ohttp_keys.0.clone(); let ctx = ohttp::ClientRequest::from_config(&mut ohttp_keys)?; let url = crate::core::Url::parse(target_resource)?; @@ -210,16 +209,21 @@ impl error::Error for OhttpEncapsulationError { } #[derive(Debug, Clone)] -pub struct OhttpKeys(pub ohttp::KeyConfig); +pub struct OhttpKeys(ohttp::KeyConfig); impl OhttpKeys { /// Decode an OHTTP KeyConfig - pub fn decode(bytes: &[u8]) -> Result { - ohttp::KeyConfig::decode(bytes).map(Self) + pub fn decode(bytes: &[u8]) -> Result { + ohttp::KeyConfig::decode(bytes).map(Self).map_err(|e| OhttpKeysError::Decode(Box::new(e))) } - pub fn to_bytes(&self) -> Result, ohttp::Error> { - let bytes = self.encode()?; + /// Encode the OHTTP KeyConfig, decodable via [`OhttpKeys::decode`]. + pub fn encode(&self) -> Result, OhttpKeysError> { + self.0.encode().map_err(|e| OhttpKeysError::Encode(Box::new(e))) + } + + pub fn to_bytes(&self) -> Result, OhttpKeysError> { + let bytes = self.0.encode().map_err(|e| OhttpKeysError::Encode(Box::new(e)))?; let key_id = bytes[0]; let uncompressed_pubkey = &bytes[3..68]; @@ -235,6 +239,20 @@ impl OhttpKeys { } } +/// An opaque OHTTP client context. +/// +/// Returned alongside the [`Request`](crate::Request) by a `create_*_request` +/// method and consumed by the paired `process_*` method to decapsulate the +/// directory's response. Callers hold it between the two calls without +/// inspecting it. +pub struct OhttpResponse(ohttp::ClientResponse); + +impl OhttpResponse { + pub(crate) fn new(inner: ohttp::ClientResponse) -> Self { Self(inner) } + + pub(crate) fn into_inner(self) -> ohttp::ClientResponse { self.0 } +} + const KEM_ID: &[u8] = b"\x00\x16"; // DHKEM(secp256k1, HKDF-SHA256) const SYMMETRIC_LEN: &[u8] = b"\x00\x04"; // 4 bytes const SYMMETRIC_KDF_AEAD: &[u8] = b"\x00\x01\x00\x03"; // KDF(HKDF-SHA256), AEAD(ChaCha20Poly1305) @@ -254,17 +272,17 @@ impl fmt::Display for OhttpKeys { } impl TryFrom<&[u8]> for OhttpKeys { - type Error = ParseOhttpKeysError; + type Error = OhttpKeysError; fn try_from(bytes: &[u8]) -> Result { let buf: [u8; 34] = - bytes.try_into().map_err(|_| ParseOhttpKeysError::IncorrectLength(bytes.len()))?; + bytes.try_into().map_err(|_| OhttpKeysError::IncorrectLength(bytes.len()))?; let key_id = buf[0]; let compressed_pk = &buf[1..]; let pubkey = bitcoin::secp256k1::PublicKey::from_slice(compressed_pk) - .map_err(|_| ParseOhttpKeysError::InvalidPublicKey)?; + .map_err(|_| OhttpKeysError::InvalidPublicKey)?; let mut buf = vec![key_id]; buf.extend_from_slice(KEM_ID); @@ -272,13 +290,13 @@ impl TryFrom<&[u8]> for OhttpKeys { buf.extend_from_slice(SYMMETRIC_LEN); buf.extend_from_slice(SYMMETRIC_KDF_AEAD); - ohttp::KeyConfig::decode(&buf).map(Self).map_err(ParseOhttpKeysError::DecodeKeyConfig) + ohttp::KeyConfig::decode(&buf).map(Self).map_err(|e| OhttpKeysError::Decode(Box::new(e))) } } #[cfg(test)] impl std::str::FromStr for OhttpKeys { - type Err = ParseOhttpKeysError; + type Err = OhttpKeysError; /// Parses a base64URL-encoded string into OhttpKeys. /// The string format is: key_id || compressed_public_key @@ -287,10 +305,10 @@ impl std::str::FromStr for OhttpKeys { bech32::Hrp::parse("OH").expect("parsing a valid HRP constant should never fail"); let (hrp, bytes) = - crate::bech32::nochecksum::decode(s).map_err(|_| ParseOhttpKeysError::InvalidFormat)?; + crate::bech32::nochecksum::decode(s).map_err(|_| OhttpKeysError::InvalidFormat)?; if hrp != oh_hrp { - return Err(ParseOhttpKeysError::InvalidFormat); + return Err(OhttpKeysError::InvalidFormat); } Self::try_from(&bytes[..]) @@ -299,9 +317,9 @@ impl std::str::FromStr for OhttpKeys { impl PartialEq for OhttpKeys { fn eq(&self, other: &Self) -> bool { - match (self.encode(), other.encode()) { + match (self.0.encode(), other.0.encode()) { (Ok(self_encoded), Ok(other_encoded)) => self_encoded == other_encoded, - // If OhttpKeys::encode(&self) is Err, return false + // If the key config fails to encode, return false _ => false, } } @@ -309,16 +327,6 @@ impl PartialEq for OhttpKeys { impl Eq for OhttpKeys {} -impl Deref for OhttpKeys { - type Target = ohttp::KeyConfig; - - fn deref(&self) -> &Self::Target { &self.0 } -} - -impl DerefMut for OhttpKeys { - fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 } -} - impl<'de> serde::Deserialize<'de> for OhttpKeys { fn deserialize(deserializer: D) -> Result where @@ -334,38 +342,46 @@ impl serde::Serialize for OhttpKeys { where S: serde::Serializer, { - let bytes = self.encode().map_err(serde::ser::Error::custom)?; + let bytes = self.0.encode().map_err(serde::ser::Error::custom)?; bytes.serialize(serializer) } } +/// Error encoding or decoding [`OhttpKeys`]. #[derive(Debug)] -pub enum ParseOhttpKeysError { +#[non_exhaustive] +pub enum OhttpKeysError { + /// The provided bytes were not the expected length. IncorrectLength(usize), + /// The bytes did not encode a valid public key. InvalidPublicKey, - DecodeKeyConfig(ohttp::Error), + /// The bytes could not be decoded as an OHTTP key config. + Decode(Box), + /// The OHTTP key config could not be encoded. + Encode(Box), #[cfg(test)] InvalidFormat, } -impl std::fmt::Display for ParseOhttpKeysError { +impl std::fmt::Display for OhttpKeysError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - use ParseOhttpKeysError::*; + use OhttpKeysError::*; match self { IncorrectLength(l) => write!(f, "Invalid length, got {l} expected 34"), InvalidPublicKey => write!(f, "Invalid public key"), - DecodeKeyConfig(e) => write!(f, "Failed to decode KeyConfig: {e}"), + Decode(e) => write!(f, "Failed to decode OHTTP keys: {e}"), + Encode(e) => write!(f, "Failed to encode OHTTP keys: {e}"), #[cfg(test)] InvalidFormat => write!(f, "Invalid format"), } } } -impl std::error::Error for ParseOhttpKeysError { +impl std::error::Error for OhttpKeysError { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - use ParseOhttpKeysError::*; + use OhttpKeysError::*; match self { - DecodeKeyConfig(e) => Some(e), + Decode(e) | Encode(e) => Some(e.as_ref()), IncorrectLength(_) | InvalidPublicKey => None, #[cfg(test)] InvalidFormat => None, diff --git a/payjoin/src/core/receive/v2/mod.rs b/payjoin/src/core/receive/v2/mod.rs index feaf596ab..7c9261cf3 100644 --- a/payjoin/src/core/receive/v2/mod.rs +++ b/payjoin/src/core/receive/v2/mod.rs @@ -53,6 +53,7 @@ use crate::error::{InternalReplayError, ReplayError}; use crate::hpke::{decrypt_message_a, encrypt_message_b, HpkeKeyPair, HpkePublicKey}; use crate::ohttp::{ ohttp_encapsulate, process_get_res, process_post_res, OhttpEncapsulationError, OhttpKeys, + OhttpResponse, }; use crate::output_substitution::OutputSubstitution; use crate::persist::{ @@ -539,13 +540,13 @@ impl Receiver { pub fn create_poll_request( &self, ohttp_relay: impl IntoUrl, - ) -> Result<(Request, ohttp::ClientResponse), CreateRequestError> { + ) -> Result<(Request, OhttpResponse), CreateRequestError> { if self.session_context.expiration.elapsed() { return Err(InternalCreateRequestError::Expired(self.session_context.expiration).into()); } let (body, ohttp_ctx) = self.fallback_req_body()?; let req = Request::new_v2(&self.session_context.full_relay_url(ohttp_relay)?, &body); - Ok((req, ohttp_ctx)) + Ok((req, OhttpResponse::new(ohttp_ctx))) } /// Process the response to the Original PSBT poll from the Payjoin Directory. @@ -559,7 +560,7 @@ impl Receiver { pub fn process_response( self, body: &[u8], - context: ohttp::ClientResponse, + context: OhttpResponse, ) -> MaybeFatalTransitionWithNoResults< SessionEvent, Receiver, @@ -567,7 +568,7 @@ impl Receiver { ProtocolError, > { let current_state = self.clone(); - let proposal = match self.inner_process_res(body, context) { + let proposal = match self.inner_process_res(body, context.into_inner()) { Ok(proposal) => proposal, Err(e) => match e { ProtocolError::V2(SessionError(InternalSessionError::DirectoryResponse( @@ -1276,7 +1277,7 @@ impl Receiver { pub fn create_post_request( &self, ohttp_relay: impl IntoUrl, - ) -> Result<(Request, ohttp::ClientResponse), CreateRequestError> { + ) -> Result<(Request, OhttpResponse), CreateRequestError> { if self.session_context.expiration.elapsed() { return Err(InternalCreateRequestError::Expired(self.session_context.expiration).into()); } @@ -1308,7 +1309,7 @@ impl Receiver { )?; let req = Request::new_v2(&self.session_context.full_relay_url(ohttp_relay)?, &body); - Ok((req, ctx)) + Ok((req, OhttpResponse::new(ctx))) } /// Processes the response for the final POST message from the receiver client in the v2 Payjoin protocol. @@ -1321,14 +1322,14 @@ impl Receiver { pub fn process_response( self, res: &[u8], - ohttp_context: ohttp::ClientResponse, + ohttp_context: OhttpResponse, ) -> MaybeFatalTransition< SessionEvent, Receiver, ProtocolError, Receiver, > { - match process_post_res(res, ohttp_context) { + match process_post_res(res, ohttp_context.into_inner()) { Ok(_) => MaybeFatalTransition::success( SessionEvent::PostedPayjoinProposal(), Receiver { @@ -1387,7 +1388,7 @@ impl Receiver { pub fn create_error_request( &self, ohttp_relay: impl IntoUrl, - ) -> Result<(Request, ohttp::ClientResponse), SessionError> { + ) -> Result<(Request, OhttpResponse), SessionError> { let session_context = &self.session_context; if session_context.expiration.elapsed() { return Err(InternalSessionError::Expired(session_context.expiration).into()); @@ -1410,10 +1411,10 @@ impl Receiver { } }; let (body, ohttp_ctx) = - ohttp_encapsulate(&session_context.ohttp_keys.0, "POST", mailbox.as_str(), Some(&body)) + ohttp_encapsulate(&session_context.ohttp_keys, "POST", mailbox.as_str(), Some(&body)) .map_err(InternalSessionError::OhttpEncapsulation)?; let req = Request::new_v2(&session_context.full_relay_url(ohttp_relay)?, &body); - Ok((req, ohttp_ctx)) + Ok((req, OhttpResponse::new(ohttp_ctx))) } /// Process an OHTTP Encapsulated HTTP POST Error response @@ -1421,7 +1422,7 @@ impl Receiver { pub fn process_error_response( &self, res: &[u8], - ohttp_context: ohttp::ClientResponse, + ohttp_context: OhttpResponse, ) -> MaybeTerminalSuccessTransition, ProtocolError> { let pending = self.pending_fallback_after_protocol_failure(); @@ -1432,7 +1433,7 @@ impl Receiver { let protocol_error = |e| ProtocolError::V2(InternalSessionError::DirectoryResponse(e).into()); - match (process_post_res(res, ohttp_context), pending) { + match (process_post_res(res, ohttp_context.into_inner()), pending) { (Ok(_), Some(pending_fallback)) => MaybeTerminalSuccessTransition::advance(event, pending_fallback), (Ok(_), None) => MaybeTerminalSuccessTransition::terminate(event), @@ -1581,8 +1582,8 @@ pub mod test { use bitcoin::{Amount, FeeRate, ScriptBuf, Witness}; use once_cell::sync::Lazy; use payjoin_test_utils::{ - BoxError, EXAMPLE_URL, KEM, KEY_ID, ORIGINAL_PSBT, PARSED_ORIGINAL_PSBT, - PARSED_PAYJOIN_PROPOSAL, QUERY_PARAMS, SYMMETRIC, + BoxError, EXAMPLE_URL, ORIGINAL_PSBT, PARSED_ORIGINAL_PSBT, PARSED_PAYJOIN_PROPOSAL, + QUERY_PARAMS, }; use super::*; @@ -1599,9 +1600,8 @@ pub mod test { .expect("valid address") .assume_checked(), directory: Url::from_str(EXAMPLE_URL).expect("Could not parse Url"), - ohttp_keys: OhttpKeys( - ohttp::KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC)).expect("valid key config"), - ), + ohttp_keys: OhttpKeys::decode(&payjoin_test_utils::ohttp_key_config_bytes()) + .expect("valid ohttp keys"), expiration: Time::from_now(Duration::from_secs(60)).expect("Valid timestamp"), receiver_key: HpkeKeyPair::gen_keypair(), reply_key: None, @@ -1659,8 +1659,7 @@ pub mod test { } fn ohttp_response_for(req_body: &[u8], status: http::StatusCode) -> Vec { - let server = ohttp::Server::new(SHARED_CONTEXT.ohttp_keys.0.clone()) - .expect("test OHTTP server should be valid"); + let server = payjoin_test_utils::ohttp_server(); let (_, probe_response) = server.decapsulate(req_body).expect("request should decapsulate"); let response_overhead = probe_response.encapsulate(&[]).expect("probe should encrypt").len(); diff --git a/payjoin/src/core/send/v2/mod.rs b/payjoin/src/core/send/v2/mod.rs index e2dda4a11..0c8491a86 100644 --- a/payjoin/src/core/send/v2/mod.rs +++ b/payjoin/src/core/send/v2/mod.rs @@ -44,7 +44,7 @@ use super::*; use crate::core::Url; use crate::error::{InternalReplayError, ReplayError}; use crate::hpke::{decrypt_message_b, encrypt_message_a, HpkeSecretKey}; -use crate::ohttp::{ohttp_encapsulate, process_get_res, process_post_res}; +use crate::ohttp::{ohttp_encapsulate, process_get_res, process_post_res, OhttpResponse}; use crate::persist::{ MaybeFatalTransition, MaybeSuccessTransitionWithNoResults, NextStateTransition, TerminalTransition, @@ -330,7 +330,7 @@ impl SendSession { } /// A payjoin V2 sender, allowing the construction of a payjoin V2 request -/// and the resulting [`ClientResponse`]. +/// and the resulting [`OhttpResponse`]. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct WithReplyKey; @@ -358,7 +358,7 @@ impl Sender { pub fn create_v2_post_request( &self, ohttp_relay: impl IntoUrl, - ) -> Result<(Request, ClientResponse), CreateRequestError> { + ) -> Result<(Request, OhttpResponse), CreateRequestError> { if self.session_context.pj_param.expiration().elapsed() { return Err(InternalCreateRequestError::Expired( self.session_context.pj_param.expiration(), @@ -375,7 +375,7 @@ impl Sender { self.session_context.psbt_ctx.min_fee_rate, )?; let (request, ohttp_ctx) = extract_request(&self.session_context, ohttp_relay, body)?; - Ok((request, ohttp_ctx)) + Ok((request, OhttpResponse::new(ohttp_ctx))) } /// Processes the response for the initial POST message from the sender @@ -391,9 +391,9 @@ impl Sender { pub fn process_response( self, response: &[u8], - post_ctx: ClientResponse, + post_ctx: OhttpResponse, ) -> MaybeFatalTransition, DecapsulationError> { - match process_post_res(response, post_ctx) { + match process_post_res(response, post_ctx.into_inner()) { Ok(()) => {} Err(e) => if e.is_fatal() { @@ -481,7 +481,7 @@ impl Sender { pub fn create_poll_request( &self, ohttp_relay: impl IntoUrl, - ) -> Result<(Request, ohttp::ClientResponse), CreateRequestError> { + ) -> Result<(Request, OhttpResponse), CreateRequestError> { if self.session_context.pj_param.expiration().elapsed() { return Err(InternalCreateRequestError::Expired( self.session_context.pj_param.expiration(), @@ -510,7 +510,10 @@ impl Sender { let (body, ohttp_ctx) = ohttp_encapsulate(ohttp_keys, "GET", url.as_str(), Some(&body)) .map_err(InternalCreateRequestError::OhttpEncapsulation)?; - Ok((Request::new_v2(&self.session_context.full_relay_url(ohttp_relay)?, &body), ohttp_ctx)) + Ok(( + Request::new_v2(&self.session_context.full_relay_url(ohttp_relay)?, &body), + OhttpResponse::new(ohttp_ctx), + )) } /// Processes the response for the final GET message from the sender client @@ -526,14 +529,14 @@ impl Sender { pub fn process_response( self, response: &[u8], - ohttp_ctx: ohttp::ClientResponse, + ohttp_ctx: OhttpResponse, ) -> MaybeSuccessTransitionWithNoResults< SessionEvent, Psbt, Sender, ResponseError, > { - let body = match process_get_res(response, ohttp_ctx) { + let body = match process_get_res(response, ohttp_ctx.into_inner()) { Ok(Some(body)) => body, Ok(None) => return MaybeSuccessTransitionWithNoResults::no_results(self.clone()), Err(e) => @@ -617,7 +620,7 @@ mod test { use bitcoin::hex::FromHex; use bitcoin::Address; - use payjoin_test_utils::{BoxError, EXAMPLE_URL, KEM, KEY_ID, PARSED_ORIGINAL_PSBT, SYMMETRIC}; + use payjoin_test_utils::{BoxError, EXAMPLE_URL, PARSED_ORIGINAL_PSBT}; use super::*; use crate::persist::InMemoryPersister; @@ -635,9 +638,8 @@ mod test { endpoint, crate::uri::ShortId::try_from(&b"12345670"[..]).expect("valid short id"), expiration, - OhttpKeys( - ohttp::KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC)).expect("valid key config"), - ), + OhttpKeys::decode(&payjoin_test_utils::ohttp_key_config_bytes()) + .expect("valid ohttp keys"), HpkeKeyPair::gen_keypair().1, ); Ok(super::Sender { @@ -728,9 +730,8 @@ mod test { .expect("valid address") .assume_checked(); let directory = EXAMPLE_URL; - let ohttp_keys = OhttpKeys( - ohttp::KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC)).expect("valid key config"), - ); + let ohttp_keys = OhttpKeys::decode(&payjoin_test_utils::ohttp_key_config_bytes()) + .expect("valid ohttp keys"); let pj_uri = ReceiverBuilder::new(address.clone(), directory, ohttp_keys) .expect("constructor on test vector should not fail") .build() diff --git a/payjoin/src/core/send/v2/session.rs b/payjoin/src/core/send/v2/session.rs index fc6955951..a0325d58c 100644 --- a/payjoin/src/core/send/v2/session.rs +++ b/payjoin/src/core/send/v2/session.rs @@ -176,7 +176,7 @@ mod tests { use std::time::{Duration, SystemTime}; use bitcoin::{FeeRate, ScriptBuf}; - use payjoin_test_utils::{KEM, KEY_ID, PARSED_ORIGINAL_PSBT, SYMMETRIC}; + use payjoin_test_utils::PARSED_ORIGINAL_PSBT; use super::*; use crate::core::Url; @@ -201,9 +201,8 @@ mod tests { endpoint, id, expiration, - crate::OhttpKeys( - ohttp::KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC)).expect("valid key config"), - ), + crate::OhttpKeys::decode(&payjoin_test_utils::ohttp_key_config_bytes()) + .expect("valid ohttp keys"), HpkeKeyPair::gen_keypair().1, ); let sender_with_reply_key = Sender { @@ -400,9 +399,8 @@ mod tests { endpoint, id, expiration, - crate::OhttpKeys( - ohttp::KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC)).expect("valid key config"), - ), + crate::OhttpKeys::decode(&payjoin_test_utils::ohttp_key_config_bytes()) + .expect("valid ohttp keys"), HpkeKeyPair::gen_keypair().1, ); diff --git a/payjoin/src/core/uri/v2.rs b/payjoin/src/core/uri/v2.rs index 7f75bd386..9b9da5a5f 100644 --- a/payjoin/src/core/uri/v2.rs +++ b/payjoin/src/core/uri/v2.rs @@ -291,7 +291,7 @@ impl std::error::Error for PjParseError { pub(super) enum ParseOhttpKeysParamError { MissingOhttpKeys, InvalidFormat, - InvalidOhttpKeys(crate::ohttp::ParseOhttpKeysError), + InvalidOhttpKeys(crate::ohttp::OhttpKeysError), InvalidFragment(ParseFragmentError), } @@ -433,7 +433,7 @@ mod tests { assert!(matches!( ohttp(&too_long_ohttp_url), Err(ParseOhttpKeysParamError::InvalidOhttpKeys( - crate::ohttp::ParseOhttpKeysError::IncorrectLength(_) + crate::ohttp::OhttpKeysError::IncorrectLength(_) )) )); @@ -443,7 +443,7 @@ mod tests { assert!(matches!( ohttp(&too_short_ohttp_url), Err(ParseOhttpKeysParamError::InvalidOhttpKeys( - crate::ohttp::ParseOhttpKeysError::IncorrectLength(_) + crate::ohttp::OhttpKeysError::IncorrectLength(_) )) )); }