Skip to content

Commit 126c672

Browse files
authored
Merge pull request #698 from FalkorDB/dvirdukhan/analyzer-import-override-edges
feat(analyzers): syntactic IMPORTS + derived OVERRIDES + tree-sitter resolver name→def fix
2 parents c4d3454 + f54f818 commit 126c672

7 files changed

Lines changed: 327 additions & 26 deletions

File tree

api/analyzers/analyzer.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,39 @@ def needs_lsp(self) -> bool:
8585
"""
8686
return True
8787

88+
def build_import_index(self, files: dict[Path, File], root: Path) -> object:
89+
"""
90+
Build a language-specific index used to resolve import statements to
91+
in-repo files. Returns an opaque structure consumed by
92+
``resolve_imports``. Default: no import resolution for this language.
93+
94+
Args:
95+
files (dict[Path, File]): All parsed files keyed by absolute path.
96+
root (Path): The analyzed repository root.
97+
98+
Returns:
99+
object: Opaque index, or ``None`` when unsupported.
100+
"""
101+
102+
return None
103+
104+
def resolve_imports(self, file: File, root: Path, index: object) -> list[File]:
105+
"""
106+
Resolve the import statements of ``file`` to the in-repo files they
107+
depend on. Purely syntactic by default (no LSP). Each returned File is
108+
connected to ``file`` with an ``IMPORTS`` edge by the orchestrator.
109+
110+
Args:
111+
file (File): The importing file (already parsed; ``file.tree`` set).
112+
root (Path): The analyzed repository root.
113+
index (object): The structure returned by ``build_import_index``.
114+
115+
Returns:
116+
list[File]: In-repo files imported by ``file`` (deduped, self excluded).
117+
"""
118+
119+
return []
120+
88121
@abstractmethod
89122
def add_dependencies(self, path: Path, files: list[Path]):
90123
"""

api/analyzers/python/analyzer.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,116 @@ def add_symbols(self, entity: Entity) -> None:
136136
def is_dependency(self, file_path: str) -> bool:
137137
return "venv" in file_path
138138

