-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathexample.py
More file actions
53 lines (42 loc) · 1.64 KB
/
example.py
File metadata and controls
53 lines (42 loc) · 1.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from cwt_learner.wavelet_feature_engineering import CWT_learner
from signal_data_base import SignalDB
import matplotlib.pyplot as plt
from plot_generator import plotResult_colorbars
"""
Usage: Create a CWT_learner
Add training data using add_training_data member function.
Call train member function to train with the training data.
Test with arbitrary data. "
"""
if __name__ == "__main__":
sdb = SignalDB('JLego', path='./sample_data/')
training_data_ = sdb.get_labeleddata()
cwt_learn = CWT_learner(signal_indices = [0,1,2,3])
training_data = training_data_[0:8]
testing_data = training_data_[8:10]
for ld in training_data:
labels = [label.split(' ')[0] for label in ld.labels]
labels = [label.split(' ')[0] for label in ld.labels]
cwt_learn.add_training_data(ld.signal_bundle.signals,labels)
cwt_learn.train()
labels = cwt_learn.fit(testing_data[0].signal_bundle.signals)
# Plotting
plt.figure()
plt.subplot(16, 1, 1)
plt.title("Training Data")
for ii in range(0, 8):
ax = plt.subplot(16, 1, 2 * ii + 1)
plt.plot(training_data[ii].signal_bundle.signals[0])
ax.get_yaxis().set_visible(False)
ax.get_xaxis().set_visible(False)
plt.subplot(16, 1, 2 * ii + 2)
plotResult_colorbars(training_data[ii].labels, range(len(training_data[ii].labels)))
plt.figure()
ax = plt.subplot(2, 1, 1)
plt.title("Test Example")
plt.plot(testing_data[0].signal_bundle.signals[0])
ax.get_yaxis().set_visible(False)
ax.get_xaxis().set_visible(False)
plt.subplot(2, 1, 2)
plotResult_colorbars(labels, range(len(labels)))
plt.show()