-
-
Notifications
You must be signed in to change notification settings - Fork 2.3k
feat: optimize async io performance and benchmark coverage #5737
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
5547cda
2a3868f
5697792
c25c558
f365501
55dfaf3
25dc3d6
553e5b7
0321474
92e5280
8bf20db
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -40,6 +40,16 @@ | |
| from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64 | ||
|
|
||
|
|
||
| def _absolute_path(path: str) -> str: | ||
| return os.path.abspath(path) | ||
|
|
||
|
|
||
| def _absolute_path_if_exists(path: str | None) -> str | None: | ||
| if not path or not os.path.exists(path): | ||
| return None | ||
| return os.path.abspath(path) | ||
|
|
||
|
|
||
| class ComponentType(StrEnum): | ||
| # Basic Segment Types | ||
| Plain = "Plain" # plain text message | ||
|
|
@@ -159,17 +169,18 @@ async def convert_to_file_path(self) -> str: | |
| return self.file[8:] | ||
| if self.file.startswith("http"): | ||
| file_path = await download_image_by_url(self.file) | ||
| return await asyncio.to_thread(os.path.abspath, file_path) | ||
| return await asyncio.to_thread(_absolute_path, file_path) | ||
| if self.file.startswith("base64://"): | ||
| bs64_data = self.file.removeprefix("base64://") | ||
| image_bytes = base64.b64decode(bs64_data) | ||
| file_path = os.path.join( | ||
| get_astrbot_temp_path(), f"recordseg_{uuid.uuid4()}.jpg" | ||
| ) | ||
| await asyncio.to_thread(Path(file_path).write_bytes, image_bytes) | ||
| return await asyncio.to_thread(os.path.abspath, file_path) | ||
| if await asyncio.to_thread(os.path.exists, self.file): | ||
| return await asyncio.to_thread(os.path.abspath, self.file) | ||
| return await asyncio.to_thread(_absolute_path, file_path) | ||
| local_path = await asyncio.to_thread(_absolute_path_if_exists, self.file) | ||
| if local_path: | ||
| return local_path | ||
|
Comment on lines
+181
to
+183
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The message components ( |
||
| raise Exception(f"not a valid file: {self.file}") | ||
|
|
||
| async def convert_to_base64(self) -> str: | ||
|
|
@@ -189,10 +200,11 @@ async def convert_to_base64(self) -> str: | |
| bs64_data = await file_to_base64(file_path) | ||
| elif self.file.startswith("base64://"): | ||
| bs64_data = self.file | ||
| elif await asyncio.to_thread(os.path.exists, self.file): | ||
| bs64_data = await file_to_base64(self.file) | ||
| else: | ||
| raise Exception(f"not a valid file: {self.file}") | ||
| try: | ||
|
sourcery-ai[bot] marked this conversation as resolved.
|
||
| bs64_data = await file_to_base64(self.file) | ||
| except OSError as exc: | ||
| raise Exception(f"not a valid file: {self.file}") from exc | ||
| bs64_data = bs64_data.removeprefix("base64://") | ||
| return bs64_data | ||
|
|
||
|
|
@@ -256,11 +268,15 @@ async def convert_to_file_path(self) -> str: | |
| get_astrbot_temp_path(), f"videoseg_{uuid.uuid4().hex}" | ||
| ) | ||
| await download_file(url, video_file_path) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| if await asyncio.to_thread(os.path.exists, video_file_path): | ||
| return await asyncio.to_thread(os.path.abspath, video_file_path) | ||
| local_path = await asyncio.to_thread( | ||
| _absolute_path_if_exists, video_file_path | ||
| ) | ||
| if local_path: | ||
| return local_path | ||
| raise Exception(f"download failed: {url}") | ||
| if await asyncio.to_thread(os.path.exists, url): | ||
| return await asyncio.to_thread(os.path.abspath, url) | ||
| local_path = await asyncio.to_thread(_absolute_path_if_exists, url) | ||
| if local_path: | ||
| return local_path | ||
| raise Exception(f"not a valid file: {url}") | ||
|
|
||
| async def register_to_file_service(self) -> str: | ||
|
|
@@ -449,17 +465,18 @@ async def convert_to_file_path(self) -> str: | |
| return url[8:] | ||
| if url.startswith("http"): | ||
| image_file_path = await download_image_by_url(url) | ||
| return await asyncio.to_thread(os.path.abspath, image_file_path) | ||
| return await asyncio.to_thread(_absolute_path, image_file_path) | ||
| if url.startswith("base64://"): | ||
| bs64_data = url.removeprefix("base64://") | ||
| image_bytes = base64.b64decode(bs64_data) | ||
| image_file_path = os.path.join( | ||
| get_astrbot_temp_path(), f"imgseg_{uuid.uuid4()}.jpg" | ||
| ) | ||
| await asyncio.to_thread(Path(image_file_path).write_bytes, image_bytes) | ||
| return await asyncio.to_thread(os.path.abspath, image_file_path) | ||
| if await asyncio.to_thread(os.path.exists, url): | ||
| return await asyncio.to_thread(os.path.abspath, url) | ||
| return await asyncio.to_thread(_absolute_path, image_file_path) | ||
| local_path = await asyncio.to_thread(_absolute_path_if_exists, url) | ||
| if local_path: | ||
| return local_path | ||
| raise Exception(f"not a valid file: {url}") | ||
|
|
||
| async def convert_to_base64(self) -> str: | ||
|
|
@@ -480,10 +497,11 @@ async def convert_to_base64(self) -> str: | |
| bs64_data = await file_to_base64(image_file_path) | ||
| elif url.startswith("base64://"): | ||
| bs64_data = url | ||
| elif await asyncio.to_thread(os.path.exists, url): | ||
| bs64_data = await file_to_base64(url) | ||
| else: | ||
| raise Exception(f"not a valid file: {url}") | ||
| try: | ||
| bs64_data = await file_to_base64(url) | ||
| except OSError as exc: | ||
| raise Exception(f"not a valid file: {url}") from exc | ||
| bs64_data = bs64_data.removeprefix("base64://") | ||
| return bs64_data | ||
|
|
||
|
|
@@ -734,8 +752,9 @@ async def get_file(self, allow_return_url: bool = False) -> str: | |
| ): | ||
| path = path[1:] | ||
|
|
||
| if await asyncio.to_thread(os.path.exists, path): | ||
| return await asyncio.to_thread(os.path.abspath, path) | ||
| local_path = await asyncio.to_thread(_absolute_path_if_exists, path) | ||
| if local_path: | ||
| return local_path | ||
|
|
||
| if self.url: | ||
| await self._download_file() | ||
|
|
@@ -750,7 +769,7 @@ async def get_file(self, allow_return_url: bool = False) -> str: | |
| and path[2] == ":" | ||
| ): | ||
| path = path[1:] | ||
| return await asyncio.to_thread(os.path.abspath, path) | ||
| return await asyncio.to_thread(_absolute_path, path) | ||
|
|
||
| return "" | ||
|
|
||
|
|
@@ -766,7 +785,7 @@ async def _download_file(self) -> None: | |
| filename = f"fileseg_{uuid.uuid4().hex}" | ||
| file_path = os.path.join(download_dir, filename) | ||
| await download_file(self.url, file_path) | ||
| self.file_ = await asyncio.to_thread(os.path.abspath, file_path) | ||
| self.file_ = await asyncio.to_thread(_absolute_path, file_path) | ||
|
|
||
| async def register_to_file_service(self) -> str: | ||
| """将文件注册到文件服务。 | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -9,6 +9,7 @@ | |||||||||||
| import uuid | ||||||||||||
| import zipfile | ||||||||||||
| from pathlib import Path | ||||||||||||
| from typing import BinaryIO | ||||||||||||
|
|
||||||||||||
| import aiohttp | ||||||||||||
| import certifi | ||||||||||||
|
|
@@ -134,29 +135,18 @@ async def download_file(url: str, path: str, show_progress: bool = False) -> Non | |||||||||||
| if resp.status != 200: | ||||||||||||
| raise Exception(f"下载文件失败: {resp.status}") | ||||||||||||
| total_size = int(resp.headers.get("content-length", 0)) | ||||||||||||
| downloaded_size = 0 | ||||||||||||
| start_time = time.time() | ||||||||||||
| if show_progress: | ||||||||||||
| print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}") | ||||||||||||
| file_obj = await asyncio.to_thread(Path(path).open, "wb") | ||||||||||||
| try: | ||||||||||||
| while True: | ||||||||||||
| chunk = await resp.content.read(8192) | ||||||||||||
| if not chunk: | ||||||||||||
| break | ||||||||||||
| await asyncio.to_thread(file_obj.write, chunk) | ||||||||||||
| downloaded_size += len(chunk) | ||||||||||||
| if show_progress: | ||||||||||||
| elapsed_time = ( | ||||||||||||
| time.time() - start_time | ||||||||||||
| if time.time() - start_time > 0 | ||||||||||||
| else 1 | ||||||||||||
| ) | ||||||||||||
| speed = downloaded_size / 1024 / elapsed_time # KB/s | ||||||||||||
| print( | ||||||||||||
| f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s", | ||||||||||||
| end="", | ||||||||||||
| ) | ||||||||||||
| await _stream_to_file( | ||||||||||||
| resp.content, | ||||||||||||
| file_obj, | ||||||||||||
| total_size=total_size, | ||||||||||||
| start_time=start_time, | ||||||||||||
| show_progress=show_progress, | ||||||||||||
| ) | ||||||||||||
| finally: | ||||||||||||
| await asyncio.to_thread(file_obj.close) | ||||||||||||
| except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError): | ||||||||||||
|
sourcery-ai[bot] marked this conversation as resolved.
|
||||||||||||
|
|
@@ -176,31 +166,70 @@ async def download_file(url: str, path: str, show_progress: bool = False) -> Non | |||||||||||
| async with aiohttp.ClientSession() as session: | ||||||||||||
| async with session.get(url, ssl=ssl_context, timeout=120) as resp: | ||||||||||||
| total_size = int(resp.headers.get("content-length", 0)) | ||||||||||||
| downloaded_size = 0 | ||||||||||||
| start_time = time.time() | ||||||||||||
| if show_progress: | ||||||||||||
| print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}") | ||||||||||||
| file_obj = await asyncio.to_thread(Path(path).open, "wb") | ||||||||||||
| try: | ||||||||||||
| while True: | ||||||||||||
| chunk = await resp.content.read(8192) | ||||||||||||
| if not chunk: | ||||||||||||
| break | ||||||||||||
| await asyncio.to_thread(file_obj.write, chunk) | ||||||||||||
| downloaded_size += len(chunk) | ||||||||||||
| if show_progress: | ||||||||||||
| elapsed_time = time.time() - start_time | ||||||||||||
| speed = downloaded_size / 1024 / elapsed_time # KB/s | ||||||||||||
| print( | ||||||||||||
| f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s", | ||||||||||||
| end="", | ||||||||||||
| ) | ||||||||||||
| await _stream_to_file( | ||||||||||||
| resp.content, | ||||||||||||
| file_obj, | ||||||||||||
| total_size=total_size, | ||||||||||||
| start_time=start_time, | ||||||||||||
| show_progress=show_progress, | ||||||||||||
| ) | ||||||||||||
| finally: | ||||||||||||
| await asyncio.to_thread(file_obj.close) | ||||||||||||
| if show_progress: | ||||||||||||
| print() | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| async def _stream_to_file( | ||||||||||||
|
sourcery-ai[bot] marked this conversation as resolved.
|
||||||||||||
| stream: aiohttp.StreamReader, | ||||||||||||
| file_obj: BinaryIO, | ||||||||||||
| *, | ||||||||||||
| total_size: int, | ||||||||||||
| start_time: float, | ||||||||||||
| show_progress: bool, | ||||||||||||
| chunk_size: int = 8192, | ||||||||||||
| flush_threshold: int = 256 * 1024, | ||||||||||||
| ) -> None: | ||||||||||||
| """Stream HTTP response into file with buffered thread-offloaded writes.""" | ||||||||||||
| downloaded_size = 0 | ||||||||||||
| buffered = bytearray() | ||||||||||||
| progress_total = total_size if total_size > 0 else None | ||||||||||||
|
|
||||||||||||
| while True: | ||||||||||||
| chunk = await stream.read(chunk_size) | ||||||||||||
| if not chunk: | ||||||||||||
| break | ||||||||||||
| buffered.extend(chunk) | ||||||||||||
| downloaded_size += len(chunk) | ||||||||||||
|
|
||||||||||||
| if len(buffered) >= flush_threshold: | ||||||||||||
| chunk_to_write = bytes(buffered) | ||||||||||||
| buffered.clear() | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 在这里通过 同样的优化也适用于函数末尾对剩余缓冲区的写入(第 222 行)。
Suggested change
|
||||||||||||
| await asyncio.to_thread(file_obj.write, chunk_to_write) | ||||||||||||
|
|
||||||||||||
| if show_progress: | ||||||||||||
| elapsed_time = max(time.time() - start_time, 1e-6) | ||||||||||||
| speed = downloaded_size / 1024 / elapsed_time # KB/s | ||||||||||||
| if progress_total: | ||||||||||||
| percent = downloaded_size / progress_total | ||||||||||||
| print( | ||||||||||||
| f"\r下载进度: {percent:.2%} 速度: {speed:.2f} KB/s", | ||||||||||||
| end="", | ||||||||||||
| ) | ||||||||||||
| else: | ||||||||||||
| print( | ||||||||||||
| f"\r已下载: {downloaded_size} 字节 速度: {speed:.2f} KB/s", | ||||||||||||
| end="", | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| if buffered: | ||||||||||||
| await asyncio.to_thread(file_obj.write, bytes(buffered)) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| async def file_to_base64(file_path: str) -> str: | ||||||||||||
| data_bytes = await asyncio.to_thread(Path(file_path).read_bytes) | ||||||||||||
| base64_str = base64.b64encode(data_bytes).decode() | ||||||||||||
|
|
||||||||||||
Uh oh!
There was an error while loading. Please reload this page.