139+
def _module_parts(self, file_path: Path, root: Path) -> Optional[list[str]]:
140+
"""Dotted module path components for ``file_path`` relative to ``root``."""
141+
try:
142+
rel = file_path.relative_to(root)
143+
except ValueError:
144+
return None
145+
parts = list(rel.with_suffix('').parts)
146+
if parts and parts[-1] == '__init__':
147+
parts = parts[:-1]
148+
return parts
149+
150+
def build_import_index(self, files: dict[Path, File], root: Path) -> object:
151+
"""Index in-repo files by dotted module name.
152+
153+
Two maps: ``exact`` keyed by the full dotted path from ``root`` and
154+
``suffix`` keyed by every trailing sub-path (first file wins). The
155+
suffix map tolerates ``src/``/``lib/`` layouts where the import name
156+
(``matplotlib.axes``) differs from the path-from-root
157+
(``lib.matplotlib.axes``).
158+
159+
Only Python files are indexed; ``files`` carries every analyzed
160+
source file, and a Python ``import pkg.mod`` must not resolve to a
161+
same-named non-Python file such as ``pkg/mod.java``.
162+
"""
163+
exact: dict[str, File] = {}
164+
suffix: dict[str, File] = {}
165+
for fpath, file in files.items():
166+
if fpath.suffix != '.py':
167+
continue
168+
if self.is_dependency(str(fpath)):
169+
continue
170+
parts = self._module_parts(fpath, root)
171+
if not parts:
172+
continue
173+
exact.setdefault('.'.join(parts), file)
174+
for i in range(len(parts)):
175+
suffix.setdefault('.'.join(parts[i:]), file)
176+
return {'exact': exact, 'suffix': suffix}
177+
178+
def _resolve_dotted(self, dotted: str, index: dict) -> Optional[File]:
179+
if not dotted:
180+
return None
181+
f = index['exact'].get(dotted) or index['suffix'].get(dotted)
182+
if f is None and '.' in dotted:
183+
# imported name may be a symbol inside a module; drop the last part.
184+
parent = dotted.rsplit('.', 1)[0]
185+
f = index['exact'].get(parent) or index['suffix'].get(parent)
186+
return f
187+
188+
def _import_requests(self, file: File) -> list[tuple[str, int]]:
189+
"""Extract (dotted, level) resolution requests from import statements."""
190+
requests: list[tuple[str, int]] = []
191+
captures = self._captures(
192+
"(import_statement) @i (import_from_statement) @f",
193+
file.tree.root_node,
194+
)
195+
for node in captures.get('i', []):
196+
for child in node.named_children:
197+
target = child
198+
if child.type == 'aliased_import':
199+
target = child.child_by_field_name('name')
200+
if target is not None and target.type == 'dotted_name':
201+
requests.append((target.text.decode('utf-8'), 0))
202+
for node in captures.get('f', []):
203+
module = node.child_by_field_name('module_name')
204+
level = 0
205+
base = ''
206+
if module is not None:
207+
if module.type == 'relative_import':
208+
prefix = next((c for c in module.children if c.type == 'import_prefix'), None)
209+
level = len(prefix.text.decode('utf-8')) if prefix is not None else 1
210+
dotted_part = next((c for c in module.named_children if c.type == 'dotted_name'), None)
211+
base = dotted_part.text.decode('utf-8') if dotted_part is not None else ''
212+
else:
213+
base = module.text.decode('utf-8')
214+
requests.append((base, level))
215+
for name_node in node.children_by_field_name('name'):
216+
leaf = name_node
217+
if name_node.type == 'aliased_import':
218+
leaf = name_node.child_by_field_name('name')
219+
if leaf is not None:
220+
name_txt = leaf.text.decode('utf-8')
221+
requests.append((f"{base}.{name_txt}" if base else name_txt, level))
222+
return requests
223+
224+
def resolve_imports(self, file: File, root: Path, index: object) -> list[File]:
225+
if not index:
226+
return []
227+
package_parts = self._module_parts(file.path, root)
228+
if package_parts is None:
229+
return []
230+
# Package of the importing file = its parent dotted path.
231+
package_parts = package_parts[:-1] if package_parts else []
232+
seen: set[Path] = set()
233+
targets: list[File] = []
234+
for dotted, level in self._import_requests(file):
235+
if level:
236+
base = package_parts[: len(package_parts) - (level - 1)] if level > 1 else list(package_parts)
237+
full = '.'.join([*base, dotted]) if dotted else '.'.join(base)
238+
else:
239+
full = dotted
240+
resolved = self._resolve_dotted(full, index)
241+
if resolved is None or resolved.path == file.path or resolved.path in seen:
242+
continue
243+
if self.is_dependency(str(resolved.path)):
244+
continue
245+
seen.add(resolved.path)
246+
targets.append(resolved)
247+
return targets
248+
139249
def _extract_type_target(self, node: Node) -> Optional[Node]:
140250
if node.type == 'attribute':
141251
return node.child_by_field_name('attribute')

api/analyzers/python/ts_resolver.py

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,21 @@ def _captures(query, root: Node) -> dict[str, list[Node]]:
168168
return cursor.captures(root)
169169

170170

171+
def _matches(query, root: Node) -> list[tuple[int, dict[str, list[Node]]]]:
172+
"""Return per-match capture groups.
173+
174+
Unlike :func:`_captures` (which groups *all* nodes by capture name into
175+
parallel lists that are **not** guaranteed to be index-aligned across
176+
different capture names), this yields one dict per match so that, e.g.,
177+
a ``@name`` capture is always paired with the ``@def`` capture from the
178+
*same* match. Zipping the two independent lists from ``captures()`` mis-
179+
aligns names and definitions whenever the per-capture node orderings
180+
diverge, scrambling the module symbol table.
181+
"""
182+
cursor = QueryCursor(query)
183+
return cursor.matches(root)
184+
185+
171186
# ---------------------------------------------------------------------------
172187
# Public resolver
173188
# ---------------------------------------------------------------------------
@@ -242,46 +257,50 @@ def _index_file(
242257
by_name: dict[str, list[_Definition]],
243258
) -> None:
244259
# Top-level functions
245-
caps = _captures(self._queries.top_level_func, root)
246-
names = caps.get("name", [])
247-
defs = caps.get("def", [])
248-
for name_node, def_node in zip(names, defs):
249-
name = name_node.text.decode("utf-8")
250-
d = _Definition(mi.file_path, _strip_decorator(def_node), "func")
260+
for _, caps in _matches(self._queries.top_level_func, root):
261+
name_nodes = caps.get("name", [])
262+
def_nodes = caps.get("def", [])
263+
if not name_nodes or not def_nodes:
264+
continue
265+
name = name_nodes[0].text.decode("utf-8")
266+
d = _Definition(mi.file_path, _strip_decorator(def_nodes[0]), "func")
251267
mi.top_level[name] = d
252268
by_name[name].append(d)
253269

