Skip to content

Commit 25dc3d6

Browse files
committed
fix: address follow-up async io review comments
1 parent 55dfaf3 commit 25dc3d6

5 files changed

Lines changed: 40 additions & 35 deletions

File tree

astrbot/core/message/components.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ async def convert_to_base64(self) -> str:
203203
else:
204204
try:
205205
bs64_data = await file_to_base64(self.file)
206-
except OSError as exc:
206+
except (FileNotFoundError, IsADirectoryError) as exc:
207207
raise Exception(f"not a valid file: {self.file}") from exc
208208
bs64_data = bs64_data.removeprefix("base64://")
209209
return bs64_data
@@ -500,7 +500,7 @@ async def convert_to_base64(self) -> str:
500500
else:
501501
try:
502502
bs64_data = await file_to_base64(url)
503-
except OSError as exc:
503+
except (FileNotFoundError, IsADirectoryError) as exc:
504504
raise Exception(f"not a valid file: {url}") from exc
505505
bs64_data = bs64_data.removeprefix("base64://")
506506
return bs64_data

astrbot/core/utils/io.py

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -191,43 +191,35 @@ async def _stream_to_file(
191191
total_size: int,
192192
start_time: float,
193193
show_progress: bool,
194-
chunk_size: int = 8192,
195-
flush_threshold: int = 256 * 1024,
196194
) -> None:
197-
"""Stream HTTP response into file with buffered thread-offloaded writes."""
195+
"""Stream HTTP response into file with thread-offloaded writes."""
198196
downloaded_size = 0
199-
buffered = bytearray()
200-
progress_total = total_size if total_size > 0 else None
197+
known_total = total_size if total_size > 0 else None
201198

202199
while True:
203-
chunk = await stream.read(chunk_size)
200+
chunk = await stream.read(8192)
204201
if not chunk:
205202
break
206-
buffered.extend(chunk)
203+
await asyncio.to_thread(file_obj.write, chunk)
204+
207205
downloaded_size += len(chunk)
206+
if show_progress:
207+
_print_download_progress(downloaded_size, known_total, start_time)
208208

209-
if len(buffered) >= flush_threshold:
210-
chunk_to_write = bytes(buffered)
211-
buffered.clear()
212-
await asyncio.to_thread(file_obj.write, chunk_to_write)
213209

214-
if show_progress:
215-
elapsed_time = max(time.time() - start_time, 1e-6)
216-
speed = downloaded_size / 1024 / elapsed_time # KB/s
217-
if progress_total:
218-
percent = downloaded_size / progress_total
219-
print(
220-
f"\r下载进度: {percent:.2%} 速度: {speed:.2f} KB/s",
221-
end="",
222-
)
223-
else:
224-
print(
225-
f"\r已下载: {downloaded_size} 字节 速度: {speed:.2f} KB/s",
226-
end="",
227-
)
210+
def _print_download_progress(
211+
downloaded_size: int, total_size: int | None, start_time: float
212+
) -> None:
213+
elapsed_time = max(time.time() - start_time, 1e-6)
214+
speed = downloaded_size / 1024 / elapsed_time # KB/s
215+
216+
if total_size:
217+
percent = downloaded_size / total_size
218+
msg = f"\r下载进度: {percent:.2%} 速度: {speed:.2f} KB/s"
219+
else:
220+
msg = f"\r已下载: {downloaded_size} 字节 速度: {speed:.2f} KB/s"
228221

229-
if buffered:
230-
await asyncio.to_thread(file_obj.write, bytes(buffered))
222+
print(msg, end="")
231223

232224

233225
async def file_to_base64(file_path: str) -> str:

tests/fixtures/helpers.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from dataclasses import dataclass, field
88
from pathlib import Path
99
from typing import Any, Callable
10+
from urllib.parse import urlparse
1011
from unittest.mock import AsyncMock, MagicMock
1112

1213
from astrbot.core.message.components import BaseMessageComponent
@@ -24,6 +25,20 @@ def __await__(self):
2425
return None
2526

2627

28+
def get_bound_tcp_port(site: Any) -> int:
29+
"""Resolve bound aiohttp TCP site port with public API first."""
30+
parsed = urlparse(getattr(site, "name", ""))
31+
if parsed.port is not None and parsed.port > 0:
32+
return parsed.port
33+
34+
server = getattr(site, "_server", None)
35+
sockets = getattr(server, "sockets", None) if server else None
36+
if sockets:
37+
return sockets[0].getsockname()[1]
38+
39+
raise RuntimeError("Unable to resolve bound TCP port from aiohttp site")
40+
41+
2742
# ============================================================
2843
# 平台配置工厂
2944
# ============================================================

tests/performance/test_benchmarks.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from astrbot.core.backup.exporter import AstrBotExporter
2727
from astrbot.core.message.components import File, Image, Record
2828
from astrbot.core.utils.io import download_file, file_to_base64
29+
from tests.fixtures.helpers import get_bound_tcp_port
2930

3031

3132
@dataclass(slots=True)
@@ -155,9 +156,7 @@ async def handle_download(_request):
155156
await runner.setup()
156157
site = web.TCPSite(runner, "127.0.0.1", 0)
157158
await site.start()
158-
sockets = site._server.sockets # noqa: SLF001
159-
assert sockets
160-
port = sockets[0].getsockname()[1]
159+
port = get_bound_tcp_port(site)
161160
download_url = f"http://127.0.0.1:{port}/download.bin"
162161

163162
async def bench_file_to_base64() -> None:

tests/unit/test_io_download_file.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33

44
from astrbot.core.utils.io import download_file
5+
from tests.fixtures.helpers import get_bound_tcp_port
56

67

78
@pytest.mark.asyncio
@@ -19,9 +20,7 @@ async def handle(_request):
1920
await site.start()
2021

2122
try:
22-
sockets = site._server.sockets # noqa: SLF001
23-
assert sockets
24-
port = sockets[0].getsockname()[1]
23+
port = get_bound_tcp_port(site)
2524
url = f"http://127.0.0.1:{port}/file.bin"
2625

2726
out = tmp_path / "downloaded.bin"

0 commit comments

Comments
 (0)