-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Expand file tree
/
Copy pathi_chat_step.py
More file actions
110 lines (90 loc) · 4.53 KB
/
i_chat_step.py
File metadata and controls
110 lines (90 loc) · 4.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# coding=utf-8
"""
@project: maxkb
@Author:虎
@file: i_chat_step.py
@date:2024/1/9 18:17
@desc: 对话
"""
from abc import abstractmethod
from typing import Type, List
from django.utils.translation import gettext_lazy as _
from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseMessage
from rest_framework import serializers
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
from application.chat_pipeline.pipeline_manage import PipelineManage
from application.serializers.application import NoReferencesSetting
from common.field.common import InstanceField
class ModelField(serializers.Field):
def to_internal_value(self, data):
if not isinstance(data, BaseChatModel):
self.fail(_('Model type error'), value=data)
return data
def to_representation(self, value):
return value
class MessageField(serializers.Field):
def to_internal_value(self, data):
if not isinstance(data, BaseMessage):
self.fail(_('Message type error'), value=data)
return data
def to_representation(self, value):
return value
class PostResponseHandler:
@abstractmethod
def handler(self, chat_id, chat_record_id, paragraph_list: List[ParagraphPipelineModel], problem_text: str,
answer_text,
manage, step, padding_problem_text: str = None, **kwargs):
pass
class IChatStep(IBaseChatPipelineStep):
class InstanceSerializer(serializers.Serializer):
# 对话列表
message_list = serializers.ListField(required=True, child=MessageField(required=True),
label=_("Conversation list"))
model_id = serializers.UUIDField(required=False, allow_null=True, label=_("Model id"))
# 段落列表
paragraph_list = serializers.ListField(label=_("Paragraph List"))
# 对话id
chat_id = serializers.UUIDField(required=True, label=_("Conversation ID"))
# 用户问题
problem_text = serializers.CharField(required=True, label=_("User Questions"))
# 后置处理器
post_response_handler = InstanceField(model_type=PostResponseHandler,
label=_("Post-processor"))
# 补全问题
padding_problem_text = serializers.CharField(required=False,
label=_("Completion Question"))
# 是否使用流的形式输出
stream = serializers.BooleanField(required=False, label=_("Streaming Output"))
chat_user_id = serializers.CharField(required=True, label=_("Chat user id"))
chat_user_type = serializers.CharField(required=True, label=_("Chat user Type"))
# 未查询到引用分段
no_references_setting = NoReferencesSetting(required=True,
label=_("No reference segment settings"))
user_id = serializers.UUIDField(required=True, label=_("User ID"))
model_setting = serializers.DictField(required=True, allow_null=True,
label=_("Model settings"))
model_params_setting = serializers.DictField(required=False, allow_null=True,
label=_("Model parameter settings"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
message_list: List = self.initial_data.get('message_list')
for message in message_list:
if not isinstance(message, BaseMessage):
raise Exception(_("message type error"))
def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]:
return self.InstanceSerializer
def _run(self, manage: PipelineManage):
chat_result = self.execute(**self.context['step_args'], manage=manage)
manage.context['chat_result'] = chat_result
@abstractmethod
def execute(self, message_list: List[BaseMessage],
chat_id, problem_text,
post_response_handler: PostResponseHandler,
model_id: str = None,
user_id: str = None,
paragraph_list=None,
manage: PipelineManage = None,
padding_problem_text: str = None, stream: bool = True, chat_user_id=None, chat_user_type=None,
no_references_setting=None, model_params_setting=None, model_setting=None, **kwargs):
pass