|
9 | 9 | import tempfile |
10 | 10 | import zipfile |
11 | 11 | from typing import Dict |
12 | | -from django.core.cache import cache |
| 12 | + |
13 | 13 | import requests |
14 | 14 | import uuid_utils.compat as uuid |
15 | 15 | from django.core import validators |
| 16 | +from django.core.cache import cache |
16 | 17 | from django.db import transaction |
17 | 18 | from django.db.models import QuerySet, Q, Subquery, OuterRef, CharField, Value, When, Case |
18 | 19 | from django.http import HttpResponse |
19 | 20 | from django.utils import timezone |
20 | 21 | from django.utils.translation import gettext_lazy as _ |
| 22 | +from langchain_core.messages import HumanMessage, AIMessage |
21 | 23 | from langchain_mcp_adapters.client import MultiServerMCPClient |
22 | 24 | from pylint.lint import Run |
23 | 25 | from pylint.reporters import JSON2Reporter |
|
36 | 38 | from common.utils.tool_code import ToolExecutor |
37 | 39 | from knowledge.models import File, FileSourceType, Knowledge |
38 | 40 | from maxkb.const import PROJECT_DIR |
| 41 | +from models_provider.models import Model |
39 | 42 | from system_manage.models import AuthTargetType, WorkspaceUserResourcePermission |
40 | 43 | from system_manage.models.resource_mapping import ResourceMapping |
41 | 44 | from system_manage.serializers.resource_mapping_serializers import ResourceMappingSerializer |
@@ -1108,6 +1111,68 @@ def upload(self): |
1108 | 1111 | file.save(self.data.get('file').read()) |
1109 | 1112 | return file_id |
1110 | 1113 |
|
| 1114 | + class GenerateCodeSerializer(serializers.Serializer): |
| 1115 | + workspace_id = serializers.CharField(required=True, label=_('Workspace ID')) |
| 1116 | + model_id = serializers.UUIDField(required=True, label=_('Model ID')) |
| 1117 | + prompt = serializers.CharField(required=True, label=_('Prompt')) |
| 1118 | + messages = serializers.ListField(required=True, label=_('Messages')) |
| 1119 | + model_params_setting = serializers.DictField(required=False, default=dict, label=_('Model Params Setting')) |
| 1120 | + init_field_list = serializers.ListField(required=False, default=list, label=_('Init Field List')) |
| 1121 | + input_field_list = serializers.ListField(required=False, default=list, label=_('Input Field List')) |
| 1122 | + |
| 1123 | + def generate_code(self): |
| 1124 | + from models_provider.tools import get_model_instance_by_model_workspace_id |
| 1125 | + from application.flow.tools import to_stream_response_simple |
| 1126 | + |
| 1127 | + self.is_valid(raise_exception=True) |
| 1128 | + |
| 1129 | + workspace_id = self.data.get('workspace_id') |
| 1130 | + model_id = self.data.get('model_id') |
| 1131 | + prompt = self.data.get('prompt') |
| 1132 | + messages = self.data.get('messages') |
| 1133 | + model_params_setting = self.data.get('model_params_setting') |
| 1134 | + init_field_list = self.data.get('init_field_list') |
| 1135 | + input_field_list = self.data.get('input_field_list') |
| 1136 | + |
| 1137 | + message = messages[-1]['content'] |
| 1138 | + q = prompt.replace( |
| 1139 | + "{userInput}", message |
| 1140 | + ).replace( |
| 1141 | + "{initFieldList}", json.dumps(init_field_list) |
| 1142 | + ).replace( |
| 1143 | + "{inputFieldList}", json.dumps(input_field_list) |
| 1144 | + ) |
| 1145 | + |
| 1146 | + messages[-1]['content'] = q |
| 1147 | + SUPPORTED_MODEL_TYPES = ["LLM"] |
| 1148 | + model_exist = QuerySet(Model).filter( |
| 1149 | + id=model_id, |
| 1150 | + model_type__in=SUPPORTED_MODEL_TYPES |
| 1151 | + ).exists() |
| 1152 | + if not model_exist: |
| 1153 | + raise Exception(_("Model does not exists or is not an LLM model")) |
| 1154 | + |
| 1155 | + def process(): |
| 1156 | + model = get_model_instance_by_model_workspace_id( |
| 1157 | + model_id=model_id, workspace_id=workspace_id, **model_params_setting |
| 1158 | + ) |
| 1159 | + try: |
| 1160 | + for r in model.stream([ |
| 1161 | + # SystemMessage(content=SYSTEM_ROLE), |
| 1162 | + *[ |
| 1163 | + HumanMessage( |
| 1164 | + content=m.get('content') |
| 1165 | + ) if m.get('role') == 'user' else AIMessage( |
| 1166 | + content=m.get('content') |
| 1167 | + ) for m in messages |
| 1168 | + ] |
| 1169 | + ]): |
| 1170 | + yield 'data: ' + json.dumps({'content': r.content}) + '\n\n' |
| 1171 | + except Exception as e: |
| 1172 | + yield 'data: ' + json.dumps({'error': str(e)}) + '\n\n' |
| 1173 | + |
| 1174 | + return to_stream_response_simple(process()) |
| 1175 | + |
1111 | 1176 |
|
1112 | 1177 | class ToolTreeSerializer(serializers.Serializer): |
1113 | 1178 | class Query(serializers.Serializer): |
|
0 commit comments