-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathembedding.py
More file actions
129 lines (99 loc) · 3.58 KB
/
embedding.py
File metadata and controls
129 lines (99 loc) · 3.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# embedding.py
#
# Generate embeddings for text chunks using a local LLM (Ollama).
# These embeddings enable semantic similarity search.
# Embeddings are cached to disk so re-runs skip already-embedded chunks.
import hashlib
import json
import requests
from pathlib import Path
from typing import List, Optional
from dataclasses import dataclass
from config import MODEL_NAME, DATA_DIR, OLLAMA_BASE_URL
from indexer import Chunk
# Ollama API endpoint for embeddings
OLLAMA_EMBED_URL = f"{OLLAMA_BASE_URL}/api/embeddings"
# Cache file path
CACHE_FILE = DATA_DIR / "embedding_cache.json"
def _chunk_key(text: str, model: str) -> str:
"""Stable cache key: SHA256 of text + model name."""
return hashlib.sha256(f"{model}:{text}".encode()).hexdigest()
def _load_cache() -> dict:
if CACHE_FILE.exists():
try:
return json.loads(CACHE_FILE.read_text(encoding="utf-8"))
except Exception:
return {}
return {}
def _save_cache(cache: dict) -> None:
DATA_DIR.mkdir(parents=True, exist_ok=True)
CACHE_FILE.write_text(json.dumps(cache), encoding="utf-8")
@dataclass
class EmbeddedChunk:
"""A chunk with its embedding vector attached."""
chunk: Chunk
embedding: List[float]
def get_embedding(text: str, model: str = MODEL_NAME) -> Optional[List[float]]:
"""
Get embedding vector for a piece of text using Ollama.
Args:
text: The text to embed.
model: The model name to use for embeddings.
Returns:
A list of floats representing the embedding, or None on failure.
"""
try:
response = requests.post(
OLLAMA_EMBED_URL,
json={"model": model, "prompt": text},
timeout=30,
)
response.raise_for_status()
data = response.json()
return data.get("embedding")
except requests.RequestException as e:
print(f"[embedding] Error getting embedding: {e}")
return None
def embed_chunks(chunks: List[Chunk], model: str = MODEL_NAME) -> List[EmbeddedChunk]:
"""
Generate embeddings for a list of chunks, using disk cache where possible.
Args:
chunks: List of Chunk objects to embed.
model: The model name to use for embeddings.
Returns:
List of EmbeddedChunk objects (chunks that failed are skipped).
"""
cache = _load_cache()
embedded: List[EmbeddedChunk] = []
total = len(chunks)
new_embeddings = 0
for i, chunk in enumerate(chunks):
key = _chunk_key(chunk.text, model)
if key in cache:
embedding = cache[key]
else:
if new_embeddings % 10 == 0:
print(f"[embedding] Embedding chunk {i + 1}/{total}...")
embedding = get_embedding(chunk.text, model)
if embedding is None:
continue
cache[key] = embedding
new_embeddings += 1
embedded.append(EmbeddedChunk(chunk=chunk, embedding=embedding))
if new_embeddings > 0:
_save_cache(cache)
print(f"[embedding] {new_embeddings} new embeddings saved to cache.")
else:
print(f"[embedding] All {total} chunks loaded from cache.")
print(f"[embedding] Successfully embedded {len(embedded)}/{total} chunks.")
return embedded
def embed_query(query: str, model: str = MODEL_NAME) -> Optional[List[float]]:
"""
Generate embedding for a search query.
Args:
query: The search query text.
model: The model name to use for embeddings.
Returns:
Embedding vector for the query, or None on failure.
"""
return get_embedding(query, model)