Skip to content

Commit 9dbe9bc

Browse files
committed
refactor(diffctx): replace reinvented wheels per audit
- Python imports: replace regex parser with ast.parse + importlib.util.resolve_name - Tarjan SCC: replace 48-line hand-rolled algorithm with nx.strongly_connected_components - Identifier extraction: merge concepts_from_diff_text into extract_identifiers with extra_stopwords param
1 parent 9aa3e1e commit 9dbe9bc

4 files changed

Lines changed: 66 additions & 116 deletions

File tree

src/treemapper/diffctx/edges/semantic/python.py

Lines changed: 43 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# pylint: disable=duplicate-code
22
from __future__ import annotations
33

4-
import re
4+
import ast
55
from collections import defaultdict
66
from pathlib import Path
77

@@ -17,74 +17,62 @@
1717
_SYMBOL_REF_WEIGHT = _PY_WEIGHTS.symbol_ref
1818
_TYPE_REF_WEIGHT = _PY_WEIGHTS.type_ref
1919

20-
_PY_IMPORT_RE = re.compile(
21-
r"(?:from\s{1,20}(\.{0,3}[\w.]{0,200})\s{1,20}import|import\s{1,20}([\w.]{1,200}(?:\s*,\s*[\w.]{1,200})*))"
22-
)
23-
2420

2521
def _is_python_file(path: Path) -> bool:
2622
return path.suffix.lower() in _PYTHON_EXTS
2723

2824

29-
def _count_leading_dots(s: str) -> int:
30-
return len(s) - len(s.lstrip("."))
31-
32-
33-
def _resolve_relative_import(imported: str, source_path: Path, repo_root: Path | None = None) -> str | None:
34-
if not imported.startswith("."):
35-
return imported
36-
37-
dots = _count_leading_dots(imported)
38-
relative_module = imported[dots:]
39-
40-
if repo_root and source_path.is_absolute():
41-
try:
42-
source_path = source_path.relative_to(repo_root)
43-
except ValueError:
44-
pass
45-
46-
parent_parts = _strip_source_prefix(list(source_path.parent.parts))
47-
48-
if parent_parts and parent_parts[-1] == "__pycache__":
49-
parent_parts = parent_parts[:-1]
50-
51-
for _ in range(dots - 1):
52-
if parent_parts:
53-
parent_parts.pop()
54-
55-
if relative_module:
56-
parent_parts.extend(relative_module.split("."))
57-
58-
return ".".join(parent_parts) if parent_parts else None
59-
60-
6125
def _add_import_with_prefixes(imports: set[str], imported: str) -> None:
6226
imports.add(imported)
6327
parts = imported.split(".")
6428
for i in range(1, len(parts) + 1):
6529
imports.add(".".join(parts[:i]))
6630

6731

32+
def _resolve_relative(name: str, source_path: Path, repo_root: Path | None) -> str | None:
33+
try:
34+
import importlib.util
35+
36+
pkg_parts = _strip_source_prefix(list(source_path.parent.parts))
37+
if pkg_parts and pkg_parts[-1] == "__pycache__":
38+
pkg_parts = pkg_parts[:-1]
39+
if repo_root and source_path.is_absolute():
40+
try:
41+
source_path = source_path.relative_to(repo_root)
42+
pkg_parts = _strip_source_prefix(list(source_path.parent.parts))
43+
except ValueError:
44+
pass
45+
package = ".".join(pkg_parts) if pkg_parts else None
46+
if not package:
47+
return None
48+
return importlib.util.resolve_name(name, package)
49+
except (ImportError, ValueError):
50+
return None
51+
52+
6853
def _extract_imports_from_content(content: str, source_path: Path | None = None, repo_root: Path | None = None) -> set[str]:
6954
imports: set[str] = set()
70-
for match in _PY_IMPORT_RE.finditer(content):
71-
from_module = match.group(1)
72-
bare_imports = match.group(2)
73-
74-
if from_module:
75-
imported = from_module
76-
if imported.startswith(".") and source_path:
77-
resolved = _resolve_relative_import(imported, source_path, repo_root)
78-
if resolved:
79-
imported = resolved
80-
else:
81-
continue
82-
_add_import_with_prefixes(imports, imported)
83-
elif bare_imports:
84-
for name in bare_imports.split(","):
85-
name = name.strip()
86-
if name:
87-
_add_import_with_prefixes(imports, name)
55+
try:
56+
tree = ast.parse(content)
57+
except SyntaxError:
58+
return imports
59+
60+
for node in ast.walk(tree):
61+
if isinstance(node, ast.Import):
62+
for alias in node.names:
63+
if alias.name:
64+
_add_import_with_prefixes(imports, alias.name)
65+
elif isinstance(node, ast.ImportFrom):
66+
module = node.module or ""
67+
if node.level and node.level > 0:
68+
dots = "." * node.level
69+
relative = dots + module
70+
if source_path:
71+
resolved = _resolve_relative(relative, source_path, repo_root)
72+
if resolved:
73+
_add_import_with_prefixes(imports, resolved)
74+
elif module:
75+
_add_import_with_prefixes(imports, module)
8876
return imports
8977

9078

src/treemapper/diffctx/graph_analytics.py

Lines changed: 8 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22

33
import subprocess
44
from collections import defaultdict
5-
from collections.abc import Iterator
65
from dataclasses import dataclass, field
76
from pathlib import Path
87

