Skip to content

Commit e4cd6e8

Browse files
committed
feat: add api_base field to VolcanicEngineTTIModelCredential and update related classes
1 parent 666065e commit e4cd6e8

2 files changed

Lines changed: 7 additions & 1 deletion

File tree

  • apps/models_provider/impl/volcanic_engine_model_provider

apps/models_provider/impl/volcanic_engine_model_provider/credential/tti.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from models_provider.base_model_provider import BaseModelCredential, ValidCode
1010
from common.utils.logger import maxkb_logger
1111

12+
1213
class VolcanicEngineTTIModelGeneralParams(BaseForm):
1314
size = forms.SingleSelect(
1415
TooltipLabel(_('Image size'),
@@ -32,6 +33,8 @@ class VolcanicEngineTTIModelGeneralParams(BaseForm):
3233

3334

3435
class VolcanicEngineTTIModelCredential(BaseForm, BaseModelCredential):
36+
volcanic_api_url = forms.TextInputField('API URL', required=True,
37+
default_value='https://ark.cn-beijing.volces.com/api/v3')
3538
api_key = forms.PasswordInputField('Api key', required=True)
3639

3740
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,

apps/models_provider/impl/volcanic_engine_model_provider/model/tti.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@
1616

1717
class VolcanicEngineTextToImage(MaxKBBaseModel, BaseTextToImage):
1818
api_key: str
19+
api_base: str
1920
model_version: str
2021
params: dict
2122

2223
def __init__(self, **kwargs):
2324
super().__init__(**kwargs)
2425
self.api_key = kwargs.get('api_key')
26+
self.api_base = kwargs.get('api_base')
2527
self.model_version = kwargs.get('model_version')
2628
self.params = kwargs.get('params')
2729

@@ -38,6 +40,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
3840
return VolcanicEngineTextToImage(
3941
model_version=model_name,
4042
api_key=model_credential.get('api_key'),
43+
api_base=model_credential.get('api_base') or 'https://ark-api.volcengine.com',
4144
**optional_params
4245
)
4346

@@ -47,7 +50,7 @@ def check_auth(self):
4750
def generate_image(self, prompt: str, negative_prompt: str = None):
4851
client = Ark(
4952
# 此为默认路径,您可根据业务所在地域进行配置
50-
base_url="https://ark.cn-beijing.volces.com/api/v3",
53+
base_url=self.api_base,
5154
# 从环境变量中获取您的 API Key。此为默认方式,您可根据需要进行修改
5255
api_key=self.api_key,
5356
)

0 commit comments

Comments
 (0)