Skip to content

Commit 39f6278

Browse files
authored
Merge pull request #28 from GilesStrong/feat/reembed_items
Feat/reembed items
2 parents f287606 + 5a41795 commit 39f6278

5 files changed

Lines changed: 272 additions & 2 deletions

File tree

app/appai/management/__init__.py

Whitespace-only changes.

app/appai/management/commands/__init__.py

Whitespace-only changes.
Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
# Copyright 2026 Giles Strong
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
from concurrent.futures import ThreadPoolExecutor, as_completed
17+
from threading import Semaphore
18+
from typing import Any, Callable
19+
20+
from appcards.constants.storage import CARD_COLLECTION_NAME, THEME_COLLECTION_NAME
21+
from appcards.models.card import Card
22+
from appcards.models.deck import DailyDeckTheme
23+
from appcards.modules.card_to_qm_pointstruct import card_to_qm_pointstruct
24+
from appcore.modules.beartype import beartype
25+
from appsearch.services.qdrant.client import QDRANT_CLIENT
26+
from appsearch.services.qdrant.upsert import create_collection_if_not_exists, upsert_documents
27+
from django.core.management.base import BaseCommand
28+
from qdrant_client.http import models as qm
29+
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
30+
31+
from appai.constants.storage import MEMORY_COLLECTION_NAME
32+
from appai.models.memory import Memory
33+
from appai.modules.dense_embedding import dense_embed
34+
35+
36+
def _re_embed_items(
37+
items: list[Any],
38+
item_to_point: Callable[[Any], qm.PointStruct],
39+
item_label: Callable[[Any], str],
40+
collection_name: str,
41+
batchsize: int,
42+
max_workers: int,
43+
) -> None:
44+
"""Re-embed a list of items into a Qdrant collection.
45+
46+
Deletes and recreates the collection, then embeds all items in batches
47+
using a thread pool. Failed embeddings are logged and skipped.
48+
49+
Args:
50+
items: Items to embed.
51+
item_to_point: Converts a single item to a Qdrant PointStruct.
52+
item_label: Returns a display label for an item (used in log output).
53+
collection_name: Target Qdrant collection name.
54+
batchsize: Number of items per upsert batch.
55+
max_workers: Maximum concurrent embedding threads.
56+
"""
57+
58+
@retry(
59+
stop=stop_after_attempt(3),
60+
wait=wait_exponential(multiplier=1, min=4, max=10),
61+
retry=retry_if_exception_type(Exception),
62+
)
63+
def embed_item(item: Any, semaphore: Semaphore) -> qm.PointStruct:
64+
"""Embed a single item, acquiring the semaphore before calling the API.
65+
66+
Args:
67+
item: The item to embed.
68+
semaphore: Semaphore limiting concurrent API calls.
69+
70+
Returns:
71+
A Qdrant PointStruct for the item.
72+
"""
73+
with semaphore:
74+
point = item_to_point(item)
75+
print(f"✓ Generated embedding for: {item_label(item)}")
76+
return point
77+
78+
def _embed_batch(batch: list[Any]) -> None:
79+
"""Embed a batch of items concurrently and upsert results to Qdrant.
80+
81+
Args:
82+
batch: Items to embed in this batch.
83+
"""
84+
semaphore = Semaphore(max_workers)
85+
embedding_results: list[qm.PointStruct] = []
86+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
87+
futures = [executor.submit(embed_item, item, semaphore) for item in batch]
88+
for future in as_completed(futures):
89+
try:
90+
embedding_results.append(future.result())
91+
except Exception as e:
92+
print(f"✗ Failed to generate embedding: {e}")
93+
if embedding_results:
94+
upsert_documents(collection_name=collection_name, points=embedding_results)
95+
96+
def get_un_embedded(all_items: list[Any]) -> list[Any]:
97+
"""Return items that do not yet have a vector in the collection.
98+
99+
Args:
100+
all_items: Full list of items to check.
101+
102+
Returns:
103+
Items whose IDs are absent from the Qdrant collection.
104+
"""
105+
existing_points = QDRANT_CLIENT.retrieve(
106+
collection_name=collection_name,
107+
ids=[str(item.id) for item in all_items],
108+
with_payload=True,
109+
with_vectors=False,
110+
)
111+
existing_ids = {str(p.id) for p in existing_points}
112+
return [item for item in all_items if str(item.id) not in existing_ids]
113+
114+
batchsize = max(1, batchsize)
115+
if collection_name in [c.name for c in QDRANT_CLIENT.get_collections().collections]:
116+
QDRANT_CLIENT.delete_collection(collection_name=collection_name)
117+
assert collection_name not in [c.name for c in QDRANT_CLIENT.get_collections().collections]
118+
create_collection_if_not_exists(collection_name)
119+
120+
n_items = len(items)
121+
print(f"Total items to process: {n_items}")
122+
for idx in range(0, n_items, batchsize):
123+
batch = items[idx : idx + batchsize]
124+
print(f"Processing batch {idx // batchsize + 1} of {((n_items - 1) // batchsize) + 1}.")
125+
_embed_batch(batch)
126+
127+
n_remaining = len(get_un_embedded(items))
128+
print(f"Finished generating embeddings. {n_remaining} items remaining without embeddings.")
129+
130+
131+
@beartype
132+
def re_embed_cards(batchsize: int = 64, max_workers: int = 5) -> None:
133+
"""Re-embed all cards into the Qdrant card collection.
134+
135+
Args:
136+
batchsize: Number of cards per upsert batch.
137+
max_workers: Maximum concurrent embedding threads.
138+
"""
139+
cards = list(Card.objects.prefetch_related("printings").all())
140+
_re_embed_items(
141+
items=cards,
142+
item_to_point=card_to_qm_pointstruct,
143+
item_label=lambda c: c.name,
144+
collection_name=CARD_COLLECTION_NAME,
145+
batchsize=batchsize,
146+
max_workers=max_workers,
147+
)
148+
149+
150+
@beartype
151+
def re_embed_memories(batchsize: int = 64, max_workers: int = 5) -> None:
152+
"""Re-embed all memories into the Qdrant memory collection.
153+
154+
Args:
155+
batchsize: Number of memories per upsert batch.
156+
max_workers: Maximum concurrent embedding threads.
157+
"""
158+
159+
def memory_to_point(memory: Memory) -> qm.PointStruct:
160+
"""Convert a Memory instance to a Qdrant PointStruct.
161+
162+
Args:
163+
memory: The memory to embed.
164+
165+
Returns:
166+
A Qdrant PointStruct with a dense embedding and associated payload.
167+
"""
168+
embedding = dense_embed(memory.text)
169+
str_related_card_uuids = sorted(str(card.id) for card in memory.related_cards.all())
170+
return qm.PointStruct(
171+
id=str(memory.id),
172+
vector={'dense': embedding},
173+
payload={
174+
"name": memory.name,
175+
"text": memory.text,
176+
"related_card_uuids": str_related_card_uuids,
177+
"created_at": memory.created_at.isoformat(),
178+
},
179+
)
180+
181+
memories = list(Memory.objects.prefetch_related("related_cards").all())
182+
_re_embed_items(
183+
items=memories,
184+
item_to_point=memory_to_point,
185+
item_label=lambda m: m.name,
186+
collection_name=MEMORY_COLLECTION_NAME,
187+
batchsize=batchsize,
188+
max_workers=max_workers,
189+
)
190+
191+
192+
@beartype
193+
def re_embed_themes(batchsize: int = 64, max_workers: int = 5) -> None:
194+
"""Re-embed all daily deck themes into the Qdrant theme collection.
195+
196+
Args:
197+
batchsize: Number of themes per upsert batch.
198+
max_workers: Maximum concurrent embedding threads.
199+
"""
200+
201+
def theme_to_point(theme: DailyDeckTheme) -> qm.PointStruct:
202+
"""Convert a DailyDeckTheme to a Qdrant PointStruct.
203+
204+
Args:
205+
theme: The theme to embed.
206+
207+
Returns:
208+
A Qdrant PointStruct with a dense embedding and associated payload.
209+
"""
210+
embedding = dense_embed(theme.theme)
211+
return qm.PointStruct(
212+
id=str(theme.id),
213+
vector={'dense': embedding},
214+
payload={
215+
"description": theme.theme,
216+
"date": theme.date.isoformat(),
217+
},
218+
)
219+
220+
themes = list(DailyDeckTheme.objects.all())
221+
_re_embed_items(
222+
items=themes,
223+
item_to_point=theme_to_point,
224+
item_label=lambda t: t.theme,
225+
collection_name=THEME_COLLECTION_NAME,
226+
batchsize=batchsize,
227+
max_workers=max_workers,
228+
)
229+
230+
231+
class Command(BaseCommand):
232+
help = 'Run re-embedding of qdrant items.'
233+
234+
def add_arguments(self, parser: argparse.ArgumentParser) -> None:
235+
parser.add_argument(
236+
'--item-type',
237+
type=str,
238+
choices=['cards', 'memories', 'themes'],
239+
help='Type of items to re-embed',
240+
required=True,
241+
)
242+
parser.add_argument('--batchsize', type=int, default=64, help='Upsert batch size (default: 64)')
243+
parser.add_argument('--max-workers', type=int, default=50, help='Maximum number of concurrent workers')
244+
245+
def handle(self, *args: Any, **options: Any) -> None:
246+
if options['item_type'] == 'cards':
247+
re_embed_cards(
248+
batchsize=options.get('batchsize', 64),
249+
max_workers=options.get('max_workers', 50),
250+
)
251+
elif options['item_type'] == 'memories':
252+
re_embed_memories(
253+
batchsize=options.get('batchsize', 64),
254+
max_workers=options.get('max_workers', 50),
255+
)
256+
elif options['item_type'] == 'themes':
257+
re_embed_themes(
258+
batchsize=options.get('batchsize', 64),
259+
max_workers=options.get('max_workers', 50),
260+
)
261+
else:
262+
print(f"Unknown item type: {options['item_type']}")

