Skip to content

Commit 34ee082

Browse files
committed
Refactor getCategoricalMask into staged builder and expand tests
Extract the spec validation into a validate helper and move the per-column mask accumulation into a CategoricalMask inner class with one method per stage (hash, recode, dummycode, sumLengths, toMatrixBlock). Allocate the lengths/categorical/hashed arrays lazily so a column no method touches keeps the default of a single non-categorical output column, and short-circuit sumLengths to nCol when no expansion occurred. Add instruction-level tests covering output-column offset mapping across interleaved encodings: pass-through-only specs, recode/dummycode interleaved with continuous columns, varying dummycode widths, a column listed in both recode and dummycode, hash-only columns, and a mixed hash/dummycode/recode row. Also cover the JSONException wrapping path in validate.
1 parent 26c5942 commit 34ee082

2 files changed

Lines changed: 237 additions & 82 deletions

File tree

src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java

Lines changed: 123 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,15 @@
3131
import org.apache.sysds.runtime.matrix.operators.MultiThreadedOperator;
3232
import org.apache.sysds.runtime.transform.TfUtils.TfMethod;
3333
import org.apache.sysds.runtime.util.UtilFunctions;
34-
import org.apache.wink.json4j.JSONArray;
34+
import org.apache.wink.json4j.JSONException;
3535
import org.apache.wink.json4j.JSONObject;
3636

