Skip to content

Commit 2b7b519

Browse files
committed
fix: add base_url parameter to various model classes and update initialization logic
1 parent a157562 commit 2b7b519

10 files changed

Lines changed: 69 additions & 34 deletions

File tree

apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/tts.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44

55
from django.utils.translation import gettext_lazy as _, gettext
66

7+
from common import forms
78
from common.exception.app_exception import AppApiException
89
from common.forms import BaseForm, PasswordInputField, SingleSelect, SliderField, TooltipLabel
910
from models_provider.base_model_provider import BaseModelCredential, ValidCode
1011
from common.utils.logger import maxkb_logger
1112

13+
1214
class AliyunBaiLianTTSModelGeneralParams(BaseForm):
1315
"""
1416
Parameters class for the Aliyun BaiLian TTS (Text-to-Speech) model.
@@ -60,17 +62,18 @@ class AliyunBaiLianTTSModelCredential(BaseForm, BaseModelCredential):
6062
Credential class for the Aliyun BaiLian TTS (Text-to-Speech) model.
6163
Provides validation and encryption for the model credentials.
6264
"""
65+
api_base = forms.TextInputField(_('API URL'), required=True, default_value='https://dashscope.aliyuncs.com/api/v1')
6366

6467
api_key = PasswordInputField("API Key", required=True)
6568

6669
def is_valid(
67-
self,
68-
model_type: str,
69-
model_name: str,
70-
model_credential: Dict[str, object],
71-
model_params,
72-
provider,
73-
raise_exception: bool = False
70+
self,
71+
model_type: str,
72+
model_name: str,
73+
model_credential: Dict[str, object],
74+
model_params,
75+
provider,
76+
raise_exception: bool = False
7477
) -> bool:
7578
"""
7679
Validate the model credentials.
@@ -90,7 +93,7 @@ def is_valid(
9093
gettext('{model_type} Model type is not supported').format(model_type=model_type)
9194
)
9295

93-
required_keys = ['api_key']
96+
required_keys = ['api_key', 'api_base']
9497
for key in required_keys:
9598
if key not in model_credential:
9699
if raise_exception:

apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/ttv.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44

55
from django.utils.translation import gettext_lazy as _, gettext
66

7+
from common import forms
78
from common.exception.app_exception import AppApiException
89
from common.forms import BaseForm, PasswordInputField, SingleSelect, SliderField, TooltipLabel
910
from common.forms.switch_field import SwitchField
1011
from models_provider.base_model_provider import BaseModelCredential, ValidCode
1112
from common.utils.logger import maxkb_logger
1213

14+
1315
class QwenModelParams(BaseForm):
1416
"""
1517
Parameters class for the Qwen Text-to-Video model.
@@ -42,7 +44,7 @@ class TextToVideoModelCredential(BaseForm, BaseModelCredential):
4244
Credential class for the Qwen Text-to-Video model.
4345
Provides validation and encryption for the model credentials.
4446
"""
45-
47+
api_base = forms.TextInputField(_('API URL'), required=True, default_value='https://dashscope.aliyuncs.com/api/v1')
4648
api_key = PasswordInputField('API Key', required=True)
4749

4850
def is_valid(
@@ -72,7 +74,7 @@ def is_valid(
7274
gettext('{model_type} Model type is not supported').format(model_type=model_type)
7375
)
7476

75-
required_keys = ['api_key']
77+
required_keys = ['api_key', 'api_base']
7678
for key in required_keys:
7779
if key not in model_credential:
7880
if raise_exception:

apps/models_provider/impl/aliyun_bai_lian_model_provider/model/tts.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@
1111

1212
class AliyunBaiLianTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
1313
api_key: str
14+
base_url: str
1415
model: str
1516
params: dict
1617

1718
def __init__(self, **kwargs):
1819
super().__init__(**kwargs)
1920
self.api_key = kwargs.get('api_key')
21+
self.base_url = kwargs.get('base_url')
2022
self.model = kwargs.get('model')
2123
self.params = kwargs.get('params')
2224

@@ -34,6 +36,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
3436
return AliyunBaiLianTextToSpeech(
3537
model=model_name,
3638
api_key=model_credential.get('api_key'),
39+
base_url=model_credential.get('api_base', "https://dashscope.aliyuncs.com/api/v1"),
3740
**optional_params,
3841
)
3942

@@ -42,6 +45,7 @@ def check_auth(self):
4245

4346
def text_to_speech(self, text):
4447
dashscope.api_key = self.api_key
48+
dashscope.base_http_api_url = self.base_url
4549
text = _remove_empty_lines(text)
4650
if 'sambert' in self.model:
4751
from dashscope.audio.tts import SpeechSynthesizer
@@ -55,4 +59,3 @@ def text_to_speech(self, text):
5559
if type(audio) == str:
5660
raise Exception(audio)
5761
return audio
58-

apps/models_provider/impl/aliyun_bai_lian_model_provider/model/ttv.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
class GenerationVideoModel(MaxKBBaseModel, BaseGenerationVideo):
1414
api_key: str
15+
api_base: str
1516
model_name: str
1617
params: dict
1718
max_retries: int = 3
@@ -20,6 +21,7 @@ class GenerationVideoModel(MaxKBBaseModel, BaseGenerationVideo):
2021
def __init__(self, **kwargs):
2122
super().__init__(**kwargs)
2223
self.api_key = kwargs.get('api_key')
24+
self.api_base = kwargs.get('api_base')
2325
self.model_name = kwargs.get('model_name')
2426
self.params = kwargs.get('params', {})
2527
self.max_retries = kwargs.get('max_retries', 3)
@@ -35,9 +37,13 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
3537
for key, value in model_kwargs.items():
3638
if key not in ['model_id', 'use_local', 'streaming']:
3739
optional_params['params'][key] = value
40+
api_base = model_credential.get('api_base')
41+
if api_base is None:
42+
api_base = 'https://dashscope.aliyuncs.com/api/v1'
3843
return GenerationVideoModel(
3944
model_name=model_name,
4045
api_key=model_credential.get('api_key'),
46+
api_base=api_base,
4147
**optional_params,
4248
)
4349

@@ -66,6 +72,8 @@ def generate_video(self, prompt, negative_prompt=None, first_frame_url=None, las
6672
last_frame_url: 结束关键帧图片 URL (KF2V 必填)
6773
如果没有提供last_frame_url,则表示只提供了first_frame_url,生成的是单关键帧视频(KFV) 参数是img_url
6874
"""
75+
import dashscope
76+
dashscope.base_http_api_url = self.api_base
6977