254270
# Top-level classes
255-
caps = _captures(self._queries.top_level_class, root)
256-
names = caps.get("name", [])
257-
defs = caps.get("def", [])
258-
for name_node, def_node in zip(names, defs):
259-
name = name_node.text.decode("utf-8")
260-
d = _Definition(mi.file_path, _strip_decorator(def_node), "class")
271+
for _, caps in _matches(self._queries.top_level_class, root):
272+
name_nodes = caps.get("name", [])
273+
def_nodes = caps.get("def", [])
274+
if not name_nodes or not def_nodes:
275+
continue
276+
name = name_nodes[0].text.decode("utf-8")
277+
d = _Definition(mi.file_path, _strip_decorator(def_nodes[0]), "class")
261278
mi.top_level[name] = d
262279
by_name[name].append(d)
263280

264281
# Top-level assignments (for class aliases like ``Foo = OtherFoo``)
265-
caps = _captures(self._queries.top_level_assign, root)
266-
names = caps.get("name", [])
267-
defs = caps.get("def", [])
268-
for name_node, def_node in zip(names, defs):
269-
name = name_node.text.decode("utf-8")
282+
for _, caps in _matches(self._queries.top_level_assign, root):
283+
name_nodes = caps.get("name", [])
284+
def_nodes = caps.get("def", [])
285+
if not name_nodes or not def_nodes:
286+
continue
287+
name = name_nodes[0].text.decode("utf-8")
270288
if name in mi.top_level:
271289
continue
272-
d = _Definition(mi.file_path, def_node, "var")
290+
d = _Definition(mi.file_path, def_nodes[0], "var")
273291
mi.top_level[name] = d
274292
by_name[name].append(d)
275293

276294
# Class methods
277-
caps = _captures(self._queries.class_methods, root)
278-
class_names = caps.get("class_name", [])
279-
method_names = caps.get("method_name", [])
280-
method_defs = caps.get("method_def", [])
281-
for cls_node, mname_node, mdef_node in zip(class_names, method_names, method_defs):
282-
class_name = cls_node.text.decode("utf-8")
283-
method_name = mname_node.text.decode("utf-8")
284-
d = _Definition(mi.file_path, _strip_decorator(mdef_node), "method")
295+
for _, caps in _matches(self._queries.class_methods, root):
296+
class_nodes = caps.get("class_name", [])
297+
mname_nodes = caps.get("method_name", [])
298+
mdef_nodes = caps.get("method_def", [])
299+
if not class_nodes or not mname_nodes or not mdef_nodes:
300+
continue
301+
class_name = class_nodes[0].text.decode("utf-8")
302+
method_name = mname_nodes[0].text.decode("utf-8")
303+
d = _Definition(mi.file_path, _strip_decorator(mdef_nodes[0]), "method")
285304
mi.class_methods.setdefault(class_name, {})[method_name] = d
286305
by_name[method_name].append(d)
287306

api/analyzers/source_analyzer.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,19 +315,49 @@ def _resolve_file(file_path: Path) -> Path:
315315
elif key == "parameters":
316316
graph.connect_entities("PARAMETERS", entity.id, resolved.id)
317317

318+
def link_imports(self, graph: Graph, root: Path) -> None:
319+
"""Add ``IMPORTS`` edges (File -> File) via per-language resolution.
320+
321+
Purely syntactic for Python (no LSP), so this runs after ``first_pass``
322+
once every file has a graph id. Languages whose analyzer does not
323+
implement import resolution are silently skipped.
324+
"""
325+
indices: dict[str, object] = {}
326+
for file_path, file in self.files.items():
327+
analyzer = analyzers.get(file_path.suffix)
328+
if analyzer is None:
329+
continue
330+
if file_path.suffix not in indices:
331+
indices[file_path.suffix] = analyzer.build_import_index(self.files, root)
332+
index = indices[file_path.suffix]
333+
if not index:
334+
continue
335+
for target in analyzer.resolve_imports(file, root, index):
336+
if getattr(file, "id", None) is None or getattr(target, "id", None) is None:
337+
continue
338+
graph.connect_entities("IMPORTS", file.id, target.id)
339+
318340
def analyze_files(self, files: list[Path], path: Path, graph: Graph) -> None:
319341
self.first_pass(path, files, [], graph)
342+
self.link_imports(graph, path)
320343
self.second_pass(graph, files, path)
344+
graph.derive_overrides()
321345

