|
51 | 51 | "from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline\n", |
52 | 52 | "import torch\n", |
53 | 53 | "import warnings\n", |
| 54 | + "\n", |
54 | 55 | "warnings.filterwarnings(\"ignore\")\n", |
55 | 56 | "\n", |
56 | 57 | "PROMPTS_FILE = \"data/test_data.json\"\n", |
|
94 | 95 | "outputs": [], |
95 | 96 | "source": [ |
96 | 97 | "embedder = HuggingFaceEmbeddings(model_name=EMBED_MODEL)\n", |
97 | | - "splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)" |
| 98 | + "splitter = RecursiveCharacterTextSplitter(\n", |
| 99 | + " chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP\n", |
| 100 | + ")" |
98 | 101 | ] |
99 | 102 | }, |
100 | 103 | { |
|
136 | 139 | " device=device,\n", |
137 | 140 | " )\n", |
138 | 141 | " return HuggingFacePipeline(pipeline=pipe)\n", |
| 142 | + "\n", |
| 143 | + "\n", |
139 | 144 | "llm = initialize_local_llm()" |
140 | 145 | ] |
141 | 146 | }, |
|
208 | 213 | "source": [ |
209 | 214 | "def build_chroma_store(docs, persist_dir=PERSIST_DIR):\n", |
210 | 215 | " db = Chroma.from_documents(\n", |
211 | | - " documents=docs,\n", |
212 | | - " embedding=embedder,\n", |
213 | | - " persist_directory=persist_dir\n", |
| 216 | + " documents=docs, embedding=embedder, persist_directory=persist_dir\n", |
214 | 217 | " )\n", |
215 | 218 | " db.persist()\n", |
216 | | - " return db\n" |
| 219 | + " return db" |
217 | 220 | ] |
218 | 221 | }, |
219 | 222 | { |
|
283 | 286 | "source": [ |
284 | 287 | "def query_database(query_text, k=TOP_K_RESULTS, threshold=RELEVANCE_THRESHOLD):\n", |
285 | 288 | " results = db.similarity_search_with_relevance_scores(query_text, k=k)\n", |
286 | | - " \n", |
| 289 | + "\n", |
287 | 290 | " if len(results) == 0 or results[0][1] < threshold:\n", |
288 | 291 | " return []\n", |
289 | | - " \n", |
| 292 | + "\n", |
290 | 293 | " return results\n", |
291 | 294 | "\n", |
| 295 | + "\n", |
292 | 296 | "def generate_rag_response(\n", |
293 | 297 | " query_text, k=TOP_K_RESULTS, threshold=RELEVANCE_THRESHOLD, verbose=False\n", |
294 | 298 | "):\n", |
|
329 | 333 | " \"prompt\": prompt,\n", |
330 | 334 | " \"scores\": [score for _, score in results],\n", |
331 | 335 | " }\n", |
332 | | - " \n", |
| 336 | + "\n", |
| 337 | + "\n", |
333 | 338 | "def ask(query_text):\n", |
334 | 339 | " result = generate_rag_response(query_text, verbose=True)\n", |
335 | 340 | " return result[\"answer\"]" |
|
407 | 412 | "\n", |
408 | 413 | " print(f\" Model Answer: {answer}\")\n", |
409 | 414 | " if expected:\n", |
410 | | - " print(f\" Expected: {expected}\")\n" |
| 415 | + " print(f\" Expected: {expected}\")" |
411 | 416 | ] |
412 | 417 | } |
413 | 418 | ], |
|
0 commit comments