Skip to content

Commit 30fb287

Browse files
authored
fix(api): normalize torch default dtype to float32 after concurrent model init (#2167)
transformers' dtype context manager (entered by SentenceTransformer / CrossEncoder / from_pretrained) does a non-thread-safe save/restore of the process-global default dtype. When an fp16 embedding model and an fp32 reranker/query-analyzer load in parallel during MemoryEngine.initialize(), an unlucky interleave can leave the global default stuck at float16. Every later encode() then emits NaN vectors that pgvector rejects ("NaN not allowed in vector") on MPS, or raises "c10::Half != float" on CPU -- non-deterministically across restarts. Keep the model loads fully parallel and, once asyncio.gather() has joined every load thread, normalize the global default dtype back to float32 -- the inference state a healthy boot already converges to. The reset is race-free (all threads have finished) and only touches torch if a local provider actually loaded it. Fixes #2162
1 parent f0802b8 commit 30fb287

2 files changed

Lines changed: 113 additions & 0 deletions

File tree

hindsight-api-slim/hindsight_api/engine/memory_engine.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import inspect
1616
import json
1717
import logging
18+
import sys
1819
import time
1920
import uuid
2021
from collections.abc import Awaitable, Callable
@@ -2567,6 +2568,28 @@ async def verify_llm():
25672568
f"first-time model download legitimately needs more time."
25682569
) from e
25692570

2571+
# Normalize torch's process-global default dtype back to float32 after the
2572+
# concurrent local model loads. transformers' dtype context manager (entered
2573+
# by SentenceTransformer / CrossEncoder / from_pretrained) does a
2574+
# NON-thread-safe save/restore of the global default dtype: when an fp16 and
2575+
# an fp32 model load in parallel above, an unlucky interleave can leave the
2576+
# default stuck at float16, after which every encode() emits NaN vectors that
2577+
# pgvector rejects ("NaN not allowed in vector") on MPS, or raises
2578+
# "c10::Half != float" on CPU — non-deterministically across restarts. By the
2579+
# time gather() returns, all load threads have joined, so resetting the
2580+
# default here is race-free, keeps the loads fully parallel, and converges on
2581+
# the float32 inference state a healthy boot already reaches. torch is only
2582+
# imported (in sys.modules) if a local provider actually loaded a model.
2583+
# See https://github.com/vectorize-io/hindsight/issues/2162.
2584+
torch_mod = sys.modules.get("torch")
2585+
if torch_mod is not None and torch_mod.get_default_dtype() != torch_mod.float32:
2586+
logger.warning(
2587+
"torch default dtype was left at %s after concurrent model init; "
2588+
"restoring float32 to avoid NaN embedding vectors (issue #2162).",
2589+
torch_mod.get_default_dtype(),
2590+
)
2591+
torch_mod.set_default_dtype(torch_mod.float32)
2592+
25702593
# Run database migrations if enabled
25712594
if self._run_migrations:
25722595
if not self.db_url:
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
"""
2+
Startup must leave torch's global default dtype at float32 regardless of how the
3+
concurrent local model loads interleave.
4+
5+
Covers issue #2162: transformers' dtype context manager (entered by
6+
SentenceTransformer / CrossEncoder / from_pretrained) does a NON-thread-safe
7+
save/restore of the *process-global* default dtype. When an fp16 embedding model
8+
and an fp32 reranker/query-analyzer load in parallel at startup, an unlucky
9+
interleave leaves the global default stuck at float16 — every later encode() then
10+
emits NaN vectors that pgvector rejects ("NaN not allowed in vector") on MPS, or
11+
raises "c10::Half != float" on CPU, non-deterministically across restarts.
12+
13+
MemoryEngine.initialize() loads the models in parallel (for speed) and then, once
14+
the gather has joined every load thread, normalizes the global default dtype back
15+
to float32 — the inference state a healthy boot already converges to. This test
16+
simulates the poisoning by having a model load flip the default to float16, then
17+
asserts initialize() leaves it at float32.
18+
"""
19+
20+
import pytest
21+
22+
from hindsight_api import MemoryEngine
23+
from hindsight_api.engine.task_backend import SyncTaskBackend
24+
25+
26+
class _StopInit(Exception):
27+
"""Sentinel to abort initialize() right after the model-load gather."""
28+
29+
30+
class _PoisoningEmbeddings:
31+
"""Local embedding stub that mimics an fp16 load poisoning the global dtype."""
32+
33+
provider_name = "local"
34+
35+
async def initialize(self) -> None:
36+
import torch
37+
38+
# Reproduce the symptom of transformers' racy dtype restore: the global
39+
# default is left at float16 after the (parallel) load.
40+
torch.set_default_dtype(torch.float16)
41+
42+
43+
class _NoopCrossEncoder:
44+
provider_name = "local"
45+
46+
async def initialize(self) -> None:
47+
return None
48+
49+
50+
class _NoopQueryAnalyzer:
51+
def load(self) -> None:
52+
return None
53+
54+
55+
@pytest.mark.asyncio
56+
async def test_global_default_dtype_restored_to_float32_after_init():
57+
"""A load that leaves the torch default at float16 is normalized back to float32."""
58+
import torch
59+
60+
original = torch.get_default_dtype()
61+
try:
62+
engine = MemoryEngine(
63+
# Non-pg0 URL so start_pg0() is a no-op and __init__ never connects.
64+
db_url="postgresql://u:p@localhost:5999/db",
65+
memory_llm_provider="none",
66+
memory_llm_api_key=None,
67+
memory_llm_model="none",
68+
embeddings=_PoisoningEmbeddings(),
69+
cross_encoder=_NoopCrossEncoder(),
70+
query_analyzer=_NoopQueryAnalyzer(),
71+
run_migrations=False,
72+
skip_llm_verification=True,
73+
lazy_reranker=False, # load the cross-encoder eagerly, in the gather
74+
task_backend=SyncTaskBackend(),
75+
)
76+
77+
# Abort right after the post-gather dtype restore, before any real DB work.
78+
async def _stop(*args, **kwargs):
79+
raise _StopInit
80+
81+
engine._backend.initialize = _stop # type: ignore[method-assign]
82+
83+
with pytest.raises(_StopInit):
84+
await engine.initialize()
85+
86+
# The embedding load poisoned the default to float16; initialize() must
87+
# have normalized it back so later encode() can't emit NaN vectors.
88+
assert torch.get_default_dtype() == torch.float32
89+
finally:
90+
torch.set_default_dtype(original)

0 commit comments

Comments
 (0)