|
4 | 4 | from __future__ import annotations |
5 | 5 |
|
6 | 6 | from collections import deque |
| 7 | +from typing import Any |
7 | 8 |
|
8 | 9 | from jinja2 import nodes as j_nodes |
9 | 10 |
|
| 11 | +# An access chain is a (root_name, accessors) pair, where accessors is the |
| 12 | +# ordered list of attribute names / subscript keys applied to the root. |
| 13 | +# E.g. ``{{ person.address["street"] }}`` -> ``("person", ["address", "street"])``. |
| 14 | +AccessChain = tuple[str, list[str | int]] |
| 15 | + |
10 | 16 |
|
11 | 17 | def ast_max_depth(node: j_nodes.Node) -> int: |
12 | 18 | """Calculate the depth of a Jinja AST from a given node. |
@@ -63,3 +69,117 @@ def ast_count_name_references(ast: j_nodes.Node, name: str) -> int: |
63 | 69 | """ |
64 | 70 | referenced_names = [node.name for node in ast.find_all(j_nodes.Name) if node.name == name] |
65 | 71 | return len(referenced_names) |
| 72 | + |
| 73 | + |
| 74 | +def ast_extract_access_chains(root: j_nodes.Node) -> list[AccessChain]: |
| 75 | + """Extract every top-level access chain rooted at a named variable. |
| 76 | +
|
| 77 | + Each output tuple is ``(root_name, accessors)`` where ``accessors`` |
| 78 | + is the ordered list of attribute/subscript keys applied to the root. |
| 79 | + Top-level means the chain is not contained inside a longer chain, |
| 80 | + so ``{{ a.b.c }}`` yields ``("a", ["b", "c"])`` rather than three |
| 81 | + overlapping entries. Dynamic subscripts inside a chain (``a[b].c``) |
| 82 | + cause the chain to be skipped; the inner ``b`` is still extracted |
| 83 | + as its own chain. |
| 84 | +
|
| 85 | + Args: |
| 86 | + root: A parsed Jinja2 AST node (typically the ``Template``). |
| 87 | +
|
| 88 | + Returns: |
| 89 | + Every extractable access chain, in source-order. Duplicates are |
| 90 | + preserved so the caller can decide how to dedupe. |
| 91 | + """ |
| 92 | + chains: list[AccessChain] = [] |
| 93 | + |
| 94 | + def visit(node: j_nodes.Node, in_chain: bool) -> None: |
| 95 | + if isinstance(node, (j_nodes.Getattr, j_nodes.Getitem)): |
| 96 | + if not in_chain: |
| 97 | + chain = _build_access_chain(node) |
| 98 | + if chain is not None: |
| 99 | + chains.append(chain) |
| 100 | + # Descend through ``.node`` as "in chain" so we don't re-emit |
| 101 | + # the prefixes ``a`` and ``a.b`` for ``{{ a.b.c }}``. |
| 102 | + visit(node.node, in_chain=True) |
| 103 | + if isinstance(node, j_nodes.Getitem): |
| 104 | + # The subscript expression is a separate scope and may |
| 105 | + # contain its own variable references. |
| 106 | + visit(node.arg, in_chain=False) |
| 107 | + return |
| 108 | + if isinstance(node, j_nodes.Name): |
| 109 | + if not in_chain: |
| 110 | + chains.append((node.name, [])) |
| 111 | + return |
| 112 | + for child in node.iter_child_nodes(): |
| 113 | + visit(child, in_chain=False) |
| 114 | + |
| 115 | + visit(root, in_chain=False) |
| 116 | + return chains |
| 117 | + |
| 118 | + |
| 119 | +def resolve_access_chain(record: dict, name: str, accessors: list[str | int]) -> tuple[bool, Any, list[str | int]]: |
| 120 | + """Walk an access chain against a record dict. |
| 121 | +
|
| 122 | + Args: |
| 123 | + record: The sanitized record dict that would be used as template |
| 124 | + context. Values are expected to be JSON-compatible types. |
| 125 | + name: Root variable name. |
| 126 | + accessors: Ordered attribute names / subscript keys to apply. |
| 127 | +
|
| 128 | + Returns: |
| 129 | + A tuple ``(resolved, value, prefix)``: |
| 130 | + - ``resolved`` is ``True`` iff every accessor matched. |
| 131 | + - ``value`` is the final value when ``resolved`` is True, |
| 132 | + otherwise ``None``. |
| 133 | + - ``prefix`` is the longest accessor list that did match. |
| 134 | + When ``resolved`` is False, the next accessor (``accessors |
| 135 | + [len(prefix)]``) is the one that broke the chain. |
| 136 | + """ |
| 137 | + if name not in record: |
| 138 | + return (False, None, []) |
| 139 | + current: Any = record[name] |
| 140 | + prefix: list[str | int] = [] |
| 141 | + for acc in accessors: |
| 142 | + if isinstance(current, dict): |
| 143 | + if not isinstance(acc, str) or acc not in current: |
| 144 | + return (False, None, prefix) |
| 145 | + current = current[acc] |
| 146 | + elif isinstance(current, list): |
| 147 | + if not isinstance(acc, int): |
| 148 | + return (False, None, prefix) |
| 149 | + if acc >= len(current) or acc < -len(current): |
| 150 | + return (False, None, prefix) |
| 151 | + current = current[acc] |
| 152 | + else: |
| 153 | + # The chain wants to go deeper but the value is a scalar. |
| 154 | + return (False, None, prefix) |
| 155 | + prefix.append(acc) |
| 156 | + return (True, current, prefix) |
| 157 | + |
| 158 | + |
| 159 | +def _build_access_chain(node: j_nodes.Node) -> AccessChain | None: |
| 160 | + """Reduce a Getattr/Getitem/Name node to a top-level access chain. |
| 161 | +
|
| 162 | + Walks down the ``.node`` field of nested ``Getattr``/``Getitem`` nodes |
| 163 | + until reaching the root ``Name``. Returns ``None`` if the chain is |
| 164 | + rooted in something other than a ``Name`` (e.g. a function call) or |
| 165 | + if a ``Getitem`` uses a non-constant subscript expression (e.g. |
| 166 | + ``a[b]`` where ``b`` is itself a variable). |
| 167 | + """ |
| 168 | + accessors: list[str | int] = [] |
| 169 | + current: j_nodes.Node = node |
| 170 | + while True: |
| 171 | + if isinstance(current, j_nodes.Getattr): |
| 172 | + accessors.append(current.attr) |
| 173 | + current = current.node |
| 174 | + elif isinstance(current, j_nodes.Getitem): |
| 175 | + arg = current.arg |
| 176 | + if isinstance(arg, j_nodes.Const) and isinstance(arg.value, (str, int)): |
| 177 | + accessors.append(arg.value) |
| 178 | + current = current.node |
| 179 | + else: |
| 180 | + # Dynamic subscript like ``a[b]`` -- not a fixed access chain. |
| 181 | + return None |
| 182 | + elif isinstance(current, j_nodes.Name): |
| 183 | + return (current.name, list(reversed(accessors))) |
| 184 | + else: |
| 185 | + return None |
0 commit comments