@@ -45,6 +45,7 @@ use lightning::events::HTLCHandlingFailureType;
4545use lightning:: ln:: channelmanager:: { AChannelManager , FailureCode , InterceptId } ;
4646use lightning:: ln:: msgs:: { ErrorAction , LightningError } ;
4747use lightning:: ln:: types:: ChannelId ;
48+ use lightning:: onion_message:: messenger:: OnionMessageInterceptor ;
4849use lightning:: util:: errors:: APIError ;
4950use lightning:: util:: logger:: Level ;
5051use lightning:: util:: ser:: Writeable ;
@@ -717,6 +718,7 @@ where
717718 total_pending_requests : AtomicUsize ,
718719 config : LSPS2ServiceConfig ,
719720 persistence_in_flight : AtomicUsize ,
721+ onion_message_interceptor : Option < Arc < dyn OnionMessageInterceptor + Send + Sync > > ,
720722}
721723
722724impl < CM : Deref , K : KVStore + Clone , T : BroadcasterInterface + Clone > LSPS2ServiceHandler < CM , K , T >
@@ -728,6 +730,7 @@ where
728730 per_peer_state : HashMap < PublicKey , Mutex < PeerState > > , pending_messages : Arc < MessageQueue > ,
729731 pending_events : Arc < EventQueue < K > > , channel_manager : CM , kv_store : K , tx_broadcaster : T ,
730732 config : LSPS2ServiceConfig ,
733+ onion_message_interceptor : Option < Arc < dyn OnionMessageInterceptor + Send + Sync > > ,
731734 ) -> Result < Self , lightning:: io:: Error > {
732735 let mut peer_by_intercept_scid = new_hash_map ( ) ;
733736 let mut peer_by_channel_id = new_hash_map ( ) ;
@@ -756,6 +759,17 @@ where
756759 }
757760 }
758761
762+ // Register all peers and SCIDs with active intercept SCIDs for onion message
763+ // interception, so that messages for offline peers are held rather than dropped.
764+ // Both peer-based and SCID-based registration are needed to support clients using
765+ // either pubkey or compact SCID encoding in their message blinded paths.
766+ if let Some ( ref interceptor) = onion_message_interceptor {
767+ for ( scid, node_id) in & peer_by_intercept_scid {
768+ interceptor. register_peer_for_interception ( * node_id) ;
769+ interceptor. register_scid_for_interception ( * scid, * node_id) ;
770+ }
771+ }
772+
759773 Ok ( Self {
760774 pending_messages,
761775 pending_events,
@@ -768,6 +782,7 @@ where
768782 kv_store,
769783 tx_broadcaster,
770784 config,
785+ onion_message_interceptor,
771786 } )
772787 }
773788
@@ -776,6 +791,33 @@ where
776791 & self . config
777792 }
778793
794+ /// Cleans up `peer_by_intercept_scid` entries for the given SCIDs, and deregisters the peer
795+ /// from onion message interception if they have no remaining active intercept SCIDs.
796+ fn cleanup_intercept_scids (
797+ & self , counterparty_node_id : & PublicKey , pruned_scids : & [ u64 ] , has_remaining_channels : bool ,
798+ ) {
799+ if pruned_scids. is_empty ( ) {
800+ return ;
801+ }
802+
803+ {
804+ let mut peer_by_intercept_scid = self . peer_by_intercept_scid . write ( ) . unwrap ( ) ;
805+ for scid in pruned_scids {
806+ peer_by_intercept_scid. remove ( scid) ;
807+ }
808+ }
809+
810+ if let Some ( ref interceptor) = self . onion_message_interceptor {
811+ for scid in pruned_scids {
812+ interceptor. deregister_scid_for_interception ( * scid) ;
813+ }
814+
815+ if !has_remaining_channels {
816+ interceptor. deregister_peer_for_interception ( counterparty_node_id) ;
817+ }
818+ }
819+ }
820+
779821 /// Returns whether the peer has any active LSPS2 requests.
780822 pub ( crate ) fn has_active_requests ( & self , counterparty_node_id : & PublicKey ) -> bool {
781823 let outer_state_lock = self . per_peer_state . read ( ) . unwrap ( ) ;
@@ -921,6 +963,14 @@ where
921963 peer_by_intercept_scid. insert ( intercept_scid, * counterparty_node_id) ;
922964 }
923965
966+ if let Some ( ref interceptor) = self . onion_message_interceptor {
967+ interceptor. register_peer_for_interception ( * counterparty_node_id) ;
968+ interceptor. register_scid_for_interception (
969+ intercept_scid,
970+ * counterparty_node_id,
971+ ) ;
972+ }
973+
924974 let outbound_jit_channel = OutboundJITChannel :: new (
925975 buy_request. payment_size_msat ,
926976 buy_request. opening_fee_params ,
@@ -990,17 +1040,17 @@ where
9901040 let event_queue_notifier = self . pending_events . notifier ( ) ;
9911041 let mut should_persist = None ;
9921042
993- if let Some ( counterparty_node_id) =
994- self . peer_by_intercept_scid . read ( ) . unwrap ( ) . get ( & intercept_scid)
995- {
1043+ let counterparty_node_id =
1044+ self . peer_by_intercept_scid . read ( ) . unwrap ( ) . get ( & intercept_scid) . copied ( ) ;
1045+ if let Some ( counterparty_node_id ) = counterparty_node_id {
9961046 let outer_state_lock = self . per_peer_state . read ( ) . unwrap ( ) ;
997- match outer_state_lock. get ( counterparty_node_id) {
1047+ match outer_state_lock. get ( & counterparty_node_id) {
9981048 Some ( inner_state_lock) => {
9991049 let mut peer_state = inner_state_lock. lock ( ) . unwrap ( ) ;
10001050 if let Some ( jit_channel) =
10011051 peer_state. outbound_channels_by_intercept_scid . get_mut ( & intercept_scid)
10021052 {
1003- should_persist = Some ( * counterparty_node_id) ;
1053+ should_persist = Some ( counterparty_node_id) ;
10041054 let htlc = InterceptedHTLC {
10051055 intercept_id,
10061056 expected_outbound_amount_msat,
@@ -1009,7 +1059,7 @@ where
10091059 match jit_channel. htlc_intercepted ( htlc) {
10101060 Ok ( Some ( HTLCInterceptedAction :: OpenChannel ( open_channel_params) ) ) => {
10111061 let event = LSPS2ServiceEvent :: OpenChannel {
1012- their_network_key : counterparty_node_id. clone ( ) ,
1062+ their_network_key : counterparty_node_id,
10131063 amt_to_forward_msat : open_channel_params. amt_to_forward_msat ,
10141064 opening_fee_msat : open_channel_params. opening_fee_msat ,
10151065 user_channel_id : jit_channel. user_channel_id ,
@@ -1021,7 +1071,7 @@ where
10211071 self . channel_manager . get_cm ( ) . forward_intercepted_htlc (
10221072 intercept_id,
10231073 & channel_id,
1024- * counterparty_node_id,
1074+ counterparty_node_id,
10251075 expected_outbound_amount_msat,
10261076 ) ?;
10271077 } ,
@@ -1038,7 +1088,7 @@ where
10381088 self . channel_manager . get_cm ( ) . forward_intercepted_htlc (
10391089 intercept_id,
10401090 & channel_id,
1041- * counterparty_node_id,
1091+ counterparty_node_id,
10421092 amount_to_forward_msat,
10431093 ) ?;
10441094 }
@@ -1051,7 +1101,13 @@ where
10511101 peer_state
10521102 . outbound_channels_by_intercept_scid
10531103 . remove ( & intercept_scid) ;
1054- // TODO: cleanup peer_by_intercept_scid
1104+ let has_remaining =
1105+ !peer_state. outbound_channels_by_intercept_scid . is_empty ( ) ;
1106+ self . cleanup_intercept_scids (
1107+ & counterparty_node_id,
1108+ & [ intercept_scid] ,
1109+ has_remaining,
1110+ ) ;
10551111 return Err ( APIError :: APIMisuseError { err : e. err } ) ;
10561112 } ,
10571113 }
@@ -1858,6 +1914,22 @@ where
18581914 debug_assert ! ( false ) ;
18591915 }
18601916 }
1917+ if future_opt. is_some ( ) {
1918+ // Clean up handler-level maps for the removed peer.
1919+ let removed_scids: Vec < u64 > = self
1920+ . peer_by_intercept_scid
1921+ . read ( )
1922+ . unwrap ( )
1923+ . iter ( )
1924+ . filter ( |( _, nid) | * * nid == counterparty_node_id)
1925+ . map ( |( scid, _) | * scid)
1926+ . collect ( ) ;
1927+ self . cleanup_intercept_scids ( & counterparty_node_id, & removed_scids, false ) ;
1928+ self . peer_by_channel_id
1929+ . write ( )
1930+ . unwrap ( )
1931+ . retain ( |_, node_id| * node_id != counterparty_node_id) ;
1932+ }
18611933 if let Some ( future) = future_opt {
18621934 future. await ?;
18631935 did_persist = true ;
0 commit comments