Skip to content

Commit 43b940f

Browse files
committed
feat: add api_base field to model credentials and update related API calls
1 parent bc31ed0 commit 43b940f

File tree

20 files changed

+96
-32
lines changed

20 files changed

+96
-32
lines changed

apps/models_provider/base_model_provider.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,22 @@ def get_model_list(self, model_type):
6262

6363
def get_model_credential(self, model_type, model_name):
6464
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
65-
return model_info.model_credential
65+
model_credential = model_info.model_credential
66+
67+
if model_type == 'TTI' and model_name.startswith(('qwen', 'wan2.6', 'wan')):
68+
if hasattr(model_credential, 'api_base'):
69+
api_base = model_credential.api_base
70+
if hasattr(api_base, 'default_value'):
71+
default_value_map = {
72+
'qwen': "https://dashscope.aliyuncs.com/v1",
73+
'wan2.6': "https://dashscope.aliyuncs.com/api/v1",
74+
'wan': "https://dashscope.aliyuncs.com/compatible-mode/v1"
75+
}
76+
prefix = next((k for k in default_value_map if model_name.startswith(k)), None)
77+
if prefix:
78+
api_base.default_value = default_value_map[prefix]
79+
80+
return model_credential
6681

6782
def get_model_params(self, model_type, model_name):
6883
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
@@ -147,13 +162,11 @@ class ModelTypeConst(Enum):
147162
IMAGE = {'code': 'IMAGE', 'message': _('Vision Model')}
148163
TTI = {'code': 'TTI', 'message': _('Image Generation')}
149164
RERANKER = {'code': 'RERANKER', 'message': _('Rerank')}
150-
#文生视频 图生视频
165+
# 文生视频 图生视频
151166
TTV = {'code': 'TTV', 'message': _('Text to Video')}
152167
ITV = {'code': 'ITV', 'message': _('Image to Video')}
153168

154169

