1111import os
1212import shutil
1313import zipfile
14+ from collections .abc import Callable
1415from dataclasses import dataclass , field
1516from datetime import datetime , timezone
1617from pathlib import Path
@@ -67,34 +68,23 @@ def _get_major_version(version_str: str) -> str:
6768)
6869
6970
70- def _resolve_platform_stats_invalid_count_warn_limit (
71- raw_value : str | None ,
72- ) -> tuple [int , bool ]:
73- """Resolve warn limit value and return whether the input was valid."""
71+ def _load_platform_stats_invalid_count_warn_limit () -> int :
72+ raw_value = os .getenv (PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV )
7473 if raw_value is None :
75- return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT , True
74+ return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT
7675 try :
7776 value = int (raw_value )
77+ if value < 0 :
78+ raise ValueError ("negative" )
79+ return value
7880 except (TypeError , ValueError ):
79- return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT , False
80- if value < 0 :
81- return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT , False
82- return value , True
83-
84-
85- def _load_platform_stats_invalid_count_warn_limit () -> int :
86- raw_value = os .getenv (PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV )
87- resolved_value , is_valid = _resolve_platform_stats_invalid_count_warn_limit (
88- raw_value
89- )
90- if raw_value is not None and not is_valid :
9181 logger .warning (
9282 "Invalid env %s=%r, fallback to default %d" ,
9383 PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV ,
9484 raw_value ,
9585 DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT ,
9686 )
97- return resolved_value
87+ return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT
9888
9989
10090# Warning limit per _merge_platform_stats_rows invocation; configurable by env.
@@ -210,6 +200,11 @@ def __init__(
210200 self .kb_manager = kb_manager
211201 self .config_path = config_path
212202 self .kb_root_dir = kb_root_dir
203+ self ._main_table_preprocessors : dict [
204+ str , Callable [[list [dict [str , Any ]]], list [dict [str , Any ]]]
205+ ] = {
206+ "platform_stats" : self ._merge_platform_stats_rows ,
207+ }
213208
214209 def pre_check (self , zip_path : str ) -> ImportPreCheckResult :
215210 """预检查备份文件
@@ -543,14 +538,7 @@ async def _import_main_database(
543538 if not model_class :
544539 logger .warning (f"未知的表: { table_name } " )
545540 continue
546- normalized_rows = rows
547- if table_name == "platform_stats" :
548- normalized_rows = self ._merge_platform_stats_rows (rows )
549- duplicate_count = len (rows ) - len (normalized_rows )
550- if duplicate_count > 0 :
551- logger .warning (
552- f"检测到 platform_stats 重复键 { duplicate_count } 条,已在导入前聚合"
553- )
541+ normalized_rows = self ._preprocess_main_table_rows (table_name , rows )
554542
555543 count = 0
556544 for row in normalized_rows :
@@ -568,6 +556,20 @@ async def _import_main_database(
568556
569557 return imported
570558
559+ def _preprocess_main_table_rows (
560+ self , table_name : str , rows : list [dict [str , Any ]]
561+ ) -> list [dict [str , Any ]]:
562+ preprocessor = self ._main_table_preprocessors .get (table_name )
563+ if preprocessor is None :
564+ return rows
565+ normalized_rows = preprocessor (rows )
566+ duplicate_count = len (rows ) - len (normalized_rows )
567+ if duplicate_count > 0 :
568+ logger .warning (
569+ f"检测到 { table_name } 重复键 { duplicate_count } 条,已在导入前聚合"
570+ )
571+ return normalized_rows
572+
571573 def _merge_platform_stats_rows (
572574 self , rows : list [dict [str , Any ]]
573575 ) -> list [dict [str , Any ]]:
@@ -579,8 +581,23 @@ def _merge_platform_stats_rows(
579581 - Invalid count warnings are rate-limited per function invocation.
580582 """
581583 merged : dict [tuple [str , str , str ], dict [str , Any ]] = {}
582- result : list [dict [str , Any ]] = []
584+ non_mergeable : list [dict [str , Any ]] = []
583585 invalid_count_warned = 0
586+
587+ def parse_count (raw_count : Any , key : tuple [str , str , str ]) -> int :
588+ nonlocal invalid_count_warned
589+ try :
590+ return int (raw_count )
591+ except (TypeError , ValueError ):
592+ if invalid_count_warned < PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT :
593+ logger .warning (
594+ "platform_stats count 非法,已按 0 处理: value=%r, key=%s" ,
595+ raw_count ,
596+ key ,
597+ )
598+ invalid_count_warned += 1
599+ return 0
600+
584601 for row in rows :
585602 normalized_row = dict (row )
586603 normalized_timestamp , is_timestamp_valid = (
@@ -597,73 +614,51 @@ def _merge_platform_stats_rows(
597614 repr (platform_id ),
598615 repr (platform_type ),
599616 )
600- count , invalid_count_warned = self ._parse_platform_stats_count (
601- normalized_row .get ("count" , 0 ), invalid_count_warned , key_for_log
602- )
617+ count = parse_count (normalized_row .get ("count" , 0 ), key_for_log )
603618 normalized_row ["count" ] = count
604619
605- # Invalid timestamps should never be merged.
606620 if not is_timestamp_valid :
607- result .append (normalized_row )
621+ non_mergeable .append (normalized_row )
608622 continue
609623
610624 if not isinstance (platform_id , str ) or not isinstance (platform_type , str ):
611- result .append (normalized_row )
625+ non_mergeable .append (normalized_row )
612626 continue
613627
614628 key = (normalized_timestamp , platform_id , platform_type )
615629 existing = merged .get (key )
616630 if existing is None :
617631 merged [key ] = normalized_row
618- result .append (normalized_row )
619632 else :
620633 existing ["count" ] += count
621634
622- return result
635+ return [ * non_mergeable , * merged . values ()]
623636
624- def _parse_platform_stats_count (
625- self ,
626- raw_count : Any ,
627- invalid_count_warned : int ,
628- key : tuple [str , str , str ],
629- ) -> tuple [int , int ]:
630- """Safe int parse with per-call rate-limited warning."""
631- try :
632- return int (raw_count ), invalid_count_warned
633- except (TypeError , ValueError ):
634- if invalid_count_warned < PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT :
635- logger .warning (
636- "platform_stats count 非法,已按 0 处理: "
637- f"value={ raw_count !r} , key={ key } "
638- )
639- invalid_count_warned += 1
640- return 0 , invalid_count_warned
637+ def _to_utc_iso (self , dt : datetime ) -> str :
638+ if dt .tzinfo is None :
639+ dt = dt .replace (tzinfo = timezone .utc )
640+ else :
641+ dt = dt .astimezone (timezone .utc )
642+ return dt .isoformat ()
641643
642644 def _normalize_platform_stats_timestamp (self , value : Any ) -> tuple [str , bool ]:
643645 if isinstance (value , datetime ):
644- dt = value
645- if dt .tzinfo is None :
646- dt = dt .replace (tzinfo = timezone .utc )
647- else :
648- dt = dt .astimezone (timezone .utc )
649- return dt .isoformat (), True
646+ return self ._to_utc_iso (value ), True
647+
650648 if isinstance (value , str ):
651649 timestamp = value .strip ()
652650 if not timestamp :
653651 return "" , False
654652 if timestamp .endswith ("Z" ):
655653 timestamp = f"{ timestamp [:- 1 ]} +00:00"
656654 try :
657- dt = datetime .fromisoformat (timestamp )
658- if dt .tzinfo is None :
659- dt = dt .replace (tzinfo = timezone .utc )
660- else :
661- dt = dt .astimezone (timezone .utc )
662- return dt .isoformat (), True
655+ return self ._to_utc_iso (datetime .fromisoformat (timestamp )), True
663656 except ValueError :
664- return value .strip (), False
657+ return timestamp , False
658+
665659 if value is None :
666660 return "" , False
661+
667662 return str (value ), False
668663
669664 async def _import_knowledge_bases (
0 commit comments