-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Expand file tree
/
Copy pathbase_vector.py
More file actions
190 lines (159 loc) · 5.91 KB
/
base_vector.py
File metadata and controls
190 lines (159 loc) · 5.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
# coding=utf-8
"""
@project: maxkb
@Author:虎
@file: base_vector.py
@date:2023/10/18 19:16
@desc:
"""
import threading
from abc import ABC, abstractmethod
from functools import reduce
from typing import List, Dict
from langchain_core.embeddings import Embeddings
from common.chunk import text_to_chunk
from common.utils.common import sub_array
from knowledge.models import SourceType, SearchMode
lock = threading.Lock()
def chunk_data(data: Dict):
if str(data.get('source_type')) == str(SourceType.PARAGRAPH.value):
text = data.get('text')
chunk_list = data.get('chunks') if data.get('chunks') else text_to_chunk(text)
return [{**data, 'text': chunk} for chunk in chunk_list]
return [data]
def chunk_data_list(data_list: List[Dict]):
result = [chunk_data(data) for data in data_list]
return reduce(lambda x, y: [*x, *y], result, [])
class BaseVectorStore(ABC):
vector_exists = False
@abstractmethod
def vector_is_create(self) -> bool:
"""
判断向量库是否创建
:return: 是否创建向量库
"""
pass
@abstractmethod
def vector_create(self):
"""
创建 向量库
:return:
"""
pass
def save_pre_handler(self):
"""
插入前置处理器 主要是判断向量库是否创建
:return: True
"""
if not BaseVectorStore.vector_exists:
if not self.vector_is_create():
self.vector_create()
BaseVectorStore.vector_exists = True
return True
def save(self, text, source_type: SourceType, knowledge_id: str, document_id: str, paragraph_id: str,
source_id: str,
is_active: bool,
embedding: Embeddings):
"""
插入向量数据
:param source_id: 资源id
:param knowledge_id: 知识库id
:param text: 文本
:param source_type: 资源类型
:param document_id: 文档id
:param is_active: 是否禁用
:param embedding: 向量化处理器
:param paragraph_id 段落id
:return: bool
"""
self.save_pre_handler()
data = {'document_id': document_id, 'paragraph_id': paragraph_id, 'knowledge_id': knowledge_id,
'is_active': is_active, 'source_id': source_id, 'source_type': source_type, 'text': text}
chunk_list = chunk_data(data)
result = sub_array(chunk_list)
for child_array in result:
self._batch_save(child_array, embedding, lambda: False)
def batch_save(self, data_list: List[Dict], embedding: Embeddings, is_the_task_interrupted):
"""
批量插入
@param data_list: 数据列表
@param embedding: 向量化处理器
@param is_the_task_interrupted: 判断是否中断任务
:return: bool
"""
self.save_pre_handler()
chunk_list = chunk_data_list(data_list)
result = sub_array(chunk_list)
for child_array in result:
if not is_the_task_interrupted():
self._batch_save(child_array, embedding, is_the_task_interrupted)
else:
break
@abstractmethod
def _save(self, text, source_type: SourceType, knowledge_id: str, document_id: str, paragraph_id: str,
source_id: str,
is_active: bool,
embedding: Embeddings):
pass
@abstractmethod
def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_the_task_interrupted):
pass
def search(self, query_text, knowledge_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_list: list[str],
is_active: bool,
embedding: Embeddings):
if knowledge_id_list is None or len(knowledge_id_list) == 0:
return []
embedding_query = embedding.embed_query(query_text)
result = self.query(embedding_query, knowledge_id_list, exclude_document_id_list, exclude_paragraph_list,
is_active, 1, 3, 0.65)
return result[0]
@abstractmethod
def query(self, query_text: str, query_embedding: List[float], knowledge_id_list: list[str],
document_id_list: list[str] | None,
exclude_document_id_list: list[str],
exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float,
search_mode: SearchMode):
pass
@abstractmethod
def hit_test(self, query_text, knowledge_id: list[str], exclude_document_id_list: list[str], top_number: int,
similarity: float,
search_mode: SearchMode,
embedding: Embeddings):
pass
@abstractmethod
def update_by_paragraph_id(self, paragraph_id: str, instance: Dict):
pass
@abstractmethod
def update_by_paragraph_ids(self, paragraph_ids: str, instance: Dict):
pass
@abstractmethod
def update_by_source_id(self, source_id: str, instance: Dict):
pass
@abstractmethod
def update_by_source_ids(self, source_ids: List[str], instance: Dict):
pass
@abstractmethod
def delete_by_knowledge_id(self, knowledge_id: str):
pass
@abstractmethod
def delete_by_document_id(self, document_id: str):
pass
@abstractmethod
def delete_by_document_id_list(self, document_id_list: List[str]):
pass
@abstractmethod
def delete_by_knowledge_id_list(self, knowledge_id_list: List[str]):
pass
@abstractmethod
def delete_by_source_id(self, source_id: str, source_type: str):
pass
@abstractmethod
def delete_by_source_ids(self, source_ids: List[str], source_type: str):
pass
@abstractmethod
def delete_by_paragraph_id(self, paragraph_id: str):
pass
@abstractmethod
def delete_by_paragraph_ids(self, paragraph_ids: List[str]):
pass