Skip to content

MLflow for Darts implementation#3022

Draft
jakubchlapek wants to merge 69 commits into
unit8co:masterfrom
jakubchlapek:feat/mlflow-base
Draft

MLflow for Darts implementation#3022
jakubchlapek wants to merge 69 commits into
unit8co:masterfrom
jakubchlapek:feat/mlflow-base

Conversation

@jakubchlapek
Copy link
Copy Markdown
Collaborator

Checklist before merging this PR:

  • Mentioned all issues that this PR fixes or addresses.
  • Summarized the updates of this PR under Summary.
  • Added an entry under Unreleased in the Changelog.

Addresses #2092 .

Summary

Provides a custom MLflow flavor for Darts on Darts' side. Supports autologging, logging, saving and loading of the models.
This PR focuses on the base MLflow integration, leaving serving of the models to be discussed in the future.

Included an example quickstart for the integration, however consider all of this a draft :)
Find example code in the .ipynb, however also providing a code snippet here as a quick reproducible example:

import mlflow
import tempfile
import os
from darts.metrics.metrics import smape
from darts.utils.mlflow import load_model, autolog
from darts.models import NBEATSModel, LinearRegressionModel
from darts.datasets import AirPassengersDataset
from torchmetrics import MeanAbsoluteError

# temp file setup
tmpdir = tempfile.mkdtemp()
mlflow_db = os.path.join(tmpdir, "mlflow.db")
mlflow.set_tracking_uri(f"sqlite:///{mlflow_db}")
mlflow.set_experiment("darts-forecasting")

train, val = AirPassengersDataset().load().astype("float32").split_before(0.7)

# autologging - patches .fit() on all ForecastingModel subclasses.
# for PyTorch-based models, inject_per_epoch_callbacks injects a Lightning callback
# that logs train/val loss or/and  user-specified torch metrics at the end of each epoch automatically.
autolog(
    log_models=True,
    log_params=True,
    log_training_metrics=True,
    log_validation_metrics=True,   # requires val_series in .fit()
    inject_per_epoch_callbacks=True, 
    extra_metrics=[smape],         # optional extra darts metric functions
)

with mlflow.start_run(run_name="nbeats") as run:
    model = NBEATSModel(
        input_chunk_length=24, 
        output_chunk_length=12,
        torch_metrics=MeanAbsoluteError())
    # val_series is forwarded to Lightning's val_dataloaders;
    # autolog captures per-epoch val metrics via the injected callback
    model.fit(train, val_series=val, epochs=10)
    run_id = run.info.run_id


# regression/sklearn models work identically
with mlflow.start_run(run_name="linreg"):
    model = LinearRegressionModel(lags=12)
    model.fit(train)  # logs params + in-sample metrics

# load back from MLflow
loaded = load_model(f"runs:/{run_id}/model")
preds = loaded.predict(12, series=train) # need to specify series as we save with clean=True in save_model

# import shutil
# shutil.rmtree(tmpdir)

@review-notebook-app
Copy link
Copy Markdown

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@jakubchlapek
Copy link
Copy Markdown
Collaborator Author

Hey @daidahao, adding this draft PR in the meantime so you and @dennisbader can have a look at what I have currently regarding the integration. There are still some decisions I am not too thrilled about and decisions to be made about the overall direction, but I'm happy to talk more about it during the meeting. Thanks for being so active for the library, really nice to be working together :)

daidahao and others added 9 commits March 5, 2026 19:30
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
@daidahao
Copy link
Copy Markdown
Contributor

daidahao commented Mar 7, 2026

@mizeller @jakubchlapek @dennisbader

Greetings! I've addressed most of the comments here, except for a few discussion points.

I've left a TODO note on post-fitting metrics which, IMHO, are HARD to implement at this point due to how MLflow manages active runs in autolog context. In short, we would need to keep a mapping between MLflow run ids, fitted models, model predictions, and metrics, to ensure the metrics are logged under the right run id (see mlflow.sklearn).

Sincere apologies for suggesting post-fitting metrics in the first place! I didn't realise the complexity involved.

My suggestion is to skip post-fitting metrics for now or settle for compromises such as non-terminated active runs (at the risk of cross-logging).

Other than that, I am truly proud of what we have achieved here and will hand this over to @mizeller for runups and more great work.

@daidahao
Copy link
Copy Markdown
Contributor

