Skip to content

Commit 374dd8d

Browse files
committed
fix: Faiss read/write failure on Windows with non-ASCII user paths
1 parent 3290d75 commit 374dd8d

1 file changed

Lines changed: 120 additions & 13 deletions

File tree

astrbot/core/db/vec_db/faiss_impl/embedding_storage.py

Lines changed: 120 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,132 @@
55
"faiss 未安装。请使用 'pip install faiss-cpu' 或 'pip install faiss-gpu' 安装。",
66
)
77
import os
8+
import shutil
9+
import tempfile
10+
import uuid
811

912
import numpy as np
1013

1114

15+
# ── Faiss C++ fopen() 在 Windows 上使用 ANSI codepage ──
16+
# Python 传给 Faiss 的路径是 UTF-8 字节,但 Windows fopen 期望 ANSI 编码,
17+
# 导致含非 ASCII 字符的路径(如 C:\Users\中文用户名\...)被解读为乱码而失败。
18+
# 本模块通过"纯 ASCII 临时文件桥接"规避此问题。
19+
#
20+
# tempfile.gettempdir() 可能返回含中文用户的路径(取决于 TEMP 环境变量),
21+
# 所以 _safe_temp_dir() 硬编码一个保证纯 ASCII 且可写的目录。
22+
23+
24+
def _safe_temp_dir() -> str:
25+
"""返回保证纯 ASCII 且可写的临时目录,用于 Faiss I/O 桥接。
26+
27+
优先级:
28+
1. %SystemRoot%\\Temp(Windows 系统临时目录,如 C:\\WINDOWS\\TEMP)
29+
2. tempfile.gettempdir()(当其为纯 ASCII 时)
30+
3. 当前工作目录
31+
4. 非 Windows 平台使用 tempfile.gettempdir()
32+
"""
33+
# Windows 专属硬编码
34+
if os.name == "nt":
35+
candidates = []
36+
root = os.environ.get("SystemRoot", r"C:\Windows")
37+
candidates.append(os.path.join(root, "Temp"))
38+
candidates.append(tempfile.gettempdir())
39+
try:
40+
candidates.append(os.getcwd())
41+
except OSError:
42+
pass
43+
44+
for d in candidates:
45+
if d.isascii() and os.path.isdir(d) and os.access(d, os.W_OK):
46+
return d
47+
48+
# 所有候选都不行时抛异常,不再静默兜底
49+
raise OSError(
50+
f"_safe_temp_dir: 无法找到可写的纯 ASCII 临时目录。"
51+
f"检查过: {candidates}"
52+
)
53+
54+
# 非 Windows(Linux / macOS):tempfile 足够
55+
return tempfile.gettempdir()
56+
57+
58+
def _make_temp_file(prefix: str) -> str:
59+
"""创建用于 Faiss 桥接的唯一临时文件,返回路径。
60+
61+
使用 tempfile.mkstemp + UUID 保证多线程/多协程并发安全。
62+
"""
63+
safe_dir = _safe_temp_dir()
64+
fd, path = tempfile.mkstemp(
65+
prefix=f"{prefix}_{uuid.uuid4().hex[:8]}_",
66+
suffix=".faiss",
67+
dir=safe_dir,
68+
)
69+
os.close(fd)
70+
return path
71+
72+
1273
class EmbeddingStorage:
1374
def __init__(self, dimension: int, path: str | None = None) -> None:
1475
self.dimension = dimension
1576
self.path = path
1677
self.index = None
1778
if path and os.path.exists(path):
18-
self.index = faiss.read_index(path)
79+
self.index = self._read_index(path)
1980
else:
2081
base_index = faiss.IndexFlatL2(dimension)
2182
self.index = faiss.IndexIDMap(base_index)
2283

