@@ -27,7 +27,7 @@ pub struct DefaultRoleManager {
2727 domain_matching_fn : Option < MatchingFn > ,
2828}
2929
30- #[ derive( Debug ) ]
30+ #[ derive( Clone , Debug ) ]
3131enum EdgeVariant {
3232 Link ,
3333 Match ,
@@ -53,6 +53,9 @@ impl DefaultRoleManager {
5353 ) -> NodeIndex < u32 > {
5454 let domain = domain. unwrap_or ( DEFAULT_DOMAIN ) ;
5555
56+ // detect whether this is a new domain creation
57+ let is_new_domain = !self . all_domains . contains_key ( domain) ;
58+
5659 let graph = self . all_domains . entry ( domain. into ( ) ) . or_default ( ) ;
5760
5861 let role_entry = self
@@ -97,9 +100,157 @@ impl DefaultRoleManager {
97100 }
98101 }
99102
103+ // If domain matching function exists and this was a new domain, copy
104+ // role links from matching domains into the newly created domain so
105+ // that BFS will see inherited links in this domain's graph.
106+ if is_new_domain {
107+ if let Some ( domain_matching_fn) = self . domain_matching_fn {
108+ let keys: Vec < String > =
109+ self . all_domains . keys ( ) . cloned ( ) . collect ( ) ;
110+ for d in keys {
111+ if d != domain && ( domain_matching_fn) ( domain, & d) {
112+ self . copy_from_domain ( & d, domain) ;
113+ }
114+ }
115+ }
116+ }
117+
100118 new_role_id
101119 }
102120
121+ // propagate a Link addition (name1 -> name2) from `domain` into all
122+ // affected/matching domains. This extracts the inline logic from
123+ // `add_link` so the code is clearer and avoids nested borrows.
124+ fn propagate_link_to_affected_domains (
125+ & mut self ,
126+ name1 : & str ,
127+ name2 : & str ,
128+ domain : & str ,
129+ ) {
130+ let name1_owned = name1. to_string ( ) ;
131+ let name2_owned = name2. to_string ( ) ;
132+ let affected = self . affected_domain_names ( domain) ;
133+ for d in affected {
134+ // obtain mutable graph and index map for the affected domain
135+ let g = self . all_domains . get_mut ( & d) . unwrap ( ) ;
136+ let idx_map =
137+ self . all_domains_indices . entry ( d. clone ( ) ) . or_default ( ) ;
138+ let idx1 = Self :: ensure_node_in_graph ( g, idx_map, & name1_owned) ;
139+ let idx2 = Self :: ensure_node_in_graph ( g, idx_map, & name2_owned) ;
140+
141+ // add Link edge if missing
142+ let has_link = g
143+ . edges_connecting ( idx1, idx2)
144+ . any ( |e| matches ! ( * e. weight( ) , EdgeVariant :: Link ) ) ;
145+ if !has_link {
146+ g. add_edge ( idx1, idx2, EdgeVariant :: Link ) ;
147+ }
148+ }
149+
150+ #[ cfg( feature = "cached" ) ]
151+ self . cache . clear ( ) ;
152+ }
153+
154+ // ensure a node with `name` exists in graph `g` and in the provided
155+ // `idx_map`. Returns the NodeIndex for the node.
156+ fn ensure_node_in_graph (
157+ g : & mut StableDiGraph < String , EdgeVariant > ,
158+ idx_map : & mut HashMap < String , NodeIndex < u32 > > ,
159+ name : & str ,
160+ ) -> NodeIndex < u32 > {
161+ if let Some ( idx) = idx_map. get ( name) {
162+ * idx
163+ } else if let Some ( idx) = g. node_indices ( ) . find ( |& i| g[ i] == name) {
164+ idx_map. insert ( name. to_string ( ) , idx) ;
165+ idx
166+ } else {
167+ let ni = g. add_node ( name. to_string ( ) ) ;
168+ idx_map. insert ( name. to_string ( ) , ni) ;
169+ ni
170+ }
171+ }
172+
173+ // return the list of affected domain names (immutable) to avoid nested
174+ // mutable borrows when performing operations across domains
175+ fn affected_domain_names ( & self , domain : & str ) -> Vec < String > {
176+ let mut res = Vec :: new ( ) ;
177+ if let Some ( domain_matching_fn) = self . domain_matching_fn {
178+ let keys: Vec < String > = self . all_domains . keys ( ) . cloned ( ) . collect ( ) ;
179+ for d in keys {
180+ if d != domain && ( domain_matching_fn) ( & d, domain) {
181+ res. push ( d) ;
182+ }
183+ }
184+ }
185+ res
186+ }
187+
188+ // copy all role links and nodes from `src_domain` graph into `dst_domain` graph
189+ fn copy_from_domain ( & mut self , src_domain : & str , dst_domain : & str ) {
190+ if src_domain == dst_domain {
191+ return ;
192+ }
193+
194+ // ensure both graphs exist
195+ if !self . all_domains . contains_key ( src_domain) {
196+ return ;
197+ }
198+
199+ let src_graph = match self . all_domains . get ( src_domain) {
200+ Some ( g) => g. clone ( ) ,
201+ None => return ,
202+ } ;
203+
204+ // ensure dst indices map exists
205+ let dst_indices = self
206+ . all_domains_indices
207+ . entry ( dst_domain. into ( ) )
208+ . or_default ( ) ;
209+
210+ let dst_graph = self . all_domains . entry ( dst_domain. into ( ) ) . or_default ( ) ;
211+
212+ // copy nodes: ensure names exist in dst and capture mapping
213+ let mut id_map: HashMap < NodeIndex < u32 > , NodeIndex < u32 > > =
214+ HashMap :: new ( ) ;
215+ for src_idx in src_graph. node_indices ( ) {
216+ let name = & src_graph[ src_idx] ;
217+ let dst_idx = if let Some ( idx) = dst_indices. get ( name) {
218+ * idx
219+ } else {
220+ let new_idx = dst_graph. add_node ( name. clone ( ) ) ;
221+ dst_indices. insert ( name. clone ( ) , new_idx) ;
222+ new_idx
223+ } ;
224+ id_map. insert ( src_idx, dst_idx) ;
225+ }
226+
227+ // copy edges: for each edge in src_graph, add equivalent edge in dst if missing
228+ for edge_idx in src_graph. edge_indices ( ) {
229+ if let Some ( ( src_s, src_t) ) = src_graph. edge_endpoints ( edge_idx) {
230+ if let Some ( weight) = src_graph. edge_weight ( edge_idx) {
231+ let dst_s = id_map. get ( & src_s) . unwrap ( ) ;
232+ let dst_t = id_map. get ( & src_t) . unwrap ( ) ;
233+
234+ let need_add = match dst_graph. find_edge ( * dst_s, * dst_t) {
235+ Some ( idx) => {
236+ // if existing edge is Match but source weight is Link, allow adding Link
237+ !matches ! ( dst_graph[ idx] , EdgeVariant :: Match )
238+ || !matches ! ( weight, & EdgeVariant :: Match )
239+ }
240+ None => true ,
241+ } ;
242+
243+ if need_add {
244+ dst_graph. add_edge ( * dst_s, * dst_t, weight. clone ( ) ) ;
245+ }
246+ }
247+ }
248+ }
249+
250+ #[ cfg( feature = "cached" ) ]
251+ self . cache . clear ( ) ;
252+ }
253+
103254 fn matched_domains ( & self , domain : Option < & str > ) -> Vec < String > {
104255 let domain = domain. unwrap_or ( DEFAULT_DOMAIN ) ;
105256 if let Some ( domain_matching_fn) = self . domain_matching_fn {
@@ -203,8 +354,14 @@ impl RoleManager for DefaultRoleManager {
203354 if add_link {
204355 graph. add_edge ( role1, role2, EdgeVariant :: Link ) ;
205356
357+ if let Some ( domain_str) = domain {
358+ self . propagate_link_to_affected_domains (
359+ name1, name2, domain_str,
360+ ) ;
361+ }
362+
206363 #[ cfg( feature = "cached" ) ]
207- self . cache . clear ( )
364+ self . cache . clear ( ) ;
208365 }
209366 }
210367
@@ -787,4 +944,28 @@ mod tests {
787944
788945 assert_eq ! ( vec![ "alice" ] , sort_unstable( rm. get_users( "*" , None ) ) ) ;
789946 }
947+
948+ #[ test]
949+ fn test_cross_domain_role_inheritance_complex ( ) {
950+ use crate :: model:: key_match;
951+ let mut rm = DefaultRoleManager :: new ( 10 ) ;
952+ rm. matching_fn ( None , Some ( key_match) ) ;
953+
954+ rm. add_link ( "editor" , "admin" , Some ( "*" ) ) ;
955+ rm. add_link ( "viewer" , "editor" , Some ( "*" ) ) ;
956+
957+ rm. add_link ( "alice" , "editor" , Some ( "domain1" ) ) ;
958+ rm. add_link ( "bob" , "viewer" , Some ( "domain2" ) ) ;
959+
960+ assert ! ( rm. has_link( "alice" , "admin" , Some ( "domain1" ) ) ) ;
961+ assert ! ( rm. has_link( "bob" , "editor" , Some ( "domain2" ) ) ) ;
962+ assert ! ( rm. has_link( "bob" , "admin" , Some ( "domain2" ) ) ) ;
963+
964+ rm. add_link ( "charlie" , "viewer" , Some ( "domain3" ) ) ;
965+ assert ! ( rm. has_link( "charlie" , "editor" , Some ( "domain3" ) ) ) ;
966+ assert ! ( rm. has_link( "charlie" , "admin" , Some ( "domain3" ) ) ) ;
967+
968+ rm. add_link ( "super_admin" , "admin" , Some ( "domain1" ) ) ;
969+ assert ! ( rm. has_link( "super_admin" , "admin" , Some ( "domain1" ) ) ) ;
970+ }
790971}
0 commit comments