Skip to content

Commit f9289a2

Browse files
committed
refactorizing code for readability.
Former-commit-id: e62f4ee [formerly e62f4ee [formerly 16ac32e]] Former-commit-id: 3dc300030e7d4db27f3dc5c7c4453bf43984a9af Former-commit-id: 1353b1e
1 parent b800124 commit f9289a2

File tree

6 files changed

+198
-624
lines changed

6 files changed

+198
-624
lines changed

python/example1-lgss.py

Lines changed: 21 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,99 +1,51 @@
1-
##############################################################################
2-
#
3-
# Example of state estimation in a LGSS model
4-
# using particle filters and Kalman filters
5-
#
6-
# Copyright (C) 2017 Johan Dahlin < liu (at) johandahlin.com.nospam >
7-
#
8-
# This program is free software; you can redistribute it and/or modify
9-
# it under the terms of the GNU General Public License as published by
10-
# the Free Software Foundation; either version 2 of the License, or
11-
# (at your option) any later version.
12-
#
13-
# This program is distributed in the hope that it will be useful,
14-
# but WITHOUT ANY WARRANTY; without even the implied warranty of
15-
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16-
# GNU General Public License for more details.
17-
#
18-
# You should have received a copy of the GNU General Public License along
19-
# with this program; if not, write to the Free Software Foundation, Inc.,
20-
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21-
#
22-
##############################################################################
1+
# State estimation in a LGSS model using particle and Kalman filters
232

243
from __future__ import print_function, division
25-
26-
# Import libraries
274
import matplotlib.pylab as plt
285
import numpy as np
296

30-
# Import helpers
317
from helpers.dataGeneration import generateData
328
from helpers.stateEstimation import particleFilter, kalmanFilter
339

3410
# Set the random seed to replicate results in tutorial
3511
np.random.seed(10)
3612

37-
38-
#=============================================================================
3913
# Define the model
40-
#=============================================================================
41-
42-
# Here, we use the following model
43-
#
44-
# x[t + 1] = phi * x[t] + sigmav * v[t]
45-
# y[t] = x[t] + sigmae * e[t]
46-
#
47-
# where v[t] ~ N(0, 1) and e[t] ~ N(0, 1)
48-
49-
# Set the parameters of the model (phi, sigmav, sigmae)
50-
theta = np.zeros(3)
51-
theta[0] = 0.75
52-
theta[1] = 1.00
53-
theta[2] = 0.10
54-
55-
# Set the number of time steps to simulate
56-
T = 250
57-
58-
# Set the initial state
14+
# x[t + 1] = phi * x[t] + sigmav * v[t], v[t] ~ N(0, 1)
15+
# y[t] = x[t] + sigmae * e[t], e[t] ~ N(0, 1)
16+
17+
# Set the parameters of the model theta=(phi, sigmav, sigmae), T, x_0
18+
parameters = np.zeros(3)
19+
parameters[0] = 0.75
20+
parameters[1] = 1.00
21+
parameters[2] = 0.10
22+
noObservations = 250
5923
initialState = 0
6024

61-
62-
#=============================================================================
6325
# Generate data
64-
#=============================================================================
65-
66-
x, y = generateData(theta, T, initialState)
26+
state, observations = generateData(parameters, noObservations, initialState)
6727

68-
# Plot the measurement
28+
# Plot data
6929
plt.subplot(3, 1, 1)
70-
plt.plot(y, color='#1B9E77', linewidth=1.5)
30+
plt.plot(observations, color='#1B9E77', linewidth=1.5)
7131
plt.xlabel("time")
7232
plt.ylabel("measurement")
7333

74-
# Plot the states
7534
plt.subplot(3, 1, 2)
76-
plt.plot(x, color='#D95F02', linewidth=1.5)
35+
plt.plot(state, color='#D95F02', linewidth=1.5)
7736
plt.xlabel("time")
7837
plt.ylabel("latent state")
7938

