diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py index bede2a6f44c..3fc675ea402 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py @@ -33,6 +33,7 @@ from pyspark.sql import SparkSession from feast import FeatureView, OnDemandFeatureView +from feast.batch_feature_view import BatchFeatureView from feast.data_source import DataSource from feast.dataframe import DataFrameEngine, FeastDataFrame from feast.errors import EntitySQLEmptyResults, InvalidEntityType @@ -260,6 +261,10 @@ def get_historical_features( entity_df_event_timestamp_range, ) + query_context = _apply_bfv_transformations( + spark_session, feature_views, query_context + ) + spark_query_context = [ SparkFeatureViewQueryContext( **asdict(context), @@ -713,6 +718,62 @@ def _entity_schema_keys_from( ) +def _apply_bfv_transformations( + spark_session: SparkSession, + feature_views: List[FeatureView], + query_contexts: List[offline_utils.FeatureViewQueryContext], +) -> List[offline_utils.FeatureViewQueryContext]: + """ + For BatchFeatureViews with a UDF, read the raw source into a Spark DataFrame, + invoke the transformation, register the result as a temp view, and replace the + table_subquery in the query context so the PIT join reads transformed data. + """ + from dataclasses import replace + + from feast.feature_view_utils import ( + get_transformation_function, + has_transformation, + resolve_feature_view_source_with_fallback, + ) + + fv_by_name = {fv.projection.name_to_use(): fv for fv in feature_views} + + updated_contexts = [] + for ctx in query_contexts: + fv = fv_by_name.get(ctx.name) + if ( + fv is not None + and isinstance(fv, BatchFeatureView) + and has_transformation(fv) + ): + udf = get_transformation_function(fv) + if udf is not None: + source_info = resolve_feature_view_source_with_fallback(fv) + source_query = source_info.data_source.get_table_query_string() + + timestamp_filter = get_timestamp_filter_sql( + start_date=ctx.min_event_timestamp, + end_date=ctx.max_event_timestamp, + timestamp_field=ctx.timestamp_field, + tz=timezone.utc, + quote_fields=False, + ) + source_df = spark_session.sql( + f"SELECT * FROM {source_query} WHERE {timestamp_filter}" + ) + + transformed_df = udf(source_df) + + tmp_view_name = "feast_bfv_" + uuid.uuid4().hex + transformed_df.createOrReplaceTempView(tmp_view_name) + + ctx = replace(ctx, table_subquery=tmp_view_name) + + updated_contexts.append(ctx) + + return updated_contexts + + def _get_entity_df_event_timestamp_range( entity_df: Union[pd.DataFrame, str], entity_df_event_timestamp_col: str, diff --git a/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark_bfv_compute_on_read.py b/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark_bfv_compute_on_read.py new file mode 100644 index 00000000000..0bcc282ae83 --- /dev/null +++ b/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark_bfv_compute_on_read.py @@ -0,0 +1,222 @@ +""" +Unit tests for BFV compute-on-read in SparkOfflineStore.get_historical_features(). + +Verifies that BatchFeatureViews with a UDF have their transformation applied +during get_historical_features(), with the transformed DataFrame registered as +a temp view that replaces the raw table_subquery in the PIT join. +""" + +from dataclasses import replace +from unittest.mock import MagicMock + +import pytest + +from feast.batch_feature_view import BatchFeatureView +from feast.feature_view import FeatureView +from feast.infra.offline_stores import offline_utils +from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( + _apply_bfv_transformations, +) +from feast.infra.offline_stores.contrib.spark_offline_store.spark_source import ( + SparkSource, +) +from feast.transformation.base import Transformation + + +@pytest.fixture() +def spark_session(): + mock = MagicMock() + mock.sql.return_value = MagicMock(name="source_df") + return mock + + +@pytest.fixture() +def spark_source(): + source = MagicMock(spec=SparkSource) + source.get_table_query_string.return_value = "`raw_events`" + return source + + +@pytest.fixture() +def base_query_context(): + return offline_utils.FeatureViewQueryContext( + name="my_bfv", + ttl=3600, + entities=["user_id"], + features=["avg_rating"], + field_mapping={}, + timestamp_field="event_timestamp", + created_timestamp_column=None, + table_subquery="`raw_events`", + entity_selections=["user_id AS user_id"], + min_event_timestamp="2023-01-01T00:00:00", + max_event_timestamp="2024-01-01T00:00:00", + date_partition_column=None, + ) + + +def _make_bfv(name: str, spark_source, has_udf: bool = True): + """Create a mock BatchFeatureView with optional UDF.""" + fv = MagicMock(spec=BatchFeatureView) + fv.name = name + fv.projection = MagicMock() + fv.projection.name_to_use.return_value = name + fv.batch_source = spark_source + fv.source_views = [] + + if has_udf: + transformation = MagicMock(spec=Transformation) + transformed_df = MagicMock(name="transformed_df") + transformation.udf = MagicMock(return_value=transformed_df) + fv.feature_transformation = transformation + fv.udf = transformation.udf + else: + fv.feature_transformation = None + fv.udf = None + + return fv + + +def _make_plain_fv(name: str, spark_source): + """Create a mock plain FeatureView (not a BatchFeatureView).""" + fv = MagicMock(spec=FeatureView) + fv.name = name + fv.projection = MagicMock() + fv.projection.name_to_use.return_value = name + fv.batch_source = spark_source + fv.feature_transformation = None + fv.udf = None + return fv + + +class TestApplyBfvTransformations: + def test_bfv_with_udf_replaces_table_subquery( + self, spark_session, spark_source, base_query_context + ): + """BFV with a UDF should have its table_subquery replaced with a temp view.""" + bfv = _make_bfv("my_bfv", spark_source) + contexts = [base_query_context] + + result = _apply_bfv_transformations(spark_session, [bfv], contexts) + + assert len(result) == 1 + assert result[0].table_subquery != "`raw_events`" + assert result[0].table_subquery.startswith("feast_bfv_") + + def test_bfv_udf_is_invoked_with_source_df( + self, spark_session, spark_source, base_query_context + ): + """The UDF should be called with the DataFrame read from the raw source.""" + bfv = _make_bfv("my_bfv", spark_source) + contexts = [base_query_context] + + _apply_bfv_transformations(spark_session, [bfv], contexts) + + sql_arg = spark_session.sql.call_args[0][0] + assert "SELECT * FROM `raw_events`" in sql_arg + assert "WHERE" in sql_arg + source_df = spark_session.sql.return_value + bfv.feature_transformation.udf.assert_called_once_with(source_df) + + def test_transformed_df_registered_as_temp_view( + self, spark_session, spark_source, base_query_context + ): + """The transformed DataFrame should be registered as a temp view.""" + bfv = _make_bfv("my_bfv", spark_source) + transformed_df = bfv.feature_transformation.udf.return_value + contexts = [base_query_context] + + result = _apply_bfv_transformations(spark_session, [bfv], contexts) + + transformed_df.createOrReplaceTempView.assert_called_once() + view_name = transformed_df.createOrReplaceTempView.call_args[0][0] + assert view_name == result[0].table_subquery + + def test_plain_feature_view_unchanged( + self, spark_session, spark_source, base_query_context + ): + """Plain FeatureViews (not BFV) should pass through without modification.""" + fv = _make_plain_fv("my_bfv", spark_source) + contexts = [base_query_context] + + result = _apply_bfv_transformations(spark_session, [fv], contexts) + + assert result[0].table_subquery == "`raw_events`" + spark_session.sql.assert_not_called() + + def test_bfv_without_udf_unchanged( + self, spark_session, spark_source, base_query_context + ): + """BFV without a UDF should pass through without modification.""" + bfv = _make_bfv("my_bfv", spark_source, has_udf=False) + contexts = [base_query_context] + + result = _apply_bfv_transformations(spark_session, [bfv], contexts) + + assert result[0].table_subquery == "`raw_events`" + spark_session.sql.assert_not_called() + + def test_mixed_views_only_transforms_bfvs( + self, spark_session, spark_source, base_query_context + ): + """With mixed BFV and plain FVs, only BFVs with UDFs get transformed.""" + bfv = _make_bfv("my_bfv", spark_source) + plain_fv = _make_plain_fv("plain_fv", spark_source) + + ctx_bfv = base_query_context + ctx_plain = replace( + base_query_context, + name="plain_fv", + features=["some_feature"], + ) + + result = _apply_bfv_transformations( + spark_session, [bfv, plain_fv], [ctx_bfv, ctx_plain] + ) + + assert result[0].table_subquery.startswith("feast_bfv_") + assert result[1].table_subquery == "`raw_events`" + + def test_time_range_filter_applied( + self, spark_session, spark_source, base_query_context + ): + """Source query should include time bounds from the context.""" + bfv = _make_bfv("my_bfv", spark_source) + contexts = [base_query_context] + + _apply_bfv_transformations(spark_session, [bfv], contexts) + + sql_arg = spark_session.sql.call_args[0][0] + assert "2023-01-01" in sql_arg + assert "2024-01-01" in sql_arg + assert "event_timestamp" in sql_arg + + def test_time_range_filter_with_none_min_timestamp( + self, spark_session, spark_source, base_query_context + ): + """When min_event_timestamp is None (no TTL), query should still work.""" + bfv = _make_bfv("my_bfv", spark_source) + ctx = replace(base_query_context, min_event_timestamp=None) + + result = _apply_bfv_transformations(spark_session, [bfv], [ctx]) + + assert result[0].table_subquery.startswith("feast_bfv_") + sql_arg = spark_session.sql.call_args[0][0] + assert "2024-01-01" in sql_arg + + def test_other_context_fields_preserved( + self, spark_session, spark_source, base_query_context + ): + """All fields besides table_subquery should remain unchanged.""" + bfv = _make_bfv("my_bfv", spark_source) + contexts = [base_query_context] + + result = _apply_bfv_transformations(spark_session, [bfv], contexts) + + assert result[0].name == base_query_context.name + assert result[0].ttl == base_query_context.ttl + assert result[0].entities == base_query_context.entities + assert result[0].features == base_query_context.features + assert result[0].timestamp_field == base_query_context.timestamp_field + assert result[0].min_event_timestamp == base_query_context.min_event_timestamp + assert result[0].max_event_timestamp == base_query_context.max_event_timestamp