Skip to content

Commit 11bb5ac

Browse files
committed
Fixed issue with TemporalMLData set's index.
1 parent 14ab58b commit 11bb5ac

1 file changed

Lines changed: 14 additions & 7 deletions

File tree

src/main/java/org/encog/ml/data/temporal/TemporalMLDataSet.java

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.encog.ml.data.MLDataPair;
3434
import org.encog.ml.data.basic.BasicMLData;
3535
import org.encog.ml.data.basic.BasicMLDataPair;
36+
import org.encog.ml.data.temporal.TemporalDataDescription.Type;
3637
import org.encog.neural.data.basic.BasicNeuralData;
3738
import org.encog.neural.data.basic.BasicNeuralDataSet;
3839
import org.encog.util.time.TimeSpan;
@@ -329,7 +330,14 @@ public TemporalPoint createPoint(final int sequence) {
329330
private double formatData(final TemporalDataDescription desc,
330331
final int index) {
331332
final double[] result = new double[1];
332-
333+
334+
if( desc.getType()==Type.DELTA_CHANGE || desc.getType()==Type.PERCENT_CHANGE ) {
335+
if (index + this.inputWindowSize > this.points.size()) {
336+
throw new TemporalError("Can't generate input temporal data "
337+
+ "beyond the end of provided data.");
338+
}
339+
}
340+
333341
switch (desc.getType()) {
334342
case DELTA_CHANGE:
335343
result[0] = getDataDeltaChange(desc, index);
@@ -356,6 +364,7 @@ private double formatData(final TemporalDataDescription desc,
356364
*/
357365
public void generate() {
358366
sortPoints();
367+
// add one to the start index so we are "one ahead", needed to calculate DELTA, if that encoding is chosen.
359368
final int start = calculateStartIndex() + 1;
360369
final int setSize = calculateActualSetSize();
361370
final int range = start + setSize - this.predictWindowSize
@@ -379,11 +388,6 @@ public void generate() {
379388
* @return The input neural data generated.
380389
*/
381390
public BasicNeuralData generateInputNeuralData(final int index) {
382-
if (index + this.inputWindowSize > this.points.size()) {
383-
throw new TemporalError("Can't generate input temporal data "
384-
+ "beyond the end of provided data.");
385-
}
386-
387391
final BasicNeuralData result = new BasicNeuralData(
388392
this.inputNeuronCount);
389393
int resultIndex = 0;
@@ -487,7 +491,10 @@ private double getDataPercentChange(final TemporalDataDescription desc,
487491
*/
488492
private double getDataRAW(final TemporalDataDescription desc,
489493
final int index) {
490-
final TemporalPoint point = this.points.get(index);
494+
// Note: The reason that we subtract 1 from the index is because we are always one ahead.
495+
// This allows the DELTA change formatter to work. DELTA change requires two timeslices,
496+
// so we have to be one ahead. RAW only requires one, so we shift backwards.
497+
final TemporalPoint point = this.points.get(index-1);
491498
return point.getData(desc.getIndex());
492499
}
493500

0 commit comments

Comments
 (0)