Skip to content

Commit c5c1fde

Browse files
committed
fix: reduce cognitive complexity across diffctx pipeline (SonarCloud S3776)
1 parent de2ba0b commit c5c1fde

17 files changed

Lines changed: 588 additions & 451 deletions

File tree

benchmarks/summarize_results.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,25 @@
77
from common import load_results
88

99

10+
def _print_txt_section(txt: Path, prefix: str, title: str, markers: tuple[str, ...]) -> None:
11+
mode = txt.stem.replace(prefix, "")
12+
print(f"### {title} ({mode})\n```")
13+
for line in txt.read_text().splitlines():
14+
if line.startswith(markers):
15+
print(line)
16+
print("```\n")
17+
18+
1019
def main() -> None:
1120
results_dir = Path(sys.argv[1]) if len(sys.argv) > 1 else Path("results")
1221

1322
print("## Benchmark Results\n")
1423

1524
for txt in sorted(results_dir.glob("cb_*.txt")):
16-
mode = txt.stem.replace("cb_", "")
17-
print(f"### ContextBench ({mode})\n```")
18-
for line in txt.read_text().splitlines():
19-
if line.startswith(("Avg ", "Total:")):
20-
print(line)
21-
print("```\n")
25+
_print_txt_section(txt, "cb_", "ContextBench", ("Avg ", "Total:"))
2226

2327
for txt in sorted(results_dir.glob("loo_*.txt")):
24-
mode = txt.stem.replace("loo_", "")
25-
print(f"### LOO ({mode})\n```")
26-
for line in txt.read_text().splitlines():
27-
if line.startswith(("Total LOO", "Found")):
28-
print(line)
29-
print("```\n")
28+
_print_txt_section(txt, "loo_", "LOO", ("Total LOO", "Found"))
3029

3130
for jf in sorted(results_dir.glob("loo_*.json")):
3231
mode = jf.stem.replace("loo_", "")

src/treemapper/cli.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -63,34 +63,42 @@ def _resolve_root_dir(directory: str) -> Path:
6363
_exit_error(f"Cannot access '{directory}': {e}")
6464

6565

66-
def _expand_paths(raw_paths: list[str]) -> tuple[list[Path], list[Path]]:
66+
def _resolve_glob_pattern(pattern: str) -> list[str]:
6767
import glob as globmod
6868

69+
matches = sorted(globmod.glob(pattern, recursive=True))
70+
if matches:
71+
return matches
72+
try:
73+
p = Path(pattern).resolve(strict=True)
74+
except FileNotFoundError:
75+
_exit_error(f"No matches for '{pattern}'")
76+
except OSError as e:
77+
_exit_error(f"Cannot access '{pattern}': {e}")
78+
return [str(p)]
79+
80+
81+
def _classify_resolved(resolved: Path, dirs: list[Path], files: list[Path]) -> None:
82+
if resolved.is_dir():
83+
dirs.append(resolved)
84+
elif resolved.is_file():
85+
files.append(resolved)
86+
87+
88+
def _expand_paths(raw_paths: list[str]) -> tuple[list[Path], list[Path]]:
6989
dirs: list[Path] = []
7090
files: list[Path] = []
7191
seen: set[Path] = set()
7292
for pattern in raw_paths:
73-
matches = sorted(globmod.glob(pattern, recursive=True))
74-
if not matches:
75-
try:
76-
p = Path(pattern).resolve(strict=True)
77-
except FileNotFoundError:
78-
_exit_error(f"No matches for '{pattern}'")
79-
except OSError as e:
80-
_exit_error(f"Cannot access '{pattern}': {e}")
81-
matches = [str(p)]
82-
for m in matches:
93+
for m in _resolve_glob_pattern(pattern):
8394
try:
8495
resolved = Path(m).resolve()
8596
except OSError as e:
8697
_exit_error(f"Cannot access '{m}': {e}")
8798
if resolved in seen:
8899
continue
89100
seen.add(resolved)
90-
if resolved.is_dir():
91-
dirs.append(resolved)
92-
elif resolved.is_file():
93-
files.append(resolved)
101+
_classify_resolved(resolved, dirs, files)
94102
return dirs, files
95103

96104

src/treemapper/diffctx/edges/semantic/javascript.py

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,21 @@ def build(self, fragments: list[Fragment], repo_root: Path | None = None) -> Edg
472472
_IMPORT_WEIGHT = 0.55
473473
_REEXPORT_WEIGHT_FACTOR = 0.8
474474

