Skip to content

Commit ec57b98

Browse files
committed
refactor: simplify platform stats merge helpers
1 parent 7de3164 commit ec57b98

2 files changed

Lines changed: 64 additions & 109 deletions

File tree

astrbot/core/backup/importer.py

Lines changed: 62 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import os
1212
import shutil
1313
import zipfile
14-
from collections.abc import Callable
1514
from dataclasses import dataclass, field
1615
from datetime import datetime, timezone
1716
from pathlib import Path
@@ -62,35 +61,7 @@ def _get_major_version(version_str: str) -> str:
6261

6362
CMD_CONFIG_FILE_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
6463
KB_PATH = get_astrbot_knowledge_base_path()
65-
DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT = 5
66-
PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV = (
67-
"ASTRBOT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT"
68-
)
69-
70-
71-
def _load_platform_stats_invalid_count_warn_limit() -> int:
72-
raw_value = os.getenv(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV)
73-
if raw_value is None:
74-
return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT
75-
try:
76-
value = int(raw_value)
77-
if value < 0:
78-
raise ValueError("negative")
79-
return value
80-
except (TypeError, ValueError):
81-
logger.warning(
82-
"Invalid env %s=%r, fallback to default %d",
83-
PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV,
84-
raw_value,
85-
DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT,
86-
)
87-
return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT
88-
89-
90-
# Warning limit per _merge_platform_stats_rows invocation; configurable by env.
91-
PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT = (
92-
_load_platform_stats_invalid_count_warn_limit()
93-
)
64+
PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT = 5
9465

9566