7078
# 构建基础参数
7179
params = {"api_key": self.api_key, "prompt": prompt, "model": self.model_name,

apps/models_provider/impl/gemini_model_provider/model/stt.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from django.utils.translation import gettext as _
44
from langchain_core.messages import HumanMessage
55
from langchain_google_genai import ChatGoogleGenerativeAI
6+
from openai import base_url
67

78
from common.config.tokenizer_manage_config import TokenizerManage
89
from models_provider.base_model_provider import MaxKBBaseModel

apps/models_provider/impl/volcanic_engine_model_provider/credential/ttv.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,7 @@ class VolcanicEngineTTVModelGeneralParams(BaseForm):
5353

5454

5555
class VolcanicEngineTTVModelCredential(BaseForm, BaseModelCredential):
56-
api_base = forms.TextInputField('API URL', required=True,
57-
default_value='https://ark.cn-beijing.volces.com/api/v3')
56+
base_url = forms.TextInputField('Base URL', required=True, default_value='https://ark.cn-beijing.volces.com/api/v3')
5857
api_key = forms.PasswordInputField('Api key', required=True)
5958

6059
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
@@ -64,7 +63,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
6463
raise AppApiException(ValidCode.valid_error.value,
6564
gettext('{model_type} Model type is not supported').format(model_type=model_type))
6665

67-
for key in ['api_key', 'api_base']:
66+
for key in ['api_key', 'base_url']:
6867
if key not in model_credential:
6968
if raise_exception:
7069
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))

apps/models_provider/impl/volcanic_engine_model_provider/model/embedding.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class VolcanicEngineEmbeddingModel(MaxKBBaseModel):
1010
api_base: str
1111
params: Dict[str, object]
1212

