|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | +# |
| 18 | + |
| 19 | +# This file contains code adapted from the MOMENT project |
| 20 | +# (https://github.com/moment-timeseries-foundation-model/moment), |
| 21 | +# originally licensed under the MIT License. |
| 22 | + |
| 23 | +from typing import Optional |
| 24 | + |
| 25 | +from transformers import PretrainedConfig |
| 26 | + |
| 27 | + |
| 28 | +class MomentConfig(PretrainedConfig): |
| 29 | + """ |
| 30 | + Configuration class for the MOMENT time series foundation model. |
| 31 | +
|
| 32 | + MOMENT (A Family of Open Time-series Foundation Models) is developed by |
| 33 | + Auton Lab, Carnegie Mellon University. It uses a T5 encoder-only backbone |
| 34 | + with patch-based input embedding and RevIN normalization for multi-task |
| 35 | + time series analysis including forecasting, classification, anomaly |
| 36 | + detection and imputation. |
| 37 | +
|
| 38 | + Reference: https://arxiv.org/abs/2402.03885 |
| 39 | + """ |
| 40 | + |
| 41 | + model_type = "moment" |
| 42 | + |
| 43 | + def __init__( |
| 44 | + self, |
| 45 | + seq_len: int = 512, |
| 46 | + patch_len: int = 8, |
| 47 | + patch_stride_len: int = 8, |
| 48 | + d_model: Optional[int] = None, |
| 49 | + transformer_backbone: str = "google/flan-t5-large", |
| 50 | + forecast_horizon: int = 96, |
| 51 | + revin_affine: bool = False, |
| 52 | + t5_config: Optional[dict] = None, |
| 53 | + **kwargs, |
| 54 | + ): |
| 55 | + self.seq_len = seq_len |
| 56 | + self.patch_len = patch_len |
| 57 | + self.patch_stride_len = patch_stride_len |
| 58 | + self.transformer_backbone = transformer_backbone |
| 59 | + self.forecast_horizon = forecast_horizon |
| 60 | + self.revin_affine = revin_affine |
| 61 | + self.t5_config = t5_config |
| 62 | + |
| 63 | + # Infer d_model: prefer explicit value, then t5_config, then default |
| 64 | + if d_model is not None: |
| 65 | + self.d_model = d_model |
| 66 | + elif t5_config is not None and "d_model" in t5_config: |
| 67 | + self.d_model = t5_config["d_model"] |
| 68 | + else: |
| 69 | + self.d_model = 1024 # Default for MOMENT-1-large |
| 70 | + |
| 71 | + super().__init__(**kwargs) |
0 commit comments