Skip to content

Commit 6a898ce

Browse files
committed
fix: stored memory embeddings had mixed dim
1 parent 75e889d commit 6a898ce

4 files changed

Lines changed: 199 additions & 7 deletions

File tree

runtime/node/agent/memory/embedding.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def __init__(self, embedding_config: EmbeddingConfig):
8686
self.max_length = embedding_config.params.get('max_length', 8191)
8787
self.use_chunking = embedding_config.params.get('use_chunking', False)
8888
self.chunk_strategy = embedding_config.params.get('chunk_strategy', 'average')
89+
self._fallback_dim = 1536 # Default; updated after first successful call
8990

9091
if self.base_url:
9192
self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
@@ -99,7 +100,7 @@ def get_embedding(self, text):
99100

100101
if not processed_text:
101102
logger.warning("Empty text after preprocessing")
102-
return [0.0] * 1536 # Return a zero vector
103+
return [0.0] * self._fallback_dim
103104

104105
# Handle long text via chunking
105106
if self.use_chunking and len(processed_text) > self.max_length:
@@ -115,17 +116,18 @@ def get_embedding(self, text):
115116
encoding_format="float"
116117
)
117118
embedding = response.data[0].embedding
119+
self._fallback_dim = len(embedding)
118120
return embedding
119121
except Exception as e:
120122
logger.error(f"Error getting embedding: {e}")
121-
return [0.0] * 1536 # Return zero vector as fallback
123+
return [0.0] * self._fallback_dim
122124

