Skip to content

Commit 4abea2b

Browse files
Clhikarizouyonghe
andauthored
fix: harden backup import for duplicate platform stats (#5594)
* fix: harden backup import for duplicate platform stats - 修复 replace 模式下主库清空失败仍继续导入的问题。 - 导入前对 platform_stats 重复键做聚合(count 累加),并统一时间戳判重格式。 - 非法 count 按 0 处理并告警(限流),补充对应测试。 * refactor: improve robustness and readability of platform stats import - 告警上限魔法数字提取为模块常量 PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT - 抽取 parse_count 内联函数,消除重复的 try/except 分支 - 存储行的 timestamp 同步写入规范化值,避免落库格式混用 - 补充测试:已有行 count 非法、告警限流、replace 模式中断断言 * fix: normalize invalid platform_stats count for non-duplicate rows * fix: avoid merging invalid platform_stats timestamps * refactor: simplify platform stats merge and normalize naive UTC * refactor: inline platform stats merge helpers * refactor: flatten platform stats merge flow * refactor: harden platform stats merge key handling * refactor: streamline platform stats preprocessing * refactor: simplify platform stats merge helpers * refactor: inline platform stats merge normalization * refactor: extract platform stats merge helpers * refactor: simplify platform stats preprocessing flow * refactor: flatten platform stats preprocess helpers * refactor: streamline platform stats merge helpers * refactor: isolate platform stats warning limiter --------- Co-authored-by: 邹永赫 <1259085392@qq.com>
1 parent 267abfd commit 4abea2b

File tree

2 files changed

+512
-4
lines changed

2 files changed

+512
-4
lines changed

astrbot/core/backup/importer.py

Lines changed: 188 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import shutil
1313
import zipfile
1414
from dataclasses import dataclass, field
15-
from datetime import datetime
15+
from datetime import datetime, timezone
1616
from pathlib import Path
1717
from typing import TYPE_CHECKING, Any
1818

@@ -61,6 +61,69 @@ def _get_major_version(version_str: str) -> str:
6161

6262
CMD_CONFIG_FILE_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
6363
KB_PATH = get_astrbot_knowledge_base_path()
64+
DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT = 5
65+
PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV = (
66+
"ASTRBOT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT"
67+
)
68+
69+
70+
def _load_platform_stats_invalid_count_warn_limit() -> int:
71+
raw_value = os.getenv(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV)
72+
if raw_value is None:
73+
return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT
74+
75+
try:
76+
value = int(raw_value)
77+
if value < 0:
78+
raise ValueError("negative")
79+
return value
80+
except (TypeError, ValueError):
81+
logger.warning(
82+
"Invalid env %s=%r, fallback to default %d",
83+
PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV,
84+
raw_value,
85+
DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT,
86+
)
87+
return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT
88+
89+
90+
PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT = (
91+
_load_platform_stats_invalid_count_warn_limit()
92+
)
93+
94+
95+
class _InvalidCountWarnLimiter:
96+
"""Rate-limit warnings for invalid platform_stats count values."""
97+
98+
def __init__(self, limit: int) -> None:
99+
self.limit = limit
100+
self._count = 0
101+
self._suppression_logged = False
102+
103+
def warn_invalid_count(self, value: Any, key_for_log: tuple[Any, ...]) -> None:
104+
if self.limit > 0:
105+
if self._count < self.limit:
106+
logger.warning(
107+
"platform_stats count 非法,已按 0 处理: value=%r, key=%s",
108+
value,
109+
key_for_log,
110+
)
111+
self._count += 1
112+
if self._count == self.limit and not self._suppression_logged:
113+
logger.warning(
114+
"platform_stats 非法 count 告警已达到上限 (%d),后续将抑制",
115+
self.limit,
116+
)
117+
self._suppression_logged = True
118+
return
119+
120+
if not self._suppression_logged:
121+
# limit <= 0: emit only one suppression warning.
122+
logger.warning(
123+
"platform_stats 非法 count 告警已达到上限 (%d),后续将抑制",
124+
self.limit,
125+
)
126+
self._suppression_logged = True
64127

65128

66129
@dataclass
@@ -138,6 +201,10 @@ def to_dict(self) -> dict:
138201
}
139202

140203

