Skip to content

Commit f3275b6

Browse files
committed
perf(bench): inverted-index BM25 discovery, label-keyed cell dicts
1 parent 47abf3e commit f3275b6

2 files changed

Lines changed: 78 additions & 46 deletions

File tree

benchmarks/adapters/calibrate.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -174,35 +174,40 @@ def evaluate_grid_cached( # noqa: C901 — pool teardown + per-cell demux + ret
174174

175175
evaluator = UniversalEvaluator()
176176
points = list(spec.points())
177+
# Index by `params.label()` because `RunParams.extra_env: dict` makes
178+
# the dataclass unhashable — can't use the params object itself as a
179+
# dict key. Labels are stable and uniquely identify a grid cell.
180+
points_by_label: dict[str, RunParams] = {p.label(): p for p in points}
177181

178-
ckpts: dict[RunParams, Path | None] = {
179-
p: (checkpoint_dir / f"{p.label()}.jsonl") if checkpoint_dir is not None else None for p in points
182+
ckpts: dict[str, Path | None] = {
183+
lbl: (checkpoint_dir / f"{lbl}.jsonl") if checkpoint_dir is not None else None for lbl in points_by_label
180184
}
181-
done_ids: dict[RunParams, set[str]] = {p: read_checkpoint(c) if c is not None else set() for p, c in ckpts.items()}
182-
results_by_cell: dict[RunParams, list[EvalResult]] = {
183-
p: (_load_existing_results(c, done_ids[p]) if c is not None else []) for p, c in ckpts.items()
185+
done_ids: dict[str, set[str]] = {lbl: read_checkpoint(c) if c is not None else set() for lbl, c in ckpts.items()}
186+
results_by_cell: dict[str, list[EvalResult]] = {
187+
lbl: (_load_existing_results(c, done_ids[lbl]) if c is not None else []) for lbl, c in ckpts.items()
184188
}
185189

186190
pending: list[tuple[BenchmarkInstance, list[RunParams]]] = []
187191
for inst in instances:
188-
needed = [p for p in points if inst.instance_id not in done_ids[p]]
192+
needed = [p for p in points if inst.instance_id not in done_ids[p.label()]]
189193
if needed:
190194
pending.append((inst, needed))
191195

192196
def _make_pool() -> ProcessPoolExecutor:
193197
ctx = mp.get_context("spawn")
194-
p = ProcessPoolExecutor(max_workers=workers, mp_context=ctx, max_tasks_per_child=50)
198+
p = ProcessPoolExecutor(max_workers=workers, mp_context=ctx, max_tasks_per_child=20)
195199
list(p.map(int, range(workers)))
196200
return p
197201

198202
def _record_per_cell(per_cell_results: list[tuple[RunParams, EvalResult]]) -> None:
199203
for params, result in per_cell_results:
200-
ckpt = ckpts.get(params)
204+
lbl = params.label()
205+
ckpt = ckpts.get(lbl)
201206
if ckpt is not None:
202207
err = str((result.extra or {}).get("error", ""))
203208
if "BrokenProcessPool" not in err:
204209
append_checkpoint(ckpt, result)
205-
results_by_cell[params].append(result)
210+
results_by_cell[lbl].append(result)
206211

207212
def _drain(pool: ProcessPoolExecutor) -> None:
208213
futures: dict = {}
@@ -258,11 +263,11 @@ def _drain(pool: ProcessPoolExecutor) -> None:
258263
# Recompute pending for the rebuild from current
259264
# checkpoint state — instances completed since last
260265
# rebuild should be skipped.
261-
done_ids_now = {p: read_checkpoint(c) if c is not None else set() for p, c in ckpts.items()}
266+
done_ids_now = {lbl: read_checkpoint(c) if c is not None else set() for lbl, c in ckpts.items()}
262267
pending[:] = [
263-
(inst, [p for p in points if inst.instance_id not in done_ids_now[p]])
268+
(inst, [p for p in points if inst.instance_id not in done_ids_now[p.label()]])
264269
for inst, _ in pending
265-
if any(inst.instance_id not in done_ids_now[p] for p in points)
270+
if any(inst.instance_id not in done_ids_now[p.label()] for p in points)
266271
]
267272
elif pending and pool is None:
268273
# workers == 1: serial fallback
@@ -278,11 +283,12 @@ def _drain(pool: ProcessPoolExecutor) -> None:
278283

279284
out: list[TrialResult] = []
280285
for i, params in enumerate(points):
281-
agg = evaluator.aggregate_per_benchmark(results_by_cell[params])
286+
cell_results = results_by_cell[params.label()]
287+
agg = evaluator.aggregate_per_benchmark(cell_results)
282288
trial = TrialResult(
283289
params=params,
284290
per_benchmark=agg,
285-
raw_results=tuple(results_by_cell[params]),
291+
raw_results=tuple(cell_results),
286292
)
287293
out.append(trial)
288294
if on_trial is not None:

diffctx/src/discovery.rs

Lines changed: 58 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -184,42 +184,54 @@ impl DiscoveryStrategy for BM25Discovery {
184184
if query_tokens.is_empty() {
185185
return Vec::new();
186186
}
187+
let query_set: FxHashSet<String> = query_tokens.into_iter().collect();
187188

188189
let changed_set: FxHashSet<&Path> = ctx.changed_files.iter().map(|p| p.as_path()).collect();
189-
let mut corpus: Vec<Vec<String>> = Vec::new();
190-
let mut paths: Vec<PathBuf> = Vec::new();
191190

192-
for f in &ctx.all_candidates {
193-
if changed_set.contains(f.as_path()) {
194-
continue;
195-
}
196-
let content = match ctx.read_file(f) {
197-
Some(c) => c,
198-
None => continue,
199-
};
200-
corpus.push(extract_identifier_list(
201-
&content,
202-
BM25.min_query_token_length,
203-
));
204-
paths.push(f.clone());
205-
}
191+
// Parallel tokenization: previously a serial loop, the dominant
192+
// cost on mega-repos (vscode/mui ~5k TS files). par_iter saturates
193+
// available rayon threads.
194+
let pairs: Vec<(PathBuf, Vec<String>)> = ctx
195+
.all_candidates
196+
.par_iter()
197+
.filter(|f| !changed_set.contains(f.as_path()))
198+
.filter_map(|f| {
199+
let content = ctx.read_file(f)?;
200+
Some((
201+
f.clone(),
202+
extract_identifier_list(&content, BM25.min_query_token_length),
203+
))
204+
})
205+
.collect();
206206

207-
if corpus.is_empty() {
207+
if pairs.is_empty() {
208208
return Vec::new();
209209
}
210+
let n_docs = pairs.len();
211+
if n_docs > 5000 {
212+
tracing::warn!(
213+
"BM25Discovery: large candidate corpus ({n_docs} docs) — using inverted-index fast path"
214+
);
215+
}
210216

211-
let n_docs = corpus.len();
212-
let avgdl = corpus.iter().map(|d| d.len()).sum::<usize>() as f64 / n_docs as f64;
213-
217+
// Single pass: compute df globally + inverted-index posting lists
218+
// for query terms only (skip indexing terms not in the query — they
219+
// are never needed and would balloon memory on large repos).
214220
let mut df: FxHashMap<String, usize> = FxHashMap::default();
215-
for doc in &corpus {
221+
let mut postings: FxHashMap<String, Vec<usize>> = FxHashMap::default();
222+
let mut total_len: usize = 0;
223+
for (doc_id, (_, doc)) in pairs.iter().enumerate() {
224+
total_len += doc.len();
216225
let unique: FxHashSet<&str> = doc.iter().map(|s| s.as_str()).collect();
217226
for term in unique {
218227
*df.entry(term.to_string()).or_insert(0) += 1;
228+
if query_set.contains(term) {
229+
postings.entry(term.to_string()).or_default().push(doc_id);
230+
}
219231
}
220232
}
233+
let avgdl = total_len as f64 / n_docs as f64;
221234

222-
let query_set: FxHashSet<String> = query_tokens.into_iter().collect();
223235
let idf: FxHashMap<String, f64> = query_set
224236
.iter()
225237
.map(|t| {
@@ -230,23 +242,37 @@ impl DiscoveryStrategy for BM25Discovery {
230242
})
231243
.collect();
232244

233-
let scores: Vec<f64> = corpus
245+
// Candidate doc-ids = union of posting lists for query terms. Docs
246+
// not in this set contain zero query terms and would score 0 — skip
247+
// them. This is the algorithmic win: scoring shrinks from O(N_docs)
248+
// to O(|posting-list union|), typically ~10-100× smaller on big
249+
// corpora where the query is sparse against the corpus vocabulary.
250+
let mut candidate_ids: FxHashSet<usize> = FxHashSet::default();
251+
for term in &query_set {
252+
if let Some(p) = postings.get(term) {
253+
candidate_ids.extend(p);
254+
}
255+
}
256+
if candidate_ids.is_empty() {
257+
return Vec::new();
258+
}
259+
260+
let candidate_vec: Vec<usize> = candidate_ids.into_iter().collect();
261+
let scored: Vec<(usize, f64)> = candidate_vec
234262
.par_iter()
235-
.map(|doc| Self::bm25_score(doc, &query_set, &idf, avgdl))
263+
.map(|&doc_id| {
264+
let s = Self::bm25_score(&pairs[doc_id].1, &query_set, &idf, avgdl);
265+
(doc_id, s)
266+
})
236267
.collect();
237268

238-
let mut ranked: Vec<usize> = (0..scores.len()).collect();
239-
ranked.sort_by(|&a, &b| {
240-
scores[b]
241-
.partial_cmp(&scores[a])
242-
.unwrap_or(std::cmp::Ordering::Equal)
243-
});
269+
let mut ranked: Vec<(usize, f64)> = scored.into_iter().filter(|(_, s)| *s > 0.0).collect();
270+
ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
244271

245272
ranked
246273
.into_iter()
247274
.take(self.top_k)
248-
.filter(|&i| scores[i] > 0.0)
249-
.map(|i| paths[i].clone())
275+
.map(|(i, _)| pairs[i].0.clone())
250276
.collect()
251277
}
252278
}

0 commit comments

Comments
 (0)