Skip to content

Commit 08573c6

Browse files
committed
add tests
Signed-off-by: Attila Toth <attila.toth@scylladb.com>
1 parent b42809e commit 08573c6

1 file changed

Lines changed: 209 additions & 6 deletions

File tree

sdk/python/tests/integration/online_store/test_scylladb_online_store.py

Lines changed: 209 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
"""
1515

1616
import os
17+
import time
18+
from datetime import datetime, timedelta, timezone
1719
from typing import List
1820

1921
import pytest
@@ -28,7 +30,7 @@
2830
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
2931
from feast.protos.feast.types.Value_pb2 import FloatList
3032
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
31-
from feast.types import Array, Float32
33+
from feast.types import Array, Float32, Int64
3234
from feast.utils import _utc_now
3335
from feast.value_type import ValueType
3436
from tests.universal.feature_repos.universal.online_store.scylladb import (
@@ -85,7 +87,7 @@ def _make_entity_key(val: str) -> EntityKeyProto:
8587
)
8688

8789

88-
def _make_feature_view(name: str, with_vector: bool = False) -> FeatureView:
90+
def _make_feature_view(name: str, with_vector: bool = False, ttl: timedelta = None) -> FeatureView:
8991
source = FileSource(path="dummy.parquet", timestamp_field="event_timestamp")
9092
schema: List[Field] = [Field(name="score", dtype=Array(Float32))]
9193
if with_vector:
@@ -108,6 +110,7 @@ def _make_feature_view(name: str, with_vector: bool = False) -> FeatureView:
108110
schema=schema,
109111
online=True,
110112
source=source,
113+
ttl=ttl,
111114
)
112115

113116

@@ -142,6 +145,7 @@ def test_write_and_read(docker_config):
142145
ts, feats = results[0]
143146
assert feats is not None
144147
assert "score" in feats
148+
assert list(feats["score"].float_list_val.val) == pytest.approx([0.9])
145149
finally:
146150
store.teardown(cfg, [fv], [])
147151

@@ -176,14 +180,155 @@ def test_missing_key_returns_none(docker_config):
176180
store.teardown(cfg, [fv], [])
177181

178182

183+
def test_multiple_features_roundtrip(docker_config):
184+
"""Multiple features of different types all round-trip with the correct value.
185+
"""
186+
store = ScyllaDBOnlineStore()
187+
cfg = docker_config
188+
source = FileSource(path="dummy.parquet", timestamp_field="event_timestamp")
189+
fv = FeatureView(
190+
name="test_multi_features",
191+
entities=[Entity(name="item_id", join_keys=["item_id"], value_type=ValueType.STRING)],
192+
schema=[
193+
Field(name="score", dtype=Array(Float32)),
194+
Field(name="priority", dtype=Int64),
195+
],
196+
online=True,
197+
source=source,
198+
)
199+
200+
store.update(cfg, [], [fv], [], [], partial=False)
201+
try:
202+
ek = _make_entity_key("item_mf")
203+
store.online_write_batch(
204+
cfg,
205+
fv,
206+
[
207+
(
208+
ek,
209+
{
210+
"score": ValueProto(float_list_val=FloatList(val=[0.85, 0.15])),
211+
"priority": ValueProto(int64_val=42),
212+
},
213+
_utc_now(),
214+
None,
215+
)
216+
],
217+
None,
218+
)
219+
results = store.online_read(cfg, fv, [ek])
220+
assert len(results) == 1
221+
ts, feats = results[0]
222+
assert feats is not None
223+
assert list(feats["score"].float_list_val.val) == pytest.approx([0.85, 0.15])
224+
assert feats["priority"].int64_val == 42
225+
finally:
226+
store.teardown(cfg, [fv], [])
227+
228+
229+
def test_multiple_entities(docker_config):
230+
"""Multiple entity keys can be read in a single online_read call with correct values.
231+
232+
Mirrors the multi-entity behaviour verified by the universal online store
233+
suite (``test_online_retrieval_with_event_timestamps``) which Cassandra runs.
234+
"""
235+
store = ScyllaDBOnlineStore()
236+
cfg = docker_config
237+
fv = _make_feature_view("test_multi_entities")
238+
239+
store.update(cfg, [], [fv], [], [], partial=False)
240+
try:
241+
keys = [_make_entity_key(f"entity_{i}") for i in range(3)]
242+
batch = [
243+
(
244+
ek,
245+
{"score": ValueProto(float_list_val=FloatList(val=[float(i) * 0.1 + 0.1]))},
246+
_utc_now(),
247+
None,
248+
)
249+
for i, ek in enumerate(keys)
250+
]
251+
store.online_write_batch(cfg, fv, batch, None)
252+
results = store.online_read(cfg, fv, keys)
253+
assert len(results) == 3
254+
for i, (ts, feats) in enumerate(results):
255+
assert feats is not None, f"Entity {i} returned None features"
256+
assert list(feats["score"].float_list_val.val) == pytest.approx(
257+
[float(i) * 0.1 + 0.1]
258+
)
259+
finally:
260+
store.teardown(cfg, [fv], [])
261+
262+
263+
def test_overwrite_uses_latest_value(docker_config):
264+
"""Writing the same entity key twice keeps the most-recently-written value."""
265+
store = ScyllaDBOnlineStore()
266+
cfg = docker_config
267+
fv = _make_feature_view("test_overwrite")
268+
269+
store.update(cfg, [], [fv], [], [], partial=False)
270+
try:
271+
ek = _make_entity_key("overwrite_item")
272+
store.online_write_batch(
273+
cfg,
274+
fv,
275+
[(ek, {"score": ValueProto(float_list_val=FloatList(val=[0.1]))}, _utc_now(), None)],
276+
None,
277+
)
278+
store.online_write_batch(
279+
cfg,
280+
fv,
281+
[(ek, {"score": ValueProto(float_list_val=FloatList(val=[0.9]))}, _utc_now(), None)],
282+
None,
283+
)
284+
results = store.online_read(cfg, fv, [ek])
285+
assert len(results) == 1
286+
ts, feats = results[0]
287+
assert feats is not None
288+
assert list(feats["score"].float_list_val.val) == pytest.approx([0.9])
289+
finally:
290+
store.teardown(cfg, [fv], [])
291+
292+
293+
def test_event_timestamp_returned(docker_config):
294+
"""The event timestamp written with a row is returned correctly by online_read.
295+
296+
Mirrors ``test_online_retrieval_with_event_timestamps`` from the universal
297+
suite which verifies per-entity timestamps for all online store types.
298+
"""
299+
store = ScyllaDBOnlineStore()
300+
cfg = docker_config
301+
fv = _make_feature_view("test_event_ts")
302+
303+
store.update(cfg, [], [fv], [], [], partial=False)
304+
try:
305+
ek = _make_entity_key("ts_item")
306+
# Use a second-boundary timestamp — CQL ``timestamp`` has ms precision.
307+
write_ts = datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc)
308+
store.online_write_batch(
309+
cfg,
310+
fv,
311+
[(ek, {"score": ValueProto(float_list_val=FloatList(val=[0.5]))}, write_ts, None)],
312+
None,
313+
)
314+
results = store.online_read(cfg, fv, [ek])
315+
assert len(results) == 1
316+
ts, feats = results[0]
317+
assert ts is not None
318+
assert isinstance(ts, datetime)
319+
# Normalise both to UTC before comparing.
320+
ts_utc = ts if ts.tzinfo is not None else ts.replace(tzinfo=timezone.utc)
321+
assert ts_utc == write_ts
322+
finally:
323+
store.teardown(cfg, [fv], [])
324+
325+
179326
# ---------------------------------------------------------------------------
180327
# Tests — vector search (local Docker stack via docker_config)
181328
# ---------------------------------------------------------------------------
182329

183330

184331
def test_vector_search(docker_config):
185-
import time
186-
187332
store = ScyllaDBOnlineStore()
188333
cfg = docker_config
189334
fv = _make_feature_view("test_vector_search", with_vector=True)
@@ -236,14 +381,28 @@ def test_vector_search(docker_config):
236381
assert len(results) == 2
237382

238383
# Extract entity IDs from the returned entity key protos.
384+
# Also verify the shape and types of every field in each result tuple.
385+
# The expected embeddings keyed by entity ID for value verification.
386+
expected_embeddings = {
387+
"vec_a": [1.0, 0.0, 0.0, 0.0],
388+
"vec_b": [0.0, 1.0, 0.0, 0.0],
389+
"vec_c": [1.0, 0.1, 0.0, 0.0],
390+
}
239391
returned_ids = []
240392
for ts, ek_proto, feats in results:
241-
assert ts is not None
393+
assert isinstance(ts, datetime), f"Expected datetime, got {type(ts)}"
242394
assert feats is not None
243395
assert "score" in feats
244396
assert "embedding" in feats
397+
# score field: single-element float list written as 0.5
398+
assert list(feats["score"].float_list_val.val) == pytest.approx([0.5])
399+
# embedding field: values must match what was written
400+
entity_id = ek_proto.entity_values[0].string_val
401+
assert list(feats["embedding"].float_list_val.val) == pytest.approx(
402+
expected_embeddings[entity_id]
403+
)
245404
assert ek_proto is not None
246-
returned_ids.append(ek_proto.entity_values[0].string_val)
405+
returned_ids.append(entity_id)
247406

248407
# Query is [1,0,0,0]; vec_a=[1,0,0,0] (exact match) and
249408
# vec_c=[1,0.1,0,0] are the two nearest neighbours by cosine similarity.
@@ -257,3 +416,47 @@ def test_vector_search(docker_config):
257416
)
258417
finally:
259418
store.teardown(cfg, [fv], [])
419+
420+
421+
def test_ttl_expiry(docker_config):
422+
"""Rows written with a TTL should be gone after ScyllaDB expires them."""
423+
store = ScyllaDBOnlineStore()
424+
cfg = docker_config
425+
# TTL of 2 seconds, short enough for a test, long enough to write first.
426+
fv = _make_feature_view("test_ttl_expiry", ttl=timedelta(seconds=5))
427+
428+
store.update(cfg, [], [fv], [], [], partial=False)
429+
try:
430+
ek = _make_entity_key("ttl_item")
431+
store.online_write_batch(
432+
cfg,
433+
fv,
434+
[
435+
(
436+
ek,
437+
{"score": ValueProto(float_list_val=FloatList(val=[0.7]))},
438+
_utc_now(),
439+
None,
440+
)
441+
],
442+
None,
443+
)
444+
445+
# Confirm the row is readable immediately after writing.
446+
results = store.online_read(cfg, fv, [ek])
447+
assert len(results) == 1
448+
ts, feats = results[0]
449+
assert feats is not None, "Row should be present right after write"
450+
451+
# Wait for ScyllaDB to expire the row (TTL = 5s, wait 8s to be safe).
452+
time.sleep(8)
453+
454+
results = store.online_read(cfg, fv, [ek])
455+
assert len(results) == 1
456+
ts_after, feats_after = results[0]
457+
assert feats_after is None, (
458+
"Row should have expired and return None, "
459+
f"but got features: {feats_after}"
460+
)
461+
finally:
462+
store.teardown(cfg, [fv], [])

0 commit comments

Comments
 (0)