77import os
88import shutil
99import tempfile
10+ import uuid
1011
1112import numpy as np
1213
@@ -27,27 +28,46 @@ def _safe_temp_dir() -> str:
2728 1. %SystemRoot%\\ Temp(Windows 系统临时目录,如 C:\\ WINDOWS\\ TEMP)
2829 2. tempfile.gettempdir()(当其为纯 ASCII 时)
2930 3. 当前工作目录
31+ 4. 非 Windows 平台使用 tempfile.gettempdir()
3032 """
31- candidates = []
32- # C:\Windows\Temp — 所有 Windows 安装的纯 ASCII 路径
33- root = os .environ .get ("SystemRoot" , r"C:\Windows" )
34- candidates .append (os .path .join (root , "Temp" ))
35- # 标准 tempdir(可能是 ASCII 也可能不是)
36- candidates .append (tempfile .gettempdir ())
37- # 兜底:当前目录
38- try :
39- candidates .append (os .getcwd ())
40- except OSError :
41- pass
42-
43- for d in candidates :
44- if d .isascii () and os .path .isdir (d ) and os .access (d , os .W_OK ):
45- return d
46-
47- # 终极兜底:创建新目录
48- fallback = r"C:\Windows\Temp"
49- os .makedirs (fallback , exist_ok = True )
50- return fallback
33+ # Windows 专属硬编码
34+ if os .name == "nt" :
35+ candidates = []
36+ root = os .environ .get ("SystemRoot" , r"C:\Windows" )
37+ candidates .append (os .path .join (root , "Temp" ))
38+ candidates .append (tempfile .gettempdir ())
39+ try :
40+ candidates .append (os .getcwd ())
41+ except OSError :
42+ pass
43+
44+ for d in candidates :
45+ if d .isascii () and os .path .isdir (d ) and os .access (d , os .W_OK ):
46+ return d
47+
48+ # 所有候选都不行时抛异常,不再静默兜底
49+ raise OSError (
50+ f"_safe_temp_dir: 无法找到可写的纯 ASCII 临时目录。"
51+ f"检查过: { candidates } "
52+ )
53+
54+ # 非 Windows(Linux / macOS):tempfile 足够
55+ return tempfile .gettempdir ()
56+
57+
58+ def _make_temp_file (prefix : str ) -> str :
59+ """创建用于 Faiss 桥接的唯一临时文件,返回路径。
60+
61+ 使用 tempfile.mkstemp + UUID 保证多线程/多协程并发安全。
62+ """
63+ safe_dir = _safe_temp_dir ()
64+ fd , path = tempfile .mkstemp (
65+ prefix = f"{ prefix } _{ uuid .uuid4 ().hex [:8 ]} _" ,
66+ suffix = ".faiss" ,
67+ dir = safe_dir ,
68+ )
69+ os .close (fd )
70+ return path
5171
5272
5373class EmbeddingStorage :
@@ -74,15 +94,16 @@ def _read_index(path: str) -> "faiss.Index":
7494 except RuntimeError :
7595 pass # 不吞其他异常类型
7696
77- tmp = os . path . join ( _safe_temp_dir (), f"_faiss_read_ { os . getpid () } .faiss " )
97+ tmp = _make_temp_file ( "_faiss_read " )
7898 try :
7999 shutil .copy2 (path , tmp )
80100 return faiss .read_index (tmp )
81101 finally :
82- try :
83- os .remove (tmp )
84- except OSError :
85- pass
102+ if os .path .exists (tmp ):
103+ try :
104+ os .remove (tmp )
105+ except OSError :
106+ pass
86107
87108 @staticmethod
88109 def _write_index (index : "faiss.Index" , path : str ) -> None :
@@ -94,18 +115,21 @@ def _write_index(index: "faiss.Index", path: str) -> None:
94115
95116 写入前先确保目标目录存在,防止 shutil.move 时目录缺失。
96117 """
97- os .makedirs (os .path .dirname (path ), exist_ok = True )
118+ dirname = os .path .dirname (path )
119+ if dirname :
120+ os .makedirs (dirname , exist_ok = True )
98121
99- tmp = os . path . join ( _safe_temp_dir (), f"_faiss_write_ { os . getpid () } .faiss " )
122+ tmp = _make_temp_file ( "_faiss_write " )
100123 try :
101124 faiss .write_index (index , tmp )
102125 # Windows 同盘 move 是原子 rename,跨盘则走 copy+delete
103126 shutil .move (tmp , path )
104127 finally :
105- try :
106- os .remove (tmp )
107- except OSError :
108- pass
128+ if os .path .exists (tmp ):
129+ try :
130+ os .remove (tmp )
131+ except OSError :
132+ pass
109133
110134 async def insert (self , vector : np .ndarray , id : int ) -> None :
111135 """插入向量
@@ -122,7 +146,7 @@ async def insert(self, vector: np.ndarray, id: int) -> None:
122146 raise ValueError (
123147 f"向量维度不匹配, 期望: { self .dimension } , 实际: { vector .shape [0 ]} " ,
124148 )
125- self .index .add_with_ids (vector .reshape (1 , - 1 ), np .array ([id ]))
149+ self .index .add_with_ids (vector .reshape (1 , - 1 ), np .array ([id ], dtype = np . int64 ))
126150 await self .save_index ()
127151
128152 async def insert_batch (self , vectors : np .ndarray , ids : list [int ]) -> None :
@@ -140,7 +164,7 @@ async def insert_batch(self, vectors: np.ndarray, ids: list[int]) -> None:
140164 raise ValueError (
141165 f"向量维度不匹配, 期望: { self .dimension } , 实际: { vectors .shape [1 ]} " ,
142166 )
143- self .index .add_with_ids (vectors , np .array (ids ))
167+ self .index .add_with_ids (vectors , np .array (ids , dtype = np . int64 ))
144168 await self .save_index ()
145169
146170 async def search (self , vector : np .ndarray , k : int ) -> tuple :
@@ -154,8 +178,9 @@ async def search(self, vector: np.ndarray, k: int) -> tuple:
154178
155179 """
156180 assert self .index is not None , "FAISS index is not initialized."
157- faiss .normalize_L2 (vector )
158- distances , indices = self .index .search (vector , k )
181+ # IndexFlatL2 是欧氏距离索引,不进行归一化,
182+ # 确保与 insert/insert_batch 的一致性
183+ distances , indices = self .index .search (vector .reshape (1 , - 1 ), k )
159184 return distances , indices
160185
161186 async def delete (self , ids : list [int ]) -> None :
@@ -172,6 +197,6 @@ async def delete(self, ids: list[int]) -> None:
172197
173198 async def save_index (self ) -> None :
174199 """保存索引(兼容含非 ASCII 字符的 Windows 路径)"""
175- if self .index is None :
200+ if self .index is None or not self . path :
176201 return
177202 self ._write_index (self .index , self .path )
0 commit comments