475+
def _link_resolved_import(
476+
self,
477+
src_path: Path,
478+
resolved: Path,
479+
file_to_frags: dict[Path, list[FragmentId]],
480+
fragment_paths: set[Path],
481+
edges: EdgeDict,
482+
) -> None:
483+
if resolved == src_path:
484+
return
485+
target_ids = file_to_frags.get(resolved, [])
486+
if target_ids:
487+
self._link_import_pairs(file_to_frags[src_path], target_ids, edges)
488+
self._follow_reexports(resolved, file_to_frags[src_path], file_to_frags, fragment_paths, edges)
489+
475490
def _add_import_edges(
476491
self,
477492
js_frags: list[Fragment],
@@ -484,43 +499,17 @@ def _add_import_edges(
484499
file_to_frags[f.path].append(f.id)
485500

486501
fragment_paths = set(file_to_frags.keys())
487-
file_imports, alias_resolved = self._collect_imports(
488-
js_frags,
489-
info_cache,
490-
tsconfig_resolver,
491-
fragment_paths,
492-
)
502+
file_imports, alias_resolved = self._collect_imports(js_frags, info_cache, tsconfig_resolver, fragment_paths)
493503

494504
for src_path, import_sources in file_imports.items():
495505
for import_source in import_sources:
496506
resolved = _resolve_relative_import(src_path, import_source, fragment_paths)
497-
if resolved is None or resolved == src_path:
498-
continue
499-
target_ids = file_to_frags.get(resolved, [])
500-
if target_ids:
501-
self._link_import_pairs(file_to_frags[src_path], target_ids, edges)
502-
self._follow_reexports(
503-
resolved,
504-
file_to_frags[src_path],
505-
file_to_frags,
506-
fragment_paths,
507-
edges,
508-
)
507+
if resolved is not None:
508+
self._link_resolved_import(src_path, resolved, file_to_frags, fragment_paths, edges)
509509

510510
for src_path, resolved_targets in alias_resolved.items():
511511
for resolved in resolved_targets:
512-
if resolved == src_path:
513-
continue
514-
target_ids = file_to_frags.get(resolved, [])
515-
if target_ids:
516-
self._link_import_pairs(file_to_frags[src_path], target_ids, edges)
517-
self._follow_reexports(
518-
resolved,
519-
file_to_frags[src_path],
520-
file_to_frags,
521-
fragment_paths,
522-
edges,
523-
)
512+
self._link_resolved_import(src_path, resolved, file_to_frags, fragment_paths, edges)
524513

525514
@staticmethod
526515
def _collect_imports(

src/treemapper/diffctx/edges/semantic/jvm.py

Lines changed: 36 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -295,56 +295,51 @@ def _compute_import_dirs(repo_root: Path | None, import_packages: set[str]) -> s
295295
import_dirs.add(repo_root / src_prefix / Path(*pkg.split(".")))
296296
return import_dirs
297297

298-
def _discover_single_hop(
299-
self,
300-
source_files: list[Path],
301-
candidates: list[Path],
302-
repo_root: Path | None,
303-
) -> list[Path]:
304-
type_refs, import_packages = self._collect_source_refs(source_files)
305-
source_dirs = {f.parent for f in source_files}
306-
eligible_dirs = source_dirs | self._compute_import_dirs(repo_root, import_packages)
307-
source_set = set(source_files)
308-
298+
@staticmethod
299+
def _collect_frontier_classes(source_files: list[Path]) -> set[str]:
309300
frontier_classes: set[str] = set()
310-
frontier_packages: set[str] = set()
311301
for f in source_files:
312302
try:
313303
content = f.read_text(encoding="utf-8")
314304
frontier_classes.update(_extract_classes(content, f))
315-
pkg = _extract_package(content)
316-
if pkg:
317-
frontier_packages.add(pkg)
318305
except (OSError, UnicodeDecodeError):
319306
pass
307+
return frontier_classes
320308

321-
discovered: list[Path] = []
322-
for candidate in candidates:
323-
if candidate in source_set:
324-
continue
325-
try:
326-
content = candidate.read_text(encoding="utf-8")
327-
cand_classes = _extract_classes(content, candidate)
328-
329-
if candidate.parent in eligible_dirs and cand_classes & type_refs:
330-
discovered.append(candidate)
331-
continue
332-
333-
cand_type_refs = _extract_type_refs(content)
334-
if cand_type_refs & frontier_classes:
335-
discovered.append(candidate)
336-
continue
337-
338-
cand_imports = _extract_imports(content, candidate)
339-
for imp in cand_imports:
340-
imp_class = imp.rsplit(".", 1)[-1]
341-
if imp_class in frontier_classes:
342-
discovered.append(candidate)
343-
break
344-
except (OSError, UnicodeDecodeError):
345-
pass
309+
@staticmethod
310+
def _candidate_matches_frontier(
311+
candidate: Path,
312+
eligible_dirs: set[Path],
313+
type_refs: set[str],
314+
frontier_classes: set[str],
315+
) -> bool:
316+
try:
317+
content = candidate.read_text(encoding="utf-8")
318+
cand_classes = _extract_classes(content, candidate)
319+
if candidate.parent in eligible_dirs and cand_classes & type_refs:
320+
return True
321+
if _extract_type_refs(content) & frontier_classes:
322+
return True
323+
return any(imp.rsplit(".", 1)[-1] in frontier_classes for imp in _extract_imports(content, candidate))
324+
except (OSError, UnicodeDecodeError):
325+
return False
326+
327+
def _discover_single_hop(
328+
self,
329+
source_files: list[Path],
330+
candidates: list[Path],
331+
repo_root: Path | None,
332+
) -> list[Path]:
333+
type_refs, import_packages = self._collect_source_refs(source_files)
334+
eligible_dirs = {f.parent for f in source_files} | self._compute_import_dirs(repo_root, import_packages)
335+
source_set = set(source_files)
336+
frontier_classes = self._collect_frontier_classes(source_files)
346337

347-
return discovered
338+
return [
339+
c
340+
for c in candidates
341+
if c not in source_set and self._candidate_matches_frontier(c, eligible_dirs, type_refs, frontier_classes)
342+
]
348343

349344
def build(self, fragments: list[Fragment], repo_root: Path | None = None) -> EdgeDict:
350345
jvm_frags = [f for f in fragments if _is_jvm_file(f.path)]

0 commit comments

Comments
 (0)