-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhaystack_research.py
More file actions
322 lines (286 loc) · 12.8 KB
/
haystack_research.py
File metadata and controls
322 lines (286 loc) · 12.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
# PJB: Use the VENV in Development/HayStack
#
# -----------------------------------------------------------------------------
# We could extract names (and other meta data) from the query
# using a LLM and use it in the retrieval.
# Uses researcher name, title and abstract (as generated by lucris-rs).
# -----------------------------------------------------------------------------
#
import sys
import re
import ollama
from haystack import Pipeline
from haystack import Document
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
from haystack.components.converters import TextFileToDocument
from haystack.components.preprocessors import DocumentCleaner
from haystack.components.preprocessors import DocumentSplitter
from haystack.document_stores.types import DuplicatePolicy
from haystack.components.writers import DocumentWriter
from haystack.components.rankers import LostInTheMiddleRanker
from haystack.components.rankers import TransformersSimilarityRanker
from haystack.components.rankers import SentenceTransformersDiversityRanker
from haystack_integrations.components.generators.ollama import OllamaGenerator
from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever
from haystack.components.embedders import SentenceTransformersDocumentEmbedder
from haystack.components.embedders import SentenceTransformersTextEmbedder
from haystack.components.builders import PromptBuilder
from haystack.components.joiners import DocumentJoiner
import argparse
import logging
import pprint
# Create a logger
logger = logging.getLogger('foo')
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler = logging.FileHandler('haystack.log')
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(formatter)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(message)s', "%Y-%m-%d %H:%M:%S")
console_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.addHandler(console_handler)
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", help="Model for text generation.", default="llama3.1")
parser.add_argument("-e", "--extractionmodel", help="Model for text extraction.", default="mistral")
parser.add_argument("-E", "--embeddings", action='store_true', help="Use embeddings.", default=False)
parser.add_argument("-q", "--query", help="query.", default=None)
parser.add_argument("-s", "--storename", help="Document store.", default="docs_research.store")
parser.add_argument("-p", "--showprompt", action='store_true', help="Show LLM prompts.", default=False)
parser.add_argument("-t", "--temp", type=float, help="Generator temperature.", default=0.1)
parser.add_argument("--top_k", type=int, help="Retriever top_k.", default=29)
parser.add_argument("--rank_k", type=int, help="Ranker top_k.", default=19)
args = parser.parse_args()
logger.debug(args)
# -----------------------------------------------------------------------------
# mistral seems to be better than llama, at least on the test cases.
def extract_persons(a_text) -> str:
prompt = "Your task is to extract the names of the people mentioned in the users input after TEXT:\n"\
"names start with a capital letter.\n"\
"Only reply with the json structure.\n"\
"Do not repeat the input text.\n"\
"Remove titles like Mr. or Mrs.\n"\
"If you cannot find any persons, reply with an empty structure like this: [{}].\n"\
"If the text is empty, reply with an empty structure like this: [{}].\n"\
"Format your output as a list of json with the following structure.\n"\
"[{\n"\
" \"person\": The name of the person\n"\
"}]\n"\
"Example user input: \"TEXT: What is Mr. John Doe working on?\n"\
"Example output: [{\"person\": \"John Doe\"}]\n"\
"Example user input: \"TEXT: who is working on second language acquisition?\n"\
"Example output: [{}]\n\n"
prompt = prompt + "TEXT:" + a_text + ".\n"
if args.showprompt:
print(prompt)
output = ollama.generate(
model=args.extractionmodel,
options={
'temperature': 0.0,
'top_k': 10, # ?
'num_ctx': 8096,
'repeat_last_n': -1,
},
prompt=prompt
)
return output['response']
def classify_query(a_text) -> str:
prompt = "Your task is to classify the query.\n"\
"The following classes are available:\n"\
"ResearchQuestion, for a question about research,\n"\
"PersonQuestion, for a question about a researcher.\n"\
"Only return the classification.\n"\
"Example input: \"TEXT: What is Mr. John Doe working on?\n"\
"Example output: PersonQuestion\n"\
"Example input: \"TEXT: Who is working with eye-tracking?\n"\
"Example output: ResearchQuestion\n"
prompt = prompt + "TEXT:" + a_text + ".\n"
if args.showprompt:
print(prompt)
output = ollama.generate(
model=args.extractionmodel,
options={
'temperature': 0.0,
'top_k': 10, # ?
'num_ctx': 8096,
'repeat_last_n': -1,
},
prompt=prompt
)
return output['response']
# -----------------------------------------------------------------------------
# Test name extraction.
if False:
print(extract_persons("What is Quinten Berck working on?"))
print(extract_persons("Tell me what John and Nisse Nissesson are researching?"))
print(extract_persons("I did my shopping at ICAs"))
print(extract_persons(""))
print(extract_persons("We used site-directed mutagenesis by Van den Bosch and Smith to do this."))
sys.exit(0)
# -----------------------------------------------------------------------------
def load_document_store(filename):
logger.info("Loading document store...")
document_store = InMemoryDocumentStore().load_from_disk(filename)
logger.info(f"Number of documents: {document_store.count_documents()}.")
return(document_store)
def handle_query(query, document_store):
#print("\nEnter Query:")
#query = input()
if query == "bye":
return False
logger.debug(args)
logger.info(query)
logger.info(extract_persons(query))
logger.info(classify_query(query))
retrieve_top_k = args.top_k # Use args directly.
rank_top_k = args.rank_k
if True: # Experimental
text_embedder = SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
embedding_retriever = InMemoryEmbeddingRetriever(document_store)
bm25_retriever = InMemoryBM25Retriever(document_store)
ranker = TransformersSimilarityRanker(model="BAAI/bge-reranker-base")
document_joiner = DocumentJoiner()
hybrid_retrieval = Pipeline()
hybrid_retrieval.add_component("text_embedder", text_embedder)
hybrid_retrieval.add_component("embedding_retriever", embedding_retriever)
hybrid_retrieval.add_component("bm25_retriever", bm25_retriever)
hybrid_retrieval.add_component("document_joiner", document_joiner)
hybrid_retrieval.add_component("ranker", ranker)
hybrid_retrieval.connect("text_embedder", "embedding_retriever")
hybrid_retrieval.connect("embedding_retriever", "document_joiner")
hybrid_retrieval.connect("bm25_retriever", "document_joiner")
hybrid_retrieval.connect("document_joiner", "ranker")
res = hybrid_retrieval.run(
{
"text_embedder": {"text": query},
"embedding_retriever":{"top_k": retrieve_top_k},
"bm25_retriever": {"query": query, "top_k": retrieve_top_k},
"ranker": {"query": query, "top_k": rank_top_k},
}
)
#print(res)
#pp = pprint.PrettyPrinter(indent=4, width=120)
#pp.pprint(res)
for i, r in enumerate(res["ranker"]["documents"]): # add ["document_joiner"] if experimental
#logger.info(r)
logger.info(f"{i:02n} {r.score:.4f} {r.meta["researcher_name"]} {r.content[0:78]}")
logger.info("")
logger.info("=" * 78)
'''
if args.embeddings == True:
retriever = InMemoryEmbeddingRetriever(document_store)
#doc_embedder = SentenceTransformersDocumentEmbedder(
# model="sentence-transformers/all-MiniLM-L6-v2", # Dim depends on model.
# meta_fields_to_embed=["title", "researcher_name"]
#)
#doc_embedder.warm_up()
text_embedder = SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
#text_embedder = SentenceTransformersTextEmbedder()
query_pipeline = Pipeline()
query_pipeline.add_component("text_embedder", text_embedder)
result = query_pipeline.run({"text_embedder": {"text": query}})
q_embedding = result['text_embedder']['embedding']
res = retriever.run(
query_embedding=q_embedding,
top_k=retrieve_top_k,
#scale_score=True
)
else:
retriever = InMemoryBM25Retriever(document_store=document_store)
res = retriever.run(
query=query,
top_k=retrieve_top_k,
#scale_score=True
)
logger.info("Retrieved documents")
for i, r in enumerate(res["documents"]): # add ["document_joiner"] if experimental
#logger.info(r)
logger.info(f"{i:02n} {r.score:.4f} {r.meta["researcher_name"]} {r.content[0:78]}")
logger.info("")
logger.info("=" * 78)
'''
'''
if args.rank_k > 0:
ranker = SentenceTransformersDiversityRanker(
model="sentence-transformers/all-MiniLM-L6-v2",
#model="cross-encoder/ms-marco-MiniLM-L-6-v2",
similarity="cosine",
)
ranker.warm_up()
res = ranker.run(
query=query,
documents=res["documents"],
top_k=rank_top_k
)
logger.info("Reranked documents:")
for i, r in enumerate(res["documents"]):
logger.info(f"{i:02n} {r.score:.4f} {r.meta["researcher_name"]} {r.content[0:78]}")
logger.info("=" * 78)
'''
# Include query classification.
template = """
Given the following context, answer the question at the end.
Do not make up facts. Do not use lists. When referring to research
mention the researchers names from the context. The name of the researcher will be given
first, followed by an abstract of the relevant research. The question will follow the context.
Reference the index numbers in the context when replying.
Context:
{% for document in documents %}
Researcher: {{ document.meta.researcher_name }}. Research: {{ document.content }}
{% endfor %}
Question: {{question}}
"""
# and: "{{ document.content if document.content is not none else 'NONE' }}"
# {{ document.content if document.content.length() > 10 else 'NONE' }}
prompt_builder = PromptBuilder(template=template)
generator = OllamaGenerator(
model=args.model, #"llama3.1",
#model="gemma2",
url = "http://localhost:11434",
generation_kwargs={
"num_predict": 8000,
"temperature": args.temp, # Higher is more "creative".
'num_ctx': 12028,
'repeat_last_n': -1,
}
)
basic_rag_pipeline = Pipeline()
basic_rag_pipeline.add_component("prompt_builder", prompt_builder)
basic_rag_pipeline.add_component("llm", generator)
basic_rag_pipeline.connect("prompt_builder", "llm")
#pp = pprint.PrettyPrinter(indent=4)
#logger.info(pp.pprint(basic_rag_pipeline.inputs()))
logger.info("Answering: "+query)
response = basic_rag_pipeline.run(
{
"prompt_builder": {"question": query,
"documents": res["ranker"]["documents"]
#"documents": res["documents"]
},
},
include_outputs_from={"prompt_builder"},
)
logger.debug(f"Context len: {len(response["llm"]["meta"][0]["context"])}")
logger.info(f"Prompt length: {len(response["prompt_builder"]["prompt"])}")
logger.info("-" * 78)
#logger.info(response["llm"]["replies"][0])
answer = response["llm"]["replies"][0]
# Remove deepseek's tags.
answer = re.sub(r'<think>.*?</think>', '', answer, flags=re.DOTALL)
logger.info(answer)
logger.info("-" * 78)
if args.showprompt:
logger.info("Prompt builder prompt:")
logger.info(response["prompt_builder"]["prompt"])
logger.info("=" * 78)
return True
if __name__ == '__main__':
document_store = load_document_store(args.storename)
go_on = True
while go_on:
print("\nEnter Query:")
query = input()
go_on = handle_query(query, document_store)