Skip to content

Commit bbb0400

Browse files
committed
optimization
1 parent 1fa878c commit bbb0400

1 file changed

Lines changed: 101 additions & 45 deletions

File tree

devtron-docs-rag-server/vector_store.py

Lines changed: 101 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,26 @@ def __init__(self, model_name: str = "BAAI/bge-large-en-v1.5"):
3838
logger.warning(f"Cache directory not found: {cache_dir}")
3939

4040
try:
41+
# Load model with optimizations for CPU inference
42+
import torch
43+
44+
# Disable gradient computation (we're only doing inference)
45+
torch.set_grad_enabled(False)
46+
4147
# Load model - it will use SENTENCE_TRANSFORMERS_HOME env var automatically
4248
self.model = SentenceTransformer(model_name)
49+
50+
# Set model to evaluation mode for faster inference
51+
self.model.eval()
52+
53+
# Enable CPU optimizations if available
54+
try:
55+
# Use Intel MKL optimizations if available
56+
torch.set_num_threads(2) # Limit threads to avoid oversubscription
57+
logger.info(f"Set PyTorch threads to 2 for optimal CPU performance")
58+
except Exception:
59+
pass
60+
4361
self.dimension = self.model.get_sentence_embedding_dimension()
4462
logger.info(f"✓ Embedding model loaded (dimension: {self.dimension})")
4563
except Exception as e:
@@ -61,14 +79,16 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
6179
# Add instruction prefix for better retrieval (recommended by BGE)
6280
texts_with_prefix = [f"passage: {text}" for text in texts]
6381

64-
# Use very small batch size for CPU to minimize blocking time
65-
# batch_size=2 processes 2 texts at a time, reducing memory and blocking
82+
# Optimized settings for CPU inference
83+
# batch_size=16 is optimal for CPU (balances speed vs memory)
84+
# convert_to_tensor=False avoids unnecessary tensor conversions
6685
embeddings = self.model.encode(
6786
texts_with_prefix,
6887
show_progress_bar=False,
69-
batch_size=2,
88+
batch_size=16,
7089
convert_to_numpy=True,
71-
normalize_embeddings=False
90+
normalize_embeddings=False,
91+
device='cpu' # Explicitly use CPU
7292
)
7393
return embeddings.tolist()
7494

