-
Notifications
You must be signed in to change notification settings - Fork 83
Expand file tree
/
Copy pathbase_splitter.py
More file actions
135 lines (124 loc) · 5.01 KB
/
base_splitter.py
File metadata and controls
135 lines (124 loc) · 5.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import copy
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable, Iterable, List, Literal, Optional, Union
from graphgen.bases.datatypes import Chunk
from graphgen.utils import logger
@dataclass
class BaseSplitter(ABC):
"""
Abstract base class for splitting text into smaller chunks.
"""
chunk_size: int = 1024
chunk_overlap: int = 100
length_function: Callable[[str], int] = len
keep_separator: bool = False
add_start_index: bool = False
strip_whitespace: bool = True
@abstractmethod
def split_text(self, text: str) -> List[str]:
"""
Split the input text into smaller chunks.
:param text: The input text to be split.
:return: A list of text chunks.
"""
def create_chunks(
self, texts: List[str], metadatas: Optional[List[dict]] = None
) -> List[Chunk]:
"""Create chunks from a list of texts."""
_metadatas = metadatas or [{}] * len(texts)
chunks = []
for i, text in enumerate(texts):
index = 0
previous_chunk_len = 0
for chunk in self.split_text(text):
metadata = copy.deepcopy(_metadatas[i])
if self.add_start_index:
offset = index + previous_chunk_len - self.chunk_overlap
index = text.find(chunk, max(0, offset))
metadata["start_index"] = index
previous_chunk_len = len(chunk)
new_chunk = Chunk(content=chunk, metadata=metadata)
chunks.append(new_chunk)
return chunks
def _join_chunks(self, chunks: List[str], separator: str) -> Optional[str]:
text = separator.join(chunks)
if self.strip_whitespace:
text = text.strip()
if text == "":
return None
return text
def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
# We now want to combine these smaller pieces into medium size chunks to send to the LLM.
separator_len = self.length_function(separator)
chunks = []
current_chunk: List[str] = []
total = 0
for d in splits:
_len = self.length_function(d)
if (
total + _len + (separator_len if len(current_chunk) > 0 else 0)
> self.chunk_size
):
if total > self.chunk_size:
logger.warning(
"Created a chunk of size %s, which is longer than the specified %s",
total,
self.chunk_size,
)
if len(current_chunk) > 0:
chunk = self._join_chunks(current_chunk, separator)
if chunk is not None:
chunks.append(chunk)
# Keep on popping if:
# - we have a larger chunk than in the chunk overlap
# - or if we still have any chunks and the length is long
while total > self.chunk_overlap or (
total + _len + (separator_len if len(current_chunk) > 0 else 0)
> self.chunk_size
and total > 0
):
total -= self.length_function(current_chunk[0]) + (
separator_len if len(current_chunk) > 1 else 0
)
current_chunk = current_chunk[1:]
current_chunk.append(d)
total += _len + (separator_len if len(current_chunk) > 1 else 0)
chunk = self._join_chunks(current_chunk, separator)
if chunk is not None:
chunks.append(chunk)
return chunks
@staticmethod
def _split_text_with_regex(
text: str, separator: str, keep_separator: Union[bool, Literal["start", "end"]]
) -> List[str]:
# Now that we have the separator, split the text
if separator:
if keep_separator:
# The parentheses in the pattern keep the delimiters in the result.
_splits = re.split(f"({separator})", text)
splits = (
(
[
_splits[i] + _splits[i + 1]
for i in range(0, len(_splits) - 1, 2)
]
)
if keep_separator == "end"
else (
[_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
)
)
if len(_splits) % 2 == 0:
splits += _splits[-1:]
splits = (
(splits + [_splits[-1]])
if keep_separator == "end"
else ([_splits[0]] + splits)
)
else:
splits = re.split(separator, text)
else:
splits = list(text)
return [s for s in splits if s != ""]