Skip to content

Commit e1fa8c0

Browse files
fix: cache branch structure from baseline for deterministic per-test branch counts
Branch detection via _analyze() depends on observed arcs, which vary with thread timing. Now the baseline snapshot caches branch point structure (totals per line), and per-test snapshots reuse that cache — only computing covered counts from the current test's arcs. This eliminates flaky branch totals (was 12/18 or 12/22 randomly, now consistently 4/10).
1 parent 06b6c27 commit e1fa8c0

1 file changed

Lines changed: 58 additions & 2 deletions

File tree

drift/core/coverage_server.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@
2525
_source_root: str | None = None
2626
_lock = threading.Lock()
2727

28+
# Cache branch structure from baseline to ensure deterministic branch counts.
29+
# Branch detection via _analyze() depends on observed arcs, which vary with
30+
# thread timing. By caching from the baseline (which has the fullest data),
31+
# per-test snapshots report consistent totals.
32+
_branch_cache: dict[str, dict] | None = None
33+
2834

2935
def start_coverage_collection() -> bool:
3036
"""Initialize coverage.py collection if TUSK_COVERAGE is set.
@@ -85,14 +91,15 @@ def start_coverage_collection() -> bool:
8591

8692
def stop_coverage_collection() -> None:
8793
"""Stop coverage collection and clean up. Thread-safe."""
88-
global _cov_instance
94+
global _cov_instance, _branch_cache
8995
with _lock:
9096
if _cov_instance is not None:
9197
try:
9298
_cov_instance.stop()
9399
except Exception:
94100
pass
95101
_cov_instance = None
102+
_branch_cache = None
96103

97104

98105
def take_coverage_snapshot(baseline: bool = False) -> dict:
@@ -115,7 +122,10 @@ def take_coverage_snapshot(baseline: bool = False) -> dict:
115122
coverage = {}
116123

117124
try:
125+
global _branch_cache
118126
if baseline:
127+
# Baseline: compute fresh branch data and cache it for per-test reuse
128+
_branch_cache = {}
119129
data = _cov_instance.get_data()
120130
for filename in data.measured_files():
121131
if not _is_user_file(filename):
@@ -127,6 +137,7 @@ def take_coverage_snapshot(baseline: bool = False) -> dict:
127137
for line in statements:
128138
lines_map[str(line)] = 0 if line in missing_set else 1
129139
branch_data = _get_branch_data(data, filename)
140+
_branch_cache[filename] = branch_data
130141
if lines_map:
131142
coverage[filename] = {"lines": lines_map, **branch_data}
132143
except Exception as e:
@@ -139,7 +150,12 @@ def take_coverage_snapshot(baseline: bool = False) -> dict:
139150
continue
140151
lines = data.lines(filename)
141152
if lines:
142-
branch_data = _get_branch_data(data, filename)
153+
# Use cached branch data from baseline for stable totals.
154+
# Fall back to live _analyze() if no cache (e.g., no baseline taken).
155+
if _branch_cache is not None and filename in _branch_cache:
156+
branch_data = _get_per_test_branch_data(data, filename, _branch_cache[filename])
157+
else:
158+
branch_data = _get_branch_data(data, filename)
143159
coverage[filename] = {
144160
"lines": {str(line): 1 for line in lines},
145161
**branch_data,
@@ -217,3 +233,43 @@ def _get_branch_data(data, filename: str) -> dict:
217233
}
218234
except Exception:
219235
return {"totalBranches": 0, "coveredBranches": 0, "branches": {}}
236+
237+
238+
def _get_per_test_branch_data(data, filename: str, cached: dict) -> dict:
239+
"""Compute per-test branch coverage using cached branch structure from baseline.
240+
241+
Uses the cached branch point set (from baseline) for stable totals,
242+
but computes covered counts from the current test's executed arcs.
243+
This avoids flaky branch totals caused by non-deterministic arc detection.
244+
"""
245+
try:
246+
if not data.has_arcs():
247+
return {"totalBranches": 0, "coveredBranches": 0, "branches": {}}
248+
249+
executed_arcs = set(data.arcs(filename) or [])
250+
251+
# Group executed arcs by from_line (skip negative entry arcs)
252+
executed_by_line: dict[int, list[int]] = {}
253+
for from_line, to_line in executed_arcs:
254+
if from_line < 0:
255+
continue
256+
executed_by_line.setdefault(from_line, []).append(to_line)
257+
258+
# Use cached branch points — only compute covered from current arcs
259+
cached_branches = cached.get("branches", {})
260+
branches: dict[str, dict] = {}
261+
total_covered = 0
262+
263+
for line_str, info in cached_branches.items():
264+
total = info["total"]
265+
covered = min(len(executed_by_line.get(int(line_str), [])), total)
266+
branches[line_str] = {"total": total, "covered": covered}
267+
total_covered += covered
268+
269+
return {
270+
"totalBranches": cached.get("totalBranches", 0),
271+
"coveredBranches": total_covered,
272+
"branches": branches,
273+
}
274+
except Exception:
275+
return {"totalBranches": 0, "coveredBranches": 0, "branches": {}}

0 commit comments

Comments
 (0)