-
Notifications
You must be signed in to change notification settings - Fork 2.8k
feat: Generate Prompt #4002
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Generate Prompt #4002
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,12 +10,35 @@ | |
| from drf_spectacular.utils import OpenApiParameter | ||
|
|
||
| from application.serializers.application_chat_record import ChatRecordSerializerModel | ||
| from chat.serializers.chat import ChatMessageSerializers | ||
| from chat.serializers.chat import ChatMessageSerializers, GeneratePromptSerializers | ||
| from chat.serializers.chat_record import HistoryChatModel, EditAbstractSerializer | ||
| from common.mixins.api_mixin import APIMixin | ||
| from common.result import ResultSerializer, ResultPageSerializer, DefaultResultSerializer | ||
|
|
||
|
|
||
| class PromptGenerateAPI(APIMixin): | ||
| @staticmethod | ||
| def get_parameters(): | ||
| return [OpenApiParameter( | ||
| name="workspace_id", | ||
| description="工作空间id", | ||
| type=OpenApiTypes.STR, | ||
| location='path', | ||
| required=True, | ||
| ), | ||
| OpenApiParameter( | ||
| name="model_id", | ||
| description="模型id", | ||
| type=OpenApiTypes.STR, | ||
| location='path', | ||
| required=True,) | ||
| ] | ||
|
|
||
| @staticmethod | ||
| def get_request(): | ||
| return GeneratePromptSerializers | ||
|
|
||
|
|
||
| class ChatAPI(APIMixin): | ||
| @staticmethod | ||
| def get_parameters(): | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The provided Python code snippet is incomplete and lacks several key components. However, I can highlight some general issues and suggest improvements: Key Issues
Improvements SuggestedTo make the code more readable and functional, consider the following changes: from drf_spectacular.utils import OpenApiParameter
# Import other serializers you need
from chat.serializers.chat import ChatMessageSerializers, GeneratePromptSerializers
from chat.serializers.chat_record import HistoryChatModel, EditAbstractSerializer, ChatRecordSerializerModel
from common.mixins.api_mixin import APIMixin
from common.result import ResultSerializer, ResultPageSerializer, DefaultResultSerializer
class PromptGenerateAPI(APIMixin):
"""
API endpoint for generating prompts based on workspace and model IDs.
Parameters:
- workspace_id: str
Workspace identifier.
- model_id: str
Model identifier.
"""
@staticmethod
def get_parameters():
return [
OpenApiParameter(
name="workspace_id",
description="工作空间identifier",
type=OpenApiTypes.STR,
location='path',
required=True,
),
OpenApiParameter(
name="model_id",
description="模型identifier",
type=OpenApiTypes.STR,
location='path',
required=True,
)
]
@staticmethod
def get_request():
return GeneratePromptSerializers
class ChatAPI(APIMixin):
"""
API endpoint for managing interactions within a chat session.
"""
@staticmethod
def get_parameters():
# Define your parameters here if needed
pass
@staticmethod
def get_request():
# Define your request serializer here if needed
passSummary of Changes
Make sure to implement the logic inside the |
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,13 +6,14 @@ | |
| @date:2025/6/9 11:23 | ||
| @desc: | ||
| """ | ||
|
|
||
| import json | ||
| from gettext import gettext | ||
| from typing import List, Dict | ||
|
|
||
| import uuid_utils.compat as uuid | ||
| from django.db.models import QuerySet | ||
| from django.utils.translation import gettext_lazy as _ | ||
| from langchain_core.messages import HumanMessage, AIMessage | ||
| from rest_framework import serializers | ||
|
|
||
| from application.chat_pipeline.pipeline_manage import PipelineManage | ||
|
|
@@ -24,6 +25,7 @@ | |
| from application.chat_pipeline.step.search_dataset_step.impl.base_search_dataset_step import BaseSearchDatasetStep | ||
| from application.flow.common import Answer, Workflow | ||
| from application.flow.i_step_node import WorkFlowPostHandler | ||
| from application.flow.tools import to_stream_response_simple | ||
| from application.flow.workflow_manage import WorkflowManage | ||
| from application.models import Application, ApplicationTypeChoices, ApplicationKnowledgeMapping, \ | ||
| ChatUserType, ApplicationChatUserStats, ApplicationAccessToken, ChatRecord, Chat, ApplicationVersion | ||
|
|
@@ -37,7 +39,33 @@ | |
| from common.utils.common import flat_map | ||
| from knowledge.models import Document, Paragraph | ||
| from models_provider.models import Model, Status | ||
| from models_provider.tools import get_model_instance_by_model_workspace_id | ||
|
|
||
|
|
||
| class ChatMessagesSerializers(serializers.Serializer): | ||
| role = serializers.CharField(required=True, label=_("Role")) | ||
| content = serializers.CharField(required=True, label=_("Content")) | ||
|
|
||
|
|
||
| class GeneratePromptSerializers(serializers.Serializer): | ||
| prompt = serializers.CharField(required=True, label=_("Prompt template")) | ||
| messages = serializers.ListSerializer(child=ChatMessagesSerializers(), required=True, label=_("Chat context")) | ||
|
|
||
| def is_valid(self, *, raise_exception=False): | ||
| super().is_valid(raise_exception=True) | ||
| messages = self.data.get("messages") | ||
|
|
||
| if len(messages) > 30: | ||
| raise AppApiException(400, _("Too many messages")) | ||
|
|
||
| for index in range(len(messages)): | ||
| role = messages[index].get('role') | ||
| if role == 'ai' and index % 2 != 1: | ||
| raise AppApiException(400, _("Authentication failed. Please verify that the parameters are correct.")) | ||
| if role == 'user' and index % 2 != 0: | ||
| raise AppApiException(400, _("Authentication failed. Please verify that the parameters are correct.")) | ||
| if role not in ['user', 'ai']: | ||
| raise AppApiException(400, _("Authentication failed. Please verify that the parameters are correct.")) | ||
|
|
||
| class ChatMessageSerializers(serializers.Serializer): | ||
| message = serializers.CharField(required=True, label=_("User Questions")) | ||
|
|
@@ -113,6 +141,37 @@ def chat(self, instance: dict, base_to_response: BaseToResponse = SystemToRespon | |
| }).chat(instance, base_to_response) | ||
|
|
||
|
|
||
| class PromptGenerateSerializer(serializers.Serializer): | ||
| workspace_id = serializers.CharField(required=False, label=_('Workspace ID')) | ||
| model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("Model")) | ||
|
|
||
| def generate_prompt(self, instance: dict, with_valid=True): | ||
| if with_valid: | ||
| self.is_valid(raise_exception=True) | ||
| GeneratePromptSerializers(data=instance).is_valid(raise_exception=True) | ||
| workspace_id = self.data.get('workspace_id') | ||
| model_id = self.data.get('model_id') | ||
| prompt = instance.get('prompt') | ||
| messages = instance.get('messages') | ||
|
|
||
| message = messages[-1]['content'] | ||
| q = prompt.replace("{userInput}", message) | ||
| messages[-1]['content'] = q | ||
|
|
||
| model_exist = QuerySet(Model).filter(workspace_id=workspace_id, id=model_id).exists() | ||
| if not model_exist: | ||
| raise Exception(_("model does not exists")) | ||
|
|
||
| def process(): | ||
| model = get_model_instance_by_model_workspace_id(model_id=model_id, workspace_id=workspace_id) | ||
|
|
||
| for r in model.stream([HumanMessage(content=m.get('content')) if m.get('role') == 'user' else AIMessage( | ||
| content=m.get('content')) for m in messages]): | ||
| yield 'data: ' + json.dumps({'content': r.content}) + '\n\n' | ||
|
|
||
| return to_stream_response_simple(process()) | ||
|
|
||
|
|
||
| class OpenAIMessage(serializers.Serializer): | ||
| content = serializers.CharField(required=True, label=_('content')) | ||
| role = serializers.CharField(required=True, label=_('Role')) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here's a review of the changes made to improve and optimize the code:
Overall, these revisions aim to make the code cleaner, more maintainable, and potentially faster or safer under certain conditions through improved error checking and input validation mechanisms. |
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The provided code snippet appears to be part of an OpenAI API implementation, specifically related to handling requests for chat generation and prompts. Here is a review with some suggestions:
Irregularities, Issues, and Optimization Suggestions
Import Statements:
PromptGenerateAPIwas added without importing it elsewhere in the file.PromptGenerateAPIis defined in its own module or at least used properly within this context.Class Definitions:
ChatViewclass handles both debug chats and prompt generation. Consider splitting them into separate classes if they serve different purposes.Method Signatures:
post(request)method does not have any additional logic between checking permissions and returning serialized data.Code Readability:
Security:
Error Handling:
HTTP_400_BAD_REQUEST,HTTP_403_FORBIDDEN, etc.).Here are more detailed optimizations and enhancements based on these points:
And similarly for
PromptGenerateView:These improvements focus on enhancing code clarity, robustness, and security while adhering to best practices for API development.