39+
# State estimation using particle filter with 100 particles
40+
xHatPF, _ = particleFilter(observations, parameters, 100, initialState)
8041

81-
#=============================================================================
82-
# State estimation
83-
#=============================================================================
84-
85-
# Using N = 100 particles and plot the estimate of the latent state
86-
xHatFilteredParticleFilter, _ = particleFilter(y, theta, 100, initialState)
87-
88-
# Using the Kalman filter
89-
xHatFilteredKalmanFilter = kalmanFilter(y, theta, initialState, 0.01)
42+
# State estimation using the Kalman filter
43+
xHatKF = kalmanFilter(observations, parameters, initialState, 0.01)
9044

45+
# Plot state estimate
9146
plt.subplot(3, 1, 3)
92-
plt.plot(xHatFilteredKalmanFilter[1:T] - xHatFilteredParticleFilter[0:T-1], color='#7570B3', linewidth=1.5)
47+
plt.plot(xHatKF[1:noObservations] - xHatPF[0:noObservations-1], color='#7570B3', linewidth=1.5)
9348
plt.xlabel("time")
9449
plt.ylabel("difference in estimate")
95-
plt.show()
9650

97-
##############################################################################
98-
# End of file
99-
##############################################################################
51+
plt.show()

python/example2-lgss.py

Lines changed: 30 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -1,130 +1,72 @@
1-
##############################################################################
2-
#
3-
# Example of particle Metropolis-Hastings in a LGSS model.
4-
#
5-
# Copyright (C) 2017 Johan Dahlin < liu (at) johandahlin.com.nospam >
6-
#
7-
# This program is free software; you can redistribute it and/or modify
8-
# it under the terms of the GNU General Public License as published by
9-
# the Free Software Foundation; either version 2 of the License, or
10-
# (at your option) any later version.
11-
#
12-
# This program is distributed in the hope that it will be useful,
13-
# but WITHOUT ANY WARRANTY; without even the implied warranty of
14-
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15-
# GNU General Public License for more details.
16-
#
17-
# You should have received a copy of the GNU General Public License along
18-
# with this program; if not, write to the Free Software Foundation, Inc.,
19-
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
20-
#
21-
##############################################################################
1+
# Parameter estimation using particle Metropolis-Hastings in a LGSS model.
222

233
from __future__ import print_function, division
24-
25-
# Import libraries
264
import matplotlib.pylab as plt
275
import numpy as np
286

29-
# Import helpers
307
from helpers.dataGeneration import generateData
318
from helpers.stateEstimation import particleFilter, kalmanFilter
329
from helpers.parameterEstimation import particleMetropolisHastings
3310

34-
# Set the random seed
11+
# Set the random seed to replicate results in tutorial
3512
np.random.seed(10)
3613

37-
38-
#=============================================================================
3914
# Define the model
40-
#=============================================================================
41-
42-
# Here, we use the following model
43-
#
44-
# x[t + 1] = phi * x[t] + sigmav * v[t]
45-
# y[t] = x[t] + sigmae * e[t]
46-
#
47-
# where v[t] ~ N(0, 1) and e[t] ~ N(0, 1)
48-
49-
# Set the parameters of the model (phi, sigmav, sigmae)
50-
theta = np.zeros(3)
51-
theta[0] = 0.75
52-
theta[1] = 1.00
53-
theta[2] = 0.10
54-
55-
# Set the number of time steps to simulate
56-
T = 250
57-
58-
# Set the initial state
15+
# x[t + 1] = phi * x[t] + sigmav * v[t], v[t] ~ N(0, 1)
16+
# y[t] = x[t] + sigmae * e[t], e[t] ~ N(0, 1)
17+
18+
# Set the parameters of the model theta = (phi, sigmav, sigmae), T, x_0
19+
parameters = np.zeros(3)
20+
parameters[0] = 0.75
21+
parameters[1] = 1.00
22+
parameters[2] = 0.10
23+
noObservations = 250
5924
initialState = 0
6025

