Skip to content

Commit fc2af1f

Browse files
committed
fix: harden backup import for duplicate platform stats
- 修复 replace 模式下主库清空失败仍继续导入的问题。 - 导入前对 platform_stats 重复键做聚合(count 累加),并统一时间戳判重格式。 - 非法 count 按 0 处理并告警(限流),补充对应测试。
1 parent d561046 commit fc2af1f

2 files changed

Lines changed: 182 additions & 4 deletions

File tree

astrbot/core/backup/importer.py

Lines changed: 92 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import shutil
1313
import zipfile
1414
from dataclasses import dataclass, field
15-
from datetime import datetime
15+
from datetime import datetime, timezone
1616
from pathlib import Path
1717
from typing import TYPE_CHECKING, Any
1818

@@ -452,7 +452,7 @@ async def _clear_main_db(self) -> None:
452452
await session.execute(delete(model_class))
453453
logger.debug(f"已清空表 {table_name}")
454454
except Exception as e:
455-
logger.warning(f"清空表 {table_name} 失败: {e}")
455+
raise RuntimeError(f"清空表 {table_name} 失败: {e}") from e
456456

457457
async def _clear_kb_data(self) -> None:
458458
"""清空知识库数据"""
@@ -494,9 +494,18 @@ async def _import_main_database(
494494
if not model_class:
495495
logger.warning(f"未知的表: {table_name}")
496496
continue
497+
normalized_rows = rows
498+
if table_name == "platform_stats":
499+
normalized_rows, duplicate_count = (
500+
self._merge_platform_stats_rows(rows)
501+
)
502+
if duplicate_count > 0:
503+
logger.warning(
504+
f"检测到 platform_stats 重复键 {duplicate_count} 条,已在导入前聚合"
505+
)
497506

498507
count = 0
499-
for row in rows:
508+
for row in normalized_rows:
500509
try:
501510
# 转换 datetime 字符串为 datetime 对象
502511
row = self._convert_datetime_fields(row, model_class)
@@ -511,6 +520,86 @@ async def _import_main_database(
511520

512521
return imported
513522

523+
def _merge_platform_stats_rows(
524+
self, rows: list[dict[str, Any]]
525+
) -> tuple[list[dict[str, Any]], int]:
526+
merged: dict[tuple[str, str, str], dict[str, Any]] = {}
527+
timestamp_cache: dict[str, str] = {}
528+
invalid_count_warned = 0
529+
invalid_count_warn_limit = 5
530+
duplicate_count = 0
531+
for row in rows:
532+
raw_timestamp = row.get("timestamp")
533+
if isinstance(raw_timestamp, str):
534+
normalized_timestamp = timestamp_cache.get(raw_timestamp)
535+
if normalized_timestamp is None:
536+
normalized_timestamp = self._normalize_platform_stats_timestamp(
537+
raw_timestamp
538+
)
539+
timestamp_cache[raw_timestamp] = normalized_timestamp
540+
else:
541+
normalized_timestamp = self._normalize_platform_stats_timestamp(
542+
raw_timestamp
543+
)
544+
key = (
545+
normalized_timestamp,
546+
str(row.get("platform_id")),
547+
str(row.get("platform_type")),
548+
)
549+
existing = merged.get(key)
550+
if existing is None:
551+
merged[key] = dict(row)
552+
continue
553+
duplicate_count += 1
554+
existing_raw_count = existing.get("count", 0)
555+
try:
556+
existing_count = int(existing_raw_count)
557+
except (TypeError, ValueError):
558+
existing_count = 0
559+
if invalid_count_warned < invalid_count_warn_limit:
560+
logger.warning(
561+
"platform_stats count 非法,已按 0 处理: "
562+
f"value={existing_raw_count!r}, key={key}"
563+
)
564+
invalid_count_warned += 1
565+
566+
incoming_raw_count = row.get("count", 0)
567+
try:
568+
incoming_count = int(incoming_raw_count)
569+
except (TypeError, ValueError):
570+
incoming_count = 0
571+
if invalid_count_warned < invalid_count_warn_limit:
572+
logger.warning(
573+
"platform_stats count 非法,已按 0 处理: "
574+
f"value={incoming_raw_count!r}, key={key}"
575+
)
576+
invalid_count_warned += 1
577+
existing["count"] = existing_count + incoming_count
578+
return list(merged.values()), duplicate_count
579+
580+
def _normalize_platform_stats_timestamp(self, value: Any) -> str:
581+
if isinstance(value, datetime):
582+
dt = value
583+
if dt.tzinfo is not None:
584+
dt = dt.astimezone(timezone.utc)
585+
return dt.isoformat()
586+
if isinstance(value, str):
587+
timestamp = value.strip()
588+
if not timestamp:
589+
return ""
590+
if timestamp.endswith("Z"):
591+
timestamp = f"{timestamp[:-1]}+00:00"
592+
try:
593+
dt = datetime.fromisoformat(timestamp)
594+
if dt.tzinfo is not None:
595+
dt = dt.astimezone(timezone.utc)
596+
return dt.isoformat()
597+
except ValueError:
598+
return value.strip()
599+
if value is None:
600+
return ""
601+
return str(value)
602+
514603
async def _import_knowledge_bases(
515604
self,
516605
zf: zipfile.ZipFile,

tests/test_backup.py

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import re
66
import zipfile
77
from datetime import datetime
8-
from unittest.mock import AsyncMock, MagicMock
8+
from unittest.mock import AsyncMock, MagicMock, patch
99

1010
import pytest
1111

@@ -308,6 +308,69 @@ def test_convert_datetime_fields(self):
308308
assert isinstance(result["created_at"], datetime)
309309
assert isinstance(result["updated_at"], datetime)
310310

311+
def test_merge_platform_stats_rows(self):
312+
"""测试 platform_stats 重复键会在导入前聚合"""
313+
importer = AstrBotImporter(main_db=MagicMock())
314+
rows = [
315+
{
316+
"id": 1,
317+
"timestamp": "2025-12-13T20:00:00Z",
318+
"platform_id": "webchat",
319+
"platform_type": "unknown",
320+
"count": 14,
321+
},
322+
{
323+
"id": 80,
324+
"timestamp": "2025-12-13T20:00:00+00:00",
325+
"platform_id": "webchat",
326+
"platform_type": "unknown",
327+
"count": 3,
328+
},
329+
{
330+
"id": 2,
331+
"timestamp": "2025-12-13T21:00:00",
332+
"platform_id": "aiocqhttp",
333+
"platform_type": "unknown",
334+
"count": 1,
335+
},
336+
]
337+
338+
merged_rows, duplicate_count = importer._merge_platform_stats_rows(rows)
339+
340+
assert duplicate_count == 1
341+
assert len(merged_rows) == 2
342+
first = merged_rows[0]
343+
assert first["timestamp"] == "2025-12-13T20:00:00Z"
344+
assert first["platform_id"] == "webchat"
345+
assert first["platform_type"] == "unknown"
346+
assert first["count"] == 17
347+
348+
def test_merge_platform_stats_rows_warns_on_invalid_count(self):
349+
"""测试 platform_stats count 非法时会告警并按 0 处理"""
350+
importer = AstrBotImporter(main_db=MagicMock())
351+
rows = [
352+
{
353+
"timestamp": "2025-12-13T20:00:00+00:00",
354+
"platform_id": "webchat",
355+
"platform_type": "unknown",
356+
"count": 5,
357+
},
358+
{
359+
"timestamp": "2025-12-13T20:00:00Z",
360+
"platform_id": "webchat",
361+
"platform_type": "unknown",
362+
"count": "bad-count",
363+
},
364+
]
365+
366+
with patch("astrbot.core.backup.importer.logger.warning") as warning_mock:
367+
merged_rows, duplicate_count = importer._merge_platform_stats_rows(rows)
368+
369+
assert duplicate_count == 1
370+
assert len(merged_rows) == 1
371+
assert merged_rows[0]["count"] == 5
372+
assert warning_mock.called
373+
311374
@pytest.mark.asyncio
312375
async def test_import_file_not_exists(self, mock_main_db, tmp_path):
313376
"""测试导入不存在的文件"""
@@ -365,6 +428,32 @@ async def test_import_major_version_mismatch(self, mock_main_db, tmp_path):
365428
assert result.success is False
366429
assert any("主版本不兼容" in err for err in result.errors)
367430

431+
@pytest.mark.asyncio
432+
async def test_import_replace_fails_when_clear_main_db_fails(
433+
self, mock_main_db, tmp_path
434+
):
435+
"""测试 replace 模式下主库清空失败会直接终止导入"""
436+
zip_path = tmp_path / "valid_backup.zip"
437+
manifest = {
438+
"version": "1.1",
439+
"astrbot_version": VERSION,
440+
"tables": {"platform_stats": 0},
441+
}
442+
main_data = {"platform_stats": []}
443+
with zipfile.ZipFile(zip_path, "w") as zf:
444+
zf.writestr("manifest.json", json.dumps(manifest))
445+
zf.writestr("databases/main_db.json", json.dumps(main_data))
446+
447+
importer = AstrBotImporter(main_db=mock_main_db)
448+
importer._clear_main_db = AsyncMock(
449+
side_effect=RuntimeError("清空表 platform_stats 失败: db locked")
450+
)
451+
452+
result = await importer.import_all(str(zip_path), mode="replace")
453+
454+
assert result.success is False
455+
assert any("清空表 platform_stats 失败" in err for err in result.errors)
456+
368457

369458
class TestSecureFilename:
370459
"""安全文件名函数测试"""

0 commit comments

Comments
 (0)