Skip to content

Commit 55dfaf3

Browse files
committed
fix: address async io review feedback
1 parent f365501 commit 55dfaf3

3 files changed

Lines changed: 46 additions & 13 deletions

File tree

astrbot/core/backup/exporter.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -368,20 +368,22 @@ def _export_attachments_sync(
368368
) -> None:
369369
"""在单个线程中批量导出附件,减少高频线程切换。"""
370370
for attachment in attachments:
371+
file_path = attachment.get("path", "")
372+
attachment_id = attachment.get("attachment_id", "")
371373
try:
372-
file_path = attachment.get("path", "")
373374
if not file_path:
374375
continue
375376
# 使用 attachment_id 作为文件名
376-
attachment_id = attachment.get("attachment_id", "")
377377
ext = os.path.splitext(file_path)[1]
378378
archive_path = f"files/attachments/{attachment_id}{ext}"
379379
zf.write(file_path, archive_path)
380380
except FileNotFoundError:
381381
# 和旧逻辑保持一致:缺失文件直接跳过。
382382
continue
383-
except Exception as e:
384-
logger.warning(f"导出附件失败: {e}")
383+
except OSError as e:
384+
logger.warning(
385+
f"导出附件失败 (path={file_path}, attachment_id={attachment_id or 'unknown'}): {e}"
386+
)
385387

386388
def _read_text_if_exists(self, file_path: str) -> str | None:
387389
"""Read text file when it exists in a single synchronous call."""

astrbot/core/utils/io.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import uuid
1010
import zipfile
1111
from pathlib import Path
12+
from typing import BinaryIO
1213

1314
import aiohttp
1415
import certifi
@@ -185,18 +186,18 @@ async def download_file(url: str, path: str, show_progress: bool = False) -> Non
185186

186187
async def _stream_to_file(
187188
stream: aiohttp.StreamReader,
188-
file_obj,
189+
file_obj: BinaryIO,
189190
*,
190191
total_size: int,
191192
start_time: float,
192193
show_progress: bool,
193194
chunk_size: int = 8192,
194195
flush_threshold: int = 256 * 1024,
195-
) -> int:
196+
) -> None:
196197
"""Stream HTTP response into file with buffered thread-offloaded writes."""
197198
downloaded_size = 0
198199
buffered = bytearray()
199-
progress_total = max(total_size, 1)
200+
progress_total = total_size if total_size > 0 else None
200201

201202
while True:
202203
chunk = await stream.read(chunk_size)
@@ -213,16 +214,21 @@ async def _stream_to_file(
213214
if show_progress:
214215
elapsed_time = max(time.time() - start_time, 1e-6)
215216
speed = downloaded_size / 1024 / elapsed_time # KB/s
216-
print(
217-
f"\r下载进度: {downloaded_size / progress_total:.2%} 速度: {speed:.2f} KB/s",
218-
end="",
219-
)
217+
if progress_total:
218+
percent = downloaded_size / progress_total
219+
print(
220+
f"\r下载进度: {percent:.2%} 速度: {speed:.2f} KB/s",
221+
end="",
222+
)
223+
else:
224+
print(
225+
f"\r已下载: {downloaded_size} 字节 速度: {speed:.2f} KB/s",
226+
end="",
227+
)
220228

221229
if buffered:
222230
await asyncio.to_thread(file_obj.write, bytes(buffered))
223231

224-
return downloaded_size
225-
226232

227233
async def file_to_base64(file_path: str) -> str:
228234
data_bytes = await asyncio.to_thread(Path(file_path).read_bytes)

tests/unit/test_message_components_paths.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import base64
12
import os
23

34
import pytest
@@ -52,3 +53,27 @@ async def test_record_convert_to_base64_raises_on_missing_file(tmp_path):
5253

5354
with pytest.raises(Exception, match="not a valid file"):
5455
await record.convert_to_base64()
56+
57+
58+
@pytest.mark.asyncio
59+
async def test_image_convert_to_base64_reads_existing_local_file(tmp_path):
60+
raw = b"image-bytes"
61+
file_path = tmp_path / "exists_image.bin"
62+
file_path.write_bytes(raw)
63+
64+
image = Image(file=str(file_path))
65+
encoded = await image.convert_to_base64()
66+
67+
assert base64.b64decode(encoded) == raw
68+
69+
70+
@pytest.mark.asyncio
71+
async def test_record_convert_to_base64_reads_existing_local_file(tmp_path):
72+
raw = b"record-bytes"
73+
file_path = tmp_path / "exists_record.bin"
74+
file_path.write_bytes(raw)
75+
76+
record = Record(file=str(file_path))
77+
encoded = await record.convert_to_base64()
78+
79+
assert base64.b64decode(encoded) == raw

0 commit comments

Comments
 (0)