|
10 | 10 |
|
11 | 11 | from __future__ import annotations |
12 | 12 |
|
13 | | -import ast |
14 | 13 | import json |
15 | 14 | import sys |
16 | 15 | import subprocess |
17 | | -from dataclasses import dataclass, field |
| 16 | +from dataclasses import dataclass |
18 | 17 | from pathlib import Path |
19 | | -from typing import Dict, List, Set, Iterable, Any |
| 18 | +from typing import Dict, List, Iterable, Any |
20 | 19 | from collections import defaultdict |
21 | 20 |
|
22 | 21 | HERE = Path(__file__).resolve().parent |
|
30 | 29 | sys.path.insert(0, _toolchain_dir) |
31 | 30 |
|
32 | 31 | from mfc.params import CONSTRAINTS, DEPENDENCIES, get_value_label # noqa: E402 pylint: disable=wrong-import-position |
33 | | - |
34 | | - |
35 | | -# --------------------------------------------------------------------------- |
36 | | -# Data structures |
37 | | -# --------------------------------------------------------------------------- |
38 | | - |
39 | | -@dataclass |
40 | | -class Rule: |
41 | | - method: str # e.g. "check_igr_simulation" |
42 | | - lineno: int # line number of the prohibit call |
43 | | - params: List[str] # case parameter names used in condition |
44 | | - message: str # user-friendly error message |
45 | | - stages: Set[str] = field(default_factory=set) # e.g. {"simulation", "pre_process"} |
46 | | - |
47 | | - |
48 | | -# --------------------------------------------------------------------------- |
49 | | -# AST analysis: methods, call graph, rules |
50 | | -# --------------------------------------------------------------------------- |
51 | | - |
52 | | -class CaseValidatorAnalyzer(ast.NodeVisitor): |
53 | | - """ |
54 | | - Analyzes the CaseValidator class: |
55 | | -
|
56 | | - - collects all methods |
57 | | - - builds a call graph between methods |
58 | | - - extracts all self.prohibit(...) rules |
59 | | - """ |
60 | | - |
61 | | - def __init__(self): |
62 | | - super().__init__() |
63 | | - self.in_case_validator = False |
64 | | - self.current_method: str | None = None |
65 | | - |
66 | | - self.methods: Dict[str, ast.FunctionDef] = {} |
67 | | - self.call_graph: Dict[str, Set[str]] = defaultdict(set) |
68 | | - self.rules: List[Rule] = [] |
69 | | - |
70 | | - # Stack of {local_name -> param_name} maps, one per method |
71 | | - self.local_param_stack: List[Dict[str, str]] = [] |
72 | | - |
73 | | - # --- top-level entrypoint --- |
74 | | - |
75 | | - def visit_ClassDef(self, node: ast.ClassDef): |
76 | | - if node.name == "CaseValidator": |
77 | | - self.in_case_validator = True |
78 | | - # collect methods |
79 | | - for item in node.body: |
80 | | - if isinstance(item, ast.FunctionDef): |
81 | | - self.methods[item.name] = item |
82 | | - # now analyze all methods |
83 | | - for method in self.methods.values(): |
84 | | - self._analyze_method(method) |
85 | | - self.in_case_validator = False |
86 | | - else: |
87 | | - self.generic_visit(node) |
88 | | - |
89 | | - # --- per-method analysis --- |
90 | | - |
91 | | - def _analyze_method(self, func: ast.FunctionDef): |
92 | | - """Analyze a single method: local param mapping, call graph, rules.""" |
93 | | - self.current_method = func.name |
94 | | - local_param_map = self._build_local_param_map(func) |
95 | | - self.local_param_stack.append(local_param_map) |
96 | | - self.generic_visit(func) |
97 | | - self.local_param_stack.pop() |
98 | | - self.current_method = None |
99 | | - |
100 | | - def _build_local_param_map(self, func: ast.FunctionDef) -> Dict[str, str]: # pylint: disable=too-many-nested-blocks |
101 | | - """ |
102 | | - Look for assignments like: |
103 | | - igr = self.get('igr', 'F') == 'T' |
104 | | - model_eqns = self.get('model_eqns') |
105 | | - and record local_name -> 'param_name'. |
106 | | - """ |
107 | | - m: Dict[str, str] = {} |
108 | | - for stmt in func.body: # pylint: disable=too-many-nested-blocks |
109 | | - if isinstance(stmt, ast.Assign): |
110 | | - # Handle both direct calls and comparisons |
111 | | - value = stmt.value |
112 | | - # Unwrap comparisons like "self.get('igr', 'F') == 'T'" |
113 | | - if isinstance(value, ast.Compare): |
114 | | - value = value.left |
115 | | - |
116 | | - if isinstance(value, ast.Call): |
117 | | - call = value |
118 | | - if ( # pylint: disable=too-many-boolean-expressions |
119 | | - isinstance(call.func, ast.Attribute) |
120 | | - and isinstance(call.func.value, ast.Name) |
121 | | - and call.func.value.id == "self" |
122 | | - and call.func.attr == "get" |
123 | | - and call.args |
124 | | - and isinstance(call.args[0], ast.Constant) |
125 | | - and isinstance(call.args[0].value, str) |
126 | | - ): |
127 | | - param_name = call.args[0].value |
128 | | - for target in stmt.targets: |
129 | | - if isinstance(target, ast.Name): |
130 | | - m[target.id] = param_name |
131 | | - return m |
132 | | - |
133 | | - # --- visit calls to build call graph + rules --- |
134 | | - |
135 | | - def visit_Call(self, node: ast.Call): |
136 | | - # record method call edges: self.some_method(...) |
137 | | - if ( |
138 | | - isinstance(node.func, ast.Attribute) |
139 | | - and isinstance(node.func.value, ast.Name) |
140 | | - and node.func.value.id == "self" |
141 | | - and isinstance(node.func.attr, str) |
142 | | - ): |
143 | | - callee = node.func.attr |
144 | | - if self.current_method is not None: |
145 | | - # method call on self |
146 | | - self.call_graph[self.current_method].add(callee) |
147 | | - |
148 | | - # detect self.prohibit(<condition>, "<message>") |
149 | | - if ( |
150 | | - isinstance(node.func, ast.Attribute) |
151 | | - and isinstance(node.func.value, ast.Name) |
152 | | - and node.func.value.id == "self" |
153 | | - and node.func.attr == "prohibit" |
154 | | - and len(node.args) >= 2 |
155 | | - ): |
156 | | - condition, msg_node = node.args[0], node.args[1] |
157 | | - if isinstance(msg_node, ast.Constant) and isinstance(msg_node.value, str): |
158 | | - params = sorted(self._extract_params(condition)) |
159 | | - rule = Rule( |
160 | | - method=self.current_method or "<unknown>", |
161 | | - lineno=node.lineno, |
162 | | - params=params, |
163 | | - message=msg_node.value, |
164 | | - ) |
165 | | - self.rules.append(rule) |
166 | | - |
167 | | - self.generic_visit(node) |
168 | | - |
169 | | - def _extract_params(self, condition: ast.AST) -> Set[str]: |
170 | | - """ |
171 | | - Collect parameter names used in the condition via: |
172 | | - - local variables mapped from self.get(...) |
173 | | - - direct self.get('param_name', ...) calls |
174 | | - """ |
175 | | - params: Set[str] = set() |
176 | | - local_map = self.local_param_stack[-1] if self.local_param_stack else {} |
177 | | - |
178 | | - for node in ast.walk(condition): |
179 | | - # local names |
180 | | - if isinstance(node, ast.Name) and node.id in local_map: |
181 | | - params.add(local_map[node.id]) |
182 | | - |
183 | | - # direct self.get('param_name') |
184 | | - if isinstance(node, ast.Call): |
185 | | - if ( # pylint: disable=too-many-boolean-expressions |
186 | | - isinstance(node.func, ast.Attribute) |
187 | | - and isinstance(node.func.value, ast.Name) |
188 | | - and node.func.value.id == "self" |
189 | | - and node.func.attr == "get" |
190 | | - and node.args |
191 | | - and isinstance(node.args[0], ast.Constant) |
192 | | - and isinstance(node.args[0].value, str) |
193 | | - ): |
194 | | - params.add(node.args[0].value) |
195 | | - |
196 | | - return params |
197 | | - |
198 | | - |
199 | | -# --------------------------------------------------------------------------- |
200 | | -# Stage inference from validate_* roots and call graph |
201 | | -# --------------------------------------------------------------------------- |
202 | | - |
203 | | -STAGE_ROOTS: Dict[str, List[str]] = { |
204 | | - "common": ["validate_common"], |
205 | | - "simulation": ["validate_simulation"], |
206 | | - "pre_process": ["validate_pre_process"], |
207 | | - "post_process": ["validate_post_process"], |
208 | | -} |
209 | | - |
210 | | - |
211 | | -def compute_method_stages(call_graph: Dict[str, Set[str]]) -> Dict[str, Set[str]]: |
212 | | - """ |
213 | | - For each stage (simulation/pre_process/post_process/common), starting from |
214 | | - validate_* roots, walk the call graph and record which methods belong to which stages. |
215 | | - """ |
216 | | - method_stages: Dict[str, Set[str]] = defaultdict(set) |
217 | | - |
218 | | - def dfs(start: str, stage: str): |
219 | | - stack = [start] |
220 | | - visited: Set[str] = set() |
221 | | - while stack: |
222 | | - m = stack.pop() |
223 | | - if m in visited: |
224 | | - continue |
225 | | - visited.add(m) |
226 | | - method_stages[m].add(stage) |
227 | | - for nxt in call_graph.get(m, ()): |
228 | | - if nxt not in visited: |
229 | | - stack.append(nxt) |
230 | | - |
231 | | - for stage, roots in STAGE_ROOTS.items(): |
232 | | - for root in roots: |
233 | | - dfs(root, stage) |
234 | | - |
235 | | - return method_stages |
236 | | - |
237 | | - |
238 | | -# --------------------------------------------------------------------------- |
239 | | -# Classification of messages for nicer grouping |
240 | | -# --------------------------------------------------------------------------- |
241 | | - |
242 | | -def classify_message(msg: str) -> str: |
243 | | - """ |
244 | | - Roughly classify rule messages for nicer grouping. |
245 | | -
|
246 | | - Returns one of: "requirement", "incompatibility", "range", "other". |
247 | | - """ |
248 | | - text = msg.lower() |
249 | | - |
250 | | - if ( # pylint: disable=too-many-boolean-expressions |
251 | | - "not compatible" in text |
252 | | - or "does not support" in text |
253 | | - or "cannot be used" in text |
254 | | - or "must not" in text |
255 | | - or "is not supported" in text |
256 | | - or "incompatible" in text |
257 | | - or "untested" in text |
258 | | - ): |
259 | | - return "incompatibility" |
260 | | - |
261 | | - if ( # pylint: disable=too-many-boolean-expressions |
262 | | - "requires" in text |
263 | | - or "must be set if" in text |
264 | | - or "must be specified" in text |
265 | | - or "must be set with" in text |
266 | | - or "can only be enabled if" in text |
267 | | - or "must be set for" in text |
268 | | - ): |
269 | | - return "requirement" |
270 | | - |
271 | | - if ( # pylint: disable=too-many-boolean-expressions |
272 | | - "must be between" in text |
273 | | - or "must be positive" in text |
274 | | - or "must be non-negative" in text |
275 | | - or "must be greater than" in text |
276 | | - or "must be less than" in text |
277 | | - or "must be at least" in text |
278 | | - or "must be <=" in text |
279 | | - or "must be >=" in text |
280 | | - or "must be odd" in text |
281 | | - or "divisible by" in text |
282 | | - ): |
283 | | - return "range" |
284 | | - |
285 | | - return "other" |
286 | | - |
287 | | - |
288 | | -# Optional: nicer display names / categories (you can extend this) |
289 | | -FEATURE_META = { |
290 | | - "igr": {"title": "Iterative Generalized Riemann (IGR)", "category": "solver"}, |
291 | | - "bubbles_euler": {"title": "Euler–Euler Bubble Model", "category": "bubbles"}, |
292 | | - "bubbles_lagrange": {"title": "Euler–Lagrange Bubble Model", "category": "bubbles"}, |
293 | | - "qbmm": {"title": "Quadrature-Based Moment Method (QBMM)", "category": "bubbles"}, |
294 | | - "polydisperse": {"title": "Polydisperse Bubble Dynamics", "category": "bubbles"}, |
295 | | - "mhd": {"title": "Magnetohydrodynamics (MHD)", "category": "physics"}, |
296 | | - "alt_soundspeed": {"title": "Alternative Sound Speed", "category": "physics"}, |
297 | | - "surface_tension": {"title": "Surface Tension Model", "category": "physics"}, |
298 | | - "hypoelasticity": {"title": "Hypoelasticity", "category": "physics"}, |
299 | | - "hyperelasticity": {"title": "Hyperelasticity", "category": "physics"}, |
300 | | - "relax": {"title": "Phase Change (Relaxation)", "category": "physics"}, |
301 | | - "viscous": {"title": "Viscosity", "category": "physics"}, |
302 | | - "acoustic_source": {"title": "Acoustic Sources", "category": "physics"}, |
303 | | - "ib": {"title": "Immersed Boundaries", "category": "geometry"}, |
304 | | - "cyl_coord": {"title": "Cylindrical Coordinates", "category": "geometry"}, |
305 | | - "weno_order": {"title": "WENO Order", "category": "numerics"}, |
306 | | - "muscl_order": {"title": "MUSCL Order", "category": "numerics"}, |
307 | | - "riemann_solver": {"title": "Riemann Solver", "category": "numerics"}, |
308 | | - "model_eqns": {"title": "Model Equations", "category": "fundamentals"}, |
309 | | - "num_fluids": {"title": "Number of Fluids", "category": "fundamentals"}, |
310 | | -} |
311 | | - |
312 | | - |
313 | | -def feature_title(param: str) -> str: |
314 | | - meta = FEATURE_META.get(param) |
315 | | - if meta and "title" in meta: |
316 | | - return meta["title"] |
317 | | - return param |
| 32 | +from mfc.params.ast_analyzer import ( # noqa: E402 pylint: disable=wrong-import-position |
| 33 | + Rule, classify_message, feature_title, |
| 34 | + analyze_case_validator, |
| 35 | +) |
318 | 36 |
|
319 | 37 |
|
320 | 38 | # --------------------------------------------------------------------------- |
@@ -1023,20 +741,8 @@ def _render_cond_parts(trigger_str, cond_dict): |
1023 | 741 | # --------------------------------------------------------------------------- |
1024 | 742 |
|
1025 | 743 | def main() -> None: |
1026 | | - src = CASE_VALIDATOR_PATH.read_text(encoding="utf-8") |
1027 | | - tree = ast.parse(src, filename=str(CASE_VALIDATOR_PATH)) |
1028 | | - |
1029 | | - analyzer = CaseValidatorAnalyzer() |
1030 | | - analyzer.visit(tree) |
1031 | | - |
1032 | | - # Infer stages per method from call graph |
1033 | | - method_stages = compute_method_stages(analyzer.call_graph) |
1034 | | - |
1035 | | - # Attach stages to rules |
1036 | | - for r in analyzer.rules: |
1037 | | - r.stages = method_stages.get(r.method, set()) |
1038 | | - |
1039 | | - md = render_markdown(analyzer.rules) |
| 744 | + analysis = analyze_case_validator(CASE_VALIDATOR_PATH) |
| 745 | + md = render_markdown(analysis["rules"]) |
1040 | 746 | print(md) |
1041 | 747 |
|
1042 | 748 |
|
|
0 commit comments