1- <h1 align =" left " >
2- <img src =" public/logo.png " width =" 100 " >
3- </h1 >
41
52[ ![ PyPI] ( https://img.shields.io/pypi/v/dynaris )] ( https://pypi.org/project/dynaris/ )
63[ ![ GitHub] ( https://img.shields.io/github/license/quant-sci/dynaris )] ( https://github.com/quant-sci/dynaris/blob/main/LICENSE )
74[ ![ Documentation Status] ( https://readthedocs.org/projects/dynaris/badge/?version=latest )] ( https://dynaris.readthedocs.io/en/latest/?badge=latest )
85
9- ** dynaris** is a JAX-powered Python library for Dynamic Linear Models -- from composable DLM components to Kalman filtering, smoothing, forecasting , and parameter estimation , all with automatic differentiation.
6+ ** dynaris** is a JAX-powered Python library for state-space models -- from composable DLMs to nonlinear filters, switching systems, Bayesian estimation , and dynamic factor models , all with automatic differentiation and GPU acceleration .
107
118## Installation
129
1310``` bash
1411pip install dynaris
15- # or
16- uv add dynaris
12+
13+ # With Bayesian estimation support
14+ pip install dynaris[bayesian]
1715```
1816
1917## Documentation
2018
21- Full documentation is available at [ dynaris.readthedocs.io] ( https://dynaris.readthedocs.io ) .
19+ Full documentation at [ dynaris.readthedocs.io] ( https://dynaris.readthedocs.io ) .
2220
2321## Quickstart
2422
25- ``` python
26- from dynaris import LocalLevel, DLM
27- from dynaris.datasets import load_nile
28-
29- # Load data
30- y = load_nile()
23+ ### DLM: Trend + Seasonality
3124
32- # Build a local-level model and fit
33- dlm = DLM(LocalLevel( sigma_level = 38.0 , sigma_obs = 123.0 ))
34- dlm.fit(y).smooth()
25+ ``` python
26+ from dynaris import LocalLinearTrend, Seasonal, DLM
27+ from dynaris.datasets import load_airline
3528
36- # Forecast and plot
37- fc = dlm.forecast(steps = 10 )
38- print (dlm.summary())
29+ model = LocalLinearTrend() + Seasonal(period = 12 )
30+ dlm = DLM(model)
31+ dlm.fit(load_airline()).smooth()
32+ dlm.forecast(steps = 24 )
3933dlm.plot(kind = " panel" )
4034```
4135
42- ## Components
36+ ### Nonlinear Filtering
37+
38+ ``` python
39+ from dynaris import SSM , LorenzAttractor
40+
41+ model = LorenzAttractor(dt = 0.01 , obs_noise = 2.0 )
42+ ssm = SSM(model, filter = " ukf" ) # auto-selects UKF for nonlinear models
43+ ssm.fit(observations)
44+ ```
4345
44- Build models by combining components with ` + ` :
46+ ### Regime Switching
4547
4648``` python
47- from dynaris import LocalLinearTrend, Seasonal, Cycle
49+ from dynaris import LocalLevel, MarkovSwitchingSSM
50+ from dynaris.filters import hamilton_filter
51+ from dynaris.smoothers import kim_smooth
52+ import jax.numpy as jnp
4853
49- model = (
50- LocalLinearTrend( sigma_level = 1.0 , sigma_slope = 0.1 )
51- + Seasonal( period = 12 , sigma_seasonal = 0.5 )
52- + Cycle( period = 40 , damping = 0.95 )
54+ switching = MarkovSwitchingSSM (
55+ models = (LocalLevel( 1 , 5 ), LocalLevel( 5 , 20 )),
56+ transition_matrix = jnp.array([[ 0.95 , 0.05 ], [ 0.10 , 0.90 ]]),
57+ initial_probs = jnp.array([ 0.5 , 0.5 ]),
5358)
59+ result = hamilton_filter(switching, observations)
60+ smoothed = kim_smooth(switching, result)
61+ ```
62+
63+ ### Bayesian Estimation
64+
65+ ``` python
66+ from dynaris import LocalLevel, fit_bayesian
67+ from dynaris.estimation.priors import inverse_gamma_log_prior
68+
69+ def model_fn (params ):
70+ return LocalLevel(sigma_level = jnp.exp(params[0 ]), sigma_obs = jnp.exp(params[1 ]))
71+
72+ result = fit_bayesian(model_fn, observations, jnp.zeros(2 ),
73+ log_prior_fn = inverse_gamma_log_prior(shape = 2.0 , scale = 1.0 ))
74+ # result.samples -> (n_samples, n_params) posterior draws
5475```
5576
77+ ### Dynamic Factor Models
78+
79+ ``` python
80+ from dynaris.models import DFMModel
81+
82+ dfm = DFMModel(n_factors = 2 )
83+ dfm.fit(panel_data) # (T, m) multivariate panel
84+ print (dfm.loadings_df())
85+ print (dfm.factor_states_df())
86+ dfm.forecast(steps = 12 )
87+ ```
88+
89+ ## Components
90+
91+ Build DLMs by combining components with ` + ` :
92+
5693| Component | State dim | Description |
5794| -----------| -----------| -------------|
5895| ` LocalLevel ` | 1 | Random walk + noise |
@@ -62,22 +99,34 @@ model = (
6299| ` Autoregressive ` | order | AR(p) in companion form |
63100| ` Cycle ` | 2 | Damped stochastic sinusoid |
64101
102+ ## Filters & Smoothers
103+
104+ | Algorithm | Model type | Use case |
105+ | -----------| -----------| ----------|
106+ | Kalman filter | Linear | Exact inference for DLMs |
107+ | Extended KF (EKF) | Nonlinear | First-order linearization |
108+ | Unscented KF (UKF) | Nonlinear | Sigma-point propagation |
109+ | Particle filter (SMC) | Any | Non-Gaussian, multi-modal |
110+ | Hamilton filter | Switching | Markov regime models |
111+ | RTS smoother | Linear | Retrospective state estimation |
112+ | Kim smoother | Switching | Retrospective regime inference |
113+
65114## Parameter Estimation
66115
67- ``` python
68- import jax.numpy as jnp
69- from dynaris import LocalLevel
70- from dynaris.estimation import fit_mle
116+ | Method | Function | Description |
117+ | --------| ----------| -------------|
118+ | MLE | ` fit_mle() ` | Gradient-based via ` jax.grad ` + scipy |
119+ | EM | ` fit_em() ` | Expectation-Maximization for variances |
120+ | Bayesian | ` fit_bayesian() ` | NUTS/HMC via NumPyro |
121+ | DFM-EM | ` fit_dfm_em() ` | EM with loading matrix updates |
71122
72- def model_fn (params ):
73- return LocalLevel(
74- sigma_level = jnp.exp(params[0 ]),
75- sigma_obs = jnp.exp(params[1 ]),
76- )
123+ ## Built-in Nonlinear Models
77124
78- result = fit_mle(model_fn, y, init_params = jnp.zeros(2 ))
79- print (f " Log-likelihood: { result.log_likelihood:.2f } " )
80- ```
125+ | Model | Description |
126+ | -------| -------------|
127+ | ` StochasticVolatility ` | AR(1) log-volatility (KSC linearization) |
128+ | ` BearingsTracking ` | 2D constant-velocity target, bearing observations |
129+ | ` LorenzAttractor ` | Chaotic 3D system (Euler discretization) |
81130
82131## Datasets
83132
@@ -90,6 +139,15 @@ print(f"Log-likelihood: {result.log_likelihood:.2f}")
90139| Global temperature | ` load_temperature() ` | 144 | Annual | Climate |
91140| US GDP growth | ` load_gdp() ` | 319 | Quarterly | Economics |
92141
142+ ## Performance
143+
144+ All filters run inside ` jax.lax.scan ` with ` @jax.jit ` -- GPU/TPU acceleration is automatic. Additional features:
145+
146+ - ** Batch processing** via ` jax.vmap ` for parallel multi-series inference
147+ - ** Memory-efficient** long series via ` jax.checkpoint ` (trade compute for memory)
148+ - ** Parallel MCMC** chains via NumPyro's ` chain_method="parallel" `
149+ - ** Pure NumPy backend** for lightweight / no-GPU environments
150+
93151## License
94152
95153MIT License. See [ LICENSE] ( LICENSE ) for details.
0 commit comments