61-
62-
#=============================================================================
63-
# Generate data
64-
#=============================================================================
65-
66-
x, y = generateData(theta, T, initialState)
67-
68-
69-
#=============================================================================
70-
# Parameter estimation using PMH
71-
#=============================================================================
72-
73-
# The inital guess of the parameter
26+
# Settings for PMH
7427
initialPhi = 0.50
75-
76-
# No. particles in the particle filter ( choose noParticles ~ T )
77-
noParticles = 500
78-
79-
# The length of the burn-in and the no. iterations of PMH
80-
# ( noBurnInIterations < noIterations )
28+
noParticles = 500 # Use noParticles ~ noObservations
8129
noBurnInIterations = 1000
8230
noIterations = 5000
83-
84-
# The standard deviation in the random walk proposal
8531
stepSize = 0.10
8632

87-
# Run the PMH algorithm
88-
res = particleMetropolisHastings(y, initialPhi, theta, noParticles, initialState,
89-
particleFilter, noIterations, stepSize)
33+
# Generate data
34+
state, observations = generateData(parameters, noObservations, initialState)
9035

36+
# Run the PMH algorithm
37+
phiTrace = particleMetropolisHastings(
38+
observations, initialPhi, parameters, noParticles,
39+
initialState, particleFilter, noIterations, stepSize)
9140

92-
#=============================================================================
9341
# Plot the results
94-
#=============================================================================
95-
9642
noBins = int(np.floor(np.sqrt(noIterations - noBurnInIterations)))
9743
grid = np.arange(noBurnInIterations, noIterations, 1)
98-
resPhi = res[noBurnInIterations:noIterations]
44+
phiTrace = phiTrace[noBurnInIterations:noIterations]
9945

100-
# Plot the parameter posterior estimate
101-
# Solid black line indicate posterior mean
46+
# Plot the parameter posterior estimate (solid black line = posterior mean)
10247
plt.subplot(3, 1, 1)
103-
plt.hist(resPhi, noBins, normed=1, facecolor='#7570B3')
48+
plt.hist(phiTrace, noBins, normed=1, facecolor='#7570B3')
10449
plt.xlabel("phi")
10550
plt.ylabel("posterior density estimate")
106-
plt.axvline(np.mean(resPhi), color='k')
51+
plt.axvline(np.mean(phiTrace), color='k')
10752

108-
# Plot the trace of the Markov chain after burn-in
109-
# Solid black line indicate posterior mean
53+
# Plot the trace of the Markov chain after burn-in (solid black line = posterior mean)
11054
plt.subplot(3, 1, 2)
111-
plt.plot(grid, resPhi, color='#7570B3')
55+
plt.plot(grid, phiTrace, color='#7570B3')
11256
plt.xlabel("iteration")
11357
plt.ylabel("phi")
114-
plt.axhline(np.mean(resPhi), color='k')
58+
plt.axhline(np.mean(phiTrace), color='k')
11559

11660
# Plot the autocorrelation function
11761
plt.subplot(3, 1, 3)
118-
macf = np.correlate(resPhi - np.mean(resPhi), resPhi - np.mean(resPhi), mode='full')
119-
macf = macf[macf.size/2:]
62+
macf = np.correlate(phiTrace - np.mean(phiTrace), phiTrace - np.mean(phiTrace), mode='full')
63+
idx = int(macf.size/2)
64+
macf = macf[idx:]
12065
macf = macf[0:100]
12166
macf /= macf[0]
12267
grid = range(len(macf))
12368
plt.plot(grid, macf, color='#7570B3')
12469
plt.xlabel("lag")
12570
plt.ylabel("ACF of phi")
12671

127-
128-
##############################################################################
129-
# End of file
130-
##############################################################################
72+
plt.show()

0 commit comments

Comments
 (0)