Skip to content

Commit b235149

Browse files
sbryngelsonclaude
andcommitted
AST-extract constraints from case_validator.py for parameters.md
Move the AST analyzer (Rule, CaseValidatorAnalyzer, classify_message, etc.) from gen_case_constraints_docs.py into a shared module at params/ast_analyzer.py. Add f-string message extraction (106 previously invisible rules now captured, 333 total), trigger-param detection via method guards and condition analysis, and an analyze_case_validator() convenience function. docs_gen.py now reads constraints from AST-extracted rules instead of the manually maintained DEPENDENCIES dict, populating 61% of table rows with validator constraints. Parameter names in messages are wrapped in backticks with word-boundary matching. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 23f17be commit b235149

4 files changed

Lines changed: 571 additions & 345 deletions

File tree

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,10 +685,12 @@ if (MFC_DOCUMENTATION)
685685
)
686686

687687
# Generate parameters.md from parameter registry
688+
# docs_gen.py now AST-parses case_validator.py, so it must be a dependency
688689
file(GLOB_RECURSE params_SRCs CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/toolchain/mfc/params/*.py")
689690
add_custom_command(
690691
OUTPUT "${CMAKE_CURRENT_SOURCE_DIR}/docs/documentation/parameters.md"
691692
DEPENDS "${params_SRCs}"
693+
"${CMAKE_CURRENT_SOURCE_DIR}/toolchain/mfc/case_validator.py"
692694
COMMAND "bash" "${CMAKE_CURRENT_SOURCE_DIR}/docs/gen_parameters.sh"
693695
"${CMAKE_CURRENT_SOURCE_DIR}"
694696
COMMENT "Generating parameters.md"

toolchain/mfc/gen_case_constraints_docs.py

Lines changed: 8 additions & 302 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,12 @@
1010

1111
from __future__ import annotations
1212

13-
import ast
1413
import json
1514
import sys
1615
import subprocess
17-
from dataclasses import dataclass, field
16+
from dataclasses import dataclass
1817
from pathlib import Path
19-
from typing import Dict, List, Set, Iterable, Any
18+
from typing import Dict, List, Iterable, Any
2019
from collections import defaultdict
2120

2221
HERE = Path(__file__).resolve().parent
@@ -30,291 +29,10 @@
3029
sys.path.insert(0, _toolchain_dir)
3130

3231
from mfc.params import CONSTRAINTS, DEPENDENCIES, get_value_label # noqa: E402 pylint: disable=wrong-import-position
33-
34-
35-
# ---------------------------------------------------------------------------
36-
# Data structures
37-
# ---------------------------------------------------------------------------
38-
39-
@dataclass
40-
class Rule:
41-
method: str # e.g. "check_igr_simulation"
42-
lineno: int # line number of the prohibit call
43-
params: List[str] # case parameter names used in condition
44-
message: str # user-friendly error message
45-
stages: Set[str] = field(default_factory=set) # e.g. {"simulation", "pre_process"}
46-
47-
48-
# ---------------------------------------------------------------------------
49-
# AST analysis: methods, call graph, rules
50-
# ---------------------------------------------------------------------------
51-
52-
class CaseValidatorAnalyzer(ast.NodeVisitor):
53-
"""
54-
Analyzes the CaseValidator class:
55-
56-
- collects all methods
57-
- builds a call graph between methods
58-
- extracts all self.prohibit(...) rules
59-
"""
60-
61-
def __init__(self):
62-
super().__init__()
63-
self.in_case_validator = False
64-
self.current_method: str | None = None
65-
66-
self.methods: Dict[str, ast.FunctionDef] = {}
67-
self.call_graph: Dict[str, Set[str]] = defaultdict(set)
68-
self.rules: List[Rule] = []
69-
70-
# Stack of {local_name -> param_name} maps, one per method
71-
self.local_param_stack: List[Dict[str, str]] = []
72-
73-
# --- top-level entrypoint ---
74-
75-
def visit_ClassDef(self, node: ast.ClassDef):
76-
if node.name == "CaseValidator":
77-
self.in_case_validator = True
78-
# collect methods
79-
for item in node.body:
80-
if isinstance(item, ast.FunctionDef):
81-
self.methods[item.name] = item
82-
# now analyze all methods
83-
for method in self.methods.values():
84-
self._analyze_method(method)
85-
self.in_case_validator = False
86-
else:
87-
self.generic_visit(node)
88-
89-
# --- per-method analysis ---
90-
91-
def _analyze_method(self, func: ast.FunctionDef):
92-
"""Analyze a single method: local param mapping, call graph, rules."""
93-
self.current_method = func.name
94-
local_param_map = self._build_local_param_map(func)
95-
self.local_param_stack.append(local_param_map)
96-
self.generic_visit(func)
97-
self.local_param_stack.pop()
98-
self.current_method = None
99-
100-
def _build_local_param_map(self, func: ast.FunctionDef) -> Dict[str, str]: # pylint: disable=too-many-nested-blocks
101-
"""
102-
Look for assignments like:
103-
igr = self.get('igr', 'F') == 'T'
104-
model_eqns = self.get('model_eqns')
105-
and record local_name -> 'param_name'.
106-
"""
107-
m: Dict[str, str] = {}
108-
for stmt in func.body: # pylint: disable=too-many-nested-blocks
109-
if isinstance(stmt, ast.Assign):
110-
# Handle both direct calls and comparisons
111-
value = stmt.value
112-
# Unwrap comparisons like "self.get('igr', 'F') == 'T'"
113-
if isinstance(value, ast.Compare):
114-
value = value.left
115-
116-
if isinstance(value, ast.Call):
117-
call = value
118-
if ( # pylint: disable=too-many-boolean-expressions
119-
isinstance(call.func, ast.Attribute)
120-
and isinstance(call.func.value, ast.Name)
121-
and call.func.value.id == "self"
122-
and call.func.attr == "get"
123-
and call.args
124-
and isinstance(call.args[0], ast.Constant)
125-
and isinstance(call.args[0].value, str)
126-
):
127-
param_name = call.args[0].value
128-
for target in stmt.targets:
129-
if isinstance(target, ast.Name):
130-
m[target.id] = param_name
131-
return m
132-
133-
# --- visit calls to build call graph + rules ---
134-
135-
def visit_Call(self, node: ast.Call):
136-
# record method call edges: self.some_method(...)
137-
if (
138-
isinstance(node.func, ast.Attribute)
139-
and isinstance(node.func.value, ast.Name)
140-
and node.func.value.id == "self"
141-
and isinstance(node.func.attr, str)
142-
):
143-
callee = node.func.attr
144-
if self.current_method is not None:
145-
# method call on self
146-
self.call_graph[self.current_method].add(callee)
147-
148-
# detect self.prohibit(<condition>, "<message>")
149-
if (
150-
isinstance(node.func, ast.Attribute)
151-
and isinstance(node.func.value, ast.Name)
152-
and node.func.value.id == "self"
153-
and node.func.attr == "prohibit"
154-
and len(node.args) >= 2
155-
):
156-
condition, msg_node = node.args[0], node.args[1]
157-
if isinstance(msg_node, ast.Constant) and isinstance(msg_node.value, str):
158-
params = sorted(self._extract_params(condition))
159-
rule = Rule(
160-
method=self.current_method or "<unknown>",
161-
lineno=node.lineno,
162-
params=params,
163-
message=msg_node.value,
164-
)
165-
self.rules.append(rule)
166-
167-
self.generic_visit(node)
168-
169-
def _extract_params(self, condition: ast.AST) -> Set[str]:
170-
"""
171-
Collect parameter names used in the condition via:
172-
- local variables mapped from self.get(...)
173-
- direct self.get('param_name', ...) calls
174-
"""
175-
params: Set[str] = set()
176-
local_map = self.local_param_stack[-1] if self.local_param_stack else {}
177-
178-
for node in ast.walk(condition):
179-
# local names
180-
if isinstance(node, ast.Name) and node.id in local_map:
181-
params.add(local_map[node.id])
182-
183-
# direct self.get('param_name')
184-
if isinstance(node, ast.Call):
185-
if ( # pylint: disable=too-many-boolean-expressions
186-
isinstance(node.func, ast.Attribute)
187-
and isinstance(node.func.value, ast.Name)
188-
and node.func.value.id == "self"
189-
and node.func.attr == "get"
190-
and node.args
191-
and isinstance(node.args[0], ast.Constant)
192-
and isinstance(node.args[0].value, str)
193-
):
194-
params.add(node.args[0].value)
195-
196-
return params
197-
198-
199-
# ---------------------------------------------------------------------------
200-
# Stage inference from validate_* roots and call graph
201-
# ---------------------------------------------------------------------------
202-
203-
STAGE_ROOTS: Dict[str, List[str]] = {
204-
"common": ["validate_common"],
205-
"simulation": ["validate_simulation"],
206-
"pre_process": ["validate_pre_process"],
207-
"post_process": ["validate_post_process"],
208-
}
209-
210-
211-
def compute_method_stages(call_graph: Dict[str, Set[str]]) -> Dict[str, Set[str]]:
212-
"""
213-
For each stage (simulation/pre_process/post_process/common), starting from
214-
validate_* roots, walk the call graph and record which methods belong to which stages.
215-
"""
216-
method_stages: Dict[str, Set[str]] = defaultdict(set)
217-
218-
def dfs(start: str, stage: str):
219-
stack = [start]
220-
visited: Set[str] = set()
221-
while stack:
222-
m = stack.pop()
223-
if m in visited:
224-
continue
225-
visited.add(m)
226-
method_stages[m].add(stage)
227-
for nxt in call_graph.get(m, ()):
228-
if nxt not in visited:
229-
stack.append(nxt)
230-
231-
for stage, roots in STAGE_ROOTS.items():
232-
for root in roots:
233-
dfs(root, stage)
234-
235-
return method_stages
236-
237-
238-
# ---------------------------------------------------------------------------
239-
# Classification of messages for nicer grouping
240-
# ---------------------------------------------------------------------------
241-
242-
def classify_message(msg: str) -> str:
243-
"""
244-
Roughly classify rule messages for nicer grouping.
245-
246-
Returns one of: "requirement", "incompatibility", "range", "other".
247-
"""
248-
text = msg.lower()
249-
250-
if ( # pylint: disable=too-many-boolean-expressions
251-
"not compatible" in text
252-
or "does not support" in text
253-
or "cannot be used" in text
254-
or "must not" in text
255-
or "is not supported" in text
256-
or "incompatible" in text
257-
or "untested" in text
258-
):
259-
return "incompatibility"
260-
261-
if ( # pylint: disable=too-many-boolean-expressions
262-
"requires" in text
263-
or "must be set if" in text
264-
or "must be specified" in text
265-
or "must be set with" in text
266-
or "can only be enabled if" in text
267-
or "must be set for" in text
268-
):
269-
return "requirement"
270-
271-
if ( # pylint: disable=too-many-boolean-expressions
272-
"must be between" in text
273-
or "must be positive" in text
274-
or "must be non-negative" in text
275-
or "must be greater than" in text
276-
or "must be less than" in text
277-
or "must be at least" in text
278-
or "must be <=" in text
279-
or "must be >=" in text
280-
or "must be odd" in text
281-
or "divisible by" in text
282-
):
283-
return "range"
284-
285-
return "other"
286-
287-
288-
# Optional: nicer display names / categories (you can extend this)
289-
FEATURE_META = {
290-
"igr": {"title": "Iterative Generalized Riemann (IGR)", "category": "solver"},
291-
"bubbles_euler": {"title": "Euler–Euler Bubble Model", "category": "bubbles"},
292-
"bubbles_lagrange": {"title": "Euler–Lagrange Bubble Model", "category": "bubbles"},
293-
"qbmm": {"title": "Quadrature-Based Moment Method (QBMM)", "category": "bubbles"},
294-
"polydisperse": {"title": "Polydisperse Bubble Dynamics", "category": "bubbles"},
295-
"mhd": {"title": "Magnetohydrodynamics (MHD)", "category": "physics"},
296-
"alt_soundspeed": {"title": "Alternative Sound Speed", "category": "physics"},
297-
"surface_tension": {"title": "Surface Tension Model", "category": "physics"},
298-
"hypoelasticity": {"title": "Hypoelasticity", "category": "physics"},
299-
"hyperelasticity": {"title": "Hyperelasticity", "category": "physics"},
300-
"relax": {"title": "Phase Change (Relaxation)", "category": "physics"},
301-
"viscous": {"title": "Viscosity", "category": "physics"},
302-
"acoustic_source": {"title": "Acoustic Sources", "category": "physics"},
303-
"ib": {"title": "Immersed Boundaries", "category": "geometry"},
304-
"cyl_coord": {"title": "Cylindrical Coordinates", "category": "geometry"},
305-
"weno_order": {"title": "WENO Order", "category": "numerics"},
306-
"muscl_order": {"title": "MUSCL Order", "category": "numerics"},
307-
"riemann_solver": {"title": "Riemann Solver", "category": "numerics"},
308-
"model_eqns": {"title": "Model Equations", "category": "fundamentals"},
309-
"num_fluids": {"title": "Number of Fluids", "category": "fundamentals"},
310-
}
311-
312-
313-
def feature_title(param: str) -> str:
314-
meta = FEATURE_META.get(param)
315-
if meta and "title" in meta:
316-
return meta["title"]
317-
return param
32+
from mfc.params.ast_analyzer import ( # noqa: E402 pylint: disable=wrong-import-position
33+
Rule, classify_message, feature_title,
34+
analyze_case_validator,
35+
)
31836

31937

32038
# ---------------------------------------------------------------------------
@@ -1023,20 +741,8 @@ def _render_cond_parts(trigger_str, cond_dict):
1023741
# ---------------------------------------------------------------------------
1024742

1025743
def main() -> None:
1026-
src = CASE_VALIDATOR_PATH.read_text(encoding="utf-8")
1027-
tree = ast.parse(src, filename=str(CASE_VALIDATOR_PATH))
1028-
1029-
analyzer = CaseValidatorAnalyzer()
1030-
analyzer.visit(tree)
1031-
1032-
# Infer stages per method from call graph
1033-
method_stages = compute_method_stages(analyzer.call_graph)
1034-
1035-
# Attach stages to rules
1036-
for r in analyzer.rules:
1037-
r.stages = method_stages.get(r.method, set())
1038-
1039-
md = render_markdown(analyzer.rules)
744+
analysis = analyze_case_validator(CASE_VALIDATOR_PATH)
745+
md = render_markdown(analysis["rules"])
1040746
print(md)
1041747

1042748

0 commit comments

Comments
 (0)