Skip to content

Commit 7bf0544

Browse files
committed
add a demo for of swt variance partitioning when norm=True
1 parent a6da076 commit 7bf0544

2 files changed

Lines changed: 63 additions & 1 deletion

File tree

demo/dwt_swt_show_coeffs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def plot_coeffs(data, w, title, use_dwt=True):
7474
use_dwt)
7575
plot_coeffs(ecg, 'sym5', "DWT: Ecg sample - Symmlets5", use_dwt)
7676

77-
# Show DWT coefficients
77+
# Show SWT coefficients
7878
use_dwt = False
7979
plot_coeffs(data1, 'db1', "SWT: Signal irregularity detection - Haar wavelet",
8080
use_dwt)

demo/swt_variance.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
#!/usr/bin/env python
2+
3+
import numpy as np
4+
import matplotlib.pyplot as plt
5+
6+
import pywt
7+
import pywt.data
8+
9+
10+
ecg = pywt.data.ecg()
11+
12+
# set trim_approx to avoid keeping approximation coefficients for all levels
13+
14+
# set norm=True to rescale the wavelets so that the transform partitions the
15+
# variance of the input signal among the various coefficient arrays.
16+
17+
coeffs = pywt.swt(ecg, wavelet='sym4', trim_approx=True, norm=True)
18+
ca = coeffs[0]
19+
details = coeffs[1:]
20+
21+
print("Variance of the ecg signal = {}".format(np.var(ecg, ddof=1)))
22+
23+
24+
variances = [np.var(c, ddof=1) for c in coeffs]
25+
detail_variances = variances[1:]
26+
print("Sum of variance across all SWT coefficients = {}".format(
27+
np.sum(variances)))
28+
29+
30+
# Create a plot using the same y axis limits for all coefficient arrays to
31+
# illustrate the preservation of amplitude scale across levels when norm=True.
32+
ylim = [ecg.min(), ecg.max()]
33+
34+
fig, axes = plt.subplots(len(coeffs) + 1)
35+
axes[0].set_title("normalized SWT decomposition")
36+
axes[0].plot(ecg)
37+
axes[0].set_ylabel('ECG Signal')
38+
axes[0].set_xlim(0, len(ecg) - 1)
39+
axes[0].set_ylim(ylim[0], ylim[1])
40+
41+
for i, x in enumerate(coeffs):
42+
ax = axes[-i - 1]
43+
ax.plot(coeffs[i], 'g')
44+
if i == 0:
45+
ax.set_ylabel("A%d" % (len(coeffs) - 1))
46+
else:
47+
ax.set_ylabel("D%d" % (len(coeffs) - i))
48+
# Scale axes
49+
ax.set_xlim(0, len(ecg) - 1)
50+
ax.set_ylim(ylim[0], ylim[1])
51+
52+
53+
# reorder from first to last level of coefficients
54+
level = np.arange(1, len(detail_variances) + 1)
55+
56+
# create a plot of the variance as a function of level
57+
plt.figure(figsize=(8, 6))
58+
fontdict = dict(fontsize=16, fontweight='bold')
59+
plt.plot(level, detail_variances[::-1], 'k.')
60+
plt.xlabel("Decomposition level", fontdict=fontdict)
61+
plt.ylabel("Variance", fontdict=fontdict)
62+
plt.title("Variances of detail coefficients", fontdict=fontdict)

0 commit comments

Comments
 (0)