Skip to content

Commit 144e2d9

Browse files
committed
Merge pull request #15 from GrayCodeAI/feat/cocoindex-adoption
feat: model-aware embedding caching, asymmetric retrieval, brace chunking
2 parents 84d402e + 2dc32d4 commit 144e2d9

16 files changed

Lines changed: 666 additions & 68 deletions

embeddings/memo.go

Lines changed: 63 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,32 +4,55 @@ import (
44
"context"
55
"crypto/sha256"
66
"encoding/hex"
7+
"strconv"
78
"sync"
89
)
910

10-
// EmbeddingMemo caches embeddings by content hash to skip re-embedding unchanged content.
11+
// EmbeddingMemo caches embeddings to skip re-embedding unchanged content.
12+
//
13+
// The cache key is namespace + mode + sha256(content), NOT content alone. The
14+
// namespace identifies the embedding model (provider Name()), so swapping models
15+
// or model versions no longer serves stale, incomparable vectors from the old
16+
// model. The mode dimension keeps document- and query-mode vectors separate for
17+
// asymmetric retrieval models (e.g. Cohere v3 search_document vs search_query).
1118
type EmbeddingMemo struct {
12-
mu sync.RWMutex
13-
cache map[string][]float32 // sha256(content) -> embedding
14-
order []string // insertion order for LRU eviction
15-
max int
19+
mu sync.RWMutex
20+
namespace string // model identity, prefixed into every key
21+
cache map[string][]float32 // namespace|mode|sha256(content) -> embedding
22+
order []string // insertion order for LRU eviction
23+
max int
1624
}
1725

18-
// NewEmbeddingMemo creates a memo cache with the given max entry count.
26+
// NewEmbeddingMemo creates a memo cache with the given max entry count. The
27+
// namespace is empty; prefer NewEmbeddingMemoNS so the cache is keyed by model
28+
// identity. Pass "" only for tests or single-model callers that never switch.
1929
func NewEmbeddingMemo(maxEntries int) *EmbeddingMemo {
30+
return NewEmbeddingMemoNS("", maxEntries)
31+
}
32+
33+
// NewEmbeddingMemoNS creates a memo cache namespaced by model identity, so that
34+
// changing the embedding model invalidates the cache instead of serving stale
35+
// vectors from the previous model.
36+
func NewEmbeddingMemoNS(namespace string, maxEntries int) *EmbeddingMemo {
2037
if maxEntries <= 0 {
2138
maxEntries = 1024
2239
}
2340
return &EmbeddingMemo{
24-
cache: make(map[string][]float32, maxEntries),
25-
order: make([]string, 0, maxEntries),
26-
max: maxEntries,
41+
namespace: namespace,
42+
cache: make(map[string][]float32, maxEntries),
43+
order: make([]string, 0, maxEntries),
44+
max: maxEntries,
2745
}
2846
}
2947

