diff --git a/sdk/python/feast/infra/key_encoding_utils.py b/sdk/python/feast/infra/key_encoding_utils.py index 1f9ffeef140..0966d13d770 100644 --- a/sdk/python/feast/infra/key_encoding_utils.py +++ b/sdk/python/feast/infra/key_encoding_utils.py @@ -19,6 +19,12 @@ def _serialize_val( if 0 <= entity_key_serialization_version <= 1: return struct.pack(" ValueProto: return ValueProto(string_val=value) elif value_type == ValueType.BYTES: return ValueProto(bytes_val=value_bytes) + elif value_type == ValueType.UNIX_TIMESTAMP: + value = struct.unpack(" Dict[str, str]: + self.container.start() + log_string_to_wait_for = "Server initialized" + wait_for_logs( + container=self.container, predicate=log_string_to_wait_for, timeout=120 + ) + host = self.container.get_container_host_ip() + exposed_port = int(self.container.get_exposed_port(self.container.port)) + connection_string = f"{host}:{exposed_port}" + print(f"connection_string: {connection_string}") + return { + "connection_string": connection_string, + } + + def teardown(self): + self.container.stop() diff --git a/sdk/python/tests/unit/infra/online_store/test_redis.py b/sdk/python/tests/unit/infra/online_store/test_redis.py index c26c2f25c5f..3ec2c196f95 100644 --- a/sdk/python/tests/unit/infra/online_store/test_redis.py +++ b/sdk/python/tests/unit/infra/online_store/test_redis.py @@ -1,11 +1,24 @@ +from datetime import datetime, timedelta + import pytest from google.protobuf.timestamp_pb2 import Timestamp +from redis import Redis -from feast import Entity, FeatureView, Field, FileSource, RepoConfig -from feast.infra.online_stores.redis import RedisOnlineStore +from feast import Entity, FeatureView, Field, FileSource, RepoConfig, ValueType +from feast.infra.online_stores.helpers import _mmh3, _redis_key +from feast.infra.online_stores.redis import RedisOnlineStore, RedisOnlineStoreConfig +from feast.protos.feast.core.SortedFeatureView_pb2 import SortOrder from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto -from feast.types import Int32 +from feast.sorted_feature_view import SortedFeatureView, SortKey +from feast.types import ( + Float32, + Int32, + UnixTimestamp, +) +from tests.unit.infra.online_store.redis_online_store_creator import ( + RedisOnlineStoreCreator, +) @pytest.fixture @@ -13,11 +26,22 @@ def redis_online_store() -> RedisOnlineStore: return RedisOnlineStore() +@pytest.fixture(scope="session") +def redis_online_store_config(): + creator = RedisOnlineStoreCreator("redis_project") + config = creator.create_online_store() + yield config + creator.teardown() + + @pytest.fixture -def repo_config(): +def repo_config(redis_online_store_config): return RepoConfig( provider="local", project="test", + online_store=RedisOnlineStoreConfig( + connection_string=redis_online_store_config["connection_string"], + ), entity_key_serialization_version=2, registry="dummy_registry.db", ) @@ -128,3 +152,272 @@ def test_get_features_for_entity(redis_online_store: RedisOnlineStore, feature_v assert "feature_view_1:feature_11" in features assert features["feature_view_1:feature_10"].int32_val == 1 assert features["feature_view_1:feature_11"].int32_val == 2 + + +def test_redis_online_write_batch_with_timestamp_as_sortkey( + repo_config: RepoConfig, + redis_online_store: RedisOnlineStore, +): + ( + feature_view, + data, + ) = _create_sorted_feature_view_with_timestamp_as_sortkey() + + redis_online_store.online_write_batch( + config=repo_config, + table=feature_view, + data=data, + progress=None, + ) + + connection_string = repo_config.online_store.connection_string + connection_string_split = connection_string.split(":") + conn_dict = {} + conn_dict["host"] = connection_string_split[0] + conn_dict["port"] = connection_string_split[1] + + r = Redis(**conn_dict) + + pipe = r.pipeline(transaction=True) + + entity_key_driver_1 = EntityKeyProto( + join_keys=["driver_id"], + entity_values=[ValueProto(int32_val=1)], + ) + + redis_key_bin_driver_1 = _redis_key( + repo_config.project, + entity_key_driver_1, + entity_key_serialization_version=repo_config.entity_key_serialization_version, + ) + + zset_key_driver_1 = f"{repo_config.project}:{feature_view.name}:{feature_view.sort_keys[0].name}:{redis_key_bin_driver_1}" + + entity_key_driver_2 = EntityKeyProto( + join_keys=["driver_id"], + entity_values=[ValueProto(int32_val=2)], + ) + redis_key_bin_driver_2 = _redis_key( + repo_config.project, + entity_key_driver_2, + entity_key_serialization_version=repo_config.entity_key_serialization_version, + ) + + zset_key_driver_2 = f"{repo_config.project}:{feature_view.name}:{feature_view.sort_keys[0].name}:{redis_key_bin_driver_2}" + + driver_1_zset_members = r.zrange(zset_key_driver_1, 0, -1, withscores=True) + driver_2_zset_members = r.zrange(zset_key_driver_2, 0, -1, withscores=True) + + assert len(driver_1_zset_members) == 5 + assert len(driver_2_zset_members) == 5 + + # Get last 3 trips for both drivers from the respective sorted sets + last_3_trips_driver_1 = r.zrevrangebyscore( + zset_key_driver_1, "+inf", "-inf", start=0, num=3 + ) + last_3_trips_driver_2 = r.zrevrangebyscore( + zset_key_driver_2, "+inf", "-inf", start=0, num=3 + ) + + # Look up features for last 3 trips for driver 1 + for id in last_3_trips_driver_1: + pipe.hgetall(id) + + # Look up features for last 3 trips for driver 2 + for id in last_3_trips_driver_2: + pipe.hgetall(id) + + features_list = pipe.execute() + + trip_id_feature_name = _mmh3(f"{feature_view.name}:trip_id") + trip_id_drivers = [] + for feature_dict in features_list: + val = ValueProto() + val.ParseFromString(feature_dict[trip_id_feature_name]) + trip_id_drivers.append(val.int32_val) + assert trip_id_drivers == [4, 3, 2, 9, 8, 7] + + +def test_redis_online_write_batch_with_float_as_sortkey( + repo_config: RepoConfig, + redis_online_store: RedisOnlineStore, +): + ( + feature_view, + data, + ) = _create_sorted_feature_view_with_float_as_sortkey() + + redis_online_store.online_write_batch( + config=repo_config, + table=feature_view, + data=data, + progress=None, + ) + + connection_string = repo_config.online_store.connection_string + connection_string_split = connection_string.split(":") + conn_dict = {} + conn_dict["host"] = connection_string_split[0] + conn_dict["port"] = connection_string_split[1] + + r = Redis(**conn_dict) + + pipe = r.pipeline(transaction=True) + + entity_key_driver_1 = EntityKeyProto( + join_keys=["driver_id"], + entity_values=[ValueProto(int32_val=1)], + ) + + redis_key_bin_driver_1 = _redis_key( + repo_config.project, + entity_key_driver_1, + entity_key_serialization_version=repo_config.entity_key_serialization_version, + ) + + zset_key_driver_1 = f"{repo_config.project}:{feature_view.name}:{feature_view.sort_keys[0].name}:{redis_key_bin_driver_1}" + + entity_key_driver_2 = EntityKeyProto( + join_keys=["driver_id"], + entity_values=[ValueProto(int32_val=2)], + ) + redis_key_bin_driver_2 = _redis_key( + repo_config.project, + entity_key_driver_2, + entity_key_serialization_version=repo_config.entity_key_serialization_version, + ) + + zset_key_driver_2 = f"{repo_config.project}:{feature_view.name}:{feature_view.sort_keys[0].name}:{redis_key_bin_driver_2}" + + driver_1_zset_members = r.zrange(zset_key_driver_1, 0, -1, withscores=True) + driver_2_zset_members = r.zrange(zset_key_driver_2, 0, -1, withscores=True) + + assert len(driver_1_zset_members) == 5 + assert len(driver_2_zset_members) == 5 + + # Get trips for driver 1 where ratings between 2.5 and 4.5 + # Get trips for driver 2 where ratings between 7.5 and 9.5 + driver_1_trips = r.zrangebyscore(zset_key_driver_1, 2.5, 4.5) + driver_2_trips = r.zrangebyscore(zset_key_driver_2, 7.5, 9.5) + + # Look up features for trips for driver 1 + for id in driver_1_trips: + pipe.hgetall(id) + + # Look up features for trips for driver 2 + for id in driver_2_trips: + pipe.hgetall(id) + + features_list = pipe.execute() + + trip_id_feature_name = _mmh3(f"{feature_view.name}:trip_id") + trip_id_drivers = [] + for feature_dict in features_list: + val = ValueProto() + val.ParseFromString(feature_dict[trip_id_feature_name]) + trip_id_drivers.append(val.int32_val) + assert trip_id_drivers == [2, 3, 4, 7, 8, 9] + + +def _create_sorted_feature_view_with_timestamp_as_sortkey(): + fv = SortedFeatureView( + name="driver_stats", + source=FileSource( + name="my_file_source", + path="test.parquet", + timestamp_field="event_timestamp", + ), + entities=[Entity(name="driver_id")], + ttl=timedelta(seconds=10), + sort_keys=[ + SortKey( + name="event_timestamp", + value_type=ValueType.UNIX_TIMESTAMP, + default_sort_order=SortOrder.DESC, + ) + ], + schema=[ + Field( + name="driver_id", + dtype=Int32, + ), + Field(name="event_timestamp", dtype=UnixTimestamp), + Field( + name="trip_id", + dtype=Int32, + ), + Field( + name="rating", + dtype=Float32, + ), + ], + ) + + return fv, _make_rows() + + +def _create_sorted_feature_view_with_float_as_sortkey(n=10): + fv = SortedFeatureView( + name="driver_stats", + source=FileSource( + name="my_file_source", + path="test.parquet", + timestamp_field="event_timestamp", + ), + entities=[Entity(name="driver_id")], + ttl=timedelta(seconds=10), + sort_keys=[ + SortKey( + name="rating", + value_type=ValueType.FLOAT, + default_sort_order=SortOrder.DESC, + ) + ], + schema=[ + Field( + name="driver_id", + dtype=Int32, + ), + Field(name="event_timestamp", dtype=UnixTimestamp), + Field( + name="trip_id", + dtype=Int32, + ), + Field( + name="rating", + dtype=Float32, + ), + ], + ) + + return fv, _make_rows() + + +def _make_rows(n=10): + """Generate 10 rows split between driver_id 1 (first 5) and 2 (rest), + with rating = i + 0.5 and an event_timestamp spanning ~15 minutes.""" + return [ + ( + EntityKeyProto( + join_keys=["driver_id"], + entity_values=[ + ValueProto(int32_val=1) if i <= 4 else ValueProto(int32_val=2) + ], + ), + { + "trip_id": ValueProto(int32_val=i), + "rating": ValueProto(float_val=i + 0.5), + "event_timestamp": ValueProto( + unix_timestamp_val=int( + ( + (datetime.utcnow() - timedelta(minutes=15)) + + timedelta(minutes=i) + ).timestamp() + ) + ), + }, + datetime.utcnow(), + None, + ) + for i in range(n) + ]