Skip to content

Commit 3000aef

Browse files
hombitclaude
andauthored
feat(embed): add MOMENT-1 embedding models (small/base/large) (#795)
* feat(embed): add MOMENT-1 embedding models (small/base/large) Add `light_curve.embed.Moment1`, an ONNX-backed univariate (magnitude-only) wrapper for the MOMENT-1 time-series foundation model (Goswami et al. 2024, MIT), exposing `mean` / `sequence` outputs in small/base/large sizes (512/768/1024-dim). Like the Chronos models it discards timestamps and treats observations as sequentially ordered, but uses a fixed 512-observation context (64 patches of 8) rather than a dynamic sequence axis. `size` is a required `from_hf` argument. Because the context length is fixed, all reduction windows are batched into a single ONNX call; the `sequence` output is restricted to single-window reductions. Also bump the prep-models test submodule to the commit that adds the MOMENT export + reference test data, and add tests, API/docs pages, and a CHANGELOG entry. The Python pipeline was validated locally against the prep-models reference parquet (cosine similarity 1.000000 for all three sizes). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01KEKNjm8oSYzw276okfggzc * test(embed): bump prep-models to merged MOMENT commit, full reference rows The MOMENT prep-models PR merged to main and the reference test-data parquets were regenerated with the full 10 samples. Bump the prep-models submodule 10eda06→a35f0d1 (was the unmerged feat/moment1 commit) and run the reference test over all 10 rows, matching the Chronos tests. Verified locally: all MOMENT tests pass against the now-published HuggingFace models (light-curve/moment1-{small,base,large}); cosine similarity 1.000000. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01KEKNjm8oSYzw276okfggzc --------- Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
1 parent 125ed0c commit 3000aef

7 files changed

Lines changed: 413 additions & 2 deletions

File tree

CHANGELOG.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99

1010
### Added
1111

12-
--
12+
- `light_curve.embed.Moment1`: ONNX-backed univariate (magnitude-only) MOMENT-1 time-series
13+
foundation model (Goswami et al. 2024, MIT license), exposing `mean` / `sequence` outputs.
14+
Available in small/base/large sizes (512/768/1024-dim) via `Moment1.from_hf(size=...)`; uses a
15+
fixed 512-observation context (64 patches of 8).
1316

1417
### Changed
1518

docs/embed/api.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@
4242
inherited_members: true
4343
members:
4444
- from_hf
45+
::: light_curve.embed.Moment1
46+
options:
47+
inherited_members: true
48+
members:
49+
- from_hf
4550

4651
## Reduction strategies
4752

docs/embed/index.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ If you already have the ONNX model file locally, `huggingface_hub` is not requir
3131
| `ATCAT` | 6 (ugrizY jointly) | time, flux, flux\_err, band index | 384 | ELAsTiCC |
3232
| `Chronos2` | single (magnitude only) | mag | 768 | time-series corpus |
3333
| `ChronosBolt` | single (magnitude only) | mag | 256–768 | time-series corpus |
34+
| `Moment1` | single (magnitude only) | mag | 512–1024 | Time-series Pile |
3435

3536
## Single-band: Astromer2
3637

@@ -210,6 +211,29 @@ Series longer than the native context (8192 for Chronos 2, 2048 for
210211
Chronos-Bolt) are reduced first; the default `reduction="end"` keeps the most
211212
recent observations.
212213

214+
## Single-band: MOMENT-1
215+
216+
[MOMENT](https://huggingface.co/AutonLab/MOMENT-1-base) is a T5-based time-series
217+
foundation model. Like Chronos it embeds a **magnitude sequence only** with
218+
timestamps discarded. It comes in three sizes (`small`/`base`/`large`
219+
512/768/1024-dim) and uses a **fixed** 512-observation context (64 patches of 8):
220+
221+
```python
222+
import numpy as np
223+
from light_curve.embed import Moment1
224+
225+
rng = np.random.default_rng(7)
226+
mag = rng.normal(18.0, 0.3, 150).astype(np.float32) # chronological order
227+
228+
model = Moment1.from_hf(size="base", output="mean")
229+
embedding = model(mag)
230+
print(embedding.shape) # (1, 1, 1, 768)
231+
```
232+
233+
Light curves longer than 512 observations are reduced first; the default
234+
`reduction="end"` keeps the most recent 512. The `"sequence"` output always has
235+
64 patches and supports only single-window reductions.
236+
213237
## GPU and alternative runtimes
214238

215239
Pass `ort_session_kwargs` to select an execution provider:

light-curve/light_curve/embed/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .atcat import ATCAT
55
from .chronos import Chronos2, ChronosBolt
66
from .model import EmbeddingSession, SingleBandModel
7+
from .moment import Moment1
78
from .reduction import (
89
Beginning,
910
End,
@@ -27,6 +28,7 @@
2728
"EmbeddingSession",
2829
"End",
2930
"Middle",
31+
"Moment1",
3032
"MultipleReductions",
3133
"NonOverlappingWindows",
3234
"RandomSubsample",
Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass, field
4+
from typing import TYPE_CHECKING, Literal
5+
6+
import numpy as np
7+
from numpy.typing import ArrayLike
8+
9+
from light_curve.embed.input_tensors import InputTensors
10+
from light_curve.embed.model import (
11+
SingleBandModel,
12+
_hf_hub_download_cached,
13+
create_onnx_session,
14+
)
15+
from light_curve.embed.reduction import Reduction
16+
17+
if TYPE_CHECKING:
18+
from typing import Self
19+
20+
import onnxruntime as ort
21+
22+
# MOMENT has a fixed 512-step context split into 64 non-overlapping patches of 8.
23+
_SEQ_LEN = 512
24+
_PATCH_SIZE = 8
25+
26+
27+
@dataclass
28+
class MomentInputs(InputTensors):
29+
"""Input tensors for MOMENT-1 models.
30+
31+
Attributes
32+
----------
33+
mag : ndarray, shape ``(n_subsamples, seq_size)``
34+
Per-subsample magnitudes, zero-padded to the reduction's ``seq_size``.
35+
The actual model context (left NaN-padded to the fixed 512-step window)
36+
is built per subsample at inference time from the valid entries.
37+
bool_mask : ndarray, shape ``(n_subsamples, seq_size)``
38+
Boolean validity — ``True`` for real observations, ``False`` for padding.
39+
"""
40+
41+
mag: np.ndarray = field(kw_only=True)
42+
43+
44+
class Moment1(SingleBandModel):
45+
"""MOMENT-1 univariate light-curve embedding model.
46+
47+
A T5-based time-series foundation model (Goswami et al. 2024) pretrained with
48+
a masked-reconstruction objective on the Time-series Pile. It embeds a single
49+
univariate magnitude series: timestamps are discarded and observations are
50+
treated as sequentially ordered (the same convention used for the Chronos
51+
models). The series is capped to the most recent 512 observations and
52+
left-padded with NaN to that fixed window; reversible instance normalisation
53+
(RevIN) is applied internally by the model.
54+
55+
The model comes in three sizes with different embedding dimensions: ``small``
56+
(512), ``base`` (768), and ``large`` (1024). Unlike Chronos, the context
57+
length is fixed at 512 observations (64 patches of 8), not a dynamic axis.
58+
59+
The ONNX models are hosted on HuggingFace at
60+
``https://huggingface.co/light-curve/moment1-<size>``.
61+
62+
Use :meth:`from_hf` (with ``size=``) to download and load the model.
63+
64+
Model license
65+
-------------
66+
MIT (upstream AutonLab/MOMENT-1 license).
67+
68+
References
69+
----------
70+
Goswami et al. (2024), *MOMENT: A Family of Open Time-series Foundation
71+
Models*, ICML 2024. https://huggingface.co/AutonLab/MOMENT-1-base
72+
73+
Parameters
74+
----------
75+
session :
76+
ONNX inference session for the MOMENT-1 model file.
77+
size : {"small", "base", "large"}
78+
Which model size this session corresponds to (sets ``embed_dim``).
79+
output : str, optional
80+
``"mean"`` (default) or ``"sequence"``.
81+
reduction : str, list of str, or Reduction, optional
82+
Observation-selection strategy for light curves longer than 512.
83+
Defaults to ``"end"``.
84+
reduction_kwargs : dict, optional
85+
Extra keyword arguments forwarded to :func:`reduction_from_str`.
86+
"""
87+
88+
patch_size: int = _PATCH_SIZE
89+
seq_len: int = _SEQ_LEN
90+
max_obs: int = _SEQ_LEN
91+
model_outputs: frozenset[str] = frozenset({"mean", "sequence"})
92+
_EMBED_DIMS: dict[str, int] = {"small": 512, "base": 768, "large": 1024}
93+
94+
def __init__(
95+
self,
96+
session: ort.InferenceSession,
97+
*,
98+
size: Literal["small", "base", "large"],
99+
output: Literal["mean", "sequence"] = "mean",
100+
reduction: str | list[str] | Reduction = "end",
101+
reduction_kwargs: dict[str, object] | None = None,
102+
) -> None:
103+
if size not in self._EMBED_DIMS:
104+
raise ValueError(f"Unknown size '{size}'. Must be one of: {', '.join(sorted(self._EMBED_DIMS))}")
105+
self.size = size
106+
self.embed_dim = self._EMBED_DIMS[size]
107+
self.hf_repo = f"light-curve/moment1-{size}"
108+
self.hf_filename = f"moment1-{size}.onnx"
109+
super().__init__(
110+
session,
111+
bands=None,
112+
reduction=reduction,
113+
reduction_kwargs=reduction_kwargs,
114+
)
115+
if output not in self.model_outputs:
116+
raise ValueError(f"Unknown output '{output}'. Must be one of: {', '.join(sorted(self.model_outputs))}")
117+
self.output = output
118+
119+
@classmethod
120+
def from_hf(
121+
cls,
122+
size: str,
123+
output: str = "mean",
124+
*,
125+
reduction: str | list[str] | Reduction = "end",
126+
reduction_kwargs: dict[str, object] | None = None,
127+
ort_session_kwargs: dict[str, object] | None = None,
128+
) -> Self:
129+
"""Load a MOMENT-1 model of the given ``size`` from the HuggingFace Hub.
130+
131+
Downloads (and caches) the ONNX model file, creates an
132+
``onnxruntime.InferenceSession``, and returns a ready-to-use instance.
133+
134+
Parameters
135+
----------
136+
size : {"small", "base", "large"}
137+
Model size to load. Required: the sizes have different embedding
138+
dimensions, so there is no meaningful default.
139+
output : str, optional
140+
Named ONNX output to return: ``"mean"`` (default, masked mean pool
141+
over valid patches → ``(..., 1, embed_dim)``) or ``"sequence"``
142+
(per-patch encoder states → ``(..., 64, embed_dim)``).
143+
reduction : str, list of str, or Reduction, optional
144+
Observation-selection strategy for light curves longer than 512.
145+
Defaults to ``"end"`` (the most recent 512 observations, matching the
146+
model's native right-aligned context).
147+
reduction_kwargs : dict or None, optional
148+
Extra keyword arguments forwarded to :func:`reduction_from_str`.
149+
ort_session_kwargs : dict or None, optional
150+
Keyword arguments forwarded to ``onnxruntime.InferenceSession``.
151+
152+
Returns
153+
-------
154+
Moment1
155+
Instance with a live ONNX inference session.
156+
157+
Raises
158+
------
159+
ValueError
160+
If ``size`` or ``output`` is not recognised.
161+
ImportError
162+
If ``huggingface_hub`` or an ``onnxruntime`` variant is missing.
163+
"""
164+
if size not in cls._EMBED_DIMS:
165+
raise ValueError(f"Unknown size '{size}'. Must be one of: {', '.join(sorted(cls._EMBED_DIMS))}")
166+
model_path = _hf_hub_download_cached(f"light-curve/moment1-{size}", f"moment1-{size}.onnx")
167+
session = create_onnx_session(model_path, **(ort_session_kwargs or {}))
168+
return cls(
169+
session=session,
170+
size=size,
171+
output=output,
172+
reduction=reduction,
173+
reduction_kwargs=reduction_kwargs,
174+
)
175+
176+
def __call__(self, mag: ArrayLike) -> np.ndarray:
177+
"""Embed a magnitude series.
178+
179+
Parameters
180+
----------
181+
mag : array-like, shape ``(n,)``
182+
Magnitudes in chronological order. Timestamps are not used by the
183+
model, which treats observations as sequentially ordered.
184+
185+
Returns
186+
-------
187+
np.ndarray, shape ``(1, n_subsamples, seq_size, embed_dim)``
188+
Embedding tensor. ``seq_size`` is 1 for ``"mean"`` and 64 (the
189+
number of patches) for ``"sequence"``.
190+
"""
191+
return super().__call__(mag)
192+
193+
def preprocess_lc(self, mag: ArrayLike) -> MomentInputs:
194+
"""Select observations per the reduction; padding to the fixed window is deferred.
195+
196+
Parameters
197+
----------
198+
mag : array-like, shape ``(n,)``
199+
Magnitudes in chronological order.
200+
201+
Returns
202+
-------
203+
MomentInputs
204+
"""
205+
mag = np.asarray(mag, dtype=np.float32)
206+
mag_win, bool_mask = self.reduction.preprocess_lc(mag, seq_size=self.max_obs)
207+
return MomentInputs(bool_mask=bool_mask, mag=mag_win.astype(np.float32))
208+
209+
def _context(self, mag: np.ndarray) -> np.ndarray:
210+
"""Left-pad valid magnitudes with NaN to the fixed 512-step window."""
211+
mag = mag[-self.seq_len :]
212+
n = mag.shape[0]
213+
context = np.full((1, self.seq_len), np.nan, dtype=np.float32)
214+
context[0, self.seq_len - n :] = mag
215+
return context
216+
217+
def predict_tensors(self, tensors: MomentInputs) -> np.ndarray:
218+
"""Run the ONNX model per subsample and return reduced embeddings.
219+
220+
Because MOMENT's context length is fixed (512), all subsamples share the
221+
same shape and are batched into a single ONNX call.
222+
223+
Parameters
224+
----------
225+
tensors : MomentInputs
226+
As returned by :meth:`preprocess_lc`.
227+
228+
Returns
229+
-------
230+
np.ndarray, shape ``(n_subsamples, seq_size, embed_dim)``
231+
Embeddings after applying the reduction's aggregation. ``seq_size``
232+
is 1 for ``"mean"`` and 64 for ``"sequence"``.
233+
234+
Raises
235+
------
236+
ValueError
237+
For the ``"sequence"`` output with a multi-window reduction: the
238+
reduction's per-window aggregation operates in observation space,
239+
which does not align with the fixed 64-patch sequence.
240+
"""
241+
n_subsamples = tensors.bool_mask.shape[0]
242+
if self.output == "sequence" and n_subsamples != 1:
243+
raise ValueError(
244+
"The 'sequence' output supports only single-subsample reductions for MOMENT "
245+
"(per-window aggregation operates in observation space, which does not align "
246+
"with the fixed 64-patch sequence)."
247+
)
248+
249+
contexts = np.concatenate(
250+
[self._context(tensors.mag[i][tensors.bool_mask[i]]) for i in range(n_subsamples)],
251+
axis=0,
252+
) # (n_subsamples, 512)
253+
(raw,) = self.session.run([self.output], {"context": contexts})
254+
# mean: (n_subsamples, embed_dim); sequence: (n_subsamples, 64, embed_dim)
255+
256+
if self.output == "mean":
257+
embeddings = raw[:, np.newaxis, :] # (n_subsamples, 1, embed_dim)
258+
else:
259+
embeddings = raw # (1, 64, embed_dim)
260+
return self.reduction.reduce_embeddings(embeddings, tensors, output=self.output)

0 commit comments

Comments
 (0)