Skip to content

Commit acb5218

Browse files
committed
Force payjoin-cli resume to choose a random relay for each session
This commit ensures that each resumed payjoin-cli session is using a separate instance of the RelayManager to then check the ohttp connection independently. This fixes a bug where resuming would converge all existing sessions to one ohttp relay.
1 parent f7b415a commit acb5218

2 files changed

Lines changed: 68 additions & 58 deletions

File tree

payjoin-cli/src/app/v2/mod.rs

Lines changed: 63 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use std::fmt;
2-
use std::sync::{Arc, Mutex};
2+
use std::sync::Arc;
33

44
use anyhow::{anyhow, Context, Result};
55
use payjoin::bitcoin::consensus::encode::serialize_hex;
@@ -40,7 +40,6 @@ pub(crate) struct App {
4040
db: Arc<Database>,
4141
wallet: BitcoindWallet,
4242
interrupt: watch::Receiver<()>,
43-
relay_manager: Arc<Mutex<RelayManager>>,
4443
}
4544

4645
trait StatusText {
@@ -140,11 +139,10 @@ impl<Status: StatusText> fmt::Display for SessionHistoryRow<Status> {
140139
impl AppTrait for App {
141140
async fn new(config: Config) -> Result<Self> {
142141
let db = Arc::new(Database::create(&config.db_path)?);
143-
let relay_manager = Arc::new(Mutex::new(RelayManager::new()));
144142
let (interrupt_tx, interrupt_rx) = watch::channel(());
145143
tokio::spawn(handle_interrupt(interrupt_tx));
146144
let wallet = BitcoindWallet::new(&config.bitcoind).await?;
147-
let app = Self { config, db, wallet, interrupt: interrupt_rx, relay_manager };
145+
let app = Self { config, db, wallet, interrupt: interrupt_rx };
148146
app.wallet()
149147
.network()
150148
.context("Failed to connect to bitcoind. Check config RPC connection.")?;
@@ -254,10 +252,10 @@ impl AppTrait for App {
254252

255253
async fn receive_payjoin(&self, amount: Amount) -> Result<()> {
256254
let address = self.wallet().get_new_address()?;
257-
let ohttp_keys =
258-
unwrap_ohttp_keys_or_else_fetch(&self.config, None, self.relay_manager.clone())
259-
.await?
260-
.ohttp_keys;
255+
let mut relay_manager = RelayManager::new();
256+
let ohttp_keys = unwrap_ohttp_keys_or_else_fetch(&self.config, None, &mut relay_manager)
257+
.await?
258+
.ohttp_keys;
261259
let persister = ReceiverPersister::new(self.db.clone())?;
262260
let session =
263261
ReceiverBuilder::new(address, self.config.v2()?.pj_directory.as_str(), ohttp_keys)?
@@ -276,7 +274,6 @@ impl AppTrait for App {
276274
Ok(())
277275
}
278276

279-
#[allow(clippy::incompatible_msrv)]
280277
async fn resume_payjoins(&self) -> Result<()> {
281278
let recv_session_ids = self.db.get_recv_session_ids()?;
282279
let send_session_ids = self.db.get_send_session_ids()?;
@@ -480,11 +477,12 @@ impl App {
480477
session: SendSession,
481478
persister: &SenderPersister,
482479
) -> Result<()> {
480+
let mut relay_manager = RelayManager::new();
483481
match session {
484482
SendSession::WithReplyKey(context) =>
485-
self.post_original_proposal(context, persister).await?,
483+
self.post_original_proposal(context, persister, &mut relay_manager).await?,
486484
SendSession::PollingForProposal(context) =>
487-
self.get_proposed_payjoin_psbt(context, persister).await?,
485+
self.get_proposed_payjoin_psbt(context, persister, &mut relay_manager).await?,
488486
SendSession::Closed(SenderSessionOutcome::Success(proposal)) => {
489487
self.process_pj_response(proposal)?;
490488
return Ok(());
@@ -498,22 +496,27 @@ impl App {
498496
&self,
499497
sender: Sender<WithReplyKey>,
500498
persister: &SenderPersister,
499+
relay_manager: &mut RelayManager,
501500
) -> Result<()> {
502501
let (req, ctx) = sender.create_v2_post_request(
503-
self.unwrap_relay_or_else_fetch(Some(&sender.endpoint())).await?.as_str(),
502+
self.unwrap_relay_or_else_fetch(Some(&sender.endpoint()), relay_manager)
503+
.await?
504+
.as_str(),
504505
)?;
505506
let response = self.post_request(req).await?;
506507
println!("Posted original proposal...");
507508
let sender = sender.process_response(&response.bytes().await?, ctx).save(persister)?;
508-
self.get_proposed_payjoin_psbt(sender, persister).await
509+
self.get_proposed_payjoin_psbt(sender, persister, relay_manager).await
509510
}
510511

511512
async fn get_proposed_payjoin_psbt(
512513
&self,
513514
sender: Sender<PollingForProposal>,
514515
persister: &SenderPersister,
516+
relay_manager: &mut RelayManager,
515517
) -> Result<()> {
516-
let ohttp_relay = self.unwrap_relay_or_else_fetch(Some(&sender.endpoint())).await?;
518+
let ohttp_relay =
519+
self.unwrap_relay_or_else_fetch(Some(&sender.endpoint()), relay_manager).await?;
517520
let mut session = sender.clone();
518521
// Long poll until we get a response
519522
loop {
@@ -544,9 +547,11 @@ impl App {
544547
&self,
545548
session: Receiver<Initialized>,
546549
persister: &ReceiverPersister,
550+
relay_manager: &mut RelayManager,
547551
) -> Result<Receiver<UncheckedOriginalPayload>> {
548-
let ohttp_relay =
549-
self.unwrap_relay_or_else_fetch(Some(&session.pj_uri().extras.endpoint())).await?;
552+
let ohttp_relay = self
553+
.unwrap_relay_or_else_fetch(Some(&session.pj_uri().extras.endpoint()), relay_manager)
554+
.await?;
550555

551556
let mut session = session;
552557
loop {
@@ -575,30 +580,31 @@ impl App {
575580
session: ReceiveSession,
576581
persister: &ReceiverPersister,
577582
) -> Result<()> {
583+
let mut relay_manager = RelayManager::new();
578584
let res = {
579585
match session {
580586
ReceiveSession::Initialized(proposal) =>
581-
self.read_from_directory(proposal, persister).await,
587+
self.read_from_directory(proposal, persister, &mut relay_manager).await,
582588
ReceiveSession::UncheckedOriginalPayload(proposal) =>
583-
self.check_proposal(proposal, persister).await,
589+
self.check_proposal(proposal, persister, &mut relay_manager).await,
584590
ReceiveSession::MaybeInputsOwned(proposal) =>
585-
self.check_inputs_not_owned(proposal, persister).await,
591+
self.check_inputs_not_owned(proposal, persister, &mut relay_manager).await,
586592
ReceiveSession::MaybeInputsSeen(proposal) =>
587-
self.check_no_inputs_seen_before(proposal, persister).await,
593+
self.check_no_inputs_seen_before(proposal, persister, &mut relay_manager).await,
588594
ReceiveSession::OutputsUnknown(proposal) =>
589-
self.identify_receiver_outputs(proposal, persister).await,
595+
self.identify_receiver_outputs(proposal, persister, &mut relay_manager).await,
590596
ReceiveSession::WantsOutputs(proposal) =>
591-
self.commit_outputs(proposal, persister).await,
597+
self.commit_outputs(proposal, persister, &mut relay_manager).await,
592598
ReceiveSession::WantsInputs(proposal) =>
593-
self.contribute_inputs(proposal, persister).await,
599+
self.contribute_inputs(proposal, persister, &mut relay_manager).await,
594600
ReceiveSession::WantsFeeRange(proposal) =>
595-
self.apply_fee_range(proposal, persister).await,
601+
self.apply_fee_range(proposal, persister, &mut relay_manager).await,
596602
ReceiveSession::ProvisionalProposal(proposal) =>
597-
self.finalize_proposal(proposal, persister).await,
603+
self.finalize_proposal(proposal, persister, &mut relay_manager).await,
598604
ReceiveSession::PayjoinProposal(proposal) =>
599-
self.send_payjoin_proposal(proposal, persister).await,
605+
self.send_payjoin_proposal(proposal, persister, &mut relay_manager).await,
600606
ReceiveSession::HasReplyableError(error) =>
601-
self.handle_error(error, persister).await,
607+
self.handle_error(error, persister, &mut relay_manager).await,
602608
ReceiveSession::Monitor(proposal) =>
603609
self.monitor_payjoin_proposal(proposal, persister).await,
604610
ReceiveSession::Closed(_) => return Err(anyhow!("Session closed")),
@@ -612,22 +618,24 @@ impl App {
612618
&self,
613619
session: Receiver<Initialized>,
614620
persister: &ReceiverPersister,
621+
relay_manager: &mut RelayManager,
615622
) -> Result<()> {
616623
let mut interrupt = self.interrupt.clone();
617624
let receiver = tokio::select! {
618-
res = self.long_poll_fallback(session, persister) => res,
625+
res = self.long_poll_fallback(session, persister, relay_manager) => res,
619626
_ = interrupt.changed() => {
620627
println!("Interrupted. Call the `resume` command to resume all sessions.");
621628
return Err(anyhow!("Interrupted"));
622629
}
623630
}?;
624-
self.check_proposal(receiver, persister).await
631+
self.check_proposal(receiver, persister, relay_manager).await
625632
}
626633

627634
async fn check_proposal(
628635
&self,
629636
proposal: Receiver<UncheckedOriginalPayload>,
630637
persister: &ReceiverPersister,
638+
relay_manager: &mut RelayManager,
631639
) -> Result<()> {
632640
let wallet = self.wallet();
633641
let proposal = proposal
@@ -640,13 +648,14 @@ impl App {
640648

641649
println!("Fallback transaction received. Consider broadcasting this to get paid if the Payjoin fails:");
642650
println!("{}", serialize_hex(&proposal.extract_tx_to_schedule_broadcast()));
643-
self.check_inputs_not_owned(proposal, persister).await
651+
self.check_inputs_not_owned(proposal, persister, relay_manager).await
644652
}
645653

646654
async fn check_inputs_not_owned(
647655
&self,
648656
proposal: Receiver<MaybeInputsOwned>,
649657
persister: &ReceiverPersister,
658+
relay_manager: &mut RelayManager,
650659
) -> Result<()> {
651660
let wallet = self.wallet();
652661
let proposal = proposal
@@ -656,26 +665,28 @@ impl App {
656665
.map_err(|e| ImplementationError::from(e.into_boxed_dyn_error()))
657666
})
658667
.save(persister)?;
659-
self.check_no_inputs_seen_before(proposal, persister).await
668+
self.check_no_inputs_seen_before(proposal, persister, relay_manager).await
660669
}
661670

662671
async fn check_no_inputs_seen_before(
663672
&self,
664673
proposal: Receiver<MaybeInputsSeen>,
665674
persister: &ReceiverPersister,
675+
relay_manager: &mut RelayManager,
666676
) -> Result<()> {
667677
let proposal = proposal
668678
.check_no_inputs_seen_before(&mut |input| {
669679
Ok(self.db.insert_input_seen_before(*input)?)
670680
})
671681
.save(persister)?;
672-
self.identify_receiver_outputs(proposal, persister).await
682+
self.identify_receiver_outputs(proposal, persister, relay_manager).await
673683
}
674684

675685
async fn identify_receiver_outputs(
676686
&self,
677687
proposal: Receiver<OutputsUnknown>,
678688
persister: &ReceiverPersister,
689+
relay_manager: &mut RelayManager,
679690
) -> Result<()> {
680691
let wallet = self.wallet();
681692
let proposal = proposal
@@ -685,22 +696,24 @@ impl App {
685696
.map_err(|e| ImplementationError::from(e.into_boxed_dyn_error()))
686697
})
687698
.save(persister)?;
688-
self.commit_outputs(proposal, persister).await
699+
self.commit_outputs(proposal, persister, relay_manager).await
689700
}
690701

691702
async fn commit_outputs(
692703
&self,
693704
proposal: Receiver<WantsOutputs>,
694705
persister: &ReceiverPersister,
706+
relay_manager: &mut RelayManager,
695707
) -> Result<()> {
696708
let proposal = proposal.commit_outputs().save(persister)?;
697-
self.contribute_inputs(proposal, persister).await
709+
self.contribute_inputs(proposal, persister, relay_manager).await
698710
}
699711

700712
async fn contribute_inputs(
701713
&self,
702714
proposal: Receiver<WantsInputs>,
703715
persister: &ReceiverPersister,
716+
relay_manager: &mut RelayManager,
704717
) -> Result<()> {
705718
let wallet = self.wallet();
706719
let candidate_inputs = wallet.list_unspent()?;
@@ -714,22 +727,24 @@ impl App {
714727
let selected_input = proposal.try_preserving_privacy(candidate_inputs)?;
715728
let proposal =
716729
proposal.contribute_inputs(vec![selected_input])?.commit_inputs().save(persister)?;
717-
self.apply_fee_range(proposal, persister).await
730+
self.apply_fee_range(proposal, persister, relay_manager).await
718731
}
719732

720733
async fn apply_fee_range(
721734
&self,
722735
proposal: Receiver<WantsFeeRange>,
723736
persister: &ReceiverPersister,
737+
relay_manager: &mut RelayManager,
724738
) -> Result<()> {
725739
let proposal = proposal.apply_fee_range(None, self.config.max_fee_rate).save(persister)?;
726-
self.finalize_proposal(proposal, persister).await
740+
self.finalize_proposal(proposal, persister, relay_manager).await
727741
}
728742

729743
async fn finalize_proposal(
730744
&self,
731745
proposal: Receiver<ProvisionalProposal>,
732746
persister: &ReceiverPersister,
747+
relay_manager: &mut RelayManager,
733748
) -> Result<()> {
734749
let wallet = self.wallet();
735750
let proposal = proposal
@@ -739,16 +754,19 @@ impl App {
739754
.map_err(|e| ImplementationError::from(e.into_boxed_dyn_error()))
740755
})
741756
.save(persister)?;
742-
self.send_payjoin_proposal(proposal, persister).await
757+
self.send_payjoin_proposal(proposal, persister, relay_manager).await
743758
}
744759

745760
async fn send_payjoin_proposal(
746761
&self,
747762
proposal: Receiver<PayjoinProposal>,
748763
persister: &ReceiverPersister,
764+
relay_manager: &mut RelayManager,
749765
) -> Result<()> {
750766
let (req, ohttp_ctx) = proposal
751-
.create_post_request(self.unwrap_relay_or_else_fetch(None::<&str>).await?.as_str())
767+
.create_post_request(
768+
self.unwrap_relay_or_else_fetch(None::<&str>, relay_manager).await?.as_str(),
769+
)
752770
.map_err(|e| anyhow!("v2 req extraction failed {}", e))?;
753771
let res = self.post_request(req).await?;
754772
let payjoin_psbt = proposal.psbt().clone();
@@ -813,14 +831,13 @@ impl App {
813831
async fn unwrap_relay_or_else_fetch(
814832
&self,
815833
directory: Option<impl payjoin::IntoUrl>,
834+
relay_manager: &mut RelayManager,
816835
) -> Result<url::Url> {
817836
let directory = directory.map(|url| url.into_url()).transpose()?;
818-
let selected_relay =
819-
self.relay_manager.lock().expect("Lock should not be poisoned").get_selected_relay();
820-
let ohttp_relay = match selected_relay {
837+
let ohttp_relay = match relay_manager.get_selected_relay() {
821838
Some(relay) => relay,
822839
None =>
823-
unwrap_ohttp_keys_or_else_fetch(&self.config, directory, self.relay_manager.clone())
840+
unwrap_ohttp_keys_or_else_fetch(&self.config, directory, relay_manager)
824841
.await?
825842
.relay_url,
826843
};
@@ -832,9 +849,11 @@ impl App {
832849
&self,
833850
session: Receiver<HasReplyableError>,
834851
persister: &ReceiverPersister,
852+
relay_manager: &mut RelayManager,
835853
) -> Result<()> {
836-
let (err_req, err_ctx) = session
837-
.create_error_request(self.unwrap_relay_or_else_fetch(None::<&str>).await?.as_str())?;
854+
let (err_req, err_ctx) = session.create_error_request(
855+
self.unwrap_relay_or_else_fetch(None::<&str>, relay_manager).await?.as_str(),
856+
)?;
838857

839858
let err_response = match self.post_request(err_req).await {
840859
Ok(response) => response,

payjoin-cli/src/app/v2/ohttp.rs

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
use std::sync::{Arc, Mutex};
2-
31
use anyhow::{anyhow, Result};
42

53
use super::Config;
@@ -30,7 +28,7 @@ pub(crate) struct ValidatedOhttpKeys {
3028
pub(crate) async fn unwrap_ohttp_keys_or_else_fetch(
3129
config: &Config,
3230
directory: Option<url::Url>,
33-
relay_manager: Arc<Mutex<RelayManager>>,
31+
relay_manager: &mut RelayManager,
3432
) -> Result<ValidatedOhttpKeys> {
3533
if let Some(ohttp_keys) = config.v2()?.ohttp_keys.clone() {
3634
println!("Using OHTTP Keys from config");
@@ -47,15 +45,14 @@ pub(crate) async fn unwrap_ohttp_keys_or_else_fetch(
4745
async fn fetch_ohttp_keys(
4846
config: &Config,
4947
directory: Option<url::Url>,
50-
relay_manager: Arc<Mutex<RelayManager>>,
48+
relay_manager: &mut RelayManager,
5149
) -> Result<ValidatedOhttpKeys> {
5250
use payjoin::bitcoin::secp256k1::rand::prelude::SliceRandom;
5351
let payjoin_directory = directory.unwrap_or(config.v2()?.pj_directory.clone());
5452
let relays = config.v2()?.ohttp_relays.clone();
5553

5654
loop {
57-
let failed_relays =
58-
relay_manager.lock().expect("Lock should not be poisoned").get_failed_relays();
55+
let failed_relays = relay_manager.get_failed_relays();
5956

6057
let remaining_relays: Vec<_> =
6158
relays.iter().filter(|r| !failed_relays.contains(r)).cloned().collect();
@@ -70,10 +67,7 @@ async fn fetch_ohttp_keys(
7067
None => return Err(anyhow!("Failed to select from remaining relays")),
7168
};
7269

73-
relay_manager
74-
.lock()
75-
.expect("Lock should not be poisoned")
76-
.set_selected_relay(selected_relay.clone());
70+
relay_manager.set_selected_relay(selected_relay.clone());
7771

7872
let ohttp_keys = {
7973
#[cfg(feature = "_manual-tls")]
@@ -106,10 +100,7 @@ async fn fetch_ohttp_keys(
106100
}
107101
Err(e) => {
108102
tracing::debug!("Failed to connect to relay: {selected_relay}, {e:?}");
109-
relay_manager
110-
.lock()
111-
.expect("Lock should not be poisoned")
112-
.add_failed_relay(selected_relay);
103+
relay_manager.add_failed_relay(selected_relay);
113104
}
114105
}
115106
}

0 commit comments

Comments
 (0)