9667
@dataclass
@@ -200,11 +171,6 @@ def __init__(
200171
self.kb_manager = kb_manager
201172
self.config_path = config_path
202173
self.kb_root_dir = kb_root_dir
203-
self._main_table_preprocessors: dict[
204-
str, Callable[[list[dict[str, Any]]], list[dict[str, Any]]]
205-
] = {
206-
"platform_stats": self._merge_platform_stats_rows,
207-
}
208174

209175
def pre_check(self, zip_path: str) -> ImportPreCheckResult:
210176
"""预检查备份文件
@@ -559,16 +525,15 @@ async def _import_main_database(
559525
def _preprocess_main_table_rows(
560526
self, table_name: str, rows: list[dict[str, Any]]
561527
) -> list[dict[str, Any]]:
562-
preprocessor = self._main_table_preprocessors.get(table_name)
563-
if preprocessor is None:
564-
return rows
565-
normalized_rows = preprocessor(rows)
566-
duplicate_count = len(rows) - len(normalized_rows)
567-
if duplicate_count > 0:
568-
logger.warning(
569-
f"检测到 {table_name} 重复键 {duplicate_count} 条,已在导入前聚合"
570-
)
571-
return normalized_rows
528+
if table_name == "platform_stats":
529+
normalized_rows = self._merge_platform_stats_rows(rows)
530+
duplicate_count = len(rows) - len(normalized_rows)
531+
if duplicate_count > 0:
532+
logger.warning(
533+
f"检测到 {table_name} 重复键 {duplicate_count} 条,已在导入前聚合"
534+
)
535+
return normalized_rows
536+
return rows
572537

573538
def _merge_platform_stats_rows(
574539
self, rows: list[dict[str, Any]]
@@ -584,28 +549,10 @@ def _merge_platform_stats_rows(
584549
non_mergeable: list[dict[str, Any]] = []
585550
invalid_count_warned = 0
586551

587-
def parse_count(raw_count: Any, key: tuple[str, str, str]) -> int:
588-
nonlocal invalid_count_warned
589-
try:
590-
return int(raw_count)
591-
except (TypeError, ValueError):
592-
if invalid_count_warned < PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT:
593-
logger.warning(
594-
"platform_stats count 非法,已按 0 处理: value=%r, key=%s",
595-
raw_count,
596-
key,
597-
)
598-
invalid_count_warned += 1
599-
return 0
600-
601552
for row in rows:
602-
normalized_row = dict(row)
603-
normalized_timestamp, is_timestamp_valid = (
604-
self._normalize_platform_stats_timestamp(
605-
normalized_row.get("timestamp")
606-
)
553+
normalized_row, normalized_timestamp, is_timestamp_valid = (
554+
self._normalize_platform_stats_row(row)
607555
)
608-
normalized_row["timestamp"] = normalized_timestamp
609556

610557
platform_id = normalized_row.get("platform_id")
611558
platform_type = normalized_row.get("platform_type")
@@ -614,7 +561,11 @@ def parse_count(raw_count: Any, key: tuple[str, str, str]) -> int:
614561
repr(platform_id),
615562
repr(platform_type),
616563
)
617-
count = parse_count(normalized_row.get("count", 0), key_for_log)
564+
count, invalid_count_warned = self._parse_platform_stats_count(
565+
normalized_row.get("count", 0),
566+
key_for_log,
567+
invalid_count_warned,
568+
)
618569
normalized_row["count"] = count
619570

620571
if not is_timestamp_valid:
@@ -634,32 +585,70 @@ def parse_count(raw_count: Any, key: tuple[str, str, str]) -> int:
634585

635586
return [*non_mergeable, *merged.values()]
636587

588+
def _parse_platform_stats_count(
589+
self,
590+
raw_count: Any,
591+
key_for_log: tuple[str, str, str],
592+
warned_count: int,
593+
) -> tuple[int, int]:
594+
if warned_count >= PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT:
595+
try:
596+
return int(raw_count), warned_count
597+
except (TypeError, ValueError):
598+
return 0, warned_count
599+
try:
600+
return int(raw_count), warned_count
601+
except (TypeError, ValueError):
602+
logger.warning(
603+
"platform_stats count 非法,已按 0 处理: value=%r, key=%s",
604+
raw_count,
605+
key_for_log,
606+
)
607+
return 0, warned_count + 1
608+
609+
def _normalize_platform_stats_row(
610+
self, row: dict[str, Any]
611+
) -> tuple[dict[str, Any], str, bool]:
612+
normalized_row = dict(row)
613+
raw_timestamp = normalized_row.get("timestamp")
614+
normalized_timestamp = self._normalize_platform_stats_timestamp(raw_timestamp)
615+
if normalized_timestamp is None:
616+
if isinstance(raw_timestamp, str):
617+
normalized_row["timestamp"] = raw_timestamp.strip()
618+
elif raw_timestamp is None:
619+
normalized_row["timestamp"] = ""
620+
else:
621+
normalized_row["timestamp"] = str(raw_timestamp)
622+
return normalized_row, normalized_row["timestamp"], False
623+
normalized_row["timestamp"] = normalized_timestamp
624+
return normalized_row, normalized_timestamp, True
625+
637626
def _to_utc_iso(self, dt: datetime) -> str:
638627
if dt.tzinfo is None:
639628
dt = dt.replace(tzinfo=timezone.utc)
640629
else:
641630
dt = dt.astimezone(timezone.utc)
642631
return dt.isoformat()
643632

644-
def _normalize_platform_stats_timestamp(self, value: Any) -> tuple[str, bool]:
633+
def _normalize_platform_stats_timestamp(self, value: Any) -> str | None:
645634
if isinstance(value, datetime):
646-
return self._to_utc_iso(value), True
635+
return self._to_utc_iso(value)
647636

648637
if isinstance(value, str):
649638
timestamp = value.strip()
650639
if not timestamp:
651-
return "", False
640+
return None
652641
if timestamp.endswith("Z"):
653642
timestamp = f"{timestamp[:-1]}+00:00"
654643
try:
655-
return self._to_utc_iso(datetime.fromisoformat(timestamp)), True
644+
return self._to_utc_iso(datetime.fromisoformat(timestamp))
656645
except ValueError:
657-
return timestamp, False
646+
return None
658647

659648
if value is None:
660-
return "", False
649+
return None
661650

662-
return str(value), False
651+
return None
663652

664653
async def _import_knowledge_bases(
665654
self,

tests/test_backup.py

Lines changed: 2 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,10 @@
1717
)
1818
from astrbot.core.backup.exporter import AstrBotExporter
1919
from astrbot.core.backup.importer import (
20-
DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT,
2120
DatabaseClearError,
22-
PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV,
2321
PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT,
2422
AstrBotImporter,
2523
ImportResult,
26-
_load_platform_stats_invalid_count_warn_limit,
2724
_get_major_version,
2825
)
2926
from astrbot.core.config.default import VERSION
@@ -384,45 +381,14 @@ def test_normalize_platform_stats_timestamp_treats_naive_as_utc(self):
384381
"""测试 naive timestamp 会统一转为显式 UTC 偏移"""
385382
importer = AstrBotImporter(main_db=MagicMock())
386383

387-
normalized, is_valid = importer._normalize_platform_stats_timestamp(
388-
"2025-12-13T21:00:00"
389-
)
390-
assert is_valid is True
384+
normalized = importer._normalize_platform_stats_timestamp("2025-12-13T21:00:00")
391385
assert normalized == "2025-12-13T21:00:00+00:00"
392386

393-
normalized_dt, is_valid_dt = importer._normalize_platform_stats_timestamp(
387+
normalized_dt = importer._normalize_platform_stats_timestamp(
394388
datetime(2025, 12, 13, 21, 0, 0)
395389
)
396-
assert is_valid_dt is True
397390
assert normalized_dt == "2025-12-13T21:00:00+00:00"
398391

399-
def test_load_platform_stats_invalid_count_warn_limit(self, monkeypatch):
400-
"""测试告警阈值环境变量解析"""
401-
monkeypatch.delenv(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV, raising=False)
402-
assert (
403-
_load_platform_stats_invalid_count_warn_limit()
404-
== DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT
405-
)
406-
407-
monkeypatch.setenv(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV, "10")
408-
assert _load_platform_stats_invalid_count_warn_limit() == 10
409-
410-
with patch("astrbot.core.backup.importer.logger.warning") as warning_mock:
411-
monkeypatch.setenv(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV, "-1")
412-
assert (
413-
_load_platform_stats_invalid_count_warn_limit()
414-
== DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT
415-
)
416-
assert warning_mock.call_count == 1
417-
418-
with patch("astrbot.core.backup.importer.logger.warning") as warning_mock:
419-
monkeypatch.setenv(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV, "bad")
420-
assert (
421-
_load_platform_stats_invalid_count_warn_limit()
422-
== DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT
423-
)
424-
assert warning_mock.call_count == 1
425-
426392
def test_merge_platform_stats_rows_warns_on_invalid_count(self):
427393
"""测试 platform_stats count 非法时会告警并按 0 处理(含上限)"""
428394
importer = AstrBotImporter(main_db=MagicMock())

0 commit comments

Comments
 (0)