Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions apps/application/chat_pipeline/I_base_chat_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@

from rest_framework import serializers

from dataset.models import Paragraph
from knowledge.models import Paragraph


class ParagraphPipelineModel:

def __init__(self, _id: str, document_id: str, dataset_id: str, content: str, title: str, status: str,
def __init__(self, _id: str, document_id: str, knowledge_id: str, content: str, title: str, status: str,
is_active: bool, comprehensive_score: float, similarity: float, dataset_name: str, document_name: str,
hit_handling_method: str, directly_return_similarity: float, meta: dict = None):
self.id = _id
self.document_id = document_id
self.dataset_id = dataset_id
self.knowledge_id = knowledge_id
self.content = content
self.title = title
self.status = status,
Expand All @@ -39,7 +39,7 @@ def to_dict(self):
return {
'id': self.id,
'document_id': self.document_id,
'dataset_id': self.dataset_id,
'knowledge_id': self.knowledge_id,
'content': self.content,
'title': self.title,
'status': self.status,
Expand All @@ -66,7 +66,7 @@ def add_paragraph(self, paragraph):
if isinstance(paragraph, Paragraph):
self.paragraph = {'id': paragraph.id,
'document_id': paragraph.document_id,
'dataset_id': paragraph.dataset_id,
'knowledge_id': paragraph.knowledge_id,
'content': paragraph.content,
'title': paragraph.title,
'status': paragraph.status,
Expand Down Expand Up @@ -106,7 +106,7 @@ def add_meta(self, meta: dict):

def build(self):
return ParagraphPipelineModel(str(self.paragraph.get('id')), str(self.paragraph.get('document_id')),
str(self.paragraph.get('dataset_id')),
str(self.paragraph.get('knowledge_id')),
self.paragraph.get('content'), self.paragraph.get('title'),
self.paragraph.get('status'),
self.paragraph.get('is_active'),
Expand Down
9 changes: 5 additions & 4 deletions apps/application/chat_pipeline/step/chat_step/i_chat_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class PostResponseHandler:
@abstractmethod
def handler(self, chat_id, chat_record_id, paragraph_list: List[ParagraphPipelineModel], problem_text: str,
answer_text,
manage, step, padding_problem_text: str = None, client_id=None, **kwargs):
manage, step, padding_problem_text: str = None, **kwargs):
pass


Expand All @@ -68,8 +68,9 @@ class InstanceSerializer(serializers.Serializer):
label=_("Completion Question"))
# 是否使用流的形式输出
stream = serializers.BooleanField(required=False, label=_("Streaming Output"))
client_id = serializers.CharField(required=True, label=_("Client id"))
client_type = serializers.CharField(required=True, label=_("Client Type"))
chat_user_id = serializers.CharField(required=True, label=_("Chat user id"))

chat_user_type = serializers.CharField(required=True, label=_("Chat user Type"))
# 未查询到引用分段
no_references_setting = NoReferencesSetting(required=True,
label=_("No reference segment settings"))
Expand Down Expand Up @@ -104,6 +105,6 @@ def execute(self, message_list: List[BaseMessage],
user_id: str = None,
paragraph_list=None,
manage: PipelineManage = None,
padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None,
padding_problem_text: str = None, stream: bool = True, chat_user_id=None, chat_user_type=None,
no_references_setting=None, model_params_setting=None, model_setting=None, **kwargs):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,16 @@
from application.chat_pipeline.pipeline_manage import PipelineManage
from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
from application.flow.tools import Reasoning
from application.models.application_api_key import ApplicationPublicAccessClient
from common.constants.authentication_type import AuthenticationType
from application.models import ApplicationChatUserStats, ChatUserType
from models_provider.tools import get_model_instance_by_model_user_id


def add_access_num(client_id=None, client_type=None, application_id=None):
if client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value and application_id is not None:
application_public_access_client = (QuerySet(ApplicationPublicAccessClient).filter(client_id=client_id,
application_id=application_id)
def add_access_num(chat_user_id=None, chat_user_type=None, application_id=None):
if [ChatUserType.ANONYMOUS_USER.value, ChatUserType.CHAT_USER.value].__contains__(
chat_user_type) and application_id is not None:
application_public_access_client = (QuerySet(ApplicationChatUserStats).filter(chat_user_id=chat_user_id,
chat_user_type=chat_user_type,
application_id=application_id)
.first())
if application_public_access_client is not None:
application_public_access_client.access_num = application_public_access_client.access_num + 1
Expand Down Expand Up @@ -124,11 +125,9 @@ def event_content(response,
request_token = 0
response_token = 0
write_context(step, manage, request_token, response_token, all_text)
asker = manage.context.get('form_data', {}).get('asker', None)
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
all_text, manage, step, padding_problem_text, client_id,
reasoning_content=reasoning_content if reasoning_content_enable else ''
, asker=asker)
all_text, manage, step, padding_problem_text,
reasoning_content=reasoning_content if reasoning_content_enable else '')
yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node',
[], '', True,
request_token, response_token,
Expand All @@ -139,10 +138,8 @@ def event_content(response,
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
all_text = 'Exception:' + str(e)
write_context(step, manage, 0, 0, all_text)
asker = manage.context.get('form_data', {}).get('asker', None)
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
all_text, manage, step, padding_problem_text, client_id, reasoning_content='',
asker=asker)
all_text, manage, step, padding_problem_text, reasoning_content='')
add_access_num(client_id, client_type, manage.context.get('application_id'))
yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node',
[], all_text,
Expand All @@ -165,7 +162,7 @@ def execute(self, message_list: List[BaseMessage],
manage: PipelineManage = None,
padding_problem_text: str = None,
stream: bool = True,
client_id=None, client_type=None,
chat_user_id=None, chat_user_type=None,
no_references_setting=None,
model_params_setting=None,
model_setting=None,
Expand All @@ -175,12 +172,13 @@ def execute(self, message_list: List[BaseMessage],
if stream:
return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model,
paragraph_list,
manage, padding_problem_text, client_id, client_type, no_references_setting,
manage, padding_problem_text, chat_user_id, chat_user_type,
no_references_setting,
model_setting)
else:
return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model,
paragraph_list,
manage, padding_problem_text, client_id, client_type, no_references_setting,
manage, padding_problem_text, chat_user_id, chat_user_type, no_references_setting,
model_setting)

def get_details(self, manage, **kwargs):
Expand Down Expand Up @@ -235,7 +233,7 @@ def execute_stream(self, message_list: List[BaseMessage],
paragraph_list=None,
manage: PipelineManage = None,
padding_problem_text: str = None,
client_id=None, client_type=None,
chat_user_id=None, chat_user_type=None,
no_references_setting=None,
model_setting=None):
chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list,
Expand All @@ -244,7 +242,8 @@ def execute_stream(self, message_list: List[BaseMessage],
r = StreamingHttpResponse(
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
post_response_handler, manage, self, chat_model, message_list, problem_text,
padding_problem_text, client_id, client_type, is_ai_chat, model_setting),
padding_problem_text, chat_user_id, chat_user_type, is_ai_chat,
model_setting),
content_type='text/event-stream;charset=utf-8')

