Skip to content

Commit c4af721

Browse files
committed
fix: address review feedback — L2 consistency, reshape, tempfile, platform guard
1 parent 987cb2a commit c4af721

1 file changed

Lines changed: 61 additions & 36 deletions

File tree

app/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py

Lines changed: 61 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import os
88
import shutil
99
import tempfile
10+
import uuid
1011

1112
import numpy as np
1213

@@ -27,27 +28,46 @@ def _safe_temp_dir() -> str:
2728
1. %SystemRoot%\\Temp(Windows 系统临时目录,如 C:\\WINDOWS\\TEMP)
2829
2. tempfile.gettempdir()(当其为纯 ASCII 时)
2930
3. 当前工作目录
31+
4. 非 Windows 平台使用 tempfile.gettempdir()
3032
"""
31-
candidates = []
32-
# C:\Windows\Temp — 所有 Windows 安装的纯 ASCII 路径
33-
root = os.environ.get("SystemRoot", r"C:\Windows")
34-
candidates.append(os.path.join(root, "Temp"))
35-
# 标准 tempdir(可能是 ASCII 也可能不是)
36-
candidates.append(tempfile.gettempdir())
37-
# 兜底:当前目录
38-
try:
39-
candidates.append(os.getcwd())
40-
except OSError:
41-
pass
42-
43-
for d in candidates:
44-
if d.isascii() and os.path.isdir(d) and os.access(d, os.W_OK):
45-
return d
46-
47-
# 终极兜底:创建新目录
48-
fallback = r"C:\Windows\Temp"
49-
os.makedirs(fallback, exist_ok=True)
50-
return fallback
33+
# Windows 专属硬编码
34+
if os.name == "nt":
35+
candidates = []
36+
root = os.environ.get("SystemRoot", r"C:\Windows")
37+
candidates.append(os.path.join(root, "Temp"))
38+
candidates.append(tempfile.gettempdir())
39+
try:
40+
candidates.append(os.getcwd())
41+
except OSError:
42+
pass
43+
44+
for d in candidates:
45+
if d.isascii() and os.path.isdir(d) and os.access(d, os.W_OK):
46+
return d
47+
48+
# 所有候选都不行时抛异常,不再静默兜底
49+
raise OSError(
50+
f"_safe_temp_dir: 无法找到可写的纯 ASCII 临时目录。"
51+
f"检查过: {candidates}"
52+
)
53+
54+
# 非 Windows(Linux / macOS):tempfile 足够
55+
return tempfile.gettempdir()
56+
57+
58+
def _make_temp_file(prefix: str) -> str:
59+
"""创建用于 Faiss 桥接的唯一临时文件,返回路径。
60+
61+
使用 tempfile.mkstemp + UUID 保证多线程/多协程并发安全。
62+
"""
63+
safe_dir = _safe_temp_dir()
64+
fd, path = tempfile.mkstemp(
65+
prefix=f"{prefix}_{uuid.uuid4().hex[:8]}_",
66+
suffix=".faiss",
67+
dir=safe_dir,
68+
)
69+
os.close(fd)
70+
return path
5171

5272

5373
class EmbeddingStorage:
@@ -74,15 +94,16 @@ def _read_index(path: str) -> "faiss.Index":
7494
except RuntimeError:
7595
pass # 不吞其他异常类型
7696

77-
tmp = os.path.join(_safe_temp_dir(), f"_faiss_read_{os.getpid()}.faiss")
97+
tmp = _make_temp_file("_faiss_read")
7898
try:
7999
shutil.copy2(path, tmp)
80100
return faiss.read_index(tmp)
81101
finally:
82-
try:
83-
os.remove(tmp)
84-
except OSError:
85-
pass
102+
if os.path.exists(tmp):
103+
try:
104+
os.remove(tmp)
105+
except OSError:
106+
pass
86107

87108
@staticmethod
88109
def _write_index(index: "faiss.Index", path: str) -> None:
@@ -94,18 +115,21 @@ def _write_index(index: "faiss.Index", path: str) -> None:
94115
95116
写入前先确保目标目录存在,防止 shutil.move 时目录缺失。
96117
"""
97-
os.makedirs(os.path.dirname(path), exist_ok=True)
118+
dirname = os.path.dirname(path)
119+
if dirname:
120+
os.makedirs(dirname, exist_ok=True)
98121

99-
tmp = os.path.join(_safe_temp_dir(), f"_faiss_write_{os.getpid()}.faiss")
122+
tmp = _make_temp_file("_faiss_write")
100123
try:
101124
faiss.write_index(index, tmp)
102125
# Windows 同盘 move 是原子 rename,跨盘则走 copy+delete
103126
shutil.move(tmp, path)
104127
finally:
105-
try:
106-
os.remove(tmp)
107-
except OSError:
108-
pass
128+
if os.path.exists(tmp):
129+
try:
130+
os.remove(tmp)
131+
except OSError:
132+
pass
109133

110134
async def insert(self, vector: np.ndarray, id: int) -> None:
111135
"""插入向量
@@ -122,7 +146,7 @@ async def insert(self, vector: np.ndarray, id: int) -> None:
122146
raise ValueError(
123147
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}",
124148
)
125-
self.index.add_with_ids(vector.reshape(1, -1), np.array([id]))
149+
self.index.add_with_ids(vector.reshape(1, -1), np.array([id], dtype=np.int64))
126150
await self.save_index()
127151

128152
async def insert_batch(self, vectors: np.ndarray, ids: list[int]) -> None:
@@ -140,7 +164,7 @@ async def insert_batch(self, vectors: np.ndarray, ids: list[int]) -> None:
140164
raise ValueError(
141165
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vectors.shape[1]}",
142166
)
143-
self.index.add_with_ids(vectors, np.array(ids))
167+
self.index.add_with_ids(vectors, np.array(ids, dtype=np.int64))
144168
await self.save_index()
145169

146170
async def search(self, vector: np.ndarray, k: int) -> tuple:
@@ -154,8 +178,9 @@ async def search(self, vector: np.ndarray, k: int) -> tuple:
154178
155179
"""
156180
assert self.index is not None, "FAISS index is not initialized."
157-
faiss.normalize_L2(vector)
158-
distances, indices = self.index.search(vector, k)
181+
# IndexFlatL2 是欧氏距离索引,不进行归一化,
182+
# 确保与 insert/insert_batch 的一致性
183+
distances, indices = self.index.search(vector.reshape(1, -1), k)
159184
return distances, indices
160185

161186
async def delete(self, ids: list[int]) -> None:
@@ -172,6 +197,6 @@ async def delete(self, ids: list[int]) -> None:
172197

173198
async def save_index(self) -> None:
174199
"""保存索引(兼容含非 ASCII 字符的 Windows 路径)"""
175-
if self.index is None:
200+
if self.index is None or not self.path:
176201
return
177202
self._write_index(self.index, self.path)

0 commit comments

Comments
 (0)