@@ -131,6 +151,11 @@ def __init__(
131151
cur.execute("SELECT version();")
132152
version = cur.fetchone()[0]
133153
logger.info(f"✓ Database connected successfully")
154+
155+
# Log connection details for debugging
156+
cur.execute("SELECT current_database(), current_schema();")
157+
db, schema = cur.fetchone()
158+
logger.info(f"Connected to database: {db}, schema: {schema}")
134159
finally:
135160
self.pool.putconn(conn)
136161

@@ -225,14 +250,19 @@ async def index_documents(self, documents: List[Dict[str, Any]]) -> None:
225250

226251
logger.info(f"Starting indexing: {len(documents)} documents")
227252

228-
# Process documents one at a time to minimize memory and allow health checks
229-
batch_size = 1
230-
total_batches = len(documents)
253+
# Process documents in small batches with optimized embedding
254+
# With faster embeddings, we can process 2-3 documents at once
255+
batch_size = 2
256+
total_batches = (len(documents) + batch_size - 1) // batch_size
231257

232258
for i in range(0, len(documents), batch_size):
233259
batch = documents[i:i + batch_size]
234-
batch_num = i + 1
235-
logger.info(f"Processing document {batch_num}/{total_batches}: {batch[0].get('title', 'Unknown')}")
260+
batch_num = (i // batch_size) + 1
261+
262+
# Log document titles being processed
263+
titles = [doc.get('title', 'Unknown') for doc in batch]
264+
logger.info(f"Processing batch {batch_num}/{total_batches}: {', '.join(titles[:2])}")
265+
236266
await self._index_batch(batch)
237267

238268
# Yield control to event loop to allow health checks to respond
@@ -269,9 +299,9 @@ async def _index_batch(self, documents: List[Dict[str, Any]]) -> None:
269299

270300
logger.info(f"Processing {len(rows)} chunks from {len(documents)} document(s)")
271301

272-
# Process chunks in very small sub-batches to avoid blocking health checks
273-
# Reduced to 5 chunks at a time (~10-15 seconds per sub-batch)
274-
chunk_batch_size = 5
302+
# Process chunks in optimized sub-batches
303+
# With optimizations: 10 chunks takes ~5-8 seconds (much faster!)
304+
chunk_batch_size = 10
275305
total_chunks = len(rows)
276306

277307
conn = self.pool.getconn()
@@ -293,46 +323,72 @@ async def _index_batch(self, documents: List[Dict[str, Any]]) -> None:
293323
)
294324

295325
# Insert into database
296-
with conn.cursor() as cur:
297-
# Prepare data for batch insert
298-
values = [
299-
(
300-
chunk_batch[i]['id'],
301-
chunk_batch[i]['title'],
302-
chunk_batch[i]['source'],
303-
chunk_batch[i]['header'],
304-
chunk_batch[i]['content'],
305-
chunk_batch[i]['chunk_index'],
306-
embeddings[i]
326+
try:
327+
with conn.cursor() as cur:
328+
# Prepare data for batch insert
329+
values = [
330+
(
331+
chunk_batch[i]['id'],
332+
chunk_batch[i]['title'],
333+
chunk_batch[i]['source'],
334+
chunk_batch[i]['header'],
335+
chunk_batch[i]['content'],
336+
chunk_batch[i]['chunk_index'],
337+
embeddings[i]
338+
)
339+
for i in range(len(chunk_batch))
340+
]
341+
342+
# Batch insert
343+
execute_values(
344+
cur,
345+
"""
346+
INSERT INTO documents
347+
(id, title, source, header, content, chunk_index, embedding)
348+
VALUES %s
349+
ON CONFLICT (id) DO UPDATE SET
350+
title = EXCLUDED.title,
351+
source = EXCLUDED.source,
352+
header = EXCLUDED.header,
353+
content = EXCLUDED.content,
354+
chunk_index = EXCLUDED.chunk_index,
355+
embedding = EXCLUDED.embedding,
356+
updated_at = CURRENT_TIMESTAMP
357+
""",
358+
values
307359
)
308-
for i in range(len(chunk_batch))
309-
]
310-
311-
# Batch insert
312-
execute_values(
313-
cur,
314-
"""
315-
INSERT INTO documents
316-
(id, title, source, header, content, chunk_index, embedding)
317-
VALUES %s
318-
ON CONFLICT (id) DO UPDATE SET
319-
title = EXCLUDED.title,
320-
source = EXCLUDED.source,
321-
header = EXCLUDED.header,
322-
content = EXCLUDED.content,
323-
chunk_index = EXCLUDED.chunk_index,
324-
embedding = EXCLUDED.embedding,
325-
updated_at = CURRENT_TIMESTAMP
326-
""",
327-
values
328-
)
360+
361+
# Commit outside cursor context to ensure it's not rolled back
329362
conn.commit()
330-
logger.info(f" ✓ Stored {len(chunk_batch)} chunks")
363+
364+
# Verify insertion immediately after commit
365+
with conn.cursor() as cur:
366+
# Check if the chunks were actually inserted
367+
chunk_ids = [chunk_batch[i]['id'] for i in range(len(chunk_batch))]
368+
cur.execute(
369+
"SELECT COUNT(*) FROM documents WHERE id = ANY(%s);",
370+
(chunk_ids,)
371+
)
372+
verified_count = cur.fetchone()[0]
373+
374+
if verified_count != len(chunk_batch):
375+
logger.error(f" ✗ Verification failed: Expected {len(chunk_batch)}, found {verified_count}")
376+
raise Exception(f"Data insertion verification failed")
377+
378+
logger.info(f" ✓ Stored and verified {len(chunk_batch)} chunks")
379+
380+
except Exception as e:
381+
logger.error(f" ✗ Failed to store chunks: {str(e)}", exc_info=True)
382+
conn.rollback()
383+
raise
331384

332385
# Yield control to event loop to allow health checks
333386
await asyncio.sleep(0.1)
334387

335388
logger.info(f"✓ Document complete: {total_chunks} chunks indexed")
389+
except Exception as e:
390+
logger.error(f"Error indexing batch: {str(e)}", exc_info=True)
391+
raise
336392
finally:
337393
self.pool.putconn(conn)
338394

0 commit comments

Comments
 (0)