Skip to content

Commit 78f3373

Browse files
DvirDukhanCopilot
andcommitted
fix(analyzer): correct tree-sitter resolver name→def pairing
The per-module symbol table in `_index_file` was built by zipping two independently-grouped `QueryCursor.captures()` lists (`@name` and `@def`). When `@def` positions shift relative to `@name` (e.g. decorated defs), the zip mis-pairs names with definitions, so imported-call resolution attaches CALLS edges to the wrong target — producing phantom edges to functions whose token never appears at the call site. Fix: iterate per-match via a `_matches()` helper wrapping `QueryCursor.matches()`, which guarantees each match's `@name`/`@def` captures belong together. Applied across all four indexing loops (top-level funcs, classes, assigns, class methods). Impact (deterministic graph-vs-jedi-oracle caller bench, n=40, paired, identical harness — only the resolver differs): uxarray CALLS macro-F1 0.178 → 0.713 (median 0.0 → 0.94) arkouda CALLS macro-F1 0.031 → 0.262 Adds two regression tests asserting each imported call resolves to the def whose name matches exactly (10 top-level functions, 8 classes). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 7d12be1 commit 78f3373

2 files changed

Lines changed: 97 additions & 26 deletions

File tree

api/analyzers/python/ts_resolver.py

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,21 @@ def _captures(query, root: Node) -> dict[str, list[Node]]:
167167
return cursor.captures(root)
168168

169169

170+
def _matches(query, root: Node) -> list[tuple[int, dict[str, list[Node]]]]:
171+
"""Return per-match capture groups.
172+
173+
Unlike :func:`_captures` (which groups *all* nodes by capture name into
174+
parallel lists that are **not** guaranteed to be index-aligned across
175+
different capture names), this yields one dict per match so that, e.g.,
176+
a ``@name`` capture is always paired with the ``@def`` capture from the
177+
*same* match. Zipping the two independent lists from ``captures()`` mis-
178+
aligns names and definitions whenever the per-capture node orderings
179+
diverge, scrambling the module symbol table.
180+
"""
181+
cursor = QueryCursor(query)
182+
return cursor.matches(root)
183+
184+
170185
# ---------------------------------------------------------------------------
171186
# Public resolver
172187
# ---------------------------------------------------------------------------
@@ -217,46 +232,50 @@ def _ensure_built(self, files: dict[Path, File], project_root: Path) -> None:
217232

218233
def _index_file(self, mi: _ModuleIndex, root: Node) -> None:
219234
# Top-level functions
220-
caps = _captures(self._queries.top_level_func, root)
221-
names = caps.get("name", [])
222-
defs = caps.get("def", [])
223-
for name_node, def_node in zip(names, defs):
224-
name = name_node.text.decode("utf-8")
225-
d = _Definition(mi.file_path, _strip_decorator(def_node), "func")
235+
for _, caps in _matches(self._queries.top_level_func, root):
236+
name_nodes = caps.get("name", [])
237+
def_nodes = caps.get("def", [])
238+
if not name_nodes or not def_nodes:
239+
continue
240+
name = name_nodes[0].text.decode("utf-8")
241+
d = _Definition(mi.file_path, _strip_decorator(def_nodes[0]), "func")
226242
mi.top_level[name] = d
227243
self._by_name[name].append(d)
228244

229245
# Top-level classes
230-
caps = _captures(self._queries.top_level_class, root)
231-
names = caps.get("name", [])
232-
defs = caps.get("def", [])
233-
for name_node, def_node in zip(names, defs):
234-
name = name_node.text.decode("utf-8")
235-
d = _Definition(mi.file_path, _strip_decorator(def_node), "class")
246+
for _, caps in _matches(self._queries.top_level_class, root):
247+
name_nodes = caps.get("name", [])
248+
def_nodes = caps.get("def", [])
249+
if not name_nodes or not def_nodes:
250+
continue
251+
name = name_nodes[0].text.decode("utf-8")
252+
d = _Definition(mi.file_path, _strip_decorator(def_nodes[0]), "class")
236253
mi.top_level[name] = d
237254
self._by_name[name].append(d)
238255

239256
# Top-level assignments (for class aliases like ``Foo = OtherFoo``)
240-
caps = _captures(self._queries.top_level_assign, root)
241-
names = caps.get("name", [])
242-
defs = caps.get("def", [])
243-
for name_node, def_node in zip(names, defs):
244-
name = name_node.text.decode("utf-8")
257+
for _, caps in _matches(self._queries.top_level_assign, root):
258+
name_nodes = caps.get("name", [])
259+
def_nodes = caps.get("def", [])
260+
if not name_nodes or not def_nodes:
261+
continue
262+
name = name_nodes[0].text.decode("utf-8")
245263
if name in mi.top_level:
246264
continue
247-
d = _Definition(mi.file_path, def_node, "var")
265+
d = _Definition(mi.file_path, def_nodes[0], "var")
248266
mi.top_level[name] = d
249267
self._by_name[name].append(d)
250268

