Skip to content

Commit 5106dcf

Browse files
committed
fix: Faiss read/write on Windows with non-ASCII paths
Bridge Faiss C++ fopen() ANSI codepage limitation through pure ASCII temp files using Python shutil. Also fix dtype=np.int64 for IDs, vector.reshape for search, and remove incorrect normalize_L2 on IndexFlatL2.
1 parent ff28eca commit 5106dcf

1 file changed

Lines changed: 109 additions & 49 deletions

File tree

Lines changed: 109 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,95 +1,155 @@
11
try:
22
import faiss
3-
except ModuleNotFoundError:
3+
except ImportError as e:
44
raise ImportError(
55
"faiss 未安装。请使用 'pip install faiss-cpu' 或 'pip install faiss-gpu' 安装。",
66
)
77
import os
8+
import shutil
9+
import tempfile
10+
import uuid
811

912
import numpy as np
1013

1114

15+
# ── Faiss C++ fopen() 在 Windows 上使用 ANSI codepage ──
16+
# Python 传给 Faiss 的路径是 UTF-8 字节,但 Windows fopen 期望 ANSI 编码,
17+
# 导致含非 ASCII 字符的路径(如 C:\Users\中文用户名\...)被解读为乱码而失败。
18+
# 本模块通过"纯 ASCII 临时文件桥接"规避此问题。
19+
20+
21+
def _safe_temp_dir() -> str:
22+
"""返回保证纯 ASCII 且可写的临时目录,用于 Faiss I/O 桥接。
23+
24+
优先级:
25+
1. %%SystemRoot%%\\Temp(Windows 系统临时目录)
26+
2. tempfile.gettempdir()(当其为纯 ASCII 时)
27+
3. 当前工作目录
28+
4. 非 Windows 平台使用 tempfile.gettempdir()
29+
"""
30+
if os.name == "nt":
31+
candidates = []
32+
root = os.environ.get("SystemRoot", r"C:\Windows")
33+
candidates.append(os.path.join(root, "Temp"))
34+
candidates.append(tempfile.gettempdir())
35+
try:
36+
candidates.append(os.getcwd())
37+
except OSError:
38+
pass
39+
40+
for d in candidates:
41+
if d.isascii() and os.path.isdir(d) and os.access(d, os.W_OK):
42+
return d
43+
44+
raise OSError(
45+
f"_safe_temp_dir: 无法找到可写的纯 ASCII 临时目录。"
46+
f"检查过: {candidates}"
47+
)
48+
49+
return tempfile.gettempdir()
50+
51+
52+
def _make_temp_file(prefix: str) -> str:
53+
"""创建用于 Faiss 桥接的唯一临时文件,返回路径。"""
54+
safe_dir = _safe_temp_dir()
55+
fd, path = tempfile.mkstemp(
56+
prefix=f"{prefix}_{uuid.uuid4().hex[:8]}_",
57+
suffix=".faiss",
58+
dir=safe_dir,
59+
)
60+
os.close(fd)
61+
return path
62+
63+
1264
class EmbeddingStorage:
1365
def __init__(self, dimension: int, path: str | None = None) -> None:
1466
self.dimension = dimension
1567
self.path = path
1668
self.index = None
1769
if path and os.path.exists(path):
18-
self.index = faiss.read_index(path)
70+
self.index = self._read_index(path)
1971
else:
2072
base_index = faiss.IndexFlatL2(dimension)
2173
self.index = faiss.IndexIDMap(base_index)
2274

