Skip to content

Commit 31ff6b1

Browse files
committed
fix: narrow RuntimeError catch, add search validation, bridge only when needed
Addresses bot review feedback on PR AstrBotDevs#8323: - Add _needs_bridge() helper to activate temp file bridge only on Windows + non-ASCII paths (Sourcery AstrBotDevs#5) - _read_index: re-raise RuntimeError when bridging not needed, preventing silent swallowing of genuine Faiss errors (Sourcery AstrBotDevs#1) - _write_index: skip temp file for ASCII/non-Windows paths (Sourcery AstrBotDevs#5) - search(): validate ndim==1 and dimension before reshape, preventing silent semantic corruption on 2D input (Sourcery AstrBotDevs#3, AstrBotDevs#4) - _safe_temp_dir & _make_temp_file: simplify (Sourcery AstrBotDevs#6, AstrBotDevs#7) - Remove redundant CWD fallback (never reached on non-ASCII paths) - Remove redundant UUID prefix (mkstemp O_EXCL guarantees uniqueness) All changes tested: 119/119 pass covering bridge logic, ASCII/non-ASCII paths, concurrent temp file uniqueness, search validation, and exception propagation.
1 parent 5106dcf commit 31ff6b1

1 file changed

Lines changed: 32 additions & 23 deletions

File tree

astrbot/core/db/vec_db/faiss_impl/embedding_storage.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import os
88
import shutil
99
import tempfile
10-
import uuid
1110

1211
import numpy as np
1312

@@ -18,45 +17,41 @@
1817
# 本模块通过"纯 ASCII 临时文件桥接"规避此问题。
1918

2019

20+
def _needs_bridge(path: str) -> bool:
21+
"""判断是否需要 ASCII 临时文件桥接。"""
22+
return os.name == "nt" and not path.isascii()
23+
24+
2125
def _safe_temp_dir() -> str:
2226
"""返回保证纯 ASCII 且可写的临时目录,用于 Faiss I/O 桥接。
2327
2428
优先级:
25-
1. %%SystemRoot%%\\Temp(Windows 系统临时目录)
29+
1. %SystemRoot%\\Temp(Windows 系统临时目录)
2630
2. tempfile.gettempdir()(当其为纯 ASCII 时)
27-
3. 当前工作目录
28-
4. 非 Windows 平台使用 tempfile.gettempdir()
31+
3. 非 Windows 平台使用 tempfile.gettempdir()
2932
"""
3033
if os.name == "nt":
31-
candidates = []
3234
root = os.environ.get("SystemRoot", r"C:\Windows")
33-
candidates.append(os.path.join(root, "Temp"))
34-
candidates.append(tempfile.gettempdir())
35-
try:
36-
candidates.append(os.getcwd())
37-
except OSError:
38-
pass
35+
temp_dir = os.path.join(root, "Temp")
36+
if temp_dir.isascii() and os.path.isdir(temp_dir) and os.access(temp_dir, os.W_OK):
37+
return temp_dir
3938

40-
for d in candidates:
41-
if d.isascii() and os.path.isdir(d) and os.access(d, os.W_OK):
42-
return d
39+
tmp = tempfile.gettempdir()
40+
if tmp.isascii():
41+
return tmp
4342

4443
raise OSError(
45-
f"_safe_temp_dir: 无法找到可写的纯 ASCII 临时目录。"
46-
f"检查过: {candidates}"
44+
"_safe_temp_dir: 无法找到可写的纯 ASCII 临时目录。"
45+
f" 检查过 SystemRoot\\Temp={temp_dir}, gettempdir={tmp}"
4746
)
4847

4948
return tempfile.gettempdir()
5049

5150

5251
def _make_temp_file(prefix: str) -> str:
53-
"""创建用于 Faiss 桥接的唯一临时文件,返回路径。"""
52+
"""创建用于 Faiss 桥接的临时文件,返回路径。"""
5453
safe_dir = _safe_temp_dir()
55-
fd, path = tempfile.mkstemp(
56-
prefix=f"{prefix}_{uuid.uuid4().hex[:8]}_",
57-
suffix=".faiss",
58-
dir=safe_dir,
59-
)
54+
fd, path = tempfile.mkstemp(prefix=f"{prefix}_", suffix=".faiss", dir=safe_dir)
6055
os.close(fd)
6156
return path
6257

@@ -78,7 +73,8 @@ def _read_index(path: str) -> "faiss.Index":
7873
try:
7974
return faiss.read_index(path)
8075
except RuntimeError:
81-
pass
76+
if not _needs_bridge(path):
77+
raise
8278

8379
tmp = _make_temp_file("_faiss_read")
8480
try:
@@ -98,6 +94,10 @@ def _write_index(index: "faiss.Index", path: str) -> None:
9894
if dirname:
9995
os.makedirs(dirname, exist_ok=True)
10096

97+
if not _needs_bridge(path):
98+
faiss.write_index(index, path)
99+
return
100+
101101
tmp = _make_temp_file("_faiss_write")
102102
try:
103103
faiss.write_index(index, tmp)
@@ -136,6 +136,15 @@ async def insert_batch(self, vectors: np.ndarray, ids: list[int]) -> None:
136136
async def search(self, vector: np.ndarray, k: int) -> tuple:
137137
"""搜索向量"""
138138
assert self.index is not None, "FAISS index is not initialized."
139+
if vector.ndim != 1:
140+
raise ValueError(
141+
f"查询向量必须是 1 维, 实际维度: {vector.ndim}。"
142+
" 如需批量搜索请使用 Faiss 原生 API。"
143+
)
144+
if vector.shape[0] != self.dimension:
145+
raise ValueError(
146+
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}",
147+
)
139148
distances, indices = self.index.search(vector.reshape(1, -1), k)
140149
return distances, indices
141150

0 commit comments

Comments
 (0)