Skip to content

Commit b7a5876

Browse files
committed
Javadoc
1 parent 71b272e commit b7a5876

8 files changed

Lines changed: 240 additions & 6 deletions

File tree

src/main/java/org/encog/ml/model/EncogModel.java

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,19 +41,70 @@
4141
*/
4242
public class EncogModel {
4343

44+
/**
45+
* The dataset to use.
46+
*/
4447
private final VersatileMLDataSet dataset;
48+
49+
/**
50+
* The input features.
51+
*/
4552
private final List<ColumnDefinition> inputFeatures = new ArrayList<ColumnDefinition>();
53+
54+
/**
55+
* The predicted features.
56+
*/
4657
private final List<ColumnDefinition> predictedFeatures = new ArrayList<ColumnDefinition>();
58+
59+
/**
60+
* The training dataset.
61+
*/
4762
private MatrixMLDataSet trainingDataset;
63+
64+
/**
65+
* The validation dataset.
66+
*/
4867
private MatrixMLDataSet validationDataset;
68+
69+
/**
70+
* The standard configrations for each method type.
71+
*/
4972
private final Map<String, MethodConfig> methodConfigurations = new HashMap<String, MethodConfig>();
73+
74+
/**
75+
* The current method configuration, determined by the selected model.
76+
*/
5077
private MethodConfig config;
78+
79+
/**
80+
* The selected method type.
81+
*/
5182
private String methodType;
83+
84+
/**
85+
* The method arguments for the selected method.
86+
*/
5287
private String methodArgs;
88+
89+
/**
90+
* The selected training type.
91+
*/
5392
private String trainingType;
93+
94+
/**
95+
* The training arguments for the selected training type.
96+
*/
5497
private String trainingArgs;
98+
99+
/**
100+
* The report.
101+
*/
55102
private StatusReportable report = new NullStatusReportable();
56103

104+
/**
105+
* Construct a model for the specified dataset.
106+
* @param theDataset The dataset.
107+
*/
57108
public EncogModel(VersatileMLDataSet theDataset) {
58109
this.dataset = theDataset;
59110
this.methodConfigurations.put(MLMethodFactory.TYPE_FEEDFORWARD,
@@ -89,6 +140,12 @@ public List<ColumnDefinition> getPredictedFeatures() {
89140
return predictedFeatures;
90141
}
91142

143+
/**
144+
* Specify a validation set to hold back.
145+
* @param validationPercent The percent to use for validation.
146+
* @param shuffle True to shuffle.
147+
* @param seed The seed for random generation.
148+
*/
92149
public void holdBackValidation(double validationPercent, boolean shuffle,
93150
int seed) {
94151
List<DataDivision> dataDivisionList = new ArrayList<DataDivision>();
@@ -100,7 +157,13 @@ public void holdBackValidation(double validationPercent, boolean shuffle,
100157
this.validationDataset = dataDivisionList.get(1).getDataset();
101158
}
102159

103-
public void fitFold(int k, int foldNum, DataFold fold) {
160+
/**
161+
* Fit the model using cross validation.
162+
* @param k The number of folds total.
163+
* @param foldNum The current fold.
164+
* @param fold The current fold.
165+
*/
166+
private void fitFold(int k, int foldNum, DataFold fold) {
104167
MLMethod method = this.createMethod();
105168
MLTrain train = this.createTrainer(method, fold.getTraining());
106169

@@ -143,6 +206,12 @@ public void fitFold(int k, int foldNum, DataFold fold) {
143206
}
144207
}
145208

209+
/**
210+
* Calculate the error for the given method and dataset.
211+
* @param method The method to use.
212+
* @param data The data to use.
213+
* @return The error.
214+
*/
146215
public double calculateError(MLMethod method, MLDataSet data) {
147216
if (this.dataset.getNormHelper().getOutputColumns().size() == 1) {
148217
ColumnDefinition cd = this.dataset.getNormHelper()
@@ -157,6 +226,12 @@ public double calculateError(MLMethod method, MLDataSet data) {
157226
data);
158227
}
159228

229+
/**
230+
* Create a trainer.
231+
* @param method The method to train.
232+
* @param dataset The dataset.
233+
* @return The trainer.
234+
*/
160235
private MLTrain createTrainer(MLMethod method, MLDataSet dataset) {
161236

162237
if (this.trainingType == null) {
@@ -169,6 +244,12 @@ private MLTrain createTrainer(MLMethod method, MLDataSet dataset) {
169244
return train;
170245
}
171246

247+
/**
248+
* Crossvalidate and fit.
249+
* @param k The number of folds.
250+
* @param shuffle True if we should shuffle.
251+
* @return The trained method.
252+
*/
172253
public MLMethod crossvalidate(int k, boolean shuffle) {
173254
KFoldCrossvalidation cross = new KFoldCrossvalidation(
174255
this.trainingDataset, k);
@@ -226,6 +307,14 @@ public void setValidationDataset(MatrixMLDataSet validationDataset) {
226307
this.validationDataset = validationDataset;
227308
}
228309

310+
/**
311+
* Select the method to use.
312+
* @param dataset The dataset.
313+
* @param methodType The type of method.
314+
* @param methodArgs The method arguments.
315+
* @param trainingType The training type.
316+
* @param trainingArgs The training arguments.
317+
*/
229318
public void selectMethod(VersatileMLDataSet dataset, String methodType,
230319
String methodArgs, String trainingType, String trainingArgs) {
231320

@@ -241,6 +330,10 @@ public void selectMethod(VersatileMLDataSet dataset, String methodType,
241330

242331
}
243332

333+
/**
334+
* Create the selected method.
335+
* @return The created method.
336+
*/
244337
public MLMethod createMethod() {
245338
if (this.methodType == null) {
246339
throw new EncogError(
@@ -253,6 +346,11 @@ public MLMethod createMethod() {
253346
return method;
254347
}
255348

349+
/**
350+
* Select the method to create.
351+
* @param dataset The dataset.
352+
* @param methodType The method type.
353+
*/
256354
public void selectMethod(VersatileMLDataSet dataset, String methodType) {
257355
if (!this.methodConfigurations.containsKey(methodType)) {
258356
throw new EncogError("Don't know how to autoconfig method: "
@@ -267,6 +365,10 @@ public void selectMethod(VersatileMLDataSet dataset, String methodType) {
267365

268366
}
269367

368+
/**
369+
* Select the training type.
370+
* @param dataset The dataset.
371+
*/
270372
public void selectTrainingType(VersatileMLDataSet dataset) {
271373
if (this.methodType == null) {
272374
throw new EncogError(
@@ -277,6 +379,12 @@ public void selectTrainingType(VersatileMLDataSet dataset) {
277379
config.suggestTrainingArgs(trainingType));
278380
}
279381

382+
/**
383+
* Select the training to use.
384+
* @param dataset The dataset.
385+
* @param trainingType The type of training.
386+
* @param trainingArgs The training arguments.
387+
*/
280388
public void selectTraining(VersatileMLDataSet dataset, String trainingType,
281389
String trainingArgs) {
282390
if (this.methodType == null) {

src/main/java/org/encog/ml/model/config/FeedforwardConfig.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,22 @@
88
import org.encog.ml.factory.MLMethodFactory;
99
import org.encog.neural.networks.BasicNetwork;
1010

11+
/**
12+
* Config class for EncogModel to use a feedforward neural network.
13+
*/
1114
public class FeedforwardConfig implements MethodConfig {
1215

16+
/**
17+
* {@inheritDoc}
18+
*/
1319
@Override
1420
public String getMethodName() {
1521
return MLMethodFactory.TYPE_FEEDFORWARD;
1622
}
1723

24+
/**
25+
* {@inheritDoc}
26+
*/
1827
@Override
1928
public String suggestModelArchitecture(VersatileMLDataSet dataset) {
2029
int inputColumns = dataset.getNormHelper().getInputColumns().size();
@@ -27,6 +36,9 @@ public String suggestModelArchitecture(VersatileMLDataSet dataset) {
2736
return result.toString();
2837
}
2938

39+
/**
40+
* {@inheritDoc}
41+
*/
3042
@Override
3143
public NormalizationStrategy suggestNormalizationStrategy(VersatileMLDataSet dataset, String architecture) {
3244
double inputLow = -1;
@@ -60,17 +72,26 @@ public NormalizationStrategy suggestNormalizationStrategy(VersatileMLDataSet dat
6072
}
6173

6274

75+
/**
76+
* {@inheritDoc}
77+
*/
6378
@Override
6479
public String suggestTrainingType() {
6580
return "rprop";
6681
}
6782

6883

84+
/**
85+
* {@inheritDoc}
86+
*/
6987
@Override
7088
public String suggestTrainingArgs(String trainingType) {
7189
return "";
7290
}
7391

92+
/**
93+
* {@inheritDoc}
94+
*/
7495
@Override
7596
public int determineOutputCount(VersatileMLDataSet dataset) {
7697
return dataset.getNormHelper().calculateNormalizedOutputCount();

src/main/java/org/encog/ml/model/config/MethodConfig.java

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,49 @@
33
import org.encog.ml.data.versatile.VersatileMLDataSet;
44
import org.encog.ml.data.versatile.normalizers.strategies.NormalizationStrategy;
55

6+
/**
7+
* Define normalization for a specific method.
8+
*/
69
public interface MethodConfig {
710

11+
/**
12+
* @return The method name.
13+
*/
814
String getMethodName();
915

16+
/**
17+
* Suggest a model architecture, based on a dataset.
18+
* @param dataset The dataset.
19+
* @return The model architecture.
20+
*/
1021
String suggestModelArchitecture(VersatileMLDataSet dataset);
1122

23+
/**
24+
* Suggest a normalization strategy based on a dataset.
25+
* @param dataset The dataset.
26+
* @param architecture The architecture.
27+
* @return The strategy.
28+
*/
1229
NormalizationStrategy suggestNormalizationStrategy(VersatileMLDataSet dataset, String architecture);
1330

31+
/**
32+
* Suggest a training type.
33+
* @return The training type.
34+
*/
1435
String suggestTrainingType();
1536

37+
/**
38+
* Suggest training arguments.
39+
* @param trainingType The training type.
40+
* @return The training arguments.
41+
*/
1642
String suggestTrainingArgs(String trainingType);
1743

44+
/**
45+
* Determine the needed output count.
46+
* @param dataset The dataset.
47+
* @return The needed output count.
48+
*/
1849
int determineOutputCount(VersatileMLDataSet dataset);
1950

2051
}

src/main/java/org/encog/ml/model/config/NEATConfig.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,30 @@
88
import org.encog.ml.data.versatile.normalizers.strategies.NormalizationStrategy;
99
import org.encog.ml.factory.MLMethodFactory;
1010

11+
/**
12+
* Config class for EncogModel to use a NEAT neural network.
13+
*/
1114
public class NEATConfig implements MethodConfig {
1215

16+
/**
17+
* {@inheritDoc}
18+
*/
1319
@Override
1420
public String getMethodName() {
1521
return MLMethodFactory.TYPE_NEAT;
1622
}
1723

24+
/**
25+
* {@inheritDoc}
26+
*/
1827
@Override
1928
public String suggestModelArchitecture(VersatileMLDataSet dataset) {
2029
return("cycles=4");
2130
}
2231

32+
/**
33+
* {@inheritDoc}
34+
*/
2335
@Override
2436
public NormalizationStrategy suggestNormalizationStrategy(VersatileMLDataSet dataset, String architecture) {
2537
BasicNormalizationStrategy result = new BasicNormalizationStrategy();
@@ -34,17 +46,26 @@ public NormalizationStrategy suggestNormalizationStrategy(VersatileMLDataSet dat
3446
}
3547

3648

49+
/**
50+
* {@inheritDoc}
51+
*/
3752
@Override
3853
public String suggestTrainingType() {
3954
return "neat-ga";
4055
}
4156

4257

58+
/**
59+
* {@inheritDoc}
60+
*/
4361
@Override
4462
public String suggestTrainingArgs(String trainingType) {
4563
return "";
4664
}
4765

66+
/**
67+
* {@inheritDoc}
68+
*/
4869
@Override
4970
public int determineOutputCount(VersatileMLDataSet dataset) {
5071
return dataset.getNormHelper().calculateNormalizedOutputCount();

0 commit comments

Comments
 (0)