|
| 1 | +from typing import Dict |
| 2 | + |
| 3 | +from django.utils.translation import gettext_lazy as _, gettext |
| 4 | +from langchain_core.messages import HumanMessage |
| 5 | + |
| 6 | +from common import forms |
| 7 | +from common.exception.app_exception import AppApiException |
| 8 | +from common.forms import BaseForm, TooltipLabel |
| 9 | +from models_provider.base_model_provider import ValidCode, BaseModelCredential |
| 10 | +from common.utils.logger import maxkb_logger |
| 11 | + |
| 12 | + |
| 13 | +class BedrockImageModelParams(BaseForm): |
| 14 | + temperature = forms.SliderField(TooltipLabel(_('Temperature'), |
| 15 | + _('Higher values make the output more random, while lower values make it more focused and deterministic')), |
| 16 | + required=True, default_value=0.7, |
| 17 | + _min=0.1, |
| 18 | + _max=1.0, |
| 19 | + _step=0.01, |
| 20 | + precision=2) |
| 21 | + |
| 22 | + max_tokens = forms.SliderField( |
| 23 | + TooltipLabel(_('Output the maximum Tokens'), |
| 24 | + _('Specify the maximum number of tokens that the model can generate')), |
| 25 | + required=True, default_value=1024, |
| 26 | + _min=1, |
| 27 | + _max=100000, |
| 28 | + _step=1, |
| 29 | + precision=0) |
| 30 | + |
| 31 | + |
| 32 | +class BedrockVLModelCredential(BaseForm, BaseModelCredential): |
| 33 | + |
| 34 | + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, |
| 35 | + raise_exception=False): |
| 36 | + model_type_list = provider.get_model_type_list() |
| 37 | + if not any(mt.get('value') == model_type for mt in model_type_list): |
| 38 | + if raise_exception: |
| 39 | + raise AppApiException(ValidCode.valid_error.value, |
| 40 | + gettext('{model_type} Model type is not supported').format(model_type=model_type)) |
| 41 | + return False |
| 42 | + |
| 43 | + required_keys = ['region_name', 'access_key_id', 'secret_access_key'] |
| 44 | + if not all(key in model_credential for key in required_keys): |
| 45 | + if raise_exception: |
| 46 | + raise AppApiException(ValidCode.valid_error.value, |
| 47 | + gettext('The following fields are required: {keys}').format( |
| 48 | + keys=", ".join(required_keys))) |
| 49 | + return False |
| 50 | + |
| 51 | + try: |
| 52 | + model = provider.get_model(model_type, model_name, model_credential, **model_params) |
| 53 | + model.invoke([HumanMessage(content=gettext('Hello'))]) |
| 54 | + except AppApiException: |
| 55 | + raise |
| 56 | + except Exception as e: |
| 57 | + maxkb_logger.error(f'Exception: {e}', exc_info=True) |
| 58 | + if raise_exception: |
| 59 | + raise AppApiException(ValidCode.valid_error.value, |
| 60 | + gettext( |
| 61 | + 'Verification failed, please check whether the parameters are correct: {error}').format( |
| 62 | + error=str(e))) |
| 63 | + return False |
| 64 | + |
| 65 | + return True |
| 66 | + |
| 67 | + def encryption_dict(self, model: Dict[str, object]): |
| 68 | + return {**model, 'secret_access_key': super().encryption(model.get('secret_access_key', ''))} |
| 69 | + |
| 70 | + region_name = forms.TextInputField('Region Name', required=True) |
| 71 | + access_key_id = forms.TextInputField('Access Key ID', required=True) |
| 72 | + secret_access_key = forms.PasswordInputField('Secret Access Key', required=True) |
| 73 | + base_url = forms.TextInputField('Proxy URL', required=False) |
| 74 | + |
| 75 | + def get_model_params_setting_form(self, model_name): |
| 76 | + return BedrockImageModelParams() |
0 commit comments