Skip to content

Commit 8cb0c50

Browse files
authored
refactored the interface for Classifier to deprecate the Map based method of providing features values and allow for multiple samples to be provided and produce multiple predictions in a single API call (#4)
1 parent 83d81d0 commit 8cb0c50

20 files changed

Lines changed: 631 additions & 233 deletions

.idea/codeStyles/codeStyleConfig.xml

Lines changed: 5 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

README.md

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
This project aims to used text exported ML models generated by sci-kit learn and make them usable in Java.
44

5+
[![javadoc](https://javadoc.io/badge2/rocks.vilaverde/scikit-learn-2-java/javadoc.svg)](https://javadoc.io/doc/rocks.vilaverde/scikit-learn-2-java)
6+
57
## Support
68
* The tree.DecisionTreeClassifier is supported
79
* Supports `predict()`,
@@ -48,24 +50,25 @@ As an example, a DecisionTreeClassifier model trained on the Iris dataset and ex
4850
```
4951

5052
The exported text can then be executed in Java. Note that when calling `export_text` it is
51-
recommended that `max_depth` be set to sys.maxsize so that the tree isn't truncated.
53+
recommended that `max_depth` be set to `sys.maxsize` so that the tree isn't truncated.
5254

5355
### Java Example
54-
In this example the iris model exported using `export_tree` is parsed, features are created as a Java Map
56+
In this example the iris model exported using `export_text` is parsed, features are created as a Java Map
5557
and the decision tree is asked to predict the class.
5658

5759
```
5860
Reader tree = getTrainedModel("iris.model");
5961
final Classifier<Integer> decisionTree = DecisionTreeClassifier.parse(tree,
6062
PredictionFactory.INTEGER);
6163
62-
Map<String, Double> features = new HashMap<>();
63-
features.put("sepal length (cm)", 3.0);
64-
features.put("sepal width (cm)", 5.0);
65-
features.put("petal length (cm)", 4.0);
66-
features.put("petal width (cm)", 2.0);
64+
Features features = Features.of("sepal length (cm)",
65+
"sepal width (cm)",
66+
"petal length (cm)",
67+
"petal width (cm)");
68+
FeatureVector fv = features.newSample();
69+
fv.add(0, 3.0).add(1, 5.0).add(2, 4.0).add(3, 2.0);
6770
68-
Integer prediction = decisionTree.predict(features);
71+
Integer prediction = decisionTree.predict(fv);
6972
System.out.println(prediction.toString());
7073
```
7174

@@ -107,7 +110,7 @@ Then you can use the RandomForestClassifier class to parse the TAR archive.
107110
...
108111
109112
TarArchiveInputStream tree = getArchive("iris.tgz");
110-
final Classifier<Integer> decisionTree = RandomForestClassifier.parse(tree,
113+
final Classifier<Double> decisionTree = RandomForestClassifier.parse(tree,
111114
PredictionFactory.DOUBLE);
112115
```
113116

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
<groupId>rocks.vilaverde</groupId>
88
<artifactId>scikit-learn-2-java</artifactId>
9-
<version>1.0.3-SNAPSHOT</version>
9+
<version>1.1.0-SNAPSHOT</version>
1010

1111
<name>${project.groupId}:${project.artifactId}</name>
1212
<description>A sklearn exported_text models parser for executing in the Java runtime.</description>
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
package rocks.vilaverde.classifier;
2+
3+
import rocks.vilaverde.classifier.dt.TreeClassifier;
4+
5+
import java.util.Map;
6+
7+
/**
8+
* Abstract base class for Tree classifiers.
9+
*/
10+
public abstract class AbstractTreeClassifier<T> implements TreeClassifier<T> {
11+
12+
/**
13+
* Predict class or regression value for features.
14+
*
15+
* @param samples input samples
16+
* @return class probabilities of the input sample
17+
*/
18+
@Override
19+
public T predict(Map<String, Double> samples) {
20+
FeatureVector fv = toFeatureVector(samples);
21+
return predict(fv).get(0);
22+
}
23+
24+
/**
25+
* Predict class probabilities of the input samples features.
26+
* The predicted class probability is the fraction of samples of the same class in a leaf.
27+
*
28+
* @param samples the input samples
29+
* @return the class probabilities of the input sample
30+
*/
31+
@Override
32+
public double[] predict_proba(Map<String, Double> samples) {
33+
FeatureVector fv = toFeatureVector(samples);
34+
return predict_proba(fv)[0];
35+
}
36+
37+
/**
38+
* Convert a Map of features to a {@link FeatureVector}.
39+
* @param samples a KV map of feature name to value
40+
* @return FeatureVector
41+
*/
42+
private FeatureVector toFeatureVector(Map<String, Double> samples) {
43+
Features features = Features.fromSet(samples.keySet());
44+
FeatureVector fv = features.newSample();
45+
for (Map.Entry<String, Double> entry : samples.entrySet()) {
46+
fv.add(entry.getKey(), entry.getValue());
47+
}
48+
return fv;
49+
}
50+
}

src/main/java/rocks/vilaverde/classifier/BooleanFeature.java

Lines changed: 0 additions & 19 deletions
This file was deleted.

src/main/java/rocks/vilaverde/classifier/Classifier.java

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,44 @@
11
package rocks.vilaverde.classifier;
22

3+
import java.util.List;
34
import java.util.Map;
45
import java.util.Set;
56

67
public interface Classifier<T> {
78

9+
/**
10+
* Predict class or regression value for samples. Predictions will be
11+
* returned at the same index of the sample provided.
12+
* @param samples input samples
13+
* @return class probabilities of the input sample
14+
*/
15+
List<T> predict(FeatureVector ... samples);
16+
17+
/**
18+
* Predict class probabilities of the input samples. Probabilities will be
19+
* returned at the same index of the sample provided.
20+
* The predicted class probability is the fraction of samples of the same class in a leaf.
21+
* @param samples the input samples
22+
* @return the class probabilities of the input sample
23+
*/
24+
double[][] predict_proba(FeatureVector ... samples);
25+
826
/**
927
* Predict class or regression value for features.
10-
* @param features input samples
28+
* @param samples input samples
1129
* @return class probabilities of the input sample
1230
*/
13-
T predict(Map<String, Double> features);
31+
@Deprecated
32+
T predict(Map<String, Double> samples);
1433

1534
/**
1635
* Predict class probabilities of the input samples features.
1736
* The predicted class probability is the fraction of samples of the same class in a leaf.
18-
* @param features the input samples
37+
* @param samples the input samples
1938
* @return the class probabilities of the input sample
2039
*/
21-
double[] predict_proba(Map<String, Double> features);
40+
@Deprecated
41+
double[] predict_proba(Map<String, Double> samples);
2242

2343
/**
2444
* Get the names of all the features in the model.
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package rocks.vilaverde.classifier;
2+
3+
/**
4+
* A container for the values for each feature of a sample that will be predicted.
5+
*/
6+
public class FeatureVector {
7+
8+
private final Features features;
9+
private final double[] vector;
10+
11+
public FeatureVector(Features features) {
12+
this.features = features;
13+
this.vector = new double[features.getLength()];
14+
}
15+
16+
public FeatureVector add(String feature, boolean value) {
17+
add(feature, value ? 1.0 : 0.0);
18+
return this;
19+
}
20+
21+
public FeatureVector add(int index, boolean value) {
22+
add(index, value ? 1.0 : 0.0);
23+
return this;
24+
}
25+
26+
public FeatureVector add(int index, double value) {
27+
this.vector[index] = value;
28+
return this;
29+
}
30+
31+
public FeatureVector add(String feature, double value) {
32+
int index = this.features.getFeatureIndex(feature);
33+
add(index, value);
34+
return this;
35+
}
36+
37+
public double get(int index) {
38+
if (index >= vector.length) {
39+
throw new IllegalArgumentException(String.format("index must be less than %d", index));
40+
}
41+
42+
return vector[index];
43+
}
44+
45+
public double get(String feature) {
46+
int index = features.getFeatureIndex(feature);
47+
return get(index);
48+
}
49+
50+
public boolean hasFeature(String feature) {
51+
return this.features.getFeatureNames().contains(feature);
52+
}
53+
}
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
package rocks.vilaverde.classifier;
2+
3+
import java.util.Arrays;
4+
import java.util.Collections;
5+
import java.util.HashMap;
6+
import java.util.HashSet;
7+
import java.util.Map;
8+
import java.util.OptionalInt;
9+
import java.util.Set;
10+
11+
/**
12+
* Container class for the set of named features that will be provided for each sample.
13+
* Call {@link Features#newSample()} to create a {@link FeatureVector} to provide the
14+
* values for the sample.
15+
*/
16+
public class Features {
17+
18+
/* Map of feature name to index */
19+
private final Map<String, Integer> features = new HashMap<>();
20+
/* Feature can be added as long as no samples have yet to be created,
21+
at that point this is immutable */
22+
private boolean allowFeatureAdd = true;
23+
24+
/**
25+
* Convienence creation method.
26+
* @param features the set of features.
27+
* @return Features
28+
*/
29+
public static Features of(String ... features) {
30+
31+
// make sure all the features are unique
32+
Set<String> featureSet = new HashSet<>(Arrays.asList(features));
33+
if (featureSet.size() != features.length) {
34+
throw new IllegalArgumentException("features names are not unique");
35+
}
36+
37+
return new Features(features);
38+
}
39+
40+
public static Features fromSet(Set<String> features) {
41+
return new Features(features.toArray(new String[0]));
42+
}
43+
44+
/**
45+
* Constructor
46+
* @param features order list of features.
47+
*/
48+
private Features(String ... features) {
49+
for (int i = 0; i < features.length; i++) {
50+
this.features.put(features[i], i);
51+
}
52+
}
53+
54+
FeatureVector newSample() {
55+
allowFeatureAdd = false;
56+
return new FeatureVector(this);
57+
}
58+
59+
public void addFeature(String feature) {
60+
if (!allowFeatureAdd) {
61+
throw new IllegalStateException("features are immutable");
62+
}
63+
64+
if (!this.features.containsKey(feature)) {
65+
int next = 0;
66+
OptionalInt optionalInt = features.values().stream().mapToInt(integer -> integer).max();
67+
if (optionalInt.isPresent()) {
68+
next = optionalInt.getAsInt() + 1;
69+
}
70+
71+
this.features.put(feature, next);
72+
}
73+
}
74+
75+
public int getLength() {
76+
return this.features.size();
77+
}
78+
79+
public int getFeatureIndex(String feature) {
80+
Integer index = this.features.get(feature);
81+
if (index == null) {
82+
throw new IllegalArgumentException(String.format("feature %s does not exist", feature));
83+
}
84+
85+
return index;
86+
}
87+
88+
public Set<String> getFeatureNames() {
89+
return Collections.unmodifiableSet(this.features.keySet());
90+
}
91+
}

src/main/java/rocks/vilaverde/classifier/dt/ChoiceNode.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
* Represents a Choice in the decision tree, where when the expression is evaluated,
77
* if true will result in the child node of the choice being selected.
88
*/
9-
class ChoiceNode extends TreeNode {
9+
public class ChoiceNode extends TreeNode {
1010
private final Operator op;
1111
private final Double value;
1212

@@ -38,7 +38,7 @@ public String toString() {
3838
return String.format("%s %s", op.toString(), value.toString());
3939
}
4040

41-
public boolean eval(Double featureValue) {
41+
public boolean eval(double featureValue) {
4242
return op.apply(featureValue, value);
4343
}
4444
}

src/main/java/rocks/vilaverde/classifier/dt/DecisionNode.java

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
11
package rocks.vilaverde.classifier.dt;
22

3-
import java.util.ArrayList;
4-
import java.util.List;
5-
63
/**
74
* Represents a decision in the DecisionTreeClassifier. The decision will have
85
* a left and right hand {@link ChoiceNode} to be evaluated.
96
* A {@link ChoiceNode} may have nested {@link DecisionNode} or {@link EndNode}.
107
*/
11-
class DecisionNode extends TreeNode {
8+
public class DecisionNode extends TreeNode {
129

1310
private final String featureName;
1411

@@ -26,7 +23,7 @@ public static DecisionNode create(String feature) {
2623

2724
/**
2825
* Private Constructor.
29-
* @param featureName
26+
* @param featureName the name of the feature used in this decision
3027
*/
3128
private DecisionNode(String featureName) {
3229
this.featureName = featureName.intern();

0 commit comments

Comments
 (0)