13-
def __init__(self, api_key: str, model: str, api_base: str, params: Dict[str, object] = None):
13+
def __init__(self, api_key: str, model: str, api_base: str, **params):
1414
self.client = Ark(
1515
api_key=api_key,
1616
base_url=api_base
@@ -37,25 +37,40 @@ def embed_query(self, text: str):
3737
return res[0]
3838

3939
def embed_documents(
40-
self, texts: List[str], chunk_size: int | None = None
40+
self, texts: List[str]
4141
) -> List[List[float]]:
4242
if self.model_name.startswith("doubao-embedding-vision-"):
43-
multimodal_inputs = []
43+
embeddings = []
4444
for text in texts:
45-
multimodal_inputs.append({
46-
"type": "text",
47-
"text": text
48-
})
49-
resp = self.client.multimodal_embeddings.create(
50-
model=self.model_name,
51-
input=multimodal_inputs,
52-
**(self.params or {})
53-
)
54-
return [resp.data.get('embedding')]
45+
multimodal_input = {"type": "text", "text": text}
46+
resp = self.client.multimodal_embeddings.create(
47+
model=self.model_name,
48+
input=[multimodal_input],
49+
encoding_format="float",
50+
**(self.params or {})
51+
)
52+
embedding = self._extract_embedding(resp.data)
53+
if embedding is not None:
54+
embeddings.append(embedding)
55+
return embeddings
5556
else:
5657
resp = self.client.embeddings.create(
5758
model=self.model_name,
5859
input=texts,
5960
**(self.params or {})
6061
)
6162
return [e.embedding for e in resp.data]
63+
64+
def _extract_embedding(self, data):
65+
if isinstance(data, list) and len(data) > 0:
66+
item = data[0]
67+
else:
68+
item = data
69+
70+
if hasattr(item, 'embedding'):
71+
return item.embedding
72+
elif isinstance(item, dict):
73+
return item.get('embedding')
74+
elif isinstance(item, list):
75+
return item
76+
return None

apps/models_provider/impl/volcanic_engine_model_provider/model/ttv.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
class GenerationVideoModel(MaxKBBaseModel, BaseGenerationVideo):
1111
api_key: str
12-
api_base: str
12+
base_url: str
1313
model_name: str
1414
params: dict
1515
max_retries: int = 3
@@ -18,7 +18,7 @@ class GenerationVideoModel(MaxKBBaseModel, BaseGenerationVideo):
1818
def __init__(self, **kwargs):
1919
super().__init__(**kwargs)
2020
self.api_key = kwargs.get('api_key')
21-
self.api_base = kwargs.get('api_base')
21+
self.base_url = kwargs.get('base_url')
2222
self.model_name = kwargs.get('model_name')
2323
self.params = kwargs.get('params', {})
2424
self.retry_delay = 5
@@ -36,7 +36,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
3636
return GenerationVideoModel(
3737
model_name=model_name,
3838
api_key=model_credential.get('api_key'),
39-
api_base=model_credential.get('api_base') or 'https://ark.cn-beijing.volces.com/api/v3',
39+
base_url=model_credential.get('base_url', "https://ark.cn-beijing.volces.com/api/v3"),
4040
**optional_params,
4141
)
4242

@@ -76,7 +76,7 @@ def _poll_task(self, client: Ark, task_id: str, max_wait: int = 60, interval: in
7676

7777
# --- 通用异步生成函数 ---
7878
def generate_video(self, prompt, negative_prompt=None, first_frame_url=None, last_frame_url=None, **kwargs):
79-
client = Ark(api_key=self.api_key, base_url=self.api_base)
79+
client = Ark(api_key=self.api_key,base_url=self.base_url)
8080
# 根据params设置其他参数 豆包的参数和别的不一样 需要拼接在text里
8181
# --rt 16:9 --dur 5 --fps 24 --rs 720p --wm true --cf false
8282
prompt = self._build_prompt(prompt)

apps/models_provider/impl/zhipu_model_provider/credential/tti.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class ZhiPuTTIModelParams(BaseForm):
2929

3030

3131
class ZhiPuTextToImageModelCredential(BaseForm, BaseModelCredential):
32+
base_url = forms.TextInputField('Base URL', required=True, default_value='https://open.bigmodel.cn/api/paas/v4')
3233
api_key = forms.PasswordInputField('API Key', required=True)
3334

3435
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
@@ -38,7 +39,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
3839
raise AppApiException(ValidCode.valid_error.value,
3940
gettext('{model_type} Model type is not supported').format(model_type=model_type))
4041

41-
for key in ['api_key']:
42+
for key in ['api_key', 'base_url']:
4243
if key not in model_credential:
4344
if raise_exception:
4445
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))

apps/models_provider/impl/zhipu_model_provider/model/tti.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@ def custom_get_token_ids(text: str):
1717

1818
class ZhiPuTextToImage(MaxKBBaseModel, BaseTextToImage):
1919
api_key: str
20+
base_url: str
2021
model: str
2122
params: dict
2223

2324
def __init__(self, **kwargs):
2425
super().__init__(**kwargs)
2526
self.api_key = kwargs.get('api_key')
27+
self.base_url = kwargs.get('base_url')
2628
self.model = kwargs.get('model')
2729
self.params = kwargs.get('params')
2830

@@ -39,6 +41,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
3941
return ZhiPuTextToImage(
4042
model=model_name,
4143
api_key=model_credential.get('api_key'),
44+
base_url=model_credential.get('base_url', "https://open.bigmodel.cn/api/paas/v4"),
4245
**optional_params,
4346
)
4447

@@ -48,7 +51,7 @@ def is_cache_model(self):
4851
def check_auth(self):
4952
chat = ChatOpenAI(
5053
api_key=self.api_key,
51-
base_url='https://open.bigmodel.cn/api/paas/v4',
54+
base_url=self.base_url,
5255
model=self.model,
5356
)
5457
chat.invoke([HumanMessage([{"type": "text", "text": gettext('Hello')}])])
@@ -60,7 +63,7 @@ def generate_image(self, prompt: str, negative_prompt: str = None):
6063
# zhipuai_api_key=self.api_key,
6164
# model_name=self.model,
6265
# )
63-
chat = ZhipuAI(api_key=self.api_key)
66+
chat = ZhipuAI(api_key=self.api_key, base_url=self.base_url)
6467
response = chat.images.generations(
6568
model=self.model, # 填写需要调用的模型编码
6669
prompt=prompt, # 填写需要生成图片的文本

0 commit comments

Comments
 (0)