30-
// Get returns a cached embedding for the content, if present.
48+
// Get returns a cached embedding for the content in ModeDocument, if present.
3149
func (m *EmbeddingMemo) Get(content string) ([]float32, bool) {
32-
key := contentHash(content)
50+
return m.GetMode(content, ModeDocument)
51+
}
52+
53+
// GetMode returns a cached embedding for the content under the given mode.
54+
func (m *EmbeddingMemo) GetMode(content string, mode EmbedMode) ([]float32, bool) {
55+
key := m.key(content, mode)
3356
m.mu.Lock()
3457
vec, ok := m.cache[key]
3558
if ok {
@@ -39,9 +62,15 @@ func (m *EmbeddingMemo) Get(content string) ([]float32, bool) {
3962
return vec, ok
4063
}
4164

42-
// Put stores an embedding for the given content, evicting the oldest entry if at capacity.
65+
// Put stores an embedding for the given content in ModeDocument.
4366
func (m *EmbeddingMemo) Put(content string, embedding []float32) {
44-
key := contentHash(content)
67+
m.PutMode(content, ModeDocument, embedding)
68+
}
69+
70+
// PutMode stores an embedding for the given content under the given mode,
71+
// evicting the oldest entry if at capacity.
72+
func (m *EmbeddingMemo) PutMode(content string, mode EmbedMode, embedding []float32) {
73+
key := m.key(content, mode)
4574
m.mu.Lock()
4675
defer m.mu.Unlock()
4776
if _, exists := m.cache[key]; exists {
@@ -75,6 +104,12 @@ func (m *EmbeddingMemo) promote(key string) {
75104
}
76105
}
77106

107+
// key builds the cache key from model namespace, embedding mode, and content
108+
// hash, so that a model change or a mode change never collides with stale state.
109+
func (m *EmbeddingMemo) key(content string, mode EmbedMode) string {
110+
return m.namespace + "|" + strconv.Itoa(int(mode)) + "|" + contentHash(content)
111+
}
112+
78113
func contentHash(s string) string {
79114
h := sha256.Sum256([]byte(s))
80115
return hex.EncodeToString(h[:])
@@ -86,11 +121,13 @@ type MemoizedProvider struct {
86121
memo *EmbeddingMemo
87122
}
88123

89-
// NewMemoizedProvider wraps an existing Provider with a memo cache.
124+
// NewMemoizedProvider wraps an existing Provider with a memo cache. The cache is
125+
// namespaced by the inner provider's Name() (which encodes the model), so a model
126+
// swap can never serve stale vectors from the previous model.
90127
func NewMemoizedProvider(inner Provider, maxEntries int) *MemoizedProvider {
91128
return &MemoizedProvider{
92129
inner: inner,
93-
memo: NewEmbeddingMemo(maxEntries),
130+
memo: NewEmbeddingMemoNS(inner.Name(), maxEntries),
94131
}
95132
}
96133

@@ -137,8 +174,17 @@ func (p *MemoizedProvider) EmbedBatch(ctx context.Context, texts []string) ([][]
137174
}
138175

139176
func (p *MemoizedProvider) EmbedWithMode(ctx context.Context, text string, mode EmbedMode) ([]float32, error) {
140-
// Mode-aware calls bypass memo since same content may produce different vectors per mode.
141-
return p.inner.EmbedWithMode(ctx, text, mode)
177+
// Memoized per-mode: the key includes the mode, so document- and query-mode
178+
// vectors for the same text never collide.
179+
if vec, ok := p.memo.GetMode(text, mode); ok {
180+
return vec, nil
181+
}
182+
vec, err := p.inner.EmbedWithMode(ctx, text, mode)
183+
if err != nil {
184+
return nil, err
185+
}
186+
p.memo.PutMode(text, mode, vec)
187+
return vec, nil
142188
}
143189

144190
// Memo returns the underlying cache for inspection/testing.

embeddings/memo_test.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ type countingProvider struct {
4040
localStub
4141
embedCalls int64
4242
batchCalls int64
43+
modeCalls int64
4344
}
4445

4546
func (c *countingProvider) Embed(ctx context.Context, text string) ([]float32, error) {
@@ -52,6 +53,11 @@ func (c *countingProvider) EmbedBatch(ctx context.Context, texts []string) ([][]
5253
return c.localStub.EmbedBatch(ctx, texts)
5354
}
5455

56+
func (c *countingProvider) EmbedWithMode(ctx context.Context, text string, mode EmbedMode) ([]float32, error) {
57+
atomic.AddInt64(&c.modeCalls, 1)
58+
return c.localStub.EmbedWithMode(ctx, text, mode)
59+
}
60+
5561
func TestMemoizedProvider_Embed(t *testing.T) {
5662
inner := &countingProvider{}
5763
p := NewMemoizedProvider(inner, 100)
@@ -97,3 +103,73 @@ func TestMemoizedProvider_EmbedBatch(t *testing.T) {
97103
t.Fatalf("expected still 1 batch call, got %d", inner.batchCalls)
98104
}
99105
}
106+
107+
// namedProvider lets a test override the reported model name.
108+
type namedProvider struct {
109+
localStub
110+
name string
111+
}
112+
113+
func (p *namedProvider) Name() string { return p.name }
114+
115+
// TestMemoizedProvider_ModelSwapInvalidates pins the core #1 fix: a memo built
116+
// for one model must not serve its vectors to a different model. Since the memo
117+
// is namespaced by Name(), each model has its own key space.
118+
func TestMemoizedProvider_ModelSwapInvalidates(t *testing.T) {
119+
ctx := context.Background()
120+
old := NewMemoizedProvider(&namedProvider{name: "modelA"}, 100)
121+
v1, _ := old.Embed(ctx, "foo")
122+
123+
// Same content, a different model identity → different namespace → miss.
124+
newer := NewMemoizedProvider(&namedProvider{name: "modelB"}, 100)
125+
if _, ok := newer.Memo().Get("foo"); ok {
126+
t.Fatal("modelB memo should not contain modelA's content")
127+
}
128+
// Within the same model, content still hits.
129+
if _, ok := old.Memo().Get("foo"); !ok {
130+
t.Fatal("modelA memo should still contain its own content")
131+
}
132+
_ = v1
133+
}
134+
135+
// TestEmbeddingMemo_ModeIsolation pins the #2-supporting behavior: document- and
136+
// query-mode vectors for identical text occupy distinct keys.
137+
func TestEmbeddingMemo_ModeIsolation(t *testing.T) {
138+
m := NewEmbeddingMemoNS("model", 100)
139+
m.PutMode("q", ModeDocument, []float32{1, 0})
140+
m.PutMode("q", ModeQuery, []float32{0, 1})
141+
142+
doc, ok := m.GetMode("q", ModeDocument)
143+
if !ok || doc[0] != 1 {
144+
t.Fatalf("document-mode vector wrong: %v ok=%v", doc, ok)
145+
}
146+
qry, ok := m.GetMode("q", ModeQuery)
147+
if !ok || qry[1] != 1 {
148+
t.Fatalf("query-mode vector wrong: %v ok=%v", qry, ok)
149+
}
150+
}
151+
152+
// TestMemoizedProvider_EmbedWithModeMemoizes pins that mode-aware calls are now
153+
// cached (previously they bypassed the memo entirely).
154+
func TestMemoizedProvider_EmbedWithModeMemoizes(t *testing.T) {
155+
inner := &countingProvider{}
156+
p := NewMemoizedProvider(inner, 100)
157+
ctx := context.Background()
158+
159+
if _, err := p.EmbedWithMode(ctx, "x", ModeQuery); err != nil {
160+
t.Fatalf("first mode embed failed: %v", err)
161+
}
162+
if _, err := p.EmbedWithMode(ctx, "x", ModeQuery); err != nil {
163+
t.Fatalf("second mode embed failed: %v", err)
164+
}
165+
if inner.modeCalls != 1 {
166+
t.Fatalf("expected 1 inner mode call, got %d", inner.modeCalls)
167+
}
168+
// A different mode for the same text must miss and call inner again.
169+
if _, err := p.EmbedWithMode(ctx, "x", ModeDocument); err != nil {
170+
t.Fatalf("doc-mode embed failed: %v", err)
171+
}
172+
if inner.modeCalls != 2 {
173+
t.Fatalf("expected 2 inner mode calls after mode switch, got %d", inner.modeCalls)
174+
}
175+
}

engine/engine_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ func (m *mockStorage) DeleteEmbedding(ctx context.Context, nodeID string) error
555555
return nil
556556
}
557557

558-
func (m *mockStorage) AllEmbeddings(ctx context.Context) (map[string][]float32, error) {
558+
func (m *mockStorage) AllEmbeddings(ctx context.Context, _ string) (map[string][]float32, error) {
559559
if err := ctx.Err(); err != nil {
560560
return nil, err
561561
}
@@ -581,8 +581,8 @@ func (m *mockStorage) GetEmbedding(ctx context.Context, nodeID string) ([]float3
581581
return nil, "", nil
582582
}
583583

584-
func (m *mockStorage) GetEmbeddingsBatch(ctx context.Context, offset, limit int) (map[string][]float32, error) {
585-
return m.AllEmbeddings(ctx)
584+
func (m *mockStorage) GetEmbeddingsBatch(ctx context.Context, model string, offset, limit int) (map[string][]float32, error) {
585+
return m.AllEmbeddings(ctx, model)
586586
}
587587

588588
func (m *mockStorage) AddFileWatch(ctx context.Context, filePath, nodeID, gitHash string) error {

engine/fused_recall.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"sync/atomic"
99
"time"
1010

11+
"github.com/GrayCodeAI/yaad/embeddings"
1112
"github.com/GrayCodeAI/yaad/intent"
1213
"github.com/GrayCodeAI/yaad/internal/telemetry"
1314
"github.com/GrayCodeAI/yaad/storage"
@@ -471,7 +472,9 @@ func fusedMergeKeys(maps ...map[string]int) []string {
471472
// checks the result before searching.
472473
func (e *Engine) queryVector(ctx context.Context, query, proxySeedID string) []float32 {
473474
if e.embedder != nil {
474-
if vec, err := e.embedder.Embed(ctx, query); err == nil && len(vec) > 0 {
475+
// Query mode: asymmetric models embed queries (search_query) differently
476+
// from stored documents (search_document).
477+
if vec, err := e.embedder.EmbedWithMode(ctx, query, embeddings.ModeQuery); err == nil && len(vec) > 0 {
475478
return vec
476479
} else if err != nil {
477480
atomic.AddInt64(&e.metrics.Errors, 1)

engine/search.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ func (h *HybridSearch) Search(ctx context.Context, query string, opts RecallOpts
6161
// Path 2: Vector seed nodes (if provider available)
6262
vectorRanks := map[string]int{}
6363
if h.provider != nil {
64-
vec, err := h.provider.Embed(ctx, query)
64+
// Embed in query mode: asymmetric models (e.g. Cohere v3) expect
65+
// search_query for queries vs search_document for stored content.
66+
vec, err := h.provider.EmbedWithMode(ctx, query, embeddings.ModeQuery)
6567
if err == nil {
6668
vectorRanks = h.vectorSearch(ctx, vec, opts.Limit*2)
6769
}
@@ -155,10 +157,18 @@ func (h *HybridSearch) vectorSearch(ctx context.Context, queryVec []float32, lim
155157
}
156158
var pairs []pair
157159

160+
// Scope the scan to the active provider's model: stored vectors from a
161+
// different model occupy an incompatible space and would produce meaningless
162+
// cosine scores. The ingest path records model = provider.Name() (rest.go).
163+
var model string
164+
if h.provider != nil {
165+
model = h.provider.Name()
166+
}
167+
158168
const batchSize = 500
159169
offset := 0
160170
for {
161-
batch, err := h.store.GetEmbeddingsBatch(ctx, offset, batchSize)
171+
batch, err := h.store.GetEmbeddingsBatch(ctx, model, offset, batchSize)
162172
if err != nil || len(batch) == 0 {
163173
break
164174
}

go.sum

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)