@@ -37,13 +37,9 @@ pub struct ScidWithPeer {
3737
3838impl ScidWithPeer {
3939 pub fn new (
40- scid : u64 , peer_id : PublicKey ,
40+ scid : u64 , peer_id : PublicKey , policy : FeePolicy ,
4141 ) -> Self {
42- Self {
43- scid,
44- peer_id,
45- policy : FeePolicy :: Flat ( FeeTier :: Standard ) ,
46- }
42+ Self { scid, peer_id, policy }
4743 }
4844
4945 pub fn store_key ( & self ) -> String {
@@ -73,6 +69,7 @@ pub struct ScidStore<L: Deref, KV: Deref + Clone>
7369where L :: Target : Logger , KV :: Target : KVStoreSync {
7470 peer_by_scid : RwLock < HashMap < u64 , PublicKey > > ,
7571 scid_by_peer : RwLock < HashMap < PublicKey , u64 > > ,
72+ policy_by_peer : RwLock < HashMap < PublicKey , FeePolicy > > ,
7673 kv_store : KV ,
7774 logger : L
7875}
@@ -112,7 +109,11 @@ where L::Target: Logger, KV::Target: KVStoreSync {
112109 let scid_by_peer =
113110 RwLock :: new ( HashMap :: from_iter ( scids. iter ( ) . map ( |obj| ( obj. peer_id ( ) , obj. scid ( ) ) ) ) ) ;
114111
115- Ok ( Self { peer_by_scid, scid_by_peer, kv_store, logger } )
112+ let policy_by_peer = RwLock :: new ( HashMap :: from_iter (
113+ scids. iter ( ) . map ( |obj| ( obj. peer_id ( ) , obj. policy ( ) . clone ( ) ) ) ,
114+ ) ) ;
115+
116+ Ok ( Self { peer_by_scid, scid_by_peer, policy_by_peer, kv_store, logger } )
116117 }
117118
118119 pub ( crate ) fn insert ( & self , scid : ScidWithPeer ) -> Result < bool , io:: Error > {
@@ -125,8 +126,10 @@ where L::Target: Logger, KV::Target: KVStoreSync {
125126 // Then insert into the maps
126127 let mut locked_peer_by_scid = self . peer_by_scid . write ( ) . unwrap ( ) ;
127128 let mut locked_scid_by_peer = self . scid_by_peer . write ( ) . unwrap ( ) ;
129+ let mut locked_policy_by_peer = self . policy_by_peer . write ( ) . unwrap ( ) ;
128130 let updated = locked_peer_by_scid. insert ( scid. scid ( ) , scid. peer_id ( ) . clone ( ) ) . is_some ( ) ;
129131 locked_scid_by_peer. insert ( scid. peer_id ( ) . clone ( ) , scid. scid ( ) ) ;
132+ locked_policy_by_peer. insert ( scid. peer_id ( ) . clone ( ) , scid. policy ( ) . clone ( ) ) ;
130133
131134 log_info ! (
132135 self . logger,
@@ -142,10 +145,12 @@ where L::Target: Logger, KV::Target: KVStoreSync {
142145 pub ( crate ) fn remove ( & self , scid : u64 ) -> Result < ( ) , io:: Error > {
143146 let mut locked_peer_by_scid = self . peer_by_scid . write ( ) . unwrap ( ) ;
144147 let mut locked_scid_by_peer = self . scid_by_peer . write ( ) . unwrap ( ) ;
148+ let mut locked_policy_by_peer = self . policy_by_peer . write ( ) . unwrap ( ) ;
145149
146150 let removed = locked_peer_by_scid. remove ( & scid) ;
147151 if let Some ( peer_id) = removed {
148152 locked_scid_by_peer. remove ( & peer_id) ;
153+ locked_policy_by_peer. remove ( & peer_id) ;
149154 let store_key = utils:: to_string ( & scid. to_be_bytes ( ) ) ;
150155 self . kv_store
151156 . remove ( INTERCEPT_SCID_STORE_PERSISTENCE_PRIMARY_NAMESPACE , INTERCEPT_SCID_STORE_PERSISTENCE_SECONDARY_NAMESPACE , & store_key, false )
@@ -182,10 +187,7 @@ where L::Target: Logger, KV::Target: KVStoreSync {
182187 pub fn add_intercepted_scid (
183188 & self , scid : u64 , peer_id : PublicKey ,
184189 ) -> Result < bool , io:: Error > {
185- let scid = ScidWithPeer :: new (
186- scid,
187- peer_id,
188- ) ;
190+ let scid = ScidWithPeer :: new ( scid, peer_id, FeePolicy :: Flat ( FeeTier :: Standard ) ) ;
189191 self . insert ( scid)
190192 }
191193
@@ -212,6 +214,10 @@ where L::Target: Logger, KV::Target: KVStoreSync {
212214 ) ;
213215 result
214216 }
217+
218+ pub fn get_policy ( & self , peer_id : & PublicKey ) -> Option < FeePolicy > {
219+ self . policy_by_peer . read ( ) . unwrap ( ) . get ( peer_id) . cloned ( )
220+ }
215221}
216222
217223#[ cfg( test) ]
@@ -243,11 +249,11 @@ mod tests {
243249
244250 #[ test]
245251 fn round_trips_with_policy ( ) {
246- let record = ScidWithPeer :: new ( 42 , test_peer ( ) ) ;
252+ let record = ScidWithPeer :: new ( 42 , test_peer ( ) , FeePolicy :: Flat ( FeeTier :: ZeroFee ) ) ;
247253 let bytes = record. encode ( ) ;
248254 let decoded = ScidWithPeer :: read ( & mut & bytes[ ..] ) . unwrap ( ) ;
249255 assert_eq ! ( record, decoded) ;
250- assert_eq ! ( decoded. policy( ) , & FeePolicy :: Flat ( FeeTier :: Standard ) ) ;
256+ assert_eq ! ( decoded. policy( ) , & FeePolicy :: Flat ( FeeTier :: ZeroFee ) ) ;
251257 }
252258
253259 #[ test]
@@ -259,4 +265,50 @@ mod tests {
259265 assert_eq ! ( decoded. peer_id( ) , test_peer( ) ) ;
260266 assert_eq ! ( decoded. policy( ) , & FeePolicy :: Flat ( FeeTier :: Standard ) ) ;
261267 }
268+
269+ use bitcoin:: secp256k1:: { Secp256k1 , SecretKey } ;
270+ use lightning:: util:: test_utils:: { TestLogger , TestStore } ;
271+ use std:: sync:: Arc ;
272+
273+ fn other_peer ( ) -> PublicKey {
274+ PublicKey :: from_secret_key ( & Secp256k1 :: new ( ) , & SecretKey :: from_slice ( & [ 0x24 ; 32 ] ) . unwrap ( ) )
275+ }
276+
277+ fn test_store ( ) -> ScidStore < Arc < TestLogger > , Arc < TestStore > > {
278+ ScidStore :: new ( Arc :: new ( TestStore :: new ( false ) ) , Arc :: new ( TestLogger :: new ( ) ) ) . unwrap ( )
279+ }
280+
281+ #[ test]
282+ fn insert_with_policy_then_get_policy_returns_it ( ) {
283+ let store = test_store ( ) ;
284+ store
285+ . insert ( ScidWithPeer :: new ( 42 , test_peer ( ) , FeePolicy :: Flat ( FeeTier :: ZeroFee ) ) )
286+ . unwrap ( ) ;
287+
288+ assert_eq ! ( store. get_policy( & test_peer( ) ) , Some ( FeePolicy :: Flat ( FeeTier :: ZeroFee ) ) ) ;
289+ assert_eq ! ( store. get_policy( & other_peer( ) ) , None ) ;
290+ }
291+
292+ #[ test]
293+ fn load_rebuilds_policy_map ( ) {
294+ let kv_store = Arc :: new ( TestStore :: new ( false ) ) ;
295+ {
296+ let store =
297+ ScidStore :: new ( kv_store. clone ( ) , Arc :: new ( TestLogger :: new ( ) ) ) . unwrap ( ) ;
298+ store
299+ . insert ( ScidWithPeer :: new ( 42 , test_peer ( ) , FeePolicy :: Flat ( FeeTier :: ZeroFee ) ) )
300+ . unwrap ( ) ;
301+ }
302+
303+ let reloaded = ScidStore :: new ( kv_store, Arc :: new ( TestLogger :: new ( ) ) ) . unwrap ( ) ;
304+ assert_eq ! ( reloaded. get_policy( & test_peer( ) ) , Some ( FeePolicy :: Flat ( FeeTier :: ZeroFee ) ) ) ;
305+ }
306+
307+ #[ test]
308+ fn default_record_resolves_to_standard_policy ( ) {
309+ let store = test_store ( ) ;
310+ store. add_intercepted_scid ( 42 , test_peer ( ) ) . unwrap ( ) ;
311+
312+ assert_eq ! ( store. get_policy( & test_peer( ) ) , Some ( FeePolicy :: Flat ( FeeTier :: Standard ) ) ) ;
313+ }
262314}
0 commit comments