app/appai/modules/dense_embedding.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from functools import lru_cache
1616
from typing import Any, cast
1717

18+
import numpy as np
1819
import requests
1920
from app.app_settings import APP_SETTINGS
2021
from app.utils import in_celery_task
@@ -48,7 +49,12 @@ def _dense_embed(text: str) -> list[float]:
4849
timeout=60,
4950
)
5051
response.raise_for_status()
51-
return response.json()["embedding"]
52+
vector = np.array(response.json()["embedding"])
53+
length = np.linalg.norm(vector)
54+
if length == 0:
55+
raise ValueError("Received zero-length embedding vector")
56+
vector /= np.linalg.norm(vector) # Normalize the embedding to unit length
57+
return vector.tolist()
5258

5359

5460
@beartype

app/appai/tests/test_dense_embedding.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
_MODULE = "appai.modules.dense_embedding"
2424

2525
_SAMPLE_TEXT = "A red aggro deck with burn spells"
26-
_SAMPLE_EMBEDDING = [0.1, 0.2, 0.3, 0.4]
26+
# Unit vector (norm=1.0) so normalization in _dense_embed is a no-op, keeping
27+
# assertions straightforward regardless of execution path.
28+
_SAMPLE_EMBEDDING = [0.5, 0.5, 0.5, 0.5]
2729

2830

2931
class DenseEmbedInternalTests(TestCase):

0 commit comments

Comments
 (0)