155-
156-
157170
class ModelInfo:
158171
def __init__(self, name: str, desc: str, model_type: ModelTypeConst, model_credential: BaseModelCredential,
159172
model_class: Type[MaxKBBaseModel],

apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/embedding.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from models_provider.impl.aliyun_bai_lian_model_provider.model.embedding import AliyunBaiLianEmbedding
1818
from common.utils.logger import maxkb_logger
1919

20+
2021
class BaiLianEmbeddingModelParams(BaseForm):
2122
dimensions = forms.SingleSelect(
2223
TooltipLabel(
@@ -55,7 +56,7 @@ def is_valid(
5556
ValidCode.valid_error.value,
5657
f"{model_type} Model type is not supported"
5758
)
58-
required_keys = ['dashscope_api_key']
59+
required_keys = ['dashscope_api_key', 'api_base']
5960
missing_keys = [key for key in required_keys if key not in model_credential]
6061
if missing_keys:
6162
if raise_exception:
@@ -88,8 +89,9 @@ def encryption_dict(self, model: Dict[str, Any]) -> Dict[str, Any]:
8889
api_key = model.get('dashscope_api_key', '')
8990
return {**model, 'dashscope_api_key': super().encryption(api_key)}
9091

91-
9292
def get_model_params_setting_form(self, model_name):
9393
return BaiLianEmbeddingModelParams()
9494

95+
api_base = forms.TextInputField(_('API URL'), required=True,
96+
default_value='https://dashscope.aliyuncs.com/compatible-mode/v1')
9597
dashscope_api_key = forms.PasswordInputField('API Key', required=True)

apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/tti.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Dict, Any
44

55
from django.utils.translation import gettext_lazy as _, gettext
6-
6+
from common import forms
77
from common.exception.app_exception import AppApiException
88
from common.forms import BaseForm, PasswordInputField, SingleSelect, SliderField, TooltipLabel
99
from models_provider.base_model_provider import BaseModelCredential, ValidCode
@@ -67,7 +67,8 @@ class QwenTextToImageModelCredential(BaseForm, BaseModelCredential):
6767
Credential class for the Qwen Text-to-Image model.
6868
Provides validation and encryption for the model credentials.
6969
"""
70-
70+
api_base = forms.TextInputField(_('API URL'), required=True,
71+
default_value='https://dashscope.aliyuncs.com/compatible-mode/v1')
7172
api_key = PasswordInputField('API Key', required=True)
7273

7374
def is_valid(
@@ -97,7 +98,7 @@ def is_valid(
9798
gettext('{model_type} Model type is not supported').format(model_type=model_type)
9899
)
99100

100-
required_keys = ['api_key']
101+
required_keys = ['api_key', 'api_base']
101102
for key in required_keys:
102103
if key not in model_credential:
103104
if raise_exception:

apps/models_provider/impl/aliyun_bai_lian_model_provider/model/embedding.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ class AliyunBaiLianEmbedding(MaxKBBaseModel):
1717
model_name: str
1818
optional_params: dict
1919

20-
def __init__(self, api_key, model_name: str, optional_params: dict):
21-
self.client = OpenAI(api_key=api_key, base_url='https://dashscope.aliyuncs.com/compatible-mode/v1').embeddings
20+
def __init__(self, api_key, model_name: str, api_base: str, optional_params: dict):
21+
self.client = OpenAI(api_key=api_key, base_url=api_base).embeddings
2222
self.model_name = model_name
2323
self.optional_params = optional_params
2424

@@ -31,6 +31,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
3131
return AliyunBaiLianEmbedding(
3232
api_key=model_credential.get('dashscope_api_key'),
3333
model_name=model_name,
34+
api_base=model_credential.get('api_base') or 'https://dashscope.aliyuncs.com/compatible-mode/v1',
3435
optional_params=optional_params
3536
)
3637

apps/models_provider/impl/aliyun_bai_lian_model_provider/model/tti.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Dict
44

55
from dashscope import ImageSynthesis, MultiModalConversation
6+
from dashscope.aigc.image_generation import ImageGeneration
67
from django.utils.translation import gettext
78
from langchain_community.chat_models import ChatTongyi
89
from langchain_core.messages import HumanMessage
@@ -17,10 +18,12 @@ class QwenTextToImageModel(MaxKBBaseModel, BaseTextToImage):
1718
api_key: str
1819
model_name: str
1920
params: dict
21+
api_base: str
2022

2123
def __init__(self, **kwargs):
2224
super().__init__(**kwargs)
2325
self.api_key = kwargs.get('api_key')
26+
self.api_base = kwargs.get('api_base')
2427
self.model_name = kwargs.get('model_name')
2528
self.params = kwargs.get('params')
2629

@@ -37,6 +40,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
3740
chat_tong_yi = QwenTextToImageModel(
3841
model_name=model_name,
3942
api_key=model_credential.get('api_key'),
43+
api_base=model_credential.get('api_base'),
4044
**optional_params,
4145
)
4246
return chat_tong_yi
@@ -46,10 +50,39 @@ def check_auth(self):
4650
chat.invoke([HumanMessage([{"type": "text", "text": gettext('Hello')}])])
4751

4852
def generate_image(self, prompt: str, negative_prompt: str = None):
49-
if self.model_name.startswith("wan"):
53+
if self.model_name.startswith("wan2.6") or self.model_name.startswith("z"):
54+
from dashscope.api_entities.dashscope_response import Message
55+
# 以下为北京地域url,各地域的base_url不同
56+
message = Message(
57+
role="user",
58+
content=[
59+
{
60+
'text': prompt
61+
}
62+
]
63+
)
64+
rsp = ImageGeneration.call(
65+
model="z-image-turbo",
66+
api_key=self.api_key,
67+
base_url=self.api_base,
68+
messages=[message],
69+
negative_prompt=negative_prompt,
70+
**self.params
71+
)
72+
file_urls = []
73+
if rsp.status_code == HTTPStatus.OK:
74+
for result in rsp.output.results:
75+
file_urls.append(result.url)
76+
else:
77+
maxkb_logger.error('sync_call Failed, status_code: %s, code: %s, message: %s' %
78+
(rsp.status_code, rsp.code, rsp.message))
79+
raise Exception('sync_call Failed, status_code: %s, code: %s, message: %s' %
80+
(rsp.status_code, rsp.code, rsp.message))
81+
return file_urls
82+
elif self.model_name.startswith("wan"):
5083
rsp = ImageSynthesis.call(api_key=self.api_key,
5184
model=self.model_name,
52-
base_url='https://dashscope.aliyuncs.com/compatible-mode/v1',
85+
base_url=self.api_base,
5386
prompt=prompt,
5487
negative_prompt=negative_prompt,
5588
**self.params)
@@ -61,7 +94,7 @@ def generate_image(self, prompt: str, negative_prompt: str = None):
6194
maxkb_logger.error('sync_call Failed, status_code: %s, code: %s, message: %s' %
6295
(rsp.status_code, rsp.code, rsp.message))
6396
raise Exception('sync_call Failed, status_code: %s, code: %s, message: %s' %
64-
(rsp.status_code, rsp.code, rsp.message))
97+
(rsp.status_code, rsp.code, rsp.message))
6598
return file_urls
6699
elif self.model_name.startswith("qwen"):
67100
messages = [
@@ -80,7 +113,7 @@ def generate_image(self, prompt: str, negative_prompt: str = None):
80113
model=self.model_name,
81114
messages=messages,
82115
result_format='message',
83-
base_url='https://dashscope.aliyuncs.com/v1',
116+
base_url=self.api_base,
84117
stream=False,
85118
negative_prompt=negative_prompt,
86119
**self.params
@@ -93,5 +126,5 @@ def generate_image(self, prompt: str, negative_prompt: str = None):
93126
maxkb_logger.error('sync_call Failed, status_code: %s, code: %s, message: %s' %
94127
(rsp.status_code, rsp.code, rsp.message))
95128
raise Exception('sync_call Failed, status_code: %s, code: %s, message: %s' %
96-
(rsp.status_code, rsp.code, rsp.message))
129+
(rsp.status_code, rsp.code, rsp.message))
97130
return file_urls

apps/models_provider/impl/deepseek_model_provider/credential/llm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from models_provider.base_model_provider import BaseModelCredential, ValidCode
1818
from common.utils.logger import maxkb_logger
1919

20+
2021
class DeepSeekLLMModelParams(BaseForm):
2122
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
2223
_('Higher values make the output more random, while lower values make it more focused and deterministic')),
@@ -45,7 +46,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
4546
raise AppApiException(ValidCode.valid_error.value,
4647
gettext('{model_type} Model type is not supported').format(model_type=model_type))
4748

48-
for key in ['api_key']:
49+
for key in ['api_key', 'api_base']:
4950
if key not in model_credential:
5051
if raise_exception:
5152
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
@@ -70,6 +71,8 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
7071
def encryption_dict(self, model: Dict[str, object]):
7172
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
7273

74+
api_base = forms.TextInputField(_('API URL'), required=True,
75+
default_value='https://api.deepseek.com')
7376
api_key = forms.PasswordInputField('API Key', required=True)
7477

7578
def get_model_params_setting_form(self, model_name):

apps/models_provider/impl/deepseek_model_provider/model/llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
2424

2525
deepseek_chat_open_ai = DeepSeekChatModel(
2626
model=model_name,
27-
openai_api_base='https://api.deepseek.com',
27+
openai_api_base=model_credential.get('api_base') or 'https://api.deepseek.com',
2828
openai_api_key=model_credential.get('api_key'),
2929
extra_body=optional_params
3030
)

apps/models_provider/impl/regolo_model_provider/credential/embedding.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from models_provider.base_model_provider import BaseModelCredential, ValidCode
1717
from common.utils.logger import maxkb_logger
1818

19+
1920
class RegoloEmbeddingCredential(BaseForm, BaseModelCredential):
2021
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
2122
raise_exception=True):
@@ -24,7 +25,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
2425
raise AppApiException(ValidCode.valid_error.value,
2526
_('{model_type} Model type is not supported').format(model_type=model_type))
2627

27-
for key in ['api_key']:
28+
for key in ['api_key', 'api_base']:
2829
if key not in model_credential:
2930
if raise_exception:
3031
raise AppApiException(ValidCode.valid_error.value, _('{key} is required').format(key=key))
@@ -48,4 +49,6 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
4849
def encryption_dict(self, model: Dict[str, object]):
4950
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
5051

52+
api_base = forms.TextInputField(_('API URL'), required=True,
53+
default_value='https://api.regolo.ai/v1')
5154
api_key = forms.PasswordInputField('API Key', required=True)

apps/models_provider/impl/regolo_model_provider/credential/image.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from models_provider.base_model_provider import BaseModelCredential, ValidCode
1414
from common.utils.logger import maxkb_logger
1515

16+
1617
class RegoloImageModelParams(BaseForm):
1718
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
1819
_('Higher values make the output more random, while lower values make it more focused and deterministic')),
@@ -33,7 +34,7 @@ class RegoloImageModelParams(BaseForm):
3334

3435

3536
class RegoloImageModelCredential(BaseForm, BaseModelCredential):
36-
api_base = forms.TextInputField('API URL', required=True)
37+
api_base = forms.TextInputField('API URL', required=True, default_value='https://api.regolo.ai/v1')
3738
api_key = forms.PasswordInputField('API Key', required=True)
3839

3940
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
@@ -43,7 +44,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
4344
raise AppApiException(ValidCode.valid_error.value,
4445
gettext('{model_type} Model type is not supported').format(model_type=model_type))
4546

46-
for key in ['api_key']:
47+
for key in ['api_key', 'api_base']:
4748
if key not in model_credential:
4849
if raise_exception:
4950
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))

apps/models_provider/impl/regolo_model_provider/credential/llm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from models_provider.base_model_provider import BaseModelCredential, ValidCode
1818
from common.utils.logger import maxkb_logger
1919

20+
2021
class RegoloLLMModelParams(BaseForm):
2122
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
2223
_('Higher values make the output more random, while lower values make it more focused and deterministic')),
@@ -45,7 +46,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
4546
raise AppApiException(ValidCode.valid_error.value,
4647
gettext('{model_type} Model type is not supported').format(model_type=model_type))
4748

48-
for key in ['api_key']:
49+
for key in ['api_key', 'api_base']:
4950
if key not in model_credential:
5051
if raise_exception:
5152
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
@@ -71,6 +72,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
7172
def encryption_dict(self, model: Dict[str, object]):
7273
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
7374

75+
api_base = forms.TextInputField('API URL', required=True, default_value='https://api.regolo.ai/v1')
7476
api_key = forms.PasswordInputField('API Key', required=True)
7577

7678
def get_model_params_setting_form(self, model_name):

0 commit comments

Comments
 (0)