Skip to content

Commit f492657

Browse files
authored
Fix uprating of cloned parameter subtrees (#492)
1 parent 20ebee2 commit f492657

3 files changed

Lines changed: 353 additions & 15 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fixed uprating dependency sorting for cloned parameter subtrees with duplicate parameter names.

policyengine_core/parameters/operations/uprate_parameters.py

Lines changed: 191 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,13 @@ def uprate_parameters(root: ParameterNode) -> ParameterNode:
3232
for parameter in root.get_descendants()
3333
if isinstance(parameter, Parameter)
3434
]
35+
parameter_paths = get_parameter_paths(root)
3536

36-
for parameter in sort_parameters_by_uprating_dependencies(parameters):
37-
uprate_parameter(parameter, root)
37+
for parameter in sort_parameters_by_uprating_dependencies(
38+
parameters,
39+
parameter_paths,
40+
):
41+
uprate_parameter(parameter, root, parameter_paths)
3842
return root
3943

4044

@@ -57,35 +61,196 @@ def get_uprating_dependency_name(parameter: Parameter) -> Optional[str]:
5761
return dependency_name
5862

5963

64+
def get_parameter_paths(root: ParameterNode) -> dict[int, str]:
65+
paths = {}
66+
67+
def visit_parameter_node(node: ParameterNode, path: str) -> None:
68+
for child_name, child in node.children.items():
69+
child_path = f"{path}.{child_name}" if path else child_name
70+
visit_child(child, child_path)
71+
72+
def visit_child(child, path: str) -> None:
73+
if isinstance(child, Parameter):
74+
paths[id(child)] = path
75+
elif isinstance(child, ParameterNode):
76+
visit_parameter_node(child, path)
77+
else:
78+
brackets = getattr(child, "__dict__", {}).get("brackets")
79+
if brackets is not None:
80+
for index, bracket in enumerate(brackets):
81+
visit_parameter_node(bracket, f"{path}[{index}]")
82+
83+
visit_parameter_node(root, "")
84+
return paths
85+
86+
87+
def join_parameter_path(prefix: str, suffix: str) -> str:
88+
if not prefix:
89+
return suffix
90+
if not suffix:
91+
return prefix
92+
return f"{prefix}.{suffix}"
93+
94+
95+
def get_parameter_scope_prefixes(
96+
parameter: Parameter,
97+
parameter_paths: dict[int, str],
98+
) -> Optional[tuple[str, str]]:
99+
parameter_path = parameter_paths.get(id(parameter))
100+
if parameter_path is None or parameter_path == parameter.name:
101+
return None
102+
parameter_name_parts = parameter.name.split(".")
103+
parameter_path_parts = parameter_path.split(".")
104+
common_suffix_length = 0
105+
max_common_suffix_length = min(
106+
len(parameter_name_parts),
107+
len(parameter_path_parts),
108+
)
109+
while (
110+
common_suffix_length < max_common_suffix_length
111+
and parameter_name_parts[-common_suffix_length - 1]
112+
== parameter_path_parts[-common_suffix_length - 1]
113+
):
114+
common_suffix_length += 1
115+
if common_suffix_length == 0:
116+
return None
117+
original_prefix = ".".join(parameter_name_parts[:-common_suffix_length])
118+
current_prefix = ".".join(parameter_path_parts[:-common_suffix_length])
119+
return original_prefix, current_prefix
120+
121+
122+
def map_original_parameter_path(
123+
parameter_path: str,
124+
original_prefix: str,
125+
current_prefix: str,
126+
) -> Optional[str]:
127+
if not original_prefix:
128+
return join_parameter_path(current_prefix, parameter_path)
129+
if parameter_path == original_prefix:
130+
return current_prefix
131+
original_prefix_with_separator = f"{original_prefix}."
132+
if parameter_path.startswith(original_prefix_with_separator):
133+
return join_parameter_path(
134+
current_prefix,
135+
parameter_path[len(original_prefix_with_separator) :],
136+
)
137+
return None
138+
139+
140+
def get_scoped_uprating_dependency_names(
141+
parameter: Parameter,
142+
dependency_name: str,
143+
parameter_paths: dict[int, str],
144+
) -> list[str]:
145+
parameter_path = parameter_paths.get(id(parameter))
146+
if parameter_path is None or parameter_path == parameter.name:
147+
return [dependency_name]
148+
149+
dependency_names = []
150+
151+
def add_dependency_name(candidate: Optional[str]) -> None:
152+
if candidate and candidate not in dependency_names:
153+
dependency_names.append(candidate)
154+
155+
scope_prefixes = get_parameter_scope_prefixes(parameter, parameter_paths)
156+
if scope_prefixes is not None:
157+
original_prefix, current_prefix = scope_prefixes
158+
add_dependency_name(
159+
map_original_parameter_path(
160+
dependency_name,
161+
original_prefix,
162+
current_prefix,
163+
)
164+
)
165+
166+
add_dependency_name(dependency_name)
167+
return dependency_names
168+
169+
170+
def get_parameter_lookup_names(
171+
parameter: Parameter,
172+
parameter_paths: dict[int, str],
173+
) -> set[str]:
174+
parameter_path = parameter_paths.get(id(parameter))
175+
if parameter_path is None:
176+
return {parameter.name}
177+
return {parameter_path}
178+
179+
180+
def get_uprating_parameter(
181+
root: ParameterNode,
182+
parameter: Parameter,
183+
dependency_name: str,
184+
parameter_paths: dict[int, str],
185+
) -> Parameter:
186+
for scoped_dependency_name in get_scoped_uprating_dependency_names(
187+
parameter,
188+
dependency_name,
189+
parameter_paths,
190+
):
191+
try:
192+
return get_parameter(root, scoped_dependency_name)
193+
except ValueError:
194+
continue
195+
return get_parameter(root, dependency_name)
196+
197+
60198
def sort_parameters_by_uprating_dependencies(
61199
parameters: list[Parameter],
200+
parameter_paths: Optional[dict[int, str]] = None,
62201
) -> list[Parameter]:
202+
if parameter_paths is None:
203+
parameter_paths = {}
63204
parameters_to_uprate = [
64205
parameter
65206
for parameter in parameters
66207
if parameter.metadata.get("uprating") is not None
67208
]
68-
parameter_by_name = {
69-
parameter.name: parameter for parameter in parameters_to_uprate
70-
}
209+
parameters_by_name = {}
210+
for parameter in parameters_to_uprate:
211+
for name in get_parameter_lookup_names(parameter, parameter_paths):
212+
parameters_by_name.setdefault(name, []).append(parameter)
71213
ordered_parameters = []
72214
visited = set()
73215
visiting = []
216+
visiting_ids = set()
74217