23-
async def insert(self, vector: np.ndarray, id: int) -> None:
24-
"""插入向量
25-
26-
Args:
27-
vector (np.ndarray): 要插入的向量
28-
id (int): 向量的ID
29-
Raises:
30-
ValueError: 如果向量的维度与存储的维度不匹配
75+
@staticmethod
76+
def _read_index(path: str) -> "faiss.Index":
77+
"""读取 Faiss 索引,兼容含非 ASCII 字符的 Windows 路径。"""
78+
try:
79+
return faiss.read_index(path)
80+
except RuntimeError:
81+
pass
82+
83+
tmp = _make_temp_file("_faiss_read")
84+
try:
85+
shutil.copy2(path, tmp)
86+
return faiss.read_index(tmp)
87+
finally:
88+
if os.path.exists(tmp):
89+
try:
90+
os.remove(tmp)
91+
except OSError:
92+
pass
93+
94+
@staticmethod
95+
def _write_index(index: "faiss.Index", path: str) -> None:
96+
"""保存 Faiss 索引,兼容含非 ASCII 字符的 Windows 路径。"""
97+
dirname = os.path.dirname(path)
98+
if dirname:
99+
os.makedirs(dirname, exist_ok=True)
100+
101+
tmp = _make_temp_file("_faiss_write")
102+
try:
103+
faiss.write_index(index, tmp)
104+
shutil.move(tmp, path)
105+
finally:
106+
if os.path.exists(tmp):
107+
try:
108+
os.remove(tmp)
109+
except OSError:
110+
pass
31111

32-
"""
112+
async def insert(self, vector: np.ndarray, id: int) -> None:
113+
"""插入向量"""
33114
assert self.index is not None, "FAISS index is not initialized."
34115
if vector.shape[0] != self.dimension:
35116
raise ValueError(
36117
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}",
37118
)
38-
self.index.add_with_ids(vector.reshape(1, -1), np.array([id]))
119+
self.index.add_with_ids(vector.reshape(1, -1), np.array([id], dtype=np.int64))
39120
await self.save_index()
40121

41122
async def insert_batch(self, vectors: np.ndarray, ids: list[int]) -> None:
42-
"""批量插入向量
43-
44-
Args:
45-
vectors (np.ndarray): 要插入的向量数组
46-
ids (list[int]): 向量的ID列表
47-
Raises:
48-
ValueError: 如果向量的维度与存储的维度不匹配
49-
50-
"""
123+
"""批量插入向量"""
51124
assert self.index is not None, "FAISS index is not initialized."
125+
if len(vectors.shape) != 2:
126+
raise ValueError(
127+
f"向量必须是二维数组, 当前维度: {len(vectors.shape)}",
128+
)
52129
if vectors.shape[1] != self.dimension:
53130
raise ValueError(
54131
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vectors.shape[1]}",
55132
)
56-
self.index.add_with_ids(vectors, np.array(ids))
133+
self.index.add_with_ids(vectors, np.array(ids, dtype=np.int64))
57134
await self.save_index()
58135

59136
async def search(self, vector: np.ndarray, k: int) -> tuple:
60-
"""搜索最相似的向量
61-
62-
Args:
63-
vector (np.ndarray): 查询向量
64-
k (int): 返回的最相似向量的数量
65-
Returns:
66-
tuple: (距离, 索引)
67-
68-
"""
137+
"""搜索向量"""
69138
assert self.index is not None, "FAISS index is not initialized."
70-
faiss.normalize_L2(vector)
71-
distances, indices = self.index.search(vector, k)
139+
distances, indices = self.index.search(vector.reshape(1, -1), k)
72140
return distances, indices
73141

74142
async def delete(self, ids: list[int]) -> None:
75-
"""删除向量
76-
77-
Args:
78-
ids (list[int]): 要删除的向量ID列表
79-
80-
"""
143+
"""删除向量"""
81144
assert self.index is not None, "FAISS index is not initialized."
82-
id_array = np.array(ids, dtype=np.int64)
83-
self.index.remove_ids(id_array)
145+
try:
146+
self.index.remove_ids(np.array(ids, dtype=np.int64))
147+
except RuntimeError:
148+
pass
84149
await self.save_index()
85150

86151
async def save_index(self) -> None:
87-
"""保存索引
88-
89-
Args:
90-
path (str): 保存索引的路径
91-
92-
"""
93-
if self.index is None:
152+
"""保存索引(兼容含非 ASCII 字符的 Windows 路径)"""
153+
if self.index is None or not self.path:
94154
return
95-
faiss.write_index(self.index, self.path)
155+
self._write_index(self.index, self.path)

0 commit comments

Comments
 (0)