forked from VivatImperial/SlovarikDB_Hallucination_Detection
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpipeline.py
More file actions
69 lines (53 loc) · 1.74 KB
/
pipeline.py
File metadata and controls
69 lines (53 loc) · 1.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from vsevolo_de_bert import VsevoloDeBERT
from classes.seporator_messeges import MessageSeparator
from classes.LLM import LLM_predictor
from classes.city_detector import CityExtractor
import gc
def pipeline(model_path: str, data: list, use_NER = True):
pred = []
sep_labels = []
city_labels = []
# First Step - len check
for i in data:
if len(i['answer']) > 100:
pred.append(1)
sep_labels.append(-1)
city_labels.append(-1)
else:
pred.append(0)
sep_labels.append(0)
city_labels.append(0)
print ('check1')
# Second Step - separation
separator = MessageSeparator()
if use_NER:
city_detector = CityExtractor()
for num, i in enumerate(data):
if sep_labels[num] != -1:
sep_labels[num] = separator.predict(i['answer'])
if use_NER:
city_labels[num] = city_detector.predict(i['answer'])
print ('check2')
del separator
if use_NER:
del city_detector
gc.collect()
# Third Step - classification
classifier = VsevoloDeBERT(model_path=model_path)
for num, i in enumerate(data):
if sep_labels[num] == 0 and city_labels[num] == 0:
pred[num] = classifier.predict([i['summary'], i['answer']])
del classifier
gc.collect()
print ('check3')
# Fourth Step - LLM
llm = LLM_predictor()
for num, i in enumerate(data):
if sep_labels[num] == 1:
pred[num] = llm.predict([i['summary'], i['question'], i['answer']], 0)
elif use_NER:
if city_labels[num] == 1:
pred[num] = llm.predict([i['summary'], i['question'], i['answer']], 1)
del llm
gc.collect()
return pred