Skip to content

Commit f365501

Browse files
committed
perf: optimize async io hot paths and extend benchmarks
1 parent c25c558 commit f365501

7 files changed

Lines changed: 512 additions & 64 deletions

File tree

astrbot/core/backup/exporter.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,10 @@ async def export_all(
161161
# 3. 导出配置文件
162162
if progress_callback:
163163
await progress_callback("config", 0, 100, "正在导出配置文件...")
164-
if await asyncio.to_thread(os.path.exists, self.config_path):
165-
config_content = await asyncio.to_thread(
166-
Path(self.config_path).read_text, encoding="utf-8"
167-
)
164+
config_content = await asyncio.to_thread(
165+
self._read_text_if_exists, self.config_path
166+
)
167+
if config_content is not None:
168168
zf.writestr("config/cmd_config.json", config_content)
169169
self._add_checksum("config/cmd_config.json", config_content)
170170
if progress_callback:
@@ -361,18 +361,34 @@ async def _export_attachments(
361361
self, zf: zipfile.ZipFile, attachments: list[dict]
362362
) -> None:
363363
"""导出附件文件"""
364+
await asyncio.to_thread(self._export_attachments_sync, zf, attachments)
365+
366+
def _export_attachments_sync(
367+
self, zf: zipfile.ZipFile, attachments: list[dict]
368+
) -> None:
369+
"""在单个线程中批量导出附件,减少高频线程切换。"""
364370
for attachment in attachments:
365371
try:
366372
file_path = attachment.get("path", "")
367-
if file_path and await asyncio.to_thread(os.path.exists, file_path):
368-
# 使用 attachment_id 作为文件名
369-
attachment_id = attachment.get("attachment_id", "")
370-
ext = os.path.splitext(file_path)[1]
371-
archive_path = f"files/attachments/{attachment_id}{ext}"
372-
zf.write(file_path, archive_path)
373+
if not file_path:
374+
continue
375+
# 使用 attachment_id 作为文件名
376+
attachment_id = attachment.get("attachment_id", "")
377+
ext = os.path.splitext(file_path)[1]
378+
archive_path = f"files/attachments/{attachment_id}{ext}"
379+
zf.write(file_path, archive_path)
380+
except FileNotFoundError:
381+
# 和旧逻辑保持一致:缺失文件直接跳过。
382+
continue
373383
except Exception as e:
374384
logger.warning(f"导出附件失败: {e}")
375385

386+
def _read_text_if_exists(self, file_path: str) -> str | None:
387+
"""Read text file when it exists in a single synchronous call."""
388+
if not os.path.exists(file_path):
389+
return None
390+
return Path(file_path).read_text(encoding="utf-8")
391+
376392
def _model_to_dict(self, record: Any) -> dict:
377393
"""将 SQLModel 实例转换为字典
378394

astrbot/core/message/components.py

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,16 @@
4040
from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64
4141

4242

43+
def _absolute_path(path: str) -> str:
44+
return os.path.abspath(path)
45+
46+
47+
def _absolute_path_if_exists(path: str | None) -> str | None:
48+
if not path or not os.path.exists(path):
49+
return None
50+
return os.path.abspath(path)
51+
52+
4353
class ComponentType(StrEnum):
4454
# Basic Segment Types
4555
Plain = "Plain" # plain text message
@@ -159,17 +169,18 @@ async def convert_to_file_path(self) -> str:
159169
return self.file[8:]
160170
if self.file.startswith("http"):
161171
file_path = await download_image_by_url(self.file)
162-
return await asyncio.to_thread(os.path.abspath, file_path)
172+
return await asyncio.to_thread(_absolute_path, file_path)
163173
if self.file.startswith("base64://"):
164174
bs64_data = self.file.removeprefix("base64://")
165175
image_bytes = base64.b64decode(bs64_data)
166176
file_path = os.path.join(
167177
get_astrbot_temp_path(), f"recordseg_{uuid.uuid4()}.jpg"
168178
)
169179
await asyncio.to_thread(Path(file_path).write_bytes, image_bytes)
170-
return await asyncio.to_thread(os.path.abspath, file_path)
171-
if await asyncio.to_thread(os.path.exists, self.file):
172-
return await asyncio.to_thread(os.path.abspath, self.file)
180+
return await asyncio.to_thread(_absolute_path, file_path)
181+
local_path = await asyncio.to_thread(_absolute_path_if_exists, self.file)
182+
if local_path:
183+
return local_path
173184
raise Exception(f"not a valid file: {self.file}")
174185

175186
async def convert_to_base64(self) -> str:
@@ -189,10 +200,11 @@ async def convert_to_base64(self) -> str:
189200
bs64_data = await file_to_base64(file_path)
190201
elif self.file.startswith("base64://"):
191202
bs64_data = self.file
192-
elif await asyncio.to_thread(os.path.exists, self.file):
193-
bs64_data = await file_to_base64(self.file)
194203
else:
195-
raise Exception(f"not a valid file: {self.file}")
204+
try:
205+
bs64_data = await file_to_base64(self.file)
206+
except OSError as exc:
207+
raise Exception(f"not a valid file: {self.file}") from exc
196208
bs64_data = bs64_data.removeprefix("base64://")
197209
return bs64_data
198210

@@ -256,11 +268,15 @@ async def convert_to_file_path(self) -> str:
256268
get_astrbot_temp_path(), f"videoseg_{uuid.uuid4().hex}"
257269
)
258270
await download_file(url, video_file_path)
259-
if await asyncio.to_thread(os.path.exists, video_file_path):
260-
return await asyncio.to_thread(os.path.abspath, video_file_path)
271+
local_path = await asyncio.to_thread(
272+
_absolute_path_if_exists, video_file_path
273+
)
274+
if local_path:
275+
return local_path
261276
raise Exception(f"download failed: {url}")
262-
if await asyncio.to_thread(os.path.exists, url):
263-
return await asyncio.to_thread(os.path.abspath, url)
277+
local_path = await asyncio.to_thread(_absolute_path_if_exists, url)
278+
if local_path:
279+
return local_path
264280
raise Exception(f"not a valid file: {url}")
265281

266282
async def register_to_file_service(self) -> str:
@@ -449,17 +465,18 @@ async def convert_to_file_path(self) -> str:
449465
return url[8:]
450466
if url.startswith("http"):
451467
image_file_path = await download_image_by_url(url)
452-
return await asyncio.to_thread(os.path.abspath, image_file_path)
468+
return await asyncio.to_thread(_absolute_path, image_file_path)
453469
if url.startswith("base64://"):
454470
bs64_data = url.removeprefix("base64://")
455471
image_bytes = base64.b64decode(bs64_data)
456472
image_file_path = os.path.join(
457473
get_astrbot_temp_path(), f"imgseg_{uuid.uuid4()}.jpg"
458474
)
459475
await asyncio.to_thread(Path(image_file_path).write_bytes, image_bytes)
460-
return await asyncio.to_thread(os.path.abspath, image_file_path)
461-
if await asyncio.to_thread(os.path.exists, url):
462-
return await asyncio.to_thread(os.path.abspath, url)
476+
return await asyncio.to_thread(_absolute_path, image_file_path)
477+
local_path = await asyncio.to_thread(_absolute_path_if_exists, url)
478+
if local_path:
479+
return local_path
463480
raise Exception(f"not a valid file: {url}")
464481

465482
async def convert_to_base64(self) -> str:
@@ -480,10 +497,11 @@ async def convert_to_base64(self) -> str:
480497
bs64_data = await file_to_base64(image_file_path)
481498
elif url.startswith("base64://"):
482499
bs64_data = url
483-
elif await asyncio.to_thread(os.path.exists, url):
484-
bs64_data = await file_to_base64(url)
485500
else:
486-
raise Exception(f"not a valid file: {url}")
501+
try:
502+
bs64_data = await file_to_base64(url)
503+
except OSError as exc:
504+
raise Exception(f"not a valid file: {url}") from exc
487505
bs64_data = bs64_data.removeprefix("base64://")
488506
return bs64_data
489507

@@ -734,8 +752,9 @@ async def get_file(self, allow_return_url: bool = False) -> str:
734752
):
735753
path = path[1:]
736754

737-
if await asyncio.to_thread(os.path.exists, path):
738-
return await asyncio.to_thread(os.path.abspath, path)
755+
local_path = await asyncio.to_thread(_absolute_path_if_exists, path)
756+
if local_path:
757+
return local_path
739758

740759
if self.url:
741760
await self._download_file()
@@ -750,7 +769,7 @@ async def get_file(self, allow_return_url: bool = False) -> str:
750769
and path[2] == ":"
751770
):
752771
path = path[1:]
753-
return await asyncio.to_thread(os.path.abspath, path)
772+
return await asyncio.to_thread(_absolute_path, path)
754773

755774
return ""
756775

@@ -766,7 +785,7 @@ async def _download_file(self) -> None:
766785
filename = f"fileseg_{uuid.uuid4().hex}"
767786
file_path = os.path.join(download_dir, filename)
768787
await download_file(self.url, file_path)
769-
self.file_ = await asyncio.to_thread(os.path.abspath, file_path)
788+
self.file_ = await asyncio.to_thread(_absolute_path, file_path)
770789

771790
async def register_to_file_service(self) -> str:
772791
"""将文件注册到文件服务。

astrbot/core/utils/io.py

Lines changed: 55 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -134,29 +134,18 @@ async def download_file(url: str, path: str, show_progress: bool = False) -> Non
134134
if resp.status != 200:
135135
raise Exception(f"下载文件失败: {resp.status}")
136136
total_size = int(resp.headers.get("content-length", 0))
137-
downloaded_size = 0
138137
start_time = time.time()
139138
if show_progress:
140139
print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}")
141140
file_obj = await asyncio.to_thread(Path(path).open, "wb")
142141
try:
143-
while True:
144-
chunk = await resp.content.read(8192)
145-
if not chunk:
146-
break
147-
await asyncio.to_thread(file_obj.write, chunk)
148-
downloaded_size += len(chunk)
149-
if show_progress:
150-
elapsed_time = (
151-
time.time() - start_time
152-
if time.time() - start_time > 0
153-
else 1
154-
)
155-
speed = downloaded_size / 1024 / elapsed_time # KB/s
156-
print(
157-
f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s",
158-
end="",
159-
)
142+
await _stream_to_file(
143+
resp.content,
144+
file_obj,
145+
total_size=total_size,
146+
start_time=start_time,
147+
show_progress=show_progress,
148+
)
160149
finally:
161150
await asyncio.to_thread(file_obj.close)
162151
except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError):
@@ -176,31 +165,65 @@ async def download_file(url: str, path: str, show_progress: bool = False) -> Non
176165
async with aiohttp.ClientSession() as session:
177166
async with session.get(url, ssl=ssl_context, timeout=120) as resp:
178167
total_size = int(resp.headers.get("content-length", 0))
179-
downloaded_size = 0
180168
start_time = time.time()
181169
if show_progress:
182170
print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}")
183171
file_obj = await asyncio.to_thread(Path(path).open, "wb")
184172
try:
185-
while True:
186-
chunk = await resp.content.read(8192)
187-
if not chunk:
188-
break
189-
await asyncio.to_thread(file_obj.write, chunk)
190-
downloaded_size += len(chunk)
191-
if show_progress:
192-
elapsed_time = time.time() - start_time
193-
speed = downloaded_size / 1024 / elapsed_time # KB/s
194-
print(
195-
f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s",
196-
end="",
197-
)
173+
await _stream_to_file(
174+
resp.content,
175+
file_obj,
176+
total_size=total_size,
177+
start_time=start_time,
178+
show_progress=show_progress,
179+
)
198180
finally:
199181
await asyncio.to_thread(file_obj.close)
200182
if show_progress:
201183
print()
202184

