Skip to content

Commit 18772e5

Browse files
committed
fix(rerank): unswallow exception, shim Qwen2 prepare_for_model, cache reranker on backend
Three independent retrieval-path bugs surfaced by Phase 10 --full LongMemEval_S: 1. tuned_hybrid.py was swallowing every reranker exception with bare `except Exception`, hiding the actual failure cause and violating the project hard constraint (CLAUDE.md: "NEVER suppress errors in indexing/retrieval paths"). Replaced with type+message logging plus full traceback. 2. mxbai-rerank 0.1.6 calls `tokenizer.prepare_for_model(...)` unconditionally, but transformers >=4.50 no longer exposes this method on the slow Qwen2Tokenizer (and PreTrainedTokenizerBase no longer provides a fallback impl). mxbai-rerank upstream is effectively unmaintained (only 1 PR ever merged, no fix released). Bound a minimal hand-written prepare_for_model on the tokenizer instance covering the exact call signature mxbai uses (add_special_tokens=False, padding=False, truncation="only_second"). 3. tuned_hybrid.query() called load_reranker() on every retrieval call, constructing a fresh wrapper with `_model=None` each time, which then triggered a fresh model weight load. Cache the reranker on the backend instance keyed by name so a long-lived process loads weights exactly once. Verified via uv run pytest (620 passed, 4 skipped, 1 upstream warning) and end-to-end via supamem eval --suite longmemeval_s --full (470 questions; rerank now actually runs, weights load once per process).
1 parent 7b3541c commit 18772e5

2 files changed

Lines changed: 93 additions & 8 deletions

File tree

src/supamem/rerankers/mxbai_v2.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,66 @@ def _ensure(self) -> Any:
3535
from mxbai_rerank import MxbaiRerankV2 # noqa: PLC0415
3636

3737
self._model = MxbaiRerankV2(self.config.reranker_model_id)
38+
39+
# Compatibility shim: mxbai-rerank 0.1.6 calls
40+
# `tokenizer.prepare_for_model(ids, pair_ids, ...)` unconditionally,
41+
# but the slow Qwen2Tokenizer in transformers >=4.50 no longer
42+
# exposes this method (and PreTrainedTokenizerBase no longer
43+
# provides a fallback impl as of recent releases). Upstream
44+
# mxbai-rerank is unmaintained (no fix released as of 2026-05-04).
45+
# Implement the minimal call signature mxbai uses:
46+
# add_special_tokens=False, padding=False, truncation="only_second".
47+
try:
48+
tok = getattr(self._model, "tokenizer", None)
49+
if tok is not None and not hasattr(tok, "prepare_for_model"):
50+
def _shim_prepare_for_model(
51+
ids,
52+
pair_ids=None,
53+
*,
54+
truncation=None,
55+
max_length=None,
56+
padding=False, # noqa: ARG001
57+
return_attention_mask=False,
58+
return_token_type_ids=False,
59+
add_special_tokens=False, # noqa: ARG001
60+
**_kwargs,
61+
):
62+
a = list(ids)
63+
b = list(pair_ids) if pair_ids is not None else []
64+
if max_length is not None:
65+
if truncation == "only_second":
66+
budget = max_length - len(a)
67+
if budget < 0:
68+
a = a[:max_length]
69+
b = []
70+
elif len(b) > budget:
71+
b = b[:budget]
72+
elif truncation in (
73+
True,
74+
"longest_first",
75+
"only_first",
76+
):
77+
while len(a) + len(b) > max_length:
78+
if truncation == "only_first" or len(a) > len(b):
79+
a.pop()
80+
else:
81+
b.pop()
82+
combined = a + b
83+
out: dict[str, Any] = {"input_ids": combined}
84+
if return_attention_mask:
85+
out["attention_mask"] = [1] * len(combined)
86+
if return_token_type_ids:
87+
out["token_type_ids"] = [0] * len(a) + [1] * len(b)
88+
return out
89+
90+
tok.prepare_for_model = _shim_prepare_for_model
91+
except Exception as _shim_exc: # noqa: BLE001
92+
err_console.print(
93+
"[supamem.warn]prepare_for_model shim failed "
94+
f"({type(_shim_exc).__name__}: {_shim_exc}); "
95+
"rerank may still raise"
96+
)
97+
3898
elapsed_ms = (time.perf_counter() - t0) * 1000.0
3999
try:
40100
from supamem.stats.counter import bump # noqa: PLC0415

src/supamem/retrieval/tuned_hybrid.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,12 @@ def __init__(self, *, config: ResolvedConfig, minimal_setup: bool = False) -> No
171171
self._client: Any | None = None
172172
self._dense: Any | None = None
173173
self._sparse: Any | None = None
174+
# D-POOL: cache the reranker plugin instance on the backend so a
175+
# process-long retrieval session reuses one model load. load_reranker
176+
# constructs a fresh wrapper per call; without this cache the
177+
# wrapper's _ensure() lazy-load reloads the model every query.
178+
self._reranker: Any | None = None
179+
self._reranker_name: str | None = None
174180
self._minimal_setup = minimal_setup
175181

176182
def _ensure(self):
@@ -213,13 +219,25 @@ def query(
213219
from supamem.rerankers import load_reranker # noqa: PLC0415
214220

215221
reranker_name = getattr(self.config, "reranker_name", "off")
216-
try:
217-
reranker = load_reranker(reranker_name, self.config)
218-
except LookupError:
219-
# Fail-soft: treat unknown reranker as off; ResolvedConfig's
220-
# validation gate (load_config) is the canonical fail-closed
221-
# surface — at backend.query() time we never abort retrieval.
222+
# D-POOL: reuse cached reranker if name unchanged. load_reranker
223+
# constructs a fresh wrapper each call — caching here keeps the
224+
# underlying model loaded for the lifetime of the backend.
225+
if reranker_name == "off":
222226
reranker = None
227+
elif (
228+
self._reranker is not None and self._reranker_name == reranker_name
229+
):
230+
reranker = self._reranker
231+
else:
232+
try:
233+
reranker = load_reranker(reranker_name, self.config)
234+
except LookupError:
235+
# Fail-soft: treat unknown reranker as off; ResolvedConfig's
236+
# validation gate (load_config) is the canonical fail-closed
237+
# surface — at backend.query() time we never abort retrieval.
238+
reranker = None
239+
self._reranker = reranker
240+
self._reranker_name = reranker_name
223241

224242
# D-POOL-01: widen prefetch only when reranker is on.
225243
prefetch_limit = (
@@ -299,14 +317,21 @@ def query(
299317
t0 = _time.perf_counter()
300318
try:
301319
reranked = reranker.rerank(text, pre_rerank)
302-
except Exception:
320+
except Exception as _rerank_exc:
303321
# Plugin failure → fall through to off-path semantics
304322
# (T-RERANK-INVAR mitigation: never silently drop hits).
323+
# Surface exception class+message so plugin authors and
324+
# bench runs can see WHY rerank failed, not just that it did.
325+
import traceback as _tb # noqa: PLC0415
326+
305327
from supamem.console import err_console # noqa: PLC0415
306328

307329
err_console.print(
308-
"[supamem.warn]reranker raised — falling back to off-branch"
330+
f"[supamem.warn]reranker raised "
331+
f"({type(_rerank_exc).__name__}: {_rerank_exc}) "
332+
"— falling back to off-branch"
309333
)
334+
err_console.print(_tb.format_exc())
310335
reranker = None
311336
reranked = []
312337
elapsed_ms = (_time.perf_counter() - t0) * 1000.0

0 commit comments

Comments
 (0)