8+
import networkx as nx
9+
910
from .project_graph import ProjectGraph, _relative_path
1011
from .types import Fragment, FragmentId
1112

@@ -123,53 +124,12 @@ def to_mermaid(qg: QuotientGraph, top_n: int = 20) -> str:
123124

124125

125126
def _tarjan_scc(adjacency: dict[str, set[str]]) -> list[list[str]]:
126-
index_counter = [0]
127-
stack: list[str] = []
128-
on_stack: set[str] = set()
129-
index: dict[str, int] = {}
130-
lowlink: dict[str, int] = {}
131-
result: list[list[str]] = []
132-
133-
for start in adjacency:
134-
if start in index:
135-
continue
136-
137-
index[start] = lowlink[start] = index_counter[0]
138-
index_counter[0] += 1
139-
stack.append(start)
140-
on_stack.add(start)
141-
142-
work: list[tuple[str, Iterator[str]]] = [(start, iter(adjacency.get(start, set())))]
143-
144-
while work:
145-
v, it = work[-1]
146-
try:
147-
w = next(it)
148-
if w not in index:
149-
index[w] = lowlink[w] = index_counter[0]
150-
index_counter[0] += 1
151-
stack.append(w)
152-
on_stack.add(w)
153-
work.append((w, iter(adjacency.get(w, set()))))
154-
elif w in on_stack:
155-
lowlink[v] = min(lowlink[v], index[w])
156-
except StopIteration:
157-
work.pop()
158-
if work:
159-
parent = work[-1][0]
160-
lowlink[parent] = min(lowlink[parent], lowlink[v])
161-
if lowlink[v] == index[v]:
162-
component: list[str] = []
163-
while True:
164-
w = stack.pop()
165-
on_stack.discard(w)
166-
component.append(w)
167-
if w == v:
168-
break
169-
if len(component) > 1:
170-
result.append(component)
171-
172-
return result
127+
g = nx.DiGraph()
128+
for node, neighbors in adjacency.items():
129+
g.add_node(node)
130+
for nbr in neighbors:
131+
g.add_edge(node, nbr)
132+
return [list(c) for c in nx.strongly_connected_components(g) if len(c) > 1]
173133

174134

175135
def detect_cycles(

src/treemapper/diffctx/types.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,16 +101,21 @@ def extract_identifiers(
101101
*,
102102
skip_stopwords: bool = False,
103103
use_nlp: bool = False,
104+
extra_stopwords: frozenset[str] | None = None,
105+
min_length: int | None = None,
104106
) -> frozenset[str]:
105107
if use_nlp and profile != "code":
106108
return _extract_tokens_nlp(text, profile=profile, use_nlp=True)
107109

108110
raw = _IDENT_RE.findall(text)
109-
min_len = TokenProfile.get_min_len(profile)
111+
min_len = min_length if min_length is not None else TokenProfile.get_min_len(profile)
112+
stopwords: frozenset[str] = frozenset()
110113
if skip_stopwords:
111114
stopwords = TokenProfile.get_stopwords(profile)
115+
if extra_stopwords:
116+
stopwords = stopwords | extra_stopwords
117+
if stopwords:
112118
return frozenset({ident.lower() for ident in raw if len(ident) >= min_len and ident.lower() not in stopwords})
113-
# Normalize to lowercase to match concepts (also lowercase)
114119
return frozenset({ident.lower() for ident in raw if len(ident) >= min_len})
115120

116121

src/treemapper/diffctx/utility.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,13 @@
1111
from .edges.structural.testing import _is_test_file
1212
from .stopwords import _DOCS_STOPWORDS, CODE_STOPWORDS
1313
from .tokenizer import extract_tokens
14-
from .types import Fragment, FragmentId
14+
from .types import Fragment, FragmentId, extract_identifiers
1515

1616
_EXPANSION_STOPWORDS = CODE_STOPWORDS | _DOCS_STOPWORDS
1717

1818
if TYPE_CHECKING:
1919
from .graph import Graph
2020

21-
_CONCEPT_RE = re.compile(r"[A-Za-z_]\w*")
2221
_CALL_RE = re.compile(r"(\w+)\s*\(")
2322
_TYPE_REF_RE = re.compile(r"(?::|->)\s*([A-Z]\w+)")
2423
_GENERIC_TYPE_RE = re.compile(r"[\[<,]\s*([A-Z]\w*)")
@@ -265,15 +264,13 @@ def concepts_from_diff_text(
265264
if use_nlp and profile != "code":
266265
return extract_tokens(text, profile=profile, use_nlp=True)
267266

268-
raw = _CONCEPT_RE.findall(text)
269-
result: set[str] = set()
270-
for ident in raw:
271-
if len(ident) < 3:
272-
continue
273-
low = ident.lower()
274-
if low not in _EXPANSION_STOPWORDS and low not in _LANGUAGE_BUILTINS:
275-
result.add(low)
276-
return frozenset(result)
267+
return extract_identifiers(
268+
text,
269+
profile=profile,
270+
skip_stopwords=True,
271+
extra_stopwords=_EXPANSION_STOPWORDS | _LANGUAGE_BUILTINS,
272+
min_length=3,
273+
)
277274

278275

279276
_CLOSURE_EDGE_CATEGORIES = frozenset({"structural", "semantic"})

0 commit comments

Comments
 (0)