203185

186+
async def _stream_to_file(
187+
stream: aiohttp.StreamReader,
188+
file_obj,
189+
*,
190+
total_size: int,
191+
start_time: float,
192+
show_progress: bool,
193+
chunk_size: int = 8192,
194+
flush_threshold: int = 256 * 1024,
195+
) -> int:
196+
"""Stream HTTP response into file with buffered thread-offloaded writes."""
197+
downloaded_size = 0
198+
buffered = bytearray()
199+
progress_total = max(total_size, 1)
200+
201+
while True:
202+
chunk = await stream.read(chunk_size)
203+
if not chunk:
204+
break
205+
buffered.extend(chunk)
206+
downloaded_size += len(chunk)
207+
208+
if len(buffered) >= flush_threshold:
209+
chunk_to_write = bytes(buffered)
210+
buffered.clear()
211+
await asyncio.to_thread(file_obj.write, chunk_to_write)
212+
213+
if show_progress:
214+
elapsed_time = max(time.time() - start_time, 1e-6)
215+
speed = downloaded_size / 1024 / elapsed_time # KB/s
216+
print(
217+
f"\r下载进度: {downloaded_size / progress_total:.2%} 速度: {speed:.2f} KB/s",
218+
end="",
219+
)
220+
221+
if buffered:
222+
await asyncio.to_thread(file_obj.write, bytes(buffered))
223+
224+
return downloaded_size
225+
226+
204227
async def file_to_base64(file_path: str) -> str:
205228
data_bytes = await asyncio.to_thread(Path(file_path).read_bytes)
206229
base64_str = base64.b64encode(data_bytes).decode()

0 commit comments

Comments
 (0)