Skip to content

Commit 1b4a359

Browse files
committed
Add a blinded-payment-path override to test utilities
Let integration tests force specific blinded payment paths so LSPS2 BOLT12 routing behavior can be exercised deterministically. Co-Authored-By: HAL 9000
1 parent a34d093 commit 1b4a359

1 file changed

Lines changed: 34 additions & 1 deletion

File tree

lightning/src/util/test_utils.rs

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,23 @@ impl chaininterface::FeeEstimator for TestFeeEstimator {
168168
}
169169
}
170170

171+
/// Override closure type for [`TestRouter::override_create_blinded_payment_paths`].
172+
///
173+
/// This closure is called instead of the default [`Router::create_blinded_payment_paths`]
174+
/// implementation when set, receiving the actual [`ReceiveTlvs`] so tests can construct custom
175+
/// blinded payment paths using the same TLVs the caller generated.
176+
pub type BlindedPaymentPathOverrideFn = Box<
177+
dyn Fn(
178+
PublicKey,
179+
ReceiveAuthKey,
180+
Vec<ChannelDetails>,
181+
ReceiveTlvs,
182+
Option<u64>,
183+
) -> Result<Vec<BlindedPaymentPath>, ()>
184+
+ Send
185+
+ Sync,
186+
>;
187+
171188
pub struct TestRouter<'a> {
172189
pub router: DefaultRouter<
173190
Arc<NetworkGraph<&'a TestLogger>>,
@@ -181,6 +198,7 @@ pub struct TestRouter<'a> {
181198
pub next_routes: Mutex<VecDeque<(RouteParameters, Option<Result<Route, &'static str>>)>>,
182199
pub next_blinded_payment_paths: Mutex<Vec<BlindedPaymentPath>>,
183200
pub next_payment_context_metadata: Mutex<Option<BTreeMap<u64, Vec<u8>>>>,
201+
pub override_create_blinded_payment_paths: Mutex<Option<BlindedPaymentPathOverrideFn>>,
184202
pub scorer: &'a RwLock<TestScorer>,
185203
}
186204

@@ -193,6 +211,7 @@ impl<'a> TestRouter<'a> {
193211
let next_routes = Mutex::new(VecDeque::new());
194212
let next_blinded_payment_paths = Mutex::new(Vec::new());
195213
let next_payment_context_metadata = Mutex::new(None);
214+
let override_create_blinded_payment_paths = Mutex::new(None);
196215
Self {
197216
router: DefaultRouter::new(
198217
Arc::clone(&network_graph),
@@ -205,6 +224,7 @@ impl<'a> TestRouter<'a> {
205224
next_routes,
206225
next_blinded_payment_paths,
207226
next_payment_context_metadata,
227+
override_create_blinded_payment_paths,
208228
scorer,
209229
}
210230
}
@@ -338,6 +358,12 @@ impl<'a> Router for TestRouter<'a> {
338358
PaymentContext::Bolt12Refund(ctx) => ctx.payment_metadata = Some(metadata),
339359
}
340360
}
361+
if let Some(override_fn) =
362+
self.override_create_blinded_payment_paths.lock().unwrap().as_ref()
363+
{
364+
return override_fn(recipient, local_node_receive_key, first_hops, tlvs, amount_msats);
365+
}
366+
341367
let mut expected_paths = self.next_blinded_payment_paths.lock().unwrap();
342368
if expected_paths.is_empty() {
343369
self.router.create_blinded_payment_paths(
@@ -383,6 +409,7 @@ pub enum TestMessageRouterInternal<'a> {
383409
pub struct TestMessageRouter<'a> {
384410
pub inner: TestMessageRouterInternal<'a>,
385411
pub peers_override: Mutex<Vec<PublicKey>>,
412+
pub forward_node_scid_override: Mutex<HashMap<PublicKey, u64>>,
386413
}
387414

388415
impl<'a> TestMessageRouter<'a> {
@@ -395,6 +422,7 @@ impl<'a> TestMessageRouter<'a> {
395422
entropy_source,
396423
)),
397424
peers_override: Mutex::new(Vec::new()),
425+
forward_node_scid_override: Mutex::new(new_hash_map()),
398426
}
399427
}
400428

@@ -407,6 +435,7 @@ impl<'a> TestMessageRouter<'a> {
407435
entropy_source,
408436
)),
409437
peers_override: Mutex::new(Vec::new()),
438+
forward_node_scid_override: Mutex::new(new_hash_map()),
410439
}
411440
}
412441
}
@@ -438,9 +467,13 @@ impl<'a> MessageRouter for TestMessageRouter<'a> {
438467
{
439468
let peers_override = self.peers_override.lock().unwrap();
440469
if !peers_override.is_empty() {
470+
let scid_override = self.forward_node_scid_override.lock().unwrap();
441471
let peer_override_nodes: Vec<_> = peers_override
442472
.iter()
443-
.map(|pk| MessageForwardNode { node_id: *pk, short_channel_id: None })
473+
.map(|pk| MessageForwardNode {
474+
node_id: *pk,
475+
short_channel_id: scid_override.get(pk).copied(),
476+
})
444477
.collect();
445478
peers = peer_override_nodes;
446479
}

0 commit comments

Comments
 (0)