75218
def visit(parameter: Parameter):
76-
if parameter.name in visited:
219+
parameter_id = id(parameter)
220+
if parameter_id in visited:
77221
return
78-
if parameter.name in visiting:
79-
cycle = visiting[visiting.index(parameter.name) :] + [parameter.name]
222+
if parameter_id in visiting_ids:
223+
cycle_start = next(
224+
index
225+
for index, visiting_parameter in enumerate(visiting)
226+
if id(visiting_parameter) == parameter_id
227+
)
228+
cycle = visiting[cycle_start:] + [parameter]
80229
raise ValueError(
81-
"Cyclic uprating dependency detected: " + " -> ".join(cycle)
230+
"Cyclic uprating dependency detected: "
231+
+ " -> ".join(parameter.name for parameter in cycle)
82232
)
83-
visiting.append(parameter.name)
233+
visiting.append(parameter)
234+
visiting_ids.add(parameter_id)
84235
dependency_name = get_uprating_dependency_name(parameter)
85-
if dependency_name in parameter_by_name:
86-
visit(parameter_by_name[dependency_name])
236+
dependency_parameters = []
237+
if dependency_name is not None:
238+
for scoped_dependency_name in get_scoped_uprating_dependency_names(
239+
parameter,
240+
dependency_name,
241+
parameter_paths,
242+
):
243+
dependency_parameters = parameters_by_name.get(
244+
scoped_dependency_name,
245+
[],
246+
)
247+
if dependency_parameters:
248+
break
249+
for dependency in dependency_parameters:
250+
visit(dependency)
87251
visiting.pop()
88-
visited.add(parameter.name)
252+
visiting_ids.remove(parameter_id)
253+
visited.add(parameter_id)
89254
ordered_parameters.append(parameter)
90255

91256
for parameter in sorted(parameters_to_uprate, key=lambda p: p.name):
@@ -94,7 +259,13 @@ def visit(parameter: Parameter):
94259
return ordered_parameters
95260

96261

97-
def uprate_parameter(parameter: Parameter, root: ParameterNode) -> None:
262+
def uprate_parameter(
263+
parameter: Parameter,
264+
root: ParameterNode,
265+
parameter_paths: Optional[dict[int, str]] = None,
266+
) -> None:
267+
if parameter_paths is None:
268+
parameter_paths = {}
98269
# Pull the uprating definition dict
99270
meta = normalize_uprating_metadata(parameter.metadata["uprating"])
100271

@@ -106,7 +277,12 @@ def uprate_parameter(parameter: Parameter, root: ParameterNode) -> None:
106277
)
107278
# Otherwise, pull uprating table from YAML
108279
else:
109-
uprating_parameter = get_parameter(root, meta["parameter"])
280+
uprating_parameter = get_uprating_parameter(
281+
root,
282+
parameter,
283+
meta["parameter"],
284+
parameter_paths,
285+
)
110286

111287
# If uprating with a set candence, ensure that all
112288
# required values are present

0 commit comments

Comments
 (0)