Skip to content

Commit 7de3164

Browse files
committed
refactor: streamline platform stats preprocessing
1 parent 817286f commit 7de3164

2 files changed

Lines changed: 85 additions & 80 deletions

File tree

astrbot/core/backup/importer.py

Lines changed: 60 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import os
1212
import shutil
1313
import zipfile
14+
from collections.abc import Callable
1415
from dataclasses import dataclass, field
1516
from datetime import datetime, timezone
1617
from pathlib import Path
@@ -67,34 +68,23 @@ def _get_major_version(version_str: str) -> str:
6768
)
6869

6970

70-
def _resolve_platform_stats_invalid_count_warn_limit(
71-
raw_value: str | None,
72-
) -> tuple[int, bool]:
73-
"""Resolve warn limit value and return whether the input was valid."""
71+
def _load_platform_stats_invalid_count_warn_limit() -> int:
72+
raw_value = os.getenv(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV)
7473
if raw_value is None:
75-
return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT, True
74+
return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT
7675
try:
7776
value = int(raw_value)
77+
if value < 0:
78+
raise ValueError("negative")
79+
return value
7880
except (TypeError, ValueError):
79-
return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT, False
80-
if value < 0:
81-
return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT, False
82-
return value, True
83-
84-
85-
def _load_platform_stats_invalid_count_warn_limit() -> int:
86-
raw_value = os.getenv(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV)
87-
resolved_value, is_valid = _resolve_platform_stats_invalid_count_warn_limit(
88-
raw_value
89-
)
90-
if raw_value is not None and not is_valid:
9181
logger.warning(
9282
"Invalid env %s=%r, fallback to default %d",
9383
PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV,
9484
raw_value,
9585
DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT,
9686
)
97-
return resolved_value
87+
return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT
9888

9989

