-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_chroma.py
More file actions
176 lines (138 loc) · 4.86 KB
/
Copy pathtest_chroma.py
File metadata and controls
176 lines (138 loc) · 4.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
#!/usr/bin/env python3
"""
ChromaDB Inspection & RAG Search Test Script
Run this to check if documents are vectorized and test search queries.
"""
import chromadb
from chromadb.config import Settings
import sys
import os
from langchain_ollama import OllamaEmbeddings
# Configuration
PERSIST_DIR = os.getenv("CHROMA_PERSIST_DIR", "/app/chroma_db")
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://ollama:11434")
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "nomic-embed-text")
COLLECTION_NAME = "documents"
def get_embedding_function():
"""Get the same embedding function used by the app"""
return OllamaEmbeddings(
model=EMBEDDING_MODEL,
base_url=OLLAMA_BASE_URL,
)
def connect_to_chroma():
"""Connect to ChromaDB using persistent storage"""
try:
client = chromadb.PersistentClient(path=PERSIST_DIR)
print(f"Connected to ChromaDB at {PERSIST_DIR}")
return client
except Exception as e:
print(f"Failed to connect to ChromaDB: {e}")
sys.exit(1)
def list_collections(client):
"""List all collections in ChromaDB"""
collections = client.list_collections()
print(f"\n{'='*50}")
print("COLLECTIONS IN CHROMADB")
print(f"{'='*50}")
if not collections:
print("No collections found. Upload some documents first!")
return None
for col in collections:
print(f" - {col.name}")
return collections
def inspect_collection(client, collection_name):
"""Inspect a specific collection"""
try:
collection = client.get_collection(collection_name)
except Exception as e:
print(f"Collection '{collection_name}' not found: {e}")
return None
count = collection.count()
print(f"\n{'='*50}")
print(f"COLLECTION: {collection_name}")
print(f"{'='*50}")
print(f"Document chunks: {count}")
if count == 0:
print("No documents in this collection. Upload documents via the Streamlit app.")
return None
# Peek at sample documents
print(f"\n--- Sample Documents (first 3) ---")
results = collection.peek(limit=3)
for i, (doc_id, doc_content, metadata) in enumerate(zip(
results.get('ids', []),
results.get('documents', []),
results.get('metadatas', [])
)):
print(f"\n[{i+1}] ID: {doc_id}")
print(f" Metadata: {metadata}")
content_preview = doc_content[:200] if doc_content else "N/A"
print(f" Content: {content_preview}...")
return collection
def run_similarity_search(collection, query, n_results=3):
"""Run a similarity search query using Ollama embeddings"""
print(f"\n{'='*50}")
print(f"QUERY: \"{query}\"")
print(f"{'='*50}")
# Generate embedding using the same model as the app
embeddings = get_embedding_function()
query_embedding = embeddings.embed_query(query)
results = collection.query(
query_embeddings=[query_embedding],
n_results=n_results,
include=["documents", "metadatas", "distances"]
)
if not results['documents'][0]:
print("No results found.")
return
print(f"Found {len(results['documents'][0])} results:\n")
for i, (doc, metadata, distance) in enumerate(zip(
results['documents'][0],
results['metadatas'][0],
results['distances'][0]
)):
similarity = 1 - distance # Convert distance to similarity score
print(f"[{i+1}] Similarity: {similarity:.4f} (distance: {distance:.4f})")
print(f" Metadata: {metadata}")
print(f" Content: {doc[:300]}...")
print()
def interactive_search(collection):
"""Interactive search mode"""
print(f"\n{'='*50}")
print("INTERACTIVE SEARCH MODE")
print("Type your queries (or 'quit' to exit)")
print(f"{'='*50}")
while True:
query = input("\nEnter query: ").strip()
if query.lower() in ['quit', 'exit', 'q']:
break
if not query:
continue
run_similarity_search(collection, query)
def main():
print("ChromaDB RAG Search Tester")
print("="*50)
# Connect
client = connect_to_chroma()
# List collections
list_collections(client)
# Inspect main collection
collection = inspect_collection(client, COLLECTION_NAME)
if collection and collection.count() > 0:
# Run some test queries
test_queries = [
"What is this document about?",
"main topic",
"summary",
]
print(f"\n{'#'*50}")
print("RUNNING TEST QUERIES")
print(f"{'#'*50}")
for query in test_queries:
run_similarity_search(collection, query)
# Offer interactive mode
response = input("\nEnter interactive search mode? (y/n): ").strip().lower()
if response == 'y':
interactive_search(collection)
print("\nDone!")
if __name__ == "__main__":
main()