-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Expand file tree
/
Copy pathgenerate.py
More file actions
139 lines (117 loc) · 6.49 KB
/
generate.py
File metadata and controls
139 lines (117 loc) · 6.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import logging
import traceback
from celery_once import QueueOnce
from django.db.models import QuerySet
from django.db.models.functions import Reverse, Substr
from langchain_core.messages import HumanMessage
from common.config.embedding_config import ModelManage
from common.event import ListenerManagement
from common.util.page_utils import page, page_desc
from dataset.models import Paragraph, Document, Status, TaskType, State
from dataset.task.tools import save_problem
from ops import celery_app
from setting.models import Model
from setting.models_provider import get_model
from django.utils.translation import gettext_lazy as _
max_kb_error = logging.getLogger("max_kb_error")
max_kb = logging.getLogger("max_kb")
def get_llm_model(model_id):
model = QuerySet(Model).filter(id=model_id).first()
return ModelManage.get_model(model_id, lambda _id: get_model(model))
def generate_problem_by_paragraph(paragraph, llm_model, prompt):
try:
ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph.id), TaskType.GENERATE_PROBLEM,
State.STARTED)
res = llm_model.invoke(
[HumanMessage(content=prompt.replace('{data}', paragraph.content).replace('{title}', paragraph.title))])
if (res.content is None) or (len(res.content) == 0):
return
problems = res.content.split('\n')
for problem in problems:
save_problem(paragraph.dataset_id, paragraph.document_id, paragraph.id, problem)
ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph.id), TaskType.GENERATE_PROBLEM,
State.SUCCESS)
except Exception as e:
ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph.id), TaskType.GENERATE_PROBLEM,
State.FAILURE)
def get_generate_problem(llm_model, prompt, post_apply=lambda: None, is_the_task_interrupted=lambda: False):
def generate_problem(paragraph_list):
for paragraph in paragraph_list:
if is_the_task_interrupted():
return
generate_problem_by_paragraph(paragraph, llm_model, prompt)
post_apply()
return generate_problem
def get_is_the_task_interrupted(document_id):
def is_the_task_interrupted():
document = QuerySet(Document).filter(id=document_id).first()
if document is None or Status(document.status)[TaskType.GENERATE_PROBLEM] == State.REVOKE:
return True
return False
return is_the_task_interrupted
@celery_app.task(base=QueueOnce, once={'keys': ['dataset_id']},
name='celery:generate_related_by_dataset')
def generate_related_by_dataset_id(dataset_id, model_id, prompt, state_list=None):
document_list = QuerySet(Document).filter(dataset_id=dataset_id)
for document in document_list:
try:
generate_related_by_document_id.delay(document.id, model_id, prompt, state_list)
except Exception as e:
pass
@celery_app.task(base=QueueOnce, once={'keys': ['document_id']},
name='celery:generate_related_by_document')
def generate_related_by_document_id(document_id, model_id, prompt, state_list=None):
if state_list is None:
state_list = [State.PENDING.value, State.STARTED.value, State.SUCCESS.value, State.FAILURE.value,
State.REVOKE.value,
State.REVOKED.value, State.IGNORED.value]
try:
is_the_task_interrupted = get_is_the_task_interrupted(document_id)
if is_the_task_interrupted():
return
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id),
TaskType.GENERATE_PROBLEM,
State.STARTED)
llm_model = get_llm_model(model_id)
# 生成问题函数
generate_problem = get_generate_problem(llm_model, prompt,
ListenerManagement.get_aggregation_document_status(
document_id), is_the_task_interrupted)
query_set = QuerySet(Paragraph).annotate(
reversed_status=Reverse('status'),
task_type_status=Substr('reversed_status', TaskType.GENERATE_PROBLEM.value,
1),
).filter(task_type_status__in=state_list, document_id=document_id)
page_desc(query_set, 10, generate_problem, is_the_task_interrupted)
except Exception as e:
max_kb_error.error(f'根据文档生成问题:{document_id}出现错误{str(e)}{traceback.format_exc()}')
max_kb_error.error(_('Generate issue based on document: {document_id} error {error}{traceback}').format(
document_id=document_id, error=str(e), traceback=traceback.format_exc()))
finally:
ListenerManagement.post_update_document_status(document_id, TaskType.GENERATE_PROBLEM)
max_kb.info(_('End--->Generate problem: {document_id}').format(document_id=document_id))
@celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id_list']},
name='celery:generate_related_by_paragraph_list')
def generate_related_by_paragraph_id_list(document_id, paragraph_id_list, model_id, prompt):
try:
is_the_task_interrupted = get_is_the_task_interrupted(document_id)
if is_the_task_interrupted():
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id),
TaskType.GENERATE_PROBLEM,
State.REVOKED)
return
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id),
TaskType.GENERATE_PROBLEM,
State.STARTED)
llm_model = get_llm_model(model_id)
# 生成问题函数
generate_problem = get_generate_problem(llm_model, prompt, ListenerManagement.get_aggregation_document_status(
document_id))
def is_the_task_interrupted():
document = QuerySet(Document).filter(id=document_id).first()
if document is None or Status(document.status)[TaskType.GENERATE_PROBLEM] == State.REVOKE:
return True
return False
page(QuerySet(Paragraph).filter(id__in=paragraph_id_list), 10, generate_problem, is_the_task_interrupted)
finally:
ListenerManagement.post_update_document_status(document_id, TaskType.GENERATE_PROBLEM)