-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Expand file tree
/
Copy pathvllm_model_provider.py
More file actions
110 lines (92 loc) · 4.98 KB
/
vllm_model_provider.py
File metadata and controls
110 lines (92 loc) · 4.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# coding=utf-8
import os
from urllib.parse import urlparse, ParseResult
import requests
from common.utils.common import get_file_content
from models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
ModelInfoManage
from models_provider.impl.vllm_model_provider.credential.embedding import VllmEmbeddingCredential
from models_provider.impl.vllm_model_provider.credential.image import VllmImageModelCredential
from models_provider.impl.vllm_model_provider.credential.llm import VLLMModelCredential
from models_provider.impl.vllm_model_provider.credential.reranker import VllmRerankerCredential
from models_provider.impl.vllm_model_provider.credential.whisper_stt import VLLMWhisperModelCredential
from models_provider.impl.vllm_model_provider.model.embedding import VllmEmbeddingModel
from models_provider.impl.vllm_model_provider.model.image import VllmImage
from models_provider.impl.vllm_model_provider.model.llm import VllmChatModel
from maxkb.conf import PROJECT_DIR
from django.utils.translation import gettext as _
from models_provider.impl.vllm_model_provider.model.reranker import VllmBgeReranker
from models_provider.impl.vllm_model_provider.model.whisper_sst import VllmWhisperSpeechToText
v_llm_model_credential = VLLMModelCredential()
image_model_credential = VllmImageModelCredential()
embedding_model_credential = VllmEmbeddingCredential()
whisper_model_credential = VLLMWhisperModelCredential()
rerank_model_credential = VllmRerankerCredential()
model_info_list = [
ModelInfo('facebook/opt-125m', _('Facebook’s 125M parameter model'), ModelTypeConst.LLM, v_llm_model_credential,
VllmChatModel),
ModelInfo('BAAI/Aquila-7B', _('BAAI’s 7B parameter model'), ModelTypeConst.LLM, v_llm_model_credential,
VllmChatModel),
ModelInfo('BAAI/AquilaChat-7B', _('BAAI’s 13B parameter mode'), ModelTypeConst.LLM, v_llm_model_credential,
VllmChatModel),
]
image_model_info_list = [
ModelInfo('Qwen/Qwen2-VL-2B-Instruct', '', ModelTypeConst.IMAGE, image_model_credential, VllmImage),
]
embedding_model_info_list = [
ModelInfo('HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5', '', ModelTypeConst.EMBEDDING,
embedding_model_credential, VllmEmbeddingModel),
]
whisper_model_info_list = [
ModelInfo('whisper-tiny', '', ModelTypeConst.STT, whisper_model_credential, VllmWhisperSpeechToText),
ModelInfo('whisper-large-v3-turbo', '', ModelTypeConst.STT, whisper_model_credential, VllmWhisperSpeechToText),
ModelInfo('whisper-small', '', ModelTypeConst.STT, whisper_model_credential, VllmWhisperSpeechToText),
ModelInfo('whisper-large-v3', '', ModelTypeConst.STT, whisper_model_credential, VllmWhisperSpeechToText),
]
reranker_model_info_list = [
ModelInfo('bge-reranker-v2-m3', '', ModelTypeConst.RERANKER, rerank_model_credential, VllmBgeReranker),
]
model_info_manage = (
ModelInfoManage.builder()
.append_model_info_list(model_info_list)
.append_default_model_info(ModelInfo('facebook/opt-125m',
_('Facebook’s 125M parameter model'),
ModelTypeConst.LLM, v_llm_model_credential, VllmChatModel))
.append_model_info_list(image_model_info_list)
.append_default_model_info(image_model_info_list[0])
.append_model_info_list(embedding_model_info_list)
.append_default_model_info(embedding_model_info_list[0])
.append_model_info_list(whisper_model_info_list)
.append_default_model_info(whisper_model_info_list[0])
.append_model_info_list(reranker_model_info_list)
.append_default_model_info(reranker_model_info_list[0])
.build()
)
def get_base_url(url: str):
parse = urlparse(url)
result_url = ParseResult(scheme=parse.scheme, netloc=parse.netloc, path=parse.path, params='',
query='',
fragment='').geturl()
return result_url[:-1] if result_url.endswith("/") else result_url
class VllmModelProvider(IModelProvider):
def get_model_info_manage(self):
return model_info_manage
def get_model_provide_info(self):
return ModelProvideInfo(provider='model_vllm_provider', name='vLLM', icon=get_file_content(
os.path.join(PROJECT_DIR, "apps", 'models_provider', 'impl', 'vllm_model_provider', 'icon',
'vllm_icon_svg')))
@staticmethod
def get_base_model_list(api_base, api_key):
base_url = get_base_url(api_base)
base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1')
headers = {}
if api_key:
headers['Authorization'] = f"Bearer {api_key}"
r = requests.request(method="GET", url=f"{base_url}/models", headers=headers, timeout=5)
r.raise_for_status()
return r.json().get('data')
@staticmethod
def get_model_info_by_name(model_list, model_name):
if model_list is None:
return []
return [model for model in model_list if model.get('id') == model_name]