3737
public class BinaryFrameScalarCPInstruction extends BinaryCPInstruction {
3838
// private static final Log LOG = LogFactory.getLog(BinaryFrameFrameCPInstruction.class.getName());
3939

40+
private static final TfMethod[] UNSUPPORTED_MASK_METHODS = new TfMethod[] {TfMethod.BIN,
41+
TfMethod.WORD_EMBEDDING, TfMethod.BAG_OF_WORDS, TfMethod.UDF};
42+
4043
protected BinaryFrameScalarCPInstruction(MultiThreadedOperator op, CPOperand in1, CPOperand in2, CPOperand out,
4144
String opcode, String istr) {
4245
super(CPType.Binary, op, in1, in2, out, opcode, istr);
@@ -58,108 +61,146 @@ public void processInstruction(ExecutionContext ec) {
5861
ec.releaseFrameInput(input1.getName());
5962
}
6063

61-
public void processGetCategorical(ExecutionContext ec, FrameBlock f, ScalarObject spec) {
64+
private static void validate(JSONObject jSpec) {
6265
try {
66+
if(!jSpec.containsKey("ids") || !jSpec.getBoolean("ids"))
67+
throw new DMLRuntimeException("not supported non ID based spec for get_categorical_mask");
6368

64-
// MatrixBlock ret = new MatrixBlock();
65-
int nCol = f.getNumColumns();
69+
for(TfMethod m : UNSUPPORTED_MASK_METHODS)
70+
if(jSpec.containsKey(m.toString()))
71+
throw new DMLRuntimeException("unsupported transform method '" + m + "' for get_categorical_mask");
72+
}
73+
catch(JSONException e) {
74+
throw new DMLRuntimeException(e);
75+
}
76+
}
6677

78+
public void processGetCategorical(ExecutionContext ec, FrameBlock f, ScalarObject spec) {
79+
try {
80+
// 1. extract the spec, 2. validate it
6781
JSONObject jSpec = new JSONObject(spec.getStringValue());
82+
validate(jSpec);
6883

69-
if(!jSpec.containsKey("ids") || !jSpec.getBoolean("ids")) {
70-
throw new DMLRuntimeException("not supported non ID based spec for get_categorical_mask");
71-
}
84+
// 3.-5. fold each supported transform method into the per-column mask state
85+
CategoricalMask mask = new CategoricalMask(f, jSpec);
86+
mask.hash();
87+
mask.recode();
88+
mask.dummycode();
7289

73-
// get_categorical_mask only models the column expansion of recode/dummycode/hash.
74-
// Methods that change the output arity (bin expands under dummycode, word_embedding and
75-
// bag_of_words map to many columns) or are user-defined (udf) would produce a mask with
76-
// the wrong number of columns, so reject them explicitly instead of emitting a silently
77-
// incorrect result. impute and omit are intentionally allowed: they do not alter the
78-
// output column count or the categorical flag of a column.
79-
for(TfMethod m : new TfMethod[] {TfMethod.BIN, TfMethod.WORD_EMBEDDING, TfMethod.BAG_OF_WORDS,
80-
TfMethod.UDF}) {
81-
if(jSpec.containsKey(m.toString()))
82-
throw new DMLRuntimeException(
83-
"unsupported transform method '" + m + "' for get_categorical_mask");
84-
}
90+
// 6.-7. size and materialize the output mask
91+
ec.setMatrixOutput(output.getName(), mask.toMatrixBlock());
92+
}
93+
catch(Exception e) {
94+
throw new DMLRuntimeException(e);
95+
}
96+
}
8597

86-
String recode = TfMethod.RECODE.toString();
87-
String dummycode = TfMethod.DUMMYCODE.toString();
88-
String hash = TfMethod.HASH.toString();
98+
/**
99+
* Accumulates, per input column, how many output columns it expands to (lengths) and whether those
100+
* output columns are categorical (categorical). The arrays are allocated lazily: a column that no
101+
* method touches keeps the implicit default of a single, non-categorical output column.
102+
*/
103+
private static final class CategoricalMask {
104+
private final FrameBlock f;
105+
private final JSONObject jSpec;
106+
private final int nCol;
107+
108+
private int[] lengths = null;
109+
private boolean[] categorical = null;
110+
111+
// feature-hashed columns map to K buckets; a plain hashed column produces a single
112+
// (categorical) bucket-id column, while a hashed column that is additionally dummycoded
113+
// expands to K columns.
114+
private boolean[] hashed = null;
115+
private int K = 0;
116+
117+
private CategoricalMask(FrameBlock f, JSONObject jSpec) {
118+
this.f = f;
119+
this.jSpec = jSpec;
120+
this.nCol = f.getNumColumns();
121+
}
89122

90-
int[] lengths = new int[nCol];
91-
// assume all columns encode to at least one column.
92-
Arrays.fill(lengths, 1);
93-
boolean[] categorical = new boolean[nCol];
94-
95-
// feature-hashed columns map to K buckets; a plain hashed column
96-
// produces a single (categorical) bucket-id column, while a hashed
97-
// column that is additionally dummycoded expands to K columns.
98-
boolean[] hashed = new boolean[nCol];
99-
int K = 0;
100-
if(jSpec.containsKey(hash)) {
101-
K = jSpec.getInt("K");
102-
JSONArray a = jSpec.getJSONArray(hash);
103-
for(Object aa : a) {
104-
int av = (Integer) aa - 1;
105-
hashed[av] = true;
106-
categorical[av] = true;
107-
}
123+
private void hash() throws JSONException {
124+
String hash = TfMethod.HASH.toString();
125+
if(!jSpec.containsKey(hash))
126+
return;
127+
K = jSpec.getInt("K");
128+
hashed = new boolean[nCol];
129+
ensureCategorical();
130+
for(Object aa : jSpec.getJSONArray(hash)) {
131+
int av = (Integer) aa - 1;
132+
hashed[av] = true;
133+
categorical[av] = true;
108134
}
135+
}
109136

110-
if(jSpec.containsKey(recode)) {
111-
JSONArray a = jSpec.getJSONArray(recode);
112-
for(Object aa : a) {
113-
int av = (Integer) aa - 1;
114-
categorical[av] = true;
115-
}
137+
private void recode() throws JSONException {
138+
String recode = TfMethod.RECODE.toString();
139+
if(!jSpec.containsKey(recode))
140+
return;
141+
ensureCategorical();
142+
for(Object aa : jSpec.getJSONArray(recode)) {
143+
int av = (Integer) aa - 1;
144+
categorical[av] = true;
116145
}
146+
}
117147

118-
if(jSpec.containsKey(dummycode)) {
119-
JSONArray a = jSpec.getJSONArray(dummycode);
120-
for(Object aa : a) {
121-
int av = (Integer) aa - 1;
122-
int ndist;
123-
if(hashed[av]) {
124-
// feature hashing followed by dummycoding yields K columns
125-
ndist = K;
126-
}
127-
else {
128-
ColumnMetadata d = f.getColumnMetadata()[av];
129-
String v = f.getString(0, av);
130-
if(v.length() > 1 && v.charAt(0) == '¿') {
131-
ndist = UtilFunctions.parseToInt(v.substring(1));
132-
}
133-
else {
134-
ndist = d.isDefault() ? 0 : (int) d.getNumDistinct();
135-
}
136-
}
137-
lengths[av] = ndist;
138-
categorical[av] = true;
139-
}
148+
private void dummycode() throws JSONException {
149+
String dummycode = TfMethod.DUMMYCODE.toString();
150+
if(!jSpec.containsKey(dummycode))
151+
return;
152+
ensureCategorical();
153+
ensureLengths();
154+
for(Object aa : jSpec.getJSONArray(dummycode)) {
155+
int av = (Integer) aa - 1;
156+
lengths[av] = distinctCount(av);
157+
categorical[av] = true;
140158
}
159+
}
141160

142-
// get total size after mapping
161+
private int distinctCount(int av) {
162+
if(hashed != null && hashed[av])
163+
// feature hashing followed by dummycoding yields K columns
164+
return K;
165+
ColumnMetadata d = f.getColumnMetadata()[av];
166+
String v = f.getString(0, av);
167+
if(v.length() > 1 && v.charAt(0) == '¿')
168+
return UtilFunctions.parseToInt(v.substring(1));
169+
return d.isDefault() ? 0 : (int) d.getNumDistinct();
170+
}
143171

144-
int sumLengths = 0;
145-
for(int i : lengths) {
146-
sumLengths += i;
147-
}
172+
private int sumLengths() {
173+
if(lengths == null)
174+
return nCol;
175+
int sum = 0;
176+
for(int i = 0; i < nCol; i++)
177+
sum += lengths[i];
178+
return sum;
179+
}
148180

149-
MatrixBlock ret = new MatrixBlock(1, sumLengths, false);
181+
private MatrixBlock toMatrixBlock() {
182+
MatrixBlock ret = new MatrixBlock(1, sumLengths(), false);
150183
ret.allocateDenseBlock();
151184
int off = 0;
152-
for(int i = 0; i < lengths.length; i++) {
153-
for(int j = 0; j < lengths[i]; j++) {
154-
ret.set(0, off++, categorical[i] ? 1 : 0);
155-
}
185+
for(int i = 0; i < nCol; i++) {
186+
int len = (lengths == null) ? 1 : lengths[i];
187+
double val = (categorical != null && categorical[i]) ? 1 : 0;
188+
for(int j = 0; j < len; j++)
189+
ret.set(0, off++, val);
156190
}
191+
return ret;
192+
}
157193

158-
ec.setMatrixOutput(output.getName(), ret);
159-
194+
private void ensureCategorical() {
195+
if(categorical == null)
196+
categorical = new boolean[nCol];
160197
}
161-
catch(Exception e) {
162-
throw new DMLRuntimeException(e);
198+
199+
private void ensureLengths() {
200+
if(lengths == null) {
201+
lengths = new int[nCol];
202+
Arrays.fill(lengths, 1);
203+
}
163204
}
164205
}
165206
}

src/test/java/org/apache/sysds/test/component/frame/transform/GetCategoricalMaskInstructionTest.java

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
3636
import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory;
3737
import org.apache.sysds.runtime.frame.data.FrameBlock;
38+
import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata;
3839
import org.apache.sysds.runtime.instructions.InstructionUtils;
3940
import org.apache.sysds.runtime.instructions.cp.BinaryCPInstruction;
4041
import org.apache.sysds.runtime.instructions.cp.BinaryFrameScalarCPInstruction;
@@ -84,6 +85,80 @@ public void dummycodeDefaultMetadataContributesNoColumns() {
8485
assertEquals(0.0, res.get(0, 0), 0.0);
8586
}
8687

88+
@Test
89+
public void noMethodAllColumnsPassThrough() {
90+
// a spec with only "ids" touches no column: every column is a single, non-categorical output
91+
FrameBlock meta = metaWithDistinct(3, new int[] {0, 0, 0});
92+
MatrixBlock res = run(meta, "{\"ids\": true}");
93+
94+
assertMask(res, new double[] {0, 0, 0});
95+
}
96+
97+
@Test
98+
public void recodeInterleavedWithPassThrough() {
99+
// categorical (recode, 1 col each) interleaved with continuous pass-through columns
100+
FrameBlock meta = metaWithDistinct(5, new int[] {0, 0, 0, 0, 0});
101+
MatrixBlock res = run(meta, "{\"ids\": true, \"recode\": [1, 4]}");
102+
103+
assertMask(res, new double[] {1, 0, 0, 1, 0});
104+
}
105+
106+
@Test
107+
public void leadingPassThroughThenDummycodeOffsets() {
108+
// the dummycode expansion must start at the correct offset after three continuous columns
109+
FrameBlock meta = metaWithDistinct(4, new int[] {0, 0, 0, 3});
110+
MatrixBlock res = run(meta, "{\"ids\": true, \"dummycode\": [4]}");
111+
112+
assertMask(res, new double[] {0, 0, 0, 1, 1, 1});
113+
}
114+
115+
@Test
116+
public void multipleDummycodeVaryingDistinctCounts() {
117+
// several dummycoded columns of different widths, all categorical, no pass-through
118+
FrameBlock meta = metaWithDistinct(3, new int[] {2, 4, 1});
119+
MatrixBlock res = run(meta, "{\"ids\": true, \"dummycode\": [1, 2, 3]}");
120+
121+
assertMask(res, new double[] {1, 1, 1, 1, 1, 1, 1});
122+
}
123+
124+
@Test
125+
public void dummycodeAndPassThroughAndRecodeInterleaved() {
126+
// dummycode(3) | pass-through | recode | dummycode(2): exercises every offset transition
127+
FrameBlock meta = metaWithDistinct(4, new int[] {3, 0, 0, 2});
128+
MatrixBlock res = run(meta, "{\"ids\": true, \"recode\": [3], \"dummycode\": [1, 4]}");
129+
130+
assertMask(res, new double[] {1, 1, 1, 0, 1, 1, 1});
131+
}
132+
133+
@Test
134+
public void recodeAndDummycodeOnSameColumnExpands() {
135+
// a column listed in both recode and dummycode must expand to its dummycode width, not collapse
136+
FrameBlock meta = metaWithDistinct(2, new int[] {4, 0});
137+
MatrixBlock res = run(meta, "{\"ids\": true, \"recode\": [1], \"dummycode\": [1]}");
138+
139+
assertMask(res, new double[] {1, 1, 1, 1, 0});
140+
}
141+
142+
@Test
143+
public void hashOnlyColumnStaysSingleCategorical() {
144+
// a hashed-but-not-dummycoded column is a single categorical column; K must not widen it
145+
FrameBlock meta = metaWithDistinct(3, new int[] {0, 0, 0});
146+
MatrixBlock res = run(meta, "{\"ids\": true, \"hash\": [2], \"K\": 5}");
147+
148+
assertMask(res, new double[] {0, 1, 0});
149+
}
150+
151+
@Test
152+
public void hashDummycodeRecodePassThroughMixed() {
153+
// col1: hash+dummycode -> K=3 (metadata ignored); col2: pass-through; col3: dummycode(9);
154+
// col4: pass-through; col5: recode. Verifies hashed columns use K while plain dummycode uses
155+
// the metadata distinct count, with correct offsets across the whole row.
156+
FrameBlock meta = metaWithDistinct(5, new int[] {0, 0, 9, 0, 0});
157+
MatrixBlock res = run(meta, "{\"ids\": true, \"recode\": [5], \"dummycode\": [1, 3], \"hash\": [1], \"K\": 3}");
158+
159+
assertMask(res, new double[] {1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1});
160+
}
161+
87162
@Test
88163
public void nonIdSpecMissingIdsKeyThrows() {
89164
// a spec without the "ids" key must be rejected, not silently mis-interpreted
@@ -142,6 +217,14 @@ public void imputeAndOmitAreAccepted() {
142217
assertEquals(1.0, res.get(0, 0), 0.0);
143218
}
144219

220+
@Test
221+
public void malformedSpecWrapsJsonException() {
222+
// "ids" present but not a boolean makes spec parsing throw a JSONException, which must be
223+
// wrapped as a DMLRuntimeException rather than propagating raw
224+
FrameBlock meta = new FrameBlock(new ValueType[] {ValueType.STRING}, new String[][] {{"a"}});
225+
assertThrowsMessage("was not a boolean", () -> run(meta, "{\"ids\": 5, \"recode\": [1]}"));
226+
}
227+
145228
@Test
146229
public void unsupportedOpcodeThrows() {
147230
// any frame-scalar binary opcode other than get_categorical_mask must be rejected
@@ -167,6 +250,37 @@ private static void assertThrowsMessage(String expected, Runnable action) {
167250
}
168251
}
169252

253+
/** Assert the mask is a single row equal to the expected values (which also fixes its width). */
254+
private static void assertMask(MatrixBlock res, double[] expected) {
255+
assertEquals(1, res.getNumRows());
256+
assertEquals(expected.length, res.getNumColumns());
257+
// compare per cell rather than via getDenseBlockValues(): an all-zero mask has nnz == 0 and
258+
// therefore no materialized dense block
259+
double[] actual = new double[expected.length];
260+
for(int i = 0; i < expected.length; i++)
261+
actual[i] = res.get(0, i);
262+
assertArrayEquals(expected, actual, 0.0);
263+
}
264+
265+
/**
266+
* Build a single-row metadata frame of nCol string columns. A positive distinct[i] is written to
267+
* that column's metadata as the recode distinct count (the path real transformencode uses), while
268+
* a zero leaves the column with default metadata (a continuous / non-dummycoded column).
269+
*/
270+
private static FrameBlock metaWithDistinct(int nCol, int[] distinct) {
271+
ValueType[] schema = new ValueType[nCol];
272+
String[][] data = new String[1][nCol];
273+
for(int i = 0; i < nCol; i++) {
274+
schema[i] = ValueType.STRING;
275+
data[0][i] = "v";
276+
}
277+
FrameBlock fb = new FrameBlock(schema, data);
278+
for(int i = 0; i < nCol; i++)
279+
if(distinct[i] > 0)
280+
fb.setColumnMetadata(i, new ColumnMetadata(distinct[i]));
281+
return fb;
282+
}
283+
170284
private static MatrixBlock run(FrameBlock meta, String spec) {
171285
ExecutionContext ec = ExecutionContextFactory.createContext();
172286
ec.setAutoCreateVars(true);

0 commit comments

Comments
 (0)