1111import os
1212import shutil
1313import zipfile
14- from collections .abc import Callable
1514from dataclasses import dataclass , field
1615from datetime import datetime , timezone
1716from pathlib import Path
@@ -62,35 +61,7 @@ def _get_major_version(version_str: str) -> str:
6261
6362CMD_CONFIG_FILE_PATH = os .path .join (get_astrbot_data_path (), "cmd_config.json" )
6463KB_PATH = get_astrbot_knowledge_base_path ()
65- DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT = 5
66- PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV = (
67- "ASTRBOT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT"
68- )
69-
70-
71- def _load_platform_stats_invalid_count_warn_limit () -> int :
72- raw_value = os .getenv (PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV )
73- if raw_value is None :
74- return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT
75- try :
76- value = int (raw_value )
77- if value < 0 :
78- raise ValueError ("negative" )
79- return value
80- except (TypeError , ValueError ):
81- logger .warning (
82- "Invalid env %s=%r, fallback to default %d" ,
83- PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV ,
84- raw_value ,
85- DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT ,
86- )
87- return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT
88-
89-
90- # Warning limit per _merge_platform_stats_rows invocation; configurable by env.
91- PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT = (
92- _load_platform_stats_invalid_count_warn_limit ()
93- )
64+ PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT = 5
9465
9566
9667@dataclass
@@ -200,11 +171,6 @@ def __init__(
200171 self .kb_manager = kb_manager
201172 self .config_path = config_path
202173 self .kb_root_dir = kb_root_dir
203- self ._main_table_preprocessors : dict [
204- str , Callable [[list [dict [str , Any ]]], list [dict [str , Any ]]]
205- ] = {
206- "platform_stats" : self ._merge_platform_stats_rows ,
207- }
208174
209175 def pre_check (self , zip_path : str ) -> ImportPreCheckResult :
210176 """预检查备份文件
@@ -559,16 +525,15 @@ async def _import_main_database(
559525 def _preprocess_main_table_rows (
560526 self , table_name : str , rows : list [dict [str , Any ]]
561527 ) -> list [dict [str , Any ]]:
562- preprocessor = self ._main_table_preprocessors .get (table_name )
563- if preprocessor is None :
564- return rows
565- normalized_rows = preprocessor (rows )
566- duplicate_count = len (rows ) - len (normalized_rows )
567- if duplicate_count > 0 :
568- logger .warning (
569- f"检测到 { table_name } 重复键 { duplicate_count } 条,已在导入前聚合"
570- )
571- return normalized_rows
528+ if table_name == "platform_stats" :
529+ normalized_rows = self ._merge_platform_stats_rows (rows )
530+ duplicate_count = len (rows ) - len (normalized_rows )
531+ if duplicate_count > 0 :
532+ logger .warning (
533+ f"检测到 { table_name } 重复键 { duplicate_count } 条,已在导入前聚合"
534+ )
535+ return normalized_rows
536+ return rows
572537
573538 def _merge_platform_stats_rows (
574539 self , rows : list [dict [str , Any ]]
@@ -584,28 +549,10 @@ def _merge_platform_stats_rows(
584549 non_mergeable : list [dict [str , Any ]] = []
585550 invalid_count_warned = 0
586551
587- def parse_count (raw_count : Any , key : tuple [str , str , str ]) -> int :
588- nonlocal invalid_count_warned
589- try :
590- return int (raw_count )
591- except (TypeError , ValueError ):
592- if invalid_count_warned < PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT :
593- logger .warning (
594- "platform_stats count 非法,已按 0 处理: value=%r, key=%s" ,
595- raw_count ,
596- key ,
597- )
598- invalid_count_warned += 1
599- return 0
600-
601552 for row in rows :
602- normalized_row = dict (row )
603- normalized_timestamp , is_timestamp_valid = (
604- self ._normalize_platform_stats_timestamp (
605- normalized_row .get ("timestamp" )
606- )
553+ normalized_row , normalized_timestamp , is_timestamp_valid = (
554+ self ._normalize_platform_stats_row (row )
607555 )
608- normalized_row ["timestamp" ] = normalized_timestamp
609556
610557 platform_id = normalized_row .get ("platform_id" )
611558 platform_type = normalized_row .get ("platform_type" )
@@ -614,7 +561,11 @@ def parse_count(raw_count: Any, key: tuple[str, str, str]) -> int:
614561 repr (platform_id ),
615562 repr (platform_type ),
616563 )
617- count = parse_count (normalized_row .get ("count" , 0 ), key_for_log )
564+ count , invalid_count_warned = self ._parse_platform_stats_count (
565+ normalized_row .get ("count" , 0 ),
566+ key_for_log ,
567+ invalid_count_warned ,
568+ )
618569 normalized_row ["count" ] = count
619570
620571 if not is_timestamp_valid :
@@ -634,32 +585,70 @@ def parse_count(raw_count: Any, key: tuple[str, str, str]) -> int:
634585
635586 return [* non_mergeable , * merged .values ()]
636587
588+ def _parse_platform_stats_count (
589+ self ,
590+ raw_count : Any ,
591+ key_for_log : tuple [str , str , str ],
592+ warned_count : int ,
593+ ) -> tuple [int , int ]:
594+ if warned_count >= PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT :
595+ try :
596+ return int (raw_count ), warned_count
597+ except (TypeError , ValueError ):
598+ return 0 , warned_count
599+ try :
600+ return int (raw_count ), warned_count
601+ except (TypeError , ValueError ):
602+ logger .warning (
603+ "platform_stats count 非法,已按 0 处理: value=%r, key=%s" ,
604+ raw_count ,
605+ key_for_log ,
606+ )
607+ return 0 , warned_count + 1
608+
609+ def _normalize_platform_stats_row (
610+ self , row : dict [str , Any ]
611+ ) -> tuple [dict [str , Any ], str , bool ]:
612+ normalized_row = dict (row )
613+ raw_timestamp = normalized_row .get ("timestamp" )
614+ normalized_timestamp = self ._normalize_platform_stats_timestamp (raw_timestamp )
615+ if normalized_timestamp is None :
616+ if isinstance (raw_timestamp , str ):
617+ normalized_row ["timestamp" ] = raw_timestamp .strip ()
618+ elif raw_timestamp is None :
619+ normalized_row ["timestamp" ] = ""
620+ else :
621+ normalized_row ["timestamp" ] = str (raw_timestamp )
622+ return normalized_row , normalized_row ["timestamp" ], False
623+ normalized_row ["timestamp" ] = normalized_timestamp
624+ return normalized_row , normalized_timestamp , True
625+
637626 def _to_utc_iso (self , dt : datetime ) -> str :
638627 if dt .tzinfo is None :
639628 dt = dt .replace (tzinfo = timezone .utc )
640629 else :
641630 dt = dt .astimezone (timezone .utc )
642631 return dt .isoformat ()
643632
644- def _normalize_platform_stats_timestamp (self , value : Any ) -> tuple [ str , bool ] :
633+ def _normalize_platform_stats_timestamp (self , value : Any ) -> str | None :
645634 if isinstance (value , datetime ):
646- return self ._to_utc_iso (value ), True
635+ return self ._to_utc_iso (value )
647636
648637 if isinstance (value , str ):
649638 timestamp = value .strip ()
650639 if not timestamp :
651- return "" , False
640+ return None
652641 if timestamp .endswith ("Z" ):
653642 timestamp = f"{ timestamp [:- 1 ]} +00:00"
654643 try :
655- return self ._to_utc_iso (datetime .fromisoformat (timestamp )), True
644+ return self ._to_utc_iso (datetime .fromisoformat (timestamp ))
656645 except ValueError :
657- return timestamp , False
646+ return None
658647
659648 if value is None :
660- return "" , False
649+ return None
661650
662- return str ( value ), False
651+ return None
663652
664653 async def _import_knowledge_bases (
665654 self ,
0 commit comments