Skip to content

Commit 9cfd9e9

Browse files
zouyongheshuiping233
authored andcommitted
feat: optimize async io performance and benchmark coverage (#5737)
* docs: align deployment sections across multilingual readmes * docs: normalize deployment punctuation and AUR guidance * docs: fix french and russian deployment wording * perf: optimize async io hot paths and extend benchmarks * fix: address async io review feedback * fix: address follow-up async io review comments * fix: align base64 io error handling in message components * fix: harden attachment export ids and tune io chunking * fix: preserve best-effort attachment export and batch writes * test: expand path conversion and helper coverage
1 parent 6d7b62d commit 9cfd9e9

File tree

9 files changed

+824
-66
lines changed

9 files changed

+824
-66
lines changed

astrbot/core/backup/exporter.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,10 @@ async def export_all(
161161
# 3. 导出配置文件
162162
if progress_callback:
163163
await progress_callback("config", 0, 100, "正在导出配置文件...")
164-
if await asyncio.to_thread(os.path.exists, self.config_path):
165-
config_content = await asyncio.to_thread(
166-
Path(self.config_path).read_text, encoding="utf-8"
167-
)
164+
config_content = await asyncio.to_thread(
165+
self._read_text_if_exists, self.config_path
166+
)
167+
if config_content is not None:
168168
zf.writestr("config/cmd_config.json", config_content)
169169
self._add_checksum("config/cmd_config.json", config_content)
170170
if progress_callback:
@@ -361,17 +361,44 @@ async def _export_attachments(
361361
self, zf: zipfile.ZipFile, attachments: list[dict]
362362
) -> None:
363363
"""导出附件文件"""
364+
await asyncio.to_thread(self._export_attachments_sync, zf, attachments)
365+
366+
def _export_attachments_sync(
367+
self, zf: zipfile.ZipFile, attachments: list[dict]
368+
) -> None:
369+
"""在单个线程中批量导出附件,减少高频线程切换。"""
364370
for attachment in attachments:
371+
file_path = attachment.get("path", "")
372+
attachment_id = attachment.get("attachment_id")
365373
try:
366-
file_path = attachment.get("path", "")
367-
if file_path and await asyncio.to_thread(os.path.exists, file_path):
368-
# 使用 attachment_id 作为文件名
369-
attachment_id = attachment.get("attachment_id", "")
370-
ext = os.path.splitext(file_path)[1]
371-
archive_path = f"files/attachments/{attachment_id}{ext}"
372-
zf.write(file_path, archive_path)
374+
if not file_path:
375+
continue
376+
if not attachment_id:
377+
logger.warning(
378+
f"跳过附件导出:attachment_id 为空 (path={file_path})"
379+
)
380+
continue
381+
# 使用 attachment_id 作为文件名
382+
ext = os.path.splitext(file_path)[1]
383+
archive_path = f"files/attachments/{attachment_id}{ext}"
384+
zf.write(file_path, archive_path)
385+
except FileNotFoundError:
386+
# 和旧逻辑保持一致:缺失文件直接跳过。
387+
continue
388+
except OSError as e:
389+
logger.warning(
390+
f"导出附件失败 (path={file_path}, attachment_id={attachment_id or 'unknown'}): {e}"
391+
)
373392
except Exception as e:
374-
logger.warning(f"导出附件失败: {e}")
393+
logger.warning(
394+
f"导出附件时发生非预期错误,已跳过 (path={file_path}, attachment_id={attachment_id or 'unknown'}): {e}"
395+
)
396+
397+
def _read_text_if_exists(self, file_path: str) -> str | None:
398+
"""Read text file when it exists in a single synchronous call."""
399+
if not os.path.exists(file_path):
400+
return None
401+
return Path(file_path).read_text(encoding="utf-8")
375402

376403
def _model_to_dict(self, record: Any) -> dict:
377404
"""将 SQLModel 实例转换为字典

astrbot/core/message/components.py

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,16 @@
4040
from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64
4141

4242

43+
def _absolute_path(path: str) -> str:
44+
return os.path.abspath(path)
45+
46+
47+
def _absolute_path_if_exists(path: str | None) -> str | None:
48+
if not path or not os.path.exists(path):
49+
return None
50+
return os.path.abspath(path)
51+
52+
4353
class ComponentType(StrEnum):
4454
# Basic Segment Types
4555
Plain = "Plain" # plain text message
@@ -159,17 +169,18 @@ async def convert_to_file_path(self) -> str:
159169
return self.file[8:]
160170
if self.file.startswith("http"):
161171
file_path = await download_image_by_url(self.file)
162-
return await asyncio.to_thread(os.path.abspath, file_path)
172+
return await asyncio.to_thread(_absolute_path, file_path)
163173
if self.file.startswith("base64://"):
164174
bs64_data = self.file.removeprefix("base64://")
165175
image_bytes = base64.b64decode(bs64_data)
166176
file_path = os.path.join(
167177
get_astrbot_temp_path(), f"recordseg_{uuid.uuid4()}.jpg"
168178
)
169179
await asyncio.to_thread(Path(file_path).write_bytes, image_bytes)
170-
return await asyncio.to_thread(os.path.abspath, file_path)
171-
if await asyncio.to_thread(os.path.exists, self.file):
172-
return await asyncio.to_thread(os.path.abspath, self.file)
180+
return await asyncio.to_thread(_absolute_path, file_path)
181+
local_path = await asyncio.to_thread(_absolute_path_if_exists, self.file)
182+
if local_path:
183+
return local_path
173184
raise Exception(f"not a valid file: {self.file}")
174185

175186
async def convert_to_base64(self) -> str:
@@ -189,10 +200,11 @@ async def convert_to_base64(self) -> str:
189200
bs64_data = await file_to_base64(file_path)
190201
elif self.file.startswith("base64://"):
191202
bs64_data = self.file
192-
elif await asyncio.to_thread(os.path.exists, self.file):
193-
bs64_data = await file_to_base64(self.file)
194203
else:
195-
raise Exception(f"not a valid file: {self.file}")
204+
try:
205+
bs64_data = await file_to_base64(self.file)
206+
except OSError as exc:
207+
raise Exception(f"not a valid file: {self.file}") from exc
196208
bs64_data = bs64_data.removeprefix("base64://")
197209
return bs64_data
198210

@@ -256,11 +268,15 @@ async def convert_to_file_path(self) -> str:
256268
get_astrbot_temp_path(), f"videoseg_{uuid.uuid4().hex}"
257269
)
258270
await download_file(url, video_file_path)
259-
if await asyncio.to_thread(os.path.exists, video_file_path):
260-
return await asyncio.to_thread(os.path.abspath, video_file_path)
271+
local_path = await asyncio.to_thread(
272+
_absolute_path_if_exists, video_file_path
273+
)
274+
if local_path:
275+
return local_path
261276
raise Exception(f"download failed: {url}")
262-
if await asyncio.to_thread(os.path.exists, url):
263-
return await asyncio.to_thread(os.path.abspath, url)
277+
local_path = await asyncio.to_thread(_absolute_path_if_exists, url)
278+
if local_path:
279+
return local_path
264280
raise Exception(f"not a valid file: {url}")
265281

266282
async def register_to_file_service(self) -> str:
@@ -449,17 +465,18 @@ async def convert_to_file_path(self) -> str:
449465
return url[8:]
450466
if url.startswith("http"):
451467
image_file_path = await download_image_by_url(url)
452-
return await asyncio.to_thread(os.path.abspath, image_file_path)
468+
return await asyncio.to_thread(_absolute_path, image_file_path)
453469
if url.startswith("base64://"):
454470
bs64_data = url.removeprefix("base64://")
455471
image_bytes = base64.b64decode(bs64_data)
456472
image_file_path = os.path.join(
457473
get_astrbot_temp_path(), f"imgseg_{uuid.uuid4()}.jpg"
458474
)
459475
await asyncio.to_thread(Path(image_file_path).write_bytes, image_bytes)
460-
return await asyncio.to_thread(os.path.abspath, image_file_path)
461-
if await asyncio.to_thread(os.path.exists, url):
462-
return await asyncio.to_thread(os.path.abspath, url)
476+
return await asyncio.to_thread(_absolute_path, image_file_path)
477+
local_path = await asyncio.to_thread(_absolute_path_if_exists, url)
478+
if local_path:
479+
return local_path
463480
raise Exception(f"not a valid file: {url}")
464481

465482
async def convert_to_base64(self) -> str:
@@ -480,10 +497,11 @@ async def convert_to_base64(self) -> str:
480497
bs64_data = await file_to_base64(image_file_path)
481498
elif url.startswith("base64://"):
482499
bs64_data = url
483-
elif await asyncio.to_thread(os.path.exists, url):
484-
bs64_data = await file_to_base64(url)
485500
else:
486-
raise Exception(f"not a valid file: {url}")
501+
try:
502+
bs64_data = await file_to_base64(url)
503+
except OSError as exc:
504+
raise Exception(f"not a valid file: {url}") from exc
487505
bs64_data = bs64_data.removeprefix("base64://")
488506
return bs64_data
489507

@@ -734,8 +752,9 @@ async def get_file(self, allow_return_url: bool = False) -> str:
734752
):
735753
path = path[1:]
736754

737-
if await asyncio.to_thread(os.path.exists, path):
738-
return await asyncio.to_thread(os.path.abspath, path)
755+
local_path = await asyncio.to_thread(_absolute_path_if_exists, path)
756+
if local_path:
757+
return local_path
739758

740759
if self.url:
741760
await self._download_file()
@@ -750,7 +769,7 @@ async def get_file(self, allow_return_url: bool = False) -> str:
750769
and path[2] == ":"
751770
):
752771
path = path[1:]
753-
return await asyncio.to_thread(os.path.abspath, path)
772+
return await asyncio.to_thread(_absolute_path, path)
754773

755774
return ""
756775

@@ -766,7 +785,7 @@ async def _download_file(self) -> None:
766785
filename = f"fileseg_{uuid.uuid4().hex}"
767786
file_path = os.path.join(download_dir, filename)
768787
await download_file(self.url, file_path)
769-
self.file_ = await asyncio.to_thread(os.path.abspath, file_path)
788+
self.file_ = await asyncio.to_thread(_absolute_path, file_path)
770789

771790
async def register_to_file_service(self) -> str:
772791
"""将文件注册到文件服务。

astrbot/core/utils/io.py

Lines changed: 66 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import uuid
1010
import zipfile
1111
from pathlib import Path
12+
from typing import BinaryIO
1213

1314
import aiohttp
1415
import certifi
@@ -18,6 +19,8 @@
1819
from .astrbot_path import get_astrbot_data_path, get_astrbot_path, get_astrbot_temp_path
1920

2021
logger = logging.getLogger("astrbot")
22+
_DOWNLOAD_READ_CHUNK_SIZE = 64 * 1024
23+
_DOWNLOAD_FLUSH_THRESHOLD = 256 * 1024
2124

2225

2326
def on_error(func, path, exc_info) -> None:
@@ -134,29 +137,18 @@ async def download_file(url: str, path: str, show_progress: bool = False) -> Non
134137
if resp.status != 200:
135138
raise Exception(f"下载文件失败: {resp.status}")
136139
total_size = int(resp.headers.get("content-length", 0))
137-
downloaded_size = 0
138140
start_time = time.time()
139141
if show_progress:
140142
print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}")
141143
file_obj = await asyncio.to_thread(Path(path).open, "wb")
142144
try:
143-
while True:
144-
chunk = await resp.content.read(8192)
145-
if not chunk:
146-
break
147-
await asyncio.to_thread(file_obj.write, chunk)
148-
downloaded_size += len(chunk)
149-
if show_progress:
150-
elapsed_time = (
151-
time.time() - start_time
152-
if time.time() - start_time > 0
153-
else 1
154-
)
155-
speed = downloaded_size / 1024 / elapsed_time # KB/s
156-
print(
157-
f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s",
158-
end="",
159-
)
145+
await _stream_to_file(
146+
resp.content,
147+
file_obj,
148+
total_size=total_size,
149+
start_time=start_time,
150+
show_progress=show_progress,
151+
)
160152
finally:
161153
await asyncio.to_thread(file_obj.close)
162154
except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError):
@@ -176,31 +168,73 @@ async def download_file(url: str, path: str, show_progress: bool = False) -> Non
176168
async with aiohttp.ClientSession() as session:
177169
async with session.get(url, ssl=ssl_context, timeout=120) as resp:
178170
total_size = int(resp.headers.get("content-length", 0))
179-
downloaded_size = 0
180171
start_time = time.time()
181172
if show_progress:
182173
print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}")
183174
file_obj = await asyncio.to_thread(Path(path).open, "wb")
184175
try:
185-
while True:
186-
chunk = await resp.content.read(8192)
187-
if not chunk:
188-
break
189-
await asyncio.to_thread(file_obj.write, chunk)
190-
downloaded_size += len(chunk)
191-
if show_progress:
192-
elapsed_time = time.time() - start_time
193-
speed = downloaded_size / 1024 / elapsed_time # KB/s
194-
print(
195-
f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s",
196-
end="",
197-
)
176+
await _stream_to_file(
177+
resp.content,
178+
file_obj,
179+
total_size=total_size,
180+
start_time=start_time,
181+
show_progress=show_progress,
182+
)
198183
finally:
199184
await asyncio.to_thread(file_obj.close)
200185
if show_progress:
201186
print()
202187

