-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Expand file tree
/
Copy pathbase_model_provider.py
More file actions
264 lines (206 loc) · 8.67 KB
/
base_model_provider.py
File metadata and controls
264 lines (206 loc) · 8.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
# coding=utf-8
"""
@project: maxkb
@Author:虎
@file: base_model_provider.py
@date:2023/10/31 16:19
@desc:
"""
from abc import ABC, abstractmethod
from enum import Enum
from functools import reduce
from typing import Dict, Iterator, Type, List
from pydantic import BaseModel
from common.exception.app_exception import AppApiException
from django.utils.translation import gettext_lazy as _
from common.util.common import encryption
class DownModelChunkStatus(Enum):
success = "success"
error = "error"
pulling = "pulling"
unknown = 'unknown'
class ValidCode(Enum):
valid_error = 500
model_not_fount = 404
class DownModelChunk:
def __init__(self, status: DownModelChunkStatus, digest: str, progress: int, details: str, index: int):
self.details = details
self.status = status
self.digest = digest
self.progress = progress
self.index = index
def to_dict(self):
return {
"details": self.details,
"status": self.status.value,
"digest": self.digest,
"progress": self.progress,
"index": self.index
}
class IModelProvider(ABC):
@abstractmethod
def get_model_info_manage(self):
pass
@abstractmethod
def get_model_provide_info(self):
pass
def get_model_type_list(self):
return self.get_model_info_manage().get_model_type_list()
def get_model_list(self, model_type):
if model_type is None:
raise AppApiException(500, _('Model type cannot be empty'))
return self.get_model_info_manage().get_model_list_by_model_type(model_type)
def get_model_credential(self, model_type, model_name):
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
return model_info.model_credential
def get_model_params(self, model_type, model_name):
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
return model_info.model_credential
def is_valid_credential(self, model_type, model_name, model_credential: Dict[str, object],
model_params: Dict[str, object], raise_exception=False):
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
return model_info.model_credential.is_valid(model_type, model_name, model_credential, model_params, self,
raise_exception=raise_exception)
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> BaseModel:
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
return model_info.model_class.new_instance(model_type, model_name, model_credential, **model_kwargs)
def get_dialogue_number(self):
return 3
def down_model(self, model_type: str, model_name, model_credential: Dict[str, object]) -> Iterator[DownModelChunk]:
raise AppApiException(500, _('The current platform does not support downloading models'))
class MaxKBBaseModel(ABC):
@staticmethod
@abstractmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
pass
@staticmethod
def is_cache_model():
return True
@staticmethod
def filter_optional_params(model_kwargs):
optional_params = {}
for key, value in model_kwargs.items():
if key not in ['model_id', 'use_local', 'streaming', 'show_ref_label']:
if key == 'extra_body' and isinstance(value, dict):
optional_params = {**optional_params, **value}
else:
optional_params[key] = value
return optional_params
class BaseModelCredential(ABC):
@abstractmethod
def is_valid(self, model_type: str, model_name, model: Dict[str, object], model_params, provider,
raise_exception=True):
pass
@abstractmethod
def encryption_dict(self, model_info: Dict[str, object]):
"""
:param model_info: 模型数据
:return: 加密后数据
"""
pass
def get_model_params_setting_form(self, model_name):
"""
模型参数设置表单
:return:
"""
pass
@staticmethod
def encryption(message: str):
"""
加密敏感字段数据 加密方式是 如果密码是 1234567890 那么给前端则是 123******890
:param message:
:return:
"""
return encryption(message)
class ModelTypeConst(Enum):
LLM = {'code': 'LLM', 'message': _('LLM')}
EMBEDDING = {'code': 'EMBEDDING', 'message': _('Embedding Model')}
STT = {'code': 'STT', 'message': _('Speech2Text')}
TTS = {'code': 'TTS', 'message': _('TTS')}
IMAGE = {'code': 'IMAGE', 'message': _('Vision Model')}
TTI = {'code': 'TTI', 'message': _('Image Generation')}
RERANKER = {'code': 'RERANKER', 'message': _('Rerank')}
class ModelInfo:
def __init__(self, name: str, desc: str, model_type: ModelTypeConst, model_credential: BaseModelCredential,
model_class: Type[MaxKBBaseModel],
**keywords):
self.name = name
self.desc = desc
self.model_type = model_type.name
self.model_credential = model_credential
self.model_class = model_class
if keywords is not None:
for key in keywords.keys():
self.__setattr__(key, keywords.get(key))
def get_name(self):
"""
获取模型名称
:return: 模型名称
"""
return self.name
def get_desc(self):
"""
获取模型描述
:return: 模型描述
"""
return self.desc
def get_model_type(self):
return self.model_type
def get_model_class(self):
return self.model_class
def to_dict(self):
return reduce(lambda x, y: {**x, **y},
[{attr: self.__getattribute__(attr)} for attr in vars(self) if
not attr.startswith("__") and not attr == 'model_credential' and not attr == 'model_class'], {})
class ModelInfoManage:
def __init__(self):
self.model_dict = {}
self.model_list = []
self.default_model_list = []
self.default_model_dict = {}
def append_model_info(self, model_info: ModelInfo):
self.model_list.append(model_info)
model_type_dict = self.model_dict.get(model_info.model_type)
if model_type_dict is None:
self.model_dict[model_info.model_type] = {model_info.name: model_info}
else:
model_type_dict[model_info.name] = model_info
def append_default_model_info(self, model_info: ModelInfo):
self.default_model_list.append(model_info)
self.default_model_dict[model_info.model_type] = model_info
def get_model_list(self):
return [model.to_dict() for model in self.model_list]
def get_model_list_by_model_type(self, model_type):
return [model.to_dict() for model in self.model_list if model.model_type == model_type]
def get_model_type_list(self):
return [{'key': _type.value.get('message'), 'value': _type.value.get('code')} for _type in ModelTypeConst if
len([model for model in self.model_list if model.model_type == _type.name]) > 0]
def get_model_info(self, model_type, model_name) -> ModelInfo:
model_info = self.model_dict.get(model_type, {}).get(model_name, self.default_model_dict.get(model_type))
if model_info is None:
raise AppApiException(500, _('The model does not support'))
return model_info
class builder:
def __init__(self):
self.modelInfoManage = ModelInfoManage()
def append_model_info(self, model_info: ModelInfo):
self.modelInfoManage.append_model_info(model_info)
return self
def append_model_info_list(self, model_info_list: List[ModelInfo]):
for model_info in model_info_list:
self.modelInfoManage.append_model_info(model_info)
return self
def append_default_model_info(self, model_info: ModelInfo):
self.modelInfoManage.append_default_model_info(model_info)
return self
def build(self):
return self.modelInfoManage
class ModelProvideInfo:
def __init__(self, provider: str, name: str, icon: str):
self.provider = provider
self.name = name
self.icon = icon
def to_dict(self):
return reduce(lambda x, y: {**x, **y},
[{attr: self.__getattribute__(attr)} for attr in vars(self) if
not attr.startswith("__")], {})