Skip to content

Commit 2a5cac7

Browse files
committed
fix-auto-log
Signed-off-by: Vanshika Vanshika <vvanshik@redhat.com> rh-pre-commit.version: 2.3.2 rh-pre-commit.check-secrets: ENABLED
1 parent f18297f commit 2a5cac7

4 files changed

Lines changed: 98 additions & 35 deletions

File tree

sdk/python/feast/feature_store.py

Lines changed: 89 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -311,13 +311,10 @@ def _resolve_feature_service_name(self, feature_refs: List[str]) -> Optional[str
311311
_logger.debug("Failed to resolve feature service name: %s", e)
312312
return None
313313

314-
def _auto_log_entity_df_info(self, entity_df, start_date=None, end_date=None):
315-
"""Log entity_df info to MLflow for reproducibility.
314+
def _log_entity_df_metadata(self, entity_df, start_date=None, end_date=None):
315+
"""Log lightweight entity_df metadata to MLflow (type, row count, columns, query, dates).
316316
317-
Handles three entity_df types:
318-
- pd.DataFrame: saves metadata + full parquet artifact (within configured limit)
319-
- str (SQL query): logs the query as a param
320-
- None (range-based): logs start_date/end_date
317+
Always called during historical retrieval regardless of auto_log_entity_df.
321318
"""
322319
try:
323320
import mlflow
@@ -350,15 +347,6 @@ def _auto_log_entity_df_info(self, entity_df, start_date=None, end_date=None):
350347
cols = cols[:MLFLOW_PARAM_TRUNCATION_SLICE] + "..."
351348
client.log_param(run_id, "feast.entity_df_columns", cols)
352349

353-
max_rows = mlflow_cfg.entity_df_max_rows
354-
if len(entity_df) <= max_rows:
355-
import tempfile
356-
357-
with tempfile.TemporaryDirectory() as tmp_dir:
358-
path = os.path.join(tmp_dir, "entity_df.parquet")
359-
entity_df.to_parquet(path, index=False)
360-
client.log_artifact(run_id, path)
361-
362350
elif entity_df is None and (start_date or end_date):
363351
client.set_tag(run_id, "feast.entity_df_type", "range")
364352
if start_date:
@@ -367,7 +355,38 @@ def _auto_log_entity_df_info(self, entity_df, start_date=None, end_date=None):
367355
client.log_param(run_id, "feast.end_date", str(end_date))
368356

369357
except Exception as e:
370-
_logger.debug("Failed to log entity_df info to MLflow: %s", e)
358+
_logger.debug("Failed to log entity_df metadata to MLflow: %s", e)
359+
360+
def _log_entity_df_artifact(self, entity_df):
361+
"""Upload entity DataFrame as a parquet artifact to MLflow.
362+
363+
Only called when auto_log_entity_df is True and entity_df is a DataFrame
364+
within the configured row limit.
365+
"""
366+
try:
367+
import mlflow
368+
369+
if mlflow.active_run() is None:
370+
return
371+
if not isinstance(entity_df, pd.DataFrame):
372+
return
373+
374+
mlflow_cfg = self.config.mlflow
375+
tracking_uri = mlflow_cfg.get_tracking_uri()
376+
client = mlflow.MlflowClient(tracking_uri=tracking_uri)
377+
run_id = mlflow.active_run().info.run_id
378+
379+
max_rows = mlflow_cfg.entity_df_max_rows
380+
if len(entity_df) <= max_rows:
381+
import tempfile
382+
383+
with tempfile.TemporaryDirectory() as tmp_dir:
384+
path = os.path.join(tmp_dir, "entity_df.parquet")
385+
entity_df.to_parquet(path, index=False)
386+
client.log_artifact(run_id, path)
387+
388+
except Exception as e:
389+
_logger.debug("Failed to log entity_df artifact to MLflow: %s", e)
371390

372391
def _init_openlineage_emitter(self) -> Optional[Any]:
373392
"""Initialize OpenLineage emitter if configured and enabled."""
@@ -1750,10 +1769,12 @@ def get_historical_features(
17501769
tracking_uri=self.config.mlflow.get_tracking_uri(),
17511770
)
17521771

1772+
self._log_entity_df_metadata(
1773+
entity_df, start_date=start_date, end_date=end_date
1774+
)
1775+
17531776
if self.config.mlflow.auto_log_entity_df:
1754-
self._auto_log_entity_df_info(
1755-
entity_df, start_date=start_date, end_date=end_date
1756-
)
1777+
self._log_entity_df_artifact(entity_df)
17571778
except Exception as e:
17581779
_logger.debug("MLflow auto-log failed for historical retrieval: %s", e)
17591780

@@ -3036,7 +3057,9 @@ async def get_online_features_async(
30363057
"""
30373058
provider = self._get_provider()
30383059

3039-
return await provider.get_online_features_async(
3060+
_retrieval_start = time.monotonic()
3061+
3062+
response = await provider.get_online_features_async(
30403063
config=self.config,
30413064
features=features,
30423065
entity_rows=entity_rows,
@@ -3046,6 +3069,52 @@ async def get_online_features_async(
30463069
include_feature_view_version_metadata=include_feature_view_version_metadata,
30473070
)
30483071

3072+
try:
3073+
if (
3074+
self.config.mlflow is not None
3075+
and self.config.mlflow.enabled
3076+
and self.config.mlflow.auto_log
3077+
):
3078+
_log_fn = _get_mlflow_log_fn()
3079+
if _log_fn is not None:
3080+
_duration = time.monotonic() - _retrieval_start
3081+
_feature_refs = utils._get_features(
3082+
self.registry, self.project, features, allow_cache=True
3083+
)
3084+
if isinstance(entity_rows, list):
3085+
_entity_count = len(entity_rows)
3086+
elif isinstance(entity_rows, Mapping):
3087+
try:
3088+
_first_col = next(iter(entity_rows.values()))
3089+
if isinstance(_first_col, RepeatedValue):
3090+
_entity_count = len(_first_col.val)
3091+
else:
3092+
_entity_count = len(_first_col)
3093+
except Exception:
3094+
_entity_count = 0
3095+
else:
3096+
_entity_count = 0
3097+
_fs = features if isinstance(features, FeatureService) else None
3098+
_fs_name = (
3099+
features.name
3100+
if isinstance(features, FeatureService)
3101+
else self._resolve_feature_service_name(_feature_refs)
3102+
)
3103+
_log_fn(
3104+
feature_refs=_feature_refs,
3105+
entity_count=_entity_count,
3106+
duration_seconds=_duration,
3107+
retrieval_type="online",
3108+
feature_service=_fs,
3109+
feature_service_name=_fs_name,
3110+
project=self.project,
3111+
tracking_uri=self.config.mlflow.get_tracking_uri(),
3112+
)
3113+
except Exception as e:
3114+
_logger.debug("MLflow auto-log failed for online retrieval: %s", e)
3115+
3116+
return response
3117+
30493118
def retrieve_online_documents(
30503119
self,
30513120
query: Union[str, List[float]],

sdk/python/feast/mlflow_integration/client.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,7 @@ def _log_required_features(self):
144144
path = os.path.join(tmp_dir, "required_features.json")
145145
with open(path, "w") as f:
146146
json.dump(features, f)
147-
self._client.log_artifact(
148-
run.info.run_id, path, artifact_path=""
149-
)
147+
self._client.log_artifact(run.info.run_id, path, artifact_path="")
150148
except Exception as e:
151149
_logger.debug("Failed to log required_features.json: %s", e)
152150

sdk/python/feast/mlflow_integration/logger.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,7 @@ def log_apply_to_mlflow(
207207
elif isinstance(obj, Entity) and obj.name != "__dummy":
208208
entity_names.append(obj.name)
209209

210-
run = client.create_run(
211-
experiment_id, run_name=f"apply_{project}"
212-
)
210+
run = client.create_run(experiment_id, run_name=f"apply_{project}")
213211
run_id = run.info.run_id
214212
try:
215213
client.set_tag(run_id, "feast.operation", "apply")
@@ -236,9 +234,7 @@ def log_apply_to_mlflow(
236234
client.log_metric(
237235
run_id, "feast.apply.feature_services_count", len(fs_names)
238236
)
239-
client.log_metric(
240-
run_id, "feast.apply.entities_count", len(entity_names)
241-
)
237+
client.log_metric(run_id, "feast.apply.entities_count", len(entity_names))
242238
finally:
243239
client.set_terminated(run_id)
244240

@@ -276,9 +272,7 @@ def log_materialize_to_mlflow(
276272
experiment_id = _get_or_create_experiment(client, experiment_name)
277273

278274
op_type = "materialize_incremental" if incremental else "materialize"
279-
run = client.create_run(
280-
experiment_id, run_name=f"{op_type}_{project}"
281-
)
275+
run = client.create_run(experiment_id, run_name=f"{op_type}_{project}")
282276
run_id = run.info.run_id
283277
try:
284278
client.set_tag(run_id, "feast.operation", op_type)

sdk/python/tests/integration/test_mlflow_integration.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -548,11 +548,13 @@ def test_auto_log_entity_df_false_skips_artifact(
548548
entity_df=entity_df,
549549
).to_df()
550550

551+
run_data = client.get_run(run.info.run_id).data
551552
artifacts = [a.path for a in client.list_artifacts(run.info.run_id)]
552553
assert "entity_df.parquet" not in artifacts
553-
assert "feast.entity_df_rows" not in client.get_run(run.info.run_id).data.params
554-
tags = client.get_run(run.info.run_id).data.tags
555-
assert tags["feast.feature_service"] == "driver_activity_v1"
554+
555+
assert "feast.entity_df_rows" in run_data.params
556+
assert run_data.tags["feast.entity_df_type"] == "dataframe"
557+
assert run_data.tags["feast.feature_service"] == "driver_activity_v1"
556558

557559

558560
class TestEntityDfBuilder:

0 commit comments

Comments
 (0)