Skip to content

Commit 660397c

Browse files
committed
fixing formatting errors
1 parent 7744549 commit 660397c

1 file changed

Lines changed: 14 additions & 9 deletions

File tree

rag_and_distilled_model/Apollo11_rag&distilled.ipynb

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
"from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline\n",
5252
"import torch\n",
5353
"import warnings\n",
54+
"\n",
5455
"warnings.filterwarnings(\"ignore\")\n",
5556
"\n",
5657
"PROMPTS_FILE = \"data/test_data.json\"\n",
@@ -94,7 +95,9 @@
9495
"outputs": [],
9596
"source": [
9697
"embedder = HuggingFaceEmbeddings(model_name=EMBED_MODEL)\n",
97-
"splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)"
98+
"splitter = RecursiveCharacterTextSplitter(\n",
99+
" chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP\n",
100+
")"
98101
]
99102
},
100103
{
@@ -136,6 +139,8 @@
136139
" device=device,\n",
137140
" )\n",
138141
" return HuggingFacePipeline(pipeline=pipe)\n",
142+
"\n",
143+
"\n",
139144
"llm = initialize_local_llm()"
140145
]
141146
},
@@ -208,12 +213,10 @@
208213
"source": [
209214
"def build_chroma_store(docs, persist_dir=PERSIST_DIR):\n",
210215
" db = Chroma.from_documents(\n",
211-
" documents=docs,\n",
212-
" embedding=embedder,\n",
213-
" persist_directory=persist_dir\n",
216+
" documents=docs, embedding=embedder, persist_directory=persist_dir\n",
214217
" )\n",
215218
" db.persist()\n",
216-
" return db\n"
219+
" return db"
217220
]
218221
},
219222
{
@@ -283,12 +286,13 @@
283286
"source": [
284287
"def query_database(query_text, k=TOP_K_RESULTS, threshold=RELEVANCE_THRESHOLD):\n",
285288
" results = db.similarity_search_with_relevance_scores(query_text, k=k)\n",
286-
" \n",
289+
"\n",
287290
" if len(results) == 0 or results[0][1] < threshold:\n",
288291
" return []\n",
289-
" \n",
292+
"\n",
290293
" return results\n",
291294
"\n",
295+
"\n",
292296
"def generate_rag_response(\n",
293297
" query_text, k=TOP_K_RESULTS, threshold=RELEVANCE_THRESHOLD, verbose=False\n",
294298
"):\n",
@@ -329,7 +333,8 @@
329333
" \"prompt\": prompt,\n",
330334
" \"scores\": [score for _, score in results],\n",
331335
" }\n",
332-
" \n",
336+
"\n",
337+
"\n",
333338
"def ask(query_text):\n",
334339
" result = generate_rag_response(query_text, verbose=True)\n",
335340
" return result[\"answer\"]"
@@ -407,7 +412,7 @@
407412
"\n",
408413
" print(f\" Model Answer: {answer}\")\n",
409414
" if expected:\n",
410-
" print(f\" Expected: {expected}\")\n"
415+
" print(f\" Expected: {expected}\")"
411416
]
412417
}
413418
],

0 commit comments

Comments
 (0)