@@ -168,6 +168,23 @@ impl chaininterface::FeeEstimator for TestFeeEstimator {
168168 }
169169}
170170
171+ /// Override closure type for [`TestRouter::override_create_blinded_payment_paths`].
172+ ///
173+ /// This closure is called instead of the default [`Router::create_blinded_payment_paths`]
174+ /// implementation when set, receiving the actual [`ReceiveTlvs`] so tests can construct custom
175+ /// blinded payment paths using the same TLVs the caller generated.
176+ pub type BlindedPaymentPathOverrideFn = Box <
177+ dyn Fn (
178+ PublicKey ,
179+ ReceiveAuthKey ,
180+ Vec < ChannelDetails > ,
181+ ReceiveTlvs ,
182+ Option < u64 > ,
183+ ) -> Result < Vec < BlindedPaymentPath > , ( ) >
184+ + Send
185+ + Sync ,
186+ > ;
187+
171188pub struct TestRouter < ' a > {
172189 pub router : DefaultRouter <
173190 Arc < NetworkGraph < & ' a TestLogger > > ,
@@ -181,6 +198,7 @@ pub struct TestRouter<'a> {
181198 pub next_routes : Mutex < VecDeque < ( RouteParameters , Option < Result < Route , & ' static str > > ) > > ,
182199 pub next_blinded_payment_paths : Mutex < Vec < BlindedPaymentPath > > ,
183200 pub next_payment_context_metadata : Mutex < Option < BTreeMap < u64 , Vec < u8 > > > > ,
201+ pub override_create_blinded_payment_paths : Mutex < Option < BlindedPaymentPathOverrideFn > > ,
184202 pub scorer : & ' a RwLock < TestScorer > ,
185203}
186204
@@ -193,6 +211,7 @@ impl<'a> TestRouter<'a> {
193211 let next_routes = Mutex :: new ( VecDeque :: new ( ) ) ;
194212 let next_blinded_payment_paths = Mutex :: new ( Vec :: new ( ) ) ;
195213 let next_payment_context_metadata = Mutex :: new ( None ) ;
214+ let override_create_blinded_payment_paths = Mutex :: new ( None ) ;
196215 Self {
197216 router : DefaultRouter :: new (
198217 Arc :: clone ( & network_graph) ,
@@ -205,6 +224,7 @@ impl<'a> TestRouter<'a> {
205224 next_routes,
206225 next_blinded_payment_paths,
207226 next_payment_context_metadata,
227+ override_create_blinded_payment_paths,
208228 scorer,
209229 }
210230 }
@@ -338,6 +358,12 @@ impl<'a> Router for TestRouter<'a> {
338358 PaymentContext :: Bolt12Refund ( ctx) => ctx. payment_metadata = Some ( metadata) ,
339359 }
340360 }
361+ if let Some ( override_fn) =
362+ self . override_create_blinded_payment_paths . lock ( ) . unwrap ( ) . as_ref ( )
363+ {
364+ return override_fn ( recipient, local_node_receive_key, first_hops, tlvs, amount_msats) ;
365+ }
366+
341367 let mut expected_paths = self . next_blinded_payment_paths . lock ( ) . unwrap ( ) ;
342368 if expected_paths. is_empty ( ) {
343369 self . router . create_blinded_payment_paths (
@@ -383,6 +409,7 @@ pub enum TestMessageRouterInternal<'a> {
383409pub struct TestMessageRouter < ' a > {
384410 pub inner : TestMessageRouterInternal < ' a > ,
385411 pub peers_override : Mutex < Vec < PublicKey > > ,
412+ pub forward_node_scid_override : Mutex < HashMap < PublicKey , u64 > > ,
386413}
387414
388415impl < ' a > TestMessageRouter < ' a > {
@@ -395,6 +422,7 @@ impl<'a> TestMessageRouter<'a> {
395422 entropy_source,
396423 ) ) ,
397424 peers_override : Mutex :: new ( Vec :: new ( ) ) ,
425+ forward_node_scid_override : Mutex :: new ( new_hash_map ( ) ) ,
398426 }
399427 }
400428
@@ -407,6 +435,7 @@ impl<'a> TestMessageRouter<'a> {
407435 entropy_source,
408436 ) ) ,
409437 peers_override : Mutex :: new ( Vec :: new ( ) ) ,
438+ forward_node_scid_override : Mutex :: new ( new_hash_map ( ) ) ,
410439 }
411440 }
412441}
@@ -438,9 +467,13 @@ impl<'a> MessageRouter for TestMessageRouter<'a> {
438467 {
439468 let peers_override = self . peers_override . lock ( ) . unwrap ( ) ;
440469 if !peers_override. is_empty ( ) {
470+ let scid_override = self . forward_node_scid_override . lock ( ) . unwrap ( ) ;
441471 let peer_override_nodes: Vec < _ > = peers_override
442472 . iter ( )
443- . map ( |pk| MessageForwardNode { node_id : * pk, short_channel_id : None } )
473+ . map ( |pk| MessageForwardNode {
474+ node_id : * pk,
475+ short_channel_id : scid_override. get ( pk) . copied ( ) ,
476+ } )
444477 . collect ( ) ;
445478 peers = peer_override_nodes;
446479 }
0 commit comments