Skip to content

Commit c31f8ad

Browse files
committed
Working on EncogModel & time series
1 parent 2de830c commit c31f8ad

7 files changed

Lines changed: 83 additions & 21 deletions

File tree

src/main/java/org/encog/ml/data/versatile/CSVDataSource.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package org.encog.ml.data.versatile;
22

33
import java.io.File;
4+
import java.util.HashMap;
5+
import java.util.Map;
46

57
import org.encog.EncogError;
68
import org.encog.util.csv.CSVFormat;
@@ -12,6 +14,7 @@ public class CSVDataSource implements VersatileDataSource {
1214
private final File file;
1315
private final boolean headers;
1416
private final CSVFormat format;
17+
private final Map<String,Integer> headerIndex = new HashMap<String,Integer>();
1518

1619
/**
1720
* Construct a CSV reader from a filename. The format parameter specifies
@@ -72,6 +75,20 @@ public String[] readLine() {
7275
@Override
7376
public void rewind() {
7477
this.reader = new ReadCSV(this.file,this.headers,this.format);
78+
if( this.headerIndex.size()==0 ) {
79+
for(int i=0;i<this.reader.getColumnNames().size();i++) {
80+
this.headerIndex.put(this.reader.getColumnNames().get(i), i);
81+
}
82+
}
83+
}
84+
85+
@Override
86+
public int columnIndex(String name) {
87+
String name2 = name.toLowerCase();
88+
if(!this.headerIndex.containsKey(name2)) {
89+
return -1;
90+
}
91+
return this.headerIndex.get(name2);
7592
}
7693

7794
}

src/main/java/org/encog/ml/data/versatile/ColumnDefinition.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ public class ColumnDefinition {
1414
private double mean;
1515
private double sd;
1616
private int count;
17+
private int index;
1718
private final List<String> classes = new ArrayList<String>();
1819
private NormalizationHelper owner;
1920

@@ -212,4 +213,20 @@ public void defineClass(String[] str) {
212213
public void setOwner(NormalizationHelper theOwner) {
213214
this.owner = theOwner;
214215
}
216+
217+
/**
218+
* @return the index
219+
*/
220+
public int getIndex() {
221+
return index;
222+
}
223+
224+
/**
225+
* @param index the index to set
226+
*/
227+
public void setIndex(int index) {
228+
this.index = index;
229+
}
230+
231+
215232
}

src/main/java/org/encog/ml/data/versatile/MatrixMLDataSet.java

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,17 +60,7 @@ public final MLDataPair next() {
6060
return null;
6161
}
6262

63-
BasicMLData input = new BasicMLData(
64-
MatrixMLDataSet.this.calculatedInputSize);
65-
BasicMLData ideal = new BasicMLData(
66-
MatrixMLDataSet.this.calculatedIdealSize);
67-
MLDataPair pair = new BasicMLDataPair(input, ideal);
68-
69-
MatrixMLDataSet.this.getRecord(this.currentIndex, pair);
70-
71-
this.currentIndex++;
72-
73-
return pair;
63+
return MatrixMLDataSet.this.get(this.currentIndex++);
7464
}
7565

