Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 31 additions & 13 deletions astrbot/core/backup/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,10 @@ async def export_all(
# 3. 导出配置文件
if progress_callback:
await progress_callback("config", 0, 100, "正在导出配置文件...")
if await asyncio.to_thread(os.path.exists, self.config_path):
config_content = await asyncio.to_thread(
Path(self.config_path).read_text, encoding="utf-8"
)
config_content = await asyncio.to_thread(
self._read_text_if_exists, self.config_path
)
if config_content is not None:
zf.writestr("config/cmd_config.json", config_content)
self._add_checksum("config/cmd_config.json", config_content)
if progress_callback:
Expand Down Expand Up @@ -361,17 +361,35 @@ async def _export_attachments(
self, zf: zipfile.ZipFile, attachments: list[dict]
) -> None:
"""导出附件文件"""
await asyncio.to_thread(self._export_attachments_sync, zf, attachments)

def _export_attachments_sync(
self, zf: zipfile.ZipFile, attachments: list[dict]
) -> None:
"""在单个线程中批量导出附件,减少高频线程切换。"""
for attachment in attachments:
file_path = attachment.get("path", "")
attachment_id = attachment.get("attachment_id", "")
try:
file_path = attachment.get("path", "")
if file_path and await asyncio.to_thread(os.path.exists, file_path):
# 使用 attachment_id 作为文件名
attachment_id = attachment.get("attachment_id", "")
ext = os.path.splitext(file_path)[1]
archive_path = f"files/attachments/{attachment_id}{ext}"
zf.write(file_path, archive_path)
except Exception as e:
logger.warning(f"导出附件失败: {e}")
if not file_path:
continue
# 使用 attachment_id 作为文件名
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
ext = os.path.splitext(file_path)[1]
archive_path = f"files/attachments/{attachment_id}{ext}"
zf.write(file_path, archive_path)
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
except FileNotFoundError:
# 和旧逻辑保持一致:缺失文件直接跳过。
continue
except OSError as e:
logger.warning(
f"导出附件失败 (path={file_path}, attachment_id={attachment_id or 'unknown'}): {e}"
)

def _read_text_if_exists(self, file_path: str) -> str | None:
"""Read text file when it exists in a single synchronous call."""
if not os.path.exists(file_path):
return None
return Path(file_path).read_text(encoding="utf-8")

def _model_to_dict(self, record: Any) -> dict:
"""将 SQLModel 实例转换为字典
Expand Down
63 changes: 41 additions & 22 deletions astrbot/core/message/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,16 @@
from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64


def _absolute_path(path: str) -> str:
return os.path.abspath(path)


def _absolute_path_if_exists(path: str | None) -> str | None:
if not path or not os.path.exists(path):
return None
return os.path.abspath(path)


class ComponentType(StrEnum):
# Basic Segment Types
Plain = "Plain" # plain text message
Expand Down Expand Up @@ -159,17 +169,18 @@ async def convert_to_file_path(self) -> str:
return self.file[8:]
if self.file.startswith("http"):
file_path = await download_image_by_url(self.file)
return await asyncio.to_thread(os.path.abspath, file_path)
return await asyncio.to_thread(_absolute_path, file_path)
if self.file.startswith("base64://"):
bs64_data = self.file.removeprefix("base64://")
image_bytes = base64.b64decode(bs64_data)
file_path = os.path.join(
get_astrbot_temp_path(), f"recordseg_{uuid.uuid4()}.jpg"
)
await asyncio.to_thread(Path(file_path).write_bytes, image_bytes)
return await asyncio.to_thread(os.path.abspath, file_path)
if await asyncio.to_thread(os.path.exists, self.file):
return await asyncio.to_thread(os.path.abspath, self.file)
return await asyncio.to_thread(_absolute_path, file_path)
local_path = await asyncio.to_thread(_absolute_path_if_exists, self.file)
if local_path:
return local_path
Comment on lines +181 to +183
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The message components (Record, Video, Image, File) allow the use of the file:/// prefix or direct absolute paths in their file or url attributes. These paths are resolved using os.path.abspath or returned directly without any validation against a safe directory (e.g., the bot's data or temp directory). Since these components are often constructed from untrusted external input (e.g., incoming messages from various platforms), an attacker can craft a message that causes the bot to read arbitrary files from the server's filesystem. This can lead to the leakage of sensitive information such as configuration files, database files, or system credentials.

raise Exception(f"not a valid file: {self.file}")

async def convert_to_base64(self) -> str:
Expand All @@ -189,10 +200,11 @@ async def convert_to_base64(self) -> str:
bs64_data = await file_to_base64(file_path)
elif self.file.startswith("base64://"):
bs64_data = self.file
elif await asyncio.to_thread(os.path.exists, self.file):
bs64_data = await file_to_base64(self.file)
else:
raise Exception(f"not a valid file: {self.file}")
try:
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
bs64_data = await file_to_base64(self.file)
except OSError as exc:
raise Exception(f"not a valid file: {self.file}") from exc
bs64_data = bs64_data.removeprefix("base64://")
return bs64_data

Expand Down Expand Up @@ -256,11 +268,15 @@ async def convert_to_file_path(self) -> str:
get_astrbot_temp_path(), f"videoseg_{uuid.uuid4().hex}"
)
await download_file(url, video_file_path)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The convert_to_file_path and convert_to_base64 methods in Record, Video, and Image components automatically download content from URLs provided in the file or url attributes. If these attributes are controlled by an untrusted user, an attacker can cause the bot to make requests to internal network services, potentially bypassing firewalls or accessing internal APIs.

if await asyncio.to_thread(os.path.exists, video_file_path):
return await asyncio.to_thread(os.path.abspath, video_file_path)
local_path = await asyncio.to_thread(
_absolute_path_if_exists, video_file_path
)
if local_path:
return local_path
raise Exception(f"download failed: {url}")
if await asyncio.to_thread(os.path.exists, url):
return await asyncio.to_thread(os.path.abspath, url)
local_path = await asyncio.to_thread(_absolute_path_if_exists, url)
if local_path:
return local_path
raise Exception(f"not a valid file: {url}")

async def register_to_file_service(self) -> str:
Expand Down Expand Up @@ -449,17 +465,18 @@ async def convert_to_file_path(self) -> str:
return url[8:]
if url.startswith("http"):
image_file_path = await download_image_by_url(url)
return await asyncio.to_thread(os.path.abspath, image_file_path)
return await asyncio.to_thread(_absolute_path, image_file_path)
if url.startswith("base64://"):
bs64_data = url.removeprefix("base64://")
image_bytes = base64.b64decode(bs64_data)
image_file_path = os.path.join(
get_astrbot_temp_path(), f"imgseg_{uuid.uuid4()}.jpg"
)
await asyncio.to_thread(Path(image_file_path).write_bytes, image_bytes)
return await asyncio.to_thread(os.path.abspath, image_file_path)
if await asyncio.to_thread(os.path.exists, url):
return await asyncio.to_thread(os.path.abspath, url)
return await asyncio.to_thread(_absolute_path, image_file_path)
local_path = await asyncio.to_thread(_absolute_path_if_exists, url)
if local_path:
return local_path
raise Exception(f"not a valid file: {url}")

async def convert_to_base64(self) -> str:
Expand All @@ -480,10 +497,11 @@ async def convert_to_base64(self) -> str:
bs64_data = await file_to_base64(image_file_path)
elif url.startswith("base64://"):
bs64_data = url
elif await asyncio.to_thread(os.path.exists, url):
bs64_data = await file_to_base64(url)
else:
raise Exception(f"not a valid file: {url}")
try:
bs64_data = await file_to_base64(url)
except OSError as exc:
raise Exception(f"not a valid file: {url}") from exc
bs64_data = bs64_data.removeprefix("base64://")
return bs64_data

Expand Down Expand Up @@ -734,8 +752,9 @@ async def get_file(self, allow_return_url: bool = False) -> str:
):
path = path[1:]

if await asyncio.to_thread(os.path.exists, path):
return await asyncio.to_thread(os.path.abspath, path)
local_path = await asyncio.to_thread(_absolute_path_if_exists, path)
if local_path:
return local_path

if self.url:
await self._download_file()
Expand All @@ -750,7 +769,7 @@ async def get_file(self, allow_return_url: bool = False) -> str:
and path[2] == ":"
):
path = path[1:]
return await asyncio.to_thread(os.path.abspath, path)
return await asyncio.to_thread(_absolute_path, path)

return ""

Expand All @@ -766,7 +785,7 @@ async def _download_file(self) -> None:
filename = f"fileseg_{uuid.uuid4().hex}"
file_path = os.path.join(download_dir, filename)
await download_file(self.url, file_path)
self.file_ = await asyncio.to_thread(os.path.abspath, file_path)
self.file_ = await asyncio.to_thread(_absolute_path, file_path)

async def register_to_file_service(self) -> str:
"""将文件注册到文件服务。
Expand Down
93 changes: 61 additions & 32 deletions astrbot/core/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import uuid
import zipfile
from pathlib import Path
from typing import BinaryIO

import aiohttp
import certifi
Expand Down Expand Up @@ -134,29 +135,18 @@ async def download_file(url: str, path: str, show_progress: bool = False) -> Non
if resp.status != 200:
raise Exception(f"下载文件失败: {resp.status}")
total_size = int(resp.headers.get("content-length", 0))
downloaded_size = 0
start_time = time.time()
if show_progress:
print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}")
file_obj = await asyncio.to_thread(Path(path).open, "wb")
try:
while True:
chunk = await resp.content.read(8192)
if not chunk:
break
await asyncio.to_thread(file_obj.write, chunk)
downloaded_size += len(chunk)
if show_progress:
elapsed_time = (
time.time() - start_time
if time.time() - start_time > 0
else 1
)
speed = downloaded_size / 1024 / elapsed_time # KB/s
print(
f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s",
end="",
)
await _stream_to_file(
resp.content,
file_obj,
total_size=total_size,
start_time=start_time,
show_progress=show_progress,
)
finally:
await asyncio.to_thread(file_obj.close)
except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError):
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Expand All @@ -176,31 +166,70 @@ async def download_file(url: str, path: str, show_progress: bool = False) -> Non
async with aiohttp.ClientSession() as session:
async with session.get(url, ssl=ssl_context, timeout=120) as resp:
total_size = int(resp.headers.get("content-length", 0))
downloaded_size = 0
start_time = time.time()
if show_progress:
print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}")
file_obj = await asyncio.to_thread(Path(path).open, "wb")
try:
while True:
chunk = await resp.content.read(8192)
if not chunk:
break
await asyncio.to_thread(file_obj.write, chunk)
downloaded_size += len(chunk)
if show_progress:
elapsed_time = time.time() - start_time
speed = downloaded_size / 1024 / elapsed_time # KB/s
print(
f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s",
end="",
)
await _stream_to_file(
resp.content,
file_obj,
total_size=total_size,
start_time=start_time,
show_progress=show_progress,
)
finally:
await asyncio.to_thread(file_obj.close)
if show_progress:
print()


