Skip to content

Commit 407ef4b

Browse files
authored
Merge branch 'main' into nmulepati/fix-620-chat-completion-choices-n
2 parents b4bfd0b + bd0410b commit 407ef4b

7 files changed

Lines changed: 961 additions & 14 deletions

File tree

packages/data-designer-engine/src/data_designer/engine/processing/ginja/ast.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,15 @@
44
from __future__ import annotations
55

66
from collections import deque
7+
from typing import Any
78

89
from jinja2 import nodes as j_nodes
910

11+
# An access chain is a (root_name, accessors) pair, where accessors is the
12+
# ordered list of attribute names / subscript keys applied to the root.
13+
# E.g. ``{{ person.address["street"] }}`` -> ``("person", ["address", "street"])``.
14+
AccessChain = tuple[str, list[str | int]]
15+
1016

1117
def ast_max_depth(node: j_nodes.Node) -> int:
1218
"""Calculate the depth of a Jinja AST from a given node.
@@ -63,3 +69,117 @@ def ast_count_name_references(ast: j_nodes.Node, name: str) -> int:
6369
"""
6470
referenced_names = [node.name for node in ast.find_all(j_nodes.Name) if node.name == name]
6571
return len(referenced_names)
72+
73+
74+
def ast_extract_access_chains(root: j_nodes.Node) -> list[AccessChain]:
75+
"""Extract every top-level access chain rooted at a named variable.
76+
77+
Each output tuple is ``(root_name, accessors)`` where ``accessors``
78+
is the ordered list of attribute/subscript keys applied to the root.
79+
Top-level means the chain is not contained inside a longer chain,
80+
so ``{{ a.b.c }}`` yields ``("a", ["b", "c"])`` rather than three
81+
overlapping entries. Dynamic subscripts inside a chain (``a[b].c``)
82+
cause the chain to be skipped; the inner ``b`` is still extracted
83+
as its own chain.
84+
85+
Args:
86+
root: A parsed Jinja2 AST node (typically the ``Template``).
87+
88+
Returns:
89+
Every extractable access chain, in source-order. Duplicates are
90+
preserved so the caller can decide how to dedupe.
91+
"""
92+
chains: list[AccessChain] = []
93+
94+
def visit(node: j_nodes.Node, in_chain: bool) -> None:
95+
if isinstance(node, (j_nodes.Getattr, j_nodes.Getitem)):
96+
if not in_chain:
97+
chain = _build_access_chain(node)
98+
if chain is not None:
99+
chains.append(chain)
100+
# Descend through ``.node`` as "in chain" so we don't re-emit
101+
# the prefixes ``a`` and ``a.b`` for ``{{ a.b.c }}``.
102+
visit(node.node, in_chain=True)
103+
if isinstance(node, j_nodes.Getitem):
104+
# The subscript expression is a separate scope and may
105+
# contain its own variable references.
106+
visit(node.arg, in_chain=False)
107+
return
108+
if isinstance(node, j_nodes.Name):
109+
if not in_chain:
110+
chains.append((node.name, []))
111+
return
112+
for child in node.iter_child_nodes():
113+
visit(child, in_chain=False)
114+
115+
visit(root, in_chain=False)
116+
return chains
117+
118+
119+
def resolve_access_chain(record: dict, name: str, accessors: list[str | int]) -> tuple[bool, Any, list[str | int]]:
120+
"""Walk an access chain against a record dict.
121+
122+
Args:
123+
record: The sanitized record dict that would be used as template
124+
context. Values are expected to be JSON-compatible types.
125+
name: Root variable name.
126+
accessors: Ordered attribute names / subscript keys to apply.
127+
128+
Returns:
129+
A tuple ``(resolved, value, prefix)``:
130+
- ``resolved`` is ``True`` iff every accessor matched.
131+
- ``value`` is the final value when ``resolved`` is True,
132+
otherwise ``None``.
133+
- ``prefix`` is the longest accessor list that did match.
134+
When ``resolved`` is False, the next accessor (``accessors
135+
[len(prefix)]``) is the one that broke the chain.
136+
"""
137+
if name not in record:
138+
return (False, None, [])
139+
current: Any = record[name]
140+
prefix: list[str | int] = []
141+
for acc in accessors:
142+
if isinstance(current, dict):
143+
if not isinstance(acc, str) or acc not in current:
144+
return (False, None, prefix)
145+
current = current[acc]
146+
elif isinstance(current, list):
147+
if not isinstance(acc, int):
148+
return (False, None, prefix)
149+
if acc >= len(current) or acc < -len(current):
150+
return (False, None, prefix)
151+
current = current[acc]
152+
else:
153+
# The chain wants to go deeper but the value is a scalar.
154+
return (False, None, prefix)
155+
prefix.append(acc)
156+
return (True, current, prefix)
157+
158+
159+
def _build_access_chain(node: j_nodes.Node) -> AccessChain | None:
160+
"""Reduce a Getattr/Getitem/Name node to a top-level access chain.
161+
162+
Walks down the ``.node`` field of nested ``Getattr``/``Getitem`` nodes
163+
until reaching the root ``Name``. Returns ``None`` if the chain is
164+
rooted in something other than a ``Name`` (e.g. a function call) or
165+
if a ``Getitem`` uses a non-constant subscript expression (e.g.
166+
``a[b]`` where ``b`` is itself a variable).
167+
"""
168+
accessors: list[str | int] = []
169+
current: j_nodes.Node = node
170+
while True:
171+
if isinstance(current, j_nodes.Getattr):
172+
accessors.append(current.attr)
173+
current = current.node
174+
elif isinstance(current, j_nodes.Getitem):
175+
arg = current.arg
176+
if isinstance(arg, j_nodes.Const) and isinstance(arg.value, (str, int)):
177+
accessors.append(arg.value)
178+
current = current.node
179+
else:
180+
# Dynamic subscript like ``a[b]`` -- not a fixed access chain.
181+
return None
182+
elif isinstance(current, j_nodes.Name):
183+
return (current.name, list(reversed(accessors)))
184+
else:
185+
return None

0 commit comments

Comments
 (0)