322346
def analyze_sources(self, path: Path, ignore: list[str], graph: Graph) -> None:
323347
path = path.resolve()
324348
files = list(path.rglob("*.java")) + list(path.rglob("*.py")) + list(path.rglob("*.cs")) + [f for f in path.rglob("*.js") if "node_modules" not in f.parts] + list(path.rglob("*.kt")) + list(path.rglob("*.kts"))
325349
# First pass analysis of the source code
326350
self.first_pass(path, files, ignore, graph)
327351

352+
# Link import edges (syntactic, language-specific, no LSP)
353+
self.link_imports(graph, path)
354+
328355
# Second pass analysis of the source code
329356
self.second_pass(graph, files, path)
330357

358+
# Derive override edges from the resolved class hierarchy
359+
graph.derive_overrides()
360+
331361
def analyze_local_folder(self, path: str, g: Graph, ignore: Optional[list[str]] = []) -> None:
332362
"""
333363
Analyze path.

api/graph.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,40 @@ def connect_entities(self, relation: str, src_id: int, dest_id: int, properties:
612612
params = {'src_id': src_id, 'dest_id': dest_id, "properties": properties}
613613
self._query(q, params)
614614

615+
def derive_overrides(self, max_depth: int = 3) -> int:
616+
"""
617+
Derive ``OVERRIDES`` edges from the existing class hierarchy.
618+
619+
A method ``m`` on a subclass overrides method ``m2`` on an ancestor
620+
class when they share a name. Pure graph derivation over existing
621+
``EXTENDS`` + ``DEFINES`` edges, so it is language-agnostic. The edge
622+
carries ``depth`` (inheritance distance) for downstream filtering.
623+
624+
Args:
625+
max_depth (int): Maximum inheritance distance to bridge.
626+
627+
Returns:
628+
int: Number of OVERRIDES edges after derivation.
629+
"""
630+
631+
q = f"""MATCH (sub:Class)-[x:EXTENDS*1..{int(max_depth)}]->(sup:Class)
632+
WHERE ID(sub) <> ID(sup)
633+
WITH DISTINCT sub, sup, length(x) AS depth
634+
MATCH (sub)-[:DEFINES]->(m:Function)
635+
MATCH (sup)-[:DEFINES]->(m2:Function)
636+
WHERE m.name = m2.name AND ID(m) <> ID(m2)
637+
MERGE (m)-[e:OVERRIDES]->(m2)
638+
ON CREATE SET e.depth = depth"""
639+
640+
try:
641+
self._query(q)
642+
except Exception as exc: # noqa: BLE001 — derivation is best-effort
643+
logging.warning("derive_overrides failed: %s", exc)
644+
return 0
645+
646+
res = self._query("MATCH ()-[e:OVERRIDES]->() RETURN count(e)").result_set
647+
return int(res[0][0]) if res else 0
648+
615649
def function_calls_function(self, caller_id: int, callee_id: int, pos: int) -> None:
616650
"""
617651
Establish a 'CALLS' relationship between two function nodes.

tests/analyzers/test_tree_sitter_base.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,26 @@ def test_tree_sitter_multilanguage_fixture_graph_counts():
7575
{"Class": 3, "Function": 4, "Method": 2}
7676
)
7777
assert Counter(edge[0] for edge in graph.edges) == Counter({"DEFINES": 9})
78+
79+
80+
def test_build_import_index_skips_non_python_files():
81+
"""A Python ``import pkg.mod`` must not resolve to ``pkg/mod.java``.
82+
83+
``build_import_index`` receives every analyzed file (all languages), so it
84+
must only index ``.py`` files; otherwise a same-named non-Python file with
85+
the same dotted path would create spurious ``IMPORTS`` edges.
86+
"""
87+
analyzer = PythonAnalyzer()
88+
root = Path("/repo")
89+
py_file = File(root / "pkg" / "mod.py", None)
90+
java_file = File(root / "pkg" / "mod.java", None)
91+
files = {py_file.path: py_file, java_file.path: java_file}
92+
93+
index = analyzer.build_import_index(files, root)
94+
95+
assert index["exact"]["pkg.mod"] is py_file
96+
assert index["suffix"]["pkg.mod"] is py_file
97+
assert index["suffix"]["mod"] is py_file
98+
# The .java file must not have been indexed under any dotted name.
99+
assert java_file not in index["exact"].values()
100+
assert java_file not in index["suffix"].values()

0 commit comments

Comments
 (0)