10090
# Warning limit per _merge_platform_stats_rows invocation; configurable by env.
@@ -210,6 +200,11 @@ def __init__(
210200
self.kb_manager = kb_manager
211201
self.config_path = config_path
212202
self.kb_root_dir = kb_root_dir
203+
self._main_table_preprocessors: dict[
204+
str, Callable[[list[dict[str, Any]]], list[dict[str, Any]]]
205+
] = {
206+
"platform_stats": self._merge_platform_stats_rows,
207+
}
213208

214209
def pre_check(self, zip_path: str) -> ImportPreCheckResult:
215210
"""预检查备份文件
@@ -543,14 +538,7 @@ async def _import_main_database(
543538
if not model_class:
544539
logger.warning(f"未知的表: {table_name}")
545540
continue
546-
normalized_rows = rows
547-
if table_name == "platform_stats":
548-
normalized_rows = self._merge_platform_stats_rows(rows)
549-
duplicate_count = len(rows) - len(normalized_rows)
550-
if duplicate_count > 0:
551-
logger.warning(
552-
f"检测到 platform_stats 重复键 {duplicate_count} 条,已在导入前聚合"
553-
)
541+
normalized_rows = self._preprocess_main_table_rows(table_name, rows)
554542

555543
count = 0
556544
for row in normalized_rows:
@@ -568,6 +556,20 @@ async def _import_main_database(
568556

569557
return imported
570558

559+
def _preprocess_main_table_rows(
560+
self, table_name: str, rows: list[dict[str, Any]]
561+
) -> list[dict[str, Any]]:
562+
preprocessor = self._main_table_preprocessors.get(table_name)
563+
if preprocessor is None:
564+
return rows
565+
normalized_rows = preprocessor(rows)
566+
duplicate_count = len(rows) - len(normalized_rows)
567+
if duplicate_count > 0:
568+
logger.warning(
569+
f"检测到 {table_name} 重复键 {duplicate_count} 条,已在导入前聚合"
570+
)
571+
return normalized_rows
572+
571573
def _merge_platform_stats_rows(
572574
self, rows: list[dict[str, Any]]
573575
) -> list[dict[str, Any]]:
@@ -579,8 +581,23 @@ def _merge_platform_stats_rows(
579581
- Invalid count warnings are rate-limited per function invocation.
580582
"""
581583
merged: dict[tuple[str, str, str], dict[str, Any]] = {}
582-
result: list[dict[str, Any]] = []
584+
non_mergeable: list[dict[str, Any]] = []
583585
invalid_count_warned = 0
586+
587+
def parse_count(raw_count: Any, key: tuple[str, str, str]) -> int:
588+
nonlocal invalid_count_warned
589+
try:
590+
return int(raw_count)
591+
except (TypeError, ValueError):
592+
if invalid_count_warned < PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT:
593+
logger.warning(
594+
"platform_stats count 非法,已按 0 处理: value=%r, key=%s",
595+
raw_count,
596+
key,
597+
)
598+
invalid_count_warned += 1
599+
return 0
600+
584601
for row in rows:
585602
normalized_row = dict(row)
586603
normalized_timestamp, is_timestamp_valid = (
@@ -597,73 +614,51 @@ def _merge_platform_stats_rows(
597614
repr(platform_id),
598615
repr(platform_type),
599616
)
600-
count, invalid_count_warned = self._parse_platform_stats_count(
601-
normalized_row.get("count", 0), invalid_count_warned, key_for_log
602-
)
617+
count = parse_count(normalized_row.get("count", 0), key_for_log)
603618
normalized_row["count"] = count
604619

605-
# Invalid timestamps should never be merged.
606620
if not is_timestamp_valid:
607-
result.append(normalized_row)
621+
non_mergeable.append(normalized_row)
608622
continue
609623

610624
if not isinstance(platform_id, str) or not isinstance(platform_type, str):
611-
result.append(normalized_row)
625+
non_mergeable.append(normalized_row)
612626
continue
613627

614628
key = (normalized_timestamp, platform_id, platform_type)
615629
existing = merged.get(key)
616630
if existing is None:
617631
merged[key] = normalized_row
618-
result.append(normalized_row)
619632
else:
620633
existing["count"] += count
621634

622-
return result
635+
return [*non_mergeable, *merged.values()]
623636

624-
def _parse_platform_stats_count(
625-
self,
626-
raw_count: Any,
627-
invalid_count_warned: int,
628-
key: tuple[str, str, str],
629-
) -> tuple[int, int]:
630-
"""Safe int parse with per-call rate-limited warning."""
631-
try:
632-
return int(raw_count), invalid_count_warned
633-
except (TypeError, ValueError):
634-
if invalid_count_warned < PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT:
635-
logger.warning(
636-
"platform_stats count 非法,已按 0 处理: "
637-
f"value={raw_count!r}, key={key}"
638-
)
639-
invalid_count_warned += 1
640-
return 0, invalid_count_warned
637+
def _to_utc_iso(self, dt: datetime) -> str:
638+
if dt.tzinfo is None:
639+
dt = dt.replace(tzinfo=timezone.utc)
640+
else:
641+
dt = dt.astimezone(timezone.utc)
642+
return dt.isoformat()
641643

642644
def _normalize_platform_stats_timestamp(self, value: Any) -> tuple[str, bool]:
643645
if isinstance(value, datetime):
644-
dt = value
645-
if dt.tzinfo is None:
646-
dt = dt.replace(tzinfo=timezone.utc)
647-
else:
648-
dt = dt.astimezone(timezone.utc)
649-
return dt.isoformat(), True
646+
return self._to_utc_iso(value), True
647+
650648
if isinstance(value, str):
651649
timestamp = value.strip()
652650
if not timestamp:
653651
return "", False
654652
if timestamp.endswith("Z"):
655653
timestamp = f"{timestamp[:-1]}+00:00"
656654
try:
657-
dt = datetime.fromisoformat(timestamp)
658-
if dt.tzinfo is None:
659-
dt = dt.replace(tzinfo=timezone.utc)
660-
else:
661-
dt = dt.astimezone(timezone.utc)
662-
return dt.isoformat(), True
655+
return self._to_utc_iso(datetime.fromisoformat(timestamp)), True
663656
except ValueError:
664-
return value.strip(), False
657+
return timestamp, False
658+
665659
if value is None:
666660
return "", False
661+
667662
return str(value), False
668663

669664
async def _import_knowledge_bases(

tests/test_backup.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@
1919
from astrbot.core.backup.importer import (
2020
DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT,
2121
DatabaseClearError,
22+
PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV,
2223
PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT,
2324
AstrBotImporter,
2425
ImportResult,
26+
_load_platform_stats_invalid_count_warn_limit,
2527
_get_major_version,
26-
_resolve_platform_stats_invalid_count_warn_limit,
2728
)
2829
from astrbot.core.config.default import VERSION
2930
from astrbot.core.db.po import (
@@ -395,23 +396,32 @@ def test_normalize_platform_stats_timestamp_treats_naive_as_utc(self):
395396
assert is_valid_dt is True
396397
assert normalized_dt == "2025-12-13T21:00:00+00:00"
397398

398-
def test_resolve_platform_stats_invalid_count_warn_limit(self):
399-
"""测试非法/合法告警阈值配置解析"""
400-
value, valid = _resolve_platform_stats_invalid_count_warn_limit(None)
401-
assert valid is True
402-
assert value == DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT
399+
def test_load_platform_stats_invalid_count_warn_limit(self, monkeypatch):
400+
"""测试告警阈值环境变量解析"""
401+
monkeypatch.delenv(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV, raising=False)
402+
assert (
403+
_load_platform_stats_invalid_count_warn_limit()
404+
== DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT
405+
)
403406

404-
value, valid = _resolve_platform_stats_invalid_count_warn_limit("10")
405-
assert valid is True
406-
assert value == 10
407+
monkeypatch.setenv(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV, "10")
408+
assert _load_platform_stats_invalid_count_warn_limit() == 10
407409

408-
value, valid = _resolve_platform_stats_invalid_count_warn_limit("-1")
409-
assert valid is False
410-
assert value == DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT
410+
with patch("astrbot.core.backup.importer.logger.warning") as warning_mock:
411+
monkeypatch.setenv(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV, "-1")
412+
assert (
413+
_load_platform_stats_invalid_count_warn_limit()
414+
== DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT
415+
)
416+
assert warning_mock.call_count == 1
411417

412-
value, valid = _resolve_platform_stats_invalid_count_warn_limit("bad")
413-
assert valid is False
414-
assert value == DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT
418+
with patch("astrbot.core.backup.importer.logger.warning") as warning_mock:
419+
monkeypatch.setenv(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV, "bad")
420+
assert (
421+
_load_platform_stats_invalid_count_warn_limit()
422+
== DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT
423+
)
424+
assert warning_mock.call_count == 1
415425

416426
def test_merge_platform_stats_rows_warns_on_invalid_count(self):
417427
"""测试 platform_stats count 非法时会告警并按 0 处理(含上限)"""

0 commit comments

Comments
 (0)