Skip to content

Commit 0d5d41c

Browse files
committed
refactor: reduce cognitive complexity across diffctx modules
1 parent 45cbe43 commit 0d5d41c

8 files changed

Lines changed: 344 additions & 216 deletions

File tree

src/treemapper/diffctx/__init__.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,33 @@ def _validate_inputs(root_dir: Path, alpha: float, tau: float, budget_tokens: in
316316
raise ValueError(f"budget_tokens must be > 0, got {budget_tokens}")
317317

318318

319+
def _find_dangling_semantic_names(
320+
selected: list[Fragment],
321+
graph: Graph,
322+
frag_by_id: dict[FragmentId, Fragment],
323+
selected_ids: set[FragmentId],
324+
) -> set[str]:
325+
dangling: set[str] = set()
326+
for frag in selected:
327+
for nbr_id in graph.neighbors(frag.id):
328+
if nbr_id in selected_ids:
329+
continue
330+
if graph.edge_categories.get((frag.id, nbr_id), "") != "semantic":
331+
continue
332+
nbr_frag = frag_by_id.get(nbr_id)
333+
if nbr_frag and nbr_frag.symbol_name:
334+
dangling.add(nbr_frag.symbol_name.lower())
335+
return dangling
336+
337+
338+
def _pick_best_fragment(candidates: list[Fragment], selected_ids: set[FragmentId]) -> Fragment | None:
339+
if any(c.id in selected_ids for c in candidates):
340+
return None
341+
sig_candidates = [f for f in candidates if "_signature" in f.kind]
342+
full_candidates = [f for f in candidates if "_signature" not in f.kind]
343+
return next(iter(sig_candidates or full_candidates), None)
344+
345+
319346
def _coherence_post_pass(
320347
result: SelectionResult,
321348
all_fragments: list[Fragment],
@@ -331,32 +358,15 @@ def _coherence_post_pass(
331358
name_to_frags.setdefault(f.symbol_name.lower(), []).append(f)
332359

333360
frag_by_id: dict[FragmentId, Fragment] = {f.id: f for f in all_fragments}
334-
335-
dangling_names: set[str] = set()
336-
for frag in result.selected:
337-
for nbr_id in graph.neighbors(frag.id):
338-
if nbr_id in selected_ids:
339-
continue
340-
cat = graph.edge_categories.get((frag.id, nbr_id), "")
341-
if cat == "semantic":
342-
nbr_frag = frag_by_id.get(nbr_id)
343-
if nbr_frag and nbr_frag.symbol_name:
344-
dangling_names.add(nbr_frag.symbol_name.lower())
361+
dangling_names = _find_dangling_semantic_names(result.selected, graph, frag_by_id, selected_ids)
345362

346363
added: list[Fragment] = []
347364
for name in dangling_names:
348-
candidates = name_to_frags.get(name, [])
349-
for c in candidates:
350-
if c.id in selected_ids:
351-
break
352-
else:
353-
sig_candidates = [f for f in candidates if "_signature" in f.kind]
354-
full_candidates = [f for f in candidates if "_signature" not in f.kind]
355-
pick = next(iter(sig_candidates or full_candidates), None)
356-
if pick and pick.token_count <= remaining and pick.id not in selected_ids:
357-
added.append(pick)
358-
selected_ids.add(pick.id)
359-
remaining -= pick.token_count
365+
pick = _pick_best_fragment(name_to_frags.get(name, []), selected_ids)
366+
if pick and pick.token_count <= remaining and pick.id not in selected_ids:
367+
added.append(pick)
368+
selected_ids.add(pick.id)
369+
remaining -= pick.token_count
360370

361371
if not added:
362372
return result

src/treemapper/diffctx/edges/base.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,26 +22,31 @@
2222
_INDEX_FILE_STEMS = frozenset({"__init__", "index", "mod"})
2323

2424

25+
def _strip_source_prefix(parts: list[str]) -> list[str]:
26+
for i, part in enumerate(parts):
27+
if part in ("src", "lib", "packages"):
28+
return parts[i + 1 :]
29+
return parts
30+
31+
32+
def _strip_file_extension(stem: str) -> str:
33+
for ext in sorted(_STRIP_EXTENSIONS, key=len, reverse=True):
34+
if stem.endswith(ext):
35+
return stem[: -len(ext)]
36+
return stem
37+
38+
2539
def path_to_module(path: Path, repo_root: Path | None = None) -> str:
2640
if repo_root and path.is_absolute():
2741
try:
2842
path = path.relative_to(repo_root)
2943
except ValueError:
3044
pass
3145

32-
parts = list(path.parts)
33-
34-
for i, part in enumerate(parts):
35-
if part in ("src", "lib", "packages"):
36-
parts = parts[i + 1 :]
37-
break
46+
parts = _strip_source_prefix(list(path.parts))
3847

3948
if parts:
40-
stem = parts[-1]
41-
for ext in sorted(_STRIP_EXTENSIONS, key=len, reverse=True):
42-
if stem.endswith(ext):
43-
parts[-1] = stem[: -len(ext)]
44-
break
49+
parts[-1] = _strip_file_extension(parts[-1])
4550
if parts and parts[-1] in _INDEX_FILE_STEMS:
4651
parts = parts[:-1]
4752

src/treemapper/diffctx/edges/semantic/jvm.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
re.MULTILINE,
2727
)
2828
_KOTLIN_FUN_RE = re.compile(
29-
r"^\s*(?:(?:public|private|internal|protected|abstract|open|final|override|inline|external|tailrec|operator|infix|suspend|expect|actual)\s+)*fun\s+(?:<[^>]+>\s+)?([a-z]\w*)",
29+
r"^\s*(?:\w+\s+)*fun\s+(?:<[^>]+>\s+)?([a-z]\w*)",
3030
re.MULTILINE,
3131
)
3232

@@ -38,9 +38,7 @@
3838
_SCALA_DEF_RE = re.compile(r"^\s*(?:private |protected )?def\s+([a-z]\w*)", re.MULTILINE)
3939

4040
_TYPE_REF_RE = re.compile(r"(?<![a-z_])([A-Z]\w*)\b")
41-
_KOTLIN_INHERIT_RE = re.compile(
42-
r"(?:class|interface|object)\s+\w+(?:<[^>]*>)?(?:\([^)]*\))?\s*:\s*([A-Z]\w*(?:\s*,\s*[A-Z]\w*)*)"
43-
)
41+
_KOTLIN_INHERIT_RE = re.compile(r"(?:class|interface|object)\s+\w+[^:\n]*:\s*([A-Z]\w*(?:\s*,\s*[A-Z]\w*)*)")
4442
_SCALA_WITH_RE = re.compile(r"\bwith\s+([A-Z]\w*)")
4543

4644
_ANNOTATION_RE = re.compile(r"@([A-Z]\w*)")
@@ -88,27 +86,25 @@ def _extract_classes(content: str, path: Path) -> set[str]:
8886
return classes
8987

9088

91-
def _extract_inheritance(content: str, path: Path) -> set[str]:
89+
def _split_class_list(regex: re.Pattern[str], content: str) -> set[str]:
9290
refs: set[str] = set()
91+
for m in regex.finditer(content):
92+
for cls in m.group(1).split(","):
93+
stripped = cls.strip()
94+
if stripped:
95+
refs.add(stripped)
96+
return refs
97+
98+
99+
def _extract_inheritance(content: str, path: Path) -> set[str]:
93100
if _is_kotlin(path):
94-
for m in _KOTLIN_INHERIT_RE.finditer(content):
95-
for cls in m.group(1).split(","):
96-
cls = cls.strip()
97-
if cls:
98-
refs.add(cls)
99-
elif _is_scala(path):
100-
for m in _JAVA_EXTENDS_RE.finditer(content):
101-
for cls in m.group(1).split(","):
102-
refs.add(cls.strip())
103-
for m in _SCALA_WITH_RE.finditer(content):
104-
refs.add(m.group(1))
105-
else:
106-
for m in _JAVA_EXTENDS_RE.finditer(content):
107-
for cls in m.group(1).split(","):
108-
refs.add(cls.strip())
109-
for m in _JAVA_IMPLEMENTS_RE.finditer(content):
110-
for cls in m.group(1).split(","):
111-
refs.add(cls.strip())
101+
return _split_class_list(_KOTLIN_INHERIT_RE, content)
102+
if _is_scala(path):
103+
refs = _split_class_list(_JAVA_EXTENDS_RE, content)
104+
refs.update(m.group(1) for m in _SCALA_WITH_RE.finditer(content))
105+
return refs
106+
refs = _split_class_list(_JAVA_EXTENDS_RE, content)
107+
refs.update(_split_class_list(_JAVA_IMPLEMENTS_RE, content))
112108
return refs
113109

114110

src/treemapper/diffctx/edges/structural/test.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -100,32 +100,34 @@ def discover_related_files(
100100
repo_root: Path | None = None,
101101
) -> list[Path]:
102102
changed_set = set(changed_files)
103-
discovered: list[Path] = []
104-
105103
candidate_by_stem: dict[str, list[Path]] = defaultdict(list)
106104
for c in all_candidate_files:
107105
if c not in changed_set:
108106
candidate_by_stem[c.stem.lower()].append(c)
109107

108+
discovered: list[Path] = []
110109
for changed in changed_files:
111-
stem = changed.stem.lower()
112110
suffix = changed.suffix.lower()
113-
114111
if _is_test_file(changed):
115-
target = _extract_target_name_from_test(changed.stem)
116-
if target:
117-
for cand in candidate_by_stem.get(target, []):
118-
if cand.suffix.lower() == suffix:
119-
discovered.append(cand)
112+
discovered.extend(self._find_source_for_test(changed, suffix, candidate_by_stem))
120113
else:
121-
test_stems = [f"test_{stem}", f"{stem}_test"]
122-
for ts in test_stems:
123-
for cand in candidate_by_stem.get(ts, []):
124-
if cand.suffix.lower() == suffix and _is_test_file(cand):
125-
discovered.append(cand)
126-
114+
discovered.extend(self._find_tests_for_source(changed.stem.lower(), suffix, candidate_by_stem))
127115
return discovered
128116

117+
@staticmethod
118+
def _find_source_for_test(changed: Path, suffix: str, candidate_by_stem: dict[str, list[Path]]) -> list[Path]:
119+
target = _extract_target_name_from_test(changed.stem)
120+
if not target:
121+
return []
122+
return [c for c in candidate_by_stem.get(target, []) if c.suffix.lower() == suffix]
123+
124+
@staticmethod
125+
def _find_tests_for_source(stem: str, suffix: str, candidate_by_stem: dict[str, list[Path]]) -> list[Path]:
126+
results: list[Path] = []
127+
for ts in (f"test_{stem}", f"{stem}_test"):
128+
results.extend(c for c in candidate_by_stem.get(ts, []) if c.suffix.lower() == suffix and _is_test_file(c))
129+
return results
130+
129131
def build(self, fragments: list[Fragment], repo_root: Path | None = None) -> EdgeDict:
130132
edges: EdgeDict = {}
131133

src/treemapper/diffctx/parsers/markdown.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -63,28 +63,42 @@ def _find_all_headings(self, lines: list[str]) -> list[tuple[int, int]]:
6363
headings: list[tuple[int, int]] = []
6464
i = 0
6565
while i < len(lines):
66-
stripped = lines[i].lstrip()
67-
if stripped.startswith("#"):
68-
level = len(stripped) - len(stripped.lstrip("#"))
69-
if level <= 6 and (len(stripped) == level or stripped[level] == " "):
70-
headings.append((i + 1, level))
66+
result = self._try_atx_heading(lines[i])
67+
if result is not None:
68+
headings.append((i + 1, result))
7169
i += 1
7270
continue
7371

74-
if stripped and i + 1 < len(lines):
75-
next_stripped = lines[i + 1].strip()
76-
if next_stripped and all(c == "=" for c in next_stripped):
77-
headings.append((i + 1, 1))
78-
i += 2
79-
continue
80-
if len(next_stripped) >= 2 and all(c == "-" for c in next_stripped):
81-
headings.append((i + 1, 2))
82-
i += 2
83-
continue
72+
setext = self._try_setext_heading(lines, i)
73+
if setext is not None:
74+
headings.append((i + 1, setext))
75+
i += 2
76+
continue
8477

8578
i += 1
8679
return headings
8780

81+
@staticmethod
82+
def _try_atx_heading(line: str) -> int | None:
83+
stripped = line.lstrip()
84+
if not stripped.startswith("#"):
85+
return None
86+
level = len(stripped) - len(stripped.lstrip("#"))
87+
if level <= 6 and (len(stripped) == level or stripped[level] == " "):
88+
return level
89+
return None
90+
91+
@staticmethod
92+
def _try_setext_heading(lines: list[str], i: int) -> int | None:
93+
if not lines[i].lstrip() or i + 1 >= len(lines):
94+
return None
95+
next_stripped = lines[i + 1].strip()
96+
if next_stripped and all(c == "=" for c in next_stripped):
97+
return 1
98+
if len(next_stripped) >= 2 and all(c == "-" for c in next_stripped):
99+
return 2
100+
return None
101+
88102
def _find_section_end(self, lines: list[str], _start_line: int, level: int, remaining_headings: list[tuple[int, int]]) -> int:
89103
for next_line, next_level in remaining_headings:
90104
if next_level <= level:

0 commit comments

Comments
 (0)