async def _stream_to_file(
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
stream: aiohttp.StreamReader,
file_obj: BinaryIO,
*,
total_size: int,
start_time: float,
show_progress: bool,
chunk_size: int = 8192,
flush_threshold: int = 256 * 1024,
) -> None:
"""Stream HTTP response into file with buffered thread-offloaded writes."""
downloaded_size = 0
buffered = bytearray()
progress_total = total_size if total_size > 0 else None

while True:
chunk = await stream.read(chunk_size)
if not chunk:
break
buffered.extend(chunk)
downloaded_size += len(chunk)

if len(buffered) >= flush_threshold:
chunk_to_write = bytes(buffered)
buffered.clear()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

在这里通过 bytes(buffered) 创建 chunk_to_write 实际上创建了一个不必要的内存副本。file_obj.write 方法可以直接接受 bytearray 类型的参数。由于 await asyncio.to_thread 会阻塞直到写操作完成,因此之后再调用 buffered.clear() 是线程安全的。这样可以减少内存分配和复制的开销。

同样的优化也适用于函数末尾对剩余缓冲区的写入(第 222 行)。

Suggested change
if len(buffered) >= flush_threshold:
chunk_to_write = bytes(buffered)
buffered.clear()
await asyncio.to_thread(file_obj.write, buffered)
buffered.clear()

await asyncio.to_thread(file_obj.write, chunk_to_write)

if show_progress:
elapsed_time = max(time.time() - start_time, 1e-6)
speed = downloaded_size / 1024 / elapsed_time # KB/s
if progress_total:
percent = downloaded_size / progress_total
print(
f"\r下载进度: {percent:.2%} 速度: {speed:.2f} KB/s",
end="",
)
else:
print(
f"\r已下载: {downloaded_size} 字节 速度: {speed:.2f} KB/s",
end="",
)

if buffered:
await asyncio.to_thread(file_obj.write, bytes(buffered))


async def file_to_base64(file_path: str) -> str:
data_bytes = await asyncio.to_thread(Path(file_path).read_bytes)
base64_str = base64.b64encode(data_bytes).decode()
Expand Down
Loading