203188

189+
async def _stream_to_file(
190+
stream: aiohttp.StreamReader,
191+
file_obj: BinaryIO,
192+
*,
193+
total_size: int,
194+
start_time: float,
195+
show_progress: bool,
196+
) -> None:
197+
"""Stream HTTP response into file with buffered thread-offloaded writes."""
198+
downloaded_size = 0
199+
known_total = total_size if total_size > 0 else None
200+
buffered = bytearray()
201+
202+
try:
203+
while True:
204+
chunk = await stream.read(_DOWNLOAD_READ_CHUNK_SIZE)
205+
if not chunk:
206+
break
207+
208+
buffered.extend(chunk)
209+
downloaded_size += len(chunk)
210+
211+
if len(buffered) >= _DOWNLOAD_FLUSH_THRESHOLD:
212+
await asyncio.to_thread(file_obj.write, bytes(buffered))
213+
buffered.clear()
214+
215+
if show_progress:
216+
_print_download_progress(downloaded_size, known_total, start_time)
217+
finally:
218+
if buffered:
219+
# Ensure buffered data is flushed even on cancellation.
220+
await asyncio.shield(asyncio.to_thread(file_obj.write, bytes(buffered)))
221+
222+
223+
def _print_download_progress(
224+
downloaded_size: int, total_size: int | None, start_time: float
225+
) -> None:
226+
elapsed_time = max(time.time() - start_time, 1e-6)
227+
speed = downloaded_size / 1024 / elapsed_time # KB/s
228+
229+
if total_size:
230+
percent = downloaded_size / total_size
231+
msg = f"\r下载进度: {percent:.2%} 速度: {speed:.2f} KB/s"
232+
else:
233+
msg = f"\r已下载: {downloaded_size} 字节 速度: {speed:.2f} KB/s"
234+
235+
print(msg, end="")
236+
237+
204238
async def file_to_base64(file_path: str) -> str:
205239
data_bytes = await asyncio.to_thread(Path(file_path).read_bytes)
206240
base64_str = base64.b64encode(data_bytes).decode()

