Skip to content

Commit 0bffcfc

Browse files
feat: add splitter classes
1 parent a430183 commit 0bffcfc

10 files changed

Lines changed: 429 additions & 14 deletions

File tree

graphgen/bases/base_splitter.py

Lines changed: 102 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import copy
2+
import re
23
from abc import ABC, abstractmethod
34
from dataclasses import dataclass
4-
from typing import Any, Callable, Dict, List, Optional
5+
from typing import Callable, Iterable, List, Literal, Optional, Union
56

67
from graphgen.bases.datatypes import Chunk
8+
from graphgen.utils import logger
79

810

911
@dataclass
@@ -13,35 +15,121 @@ class BaseSplitter(ABC):
1315
"""
1416

1517
chunk_size: int = 1024
16-
chunk_overlap_size: int = 100
18+
chunk_overlap: int = 100
1719
length_function: Callable[[str], int] = len
1820
keep_separator: bool = False
1921
add_start_index: bool = False
22+
strip_whitespace: bool = True
2023

2124
@abstractmethod
22-
def split_text(self, text: str) -> List[Dict[str, Any]]:
25+
def split_text(self, text: str) -> List[str]:
2326
"""
2427
Split the input text into smaller chunks.
2528
2629
:param text: The input text to be split.
27-
:return: A list of dictionaries, each containing a chunk of text and optionally its start index.
30+
:return: A list of text chunks.
2831
"""
2932

3033
def create_chunks(
3134
self, texts: List[str], metadatas: Optional[List[dict]] = None
3235
) -> List[Chunk]:
33-
"""
34-
Turn a list of texts into a list of Chunks, with optional metadata.
35-
:param texts:
36-
:param metadatas:
37-
:return:
38-
"""
36+
"""Create chunks from a list of texts."""
3937
_metadatas = metadatas or [{}] * len(texts)
4038
chunks = []
4139
for i, text in enumerate(texts):
42-
chunks.append(Chunk(content=text, metadata=copy.deepcopy(_metadatas[i])))
40+
index = 0
41+
previous_chunk_len = 0
42+
for chunk in self.split_text(text):
43+
metadata = copy.deepcopy(_metadatas[i])
44+
if self.add_start_index:
45+
offset = index + previous_chunk_len - self.chunk_overlap
46+
index = text.find(chunk, max(0, offset))
47+
metadata["start_index"] = index
48+
previous_chunk_len = len(chunk)
49+
new_chunk = Chunk(content=chunk, metadata=metadata)
50+
chunks.append(new_chunk)
51+
return chunks
52+
53+
def _join_chunks(self, chunks: List[str], separator: str) -> Optional[str]:
54+
text = separator.join(chunks)
55+
if self.strip_whitespace:
56+
text = text.strip()
57+
if text == "":
58+
return None
59+
return text
60+
61+
def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
62+
# We now want to combine these smaller pieces into medium size chunks to send to the LLM.
63+
separator_len = self.length_function(separator)
64+
65+
chunks = []
66+
current_chunk: List[str] = []
67+
total = 0
68+
for d in splits:
69+
_len = self.length_function(d)
70+
if (
71+
total + _len + (separator_len if len(current_chunk) > 0 else 0)
72+
> self.chunk_size
73+
):
74+
if total > self.chunk_size:
75+
logger.warning(
76+
"Created a chunk of size %s, which is longer than the specified %s",
77+
total,
78+
self.chunk_size,
79+
)
80+
if len(current_chunk) > 0:
81+
chunk = self._join_chunks(current_chunk, separator)
82+
if chunk is not None:
83+
chunks.append(chunk)
84+
# Keep on popping if:
85+
# - we have a larger chunk than in the chunk overlap
86+
# - or if we still have any chunks and the length is long
87+
while total > self.chunk_overlap or (
88+
total + _len + (separator_len if len(current_chunk) > 0 else 0)
89+
> self.chunk_size
90+
and total > 0
91+
):
92+
total -= self.length_function(current_chunk[0]) + (
93+
separator_len if len(current_chunk) > 1 else 0
94+
)
95+
current_chunk = current_chunk[1:]
96+
current_chunk.append(d)
97+
total += _len + (separator_len if len(current_chunk) > 1 else 0)
98+
chunk = self._join_chunks(current_chunk, separator)
99+
if chunk is not None:
100+
chunks.append(chunk)
43101
return chunks
44102

45-
def split(self, text: str, metadata: Optional[dict] = None) -> List[Chunk]:
46-
texts = self.split_text(text)
47-
return self.create_chunks(texts, [metadata] * len(texts) if metadata else None)
103+
@staticmethod
104+
def _split_text_with_regex(
105+
text: str, separator: str, keep_separator: Union[bool, Literal["start", "end"]]
106+
) -> List[str]:
107+
# Now that we have the separator, split the text
108+
if separator:
109+
if keep_separator:
110+
# The parentheses in the pattern keep the delimiters in the result.
111+
_splits = re.split(f"({separator})", text)
112+
splits = (
113+
(
114+
[
115+
_splits[i] + _splits[i + 1]
116+
for i in range(0, len(_splits) - 1, 2)
117+
]
118+
)
119+
if keep_separator == "end"
120+
else (
121+
[_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
122+
)
123+
)
124+
if len(_splits) % 2 == 0:
125+
splits += _splits[-1:]
126+
splits = (
127+
(splits + [_splits[-1]])
128+
if keep_separator == "end"
129+
else ([_splits[0]] + splits)
130+
)
131+
else:
132+
splits = re.split(separator, text)
133+
else:
134+
splits = list(text)
135+
return [s for s in splits if s != ""]

graphgen/models/splitter/__init__.py

Whitespace-only changes.
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import re
2+
from typing import Any, List
3+
4+
from graphgen.bases.base_splitter import BaseSplitter
5+
6+
7+
class CharacterSplitter(BaseSplitter):
8+
"""Splitting text that looks at characters."""
9+
10+
def __init__(
11+
self, separator: str = "\n\n", is_separator_regex: bool = False, **kwargs: Any
12+
) -> None:
13+
"""Create a new TextSplitter."""
14+
super().__init__(**kwargs)
15+
self._separator = separator
16+
self._is_separator_regex = is_separator_regex
17+
18+
def split_text(self, text: str) -> List[str]:
19+
"""Split incoming text and return chunks."""
20+
# First we naively split the large input into a bunch of smaller ones.
21+
separator = (
22+
self._separator if self._is_separator_regex else re.escape(self._separator)
23+
)
24+
splits = self._split_text_with_regex(text, separator, self.keep_separator)
25+
_separator = "" if self.keep_separator else self._separator
26+
return self._merge_splits(splits, _separator)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from typing import Any
2+
3+
from graphgen.models.splitter.recursive_character_splitter import (
4+
RecursiveCharacterSplitter,
5+
)
6+
7+
8+
class MarkdownTextRefSplitter(RecursiveCharacterSplitter):
9+
"""Attempts to split the text along Markdown-formatted headings."""
10+
11+
def __init__(self, **kwargs: Any) -> None:
12+
"""Initialize a MarkdownTextRefSplitter."""
13+
separators = [
14+
# First, try to split along Markdown headings (starting with level 2)
15+
"\n#{1,6} ",
16+
# Note the alternative syntax for headings (below) is not handled here
17+
# Heading level 2
18+
# ---------------
19+
# End of code block
20+
"```\n",
21+
# Horizontal lines
22+
"\n\\*\\*\\*+\n",
23+
"\n---+\n",
24+
"\n___+\n",
25+
# Note that this splitter doesn't handle horizontal lines defined
26+
# by *three or more* of ***, ---, or ___, but this is not handled
27+
"\n\n",
28+
"\n",
29+
" ",
30+
"",
31+
]
32+
super().__init__(separators=separators, **kwargs)
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
import re
2+
from typing import Any, List, Optional
3+
4+
from graphgen.bases.base_splitter import BaseSplitter
5+
6+
7+
class RecursiveCharacterSplitter(BaseSplitter):
8+
"""Splitting text by recursively look at characters.
9+
10+
Recursively tries to split by different characters to find one that works.
11+
"""
12+
13+
def __init__(
14+
self,
15+
separators: Optional[List[str]] = None,
16+
keep_separator: bool = True,
17+
is_separator_regex: bool = False,
18+
**kwargs: Any,
19+
) -> None:
20+
"""Create a new TextSplitter."""
21+
super().__init__(keep_separator=keep_separator, **kwargs)
22+
self._separators = separators or ["\n\n", "\n", " ", ""]
23+
self._is_separator_regex = is_separator_regex
24+
25+
def _split_text(self, text: str, separators: List[str]) -> List[str]:
26+
"""Split incoming text and return chunks."""
27+
final_chunks = []
28+
# Get appropriate separator to use
29+
separator = separators[-1]
30+
new_separators = []
31+
for i, _s in enumerate(separators):
32+
_separator = _s if self._is_separator_regex else re.escape(_s)
33+
if _s == "":
34+
separator = _s
35+
break
36+
if re.search(_separator, text):
37+
separator = _s
38+
new_separators = separators[i + 1 :]
39+
break
40+
41+
_separator = separator if self._is_separator_regex else re.escape(separator)
42+
splits = self._split_text_with_regex(text, _separator, self.keep_separator)
43+
44+
# Now go merging things, recursively splitting longer texts.
45+
_good_splits = []
46+
_separator = "" if self.keep_separator else separator
47+
for s in splits:
48+
if self.length_function(s) < self.chunk_size:
49+
_good_splits.append(s)
50+
else:
51+
if _good_splits:
52+
merged_text = self._merge_splits(_good_splits, _separator)
53+
final_chunks.extend(merged_text)
54+
_good_splits = []
55+
if not new_separators:
56+
final_chunks.append(s)
57+
else:
58+
other_info = self._split_text(s, new_separators)
59+
final_chunks.extend(other_info)
60+
if _good_splits:
61+
merged_text = self._merge_splits(_good_splits, _separator)
62+
final_chunks.extend(merged_text)
63+
return final_chunks
64+
65+
def split_text(self, text: str) -> List[str]:
66+
return self._split_text(text, self._separators)
67+
68+
69+
class ChineseRecursiveTextSplitter(RecursiveCharacterSplitter):
70+
def __init__(
71+
self,
72+
separators: Optional[List[str]] = None,
73+
keep_separator: bool = True,
74+
is_separator_regex: bool = True,
75+
**kwargs: Any,
76+
) -> None:
77+
super().__init__(keep_separator=keep_separator, **kwargs)
78+
self._separators = separators or [
79+
"\n\n",
80+
"\n",
81+
"。|!|?",
82+
r"\.\s|\!\s|\?\s",
83+
r";|;\s",
84+
r",|,\s",
85+
]
86+
self._is_separator_regex = is_separator_regex
87+
88+
def _split_text_with_regex_from_end(
89+
self, text: str, separator: str, keep_separator: bool
90+
) -> List[str]:
91+
# Now that we have the separator, split the text
92+
if separator:
93+
if keep_separator:
94+
# The parentheses in the pattern keep the delimiters in the result.
95+
_splits = re.split(f"({separator})", text)
96+
splits = ["".join(i) for i in zip(_splits[0::2], _splits[1::2])]
97+
if len(_splits) % 2 == 1:
98+
splits += _splits[-1:]
99+
# splits = [_splits[0]] + splits
100+
else:
101+
splits = re.split(separator, text)
102+
else:
103+
splits = list(text)
104+
return [s for s in splits if s != ""]
105+
106+
def _split_text(self, text: str, separators: List[str]) -> List[str]:
107+
"""Split incoming text and return chunks."""
108+
final_chunks = []
109+
# Get appropriate separator to use
110+
separator = separators[-1]
111+
new_separators = []
112+
for i, _s in enumerate(separators):
113+
_separator = _s if self._is_separator_regex else re.escape(_s)
114+
if _s == "":
115+
separator = _s
116+
break
117+
if re.search(_separator, text):
118+
separator = _s
119+
new_separators = separators[i + 1 :]
120+
break
121+
122+
_separator = separator if self._is_separator_regex else re.escape(separator)
123+
splits = self._split_text_with_regex_from_end(
124+
text, _separator, self.keep_separator
125+
)
126+
127+
# Now go merging things, recursively splitting longer texts.
128+
_good_splits = []
129+
_separator = "" if self.keep_separator else separator
130+
for s in splits:
131+
if self.length_function(s) < self.chunk_size:
132+
_good_splits.append(s)
133+
else:
134+
if _good_splits:
135+
merged_text = self._merge_splits(_good_splits, _separator)
136+
final_chunks.extend(merged_text)
137+
_good_splits = []
138+
if not new_separators:
139+
final_chunks.append(s)
140+
else:
141+
other_info = self._split_text(s, new_separators)
142+
final_chunks.extend(other_info)
143+
if _good_splits:
144+
merged_text = self._merge_splits(_good_splits, _separator)
145+
final_chunks.extend(merged_text)
146+
return [
147+
re.sub(r"\n{2,}", "\n", chunk.strip())
148+
for chunk in final_chunks
149+
if chunk.strip() != ""
150+
]

tests/__init__.py

Whitespace-only changes.

tests/integration_tests/__init__.py

Whitespace-only changes.
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import pytest
2+
3+
from graphgen.models.splitter.character_splitter import CharacterSplitter
4+
5+
6+
@pytest.mark.parametrize(
7+
"text,chunk_size,chunk_overlap,expected",
8+
[
9+
(
10+
"This is a test.\n\nThis is only a test.\n\nIn the event of an actual emergency...",
11+
25,
12+
5,
13+
[
14+
"This is a test.",
15+
"This is only a test.",
16+
"In the event of an actual emergency...",
17+
],
18+
),
19+
],
20+
)
21+
def test_character_splitter(text, chunk_size, chunk_overlap, expected):
22+
splitter = CharacterSplitter(
23+
separator="\n\n",
24+
is_separator_regex=False,
25+
chunk_size=chunk_size,
26+
chunk_overlap=chunk_overlap,
27+
keep_separator=False,
28+
)
29+
chunks = splitter.split_text(text)
30+
assert chunks == expected

0 commit comments

Comments
 (0)