Skip to content

Commit 837ff11

Browse files
[Temporary] Removes Vec from MaybeInputsOwned non-blocking interface
1 parent 7a0c45f commit 837ff11

3 files changed

Lines changed: 115 additions & 71 deletions

File tree

payjoin/src/core/receive/mod.rs

Lines changed: 86 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,67 @@ impl<'a> From<&'a InputPair> for InternalInputPair<'a> {
228228
fn from(pair: &'a InputPair) -> Self { Self { psbtin: &pair.psbtin, txin: &pair.txin } }
229229
}
230230

231+
pub struct Validator<T> {
232+
items: Vec<T>,
233+
positives: Vec<T>,
234+
negatives: Vec<T>,
235+
finalized: bool,
236+
}
237+
238+
impl<T> Validator<T>
239+
where
240+
T: Eq + Clone,
241+
{
242+
fn new(items: Vec<T>) -> Self {
243+
Validator { items, positives: vec![], negatives: vec![], finalized: false }
244+
}
245+
246+
pub fn is_finalized(&self) -> bool { self.finalized }
247+
248+
pub fn get_reference(&self) -> UntaggedReference<T> { UntaggedReference(self.items[0].clone()) }
249+
250+
pub fn mark_reference(&mut self, tagged_reference: TaggedReference<T>) -> Result<bool, Error> {
251+
if self.finalized {
252+
return Err(ImplementationError::from(
253+
"Validation state machine already validated, no references left to mark",
254+
)
255+
.into());
256+
} else if self.items[0] != tagged_reference.0 {
257+
return Err(ImplementationError::from(
258+
"Incorrect reference returned for current validation state",
259+
)
260+
.into());
261+
}
262+
263+
if tagged_reference.1 {
264+
self.positives.push(self.items.remove(0));
265+
} else {
266+
self.negatives.push(self.items.remove(0));
267+
}
268+
if self.items.is_empty() {
269+
self.finalized = true
270+
}
271+
Ok(self.finalized)
272+
}
273+
274+
fn get_positives(&self) -> Vec<T> { self.positives.clone() }
275+
}
276+
277+
pub struct UntaggedReference<T>(T);
278+
279+
pub struct TaggedReference<T>(T, bool);
280+
281+
impl<T> UntaggedReference<T>
282+
where
283+
T: Clone,
284+
{
285+
pub fn value(&self) -> T { self.0.clone() }
286+
287+
pub fn mark(&self, is_owned_seen: bool) -> TaggedReference<T> {
288+
TaggedReference(self.0.clone(), is_owned_seen)
289+
}
290+
}
291+
231292
/// Validate the payload of a Payjoin request for PSBT and Params sanity
232293
pub(crate) fn parse_payload(
233294
base64: &str,
@@ -399,15 +460,15 @@ impl OriginalPayload {
399460
&self,
400461
is_owned: &mut impl FnMut(&Script) -> Result<bool, ImplementationError>,
401462
) -> Result<(), Error> {
402-
let inputs_owned_result = self.gather_inputs_owned_result(is_owned)?;
403-
self.process_inputs_owned_result(inputs_owned_result)
463+
let validator = self.validate_inputs_not_owned(is_owned)?;
464+
self.process_inputs_owned_validator(validator)
404465
}
405466

406467
/// Expose extracting input txout scriptpubkey. This is intended to be used for callers
407468
/// that require non-blocking calls for checking if inputs are owned, it enables
408469
/// callers to check ownership in separate call outside of `check_inputs_not_owned`
409470
/// and return the result
410-
pub fn extract_input_scripts(&self) -> Result<Vec<ScriptBuf>, Error> {
471+
fn extract_input_scripts(&self) -> Result<Vec<ScriptBuf>, Error> {
411472
let mut err: Result<(), Error> = Ok(());
412473
let scripts: Vec<ScriptBuf> = self
413474
.psbt
@@ -424,58 +485,37 @@ impl OriginalPayload {
424485
Ok(scripts)
425486
}
426487

427-
/// Utility function to run callback to check if inputs are owned and gather the result
428-
pub fn gather_inputs_owned_result(
488+
/// Utility function to check inputs owned using Validator
489+
fn validate_inputs_not_owned(
429490
&self,
430491
is_owned: &mut impl FnMut(&Script) -> Result<bool, ImplementationError>,
431-
) -> Result<Vec<(ScriptBuf, bool)>, Error> {
432-
self.psbt
433-
.input_pairs()
434-
.map(|input| match input.previous_txout() {
435-
Ok(txout) => {
436-
let script = txout.script_pubkey.to_owned();
437-
let is_owned_result = is_owned(&script)?;
438-
Ok((script, is_owned_result))
439-
}
440-
Err(e) => Err(InternalPayloadError::PrevTxOut(e).into()),
441-
})
442-
.collect()
492+
) -> Result<Validator<ScriptBuf>, Error> {
493+
let input_scripts = self.extract_input_scripts()?;
494+
let mut validator = Validator::new(input_scripts);
495+
while !validator.is_finalized() {
496+
let untagged_reference = validator.get_reference();
497+
let input_script = untagged_reference.value();
498+
let input_script_owned = is_owned(&input_script)?;
499+
let tagged_reference = untagged_reference.mark(input_script_owned);
500+
validator.mark_reference(tagged_reference)?;
501+
}
502+
Ok(validator)
443503
}
444504

445-
/// Check that the original PSBT has no receiver-owned inputs.
505+
/// Process inputs owned validator
446506
///
447507
/// An attacker can try to spend the receiver's own inputs. This check prevents that.
448-
pub fn process_inputs_owned_result(
508+
pub fn process_inputs_owned_validator(
449509
&self,
450-
inputs_owned_result: Vec<(ScriptBuf, bool)>,
510+
validator: Validator<ScriptBuf>,
451511
) -> Result<(), Error> {
452-
let mut err: Result<(), Error> = Ok(());
453-
if let Some(e) = self
454-
.psbt
455-
.input_pairs()
456-
.scan(&mut err, |err, input| match input.previous_txout() {
457-
Ok(txout) => Some(txout.script_pubkey.to_owned()),
458-
Err(e) => {
459-
**err = Err(InternalPayloadError::PrevTxOut(e).into());
460-
None
461-
}
462-
})
463-
.find_map(|script| {
464-
match inputs_owned_result
465-
.iter()
466-
.find(|(is_owned_script, _)| script == *is_owned_script)
467-
{
468-
Some((_, false)) => None,
469-
Some((_, true)) => Some(InternalPayloadError::InputOwned(script).into()),
470-
None => Some(Error::Implementation(ImplementationError::from(
471-
format!("Input is owned result missing for script {script}").as_str(),
472-
))),
473-
}
474-
})
475-
{
476-
return Err(e);
512+
if !validator.is_finalized() {
513+
return Err(ImplementationError::from("Validator has not finished validation").into());
514+
}
515+
let owned_inputs = validator.get_positives();
516+
if !owned_inputs.is_empty() {
517+
return Err(InternalPayloadError::InputOwned(owned_inputs[0].clone()).into());
477518
}
478-
err?;
479519
Ok(())
480520
}
481521

payjoin/src/core/receive/v1/mod.rs

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,9 @@ impl UncheckedOriginalPayload {
153153
/// to extract the signed original PSBT to schedule a fallback in case the Payjoin process fails.
154154
///
155155
/// Call [`Self::check_inputs_not_owned`] to proceed. If caller needs to use non-blocking calls
156-
/// for checking inputs are not owned call [`Self::extract_input_scripts`] to extract the scripts
157-
/// to be checked and [`Self::process_inputs_owned_result`] to return the result and proceed to the
156+
/// for checking inputs are not owned call [`Self::get_inputs_owned_validator`] to get
157+
/// a Validator instance that will provide the input scripts to be checked
158+
/// and [`Self::process_inputs_owned_validator`] to return a finalized Validator and proceed to the
158159
/// next state.
159160
#[derive(Debug, Clone)]
160161
pub struct MaybeInputsOwned {
@@ -178,25 +179,26 @@ impl MaybeInputsOwned {
178179
self,
179180
is_owned: &mut impl FnMut(&Script) -> Result<bool, ImplementationError>,
180181
) -> Result<MaybeInputsSeen, Error> {
181-
let inputs_owned_result = self.original.gather_inputs_owned_result(is_owned)?;
182-
self.process_inputs_owned_result(inputs_owned_result)
182+
let validator = self.original.validate_inputs_not_owned(is_owned)?;
183+
self.process_inputs_owned_validator(validator)
183184
}
184185

185-
/// Extracts the inputs txout script pubkeys
186+
/// Provides Validator for checking whether input is owned
186187
///
187188
/// Use this for using non-blocking calls to check whether inputs are owned
188-
pub fn extract_input_scripts(&self) -> Result<Vec<ScriptBuf>, Error> {
189-
self.original.extract_input_scripts()
189+
pub fn get_inputs_owned_validator(&self) -> Result<Validator<ScriptBuf>, Error> {
190+
let input_scripts = self.original.extract_input_scripts()?;
191+
Ok(Validator::new(input_scripts))
190192
}
191193

192-
/// Process result of whether the original PSBT has no receiver-owned inputs.
194+
/// Provides Validator for checking whether input is owned
193195
///
194196
/// Use this for using non-blocking calls to check whether inputs are owned
195-
pub fn process_inputs_owned_result(
197+
pub fn process_inputs_owned_validator(
196198
self,
197-
inputs_owned_result: Vec<(ScriptBuf, bool)>,
199+
validator: Validator<ScriptBuf>,
198200
) -> Result<MaybeInputsSeen, Error> {
199-
self.original.process_inputs_owned_result(inputs_owned_result)?;
201+
self.original.process_inputs_owned_validator(validator)?;
200202
Ok(MaybeInputsSeen { original: self.original })
201203
}
202204
}

payjoin/src/core/receive/v2/mod.rs

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ use crate::persist::{
5757
MaybeFatalOrSuccessTransition, MaybeFatalTransition, MaybeFatalTransitionWithNoResults,
5858
MaybeSuccessTransition, MaybeTransientTransition, NextStateTransition,
5959
};
60-
use crate::receive::{parse_payload, InputPair, OriginalPayload, PsbtContext};
60+
use crate::receive::{parse_payload, InputPair, OriginalPayload, PsbtContext, Validator};
6161
use crate::time::Time;
6262
use crate::uri::ShortId;
6363
use crate::{ImplementationError, IntoUrl, IntoUrlError, Request, Version};
@@ -665,9 +665,10 @@ pub struct MaybeInputsOwned {
665665
/// to extract the signed original PSBT to schedule a fallback in case the Payjoin process fails.
666666
///
667667
/// Call [`Receiver<MaybeInputsOwned>::check_inputs_not_owned`] to proceed. If caller needs to use
668-
/// non-blocking calls for checking inputs are not owned call [`Receiver<MaybeInputsOwned>::extract_input_scripts`]
669-
/// to extract the scripts to be checked and [`Receiver<MaybeInputsOwned>::process_inputs_owned_result`]
670-
/// to return the result and proceed to the next state.
668+
/// non-blocking calls for checking inputs are not owned call
669+
/// [`Receiver<MaybeInputsOwned>::get_inputs_owned_validator`] to get a Validator instance that
670+
/// will provide the input scripts to be checked and [`Receiver<MaybeInputsOwned>::process_inputs_owned_validator`]
671+
/// to return a finalized Validator and proceed to the next state.
671672
impl Receiver<MaybeInputsOwned> {
672673
/// Extracts the original transaction received from the sender.
673674
///
@@ -690,9 +691,9 @@ impl Receiver<MaybeInputsOwned> {
690691
Error,
691692
Receiver<HasReplyableError>,
692693
> {
693-
let inputs_owned_result = self.original.gather_inputs_owned_result(is_owned);
694-
match inputs_owned_result {
695-
Ok(inputs_owned_result) => self.process_inputs_owned_result(inputs_owned_result),
694+
let validator_result = self.original.validate_inputs_not_owned(is_owned);
695+
match validator_result {
696+
Ok(validator) => self.process_inputs_owned_validator(validator),
696697
Err(e) => match e {
697698
Error::Implementation(_) => MaybeFatalTransition::transient(e),
698699
_ => MaybeFatalTransition::replyable_error(
@@ -707,26 +708,27 @@ impl Receiver<MaybeInputsOwned> {
707708
}
708709
}
709710

710-
/// Extracts the inputs txout script pubkeys
711+
/// Provides Validator for checking whether input is owned
711712
///
712713
/// Use this for using non-blocking calls to check whether inputs are owned
713-
pub fn extract_input_scripts(&self) -> Result<Vec<ScriptBuf>, Error> {
714-
self.original.extract_input_scripts()
714+
pub fn get_inputs_owned_validator(&self) -> Result<Validator<ScriptBuf>, Error> {
715+
let input_scripts = self.original.extract_input_scripts()?;
716+
Ok(Validator::new(input_scripts))
715717
}
716718

717-
/// Process result of whether the original PSBT inputs are owned.
719+
/// Process Validator for whether the original PSBT inputs are owned.
718720
///
719721
/// Use this for using non-blocking calls to check whether inputs are owned
720-
pub fn process_inputs_owned_result(
722+
pub fn process_inputs_owned_validator(
721723
self,
722-
inputs_owned_result: Vec<(ScriptBuf, bool)>,
724+
validator: Validator<ScriptBuf>,
723725
) -> MaybeFatalTransition<
724726
SessionEvent,
725727
Receiver<MaybeInputsSeen>,
726728
Error,
727729
Receiver<HasReplyableError>,
728730
> {
729-
match self.state.original.process_inputs_owned_result(inputs_owned_result) {
731+
match self.state.original.process_inputs_owned_validator(validator) {
730732
Ok(inner) => inner,
731733
Err(e) => match e {
732734
Error::Implementation(_) => {

0 commit comments

Comments
 (0)