Skip to content

Commit c528fa2

Browse files
committed
Handle feature-hash columns in getCategoricalMask and fix unused import
getCategoricalMask only accounted for recode and dummycode columns, so specs using feature hashing produced a mask with the wrong number of columns and the DML check failed. Parse the hash column list and bucket count K from the spec: a hashed column is categorical, and a hashed column that is also dummycoded expands to K columns rather than the recode distinct count. Also correct the ID-based spec guard (it never triggered) to actually require an ID-based spec, and remove the now-unused java.io.FileWriter import in TestUtils that broke Checkstyle.
1 parent a2165d7 commit c528fa2

2 files changed

Lines changed: 28 additions & 7 deletions

File tree

src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,18 +66,34 @@ public void processGetCategorical(ExecutionContext ec, FrameBlock f, ScalarObjec
6666

6767
JSONObject jSpec = new JSONObject(spec.getStringValue());
6868

69-
if(!jSpec.containsKey("ids") && jSpec.getBoolean("ids")) {
69+
if(!jSpec.containsKey("ids") || !jSpec.getBoolean("ids")) {
7070
throw new DMLRuntimeException("not supported non ID based spec for get_categorical_mask");
7171
}
7272

7373
String recode = TfMethod.RECODE.toString();
7474
String dummycode = TfMethod.DUMMYCODE.toString();
75+
String hash = TfMethod.HASH.toString();
7576

7677
int[] lengths = new int[nCol];
7778
// assume all columns encode to at least one column.
7879
Arrays.fill(lengths, 1);
7980
boolean[] categorical = new boolean[nCol];
8081

82+
// feature-hashed columns map to K buckets; a plain hashed column
83+
// produces a single (categorical) bucket-id column, while a hashed
84+
// column that is additionally dummycoded expands to K columns.
85+
boolean[] hashed = new boolean[nCol];
86+
int K = 0;
87+
if(jSpec.containsKey(hash)) {
88+
K = jSpec.getInt("K");
89+
JSONArray a = jSpec.getJSONArray(hash);
90+
for(Object aa : a) {
91+
int av = (Integer) aa - 1;
92+
hashed[av] = true;
93+
categorical[av] = true;
94+
}
95+
}
96+
8197
if(jSpec.containsKey(recode)) {
8298
JSONArray a = jSpec.getJSONArray(recode);
8399
for(Object aa : a) {
@@ -90,14 +106,20 @@ public void processGetCategorical(ExecutionContext ec, FrameBlock f, ScalarObjec
90106
JSONArray a = jSpec.getJSONArray(dummycode);
91107
for(Object aa : a) {
92108
int av = (Integer) aa - 1;
93-
ColumnMetadata d = f.getColumnMetadata()[av];
94-
String v = f.getString(0, av);
95109
int ndist;
96-
if(v.length() > 1 && v.charAt(0) == '¿') {
97-
ndist = UtilFunctions.parseToInt(v.substring(1));
110+
if(hashed[av]) {
111+
// feature hashing followed by dummycoding yields K columns
112+
ndist = K;
98113
}
99114
else {
100-
ndist = d.isDefault() ? 0 : (int) d.getNumDistinct();
115+
ColumnMetadata d = f.getColumnMetadata()[av];
116+
String v = f.getString(0, av);
117+
if(v.length() > 1 && v.charAt(0) == '¿') {
118+
ndist = UtilFunctions.parseToInt(v.substring(1));
119+
}
120+
else {
121+
ndist = d.isDefault() ? 0 : (int) d.getNumDistinct();
122+
}
101123
}
102124
lengths[av] = ndist;
103125
categorical[av] = true;

src/test/java/org/apache/sysds/test/TestUtils.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
import java.io.FileInputStream;
3333
import java.io.FileOutputStream;
3434
import java.io.FileReader;
35-
import java.io.FileWriter;
3635
import java.io.IOException;
3736
import java.io.InputStreamReader;
3837
import java.io.OutputStreamWriter;

0 commit comments

Comments
 (0)