123125
def _get_chunked_embedding(self, text: str) -> List[float]:
124126
"""Chunk long text, embed each chunk, then aggregate."""
125127
chunks = self._chunk_text(text, self.max_length // 2) # Halve the chunk length
126128

127129
if not chunks:
128-
return [0.0] * 1536
130+
return [0.0] * self._fallback_dim
129131

130132
chunk_embeddings = []
131133
for chunk in chunks:
@@ -141,7 +143,7 @@ def _get_chunked_embedding(self, text: str) -> List[float]:
141143
continue
142144

143145
if not chunk_embeddings:
144-
return [0.0] * 1536
146+
return [0.0] * self._fallback_dim
145147

146148
# Aggregation strategy
147149
if self.chunk_strategy == 'average':
@@ -163,6 +165,7 @@ def __init__(self, embedding_config: EmbeddingConfig):
163165
super().__init__(embedding_config)
164166
self.model_path = embedding_config.params.get('model_path')
165167
self.device = embedding_config.params.get('device', 'cpu')
168+
self._fallback_dim = 768 # Default; updated after first successful call
166169

167170
if not self.model_path:
168171
raise ValueError("LocalEmbedding requires model_path parameter")
@@ -179,11 +182,13 @@ def get_embedding(self, text):
179182
processed_text = self._preprocess_text(text)
180183

181184
if not processed_text:
182-
return [0.0] * 768 # Return zero vector
185+
return [0.0] * self._fallback_dim
183186

184187
try:
185188
embedding = self.model.encode(processed_text, convert_to_tensor=False)
186-
return embedding.tolist()
189+
result = embedding.tolist()
190+
self._fallback_dim = len(result)
191+
return result
187192
except Exception as e:
188193
logger.error(f"Error getting local embedding: {e}")
189-
return [0.0] * 768 # Return zero vector as fallback
194+
return [0.0] * self._fallback_dim

runtime/node/agent/memory/file_memory.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,19 @@ def retrieve(
123123
query_embedding = query_embedding.reshape(1, -1)
124124
faiss.normalize_L2(query_embedding)
125125

126+
expected_dim = query_embedding.shape[1]
127+
126128
# Collect embeddings from memory items
127129
memory_embeddings = []
128130
valid_items = []
129131
for item in self.contents:
130132
if item.embedding is not None:
133+
if len(item.embedding) != expected_dim:
134+
logger.warning(
135+
"Skipping memory item %s: embedding dim %d != expected %d",
136+
item.id, len(item.embedding), expected_dim,
137+
)
138+
continue
131139
memory_embeddings.append(item.embedding)
132140
valid_items.append(item)
133141

runtime/node/agent/memory/simple_memory.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import hashlib
22
import json
3+
import logging
34
import os
45
import re
56
import time
@@ -16,6 +17,8 @@
1617
import faiss
1718
import numpy as np
1819

20+
logger = logging.getLogger(__name__)
21+
1922
class SimpleMemory(MemoryBase):
2023
def __init__(self, store: MemoryStoreConfig):
2124
config = store.as_config(SimpleMemoryConfig)
@@ -107,10 +110,18 @@ def retrieve(
107110
inputs_embedding = inputs_embedding.reshape(1, -1)
108111
faiss.normalize_L2(inputs_embedding)
109112

113+
expected_dim = inputs_embedding.shape[1]
114+
110115
memory_embeddings = []
111116
valid_items = []
112117
for item in self.contents:
113118
if item.embedding is not None:
119+
if len(item.embedding) != expected_dim:
120+
logger.warning(
121+
"Skipping memory item %s: embedding dim %d != expected %d",
122+
item.id, len(item.embedding), expected_dim,
123+
)
124+
continue
114125
memory_embeddings.append(item.embedding)
115126
valid_items.append(item)
116127

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
"""Tests for memory embedding dimension consistency."""
2+
from unittest.mock import MagicMock, patch
3+
from runtime.node.agent.memory.memory_base import MemoryContentSnapshot, MemoryItem
4+
from runtime.node.agent.memory.simple_memory import SimpleMemory
5+
6+
7+
# ---------------------------------------------------------------------------
8+
# Helpers
9+
# ---------------------------------------------------------------------------
10+
11+
def _make_store(memory_path=None):
12+
"""Build a minimal MemoryStoreConfig mock for SimpleMemory."""
13+
simple_cfg = MagicMock()
14+
simple_cfg.memory_path = memory_path
15+
simple_cfg.embedding = None # We'll set embedding manually
16+
17+
store = MagicMock()
18+
store.name = "test_store"
19+
store.as_config.return_value = simple_cfg
20+
return store
21+
22+
23+
def _make_embedding(dim: int):
24+
"""Create a mock EmbeddingBase that produces vectors of the given dimension."""
25+
emb = MagicMock()
26+
emb.get_embedding.return_value = [0.1] * dim
27+
return emb
28+
29+
30+
def _make_memory_item(item_id: str, dim: int):
31+
"""Create a MemoryItem with an embedding of the specified dimension."""
32+
return MemoryItem(
33+
id=item_id,
34+
content_summary=f"content for {item_id}",
35+
metadata={},
36+
embedding=[float(i) for i in range(dim)],
37+
)
38+
39+
40+
# ---------------------------------------------------------------------------
41+
# Tests
42+
# ---------------------------------------------------------------------------
43+
44+
class TestSimpleMemoryRetrieveMixedDimensions:
45+
"""Task 2.1: verify retrieve() handles mixed-dimension embeddings."""
46+
47+
def test_mixed_dimensions_does_not_crash(self):
48+
"""Retrieve with mixed-dimensional embeddings MUST not raise."""
49+
store = _make_store()
50+
memory = SimpleMemory(store)
51+
memory.embedding = _make_embedding(dim=768)
52+
53+
# 3 items with correct dim, 2 with wrong dim
54+
memory.contents = [
55+
_make_memory_item("ok_1", 768),
56+
_make_memory_item("bad_1", 1536),
57+
_make_memory_item("ok_2", 768),
58+
_make_memory_item("bad_2", 256),
59+
_make_memory_item("ok_3", 768),
60+
]
61+
62+
query = MemoryContentSnapshot(text="test query")
63+
# Should NOT raise ValueError / numpy error
64+
results = memory.retrieve(
65+
agent_role="tester",
66+
query=query,
67+
top_k=5,
68+
similarity_threshold=-1.0,
69+
)
70+
# Only the 3 correct-dimension items should be candidates
71+
assert len(results) <= 3
72+
73+
def test_all_same_dimension_returns_results(self):
74+
"""When all embeddings share the correct dimension, all are candidates."""
75+
store = _make_store()
76+
memory = SimpleMemory(store)
77+
memory.embedding = _make_embedding(dim=768)
78+
79+
memory.contents = [
80+
_make_memory_item("a", 768),
81+
_make_memory_item("b", 768),
82+
]
83+
84+
query = MemoryContentSnapshot(text="test query")
85+
results = memory.retrieve(
86+
agent_role="tester",
87+
query=query,
88+
top_k=5,
89+
similarity_threshold=-1.0,
90+
)
91+
assert len(results) == 2
92+
93+
def test_all_wrong_dimension_returns_empty(self):
94+
"""When every stored embedding has a wrong dimension, return empty."""
95+
store = _make_store()
96+
memory = SimpleMemory(store)
97+
memory.embedding = _make_embedding(dim=768)
98+
99+
memory.contents = [
100+
_make_memory_item("x", 1536),
101+
_make_memory_item("y", 1536),
102+
]
103+
104+
query = MemoryContentSnapshot(text="test query")
105+
results = memory.retrieve(
106+
agent_role="tester",
107+
query=query,
108+
top_k=5,
109+
similarity_threshold=-1.0,
110+
)
111+
assert results == []
112+
113+
114+
class TestOpenAIEmbeddingDynamicFallback:
115+
"""Task 2.2: verify dynamic fallback dimension caching."""
116+
117+
def test_fallback_uses_model_dimension_after_success(self):
118+
"""After a successful call the fallback dimension MUST match the model."""
119+
from runtime.node.agent.memory.embedding import OpenAIEmbedding
120+
121+
cfg = MagicMock()
122+
cfg.base_url = "http://localhost:11434/v1"
123+
cfg.api_key = "test"
124+
cfg.model = "test-model"
125+
cfg.params = {}
126+
127+
emb = OpenAIEmbedding(cfg)
128+
assert emb._fallback_dim == 1536 # default before any call
129+
130+
# Simulate a successful 768-dim response
131+
mock_data = MagicMock()
132+
mock_data.embedding = [0.1] * 768
133+
mock_response = MagicMock()
134+
mock_response.data = [mock_data]
135+
136+
with patch.object(emb.client.embeddings, "create", return_value=mock_response):
137+
result = emb.get_embedding("hello world")
138+
139+
assert len(result) == 768
140+
assert emb._fallback_dim == 768 # updated after success
141+
142+
def test_fallback_zero_vector_matches_cached_dim(self):
143+
"""After caching dim, fallback zero-vectors MUST use that dim."""
144+
from runtime.node.agent.memory.embedding import OpenAIEmbedding
145+
146+
cfg = MagicMock()
147+
cfg.base_url = "http://localhost:11434/v1"
148+
cfg.api_key = "test"
149+
cfg.model = "test-model"
150+
cfg.params = {}
151+
152+
emb = OpenAIEmbedding(cfg)
153+
154+
# Simulate successful 512-dim call
155+
mock_data = MagicMock()
156+
mock_data.embedding = [0.1] * 512
157+
mock_response = MagicMock()
158+
mock_response.data = [mock_data]
159+
160+
with patch.object(emb.client.embeddings, "create", return_value=mock_response):
161+
emb.get_embedding("first call")
162+
163+
# Now simulate a failure — fallback should be 512-dim
164+
with patch.object(emb.client.embeddings, "create", side_effect=Exception("API down")):
165+
fallback = emb.get_embedding("failing call")
166+
167+
assert len(fallback) == 512
168+
assert all(v == 0.0 for v in fallback)

0 commit comments

Comments
 (0)