Skip to content

Commit 0d8b9db

Browse files
authored
feat(rbac): preload and propagate wildcard/pattern domain role links (#387)
1 parent 06159a3 commit 0d8b9db

3 files changed

Lines changed: 190 additions & 9 deletions

File tree

src/rbac/default_role_manager.rs

Lines changed: 183 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ pub struct DefaultRoleManager {
2727
domain_matching_fn: Option<MatchingFn>,
2828
}
2929

30-
#[derive(Debug)]
30+
#[derive(Clone, Debug)]
3131
enum EdgeVariant {
3232
Link,
3333
Match,
@@ -53,6 +53,9 @@ impl DefaultRoleManager {
5353
) -> NodeIndex<u32> {
5454
let domain = domain.unwrap_or(DEFAULT_DOMAIN);
5555

56+
// detect whether this is a new domain creation
57+
let is_new_domain = !self.all_domains.contains_key(domain);
58+
5659
let graph = self.all_domains.entry(domain.into()).or_default();
5760

5861
let role_entry = self
@@ -97,9 +100,157 @@ impl DefaultRoleManager {
97100
}
98101
}
99102

103+
// If domain matching function exists and this was a new domain, copy
104+
// role links from matching domains into the newly created domain so
105+
// that BFS will see inherited links in this domain's graph.
106+
if is_new_domain {
107+
if let Some(domain_matching_fn) = self.domain_matching_fn {
108+
let keys: Vec<String> =
109+
self.all_domains.keys().cloned().collect();
110+
for d in keys {
111+
if d != domain && (domain_matching_fn)(domain, &d) {
112+
self.copy_from_domain(&d, domain);
113+
}
114+
}
115+
}
116+
}
117+
100118
new_role_id
101119
}
102120

121+
// propagate a Link addition (name1 -> name2) from `domain` into all
122+
// affected/matching domains. This extracts the inline logic from
123+
// `add_link` so the code is clearer and avoids nested borrows.
124+
fn propagate_link_to_affected_domains(
125+
&mut self,
126+
name1: &str,
127+
name2: &str,
128+
domain: &str,
129+
) {
130+
let name1_owned = name1.to_string();
131+
let name2_owned = name2.to_string();
132+
let affected = self.affected_domain_names(domain);
133+
for d in affected {
134+
// obtain mutable graph and index map for the affected domain
135+
let g = self.all_domains.get_mut(&d).unwrap();
136+
let idx_map =
137+
self.all_domains_indices.entry(d.clone()).or_default();
138+
let idx1 = Self::ensure_node_in_graph(g, idx_map, &name1_owned);
139+
let idx2 = Self::ensure_node_in_graph(g, idx_map, &name2_owned);
140+
141+
// add Link edge if missing
142+
let has_link = g
143+
.edges_connecting(idx1, idx2)
144+
.any(|e| matches!(*e.weight(), EdgeVariant::Link));
145+
if !has_link {
146+
g.add_edge(idx1, idx2, EdgeVariant::Link);
147+
}
148+
}
149+
150+
#[cfg(feature = "cached")]
151+
self.cache.clear();
152+
}
153+
154+
// ensure a node with `name` exists in graph `g` and in the provided
155+
// `idx_map`. Returns the NodeIndex for the node.
156+
fn ensure_node_in_graph(
157+
g: &mut StableDiGraph<String, EdgeVariant>,
158+
idx_map: &mut HashMap<String, NodeIndex<u32>>,
159+
name: &str,
160+
) -> NodeIndex<u32> {
161+
if let Some(idx) = idx_map.get(name) {
162+
*idx
163+
} else if let Some(idx) = g.node_indices().find(|&i| g[i] == name) {
164+
idx_map.insert(name.to_string(), idx);
165+
idx
166+
} else {
167+
let ni = g.add_node(name.to_string());
168+
idx_map.insert(name.to_string(), ni);
169+
ni
170+
}
171+
}
172+
173+
// return the list of affected domain names (immutable) to avoid nested
174+
// mutable borrows when performing operations across domains
175+
fn affected_domain_names(&self, domain: &str) -> Vec<String> {
176+
let mut res = Vec::new();
177+
if let Some(domain_matching_fn) = self.domain_matching_fn {
178+
let keys: Vec<String> = self.all_domains.keys().cloned().collect();
179+
for d in keys {
180+
if d != domain && (domain_matching_fn)(&d, domain) {
181+
res.push(d);
182+
}
183+
}
184+
}
185+
res
186+
}
187+
188+
// copy all role links and nodes from `src_domain` graph into `dst_domain` graph
189+
fn copy_from_domain(&mut self, src_domain: &str, dst_domain: &str) {
190+
if src_domain == dst_domain {
191+
return;
192+
}
193+
194+
// ensure both graphs exist
195+
if !self.all_domains.contains_key(src_domain) {
196+
return;
197+
}
198+
199+
let src_graph = match self.all_domains.get(src_domain) {
200+
Some(g) => g.clone(),
201+
None => return,
202+
};
203+
204+
// ensure dst indices map exists
205+
let dst_indices = self
206+
.all_domains_indices
207+
.entry(dst_domain.into())
208+
.or_default();
209+
210+
let dst_graph = self.all_domains.entry(dst_domain.into()).or_default();
211+
212+
// copy nodes: ensure names exist in dst and capture mapping
213+
let mut id_map: HashMap<NodeIndex<u32>, NodeIndex<u32>> =
214+
HashMap::new();
215+
for src_idx in src_graph.node_indices() {
216+
let name = &src_graph[src_idx];
217+
let dst_idx = if let Some(idx) = dst_indices.get(name) {
218+
*idx
219+
} else {
220+
let new_idx = dst_graph.add_node(name.clone());
221+
dst_indices.insert(name.clone(), new_idx);
222+
new_idx
223+
};
224+
id_map.insert(src_idx, dst_idx);
225+
}
226+
227+
// copy edges: for each edge in src_graph, add equivalent edge in dst if missing
228+
for edge_idx in src_graph.edge_indices() {
229+
if let Some((src_s, src_t)) = src_graph.edge_endpoints(edge_idx) {
230+
if let Some(weight) = src_graph.edge_weight(edge_idx) {
231+
let dst_s = id_map.get(&src_s).unwrap();
232+
let dst_t = id_map.get(&src_t).unwrap();
233+
234+
let need_add = match dst_graph.find_edge(*dst_s, *dst_t) {
235+
Some(idx) => {
236+
// if existing edge is Match but source weight is Link, allow adding Link
237+
!matches!(dst_graph[idx], EdgeVariant::Match)
238+
|| !matches!(weight, &EdgeVariant::Match)
239+
}
240+
None => true,
241+
};
242+
243+
if need_add {
244+
dst_graph.add_edge(*dst_s, *dst_t, weight.clone());
245+
}
246+
}
247+
}
248+
}
249+
250+
#[cfg(feature = "cached")]
251+
self.cache.clear();
252+
}
253+
103254
fn matched_domains(&self, domain: Option<&str>) -> Vec<String> {
104255
let domain = domain.unwrap_or(DEFAULT_DOMAIN);
105256
if let Some(domain_matching_fn) = self.domain_matching_fn {
@@ -203,8 +354,14 @@ impl RoleManager for DefaultRoleManager {
203354
if add_link {
204355
graph.add_edge(role1, role2, EdgeVariant::Link);
205356

357+
if let Some(domain_str) = domain {
358+
self.propagate_link_to_affected_domains(
359+
name1, name2, domain_str,
360+
);
361+
}
362+
206363
#[cfg(feature = "cached")]
207-
self.cache.clear()
364+
self.cache.clear();
208365
}
209366
}
210367

@@ -787,4 +944,28 @@ mod tests {
787944

788945
assert_eq!(vec!["alice"], sort_unstable(rm.get_users("*", None)));
789946
}
947+
948+
#[test]
949+
fn test_cross_domain_role_inheritance_complex() {
950+
use crate::model::key_match;
951+
let mut rm = DefaultRoleManager::new(10);
952+
rm.matching_fn(None, Some(key_match));
953+
954+
rm.add_link("editor", "admin", Some("*"));
955+
rm.add_link("viewer", "editor", Some("*"));
956+
957+
rm.add_link("alice", "editor", Some("domain1"));
958+
rm.add_link("bob", "viewer", Some("domain2"));
959+
960+
assert!(rm.has_link("alice", "admin", Some("domain1")));
961+
assert!(rm.has_link("bob", "editor", Some("domain2")));
962+
assert!(rm.has_link("bob", "admin", Some("domain2")));
963+
964+
rm.add_link("charlie", "viewer", Some("domain3"));
965+
assert!(rm.has_link("charlie", "editor", Some("domain3")));
966+
assert!(rm.has_link("charlie", "admin", Some("domain3")));
967+
968+
rm.add_link("super_admin", "admin", Some("domain1"));
969+
assert!(rm.has_link("super_admin", "admin", Some("domain1")));
970+
}
790971
}

src/rbac_api.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -907,7 +907,7 @@ mod tests {
907907

908908
let adapter =
909909
FileAdapter::new("examples/rbac_with_hierarchy_policy.csv");
910-
let mut e = Enforcer::new(m, adapter).await.unwrap();
910+
let e = Enforcer::new(m, adapter).await.unwrap();
911911

912912
assert_eq!(
913913
vec![vec!["alice", "data1", "read"]],
@@ -944,7 +944,7 @@ mod tests {
944944

945945
let adapter =
946946
FileAdapter::new("examples/rbac_with_hierarchy_policy.csv");
947-
let mut e = Enforcer::new(m, adapter).await.unwrap();
947+
let e = Enforcer::new(m, adapter).await.unwrap();
948948

949949
assert_eq!(
950950
vec![vec!["alice", "data1", "read"]],
@@ -1051,7 +1051,7 @@ mod tests {
10511051
let adapter = FileAdapter::new(
10521052
"examples/rbac_with_hierarchy_with_domains_policy.csv",
10531053
);
1054-
let mut e = Enforcer::new(m, adapter).await.unwrap();
1054+
let e = Enforcer::new(m, adapter).await.unwrap();
10551055

10561056
assert_eq!(
10571057
vec![
@@ -1075,7 +1075,7 @@ mod tests {
10751075
tokio::test
10761076
)]
10771077
async fn test_pattern_matching_fn() {
1078-
let mut e = Enforcer::new(
1078+
let e = Enforcer::new(
10791079
"examples/rbac_with_pattern_model.conf",
10801080
"examples/rbac_with_pattern_policy.csv",
10811081
)
@@ -1120,7 +1120,7 @@ mod tests {
11201120
tokio::test
11211121
)]
11221122
async fn test_pattern_matching_fn_with_domain() {
1123-
let mut e = Enforcer::new(
1123+
let e = Enforcer::new(
11241124
"examples/rbac_with_pattern_domain_model.conf",
11251125
"examples/rbac_with_pattern_domain_policy.csv",
11261126
)
@@ -1164,7 +1164,7 @@ mod tests {
11641164
tokio::test
11651165
)]
11661166
async fn test_pattern_matching_basic_role() {
1167-
let mut e = Enforcer::new(
1167+
let e = Enforcer::new(
11681168
"examples/rbac_basic_role_model.conf",
11691169
"examples/rbac_basic_role_policy.csv",
11701170
)

src/util.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ pub fn escape_eval(m: &str) -> Cow<'_, str> {
3636
ESC_E.replace_all(m, "eval(escape_assertion(${1}))")
3737
}
3838

39-
pub fn parse_csv_line<S: AsRef<str>>(line: S) -> Option<Vec<String>> {
39+
pub fn parse_csv_line<'a, S: AsRef<str> + 'a>(line: S) -> Option<Vec<String>> {
4040
let line = line.as_ref().trim();
4141
if line.is_empty() || line.starts_with('#') {
4242
return None;

0 commit comments

Comments
 (0)