Skip to content

Commit ae4d420

Browse files
committed
update readme and docs
1 parent a588eac commit ae4d420

3 files changed

Lines changed: 96 additions & 39 deletions

File tree

README.md

Lines changed: 96 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,95 @@
1-
<h1 align="left">
2-
<img src="public/logo.png" width="100">
3-
</h1>
41

52
[![PyPI](https://img.shields.io/pypi/v/dynaris)](https://pypi.org/project/dynaris/)
63
[![GitHub](https://img.shields.io/github/license/quant-sci/dynaris)](https://github.com/quant-sci/dynaris/blob/main/LICENSE)
74
[![Documentation Status](https://readthedocs.org/projects/dynaris/badge/?version=latest)](https://dynaris.readthedocs.io/en/latest/?badge=latest)
85

9-
**dynaris** is a JAX-powered Python library for Dynamic Linear Models -- from composable DLM components to Kalman filtering, smoothing, forecasting, and parameter estimation, all with automatic differentiation.
6+
**dynaris** is a JAX-powered Python library for state-space models -- from composable DLMs to nonlinear filters, switching systems, Bayesian estimation, and dynamic factor models, all with automatic differentiation and GPU acceleration.
107

118
## Installation
129

1310
```bash
1411
pip install dynaris
15-
# or
16-
uv add dynaris
12+
13+
# With Bayesian estimation support
14+
pip install dynaris[bayesian]
1715
```
1816

1917
## Documentation
2018

21-
Full documentation is available at [dynaris.readthedocs.io](https://dynaris.readthedocs.io).
19+
Full documentation at [dynaris.readthedocs.io](https://dynaris.readthedocs.io).
2220

2321
## Quickstart
2422

25-
```python
26-
from dynaris import LocalLevel, DLM
27-
from dynaris.datasets import load_nile
28-
29-
# Load data
30-
y = load_nile()
23+
### DLM: Trend + Seasonality
3124

32-
# Build a local-level model and fit
33-
dlm = DLM(LocalLevel(sigma_level=38.0, sigma_obs=123.0))
34-
dlm.fit(y).smooth()
25+
```python
26+
from dynaris import LocalLinearTrend, Seasonal, DLM
27+
from dynaris.datasets import load_airline
3528

36-
# Forecast and plot
37-
fc = dlm.forecast(steps=10)
38-
print(dlm.summary())
29+
model = LocalLinearTrend() + Seasonal(period=12)
30+
dlm = DLM(model)
31+
dlm.fit(load_airline()).smooth()
32+
dlm.forecast(steps=24)
3933
dlm.plot(kind="panel")
4034
```
4135

42-
## Components
36+
### Nonlinear Filtering
37+
38+
```python
39+
from dynaris import SSM, LorenzAttractor
40+
41+
model = LorenzAttractor(dt=0.01, obs_noise=2.0)
42+
ssm = SSM(model, filter="ukf") # auto-selects UKF for nonlinear models
43+
ssm.fit(observations)
44+
```
4345

44-
Build models by combining components with `+`:
46+
### Regime Switching
4547

4648
```python
47-
from dynaris import LocalLinearTrend, Seasonal, Cycle
49+
from dynaris import LocalLevel, MarkovSwitchingSSM
50+
from dynaris.filters import hamilton_filter
51+
from dynaris.smoothers import kim_smooth
52+
import jax.numpy as jnp
4853

49-
model = (
50-
LocalLinearTrend(sigma_level=1.0, sigma_slope=0.1)
51-
+ Seasonal(period=12, sigma_seasonal=0.5)
52-
+ Cycle(period=40, damping=0.95)
54+
switching = MarkovSwitchingSSM(
55+
models=(LocalLevel(1, 5), LocalLevel(5, 20)),
56+
transition_matrix=jnp.array([[0.95, 0.05], [0.10, 0.90]]),
57+
initial_probs=jnp.array([0.5, 0.5]),
5358
)
59+
result = hamilton_filter(switching, observations)
60+
smoothed = kim_smooth(switching, result)
61+
```
62+
63+
### Bayesian Estimation
64+
65+
```python
66+
from dynaris import LocalLevel, fit_bayesian
67+
from dynaris.estimation.priors import inverse_gamma_log_prior
68+
69+
def model_fn(params):
70+
return LocalLevel(sigma_level=jnp.exp(params[0]), sigma_obs=jnp.exp(params[1]))
71+
72+
result = fit_bayesian(model_fn, observations, jnp.zeros(2),
73+
log_prior_fn=inverse_gamma_log_prior(shape=2.0, scale=1.0))
74+
# result.samples -> (n_samples, n_params) posterior draws
5475
```
5576

77+
### Dynamic Factor Models
78+
79+
```python
80+
from dynaris.models import DFMModel
81+
82+
dfm = DFMModel(n_factors=2)
83+
dfm.fit(panel_data) # (T, m) multivariate panel
84+
print(dfm.loadings_df())
85+
print(dfm.factor_states_df())
86+
dfm.forecast(steps=12)
87+
```
88+
89+
## Components
90+
91+
Build DLMs by combining components with `+`:
92+
5693
| Component | State dim | Description |
5794
|-----------|-----------|-------------|
5895
| `LocalLevel` | 1 | Random walk + noise |
@@ -62,22 +99,34 @@ model = (
6299
| `Autoregressive` | order | AR(p) in companion form |
63100
| `Cycle` | 2 | Damped stochastic sinusoid |
64101

102+
## Filters & Smoothers
103+
104+
| Algorithm | Model type | Use case |
105+
|-----------|-----------|----------|
106+
| Kalman filter | Linear | Exact inference for DLMs |
107+
| Extended KF (EKF) | Nonlinear | First-order linearization |
108+
| Unscented KF (UKF) | Nonlinear | Sigma-point propagation |
109+
| Particle filter (SMC) | Any | Non-Gaussian, multi-modal |
110+
| Hamilton filter | Switching | Markov regime models |
111+
| RTS smoother | Linear | Retrospective state estimation |
112+
| Kim smoother | Switching | Retrospective regime inference |
113+
65114
## Parameter Estimation
66115

67-
```python
68-
import jax.numpy as jnp
69-
from dynaris import LocalLevel
70-
from dynaris.estimation import fit_mle
116+
| Method | Function | Description |
117+
|--------|----------|-------------|
118+
| MLE | `fit_mle()` | Gradient-based via `jax.grad` + scipy |
119+
| EM | `fit_em()` | Expectation-Maximization for variances |
120+
| Bayesian | `fit_bayesian()` | NUTS/HMC via NumPyro |
121+
| DFM-EM | `fit_dfm_em()` | EM with loading matrix updates |
71122

72-
def model_fn(params):
73-
return LocalLevel(
74-
sigma_level=jnp.exp(params[0]),
75-
sigma_obs=jnp.exp(params[1]),
76-
)
123+
## Built-in Nonlinear Models
77124

78-
result = fit_mle(model_fn, y, init_params=jnp.zeros(2))
79-
print(f"Log-likelihood: {result.log_likelihood:.2f}")
80-
```
125+
| Model | Description |
126+
|-------|-------------|
127+
| `StochasticVolatility` | AR(1) log-volatility (KSC linearization) |
128+
| `BearingsTracking` | 2D constant-velocity target, bearing observations |
129+
| `LorenzAttractor` | Chaotic 3D system (Euler discretization) |
81130

82131
## Datasets
83132

@@ -90,6 +139,15 @@ print(f"Log-likelihood: {result.log_likelihood:.2f}")
90139
| Global temperature | `load_temperature()` | 144 | Annual | Climate |
91140
| US GDP growth | `load_gdp()` | 319 | Quarterly | Economics |
92141

142+
## Performance
143+
144+
All filters run inside `jax.lax.scan` with `@jax.jit` -- GPU/TPU acceleration is automatic. Additional features:
145+
146+
- **Batch processing** via `jax.vmap` for parallel multi-series inference
147+
- **Memory-efficient** long series via `jax.checkpoint` (trade compute for memory)
148+
- **Parallel MCMC** chains via NumPyro's `chain_method="parallel"`
149+
- **Pure NumPy backend** for lightweight / no-GPU environments
150+
93151
## License
94152

95153
MIT License. See [LICENSE](LICENSE) for details.

docs/_static/logo.png

-62.6 KB
Binary file not shown.

docs/conf.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232

3333
html_theme = "furo"
3434
html_title = "dynaris"
35-
html_logo = "_static/logo.png"
3635
html_theme_options = {
3736
"source_repository": "https://github.com/quant-sci/dynaris",
3837
"source_branch": "main",

0 commit comments

Comments
 (0)