@@ -51,6 +51,7 @@ class PercentileSemanticChunker:
5151 """
5252
5353 SENTENCE_BOUNDARY_RE = re .compile (r"(?<=[。!?;])|(?<=[.?!;])\s+" )
54+ SUPPORTED_SEMANTIC_UNITS = frozenset ({"sentence" , "paragraph" })
5455
5556 def __init__ (
5657 self ,
@@ -62,6 +63,7 @@ def __init__(
6263 overlap_tokens : int = 50 ,
6364 overlap_percentage : float | None = None ,
6465 min_distance_gate : float = 0.25 ,
66+ semantic_unit : str = "sentence" ,
6567 ):
6668 """
6769 初始化语义切片器及其阈值、长度约束与 overlap 配置。
@@ -75,10 +77,12 @@ def __init__(
7577 overlap_tokens: 相邻 Chunk 的固定 token overlap 上限。
7678 overlap_percentage: 可选的 overlap 百分比配置;当 `overlap_tokens` 为 0 时启用。
7779 min_distance_gate: 绝对最小语义距离阈值,用于避免过度切分。
80+ semantic_unit: 语义相似度计算粒度,支持 `sentence` 或 `paragraph`。
7881
7982 Returns:
8083 None.
8184 """
85+ semantic_unit = semantic_unit .strip ().lower ()
8286 if not 0 < percentile <= 100 :
8387 raise ValueError ("percentile must be in (0, 100]." )
8488 if min_chunk_tokens <= 0 :
@@ -93,6 +97,9 @@ def __init__(
9397 raise ValueError ("overlap_percentage must be in [0, 1)." )
9498 if min_distance_gate < 0 :
9599 raise ValueError ("min_distance_gate cannot be negative." )
100+ if semantic_unit not in self .SUPPORTED_SEMANTIC_UNITS :
101+ supported = ", " .join (sorted (self .SUPPORTED_SEMANTIC_UNITS ))
102+ raise ValueError (f"semantic_unit must be one of: { supported } ." )
96103
97104 self .embedder = embedder
98105 self .tokenizer = tokenizer
@@ -102,6 +109,7 @@ def __init__(
102109 self .overlap_tokens = overlap_tokens
103110 self .overlap_percentage = overlap_percentage
104111 self .min_distance_gate = min_distance_gate
112+ self .semantic_unit = semantic_unit
105113 self .last_stats = SemanticChunkingStats ()
106114
107115 def _resolve_overlap_tokens (self ) -> int :
@@ -243,6 +251,18 @@ def _split_by_sentences(self, text: str) -> List[str]:
243251 atoms .append (current )
244252 return atoms
245253
254+ def _split_paragraphs (self , text : str ) -> List [str ]:
255+ """
256+ 按 Markdown 段落边界切分文本,过滤空段落并保留段落内部换行。
257+
258+ Args:
259+ text: 待切分的大文本块。
260+
261+ Returns:
262+ List[str]: 非空段落列表。
263+ """
264+ return [paragraph .strip () for paragraph in re .split (r"\n{2,}" , text ) if paragraph .strip ()]
265+
246266 def _atomize_text (self , text : str ) -> List [str ]:
247267 """
248268 执行原子化拆解,优先按段落切分,必要时降级为按行或按句切分。
@@ -253,7 +273,10 @@ def _atomize_text(self, text: str) -> List[str]:
253273 Returns:
254274 List[str]: 原子化后的文本单元列表。
255275 """
256- paragraphs = [paragraph .strip () for paragraph in re .split (r"\n{2,}" , text ) if paragraph .strip ()]
276+ paragraphs = self ._split_paragraphs (text )
277+ if self .semantic_unit == "paragraph" :
278+ return paragraphs
279+
257280 atoms : List [str ] = []
258281
259282 for paragraph in paragraphs :
@@ -274,6 +297,25 @@ def _atomize_text(self, text: str) -> List[str]:
274297
275298 return [atom for atom in atoms if atom .strip ()]
276299
300+ def _append_limited_chunk (self , chunks : list [str ], text : str ) -> None :
301+ """
302+ 追加最终 Chunk 文本;若文本超长,则只做长度保底拆分。
303+
304+ Args:
305+ chunks: 待追加的最终 Chunk 列表。
306+ text: 当前待落盘的 Chunk 文本。
307+
308+ Returns:
309+ None.
310+ """
311+ cleaned = text .strip ()
312+ if not cleaned :
313+ return
314+ if self ._count_tokens (cleaned ) <= self .max_chunk_tokens :
315+ chunks .append (cleaned )
316+ return
317+ chunks .extend (self ._split_oversized_text (cleaned ))
318+
277319 @staticmethod
278320 def _compute_distances (embeddings : Sequence [Sequence [float ]]) -> list [float ]:
279321 """
@@ -391,7 +433,9 @@ def _group_atom_indices(
391433
392434 for idx in range (1 , len (atoms )):
393435 next_atom = atoms [idx ].strip ()
394- distance = distances [idx - 1 ] if distances is not None and idx - 1 < len (distances ) else None
436+ distance = (
437+ distances [idx - 1 ] if distances is not None and idx - 1 < len (distances ) else None
438+ )
395439
396440 semantic_breakpoint = (
397441 distance is not None
@@ -454,7 +498,9 @@ def _merge_atoms(
454498
455499 for idx in range (1 , len (atoms )):
456500 next_atom = atoms [idx ].strip ()
457- distance = distances [idx - 1 ] if distances is not None and idx - 1 < len (distances ) else None
501+ distance = (
502+ distances [idx - 1 ] if distances is not None and idx - 1 < len (distances ) else None
503+ )
458504
459505 semantic_breakpoint = (
460506 distance is not None
@@ -469,15 +515,15 @@ def _merge_atoms(
469515 if overflow_forced or (
470516 semantic_breakpoint and self ._count_tokens (current_text ) >= self .min_chunk_tokens
471517 ):
472- chunks . append ( current_text )
518+ self . _append_limited_chunk ( chunks , current_text )
473519 if semantic_breakpoint and not overflow_forced :
474520 breakpoints .append (idx - 1 )
475521 current_text = self ._build_next_chunk (current_text , next_atom )
476522 else :
477523 current_text = merged_candidate
478524
479525 if current_text :
480- chunks . append ( current_text )
526+ self . _append_limited_chunk ( chunks , current_text )
481527
482528 self .last_stats = SemanticChunkingStats (
483529 atom_count = len (atoms ),
@@ -509,7 +555,8 @@ async def group_texts(self, texts: Sequence[str]) -> list[list[int]]:
509555
510556 try :
511557 embedding_result = await self .embedder .embed (list (atoms ))
512- embeddings = [list (map (float , vector )) for vector in embedding_result .embeddings ]
558+ raw_embeddings = getattr (embedding_result , "embeddings" , None ) or []
559+ embeddings = [list (map (float , vector )) for vector in raw_embeddings ]
513560 if len (embeddings ) != len (atoms ) or any (not vector for vector in embeddings ):
514561 raise ValueError (
515562 f"Embedding shape mismatch: got { len (embeddings )} vectors, expected { len (atoms )} ."
@@ -547,7 +594,8 @@ async def split(self, text_block: str) -> List[str]:
547594
548595 try :
549596 embedding_result = await self .embedder .embed (list (atoms ))
550- embeddings = [list (map (float , vector )) for vector in embedding_result .embeddings ]
597+ raw_embeddings = getattr (embedding_result , "embeddings" , None ) or []
598+ embeddings = [list (map (float , vector )) for vector in raw_embeddings ]
551599 if len (embeddings ) != len (atoms ) or any (not vector for vector in embeddings ):
552600 raise ValueError (
553601 f"Embedding shape mismatch: got { len (embeddings )} vectors, expected { len (atoms )} ."
0 commit comments