Skip to content

Commit 59b40f0

Browse files
committed
REFACTOR Pull apart get_closest_embeddings to make testing easier
1 parent c64cc48 commit 59b40f0

1 file changed

Lines changed: 112 additions & 49 deletions

File tree

server/api/services/embedding_services.py

Lines changed: 112 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,17 @@
1111

1212
logger = logging.getLogger(__name__)
1313

14-
def get_closest_embeddings(
15-
user, message_data, document_name=None, guid=None, num_results=10
16-
):
14+
15+
def build_query(user, embedding_vector, document_name=None, guid=None, num_results=10):
1716
"""
18-
Find the closest embeddings to a given message for a specific user.
17+
Build an unevaluated QuerySet for the closest embeddings.
1918
2019
Parameters
2120
----------
2221
user : User
2322
The user whose uploaded documents will be searched
24-
message_data : str
25-
The input message to find similar embeddings for
23+
embedding_vector : array-like
24+
Pre-computed embedding vector to compare against
2625
document_name : str, optional
2726
Filter results to a specific document name
2827
guid : str, optional
@@ -32,59 +31,52 @@ def get_closest_embeddings(
3231
3332
Returns
3433
-------
35-
list[dict]
36-
List of dictionaries containing embedding results with keys:
37-
- name: document name
38-
- text: embedded text content
39-
- page_number: page number in source document
40-
- chunk_number: chunk number within the document
41-
- distance: L2 distance from query embedding
42-
- file_id: GUID of the source file
34+
QuerySet
35+
Unevaluated Django QuerySet ordered by L2 distance, sliced to num_results
4336
"""
44-
45-
encoding_start = time.time()
46-
transformerModel = TransformerModel.get_instance().model
47-
embedding_message = transformerModel.encode(message_data)
48-
encoding_time = time.time() - encoding_start
49-
50-
db_query_start = time.time()
51-
5237
# Django QuerySets are lazily evaluated
5338
if user.is_authenticated:
5439
# User sees their own files + files uploaded by superusers
55-
closest_embeddings_query = (
56-
Embeddings.objects.filter(
57-
Q(upload_file__uploaded_by=user) | Q(upload_file__uploaded_by__is_superuser=True)
58-
)
59-
.annotate(
60-
distance=L2Distance("embedding_sentence_transformers", embedding_message)
61-
)
62-
.order_by("distance")
40+
queryset = Embeddings.objects.filter(
41+
Q(upload_file__uploaded_by=user) | Q(upload_file__uploaded_by__is_superuser=True)
6342
)
6443
else:
6544
# Unauthenticated users only see superuser-uploaded files
66-
closest_embeddings_query = (
67-
Embeddings.objects.filter(upload_file__uploaded_by__is_superuser=True)
68-
.annotate(
69-
distance=L2Distance("embedding_sentence_transformers", embedding_message)
70-
)
71-
.order_by("distance")
72-
)
45+
queryset = Embeddings.objects.filter(upload_file__uploaded_by__is_superuser=True)
46+
47+
queryset = (
48+
queryset
49+
.annotate(distance=L2Distance("embedding_sentence_transformers", embedding_vector))
50+
.order_by("distance")
51+
)
7352

7453
# Filtering to a document GUID takes precedence over a document name
7554
if guid:
76-
closest_embeddings_query = closest_embeddings_query.filter(
77-
upload_file__guid=guid
78-
)
55+
queryset = queryset.filter(upload_file__guid=guid)
7956
elif document_name:
80-
closest_embeddings_query = closest_embeddings_query.filter(name=document_name)
57+
queryset = queryset.filter(name=document_name)
8158

8259
# Slicing is equivalent to SQL's LIMIT clause
83-
closest_embeddings_query = closest_embeddings_query[:num_results]
60+
return queryset[:num_results]
61+
62+
63+
def format_results(queryset):
64+
"""
65+
Evaluate a QuerySet and return a list of result dicts.
66+
67+
Parameters
68+
----------
69+
queryset : iterable
70+
Iterable of Embeddings objects (or any objects with the expected attributes)
8471
72+
Returns
73+
-------
74+
list[dict]
75+
List of dicts with keys: name, text, page_number, chunk_number, distance, file_id
76+
"""
8577
# Iterating evaluates the QuerySet and hits the database
8678
# TODO: Research improving the query evaluation performance
87-
results = [
79+
return [
8880
{
8981
"name": obj.name,
9082
"text": obj.text,
@@ -93,13 +85,36 @@ def get_closest_embeddings(
9385
"distance": obj.distance,
9486
"file_id": obj.upload_file.guid if obj.upload_file else None,
9587
}
96-
for obj in closest_embeddings_query
88+
for obj in queryset
9789
]
9890

99-
db_query_time = time.time() - db_query_start
10091

92+
def log_search_usage(
93+
results, message_data, user, guid, document_name, num_results, encoding_time, db_query_time
94+
):
95+
"""
96+
Create a SemanticSearchUsage record. Swallows exceptions so search isn't interrupted.
97+
98+
Parameters
99+
----------
100+
results : list[dict]
101+
The search results, each containing a "distance" key
102+
message_data : str
103+
The original search query text
104+
user : User
105+
The user who performed the search
106+
guid : str or None
107+
Document GUID filter used in the search
108+
document_name : str or None
109+
Document name filter used in the search
110+
num_results : int
111+
Number of results requested
112+
encoding_time : float
113+
Time in seconds to encode the query
114+
db_query_time : float
115+
Time in seconds for the database query
116+
"""
101117
try:
102-
# Handle user having no uploaded docs or doc filtering returning no matches
103118
if results:
104119
distances = [r["distance"] for r in results]
105120
SemanticSearchUsage.objects.create(
@@ -113,11 +128,10 @@ def get_closest_embeddings(
113128
num_results_returned=len(results),
114129
max_distance=max(distances),
115130
median_distance=median(distances),
116-
min_distance=min(distances)
131+
min_distance=min(distances),
117132
)
118133
else:
119134
logger.warning("Semantic search returned no results")
120-
121135
SemanticSearchUsage.objects.create(
122136
query_text=message_data,
123137
user=user if (user and user.is_authenticated) else None,
@@ -129,9 +143,58 @@ def get_closest_embeddings(
129143
num_results_returned=0,
130144
max_distance=None,
131145
median_distance=None,
132-
min_distance=None
146+
min_distance=None,
133147
)
134148
except Exception as e:
135149
logger.error(f"Failed to create semantic search usage database record: {e}")
136150

151+
152+
def get_closest_embeddings(
153+
user, message_data, document_name=None, guid=None, num_results=10
154+
):
155+
"""
156+
Find the closest embeddings to a given message for a specific user.
157+
158+
Parameters
159+
----------
160+
user : User
161+
The user whose uploaded documents will be searched
162+
message_data : str
163+
The input message to find similar embeddings for
164+
document_name : str, optional
165+
Filter results to a specific document name
166+
guid : str, optional
167+
Filter results to a specific document GUID (takes precedence over document_name)
168+
num_results : int, default 10
169+
Maximum number of results to return
170+
171+
Returns
172+
-------
173+
list[dict]
174+
List of dictionaries containing embedding results with keys:
175+
- name: document name
176+
- text: embedded text content
177+
- page_number: page number in source document
178+
- chunk_number: chunk number within the document
179+
- distance: L2 distance from query embedding
180+
- file_id: GUID of the source file
181+
182+
Notes
183+
-----
184+
Creates a SemanticSearchUsage record. Swallows exceptions so search isn't interrupted.
185+
"""
186+
encoding_start = time.time()
187+
model = TransformerModel.get_instance().model
188+
embedding_vector = model.encode(message_data)
189+
encoding_time = time.time() - encoding_start
190+
191+
db_query_start = time.time()
192+
queryset = build_query(user, embedding_vector, document_name, guid, num_results)
193+
results = format_results(queryset)
194+
db_query_time = time.time() - db_query_start
195+
196+
log_search_usage(
197+
results, message_data, user, guid, document_name, num_results, encoding_time, db_query_time
198+
)
199+
137200
return results

0 commit comments

Comments
 (0)