-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutility.py
More file actions
571 lines (445 loc) · 19.5 KB
/
utility.py
File metadata and controls
571 lines (445 loc) · 19.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
# utility.py
"""
Core retrieval, embedding, and indexing utilities for AskMe-FAQ-Bot.
This module provides:
- CSV I/O for the FAQ knowledge base
- Embedding via Gemini with retries
- FAISS index build/load and dimension checks
- Hybrid retrieval (semantic + lexical) with score fusion
- Exact normalized lookup fast path
- Fallback capture and conversational fallback generation
Design notes:
* Import-time side effects are minimized. Gemini client configuration is guarded;
if configuration fails at import time, we log it and defer hard failures until
an embedding/fallback call is actually attempted.
* All functions log key steps and errors via the shared app logger
(see logger.setup_logging()).
"""
from __future__ import annotations
import os
import time
import hashlib
from dataclasses import dataclass
from datetime import datetime
from typing import List, Optional, Dict, Tuple
import numpy as np
import pandas as pd
import faiss
from tqdm import tqdm
from rapidfuzz import fuzz
from config import (
BASE_FAQ_PATH, FALLBACK_PATH,
EMBED_MODEL, GENERATE_MODEL,
TOP_K, THRESHOLD, HYBRID_ALPHA,
FAISS_DIR, GEMINI_API_KEY,
)
from preprocessing import normalize_for_match, preprocess
from logger import get_logger
log = get_logger()
# ----------------------- Model-name normalization -----------------------
def _fix_model_name(name: str) -> str:
"""Ensure Google Generative AI model id is prefixed correctly.
Accepts either 'text-embedding-004' or 'models/text-embedding-004'
and returns the prefixed form. Also passes through tuned models.
Args:
name: Raw model name from config/.env.
Returns:
Normalized model name beginning with 'models/' or 'tunedModels/'.
"""
if not name:
return name
if name.startswith(("models/", "tunedModels/")):
return name
return f"models/{name}"
# ---- Safe Gemini import & configuration (no hard crash at import time) ----
try:
import google.generativeai as genai # type: ignore
try:
genai.configure(api_key=GEMINI_API_KEY)
_GEMINI_OK = True
log.info("Gemini client configured.")
# Log final model ids so you can confirm the prefix fix at runtime
log.info(
"Using models => embed=%s | generate=%s",
_fix_model_name(EMBED_MODEL),
_fix_model_name(GENERATE_MODEL),
)
except Exception as e:
_GEMINI_OK = False
log.error("Gemini configure() failed: %s", e)
except Exception as e:
_GEMINI_OK = False
genai = None # type: ignore
log.error("Failed to import google.generativeai: %s", e)
def _model_tag() -> str:
"""Compute a short stable tag for the current embedding model.
Returns:
str: First 8 hex chars of SHA1 over the EMBED_MODEL string.
"""
return hashlib.sha1(EMBED_MODEL.encode()).hexdigest()[:8]
# Paths for the FAISS index and metadata (per-embed-model)
INDEX_PATH = str(FAISS_DIR / f"index_{_model_tag()}.bin")
INDEX_META = str(FAISS_DIR / f"meta_{_model_tag()}.parquet")
@dataclass
class Store:
"""In-memory holders for the FAISS index and metadata.
Attributes:
index (Optional[faiss.IndexFlatIP]): In-memory FAISS index (IP over L2-normalized vectors).
meta (Optional[pd.DataFrame]): Metadata with columns ['question_id', 'question', 'answer'].
norm_map (Optional[dict]): Normalized (lower/alpha-num) question → answer map for exact lookup.
"""
index: Optional[faiss.IndexFlatIP]
meta: Optional[pd.DataFrame]
norm_map: Optional[Dict[str, str]]
# Global store (shared across UI callbacks)
store = Store(index=None, meta=None, norm_map=None)
# ------------------------------- CSV I/O -------------------------------
def load_base_csv() -> Optional[pd.DataFrame]:
"""Load the base FAQ CSV (`basic_faq.csv`) if present.
The CSV is expected to contain the columns: 'question_id', 'question', 'answer'
(case-insensitive). Extra columns are ignored.
Returns:
Optional[pd.DataFrame]: A dataframe with exactly ['question_id', 'question', 'answer']
or None if the file doesn't exist.
Raises:
ValueError: If the file exists but is missing required columns.
"""
if not os.path.exists(BASE_FAQ_PATH):
log.warning("basic_faq.csv not found at %s", BASE_FAQ_PATH)
return None
df = pd.read_csv(BASE_FAQ_PATH)
expected = {"question_id", "question", "answer"}
cols_lower = {c.strip().lower() for c in df.columns}
if not expected.issubset(cols_lower):
# Normalize headers and recheck
df.columns = [c.strip().lower() for c in df.columns]
if not expected.issubset(set(df.columns)):
raise ValueError(f"basic_faq.csv must have columns: {expected}")
df = df[["question_id", "question", "answer"]].dropna(subset=["question", "answer"])
log.info("Loaded base CSV: rows=%d", len(df))
return df
def save_base_csv(df: pd.DataFrame) -> None:
"""Persist the (question_id, question, answer) dataframe to `basic_faq.csv`.
Args:
df: Dataframe with columns ['question_id', 'question', 'answer'].
"""
df.to_csv(BASE_FAQ_PATH, index=False)
log.info("Base CSV saved -> %s (rows=%d)", BASE_FAQ_PATH, len(df))
# ---------------------------- Embeddings ----------------------------
def _embed_one(text: str, task_type: str) -> np.ndarray:
"""Embed a single text string using Gemini embeddings with retries.
Args:
text: The text to embed.
task_type: One of 'retrieval_document' or 'retrieval_query'.
Returns:
2D float32 array of shape (1, d).
Raises:
RuntimeError: If the Gemini client is not available or if all retries fail.
"""
if not _GEMINI_OK:
raise RuntimeError("Gemini client not available; cannot embed.")
last_err: Optional[Exception] = None
for attempt in range(3):
try:
res = genai.embed_content( # type: ignore
model=_fix_model_name(EMBED_MODEL), # <-- normalized here
content=text,
task_type=task_type,
)
vec = np.array(res["embedding"], dtype=np.float32).reshape(1, -1)
return vec
except Exception as e:
last_err = e
log.warning("Embed attempt %d failed: %s", attempt + 1, e)
time.sleep(0.5 * (attempt + 1))
raise RuntimeError(f"Embedding failed after retries: {last_err}")
def embed_texts(texts: List[str], task_type: str, show_progress: bool = False) -> np.ndarray:
"""Embed a batch of texts into a 2D numpy array.
Uses a simple list-comprehension loop over `_embed_one` with optional tqdm progress.
Args:
texts: List of input strings.
task_type: 'retrieval_document' or 'retrieval_query'.
show_progress: If True, display a terminal tqdm bar.
Returns:
2D float32 array of shape (N, d).
"""
iterator = tqdm(texts, desc=f"Embedding ({task_type})", unit="q", disable=not show_progress)
vecs = [_embed_one(t, task_type) for t in iterator]
return np.vstack(vecs).astype(np.float32)
# -------------------------- FAISS Indexing --------------------------
def _build_norm_map(df: pd.DataFrame) -> Dict[str, str]:
"""Build a normalized question → answer map for exact-match fast path.
Args:
df: Dataframe with ['question', 'answer'] columns.
Returns:
Mapping from normalized question to answer.
"""
return {normalize_for_match(q): a for q, a in zip(df["question"].astype(str), df["answer"].astype(str))}
def build_index_from_df(df: pd.DataFrame) -> None:
"""Create a FAISS index + metadata parquet from a (qid,question,answer) dataframe.
Steps:
1) Normalize column names and select required columns.
2) Build `store.norm_map` for exact lookups.
3) Embed all questions (shows tqdm in terminal).
4) L2-normalize vectors and add to an IndexFlatIP (cosine via IP on L2-normalized).
5) Persist index (binary) and metadata (parquet) to disk.
6) Update the in-memory `store`.
Args:
df: Dataframe with columns ['question_id', 'question', 'answer'].
Raises:
ValueError: If `df` is empty.
"""
if df is None or df.empty:
raise ValueError("FAQ dataframe is empty; cannot build index.")
df = df.rename(columns={c: c.strip().lower() for c in df.columns})
df = df[["question_id", "question", "answer"]].dropna(subset=["question", "answer"])
# Build normalized map for the exact-lookup path
store.norm_map = _build_norm_map(df)
log.info("[Indexing] Starting embeddings for %d questions…", len(df))
vecs = embed_texts(df["question"].astype(str).tolist(), task_type="retrieval_document", show_progress=True)
log.info("[Indexing] Embeddings complete. vecs.shape=%s", vecs.shape)
log.info("[Indexing] Normalizing vectors (L2)…")
faiss.normalize_L2(vecs)
log.info("[Indexing] Building FAISS (IndexFlatIP)…")
index = faiss.IndexFlatIP(vecs.shape[1])
index.add(vecs)
log.info("[Indexing] Writing index + metadata to disk…")
faiss.write_index(index, INDEX_PATH)
df.to_parquet(INDEX_META, index=False)
store.index, store.meta = index, df
log.info("[Indexing] Done. FAISS d=%d, items=%d -> %s, %s", index.d, index.ntotal, INDEX_PATH, INDEX_META)
def load_index_if_exists() -> None:
"""Load an existing FAISS index and metadata parquet from disk if present.
Side Effects:
- Populates `store.index`, `store.meta`, and `store.norm_map` when files exist.
Notes:
Does nothing if either the index file or metadata file is missing.
"""
if os.path.exists(INDEX_PATH) and os.path.exists(INDEX_META):
idx = faiss.read_index(INDEX_PATH)
meta = pd.read_parquet(INDEX_META)
store.index, store.meta = idx, meta
store.norm_map = _build_norm_map(meta)
log.info("Loaded FAISS: d=%d, items=%d", idx.d, idx.ntotal)
def ensure_index_ready_for_dim(expected_dim: int) -> None:
"""Ensure that a FAISS index exists in-memory with the expected embedding dimension.
This function:
- Tries to load an existing index if not already loaded.
- If no index exists, attempts to build one from `basic_faq.csv`.
- If the index dimension differs from `expected_dim`, rebuilds from the base CSV.
Args:
expected_dim: The embedding dimension produced by the current model.
Raises:
RuntimeError: If no base CSV is available to build from.
"""
load_index_if_exists()
if store.index is None:
base = load_base_csv()
if base is None or base.empty:
raise RuntimeError("No index on disk and no basic_faq.csv to rebuild from.")
log.info("No index present. Building a fresh index…")
build_index_from_df(base)
return
if getattr(store.index, "d", None) != expected_dim:
log.warning("FAISS dim mismatch: index.d=%s, expected=%s. Rebuilding…", store.index.d, expected_dim)
base = load_base_csv()
if base is None or base.empty:
raise RuntimeError("Dim mismatch and no basic_faq.csv available to rebuild.")
build_index_from_df(base)
# ---------------------------- Retrieval ----------------------------
def semantic_search(q: str, k: int) -> List[Tuple[float, Dict, int]]:
"""Search the FAISS index semantically and return top-k results.
Args:
q: The preprocessed query string (use `preprocess()` for semantic embedding).
k: Number of nearest neighbors to retrieve (upper bound).
Returns:
A list of tuples (score, row_dict, row_idx), where:
- score: inner-product score on L2-normalized vectors (≈ cosine).
- row_dict: metadata row with keys ['question_id','question','answer'].
- row_idx: index of the row in `store.meta`.
"""
qv = embed_texts([q], task_type="retrieval_query").astype(np.float32)
faiss.normalize_L2(qv)
ensure_index_ready_for_dim(qv.shape[1])
sims, ids = store.index.search(qv, k) # type: ignore
results: List[Tuple[float, Dict, int]] = []
for score, i in zip(sims[0], ids[0]):
if i == -1:
continue
row = store.meta.iloc[int(i)]
results.append((float(score), dict(row), int(i)))
log.info("Semantic top sims: %s", ", ".join([f"{s:.3f}" for s, _, _ in results[:3]]))
return results
def lexical_score(a: str, b: str) -> float:
"""Compute a token-set fuzzy match ratio in [0, 1].
Args:
a: Raw user query (unprocessed; keeps user phrasing).
b: Candidate question from the corpus.
Returns:
Similarity score in [0, 1].
"""
return fuzz.token_set_ratio(a, b) / 100.0
def hybrid_search(q_raw: str, q_proc: str, k: int = TOP_K) -> Optional[Dict]:
"""Perform hybrid retrieval (semantic + lexical) and return the best hit.
Procedure:
1) Semantic search on the preprocessed query to obtain top-K candidates.
2) Lexical fuzzy scoring:
- If corpus is small (<= 200 items), score against ALL questions.
- Else, score only semantic candidates.
3) Score fusion: combined = HYBRID_ALPHA * semantic + (1 - HYBRID_ALPHA) * lexical.
4) Return the top fused candidate if `combined >= THRESHOLD`; otherwise None.
Args:
q_raw: The raw user query (for lexical scoring).
q_proc: The preprocessed query (for semantic scoring).
k: Number of semantic neighbors to consider.
Returns:
The best metadata row (as dict) if above threshold; else None.
"""
if store.meta is None or store.meta.empty:
base = load_base_csv()
if base is None or base.empty:
return None
build_index_from_df(base)
sem = semantic_search(q_proc, k=max(k, 8))
# Decide lexical scope (all vs semantic-only) based on corpus size
if len(store.meta) <= 200:
lex_scores: List[Tuple[int, float]] = []
for i, row in store.meta.iterrows():
lex = lexical_score(q_raw, str(row["question"]))
lex_scores.append((int(i), lex))
lex_scores.sort(key=lambda x: x[1], reverse=True)
best_lex: Dict[int, float] = dict(lex_scores[: max(k, 8)])
else:
best_lex = {idx: lexical_score(q_raw, r["question"]) for _, r, idx in sem}
# Union of semantic and lexical candidates
candidates: Dict[int, Dict] = {}
for s, r, idx in sem:
candidates[idx] = {"semantic": s, "row": r}
for idx, l in best_lex.items():
if idx not in candidates:
r = store.meta.iloc[idx].to_dict()
candidates[idx] = {"semantic": 0.0, "row": r}
candidates[idx]["lexical"] = l
# Fusion
fused: List[Tuple[float, float, float, int, Dict]] = []
for idx, obj in candidates.items():
s = obj.get("semantic", 0.0)
l = obj.get("lexical", 0.0)
combined = HYBRID_ALPHA * s + (1.0 - HYBRID_ALPHA) * l
fused.append((combined, s, l, idx, obj["row"]))
fused.sort(key=lambda x: x[0], reverse=True)
top = fused[0] if fused else None
if top:
combined, s, l, idx, row = top
log.info("Hybrid top => combined=%.3f (semantic=%.3f, lexical=%.3f)", combined, s, l)
if combined >= THRESHOLD:
return row
return None
# ----------------------------- Answering -----------------------------
def append_fallback(q: str) -> None:
"""Append an unanswered question to `fallback.csv` with a UTC timestamp.
Args:
q: The user question to record.
"""
row = {"timestamp": datetime.utcnow().isoformat(), "question": q}
pd.DataFrame([row]).to_csv(
FALLBACK_PATH,
mode="a",
header=not os.path.exists(FALLBACK_PATH),
index=False,
)
log.info("Fallback appended -> %s", FALLBACK_PATH)
def natural_fallback(user_text: str) -> str:
"""Generate a polite fallback response via Gemini, or return a static message.
Args:
user_text: The original user query.
Returns:
A brief, friendly fallback text.
"""
if not _GEMINI_OK:
log.warning("Gemini fallback unavailable; returning static message.")
return "I don’t currently have an answer for that in my database."
prompt = (
"You are a helpful FAQ assistant. "
"We could not find an exact answer in the database. "
f"User asked: '{user_text}'. "
"Politely say you don’t currently have an answer in the database and that it will be reviewed. "
"Be brief and friendly."
)
try:
model = genai.GenerativeModel(_fix_model_name(GENERATE_MODEL)) # <-- normalized here
resp = model.generate_content(prompt)
return resp.text.strip() if getattr(resp, "text", None) else "I don’t currently have an answer for that in my database."
except Exception as e:
log.warning("Gemini fallback failed: %s", e)
return "I don’t currently have an answer for that in my database."
def try_exact_lookup(raw_text: str) -> Optional[str]:
"""Try exact normalized question lookup before costly retrieval.
Args:
raw_text: The raw user query.
Returns:
The answer string if an exact normalized match is found; else None.
"""
if store.norm_map is None:
base = load_base_csv()
if base is not None and not base.empty:
store.norm_map = _build_norm_map(base)
if store.norm_map:
norm = normalize_for_match(raw_text)
ans = store.norm_map.get(norm)
if ans is not None:
log.info("Exact normalized match hit.")
return ans
return None
def answer_query(user_text: str) -> str:
"""Full QA pipeline: exact → hybrid → fallback.
Args:
user_text: The raw user question.
Returns:
The best-matched answer, or a polite fallback if none is found.
"""
# 1) Exact normalized lookup (fast path)
exact = try_exact_lookup(user_text)
if exact is not None:
return str(exact)
# 2) Hybrid retrieval (semantic + lexical)
q_proc = preprocess(user_text)
row = hybrid_search(user_text, q_proc, k=TOP_K)
if row is not None:
return str(row["answer"])
# 3) Record and respond via fallback
append_fallback(user_text)
return natural_fallback(user_text)
# --------------------------- Train Helpers ---------------------------
def merge_and_rebuild(upload_df: pd.DataFrame) -> None:
"""Merge uploaded data into the base CSV and rebuild the FAISS index.
The merge is based on 'question_id' (latest wins for duplicates).
Args:
upload_df: Dataframe with at least ['question_id','question','answer'].
"""
base = load_base_csv()
cur = pd.concat([base, upload_df], ignore_index=True) if base is not None else upload_df
cur = cur.rename(columns={c: c.strip().lower() for c in cur.columns})
cur = cur.drop_duplicates(subset=["question_id"], keep="last")
save_base_csv(cur[["question_id", "question", "answer"]])
build_index_from_df(cur)
def replace_and_rebuild(upload_df: pd.DataFrame) -> None:
"""Replace the entire base CSV with uploaded data and rebuild the FAISS index.
Args:
upload_df: Dataframe with ['question_id','question','answer'].
"""
upload_df = upload_df.rename(columns={c: c.strip().lower() for c in upload_df.columns})
upload_df = upload_df[["question_id", "question", "answer"]]
save_base_csv(upload_df)
build_index_from_df(upload_df)
# Public API exported to consumers (Gradio app, Streamlit app, tests)
__all__ = [
"store",
"load_index_if_exists",
"build_index_from_df",
"load_base_csv",
"merge_and_rebuild",
"replace_and_rebuild",
"answer_query",
"append_fallback",
]