204+
class DatabaseClearError(RuntimeError):
205+
"""Raised when clearing the main database in replace mode fails."""
206+
207+
141208
class AstrBotImporter:
142209
"""AstrBot 数据导入器
143210
@@ -342,6 +409,9 @@ async def import_all(
342409

343410
imported = await self._import_main_database(main_data)
344411
result.imported_tables.update(imported)
412+
except DatabaseClearError as e:
413+
result.add_error(f"清空主数据库失败: {e}")
414+
return result
345415
except Exception as e:
346416
result.add_error(f"导入主数据库失败: {e}")
347417
return result
@@ -452,7 +522,9 @@ async def _clear_main_db(self) -> None:
452522
await session.execute(delete(model_class))
453523
logger.debug(f"已清空表 {table_name}")
454524
except Exception as e:
455-
logger.warning(f"清空表 {table_name} 失败: {e}")
525+
raise DatabaseClearError(
526+
f"清空表 {table_name} 失败: {e}"
527+
) from e
456528

457529
async def _clear_kb_data(self) -> None:
458530
"""清空知识库数据"""
@@ -494,9 +566,10 @@ async def _import_main_database(
494566
if not model_class:
495567
logger.warning(f"未知的表: {table_name}")
496568
continue
569+
normalized_rows = self._preprocess_main_table_rows(table_name, rows)
497570

498571
count = 0
499-
for row in rows:
572+
for row in normalized_rows:
500573
try:
501574
# 转换 datetime 字符串为 datetime 对象
502575
row = self._convert_datetime_fields(row, model_class)
@@ -511,6 +584,118 @@ async def _import_main_database(
511584

512585
return imported
513586

587+
def _preprocess_main_table_rows(
588+
self, table_name: str, rows: list[dict[str, Any]]
589+
) -> list[dict[str, Any]]:
590+
if table_name == "platform_stats":
591+
normalized_rows = self._merge_platform_stats_rows(rows)
592+
duplicate_count = len(rows) - len(normalized_rows)
593+
if duplicate_count > 0:
594+
logger.warning(
595+
"检测到 %s 重复键 %d 条,已在导入前聚合",
596+
table_name,
597+
duplicate_count,
598+
)
599+
return normalized_rows
600+
return rows
601+
602+
def _merge_platform_stats_rows(
603+
self, rows: list[dict[str, Any]]
604+
) -> list[dict[str, Any]]:
605+
"""Merge duplicate platform_stats rows by normalized timestamp/platform key.
606+
607+
Note:
608+
- Invalid/empty timestamps are kept as distinct rows to avoid accidental merging.
609+
- Non-string platform_id/platform_type are kept as distinct rows.
610+
- Invalid count warnings are rate-limited per function invocation.
611+
"""
612+
merged: dict[tuple[str, str, str], dict[str, Any]] = {}
613+
result: list[dict[str, Any]] = []
614+
warn_limiter = _InvalidCountWarnLimiter(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT)
615+
616+
for row in rows:
617+
normalized_row, normalized_timestamp, count = (
618+
self._normalize_platform_stats_entry(row, warn_limiter)
619+
)
620+
platform_id = normalized_row.get("platform_id")
621+
platform_type = normalized_row.get("platform_type")
622+
623+
if (
624+
normalized_timestamp is None
625+
or not isinstance(platform_id, str)
626+
or not isinstance(platform_type, str)
627+
):
628+
result.append(normalized_row)
629+
continue
630+
631+
merge_key = (normalized_timestamp, platform_id, platform_type)
632+
existing = merged.get(merge_key)
633+
if existing is None:
634+
merged[merge_key] = normalized_row
635+
result.append(normalized_row)
636+
else:
637+
existing["count"] += count
638+
639+
return result
640+
641+
def _normalize_platform_stats_entry(
642+
self,
643+
row: dict[str, Any],
644+
warn_limiter: _InvalidCountWarnLimiter,
645+
) -> tuple[dict[str, Any], str | None, int]:
646+
normalized_row = dict(row)
647+
raw_timestamp = normalized_row.get("timestamp")
648+
normalized_timestamp = self._normalize_platform_stats_timestamp(raw_timestamp)
649+
650+
if normalized_timestamp is not None:
651+
normalized_row["timestamp"] = normalized_timestamp
652+
elif isinstance(raw_timestamp, str):
653+
normalized_row["timestamp"] = raw_timestamp.strip()
654+
elif raw_timestamp is None:
655+
normalized_row["timestamp"] = ""
656+
else:
657+
normalized_row["timestamp"] = str(raw_timestamp)
658+
659+
raw_count = normalized_row.get("count", 0)
660+
try:
661+
count = int(raw_count)
662+
except (TypeError, ValueError):
663+
key_for_log = (
664+
normalized_row.get("timestamp"),
665+
repr(normalized_row.get("platform_id")),
666+
repr(normalized_row.get("platform_type")),
667+
)
668+
warn_limiter.warn_invalid_count(raw_count, key_for_log)
669+
count = 0
670+
671+
normalized_row["count"] = count
672+
return normalized_row, normalized_timestamp, count
673+
674+
def _normalize_platform_stats_timestamp(self, value: Any) -> str | None:
675+
if isinstance(value, datetime):
676+
dt = value
677+
if dt.tzinfo is None:
678+
dt = dt.replace(tzinfo=timezone.utc)
679+
else:
680+
dt = dt.astimezone(timezone.utc)
681+
return dt.isoformat()
682+
if isinstance(value, str):
683+
timestamp = value.strip()
684+
if not timestamp:
685+
return None
686+
if timestamp.endswith("Z"):
687+
timestamp = f"{timestamp[:-1]}+00:00"
688+
try:
689+
dt = datetime.fromisoformat(timestamp)
690+
if dt.tzinfo is None:
691+
dt = dt.replace(tzinfo=timezone.utc)
692+
else:
693+
dt = dt.astimezone(timezone.utc)
694+
return dt.isoformat()
695+
except ValueError:
696+
return None
697+
return None
698+
514699
async def _import_knowledge_bases(
515700
self,
516701
zf: zipfile.ZipFile,

0 commit comments

Comments
 (0)