Skip to content

Commit be9c9dd

Browse files
committed
Add pytest + syrupy snapshot tests for MMR reranking
9 test functions covering: - Cosine diversity (baseline vs lambda=1.0, 0.5, 0.0) - L2 distance metric compatibility - Int8 vector element type - Cluster monopoly breaking - Composition with distance constraints - Composition with partition keys - Edge cases (k=1, k=0) - Error handling (invalid lambda range) - Insert guard for hidden column
1 parent 43bd226 commit be9c9dd

File tree

2 files changed

+492
-0
lines changed

2 files changed

+492
-0
lines changed
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
# serializer version: 1
2+
# name: test_mmr_clustering
3+
OrderedDict({
4+
'sql': 'select rowid, distance from v where embedding match ? and k = ?',
5+
'rows': list([
6+
OrderedDict({
7+
'rowid': 1,
8+
'distance': 0.0,
9+
}),
10+
OrderedDict({
11+
'rowid': 2,
12+
'distance': 0.005062814336270094,
13+
}),
14+
OrderedDict({
15+
'rowid': 3,
16+
'distance': 0.011511977761983871,
17+
}),
18+
OrderedDict({
19+
'rowid': 4,
20+
'distance': 0.020601648837327957,
21+
}),
22+
OrderedDict({
23+
'rowid': 5,
24+
'distance': 0.03227578103542328,
25+
}),
26+
]),
27+
})
28+
# ---
29+
# name: test_mmr_clustering.1
30+
OrderedDict({
31+
'sql': 'select rowid, distance from v where embedding match ? and k = ? and mmr_lambda = ?',
32+
'rows': list([
33+
OrderedDict({
34+
'rowid': 1,
35+
'distance': 0.0,
36+
}),
37+
OrderedDict({
38+
'rowid': 2,
39+
'distance': 0.005062814336270094,
40+
}),
41+
OrderedDict({
42+
'rowid': 8,
43+
'distance': 1.0,
44+
}),
45+
OrderedDict({
46+
'rowid': 3,
47+
'distance': 0.011511977761983871,
48+
}),
49+
OrderedDict({
50+
'rowid': 4,
51+
'distance': 0.020601648837327957,
52+
}),
53+
]),
54+
})
55+
# ---
56+
# name: test_mmr_cosine_diversity
57+
OrderedDict({
58+
'sql': 'select rowid, distance from v where embedding match ? and k = ?',
59+
'rows': list([
60+
OrderedDict({
61+
'rowid': 1,
62+
'distance': 0.0,
63+
}),
64+
OrderedDict({
65+
'rowid': 2,
66+
'distance': 0.005062814336270094,
67+
}),
68+
OrderedDict({
69+
'rowid': 3,
70+
'distance': 0.02019595541059971,
71+
}),
72+
]),
73+
})
74+
# ---
75+
# name: test_mmr_cosine_diversity.1
76+
OrderedDict({
77+
'sql': 'select rowid, distance from v where embedding match ? and k = ? and mmr_lambda = ?',
78+
'rows': list([
79+
OrderedDict({
80+
'rowid': 1,
81+
'distance': 0.0,
82+
}),
83+
OrderedDict({
84+
'rowid': 2,
85+
'distance': 0.005062814336270094,
86+
}),
87+
OrderedDict({
88+
'rowid': 3,
89+
'distance': 0.02019595541059971,
90+
}),
91+
]),
92+
})
93+
# ---
94+
# name: test_mmr_cosine_diversity.2
95+
OrderedDict({
96+
'sql': 'select rowid, distance from v where embedding match ? and k = ? and mmr_lambda = ?',
97+
'rows': list([
98+
OrderedDict({
99+
'rowid': 1,
100+
'distance': 0.0,
101+
}),
102+
OrderedDict({
103+
'rowid': 2,
104+
'distance': 0.005062814336270094,
105+
}),
106+
OrderedDict({
107+
'rowid': 5,
108+
'distance': 1.0,
109+
}),
110+
]),
111+
})
112+
# ---
113+
# name: test_mmr_cosine_diversity.3
114+
OrderedDict({
115+
'sql': 'select rowid, distance from v where embedding match ? and k = ? and mmr_lambda = ?',
116+
'rows': list([
117+
OrderedDict({
118+
'rowid': 1,
119+
'distance': 0.0,
120+
}),
121+
OrderedDict({
122+
'rowid': 5,
123+
'distance': 1.0,
124+
}),
125+
OrderedDict({
126+
'rowid': 4,
127+
'distance': 1.0,
128+
}),
129+
]),
130+
})
131+
# ---
132+
# name: test_mmr_edge_cases
133+
OrderedDict({
134+
'sql': 'select rowid, distance from v where embedding match ? and k = ? and mmr_lambda = ?',
135+
'rows': list([
136+
OrderedDict({
137+
'rowid': 1,
138+
'distance': 0.0,
139+
}),
140+
]),
141+
})
142+
# ---
143+
# name: test_mmr_edge_cases.1
144+
OrderedDict({
145+
'sql': 'select rowid, distance from v where embedding match ? and k = ? and mmr_lambda = ?',
146+
'rows': list([
147+
]),
148+
})
149+
# ---
150+
# name: test_mmr_error_invalid_lambda
151+
dict({
152+
'error': 'OperationalError',
153+
'message': 'mmr_lambda value in knn query must be between 0.0 and 1.0, provided 1.500000',
154+
})
155+
# ---
156+
# name: test_mmr_error_invalid_lambda.1
157+
dict({
158+
'error': 'OperationalError',
159+
'message': 'mmr_lambda value in knn query must be between 0.0 and 1.0, provided -0.100000',
160+
})
161+
# ---
162+
# name: test_mmr_insert_guard
163+
dict({
164+
'error': 'OperationalError',
165+
'message': 'A value was provided for the hidden "mmr_lambda" column.',
166+
})
167+
# ---
168+
# name: test_mmr_int8_vectors
169+
OrderedDict({
170+
'sql': 'select rowid, distance from v where embedding match vec_int8(?) and k = ? and mmr_lambda = ?',
171+
'rows': list([
172+
OrderedDict({
173+
'rowid': 1,
174+
'distance': 0.0,
175+
}),
176+
OrderedDict({
177+
'rowid': 2,
178+
'distance': 5.101129863760434e-05,
179+
}),
180+
OrderedDict({
181+
'rowid': 5,
182+
'distance': 1.0,
183+
}),
184+
]),
185+
})
186+
# ---
187+
# name: test_mmr_l2_metric
188+
OrderedDict({
189+
'sql': 'select rowid, distance from v where embedding match ? and k = ? and mmr_lambda = ?',
190+
'rows': list([
191+
OrderedDict({
192+
'rowid': 1,
193+
'distance': 0.0,
194+
}),
195+
OrderedDict({
196+
'rowid': 3,
197+
'distance': 0.028284257277846336,
198+
}),
199+
OrderedDict({
200+
'rowid': 2,
201+
'distance': 0.014142128638923168,
202+
}),
203+
]),
204+
})
205+
# ---
206+
# name: test_mmr_with_distance_constraint
207+
OrderedDict({
208+
'sql': 'select rowid, distance from v where embedding match ? and k = ? and mmr_lambda = ? and distance > 0.001',
209+
'rows': list([
210+
OrderedDict({
211+
'rowid': 2,
212+
'distance': 0.005062814336270094,
213+
}),
214+
OrderedDict({
215+
'rowid': 5,
216+
'distance': 1.0,
217+
}),
218+
OrderedDict({
219+
'rowid': 3,
220+
'distance': 0.02019595541059971,
221+
}),
222+
]),
223+
})
224+
# ---
225+
# name: test_mmr_with_partition_key
226+
OrderedDict({
227+
'sql': 'select rowid, distance from v where embedding match ? and k = ? and category = ? and mmr_lambda = ?',
228+
'rows': list([
229+
OrderedDict({
230+
'rowid': 1,
231+
'distance': 0.0,
232+
}),
233+
OrderedDict({
234+
'rowid': 2,
235+
'distance': 0.005062814336270094,
236+
}),
237+
OrderedDict({
238+
'rowid': 5,
239+
'distance': 1.0,
240+
}),
241+
]),
242+
})
243+
# ---

0 commit comments

Comments
 (0)