251269
# Class methods
252-
caps = _captures(self._queries.class_methods, root)
253-
class_names = caps.get("class_name", [])
254-
method_names = caps.get("method_name", [])
255-
method_defs = caps.get("method_def", [])
256-
for cls_node, mname_node, mdef_node in zip(class_names, method_names, method_defs):
257-
class_name = cls_node.text.decode("utf-8")
258-
method_name = mname_node.text.decode("utf-8")
259-
d = _Definition(mi.file_path, _strip_decorator(mdef_node), "method")
270+
for _, caps in _matches(self._queries.class_methods, root):
271+
class_nodes = caps.get("class_name", [])
272+
mname_nodes = caps.get("method_name", [])
273+
mdef_nodes = caps.get("method_def", [])
274+
if not class_nodes or not mname_nodes or not mdef_nodes:
275+
continue
276+
class_name = class_nodes[0].text.decode("utf-8")
277+
method_name = mname_nodes[0].text.decode("utf-8")
278+
d = _Definition(mi.file_path, _strip_decorator(mdef_nodes[0]), "method")
260279
mi.class_methods.setdefault(class_name, {})[method_name] = d
261280
self._by_name[method_name].append(d)
262281

tests/analyzers/test_ts_python_resolver.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,58 @@ def test_resolver_unknown_name_returns_empty(tmp_path: Path):
227227
assert r.resolve(files, mod, tmp_path.resolve(), name) == []
228228

229229

230+
def test_resolver_many_defs_name_def_alignment(tmp_path: Path):
231+
"""Regression for the scrambled module symbol table.
232+
233+
With several top-level definitions in one module, pairing the ``@name``
234+
and ``@def`` captures by zipping two independently-grouped lists mis-
235+
aligned names with definitions (e.g. an imported ``arange`` call resolved
236+
to the ``array`` def node). Each imported call must resolve to the def
237+
whose name actually matches the call name.
238+
"""
239+
lib_src = "".join(f"def fn_{i}():\n return {i}\n\n" for i in range(10))
240+
import_line = "from lib import " + ", ".join(f"fn_{i}" for i in range(10))
241+
call_lines = "\n".join(f" fn_{i}()" for i in range(10))
242+
app_src = f"{import_line}\n\ndef use():\n{call_lines}\n"
243+
files = _make_project(tmp_path, {"lib.py": lib_src, "app.py": app_src})
244+
r = TreeSitterPythonResolver(_PY)
245+
app_path = (tmp_path / "app.py").resolve()
246+
lib_path = (tmp_path / "lib.py").resolve()
247+
root = files[app_path].tree.root_node
248+
for i in range(10):
249+
call = _find_call_node(root, f"fn_{i}(")
250+
out = r.resolve(
251+
files, app_path, tmp_path.resolve(), call.child_by_field_name("function")
252+
)
253+
assert len(out) == 1, f"fn_{i} did not resolve uniquely"
254+
file, def_node = out[0]
255+
assert file.path == lib_path
256+
resolved_name = def_node.child_by_field_name("name").text.decode("utf-8")
257+
assert resolved_name == f"fn_{i}", (
258+
f"call fn_{i} resolved to wrong def {resolved_name}"
259+
)
260+
261+
262+
def test_resolver_many_classes_name_def_alignment(tmp_path: Path):
263+
"""Same alignment regression for top-level classes."""
264+
lib_src = "".join(f"class Cls{i}:\n pass\n\n" for i in range(8))
265+
import_line = "from lib import " + ", ".join(f"Cls{i}" for i in range(8))
266+
body = "\n".join(f" Cls{i}()" for i in range(8))
267+
app_src = f"{import_line}\n\ndef use():\n{body}\n"
268+
files = _make_project(tmp_path, {"lib.py": lib_src, "app.py": app_src})
269+
r = TreeSitterPythonResolver(_PY)
270+
app_path = (tmp_path / "app.py").resolve()
271+
root = files[app_path].tree.root_node
272+
for i in range(8):
273+
call = _find_call_node(root, f"Cls{i}(")
274+
out = r.resolve(
275+
files, app_path, tmp_path.resolve(), call.child_by_field_name("function")
276+
)
277+
assert len(out) == 1
278+
resolved_name = out[0][1].child_by_field_name("name").text.decode("utf-8")
279+
assert resolved_name == f"Cls{i}"
280+
281+
230282
# ---------------------------------------------------------------------------
231283
# PythonAnalyzer integration via env var
232284
# ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)