-
-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathembedding_services.py
More file actions
201 lines (176 loc) · 6.65 KB
/
embedding_services.py
File metadata and controls
201 lines (176 loc) · 6.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import time
import logging
from statistics import median
# Use Q objects to express OR conditions in Django queries
from django.db.models import Q
from pgvector.django import L2Distance
from .sentencetTransformer_model import TransformerModel
from ..models.model_embeddings import Embeddings
from ..models.model_search_usage import SemanticSearchUsage
logger = logging.getLogger(__name__)
def build_query(user, embedding_vector, document_name=None, guid=None, num_results=10):
"""
Build an unevaluated QuerySet for the closest embeddings.
Parameters
----------
user : User
The user whose uploaded documents will be searched
embedding_vector : array-like
Pre-computed embedding vector to compare against
document_name : str, optional
Filter results to a specific document name
guid : str, optional
Filter results to a specific document GUID (takes precedence over document_name)
num_results : int, default 10
Maximum number of results to return
Returns
-------
QuerySet
Unevaluated Django QuerySet ordered by L2 distance, sliced to num_results
"""
# Django QuerySets are lazily evaluated
if user.is_authenticated:
# User sees their own files + files uploaded by superusers
queryset = Embeddings.objects.filter(
Q(upload_file__uploaded_by=user) | Q(upload_file__uploaded_by__is_superuser=True)
)
else:
# Unauthenticated users only see superuser-uploaded files
queryset = Embeddings.objects.filter(upload_file__uploaded_by__is_superuser=True)
queryset = (
queryset
.annotate(distance=L2Distance("embedding_sentence_transformers", embedding_vector))
.order_by("distance")
)
# Filtering to a document GUID takes precedence over a document name
if guid:
queryset = queryset.filter(upload_file__guid=guid)
elif document_name:
queryset = queryset.filter(name=document_name)
# Slicing is equivalent to SQL's LIMIT clause
return queryset[:num_results]
def evaluate_query(queryset):
"""
Evaluate a QuerySet and return a list of result dicts.
Parameters
----------
queryset : iterable
Iterable of Embeddings objects (or any objects with the expected attributes)
Returns
-------
list[dict]
List of dicts with keys: name, text, page_number, chunk_number, distance, file_id
"""
# Iterating evaluates the QuerySet and hits the database
# TODO: Research improving the query evaluation performance
return [
{
"name": obj.name,
"text": obj.text,
"page_number": obj.page_num,
"chunk_number": obj.chunk_number,
"distance": obj.distance,
"file_id": obj.upload_file.guid if obj.upload_file else None,
}
for obj in queryset
]
def log_usage(
results, message_data, user, guid, document_name, num_results, encoding_time, db_query_time
):
"""
Create a SemanticSearchUsage record. Swallows exceptions so search isn't interrupted.
Parameters
----------
results : list[dict]
The search results, each containing a "distance" key
message_data : str
The original search query text
user : User
The user who performed the search
guid : str or None
Document GUID filter used in the search
document_name : str or None
Document name filter used in the search
num_results : int
Number of results requested
encoding_time : float
Time in seconds to encode the query
db_query_time : float
Time in seconds for the database query
"""
try:
if results:
distances = [r["distance"] for r in results]
SemanticSearchUsage.objects.create(
query_text=message_data,
user=user if (user and user.is_authenticated) else None,
document_guid=guid,
document_name=document_name,
num_results_requested=num_results,
encoding_time=encoding_time,
db_query_time=db_query_time,
num_results_returned=len(results),
max_distance=max(distances),
median_distance=median(distances),
min_distance=min(distances),
)
else:
logger.warning("Semantic search returned no results")
SemanticSearchUsage.objects.create(
query_text=message_data,
user=user if (user and user.is_authenticated) else None,
document_guid=guid,
document_name=document_name,
num_results_requested=num_results,
encoding_time=encoding_time,
db_query_time=db_query_time,
num_results_returned=0,
max_distance=None,
median_distance=None,
min_distance=None,
)
except Exception:
logger.exception("Failed to create semantic search usage database record")
def get_closest_embeddings(
user, message_data, document_name=None, guid=None, num_results=10
):
"""
Find the closest embeddings to a given message for a specific user.
Parameters
----------
user : User
The user whose uploaded documents will be searched
message_data : str
The input message to find similar embeddings for
document_name : str, optional
Filter results to a specific document name
guid : str, optional
Filter results to a specific document GUID (takes precedence over document_name)
num_results : int, default 10
Maximum number of results to return
Returns
-------
list[dict]
List of dictionaries containing embedding results with keys:
- name: document name
- text: embedded text content
- page_number: page number in source document
- chunk_number: chunk number within the document
- distance: L2 distance from query embedding
- file_id: GUID of the source file
Notes
-----
Creates a SemanticSearchUsage record. Swallows exceptions so search isn't interrupted.
"""
encoding_start = time.time()
model = TransformerModel.get_instance().model
embedding_vector = model.encode(message_data)
encoding_time = time.time() - encoding_start
db_query_start = time.time()
queryset = build_query(user, embedding_vector, document_name, guid, num_results)
results = evaluate_query(queryset)
db_query_time = time.time() - db_query_start
log_usage(
results, message_data, user, guid, document_name, num_results, encoding_time, db_query_time
)
return results