@@ -4,32 +4,55 @@ import (
44 "context"
55 "crypto/sha256"
66 "encoding/hex"
7+ "strconv"
78 "sync"
89)
910
10- // EmbeddingMemo caches embeddings by content hash to skip re-embedding unchanged content.
11+ // EmbeddingMemo caches embeddings to skip re-embedding unchanged content.
12+ //
13+ // The cache key is namespace + mode + sha256(content), NOT content alone. The
14+ // namespace identifies the embedding model (provider Name()), so swapping models
15+ // or model versions no longer serves stale, incomparable vectors from the old
16+ // model. The mode dimension keeps document- and query-mode vectors separate for
17+ // asymmetric retrieval models (e.g. Cohere v3 search_document vs search_query).
1118type EmbeddingMemo struct {
12- mu sync.RWMutex
13- cache map [string ][]float32 // sha256(content) -> embedding
14- order []string // insertion order for LRU eviction
15- max int
19+ mu sync.RWMutex
20+ namespace string // model identity, prefixed into every key
21+ cache map [string ][]float32 // namespace|mode|sha256(content) -> embedding
22+ order []string // insertion order for LRU eviction
23+ max int
1624}
1725
18- // NewEmbeddingMemo creates a memo cache with the given max entry count.
26+ // NewEmbeddingMemo creates a memo cache with the given max entry count. The
27+ // namespace is empty; prefer NewEmbeddingMemoNS so the cache is keyed by model
28+ // identity. Pass "" only for tests or single-model callers that never switch.
1929func NewEmbeddingMemo (maxEntries int ) * EmbeddingMemo {
30+ return NewEmbeddingMemoNS ("" , maxEntries )
31+ }
32+
33+ // NewEmbeddingMemoNS creates a memo cache namespaced by model identity, so that
34+ // changing the embedding model invalidates the cache instead of serving stale
35+ // vectors from the previous model.
36+ func NewEmbeddingMemoNS (namespace string , maxEntries int ) * EmbeddingMemo {
2037 if maxEntries <= 0 {
2138 maxEntries = 1024
2239 }
2340 return & EmbeddingMemo {
24- cache : make (map [string ][]float32 , maxEntries ),
25- order : make ([]string , 0 , maxEntries ),
26- max : maxEntries ,
41+ namespace : namespace ,
42+ cache : make (map [string ][]float32 , maxEntries ),
43+ order : make ([]string , 0 , maxEntries ),
44+ max : maxEntries ,
2745 }
2846}
2947
30- // Get returns a cached embedding for the content, if present.
48+ // Get returns a cached embedding for the content in ModeDocument , if present.
3149func (m * EmbeddingMemo ) Get (content string ) ([]float32 , bool ) {
32- key := contentHash (content )
50+ return m .GetMode (content , ModeDocument )
51+ }
52+
53+ // GetMode returns a cached embedding for the content under the given mode.
54+ func (m * EmbeddingMemo ) GetMode (content string , mode EmbedMode ) ([]float32 , bool ) {
55+ key := m .key (content , mode )
3356 m .mu .Lock ()
3457 vec , ok := m .cache [key ]
3558 if ok {
@@ -39,9 +62,15 @@ func (m *EmbeddingMemo) Get(content string) ([]float32, bool) {
3962 return vec , ok
4063}
4164
42- // Put stores an embedding for the given content, evicting the oldest entry if at capacity .
65+ // Put stores an embedding for the given content in ModeDocument .
4366func (m * EmbeddingMemo ) Put (content string , embedding []float32 ) {
44- key := contentHash (content )
67+ m .PutMode (content , ModeDocument , embedding )
68+ }
69+
70+ // PutMode stores an embedding for the given content under the given mode,
71+ // evicting the oldest entry if at capacity.
72+ func (m * EmbeddingMemo ) PutMode (content string , mode EmbedMode , embedding []float32 ) {
73+ key := m .key (content , mode )
4574 m .mu .Lock ()
4675 defer m .mu .Unlock ()
4776 if _ , exists := m .cache [key ]; exists {
@@ -75,6 +104,12 @@ func (m *EmbeddingMemo) promote(key string) {
75104 }
76105}
77106
107+ // key builds the cache key from model namespace, embedding mode, and content
108+ // hash, so that a model change or a mode change never collides with stale state.
109+ func (m * EmbeddingMemo ) key (content string , mode EmbedMode ) string {
110+ return m .namespace + "|" + strconv .Itoa (int (mode )) + "|" + contentHash (content )
111+ }
112+
78113func contentHash (s string ) string {
79114 h := sha256 .Sum256 ([]byte (s ))
80115 return hex .EncodeToString (h [:])
@@ -86,11 +121,13 @@ type MemoizedProvider struct {
86121 memo * EmbeddingMemo
87122}
88123
89- // NewMemoizedProvider wraps an existing Provider with a memo cache.
124+ // NewMemoizedProvider wraps an existing Provider with a memo cache. The cache is
125+ // namespaced by the inner provider's Name() (which encodes the model), so a model
126+ // swap can never serve stale vectors from the previous model.
90127func NewMemoizedProvider (inner Provider , maxEntries int ) * MemoizedProvider {
91128 return & MemoizedProvider {
92129 inner : inner ,
93- memo : NewEmbeddingMemo ( maxEntries ),
130+ memo : NewEmbeddingMemoNS ( inner . Name (), maxEntries ),
94131 }
95132}
96133
@@ -137,8 +174,17 @@ func (p *MemoizedProvider) EmbedBatch(ctx context.Context, texts []string) ([][]
137174}
138175
139176func (p * MemoizedProvider ) EmbedWithMode (ctx context.Context , text string , mode EmbedMode ) ([]float32 , error ) {
140- // Mode-aware calls bypass memo since same content may produce different vectors per mode.
141- return p .inner .EmbedWithMode (ctx , text , mode )
177+ // Memoized per-mode: the key includes the mode, so document- and query-mode
178+ // vectors for the same text never collide.
179+ if vec , ok := p .memo .GetMode (text , mode ); ok {
180+ return vec , nil
181+ }
182+ vec , err := p .inner .EmbedWithMode (ctx , text , mode )
183+ if err != nil {
184+ return nil , err
185+ }
186+ p .memo .PutMode (text , mode , vec )
187+ return vec , nil
142188}
143189
144190// Memo returns the underlying cache for inspection/testing.
0 commit comments