@@ -32,9 +32,13 @@ def uprate_parameters(root: ParameterNode) -> ParameterNode:
3232 for parameter in root .get_descendants ()
3333 if isinstance (parameter , Parameter )
3434 ]
35+ parameter_paths = get_parameter_paths (root )
3536
36- for parameter in sort_parameters_by_uprating_dependencies (parameters ):
37- uprate_parameter (parameter , root )
37+ for parameter in sort_parameters_by_uprating_dependencies (
38+ parameters ,
39+ parameter_paths ,
40+ ):
41+ uprate_parameter (parameter , root , parameter_paths )
3842 return root
3943
4044
@@ -57,35 +61,196 @@ def get_uprating_dependency_name(parameter: Parameter) -> Optional[str]:
5761 return dependency_name
5862
5963
64+ def get_parameter_paths (root : ParameterNode ) -> dict [int , str ]:
65+ paths = {}
66+
67+ def visit_parameter_node (node : ParameterNode , path : str ) -> None :
68+ for child_name , child in node .children .items ():
69+ child_path = f"{ path } .{ child_name } " if path else child_name
70+ visit_child (child , child_path )
71+
72+ def visit_child (child , path : str ) -> None :
73+ if isinstance (child , Parameter ):
74+ paths [id (child )] = path
75+ elif isinstance (child , ParameterNode ):
76+ visit_parameter_node (child , path )
77+ else :
78+ brackets = getattr (child , "__dict__" , {}).get ("brackets" )
79+ if brackets is not None :
80+ for index , bracket in enumerate (brackets ):
81+ visit_parameter_node (bracket , f"{ path } [{ index } ]" )
82+
83+ visit_parameter_node (root , "" )
84+ return paths
85+
86+
87+ def join_parameter_path (prefix : str , suffix : str ) -> str :
88+ if not prefix :
89+ return suffix
90+ if not suffix :
91+ return prefix
92+ return f"{ prefix } .{ suffix } "
93+
94+
95+ def get_parameter_scope_prefixes (
96+ parameter : Parameter ,
97+ parameter_paths : dict [int , str ],
98+ ) -> Optional [tuple [str , str ]]:
99+ parameter_path = parameter_paths .get (id (parameter ))
100+ if parameter_path is None or parameter_path == parameter .name :
101+ return None
102+ parameter_name_parts = parameter .name .split ("." )
103+ parameter_path_parts = parameter_path .split ("." )
104+ common_suffix_length = 0
105+ max_common_suffix_length = min (
106+ len (parameter_name_parts ),
107+ len (parameter_path_parts ),
108+ )
109+ while (
110+ common_suffix_length < max_common_suffix_length
111+ and parameter_name_parts [- common_suffix_length - 1 ]
112+ == parameter_path_parts [- common_suffix_length - 1 ]
113+ ):
114+ common_suffix_length += 1
115+ if common_suffix_length == 0 :
116+ return None
117+ original_prefix = "." .join (parameter_name_parts [:- common_suffix_length ])
118+ current_prefix = "." .join (parameter_path_parts [:- common_suffix_length ])
119+ return original_prefix , current_prefix
120+
121+
122+ def map_original_parameter_path (
123+ parameter_path : str ,
124+ original_prefix : str ,
125+ current_prefix : str ,
126+ ) -> Optional [str ]:
127+ if not original_prefix :
128+ return join_parameter_path (current_prefix , parameter_path )
129+ if parameter_path == original_prefix :
130+ return current_prefix
131+ original_prefix_with_separator = f"{ original_prefix } ."
132+ if parameter_path .startswith (original_prefix_with_separator ):
133+ return join_parameter_path (
134+ current_prefix ,
135+ parameter_path [len (original_prefix_with_separator ) :],
136+ )
137+ return None
138+
139+
140+ def get_scoped_uprating_dependency_names (
141+ parameter : Parameter ,
142+ dependency_name : str ,
143+ parameter_paths : dict [int , str ],
144+ ) -> list [str ]:
145+ parameter_path = parameter_paths .get (id (parameter ))
146+ if parameter_path is None or parameter_path == parameter .name :
147+ return [dependency_name ]
148+
149+ dependency_names = []
150+
151+ def add_dependency_name (candidate : Optional [str ]) -> None :
152+ if candidate and candidate not in dependency_names :
153+ dependency_names .append (candidate )
154+
155+ scope_prefixes = get_parameter_scope_prefixes (parameter , parameter_paths )
156+ if scope_prefixes is not None :
157+ original_prefix , current_prefix = scope_prefixes
158+ add_dependency_name (
159+ map_original_parameter_path (
160+ dependency_name ,
161+ original_prefix ,
162+ current_prefix ,
163+ )
164+ )
165+
166+ add_dependency_name (dependency_name )
167+ return dependency_names
168+
169+
170+ def get_parameter_lookup_names (
171+ parameter : Parameter ,
172+ parameter_paths : dict [int , str ],
173+ ) -> set [str ]:
174+ parameter_path = parameter_paths .get (id (parameter ))
175+ if parameter_path is None :
176+ return {parameter .name }
177+ return {parameter_path }
178+
179+
180+ def get_uprating_parameter (
181+ root : ParameterNode ,
182+ parameter : Parameter ,
183+ dependency_name : str ,
184+ parameter_paths : dict [int , str ],
185+ ) -> Parameter :
186+ for scoped_dependency_name in get_scoped_uprating_dependency_names (
187+ parameter ,
188+ dependency_name ,
189+ parameter_paths ,
190+ ):
191+ try :
192+ return get_parameter (root , scoped_dependency_name )
193+ except ValueError :
194+ continue
195+ return get_parameter (root , dependency_name )
196+
197+
60198def sort_parameters_by_uprating_dependencies (
61199 parameters : list [Parameter ],
200+ parameter_paths : Optional [dict [int , str ]] = None ,
62201) -> list [Parameter ]:
202+ if parameter_paths is None :
203+ parameter_paths = {}
63204 parameters_to_uprate = [
64205 parameter
65206 for parameter in parameters
66207 if parameter .metadata .get ("uprating" ) is not None
67208 ]
68- parameter_by_name = {
69- parameter .name : parameter for parameter in parameters_to_uprate
70- }
209+ parameters_by_name = {}
210+ for parameter in parameters_to_uprate :
211+ for name in get_parameter_lookup_names (parameter , parameter_paths ):
212+ parameters_by_name .setdefault (name , []).append (parameter )
71213 ordered_parameters = []
72214 visited = set ()
73215 visiting = []
216+ visiting_ids = set ()
74217
75218 def visit (parameter : Parameter ):
76- if parameter .name in visited :
219+ parameter_id = id (parameter )
220+ if parameter_id in visited :
77221 return
78- if parameter .name in visiting :
79- cycle = visiting [visiting .index (parameter .name ) :] + [parameter .name ]
222+ if parameter_id in visiting_ids :
223+ cycle_start = next (
224+ index
225+ for index , visiting_parameter in enumerate (visiting )
226+ if id (visiting_parameter ) == parameter_id
227+ )
228+ cycle = visiting [cycle_start :] + [parameter ]
80229 raise ValueError (
81- "Cyclic uprating dependency detected: " + " -> " .join (cycle )
230+ "Cyclic uprating dependency detected: "
231+ + " -> " .join (parameter .name for parameter in cycle )
82232 )
83- visiting .append (parameter .name )
233+ visiting .append (parameter )
234+ visiting_ids .add (parameter_id )
84235 dependency_name = get_uprating_dependency_name (parameter )
85- if dependency_name in parameter_by_name :
86- visit (parameter_by_name [dependency_name ])
236+ dependency_parameters = []
237+ if dependency_name is not None :
238+ for scoped_dependency_name in get_scoped_uprating_dependency_names (
239+ parameter ,
240+ dependency_name ,
241+ parameter_paths ,
242+ ):
243+ dependency_parameters = parameters_by_name .get (
244+ scoped_dependency_name ,
245+ [],
246+ )
247+ if dependency_parameters :
248+ break
249+ for dependency in dependency_parameters :
250+ visit (dependency )
87251 visiting .pop ()
88- visited .add (parameter .name )
252+ visiting_ids .remove (parameter_id )
253+ visited .add (parameter_id )
89254 ordered_parameters .append (parameter )
90255
91256 for parameter in sorted (parameters_to_uprate , key = lambda p : p .name ):
@@ -94,7 +259,13 @@ def visit(parameter: Parameter):
94259 return ordered_parameters
95260
96261
97- def uprate_parameter (parameter : Parameter , root : ParameterNode ) -> None :
262+ def uprate_parameter (
263+ parameter : Parameter ,
264+ root : ParameterNode ,
265+ parameter_paths : Optional [dict [int , str ]] = None ,
266+ ) -> None :
267+ if parameter_paths is None :
268+ parameter_paths = {}
98269 # Pull the uprating definition dict
99270 meta = normalize_uprating_metadata (parameter .metadata ["uprating" ])
100271
@@ -106,7 +277,12 @@ def uprate_parameter(parameter: Parameter, root: ParameterNode) -> None:
106277 )
107278 # Otherwise, pull uprating table from YAML
108279 else :
109- uprating_parameter = get_parameter (root , meta ["parameter" ])
280+ uprating_parameter = get_uprating_parameter (
281+ root ,
282+ parameter ,
283+ meta ["parameter" ],
284+ parameter_paths ,
285+ )
110286
111287 # If uprating with a set candence, ensure that all
112288 # required values are present
0 commit comments