Skip to content

Commit a134482

Browse files
authored
Fix OR expression with KNN producing syntax error (#787)
When combining an OR expression like (A) | (B) with a KNN expression, the generated query was invalid because the KNN suffix was only applied to the second term: (A)| (B)=>[KNN ...] The fix ensures the entire filter expression is wrapped in parentheses before appending the KNN syntax: ((A)| (B))=>[KNN ...] Fixes #557
1 parent 57b673c commit a134482

2 files changed

Lines changed: 95 additions & 5 deletions

File tree

aredis_om/model/model.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -906,11 +906,13 @@ def query(self):
906906
return self._query
907907
self._query = self._resolve_redisearch_query(self.expression)
908908
if self.knn:
909-
self._query = (
910-
self._query
911-
if self._query.startswith("(") or self._query == "*"
912-
else f"({self._query})"
913-
) + f"=>[{self.knn}]"
909+
# Always wrap the filter expression in parentheses when combining with KNN,
910+
# unless it's the wildcard "*". This ensures OR expressions like
911+
# "(A)| (B)" become "((A)| (B))=>[KNN ...]" instead of the invalid
912+
# "(A)| (B)=>[KNN ...]" where KNN only applies to the second term.
913+
if self._query != "*":
914+
self._query = f"({self._query})"
915+
self._query += f"=>[{self.knn}]"
914916
# RETURN clause should be added to args, not to the query string
915917
return self._query
916918

tests/test_knn_expression.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,91 @@ async def test_nested_vector_field(n: Type[JsonModel]):
113113

114114
assert len(members) == 1
115115
assert members[0].embeddings_score is not None
116+
117+
118+
119+
@pytest_asyncio.fixture
120+
async def album_model(key_prefix, redis):
121+
"""Fixture for testing OR expressions with KNN."""
122+
class BaseJsonModel(JsonModel, abc.ABC):
123+
class Meta:
124+
global_key_prefix = key_prefix
125+
database = redis
126+
127+
vector_options = VectorFieldOptions.flat(
128+
type=VectorFieldOptions.TYPE.FLOAT32,
129+
dimension=2,
130+
distance_metric=VectorFieldOptions.DISTANCE_METRIC.COSINE,
131+
)
132+
133+
class Album(BaseJsonModel, index=True):
134+
title: str = Field(primary_key=True)
135+
tags: str = Field(index=True)
136+
title_embeddings: list[float] = Field(
137+
[], index=True, vector_options=vector_options
138+
)
139+
embeddings_score: Optional[float] = None
140+
141+
await Migrator(conn=redis).run()
142+
143+
return Album
144+
145+
146+
@py_test_mark_asyncio
147+
async def test_or_expression_with_knn(album_model):
148+
"""Test that OR expressions work correctly with KNN.
149+
150+
Regression test for GitHub issue #557: Using an OR expression with a
151+
KNN expression raises ResponseError with syntax error.
152+
"""
153+
Album = album_model
154+
155+
# Create test data
156+
albums = [
157+
Album(
158+
title="Rumours",
159+
tags="Genre:rock|Decade:70s",
160+
title_embeddings=[0.7, 0.3],
161+
),
162+
Album(
163+
title="Abbey Road",
164+
tags="Genre:rock|Decade:60s",
165+
title_embeddings=[0.6, 0.4],
166+
),
167+
Album(
168+
title="The Dark Side Of The Moon",
169+
tags="Genre:prog-rock|Decade:70s",
170+
title_embeddings=[0.5, 0.5],
171+
),
172+
]
173+
for album in albums:
174+
await album.save()
175+
176+
# Create OR expression
177+
or_expr = (Album.tags == "Genre:rock|Decade:70s") | (
178+
Album.tags == "Genre:rock|Decade:60s"
179+
)
180+
181+
# Create KNN expression
182+
knn = KNNExpression(
183+
k=3,
184+
vector_field=Album.title_embeddings,
185+
score_field=Album.embeddings_score,
186+
reference_vector=to_bytes([0.65, 0.35]),
187+
)
188+
189+
# Query with just OR expression (should work)
190+
or_results = await Album.find(or_expr).all()
191+
assert len(or_results) == 2
192+
193+
# Query with just KNN (should work)
194+
knn_results = await Album.find(knn=knn).all()
195+
assert len(knn_results) == 3
196+
197+
# Query with OR expression AND KNN (this was failing before the fix)
198+
combined_results = await Album.find(or_expr, knn=knn).all()
199+
# Should return only the 2 albums matching the OR expression
200+
assert len(combined_results) == 2
201+
# All results should have an embeddings score from KNN
202+
for result in combined_results:
203+
assert result.embeddings_score is not None

0 commit comments

Comments
 (0)