Skip to content

Commit 2de830c

Browse files
committed
debugged versatile dataset lead/lag, added unit test.
1 parent 74f5aa8 commit 2de830c

2 files changed

Lines changed: 173 additions & 18 deletions

File tree

src/main/java/org/encog/ml/data/versatile/MatrixMLDataSet.java

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -147,12 +147,22 @@ public long getRecordCount() {
147147
}
148148
return this.mask.length - (this.lagWindowSize + this.leadWindowSize);
149149
}
150+
151+
private int calculateLagCount() {
152+
return (MatrixMLDataSet.this.lagWindowSize <= 0) ? 1: (this.lagWindowSize+1);
153+
}
154+
155+
private int calculateLeadCount() {
156+
return (this.leadWindowSize <= 1) ? 1
157+
: this.leadWindowSize;
158+
}
150159

151160
@Override
152161
public void getRecord(long index, MLDataPair pair) {
153162

154163
// Copy the input, account for time windows.
155-
for (int i = 0; i < MatrixMLDataSet.this.lagWindowSize; i++) {
164+
int inputSize = calculateLagCount();
165+
for (int i = 0; i < inputSize; i++) {
156166
double[] dataRow = lookupDataRow((int) (index + i));
157167

158168
EngineArray.arrayCopy(dataRow, 0, pair.getInput().getData(), i
@@ -161,24 +171,14 @@ public void getRecord(long index, MLDataPair pair) {
161171
}
162172

163173
// Copy the output, account for time windows.
164-
int start = (MatrixMLDataSet.this.leadWindowSize > 0) ? 1 : 0;
165-
int size = (MatrixMLDataSet.this.leadWindowSize <= 1) ? 1
166-
: MatrixMLDataSet.this.leadWindowSize;
167-
for (int i = start; i < size; i++) {
168-
double[] dataRow = lookupDataRow((int) (index + i));
169-
EngineArray.arrayCopy(dataRow, 0, pair.getIdealArray(), i
174+
int outputStart = (this.leadWindowSize > 0) ? 1 : 0;
175+
int outputSize = calculateLeadCount();
176+
for (int i = 0; i < outputSize; i++) {
177+
double[] dataRow = lookupDataRow((int) (index + i+outputStart));
178+
EngineArray.arrayCopy(dataRow, this.calculatedInputSize, pair.getIdealArray(), i
170179
* MatrixMLDataSet.this.calculatedIdealSize,
171180
MatrixMLDataSet.this.calculatedIdealSize);
172181
}
173-
174-
double[] dataRow = lookupDataRow((int) index);
175-
176-
EngineArray.arrayCopy(dataRow, 0, pair.getInputArray(), 0,
177-
MatrixMLDataSet.this.calculatedInputSize);
178-
EngineArray.arrayCopy(dataRow,
179-
MatrixMLDataSet.this.calculatedInputSize, pair.getIdealArray(),
180-
0, MatrixMLDataSet.this.calculatedIdealSize);
181-
182182
}
183183

184184
private double[] lookupDataRow(int index) {
@@ -229,8 +229,19 @@ public int size() {
229229

230230
@Override
231231
public MLDataPair get(int index) {
232-
// TODO Auto-generated method stub
233-
return null;
232+
if (index>size()) {
233+
return null;
234+
}
235+
236+
BasicMLData input = new BasicMLData(
237+
MatrixMLDataSet.this.calculatedInputSize*calculateLagCount());
238+
BasicMLData ideal = new BasicMLData(
239+
MatrixMLDataSet.this.calculatedIdealSize*calculateLeadCount());
240+
MLDataPair pair = new BasicMLDataPair(input, ideal);
241+
242+
MatrixMLDataSet.this.getRecord(index, pair);
243+
244+
return pair;
234245
}
235246

236247
/**
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
package org.encog.ml.data.versatile;
2+
3+
import junit.framework.Assert;
4+
5+
import org.encog.Encog;
6+
import org.encog.ml.data.MLDataPair;
7+
import org.junit.Test;
8+
9+
public class TestMatrixMLDataSet {
10+
11+
public static final double[][] DATA1 = {
12+
{ 1.0, 10.0 },
13+
{ 2.0, 20.0 },
14+
{ 3.0, 30.0 },
15+
{ 4.0, 40.0 },
16+
{ 5.0, 50.0 },
17+
{ 6.0, 60.0 },
18+
{ 7.0, 70.0 },
19+
{ 8.0, 80.0 },
20+
{ 9.0, 90.0 },
21+
{ 10.0, 100.0 }
22+
};
23+
24+
@Test
25+
public void testTimeSeriesLead1Lag0() {
26+
MatrixMLDataSet dset = new MatrixMLDataSet(DATA1,1,1);
27+
dset.setLeadWindowSize(1);
28+
29+
Assert.assertEquals(9, dset.size());
30+
31+
MLDataPair p1 = dset.get(0);
32+
Assert.assertEquals(1.0, p1.getInput().getData(0),Encog.DEFAULT_DOUBLE_EQUAL);
33+
Assert.assertEquals(20.0, p1.getIdeal().getData(0),Encog.DEFAULT_DOUBLE_EQUAL);
34+
35+
MLDataPair p2 = dset.get(1);
36+
Assert.assertEquals(2.0, p2.getInput().getData(0),Encog.DEFAULT_DOUBLE_EQUAL);
37+
Assert.assertEquals(30.0, p2.getIdeal().getData(0),Encog.DEFAULT_DOUBLE_EQUAL);
38+
39+
MLDataPair p3 = dset.get(2);
40+
Assert.assertEquals(3.0, p3.getInput().getData(0),Encog.DEFAULT_DOUBLE_EQUAL);
41+
Assert.assertEquals(40.0, p3.getIdeal().getData(0),Encog.DEFAULT_DOUBLE_EQUAL);
42+
}
43+
44+
@Test
45+
public void testTimeSeriesLead0Lag1() {
46+
MatrixMLDataSet dset = new MatrixMLDataSet(DATA1,1,1);
47+
dset.setLagWindowSize(1);
48+
49+
Assert.assertEquals(9, dset.size());
50+
51+
MLDataPair p1 = dset.get(0);
52+
Assert.assertEquals(1.0, p1.getInput().getData(0),Encog.DEFAULT_DOUBLE_EQUAL);
53+
Assert.assertEquals(2.0, p1.getInput().getData(1),Encog.DEFAULT_DOUBLE_EQUAL);
54+
Assert.assertEquals(10.0, p1.getIdeal().getData(0),Encog.DEFAULT_DOUBLE_EQUAL);
55+
56+
MLDataPair p2 = dset.get(1);
57+
Assert.assertEquals(2.0, p2.getInput().getData(0),Encog.DEFAULT_DOUBLE_EQUAL);
58+
Assert.assertEquals(3.0, p2.getInput().getData(1),Encog.DEFAULT_DOUBLE_EQUAL);
59+
Assert.assertEquals(20.0, p2.getIdeal().getData(0),Encog.DEFAULT_DOUBLE_EQUAL);
60+
61+
MLDataPair p3 = dset.get(2);
62+
Assert.assertEquals(3.0, p3.getInput().getData(0),Encog.DEFAULT_DOUBLE_EQUAL);
63+
Assert.assertEquals(4.0, p3.getInput().getData(1),Encog.DEFAULT_DOUBLE_EQUAL);
64+
Assert.assertEquals(30.0, p3.getIdeal().getData(0),Encog.DEFAULT_DOUBLE_EQUAL);
65+
}
66+
67+
@Test
68+
public void testTimeSeriesLead1Lag1() {
69+
MatrixMLDataSet dset = new MatrixMLDataSet(DATA1,1,1);
70+
dset.setLeadWindowSize(1);
71+
dset.setLagWindowSize(1);
72+
73+
Assert.assertEquals(8, dset.size());
74+
75+
MLDataPair p1 = dset.get(0);
76+
Assert.assertEquals(1.0, p1.getInput().getData(0),Encog.DEFAULT_DOUBLE_EQUAL);
77+
Assert.assertEquals(2.0, p1.getInput().getData(1),Encog.DEFAULT_DOUBLE_EQUAL);
78+
Assert.assertEquals(20.0, p1.getIdeal().getData(0),Encog.DEFAULT_DOUBLE_EQUAL);
79+
80+
MLDataPair p2 = dset.get(1);
81+
Assert.assertEquals(2.0, p2.getInput().getData(0),Encog.DEFAULT_DOUBLE_EQUAL);
82+
Assert.assertEquals(3.0, p2.getInput().getData(1),Encog.DEFAULT_DOUBLE_EQUAL);
83+
Assert.assertEquals(30.0, p2.getIdeal().getData(0),Encog.DEFAULT_DOUBLE_EQUAL);
84+
85+
MLDataPair p3 = dset.get(2);
86+
Assert.assertEquals(3.0, p3.getInput().getData(0),Encog.DEFAULT_DOUBLE_EQUAL);
87+
Assert.assertEquals(4.0, p3.getInput().getData(1),Encog.DEFAULT_DOUBLE_EQUAL);
88+
Assert.assertEquals(40.0, p3.getIdeal().getData(0),Encog.DEFAULT_DOUBLE_EQUAL);
89+
}
90+
91+
@Test
92+
public void testTimeSeriesLead2Lag1() {
93+
MatrixMLDataSet dset = new MatrixMLDataSet(DATA1,1,1);
94+
dset.setLeadWindowSize(2);
95+
dset.setLagWindowSize(1);
96+
97+
Assert.assertEquals(7, dset.size());
98+
99+
MLDataPair p1 = dset.get(0);
100+
Assert.assertEquals(1.0, p1.getInput().getData(0),Encog.DEFAULT_DOUBLE_EQUAL);
101+
Assert.assertEquals(2.0, p1.getInput().getData(1),Encog.DEFAULT_DOUBLE_EQUAL);
102+
Assert.assertEquals(20.0, p1.getIdeal().getData(0),Encog.DEFAULT_DOUBLE_EQUAL);
103+
Assert.assertEquals(30.0, p1.getIdeal().getData(1),Encog.DEFAULT_DOUBLE_EQUAL);
104+
105+
MLDataPair p2 = dset.get(1);
106+
Assert.assertEquals(2.0, p2.getInput().getData(0),Encog.DEFAULT_DOUBLE_EQUAL);
107+
Assert.assertEquals(3.0, p2.getInput().getData(1),Encog.DEFAULT_DOUBLE_EQUAL);
108+
Assert.assertEquals(30.0, p2.getIdeal().getData(0),Encog.DEFAULT_DOUBLE_EQUAL);
109+
Assert.assertEquals(40.0, p2.getIdeal().getData(1),Encog.DEFAULT_DOUBLE_EQUAL);
110+
111+
MLDataPair p3 = dset.get(2);
112+
Assert.assertEquals(3.0, p3.getInput().getData(0),Encog.DEFAULT_DOUBLE_EQUAL);
113+
Assert.assertEquals(4.0, p3.getInput().getData(1),Encog.DEFAULT_DOUBLE_EQUAL);
114+
Assert.assertEquals(40.0, p3.getIdeal().getData(0),Encog.DEFAULT_DOUBLE_EQUAL);
115+
Assert.assertEquals(50.0, p3.getIdeal().getData(1),Encog.DEFAULT_DOUBLE_EQUAL);
116+
}
117+
118+
@Test
119+
public void testTimeSeriesLead1Lag2() {
120+
MatrixMLDataSet dset = new MatrixMLDataSet(DATA1,1,1);
121+
dset.setLeadWindowSize(1);
122+
dset.setLagWindowSize(2);
123+
124+
Assert.assertEquals(7, dset.size());
125+
126+
MLDataPair p1 = dset.get(0);
127+
Assert.assertEquals(1.0, p1.getInput().getData(0),Encog.DEFAULT_DOUBLE_EQUAL);
128+
Assert.assertEquals(2.0, p1.getInput().getData(1),Encog.DEFAULT_DOUBLE_EQUAL);
129+
Assert.assertEquals(3.0, p1.getInput().getData(2),Encog.DEFAULT_DOUBLE_EQUAL);
130+
Assert.assertEquals(20.0, p1.getIdeal().getData(0),Encog.DEFAULT_DOUBLE_EQUAL);
131+
132+
MLDataPair p2 = dset.get(1);
133+
Assert.assertEquals(2.0, p2.getInput().getData(0),Encog.DEFAULT_DOUBLE_EQUAL);
134+
Assert.assertEquals(3.0, p2.getInput().getData(1),Encog.DEFAULT_DOUBLE_EQUAL);
135+
Assert.assertEquals(4.0, p2.getInput().getData(2),Encog.DEFAULT_DOUBLE_EQUAL);
136+
Assert.assertEquals(30.0, p2.getIdeal().getData(0),Encog.DEFAULT_DOUBLE_EQUAL);
137+
138+
MLDataPair p3 = dset.get(2);
139+
Assert.assertEquals(3.0, p3.getInput().getData(0),Encog.DEFAULT_DOUBLE_EQUAL);
140+
Assert.assertEquals(4.0, p3.getInput().getData(1),Encog.DEFAULT_DOUBLE_EQUAL);
141+
Assert.assertEquals(5.0, p3.getInput().getData(2),Encog.DEFAULT_DOUBLE_EQUAL);
142+
Assert.assertEquals(40.0, p3.getIdeal().getData(0),Encog.DEFAULT_DOUBLE_EQUAL);
143+
}
144+
}

0 commit comments

Comments
 (0)