7666
/**
@@ -141,6 +131,10 @@ public boolean isSupervised() {
141131

142132
@Override
143133
public long getRecordCount() {
134+
if( this.data==null ) {
135+
throw new EncogError("You must normalize the dataset before using it.");
136+
}
137+
144138
if (this.mask == null) {
145139
return this.data.length
146140
- (this.lagWindowSize + this.leadWindowSize);
@@ -159,7 +153,10 @@ private int calculateLeadCount() {
159153

160154
@Override
161155
public void getRecord(long index, MLDataPair pair) {
162-
156+
if( this.data==null ) {
157+
throw new EncogError("You must normalize the dataset before using it.");
158+
}
159+
163160
// Copy the input, account for time windows.
164161
int inputSize = calculateLagCount();
165162
for (int i = 0; i < inputSize; i++) {

src/main/java/org/encog/ml/data/versatile/NormalizationHelper.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,9 @@ public void addSourceColumn(ColumnDefinition def) {
7474
def.setOwner(this);
7575
}
7676

77-
public ColumnDefinition defineSourceColumn(String name, ColumnType colType) {
77+
public ColumnDefinition defineSourceColumn(String name, int index, ColumnType colType) {
7878
ColumnDefinition result = new ColumnDefinition(name,colType);
79+
result.setIndex(index);
7980
addSourceColumn(result);
8081
return result;
8182
}

src/main/java/org/encog/ml/data/versatile/VersatileDataSource.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@
33
public interface VersatileDataSource {
44
String[] readLine();
55
void rewind();
6+
int columnIndex(String name);
67
}

src/main/java/org/encog/ml/data/versatile/VersatileMLDataSet.java

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,21 @@ public class VersatileMLDataSet extends MatrixMLDataSet {
1515
public VersatileMLDataSet(VersatileDataSource theSource) {
1616
this.source = theSource;
1717
}
18+
19+
private int findIndex(ColumnDefinition colDef) {
20+
if( colDef.getIndex()!=-1 ) {
21+
return colDef.getIndex();
22+
}
23+
24+
int index = this.source.columnIndex(colDef.getName());
25+
colDef.setIndex(index);
26+
27+
if( index==-1 ) {
28+
throw new EncogError("Can't find column");
29+
}
30+
31+
return index;
32+
}
1833

1934
public void analyze() {
2035
String[] line;
@@ -26,7 +41,8 @@ public void analyze() {
2641
c++;
2742
for (int i = 0; i < this.helper.getSourceColumns().size(); i++) {
2843
ColumnDefinition colDef = this.helper.getSourceColumns().get(i);
29-
String value = line[i];
44+
int index = findIndex(colDef);
45+
String value = line[index];
3046
colDef.analyze(value);
3147
}
3248
}
@@ -91,14 +107,14 @@ public void normalize() {
91107
while ((line = this.source.readLine()) != null) {
92108
int column = 0;
93109
for (ColumnDefinition colDef : this.helper.getInputColumns()) {
94-
int index = this.helper.getSourceColumns().indexOf(colDef);
110+
int index = findIndex(colDef);
95111
String value = line[index];
96112

97113
column = this.helper.normalizeToVector(colDef, column, getData()[row], true, value);
98114
}
99115

100116
for (ColumnDefinition colDef : this.helper.getOutputColumns()) {
101-
int index = this.helper.getSourceColumns().indexOf(colDef);
117+
int index = findIndex(colDef);
102118
String value = line[index];
103119

104120
column = this.helper.normalizeToVector(colDef, column, getData()[row], false, value);
@@ -107,8 +123,8 @@ public void normalize() {
107123
}
108124
}
109125

110-
public ColumnDefinition defineSourceColumn(String name, ColumnType colType) {
111-
return this.helper.defineSourceColumn(name, colType);
126+
public ColumnDefinition defineSourceColumn(String name, int index, ColumnType colType) {
127+
return this.helper.defineSourceColumn(name, index, colType);
112128
}
113129

114130
/**
@@ -137,17 +153,30 @@ public void divide(List<DataDivision> dataDivisionList, boolean shuffle,
137153
getCalculatedIdealSize());
138154

139155
}
156+
157+
public void defineOutput(ColumnDefinition col) {
158+
this.helper.getOutputColumns().add(col);
159+
}
160+
161+
public void defineInput(ColumnDefinition col) {
162+
this.helper.getInputColumns().add(col);
163+
}
140164

141165
public void defineSingleOutputOthersInput(ColumnDefinition outputColumn) {
142166
this.helper.clearInputOutput();
143167

144168
for (ColumnDefinition colDef : this.helper.getSourceColumns()) {
145169
if (colDef == outputColumn) {
146-
this.helper.getOutputColumns().add(colDef);
170+
defineOutput(colDef);
147171
} else if(colDef.getDataType()!=ColumnType.ignore) {
148-
this.helper.getInputColumns().add(colDef);
172+
defineInput(colDef);
149173
}
150174
}
151175
}
152176

177+
public ColumnDefinition defineSourceColumn(String name,
178+
ColumnType colType) {
179+
return this.helper.defineSourceColumn(name, -1, colType);
180+
}
181+
153182
}

src/main/java/org/encog/ml/train/strategy/end/SimpleEarlyStoppingStrategy.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ public void postIteration() {
110110

111111
double currentValidationError = this.calc.calculateError(this.validationSet);
112112

113-
if( currentValidationError>this.lastValidationError ) {
113+
if( currentValidationError>=this.lastValidationError ) {
114114
stop = true;
115115
}
116116

0 commit comments

Comments
 (0)