84+
@staticmethod
85+
def _read_index(path: str) -> "faiss.Index":
86+
"""读取 Faiss 索引,兼容含非 ASCII 字符的 Windows 路径。
87+
88+
Faiss C++ fopen() 使用 ANSI codepage,无法处理 Python 传入的
89+
UTF-8 编码非 ASCII 路径。应对:先尝试直接读;失败则用 Python
90+
shutil.copy2 复制到纯 ASCII 临时文件再读。
91+
"""
92+
try:
93+
return faiss.read_index(path)
94+
except RuntimeError:
95+
pass # 不吞其他异常类型
96+
97+
tmp = _make_temp_file("_faiss_read")
98+
try:
99+
shutil.copy2(path, tmp)
100+
return faiss.read_index(tmp)
101+
finally:
102+
if os.path.exists(tmp):
103+
try:
104+
os.remove(tmp)
105+
except OSError:
106+
pass
107+
108+
@staticmethod
109+
def _write_index(index: "faiss.Index", path: str) -> None:
110+
"""保存 Faiss 索引,兼容含非 ASCII 字符的 Windows 路径。
111+
112+
先写入纯 ASCII 临时文件,再用 Python shutil.move 移动到位。
113+
Python 文件操作使用 Windows wide-char API (CreateFileW),
114+
正确支持 Unicode 路径。
115+
116+
写入前先确保目标目录存在,防止 shutil.move 时目录缺失。
117+
"""
118+
dirname = os.path.dirname(path)
119+
if dirname:
120+
os.makedirs(dirname, exist_ok=True)
121+
122+
tmp = _make_temp_file("_faiss_write")
123+
try:
124+
faiss.write_index(index, tmp)
125+
# Windows 同盘 move 是原子 rename,跨盘则走 copy+delete
126+
shutil.move(tmp, path)
127+
finally:
128+
if os.path.exists(tmp):
129+
try:
130+
os.remove(tmp)
131+
except OSError:
132+
pass
133+
23134
async def insert(self, vector: np.ndarray, id: int) -> None:
24135
"""插入向量
25136
@@ -35,7 +146,7 @@ async def insert(self, vector: np.ndarray, id: int) -> None:
35146
raise ValueError(
36147
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}",
37148
)
38-
self.index.add_with_ids(vector.reshape(1, -1), np.array([id]))
149+
self.index.add_with_ids(vector.reshape(1, -1), np.array([id], dtype=np.int64))
39150
await self.save_index()
40151

41152
async def insert_batch(self, vectors: np.ndarray, ids: list[int]) -> None:
@@ -53,7 +164,7 @@ async def insert_batch(self, vectors: np.ndarray, ids: list[int]) -> None:
53164
raise ValueError(
54165
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vectors.shape[1]}",
55166
)
56-
self.index.add_with_ids(vectors, np.array(ids))
167+
self.index.add_with_ids(vectors, np.array(ids, dtype=np.int64))
57168
await self.save_index()
58169

59170
async def search(self, vector: np.ndarray, k: int) -> tuple:
@@ -67,8 +178,9 @@ async def search(self, vector: np.ndarray, k: int) -> tuple:
67178
68179
"""
69180
assert self.index is not None, "FAISS index is not initialized."
70-
faiss.normalize_L2(vector)
71-
distances, indices = self.index.search(vector, k)
181+
# IndexFlatL2 是欧氏距离索引,不进行归一化,
182+
# 确保与 insert/insert_batch 的一致性
183+
distances, indices = self.index.search(vector.reshape(1, -1), k)
72184
return distances, indices
73185

74186
async def delete(self, ids: list[int]) -> None:
@@ -84,12 +196,7 @@ async def delete(self, ids: list[int]) -> None:
84196
await self.save_index()
85197

86198
async def save_index(self) -> None:
87-
"""保存索引
88-
89-
Args:
90-
path (str): 保存索引的路径
91-
92-
"""
93-
if self.index is None:
199+
"""保存索引(兼容含非 ASCII 字符的 Windows 路径)"""
200+
if self.index is None or not self.path:
94201
return
95-
faiss.write_index(self.index, self.path)
202+
self._write_index(self.index, self.path)

0 commit comments

Comments
 (0)