tests/fixtures/helpers.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from dataclasses import dataclass, field
88
from pathlib import Path
99
from typing import Any, Callable
10+
from urllib.parse import urlparse
1011
from unittest.mock import AsyncMock, MagicMock
1112

1213
from astrbot.core.message.components import BaseMessageComponent
@@ -24,6 +25,25 @@ def __await__(self):
2425
return None
2526

2627

28+
def get_bound_tcp_port(site: Any) -> int:
29+
"""Resolve the bound aiohttp TCP site port for tests.
30+
31+
We prefer the public ``site.name`` first. Some aiohttp test setups with
32+
ephemeral ports may not expose a usable port there, so we fall back to
33+
``site._server.sockets`` as a test-only compatibility path.
34+
"""
35+
parsed = urlparse(getattr(site, "name", ""))
36+
if parsed.port is not None and parsed.port > 0:
37+
return parsed.port
38+
39+
server = getattr(site, "_server", None)
40+
sockets = getattr(server, "sockets", None) if server else None
41+
if sockets:
42+
return sockets[0].getsockname()[1]
43+
44+
raise RuntimeError("Unable to resolve bound TCP port from aiohttp site")
45+
46+
2747
# ============================================================
2848
# 平台配置工厂
2949
# ============================================================

0 commit comments

Comments
 (0)