r['Cache-Control'] = 'no-cache'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from application.chat_pipeline.step.generate_human_message_step.i_generate_human_message_step import \
IGenerateHumanMessageStep
from application.models import ChatRecord
from common.util.split_model import flat_map
from common.utils.common import flat_map


class BaseGenerateHumanMessageStep(IGenerateHumanMessageStep):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ class InstanceSerializer(serializers.Serializer):
padding_problem_text = serializers.CharField(required=False,
label=_("System completes question text"))
# 需要查询的数据集id列表
dataset_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
label=_("Dataset id list"))
knowledge_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
label=_("Dataset id list"))
# 需要排除的文档id
exclude_document_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
label=_("List of document ids to exclude"))
Expand Down Expand Up @@ -55,7 +55,7 @@ def _run(self, manage: PipelineManage):
self.context['paragraph_list'] = paragraph_list

@abstractmethod
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
def execute(self, problem_text: str, knowledge_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
search_mode: str = None,
user_id=None,
Expand All @@ -65,7 +65,7 @@ def execute(self, problem_text: str, dataset_id_list: list[str], exclude_documen
:param similarity: 相关性
:param top_n: 查询多少条
:param problem_text: 用户问题
:param dataset_id_list: 需要查询的数据集id列表
:param knowledge_id_list: 需要查询的数据集id列表
:param exclude_document_id_list: 需要排除的文档id
:param exclude_paragraph_id_list: 需要排除段落id
:param padding_problem_text 补全问题
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,42 +35,33 @@ def get_model_by_id(_id, user_id):
return model


def get_embedding_id(dataset_id_list):
<<<<<<< Updated upstream:apps/chat/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
if len(set([dataset.embedding_model_id for dataset in dataset_list])) > 1:
raise Exception(_("The vector model of the associated knowledge base is inconsistent and the segmentation cannot be recalled."))
if len(dataset_list) == 0:
raise Exception(_("The knowledge base setting is wrong, please reset the knowledge base"))
return dataset_list[0].embedding_model_id
=======
knowledge_list = QuerySet(Knowledge).filter(id__in=dataset_id_list)
def get_embedding_id(knowledge_id_list):
knowledge_list = QuerySet(Knowledge).filter(id__in=knowledge_id_list)
if len(set([knowledge.embedding_mode_id for knowledge in knowledge_list])) > 1:
raise Exception(
_("The vector model of the associated knowledge base is inconsistent and the segmentation cannot be recalled."))
if len(knowledge_list) == 0:
raise Exception(_("The knowledge base setting is wrong, please reset the knowledge base"))
return knowledge_list[0].embedding_mode_id
>>>>>>> Stashed changes:apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py


class BaseSearchDatasetStep(ISearchDatasetStep):

def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
def execute(self, problem_text: str, knowledge_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
search_mode: str = None,
user_id=None,
**kwargs) -> List[ParagraphPipelineModel]:
if len(dataset_id_list) == 0:
if len(knowledge_id_list) == 0:
return []
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
model_id = get_embedding_id(dataset_id_list)
model_id = get_embedding_id(knowledge_id_list)
model = get_model_by_id(model_id, user_id)
self.context['model_name'] = model.name
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model))
embedding_value = embedding_model.embed_query(exec_problem_text)
vector = VectorStore.get_embedding_vector()
embedding_list = vector.query(exec_problem_text, embedding_value, dataset_id_list, exclude_document_id_list,
embedding_list = vector.query(exec_problem_text, embedding_value, knowledge_id_list, exclude_document_id_list,
exclude_paragraph_id_list, True, top_n, similarity, SearchMode(search_mode))
if embedding_list is None:
return []
Expand Down
24 changes: 12 additions & 12 deletions apps/application/flow/i_step_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from rest_framework.exceptions import ValidationError, ErrorDetail

from application.flow.common import Answer, NodeChunk
from application.models import ChatRecord
from application.models import ApplicationChatClientStats
from application.models import ChatRecord, ChatUserType
from application.models import ApplicationChatUserStats
from common.constants.authentication_type import AuthenticationType
from common.field.common import InstanceField

Expand All @@ -45,10 +45,10 @@ def is_interrupt(node, step_variable: Dict, global_variable: Dict):


class WorkFlowPostHandler:
def __init__(self, chat_info, client_id, client_type):
def __init__(self, chat_info, chat_user_id, chat_user_type):
self.chat_info = chat_info
self.client_id = client_id
self.client_type = client_type
self.chat_user_id = chat_user_id
self.chat_user_type = chat_user_type

def handler(self, chat_id,
chat_record_id,
Expand Down Expand Up @@ -84,13 +84,13 @@ def handler(self, chat_id,
run_time=time.time() - workflow.context['start_time'],
index=0)
asker = workflow.context.get('asker', None)
self.chat_info.append_chat_record(chat_record, self.client_id, asker)
# 重新设置缓存
chat_cache.set(chat_id,
self.chat_info, timeout=60 * 30)
if self.client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value:
application_public_access_client = (QuerySet(ApplicationChatClientStats)
.filter(client_id=self.client_id,
self.chat_info.append_chat_record(chat_record)
self.chat_info.set_cahce()
if [ChatUserType.ANONYMOUS_USER.value, ChatUserType.CHAT_USER.value].__contains__(
self.chat_user_type):
application_public_access_client = (QuerySet(ApplicationChatUserStats)
.filter(chat_user_id=self.chat_user_id,
chat_user_type=self.chat_user_type,
application_id=self.chat_info.application.id).first())
if application_public_access_client is not None:
application_public_access_client.access_num = application_public_access_client.access_num + 1
Expand Down
Loading