@mizeller @jakubchlapek @dennisbader

Just thought of an easy way to track TimeSeries provenance without a mlflow.sklearn-style _AutologgingMetricsManager for metric auto-logging:

Since TimeSeries here is a custom container rather than generic numpy arrays in sklearn, we could write run_id into its metadata dict whenever they are generated from forecasts. There are different ways of doing that, either via patching predict(), historical_forecasts() etc., OR changing _build_forecast_series(), etc.

Logging metrics then becomes a lot easier. When patching darts.metrics, we only need to identify run_id in TimeSeries.metadata and log metrics to those runs accordingly.

What do you think?

@jakubchlapek
Copy link
Copy Markdown
Collaborator Author

Hi @daidahao, it's been a minute haha, Michel will be continuing on this, but I've found some time to checkout the changes so far. Thanks a lot for the review and your updates, they look good and make sense to me :).

Regarding the post-fitting metrics, I agree that the issue is more complicated than originally envisioned. I think the TimeSeries.metadata patching is a smart idea that would work, but I'm not a fan of mutating the TimeSeries. In my mind I'd like the objects to not be impacted by the logging (for example, that later we don't export the series with the run_id still attached). I would second leaving the solution as is i.e. logging to the current non-terminated active run, with the behavior documented well. I think the autolog even in the context block where it works correctly
e.g.

with mlflow.start_run(run_name="whatever"):
    # metrics should log here nicely

could bring nice value and be useful. Important that we document the risk of crosslogging for multiple active runs though.

Another issue that I'd like to focus on first to have in the merged PR would be the support for historical_forecasts(retrain=True # or int) on autologging.
In the current implementation, since we patch fit the hfc will start a run for each iteration, along with saving models and any artifacts. The two ideas that come to mind are to detect if the .fit was called from hfc, ignore any fits called from within the method, and get only the result, OR patch historical_forecasts directly to supress the autologging for the iterations however I haven't yet investigated deeper, so just brainstorming :).
Secondly, also stemming from the retrain issue, is that the current covariate saving methods are based on the model's past_covariate_series attribute. Since in historical_forecasts we currently train new internal models for each iteration and fit them, we also don't pass the correct flags to the final model, since it's not modfied.
code snippet from forecasting_model.py L1180-1190

if apply_retrain:
    # fit a new instance of the model
    model = model.untrained_model()
    model._fit_wrapper(
        series=train_series_tf,
        past_covariates=past_covariates_tf,
        future_covariates=future_covariates_tf,
        sample_weight=sample_weight_tf,
        val_series=val_series_tf,
        **fit_kwargs,
    )

This leads to the covariates not saving to covariates.json later. Not sure what the best course of action here is, but I think this could be easily solved by patching historical_forecasts and logging directly. I would prefer to avoid patching so many functions, but if it makes sense then it's fine.
Thanks a lot for all the work so far, I think we can soon have this done properly :)

@daidahao
Copy link
Copy Markdown
Contributor

@jakubchlapek

Thank you for the response here. For metric logging, I agree with your suggestion on leaving it as is at some risks of cross-logging models as long as those risks are clearly documented. In the long term, I would prefer a more robust solution using either a _AutologgingMetricsManager or metadata, while we continue exploring other options. As far as I could see, the key concerns of mutating TimeSeries would be the overheads, which can be kept to minimal by disabling copying, etc., while the run_id metadata is not overtly intrusive given the MLflow context. Is there anything else that I missed?

