|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from dataclasses import dataclass, field |
| 4 | +from typing import TYPE_CHECKING, Literal |
| 5 | + |
| 6 | +import numpy as np |
| 7 | +from numpy.typing import ArrayLike |
| 8 | + |
| 9 | +from light_curve.embed.input_tensors import InputTensors |
| 10 | +from light_curve.embed.model import ( |
| 11 | + SingleBandModel, |
| 12 | + _hf_hub_download_cached, |
| 13 | + create_onnx_session, |
| 14 | +) |
| 15 | +from light_curve.embed.reduction import Reduction |
| 16 | + |
| 17 | +if TYPE_CHECKING: |
| 18 | + from typing import Self |
| 19 | + |
| 20 | + import onnxruntime as ort |
| 21 | + |
| 22 | +# MOMENT has a fixed 512-step context split into 64 non-overlapping patches of 8. |
| 23 | +_SEQ_LEN = 512 |
| 24 | +_PATCH_SIZE = 8 |
| 25 | + |
| 26 | + |
| 27 | +@dataclass |
| 28 | +class MomentInputs(InputTensors): |
| 29 | + """Input tensors for MOMENT-1 models. |
| 30 | +
|
| 31 | + Attributes |
| 32 | + ---------- |
| 33 | + mag : ndarray, shape ``(n_subsamples, seq_size)`` |
| 34 | + Per-subsample magnitudes, zero-padded to the reduction's ``seq_size``. |
| 35 | + The actual model context (left NaN-padded to the fixed 512-step window) |
| 36 | + is built per subsample at inference time from the valid entries. |
| 37 | + bool_mask : ndarray, shape ``(n_subsamples, seq_size)`` |
| 38 | + Boolean validity — ``True`` for real observations, ``False`` for padding. |
| 39 | + """ |
| 40 | + |
| 41 | + mag: np.ndarray = field(kw_only=True) |
| 42 | + |
| 43 | + |
| 44 | +class Moment1(SingleBandModel): |
| 45 | + """MOMENT-1 univariate light-curve embedding model. |
| 46 | +
|
| 47 | + A T5-based time-series foundation model (Goswami et al. 2024) pretrained with |
| 48 | + a masked-reconstruction objective on the Time-series Pile. It embeds a single |
| 49 | + univariate magnitude series: timestamps are discarded and observations are |
| 50 | + treated as sequentially ordered (the same convention used for the Chronos |
| 51 | + models). The series is capped to the most recent 512 observations and |
| 52 | + left-padded with NaN to that fixed window; reversible instance normalisation |
| 53 | + (RevIN) is applied internally by the model. |
| 54 | +
|
| 55 | + The model comes in three sizes with different embedding dimensions: ``small`` |
| 56 | + (512), ``base`` (768), and ``large`` (1024). Unlike Chronos, the context |
| 57 | + length is fixed at 512 observations (64 patches of 8), not a dynamic axis. |
| 58 | +
|
| 59 | + The ONNX models are hosted on HuggingFace at |
| 60 | + ``https://huggingface.co/light-curve/moment1-<size>``. |
| 61 | +
|
| 62 | + Use :meth:`from_hf` (with ``size=``) to download and load the model. |
| 63 | +
|
| 64 | + Model license |
| 65 | + ------------- |
| 66 | + MIT (upstream AutonLab/MOMENT-1 license). |
| 67 | +
|
| 68 | + References |
| 69 | + ---------- |
| 70 | + Goswami et al. (2024), *MOMENT: A Family of Open Time-series Foundation |
| 71 | + Models*, ICML 2024. https://huggingface.co/AutonLab/MOMENT-1-base |
| 72 | +
|
| 73 | + Parameters |
| 74 | + ---------- |
| 75 | + session : |
| 76 | + ONNX inference session for the MOMENT-1 model file. |
| 77 | + size : {"small", "base", "large"} |
| 78 | + Which model size this session corresponds to (sets ``embed_dim``). |
| 79 | + output : str, optional |
| 80 | + ``"mean"`` (default) or ``"sequence"``. |
| 81 | + reduction : str, list of str, or Reduction, optional |
| 82 | + Observation-selection strategy for light curves longer than 512. |
| 83 | + Defaults to ``"end"``. |
| 84 | + reduction_kwargs : dict, optional |
| 85 | + Extra keyword arguments forwarded to :func:`reduction_from_str`. |
| 86 | + """ |
| 87 | + |
| 88 | + patch_size: int = _PATCH_SIZE |
| 89 | + seq_len: int = _SEQ_LEN |
| 90 | + max_obs: int = _SEQ_LEN |
| 91 | + model_outputs: frozenset[str] = frozenset({"mean", "sequence"}) |
| 92 | + _EMBED_DIMS: dict[str, int] = {"small": 512, "base": 768, "large": 1024} |
| 93 | + |
| 94 | + def __init__( |
| 95 | + self, |
| 96 | + session: ort.InferenceSession, |
| 97 | + *, |
| 98 | + size: Literal["small", "base", "large"], |
| 99 | + output: Literal["mean", "sequence"] = "mean", |
| 100 | + reduction: str | list[str] | Reduction = "end", |
| 101 | + reduction_kwargs: dict[str, object] | None = None, |
| 102 | + ) -> None: |
| 103 | + if size not in self._EMBED_DIMS: |
| 104 | + raise ValueError(f"Unknown size '{size}'. Must be one of: {', '.join(sorted(self._EMBED_DIMS))}") |
| 105 | + self.size = size |
| 106 | + self.embed_dim = self._EMBED_DIMS[size] |
| 107 | + self.hf_repo = f"light-curve/moment1-{size}" |
| 108 | + self.hf_filename = f"moment1-{size}.onnx" |
| 109 | + super().__init__( |
| 110 | + session, |
| 111 | + bands=None, |
| 112 | + reduction=reduction, |
| 113 | + reduction_kwargs=reduction_kwargs, |
| 114 | + ) |
| 115 | + if output not in self.model_outputs: |
| 116 | + raise ValueError(f"Unknown output '{output}'. Must be one of: {', '.join(sorted(self.model_outputs))}") |
| 117 | + self.output = output |
| 118 | + |
| 119 | + @classmethod |
| 120 | + def from_hf( |
| 121 | + cls, |
| 122 | + size: str, |
| 123 | + output: str = "mean", |
| 124 | + *, |
| 125 | + reduction: str | list[str] | Reduction = "end", |
| 126 | + reduction_kwargs: dict[str, object] | None = None, |
| 127 | + ort_session_kwargs: dict[str, object] | None = None, |
| 128 | + ) -> Self: |
| 129 | + """Load a MOMENT-1 model of the given ``size`` from the HuggingFace Hub. |
| 130 | +
|
| 131 | + Downloads (and caches) the ONNX model file, creates an |
| 132 | + ``onnxruntime.InferenceSession``, and returns a ready-to-use instance. |
| 133 | +
|
| 134 | + Parameters |
| 135 | + ---------- |
| 136 | + size : {"small", "base", "large"} |
| 137 | + Model size to load. Required: the sizes have different embedding |
| 138 | + dimensions, so there is no meaningful default. |
| 139 | + output : str, optional |
| 140 | + Named ONNX output to return: ``"mean"`` (default, masked mean pool |
| 141 | + over valid patches → ``(..., 1, embed_dim)``) or ``"sequence"`` |
| 142 | + (per-patch encoder states → ``(..., 64, embed_dim)``). |
| 143 | + reduction : str, list of str, or Reduction, optional |
| 144 | + Observation-selection strategy for light curves longer than 512. |
| 145 | + Defaults to ``"end"`` (the most recent 512 observations, matching the |
| 146 | + model's native right-aligned context). |
| 147 | + reduction_kwargs : dict or None, optional |
| 148 | + Extra keyword arguments forwarded to :func:`reduction_from_str`. |
| 149 | + ort_session_kwargs : dict or None, optional |
| 150 | + Keyword arguments forwarded to ``onnxruntime.InferenceSession``. |
| 151 | +
|
| 152 | + Returns |
| 153 | + ------- |
| 154 | + Moment1 |
| 155 | + Instance with a live ONNX inference session. |
| 156 | +
|
| 157 | + Raises |
| 158 | + ------ |
| 159 | + ValueError |
| 160 | + If ``size`` or ``output`` is not recognised. |
| 161 | + ImportError |
| 162 | + If ``huggingface_hub`` or an ``onnxruntime`` variant is missing. |
| 163 | + """ |
| 164 | + if size not in cls._EMBED_DIMS: |
| 165 | + raise ValueError(f"Unknown size '{size}'. Must be one of: {', '.join(sorted(cls._EMBED_DIMS))}") |
| 166 | + model_path = _hf_hub_download_cached(f"light-curve/moment1-{size}", f"moment1-{size}.onnx") |
| 167 | + session = create_onnx_session(model_path, **(ort_session_kwargs or {})) |
| 168 | + return cls( |
| 169 | + session=session, |
| 170 | + size=size, |
| 171 | + output=output, |
| 172 | + reduction=reduction, |
| 173 | + reduction_kwargs=reduction_kwargs, |
| 174 | + ) |
| 175 | + |
| 176 | + def __call__(self, mag: ArrayLike) -> np.ndarray: |
| 177 | + """Embed a magnitude series. |
| 178 | +
|
| 179 | + Parameters |
| 180 | + ---------- |
| 181 | + mag : array-like, shape ``(n,)`` |
| 182 | + Magnitudes in chronological order. Timestamps are not used by the |
| 183 | + model, which treats observations as sequentially ordered. |
| 184 | +
|
| 185 | + Returns |
| 186 | + ------- |
| 187 | + np.ndarray, shape ``(1, n_subsamples, seq_size, embed_dim)`` |
| 188 | + Embedding tensor. ``seq_size`` is 1 for ``"mean"`` and 64 (the |
| 189 | + number of patches) for ``"sequence"``. |
| 190 | + """ |
| 191 | + return super().__call__(mag) |
| 192 | + |
| 193 | + def preprocess_lc(self, mag: ArrayLike) -> MomentInputs: |
| 194 | + """Select observations per the reduction; padding to the fixed window is deferred. |
| 195 | +
|
| 196 | + Parameters |
| 197 | + ---------- |
| 198 | + mag : array-like, shape ``(n,)`` |
| 199 | + Magnitudes in chronological order. |
| 200 | +
|
| 201 | + Returns |
| 202 | + ------- |
| 203 | + MomentInputs |
| 204 | + """ |
| 205 | + mag = np.asarray(mag, dtype=np.float32) |
| 206 | + mag_win, bool_mask = self.reduction.preprocess_lc(mag, seq_size=self.max_obs) |
| 207 | + return MomentInputs(bool_mask=bool_mask, mag=mag_win.astype(np.float32)) |
| 208 | + |
| 209 | + def _context(self, mag: np.ndarray) -> np.ndarray: |
| 210 | + """Left-pad valid magnitudes with NaN to the fixed 512-step window.""" |
| 211 | + mag = mag[-self.seq_len :] |
| 212 | + n = mag.shape[0] |
| 213 | + context = np.full((1, self.seq_len), np.nan, dtype=np.float32) |
| 214 | + context[0, self.seq_len - n :] = mag |
| 215 | + return context |
| 216 | + |
| 217 | + def predict_tensors(self, tensors: MomentInputs) -> np.ndarray: |
| 218 | + """Run the ONNX model per subsample and return reduced embeddings. |
| 219 | +
|
| 220 | + Because MOMENT's context length is fixed (512), all subsamples share the |
| 221 | + same shape and are batched into a single ONNX call. |
| 222 | +
|
| 223 | + Parameters |
| 224 | + ---------- |
| 225 | + tensors : MomentInputs |
| 226 | + As returned by :meth:`preprocess_lc`. |
| 227 | +
|
| 228 | + Returns |
| 229 | + ------- |
| 230 | + np.ndarray, shape ``(n_subsamples, seq_size, embed_dim)`` |
| 231 | + Embeddings after applying the reduction's aggregation. ``seq_size`` |
| 232 | + is 1 for ``"mean"`` and 64 for ``"sequence"``. |
| 233 | +
|
| 234 | + Raises |
| 235 | + ------ |
| 236 | + ValueError |
| 237 | + For the ``"sequence"`` output with a multi-window reduction: the |
| 238 | + reduction's per-window aggregation operates in observation space, |
| 239 | + which does not align with the fixed 64-patch sequence. |
| 240 | + """ |
| 241 | + n_subsamples = tensors.bool_mask.shape[0] |
| 242 | + if self.output == "sequence" and n_subsamples != 1: |
| 243 | + raise ValueError( |
| 244 | + "The 'sequence' output supports only single-subsample reductions for MOMENT " |
| 245 | + "(per-window aggregation operates in observation space, which does not align " |
| 246 | + "with the fixed 64-patch sequence)." |
| 247 | + ) |
| 248 | + |
| 249 | + contexts = np.concatenate( |
| 250 | + [self._context(tensors.mag[i][tensors.bool_mask[i]]) for i in range(n_subsamples)], |
| 251 | + axis=0, |
| 252 | + ) # (n_subsamples, 512) |
| 253 | + (raw,) = self.session.run([self.output], {"context": contexts}) |
| 254 | + # mean: (n_subsamples, embed_dim); sequence: (n_subsamples, 64, embed_dim) |
| 255 | + |
| 256 | + if self.output == "mean": |
| 257 | + embeddings = raw[:, np.newaxis, :] # (n_subsamples, 1, embed_dim) |
| 258 | + else: |
| 259 | + embeddings = raw # (1, 64, embed_dim) |
| 260 | + return self.reduction.reduce_embeddings(embeddings, tensors, output=self.output) |
0 commit comments