-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrag_mistral_single_ip.py
More file actions
117 lines (97 loc) · 3.13 KB
/
rag_mistral_single_ip.py
File metadata and controls
117 lines (97 loc) · 3.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from langchain_community.vectorstores import FAISS
import sys
import re
import os
import json
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from transformers import pipeline
# Step 1: Load FAISS vectorstore
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vectorstore = FAISS.load_local(
"./data/vectordb/rag_vectorstore_db_v5",
embedding_model,
allow_dangerous_deserialization=True
)
# Step 2: Load Mistral-7B-Instruct model using HuggingFace pipeline
text_gen = pipeline(
"text-generation",
model="mistralai/Mistral-7B-Instruct-v0.1",
device=0, # GPU only
torch_dtype="auto",
max_new_tokens=512,
do_sample=True,
temperature=0.3,
top_p=0.9
)
llm = HuggingFacePipeline(
pipeline=text_gen,
model_kwargs={
"max_new_tokens": 512,
"temperature": 0.3,
"top_p": 0.9
}
)
# Step 3: Define Mistral-style prompt
prompt_template = PromptTemplate.from_template(
"""<s>[INST] Use the following context to answer the question.
If you don't know the answer, just say "I don't know."
Context:
{context}
Question:
{question}
[/INST]"""
)
# Step 4: Set up retrieval-based QA chain
retriever = vectorstore.as_retriever(search_kwargs={"k": 1}) # Lowered k to avoid context overflow
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
chain_type_kwargs={"prompt": prompt_template}
)
# Step 5: Run query
#query = "How is the 'condition' column preprocessed before training CatBoost in Zillow samples?"
query = input("Input your query to Mistral about CatBoost usage.")
result = qa_chain.invoke(query)
#print("Answer:")
#print(result["result"])
# Debug
'''
print(type(result))
for key, value in result.items():
print(f"{key}: {value}")
print(result)
'''
text = result['result'] # your raw string
context_match = re.search(r"Context:\n(.+?)\nQuestion:", text, re.DOTALL)
question_match = re.search(r"Question:\n(.+?)\n\[/INST\]", text, re.DOTALL)
answer_match = re.search(r"\[/INST\](.+)", text, re.DOTALL)
context = context_match.group(1).strip() if context_match else ""
question = question_match.group(1).strip() if question_match else ""
answer = answer_match.group(1).strip() if answer_match else ""
print("############ QUESTION ############\n", question)
print("\n############ CONTEXT RETRIEVED #############\n", context)
print("\n############ ANSWER ###########\n", answer)
# Load existing JSON or initialize new list
data_path = "output/data_new.json"
if os.path.exists(data_path) and (os.path.getsize(data_path) != 0):
with open(data_path, "r", encoding="utf-8") as f:
data = json.load(f)
else:
data = []
# Assign new unique ID
new_id = max((entry["id"] for entry in data), default=0) + 1
# Create new entry
entry = {
"id": new_id,
"question": question,
"context": context,
"answer": answer
}
# Append to dataset and save
data.append(entry)
with open(data_path, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2)
print(f"✅ Added result as id={new_id}")