Skip to content

Commit fea7091

Browse files
authored
Merge pull request #582 from LaansDole/main
Fix: hardcoded embedding dimension
2 parents 51ca75c + f3d6a2f commit fea7091

6 files changed

Lines changed: 220 additions & 8 deletions

File tree

Makefile

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,23 @@ validate-yamls: ## Validate all YAML configuration files
4141

4242
.PHONY: help
4343
help: ## Display this help message
44-
@python -c "import re; \
44+
@uv run python -c "import re; \
4545
p=r'$(firstword $(MAKEFILE_LIST))'.strip(); \
4646
[print(f'{m[0]:<20} {m[1]}') for m in re.findall(r'^([a-zA-Z_-]+):.*?## (.*)$$', open(p, encoding='utf-8').read(), re.M)]" | sort
47+
48+
# ==============================================================================
49+
# Quality Checks
50+
# ==============================================================================
51+
52+
.PHONY: check-backend
53+
check-backend: ## Run backend quality checks (tests + linting)
54+
@$(MAKE) backend-tests
55+
@$(MAKE) backend-lint
56+
57+
.PHONY: backend-tests
58+
backend-tests: ## Run backend tests
59+
@uv run pytest -v
60+
61+
.PHONY: backend-lint
62+
backend-lint: ## Run backend linting
63+
@uvx ruff check .

pyproject.toml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,5 +45,19 @@ dependencies = [
4545
requires = ["hatchling"]
4646
build-backend = "hatchling.build"
4747

48+
[tool.pytest.ini_options]
49+
pythonpath = ["."]
50+
testpaths = ["tests"]
51+
python_files = ["test_*.py"]
52+
python_classes = ["Test*"]
53+
python_functions = ["test_*"]
54+
addopts = "-v --tb=short"
55+
filterwarnings = [
56+
# Upstream SWIG issue in faiss-cpu on Python 3.12; awaiting SWIG 4.4 fix.
57+
"ignore:builtin type Swig.*:DeprecationWarning",
58+
"ignore:builtin type swigvarlink.*:DeprecationWarning",
59+
]
60+
61+
4862
[tool.uv]
4963
package = false

runtime/node/agent/memory/embedding.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def __init__(self, embedding_config: EmbeddingConfig):
8686
self.max_length = embedding_config.params.get('max_length', 8191)
8787
self.use_chunking = embedding_config.params.get('use_chunking', False)
8888
self.chunk_strategy = embedding_config.params.get('chunk_strategy', 'average')
89+
self._fallback_dim = 1536 # Default; updated after first successful call
8990

9091
if self.base_url:
9192
self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
@@ -99,7 +100,7 @@ def get_embedding(self, text):
99100

100101
if not processed_text:
101102
logger.warning("Empty text after preprocessing")
102-
return [0.0] * 1536 # Return a zero vector
103+
return [0.0] * self._fallback_dim
103104

104105
# Handle long text via chunking
105106
if self.use_chunking and len(processed_text) > self.max_length:
@@ -115,17 +116,18 @@ def get_embedding(self, text):
115116
encoding_format="float"
116117
)
117118
embedding = response.data[0].embedding
119+
self._fallback_dim = len(embedding)
118120
return embedding
119121
except Exception as e:
120122
logger.error(f"Error getting embedding: {e}")
121-
return [0.0] * 1536 # Return zero vector as fallback
123+
return [0.0] * self._fallback_dim
122124

