4141 */
4242public class EncogModel {
4343
44+ /**
45+ * The dataset to use.
46+ */
4447 private final VersatileMLDataSet dataset ;
48+
49+ /**
50+ * The input features.
51+ */
4552 private final List <ColumnDefinition > inputFeatures = new ArrayList <ColumnDefinition >();
53+
54+ /**
55+ * The predicted features.
56+ */
4657 private final List <ColumnDefinition > predictedFeatures = new ArrayList <ColumnDefinition >();
58+
59+ /**
60+ * The training dataset.
61+ */
4762 private MatrixMLDataSet trainingDataset ;
63+
64+ /**
65+ * The validation dataset.
66+ */
4867 private MatrixMLDataSet validationDataset ;
68+
69+ /**
70+ * The standard configrations for each method type.
71+ */
4972 private final Map <String , MethodConfig > methodConfigurations = new HashMap <String , MethodConfig >();
73+
74+ /**
75+ * The current method configuration, determined by the selected model.
76+ */
5077 private MethodConfig config ;
78+
79+ /**
80+ * The selected method type.
81+ */
5182 private String methodType ;
83+
84+ /**
85+ * The method arguments for the selected method.
86+ */
5287 private String methodArgs ;
88+
89+ /**
90+ * The selected training type.
91+ */
5392 private String trainingType ;
93+
94+ /**
95+ * The training arguments for the selected training type.
96+ */
5497 private String trainingArgs ;
98+
99+ /**
100+ * The report.
101+ */
55102 private StatusReportable report = new NullStatusReportable ();
56103
104+ /**
105+ * Construct a model for the specified dataset.
106+ * @param theDataset The dataset.
107+ */
57108 public EncogModel (VersatileMLDataSet theDataset ) {
58109 this .dataset = theDataset ;
59110 this .methodConfigurations .put (MLMethodFactory .TYPE_FEEDFORWARD ,
@@ -89,6 +140,12 @@ public List<ColumnDefinition> getPredictedFeatures() {
89140 return predictedFeatures ;
90141 }
91142
143+ /**
144+ * Specify a validation set to hold back.
145+ * @param validationPercent The percent to use for validation.
146+ * @param shuffle True to shuffle.
147+ * @param seed The seed for random generation.
148+ */
92149 public void holdBackValidation (double validationPercent , boolean shuffle ,
93150 int seed ) {
94151 List <DataDivision > dataDivisionList = new ArrayList <DataDivision >();
@@ -100,7 +157,13 @@ public void holdBackValidation(double validationPercent, boolean shuffle,
100157 this .validationDataset = dataDivisionList .get (1 ).getDataset ();
101158 }
102159
103- public void fitFold (int k , int foldNum , DataFold fold ) {
160+ /**
161+ * Fit the model using cross validation.
162+ * @param k The number of folds total.
163+ * @param foldNum The current fold.
164+ * @param fold The current fold.
165+ */
166+ private void fitFold (int k , int foldNum , DataFold fold ) {
104167 MLMethod method = this .createMethod ();
105168 MLTrain train = this .createTrainer (method , fold .getTraining ());
106169
@@ -143,6 +206,12 @@ public void fitFold(int k, int foldNum, DataFold fold) {
143206 }
144207 }
145208
209+ /**
210+ * Calculate the error for the given method and dataset.
211+ * @param method The method to use.
212+ * @param data The data to use.
213+ * @return The error.
214+ */
146215 public double calculateError (MLMethod method , MLDataSet data ) {
147216 if (this .dataset .getNormHelper ().getOutputColumns ().size () == 1 ) {
148217 ColumnDefinition cd = this .dataset .getNormHelper ()
@@ -157,6 +226,12 @@ public double calculateError(MLMethod method, MLDataSet data) {
157226 data );
158227 }
159228
229+ /**
230+ * Create a trainer.
231+ * @param method The method to train.
232+ * @param dataset The dataset.
233+ * @return The trainer.
234+ */
160235 private MLTrain createTrainer (MLMethod method , MLDataSet dataset ) {
161236
162237 if (this .trainingType == null ) {
@@ -169,6 +244,12 @@ private MLTrain createTrainer(MLMethod method, MLDataSet dataset) {
169244 return train ;
170245 }
171246
247+ /**
248+ * Crossvalidate and fit.
249+ * @param k The number of folds.
250+ * @param shuffle True if we should shuffle.
251+ * @return The trained method.
252+ */
172253 public MLMethod crossvalidate (int k , boolean shuffle ) {
173254 KFoldCrossvalidation cross = new KFoldCrossvalidation (
174255 this .trainingDataset , k );
@@ -226,6 +307,14 @@ public void setValidationDataset(MatrixMLDataSet validationDataset) {
226307 this .validationDataset = validationDataset ;
227308 }
228309
310+ /**
311+ * Select the method to use.
312+ * @param dataset The dataset.
313+ * @param methodType The type of method.
314+ * @param methodArgs The method arguments.
315+ * @param trainingType The training type.
316+ * @param trainingArgs The training arguments.
317+ */
229318 public void selectMethod (VersatileMLDataSet dataset , String methodType ,
230319 String methodArgs , String trainingType , String trainingArgs ) {
231320
@@ -241,6 +330,10 @@ public void selectMethod(VersatileMLDataSet dataset, String methodType,
241330
242331 }
243332
333+ /**
334+ * Create the selected method.
335+ * @return The created method.
336+ */
244337 public MLMethod createMethod () {
245338 if (this .methodType == null ) {
246339 throw new EncogError (
@@ -253,6 +346,11 @@ public MLMethod createMethod() {
253346 return method ;
254347 }
255348
349+ /**
350+ * Select the method to create.
351+ * @param dataset The dataset.
352+ * @param methodType The method type.
353+ */
256354 public void selectMethod (VersatileMLDataSet dataset , String methodType ) {
257355 if (!this .methodConfigurations .containsKey (methodType )) {
258356 throw new EncogError ("Don't know how to autoconfig method: "
@@ -267,6 +365,10 @@ public void selectMethod(VersatileMLDataSet dataset, String methodType) {
267365
268366 }
269367
368+ /**
369+ * Select the training type.
370+ * @param dataset The dataset.
371+ */
270372 public void selectTrainingType (VersatileMLDataSet dataset ) {
271373 if (this .methodType == null ) {
272374 throw new EncogError (
@@ -277,6 +379,12 @@ public void selectTrainingType(VersatileMLDataSet dataset) {
277379 config .suggestTrainingArgs (trainingType ));
278380 }
279381
382+ /**
383+ * Select the training to use.
384+ * @param dataset The dataset.
385+ * @param trainingType The type of training.
386+ * @param trainingArgs The training arguments.
387+ */
280388 public void selectTraining (VersatileMLDataSet dataset , String trainingType ,
281389 String trainingArgs ) {
282390 if (this .methodType == null ) {
0 commit comments