Skip to content

Commit e4f0987

Browse files
authored
[BWARE] Add getCategoricalMask DML builtin (#2482)
* Add getCategoricalMask DML builtin Adds a new builtin that, given a transform-encode metadata frame and the encoding JSON spec, returns a 1xN matrix mask marking which output columns are categorical (1) versus continuous (0). Useful when callers need to know the category boundary in transformed output without re-deriving it from the spec.
1 parent 65e734e commit e4f0987

13 files changed

Lines changed: 766 additions & 2 deletions

File tree

src/main/java/org/apache/sysds/common/Builtins.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ public enum Builtins {
154154
GARCH("garch", true),
155155
GAUSSIAN_CLASSIFIER("gaussianClassifier", true),
156156
GET_ACCURACY("getAccuracy", true),
157+
GET_CATEGORICAL_MASK("getCategoricalMask", false),
157158
GLM("glm", true),
158159
GLM_PREDICT("glmPredict", true),
159160
GLOVE("glove", true),

src/main/java/org/apache/sysds/common/Opcodes.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,8 @@ public enum Opcodes {
215215
TRANSFORMMETA("transformmeta", InstructionType.ParameterizedBuiltin),
216216
TRANSFORMENCODE("transformencode", InstructionType.MultiReturnParameterizedBuiltin, InstructionType.MultiReturnBuiltin),
217217

218+
GET_CATEGORICAL_MASK("get_categorical_mask", InstructionType.Binary),
219+
218220
//Ternary instruction opcodes
219221
PM("+*", InstructionType.Ternary),
220222
MINUSMULT("-*", InstructionType.Ternary),

src/main/java/org/apache/sysds/common/Types.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,7 @@ public enum OpOp2 {
639639
MINUS_NZ(false), //sparse-safe minus: X-(mean*ppred(X,0,!=))
640640
LOG_NZ(false), //sparse-safe log; ppred(X,0,"!=")*log(X,0.5)
641641
MINUS1_MULT(false), //1-X*Y
642+
GET_CATEGORICAL_MASK(false), // get transformation mask
642643
QUANTIZE_COMPRESS(false), //quantization-fused compression
643644
UNION_DISTINCT(false);
644645

src/main/java/org/apache/sysds/hops/BinaryOp.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -853,7 +853,10 @@ else if( (op == OpOp2.CBIND && getDataType().isList())
853853
|| (op == OpOp2.RBIND && getDataType().isList())) {
854854
_etype = ExecType.CP;
855855
}
856-
856+
857+
if( op == OpOp2.GET_CATEGORICAL_MASK)
858+
_etype = ExecType.CP;
859+
857860
//mark for recompile (forever)
858861
setRequiresRecompileIfNecessary();
859862

src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2018,6 +2018,15 @@ else if(this.getOpCode() == Builtins.MAX_POOL || this.getOpCode() == Builtins.AV
20182018
else
20192019
raiseValidateError("The compress or decompress instruction is not allowed in dml scripts");
20202020
break;
2021+
case GET_CATEGORICAL_MASK:
2022+
checkNumParameters(2);
2023+
checkFrameParam(getFirstExpr());
2024+
checkScalarParam(getSecondExpr());
2025+
output.setDataType(DataType.MATRIX);
2026+
output.setDimensions(1, -1);
2027+
output.setBlocksize( id.getBlocksize());
2028+
output.setValueType(ValueType.FP64);
2029+
break;
20212030
case QUANTIZE_COMPRESS:
20222031
if(OptimizerUtils.ALLOW_SCRIPT_LEVEL_QUANTIZE_COMPRESS_COMMAND) {
20232032
checkNumParameters(2);
@@ -2383,6 +2392,13 @@ protected void checkMatrixFrameParam(Expression e) { //always unconditional
23832392
raiseValidateError("Expecting matrix or frame parameter for function "+ getOpCode(), false, LanguageErrorCodes.UNSUPPORTED_PARAMETERS);
23842393
}
23852394
}
2395+
2396+
protected void checkFrameParam(Expression e) {
2397+
if(e.getOutput().getDataType() != DataType.FRAME) {
2398+
raiseValidateError("Expecting frame parameter for function " + getOpCode(), false,
2399+
LanguageErrorCodes.UNSUPPORTED_PARAMETERS);
2400+
}
2401+
}
23862402

23872403
protected void checkMatrixScalarParam(Expression e) { //always unconditional
23882404
if (e.getOutput().getDataType() != DataType.MATRIX && e.getOutput().getDataType() != DataType.SCALAR) {

src/main/java/org/apache/sysds/parser/DMLTranslator.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2821,6 +2821,9 @@ else if ( in.length == 2 )
28212821
DataType.MATRIX, target.getValueType(), AggOp.COUNT_DISTINCT, Direction.Col, expr);
28222822
break;
28232823

2824+
case GET_CATEGORICAL_MASK:
2825+
currBuiltinOp = new BinaryOp(target.getName(), DataType.MATRIX, ValueType.FP64, OpOp2.GET_CATEGORICAL_MASK, expr, expr2);
2826+
break;
28242827
default:
28252828
throw new ParseException("Unsupported builtin function type: "+source.getOpCode());
28262829
}

src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ public enum BuiltinCode { AUTODIFF, SIN, COS, TAN, SINH, COSH, TANH, ASIN, ACOS,
5454
MAX, ABS, SIGN, SQRT, EXP, PLOGP, PRINT, PRINTF, NROW, NCOL, LENGTH, LINEAGE, ROUND, MAXINDEX, MININDEX,
5555
STOP, CEIL, FLOOR, CUMSUM, ROWCUMSUM, CUMPROD, CUMMIN, CUMMAX, CUMSUMPROD, INVERSE, SPROP, SIGMOID, EVAL, LIST,
5656
TYPEOF, APPLY_SCHEMA, DETECTSCHEMA, ISNA, ISNAN, ISINF, DROP_INVALID_TYPE,
57-
DROP_INVALID_LENGTH, VALUE_SWAP, FRAME_ROW_REPLICATE,
57+
DROP_INVALID_LENGTH, VALUE_SWAP, FRAME_ROW_REPLICATE, GET_CATEGORICAL_MASK,
5858
MAP, COUNT_DISTINCT, COUNT_DISTINCT_APPROX, UNIQUE}
5959

6060
private static final VectorSpecies<Double> SPECIES = DoubleVector.SPECIES_PREFERRED;
@@ -120,6 +120,7 @@ public enum BuiltinCode { AUTODIFF, SIN, COS, TAN, SINH, COSH, TANH, ASIN, ACOS,
120120
String2BuiltinCode.put( "_map", BuiltinCode.MAP);
121121
String2BuiltinCode.put( "valueSwap", BuiltinCode.VALUE_SWAP);
122122
String2BuiltinCode.put( "applySchema", BuiltinCode.APPLY_SCHEMA);
123+
String2BuiltinCode.put( "get_categorical_mask", BuiltinCode.GET_CATEGORICAL_MASK);
123124
}
124125

125126
protected Builtin(BuiltinCode bf) {

src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ else if (in1.getDataType() == DataType.TENSOR && in2.getDataType() == DataType.T
5959
return new BinaryTensorTensorCPInstruction(operator, in1, in2, out, opcode, str);
6060
else if (in1.getDataType() == DataType.FRAME && in2.getDataType() == DataType.FRAME)
6161
return new BinaryFrameFrameCPInstruction(operator, in1, in2, out, opcode, str);
62+
else if (in1.getDataType() == DataType.FRAME && in2.getDataType() == DataType.SCALAR)
63+
return new BinaryFrameScalarCPInstruction(operator, in1, in2, out, opcode, str);
6264
else if (in1.getDataType() == DataType.FRAME && in2.getDataType() == DataType.MATRIX)
6365
return new BinaryFrameMatrixCPInstruction(operator, in1, in2, out, opcode, str);
6466
else
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.runtime.instructions.cp;
21+
22+
import java.util.Arrays;
23+
24+
import org.apache.sysds.common.Builtins;
25+
import org.apache.sysds.common.Types.ValueType;
26+
import org.apache.sysds.runtime.DMLRuntimeException;
27+
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
28+
import org.apache.sysds.runtime.frame.data.FrameBlock;
29+
import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata;
30+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
31+
import org.apache.sysds.runtime.matrix.operators.MultiThreadedOperator;
32+
import org.apache.sysds.runtime.transform.TfUtils.TfMethod;
33+
import org.apache.sysds.runtime.util.UtilFunctions;
34+
import org.apache.wink.json4j.JSONException;
35+
import org.apache.wink.json4j.JSONObject;
36+
37+
public class BinaryFrameScalarCPInstruction extends BinaryCPInstruction {
38+
// private static final Log LOG = LogFactory.getLog(BinaryFrameFrameCPInstruction.class.getName());
39+
40+
private static final TfMethod[] UNSUPPORTED_MASK_METHODS = new TfMethod[] {TfMethod.BIN,
41+
TfMethod.WORD_EMBEDDING, TfMethod.BAG_OF_WORDS, TfMethod.UDF};
42+
43+
protected BinaryFrameScalarCPInstruction(MultiThreadedOperator op, CPOperand in1, CPOperand in2, CPOperand out,
44+
String opcode, String istr) {
45+
super(CPType.Binary, op, in1, in2, out, opcode, istr);
46+
}
47+
48+
@Override
49+
public void processInstruction(ExecutionContext ec) {
50+
// get input frames
51+
FrameBlock inBlock1 = ec.getFrameInput(input1.getName());
52+
ScalarObject spec = ec.getScalarInput(input2.getName(), ValueType.STRING, true);
53+
if(getOpcode().equals(Builtins.GET_CATEGORICAL_MASK.toString().toLowerCase())) {
54+
processGetCategorical(ec, inBlock1, spec);
55+
}
56+
else {
57+
throw new DMLRuntimeException("Unsupported operation");
58+
}
59+
60+
// Release the memory occupied by input frames
61+
ec.releaseFrameInput(input1.getName());
62+
}
63+
64+
private static void validate(JSONObject jSpec) {
65+
try {
66+
if(!jSpec.containsKey("ids") || !jSpec.getBoolean("ids"))
67+
throw new DMLRuntimeException("not supported non ID based spec for get_categorical_mask");
68+
69+
for(TfMethod m : UNSUPPORTED_MASK_METHODS)
70+
if(jSpec.containsKey(m.toString()))
71+
throw new DMLRuntimeException("unsupported transform method '" + m + "' for get_categorical_mask");
72+
}
73+
catch(JSONException e) {
74+
throw new DMLRuntimeException(e);
75+
}
76+
}
77+
78+
public void processGetCategorical(ExecutionContext ec, FrameBlock f, ScalarObject spec) {
79+
try {
80+
// 1. extract the spec, 2. validate it
81+
JSONObject jSpec = new JSONObject(spec.getStringValue());
82+
validate(jSpec);
83+
84+
// 3.-5. fold each supported transform method into the per-column mask state
85+
CategoricalMask mask = new CategoricalMask(f, jSpec);
86+
mask.hash();
87+
mask.recode();
88+
mask.dummycode();
89+
90+
// 6.-7. size and materialize the output mask
91+
ec.setMatrixOutput(output.getName(), mask.toMatrixBlock());
92+
}
93+
catch(Exception e) {
94+
throw new DMLRuntimeException(e);
95+
}
96+
}
97+
98+
/**
99+
* Accumulates, per input column, how many output columns it expands to (lengths) and whether those
100+
* output columns are categorical (categorical). The arrays are allocated lazily: a column that no
101+
* method touches keeps the implicit default of a single, non-categorical output column.
102+
*/
103+
private static final class CategoricalMask {
104+
private final FrameBlock f;
105+
private final JSONObject jSpec;
106+
private final int nCol;
107+
108+
private int[] lengths = null;
109+
private boolean[] categorical = null;
110+
111+
// feature-hashed columns map to K buckets; a plain hashed column produces a single
112+
// (categorical) bucket-id column, while a hashed column that is additionally dummycoded
113+
// expands to K columns.
114+
private boolean[] hashed = null;
115+
private int K = 0;
116+
117+
private CategoricalMask(FrameBlock f, JSONObject jSpec) {
118+
this.f = f;
119+
this.jSpec = jSpec;
120+
this.nCol = f.getNumColumns();
121+
}
122+
123+
private void hash() throws JSONException {
124+
String hash = TfMethod.HASH.toString();
125+
if(!jSpec.containsKey(hash))
126+
return;
127+
K = jSpec.getInt("K");
128+
hashed = new boolean[nCol];
129+
ensureCategorical();
130+
for(Object aa : jSpec.getJSONArray(hash)) {
131+
int av = (Integer) aa - 1;
132+
hashed[av] = true;
133+
categorical[av] = true;
134+
}
135+
}
136+
137+
private void recode() throws JSONException {
138+
String recode = TfMethod.RECODE.toString();
139+
if(!jSpec.containsKey(recode))
140+
return;
141+
ensureCategorical();
142+
for(Object aa : jSpec.getJSONArray(recode)) {
143+
int av = (Integer) aa - 1;
144+
categorical[av] = true;
145+
}
146+
}
147+
148+
private void dummycode() throws JSONException {
149+
String dummycode = TfMethod.DUMMYCODE.toString();
150+
if(!jSpec.containsKey(dummycode))
151+
return;
152+
ensureCategorical();
153+
ensureLengths();
154+
for(Object aa : jSpec.getJSONArray(dummycode)) {
155+
int av = (Integer) aa - 1;
156+
lengths[av] = distinctCount(av);
157+
categorical[av] = true;
158+
}
159+
}
160+
161+
private int distinctCount(int av) {
162+
if(hashed != null && hashed[av])
163+
// feature hashing followed by dummycoding yields K columns
164+
return K;
165+
ColumnMetadata d = f.getColumnMetadata()[av];
166+
String v = f.getString(0, av);
167+
if(v.length() > 1 && v.charAt(0) == '¿')
168+
return UtilFunctions.parseToInt(v.substring(1));
169+
return d.isDefault() ? 0 : (int) d.getNumDistinct();
170+
}
171+
172+
private int sumLengths() {
173+
if(lengths == null)
174+
return nCol;
175+
int sum = 0;
176+
for(int i = 0; i < nCol; i++)
177+
sum += lengths[i];
178+
return sum;
179+
}
180+
181+
private MatrixBlock toMatrixBlock() {
182+
MatrixBlock ret = new MatrixBlock(1, sumLengths(), false);
183+
ret.allocateDenseBlock();
184+
int off = 0;
185+
for(int i = 0; i < nCol; i++) {
186+
int len = (lengths == null) ? 1 : lengths[i];
187+
double val = (categorical != null && categorical[i]) ? 1 : 0;
188+
for(int j = 0; j < len; j++)
189+
ret.set(0, off++, val);
190+
}
191+
return ret;
192+
}
193+
194+
private void ensureCategorical() {
195+
if(categorical == null)
196+
categorical = new boolean[nCol];
197+
}
198+
199+
private void ensureLengths() {
200+
if(lengths == null) {
201+
lengths = new int[nCol];
202+
Arrays.fill(lengths, 1);
203+
}
204+
}
205+
}
206+
}

src/test/java/org/apache/sysds/test/TestUtils.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2941,6 +2941,25 @@ public static void writeTestScalar(String file, double value) {
29412941
}
29422942
}
29432943

2944+
2945+
/**
2946+
* Write scalar to file
2947+
*
2948+
* @param file File to write to
2949+
* @param value Value to write
2950+
*/
2951+
public static void writeTestScalar(String file, String value) {
2952+
try {
2953+
DataOutputStream out = new DataOutputStream(new FileOutputStream(file));
2954+
try(PrintWriter pw = new PrintWriter(out)) {
2955+
pw.println(value);
2956+
}
2957+
}
2958+
catch(IOException e) {
2959+
fail("unable to write test scalar (" + file + "): " + e.getMessage());
2960+
}
2961+
}
2962+
29442963
/**
29452964
* Write scalar to file
29462965
*

0 commit comments

Comments
 (0)