123125
def _get_chunked_embedding(self, text: str) -> List[float]:
124126
"""Chunk long text, embed each chunk, then aggregate."""
125127
chunks = self._chunk_text(text, self.max_length // 2) # Halve the chunk length
126128

127129
if not chunks:
128-
return [0.0] * 1536
130+
return [0.0] * self._fallback_dim
129131

130132
chunk_embeddings = []
131133
for chunk in chunks:
@@ -141,7 +143,7 @@ def _get_chunked_embedding(self, text: str) -> List[float]:
141143
continue
142144

143145
if not chunk_embeddings:
144-
return [0.0] * 1536
146+
return [0.0] * self._fallback_dim
145147

146148
# Aggregation strategy
147149
if self.chunk_strategy == 'average':
@@ -163,6 +165,7 @@ def __init__(self, embedding_config: EmbeddingConfig):
163165
super().__init__(embedding_config)
164166
self.model_path = embedding_config.params.get('model_path')
165167
self.device = embedding_config.params.get('device', 'cpu')
168+
self._fallback_dim = 768 # Default; updated after first successful call
166169

167170
if not self.model_path:
168171
raise ValueError("LocalEmbedding requires model_path parameter")
@@ -179,11 +182,13 @@ def get_embedding(self, text):
179182
processed_text = self._preprocess_text(text)
180183

181184
if not processed_text:
182-
return [0.0] * 768 # Return zero vector
185+
return [0.0] * self._fallback_dim
183186

184187
try:
185188
embedding = self.model.encode(processed_text, convert_to_tensor=False)
186-
return embedding.tolist()
189+
result = embedding.tolist()
190+
self._fallback_dim = len(result)
191+
return result
187192
except Exception as e:
188193
logger.error(f"Error getting local embedding: {e}")
189-
return [0.0] * 768 # Return zero vector as fallback
194+
return [0.0] * self._fallback_dim

runtime/node/agent/memory/file_memory.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,19 @@ def retrieve(
123123
query_embedding = query_embedding.reshape(1, -1)
124124
faiss.normalize_L2(query_embedding)
125125

126+
expected_dim = query_embedding.shape[1]
127+
126128
# Collect embeddings from memory items
127129
memory_embeddings = []
128130
valid_items = []
129131
for item in self.contents:
130132
if item.embedding is not None:
133+
if len(item.embedding) != expected_dim:
134+
logger.warning(
135+
"Skipping memory item %s: embedding dim %d != expected %d",
136+
item.id, len(item.embedding), expected_dim,
137+
)
138+
continue
131139
memory_embeddings.append(item.embedding)
132140
valid_items.append(item)
133141

runtime/node/agent/memory/simple_memory.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import hashlib
22
import json
3+
import logging
34
import os
45
import re
56
import time
@@ -16,6 +17,8 @@
1617
import faiss
1718
import numpy as np
1819

20+
logger = logging.getLogger(__name__)
21+
1922
class SimpleMemory(MemoryBase):
2023
def __init__(self, store: MemoryStoreConfig):
2124
config = store.as_config(SimpleMemoryConfig)
@@ -107,10 +110,18 @@ def retrieve(
107110
inputs_embedding = inputs_embedding.reshape(1, -1)
108111
faiss.normalize_L2(inputs_embedding)
109112

113+
expected_dim = inputs_embedding.shape[1]
114+
110115
memory_embeddings = []
111116
valid_items = []
112117
for item in self.contents:
113118
if item.embedding is not None:
119+
if len(item.embedding) != expected_dim:
120+
logger.warning(
121+
"Skipping memory item %s: embedding dim %d != expected %d",
122+
item.id, len(item.embedding), expected_dim,
123+
)
124+
continue
114125
memory_embeddings.append(item.embedding)
115126
valid_items.append(item)
116127

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
"""Tests for memory embedding dimension consistency."""
2+
from unittest.mock import MagicMock, patch
3+
from runtime.node.agent.memory.memory_base import MemoryContentSnapshot, MemoryItem
4+
from runtime.node.agent.memory.simple_memory import SimpleMemory
5+
6+
def _make_store(memory_path=None):
7+
"""Build a minimal MemoryStoreConfig mock for SimpleMemory."""
8+
simple_cfg = MagicMock()
9+
simple_cfg.memory_path = memory_path
10+
simple_cfg.embedding = None # We'll set embedding manually
11+
12+
store = MagicMock()
13+
store.name = "test_store"
14+
store.as_config.return_value = simple_cfg
15+
return store
16+
17+
18+
def _make_embedding(dim: int):
19+
"""Create a mock EmbeddingBase that produces vectors of the given dimension."""
20+
emb = MagicMock()
21+
emb.get_embedding.return_value = [0.1] * dim
22+
return emb
23+
24+
25+
def _make_memory_item(item_id: str, dim: int):
26+
"""Create a MemoryItem with an embedding of the specified dimension."""
27+
return MemoryItem(
28+
id=item_id,
29+
content_summary=f"content for {item_id}",
30+
metadata={},
31+
embedding=[float(i) for i in range(dim)],
32+
)
33+
34+
35+
class TestSimpleMemoryRetrieveMixedDimensions:
36+
37+
def test_mixed_dimensions_does_not_crash(self):
38+
"""Retrieve with mixed-dimensional embeddings MUST not raise."""
39+
store = _make_store()
40+
memory = SimpleMemory(store)
41+
memory.embedding = _make_embedding(dim=768)
42+
43+
# 3 items with correct dim, 2 with wrong dim
44+
memory.contents = [
45+
_make_memory_item("ok_1", 768),
46+
_make_memory_item("bad_1", 1536),
47+
_make_memory_item("ok_2", 768),
48+
_make_memory_item("bad_2", 256),
49+
_make_memory_item("ok_3", 768),
50+
]
51+
52+
query = MemoryContentSnapshot(text="test query")
53+
# Should NOT raise ValueError / numpy error
54+
results = memory.retrieve(
55+
agent_role="tester",
56+
query=query,
57+
top_k=5,
58+
similarity_threshold=-1.0,
59+
)
60+
# Only the 3 correct-dimension items should be candidates
61+
assert len(results) <= 3
62+
63+
def test_all_same_dimension_returns_results(self):
64+
"""When all embeddings share the correct dimension, all are candidates."""
65+
store = _make_store()
66+
memory = SimpleMemory(store)
67+
memory.embedding = _make_embedding(dim=768)
68+
69+
memory.contents = [
70+
_make_memory_item("a", 768),
71+
_make_memory_item("b", 768),
72+
]
73+
74+
query = MemoryContentSnapshot(text="test query")
75+
results = memory.retrieve(
76+
agent_role="tester",
77+
query=query,
78+
top_k=5,
79+
similarity_threshold=-1.0,
80+
)
81+
assert len(results) == 2
82+
83+
def test_all_wrong_dimension_returns_empty(self):
84+
"""When every stored embedding has a wrong dimension, return empty."""
85+
store = _make_store()
86+
memory = SimpleMemory(store)
87+
memory.embedding = _make_embedding(dim=768)
88+
89+
memory.contents = [
90+
_make_memory_item("x", 1536),
91+
_make_memory_item("y", 1536),
92+
]
93+
94+
query = MemoryContentSnapshot(text="test query")
95+
results = memory.retrieve(
96+
agent_role="tester",
97+
query=query,
98+
top_k=5,
99+
similarity_threshold=-1.0,
100+
)
101+
assert results == []
102+
103+
104+
class TestOpenAIEmbeddingDynamicFallback:
105+
106+
def test_fallback_uses_model_dimension_after_success(self):
107+
"""After a successful call the fallback dimension MUST match the model."""
108+
from runtime.node.agent.memory.embedding import OpenAIEmbedding
109+
110+
cfg = MagicMock()
111+
cfg.base_url = "http://localhost:11434/v1"
112+
cfg.api_key = "test"
113+
cfg.model = "test-model"
114+
cfg.params = {}
115+
116+
emb = OpenAIEmbedding(cfg)
117+
assert emb._fallback_dim == 1536 # default before any call
118+
119+
# Simulate a successful 768-dim response
120+
mock_data = MagicMock()
121+
mock_data.embedding = [0.1] * 768
122+
mock_response = MagicMock()
123+
mock_response.data = [mock_data]
124+
125+
with patch.object(emb.client.embeddings, "create", return_value=mock_response):
126+
result = emb.get_embedding("hello world")
127+
128+
assert len(result) == 768
129+
assert emb._fallback_dim == 768 # updated after success
130+
131+
def test_fallback_zero_vector_matches_cached_dim(self):
132+
"""After caching dim, fallback zero-vectors MUST use that dim."""
133+
from runtime.node.agent.memory.embedding import OpenAIEmbedding
134+
135+
cfg = MagicMock()
136+
cfg.base_url = "http://localhost:11434/v1"
137+
cfg.api_key = "test"
138+
cfg.model = "test-model"
139+
cfg.params = {}
140+
141+
emb = OpenAIEmbedding(cfg)
142+
143+
# Simulate successful 512-dim call
144+
mock_data = MagicMock()
145+
mock_data.embedding = [0.1] * 512
146+
mock_response = MagicMock()
147+
mock_response.data = [mock_data]
148+
149+
with patch.object(emb.client.embeddings, "create", return_value=mock_response):
150+
emb.get_embedding("first call")
151+
152+
# Now simulate a failure — fallback should be 512-dim
153+
with patch.object(emb.client.embeddings, "create", side_effect=Exception("API down")):
154+
fallback = emb.get_embedding("failing call")
155+
156+
assert len(fallback) == 512
157+
assert all(v == 0.0 for v in fallback)

0 commit comments

Comments
 (0)