PS. Speaking of exporting, AFAIK, there is no official import/export functionality and file format for TimeSeries. I do wish there is a safe/easy way to do so without relying on pickle--Darts v0.34.0 last year broke the old TimeSeries in pickle format and we had to rerun all experiments at one point. :(

For the historical_forecasts() support, I also agree with your solutions here. However, if time is a concern here for @mizeller, I would not mind implementing the support in another PR and targeting a future release. So long as the current limitations are clearly documented and integration is marked as beta. It would be helpful to gather early feedback from Darts users, while basic integration (fit, predict, save, load, etc.) is covered.

@mizeller
Copy link
Copy Markdown
Contributor

mizeller commented Apr 2, 2026

@daidahao I planned to work on the historical_forecast support next week.

If you're available before, feel free:) I unfortunately didnt have any capacity the last few weeks - sorry about that!

@dennisbader
Copy link
Copy Markdown
Collaborator

Thanks everyone for all the work and the recent pushes to this PR 🚀
@mizeller could you give a quick summary of the current state and what is still missing before the PR can be finalized?

@mizeller
Copy link
Copy Markdown
Contributor

Off the top of my head, the status on the MLFlow PR:

  • historical forecasts / backtesting is patched now. i.e. metrics are logged correctly in both cases. tested 
    • w/ local/global forecasting, torch models
    • w/ all backtest(reduction=XXX) flag
  • deprecated the managed_run flag in (autolog())
    reason: following discussion w/ @dennisbader we decided to enforce a "desired" way of using MLFlow x Darts (& make our lifes easier in the process)

TODO

  • so far I've always worked with only one timeseries object. the following cases should be handled in a user-friendly manner:
series = AirPassengersDataset().load().astype(np.float32)
series_multiple = [series, series / 3.]
series_multivariate = series.stack(series / 3.)
series_multiple_multivariate = [series.stack(series / 3.), series.stack(series / 10.)]
  • there's a problem ("bug") w/ metrics logging. currently, a metric's name is used in MLFlow, which is generally fine. but i.e. for mase + different kwargs, it is only logged once (same key). solution: when passing metrics_kwargs augment the metric name used on MLFlow, i.e.:
    model.backtest(
        series=series,
        historical_forecasts=hfc,
        last_points_only=False,
        metric=[darts_metrics.mape, darts_metrics.rmse, darts_metrics.ape, darts_metrics.mase, darts_metrics.mase],
        metric_kwargs=[{}, {}, {}, {"m": 1}, {"m": 2}],
        reduction=None,
    )
  • ensure in the multiple series case, the results are usable; the plots will probably explode with very long lists of timeseries.
  • (@dennisbader I think we talked about more todos/permutation of input params but I can't recall exactly which ones)

Also, I believe the TODO regarding metrics/kwargs was implemented in the most recent commit by @jakubchlapek - very cool! :)

Copy link
Copy Markdown
Collaborator Author

@jakubchlapek jakubchlapek left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, looks nice @mizeller, just a few comments on the historical forecasts. The hfcs solution is nice.

Comment thread darts/utils/mlflow.py
Comment on lines +620 to +627
if metric is None:
try:
sig = inspect.signature(original)
bound = sig.bind(self, *args, **kwargs)
bound.apply_defaults()
metric = bound.arguments.get("metric")
except Exception:
pass
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say we can remove this, I don't believe anyone will pass in metrics positionally and it adds unnecessary complexity to the code (default mape will then still be covered by else branch)

Comment thread darts/utils/mlflow.py
Comment on lines +662 to +663
# 2-D and higher: skip to keep MVP simple

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'd prefer to include this in the PR if possible :)

Comment thread darts/utils/mlflow.py
Comment on lines +666 to +696
if isinstance(metric, (list | tuple)) and isinstance(result, list):
# multiple metrics → result is list[scalar_or_array], one per metric
for name, r in zip(names, result):
_log(f"backtest_{name}", r)
elif (
isinstance(metric, list | tuple)
and result_arr is not None
and result_arr.ndim == 1
and len(result_arr) == len(names)
):
# multiple metrics with scalar reduction returned as a 1-D ndarray
# (e.g. np.mean/median/percentile) — log each as a separate scalar
for name, r in zip(names, result_arr):
autologging_client.log_metrics(
run_id=run_id, metrics={f"backtest_{name}": float(r)}
)
elif result_arr is not None and result_arr.ndim == 2:
# (N_windows, N_metrics) ndarray — multi-metric + reduction=None
for col_i, name in enumerate(names[: result_arr.shape[1]]):
for step, val in enumerate(result_arr[:, col_i]):
autologging_client.log_metrics(
run_id=run_id,
metrics={f"backtest_{name}": float(val)},
step=step,
)
elif isinstance(result, list):
# single metric, multiple series → result is list[scalar_or_array]
for s_i, r in enumerate(result):
_log(f"backtest_{names[0]}_{s_i}", r)
else:
_log(f"backtest_{names[0]}", result)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ideally we would like to also support multivariate series where we can log per component if no reduction (e.g. maybe [backtest_MAE_x, backtest_MAE_y]). I worry that this approach can then get a bit complex with all the branches. Maybe we can think about normalizing the result to a dataframe first which could simplify logging? Let me know what you think here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants