Skip to content

Commit a2165d7

Browse files
committed
Add getCategoricalMask DML builtin
Adds a new builtin that, given a transform-encode metadata frame and the encoding JSON spec, returns a 1xN matrix mask marking which output columns are categorical (1) versus continuous (0). Useful when callers need to know the category boundary in transformed output without re-deriving it from the spec. - Register GET_CATEGORICAL_MASK in Builtins, Opcodes, Types (OpOp2), Builtin (functionobject) - Validate it as a frame+scalar binary in BuiltinFunctionExpression (new checkFrameParam helper) and lower it to a BinaryOp in DMLTranslator - Force CP execution for the new op in BinaryOp.optFindExecType - Implement runtime in BinaryFrameScalarCPInstruction and route FRAME+SCALAR binary instructions to it in BinaryCPInstruction - Add writeTestScalar(String, String) overload to TestUtils - Cover recode, dummycode, hash, and hybrid specs in GetCategoricalMaskTest (note: hash variants depend on the decoder/encoder hash-column changes in a separate branch)
1 parent 9a4e2a3 commit a2165d7

12 files changed

Lines changed: 385 additions & 2 deletions

File tree

src/main/java/org/apache/sysds/common/Builtins.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ public enum Builtins {
154154
GARCH("garch", true),
155155
GAUSSIAN_CLASSIFIER("gaussianClassifier", true),
156156
GET_ACCURACY("getAccuracy", true),
157+
GET_CATEGORICAL_MASK("getCategoricalMask", false),
157158
GLM("glm", true),
158159
GLM_PREDICT("glmPredict", true),
159160
GLOVE("glove", true),

src/main/java/org/apache/sysds/common/Opcodes.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,8 @@ public enum Opcodes {
215215
TRANSFORMMETA("transformmeta", InstructionType.ParameterizedBuiltin),
216216
TRANSFORMENCODE("transformencode", InstructionType.MultiReturnParameterizedBuiltin, InstructionType.MultiReturnBuiltin),
217217

218+
GET_CATEGORICAL_MASK("get_categorical_mask", InstructionType.Binary),
219+
218220
//Ternary instruction opcodes
219221
PM("+*", InstructionType.Ternary),
220222
MINUSMULT("-*", InstructionType.Ternary),

src/main/java/org/apache/sysds/common/Types.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,7 @@ public enum OpOp2 {
639639
MINUS_NZ(false), //sparse-safe minus: X-(mean*ppred(X,0,!=))
640640
LOG_NZ(false), //sparse-safe log; ppred(X,0,"!=")*log(X,0.5)
641641
MINUS1_MULT(false), //1-X*Y
642+
GET_CATEGORICAL_MASK(false), // get transformation mask
642643
QUANTIZE_COMPRESS(false), //quantization-fused compression
643644
UNION_DISTINCT(false);
644645

src/main/java/org/apache/sysds/hops/BinaryOp.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -853,7 +853,10 @@ else if( (op == OpOp2.CBIND && getDataType().isList())
853853
|| (op == OpOp2.RBIND && getDataType().isList())) {
854854
_etype = ExecType.CP;
855855
}
856-
856+
857+
if( op == OpOp2.GET_CATEGORICAL_MASK)
858+
_etype = ExecType.CP;
859+
857860
//mark for recompile (forever)
858861
setRequiresRecompileIfNecessary();
859862

src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2018,6 +2018,15 @@ else if(this.getOpCode() == Builtins.MAX_POOL || this.getOpCode() == Builtins.AV
20182018
else
20192019
raiseValidateError("The compress or decompress instruction is not allowed in dml scripts");
20202020
break;
2021+
case GET_CATEGORICAL_MASK:
2022+
checkNumParameters(2);
2023+
checkFrameParam(getFirstExpr());
2024+
checkScalarParam(getSecondExpr());
2025+
output.setDataType(DataType.MATRIX);
2026+
output.setDimensions(1, -1);
2027+
output.setBlocksize( id.getBlocksize());
2028+
output.setValueType(ValueType.FP64);
2029+
break;
20212030
case QUANTIZE_COMPRESS:
20222031
if(OptimizerUtils.ALLOW_SCRIPT_LEVEL_QUANTIZE_COMPRESS_COMMAND) {
20232032
checkNumParameters(2);
@@ -2383,6 +2392,13 @@ protected void checkMatrixFrameParam(Expression e) { //always unconditional
23832392
raiseValidateError("Expecting matrix or frame parameter for function "+ getOpCode(), false, LanguageErrorCodes.UNSUPPORTED_PARAMETERS);
23842393
}
23852394
}
2395+
2396+
protected void checkFrameParam(Expression e) {
2397+
if(e.getOutput().getDataType() != DataType.FRAME) {
2398+
raiseValidateError("Expecting frame parameter for function " + getOpCode(), false,
2399+
LanguageErrorCodes.UNSUPPORTED_PARAMETERS);
2400+
}
2401+
}
23862402

23872403
protected void checkMatrixScalarParam(Expression e) { //always unconditional
23882404
if (e.getOutput().getDataType() != DataType.MATRIX && e.getOutput().getDataType() != DataType.SCALAR) {

src/main/java/org/apache/sysds/parser/DMLTranslator.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2821,6 +2821,9 @@ else if ( in.length == 2 )
28212821
DataType.MATRIX, target.getValueType(), AggOp.COUNT_DISTINCT, Direction.Col, expr);
28222822
break;
28232823

2824+
case GET_CATEGORICAL_MASK:
2825+
currBuiltinOp = new BinaryOp(target.getName(), DataType.MATRIX, ValueType.FP64, OpOp2.GET_CATEGORICAL_MASK, expr, expr2);
2826+
break;
28242827
default:
28252828
throw new ParseException("Unsupported builtin function type: "+source.getOpCode());
28262829
}

src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ public enum BuiltinCode { AUTODIFF, SIN, COS, TAN, SINH, COSH, TANH, ASIN, ACOS,
5454
MAX, ABS, SIGN, SQRT, EXP, PLOGP, PRINT, PRINTF, NROW, NCOL, LENGTH, LINEAGE, ROUND, MAXINDEX, MININDEX,
5555
STOP, CEIL, FLOOR, CUMSUM, ROWCUMSUM, CUMPROD, CUMMIN, CUMMAX, CUMSUMPROD, INVERSE, SPROP, SIGMOID, EVAL, LIST,
5656
TYPEOF, APPLY_SCHEMA, DETECTSCHEMA, ISNA, ISNAN, ISINF, DROP_INVALID_TYPE,
57-
DROP_INVALID_LENGTH, VALUE_SWAP, FRAME_ROW_REPLICATE,
57+
DROP_INVALID_LENGTH, VALUE_SWAP, FRAME_ROW_REPLICATE, GET_CATEGORICAL_MASK,
5858
MAP, COUNT_DISTINCT, COUNT_DISTINCT_APPROX, UNIQUE}
5959

6060
private static final VectorSpecies<Double> SPECIES = DoubleVector.SPECIES_PREFERRED;
@@ -120,6 +120,7 @@ public enum BuiltinCode { AUTODIFF, SIN, COS, TAN, SINH, COSH, TANH, ASIN, ACOS,
120120
String2BuiltinCode.put( "_map", BuiltinCode.MAP);
121121
String2BuiltinCode.put( "valueSwap", BuiltinCode.VALUE_SWAP);
122122
String2BuiltinCode.put( "applySchema", BuiltinCode.APPLY_SCHEMA);
123+
String2BuiltinCode.put( "get_categorical_mask", BuiltinCode.GET_CATEGORICAL_MASK);
123124
}
124125

125126
protected Builtin(BuiltinCode bf) {

src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ else if (in1.getDataType() == DataType.TENSOR && in2.getDataType() == DataType.T
5959
return new BinaryTensorTensorCPInstruction(operator, in1, in2, out, opcode, str);
6060
else if (in1.getDataType() == DataType.FRAME && in2.getDataType() == DataType.FRAME)
6161
return new BinaryFrameFrameCPInstruction(operator, in1, in2, out, opcode, str);
62+
else if (in1.getDataType() == DataType.FRAME && in2.getDataType() == DataType.SCALAR)
63+
return new BinaryFrameScalarCPInstruction(operator, in1, in2, out, opcode, str);
6264
else if (in1.getDataType() == DataType.FRAME && in2.getDataType() == DataType.MATRIX)
6365
return new BinaryFrameMatrixCPInstruction(operator, in1, in2, out, opcode, str);
6466
else
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.runtime.instructions.cp;
21+
22+
import java.util.Arrays;
23+
24+
import org.apache.sysds.common.Builtins;
25+
import org.apache.sysds.common.Types.ValueType;
26+
import org.apache.sysds.runtime.DMLRuntimeException;
27+
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
28+
import org.apache.sysds.runtime.frame.data.FrameBlock;
29+
import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata;
30+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
31+
import org.apache.sysds.runtime.matrix.operators.MultiThreadedOperator;
32+
import org.apache.sysds.runtime.transform.TfUtils.TfMethod;
33+
import org.apache.sysds.runtime.util.UtilFunctions;
34+
import org.apache.wink.json4j.JSONArray;
35+
import org.apache.wink.json4j.JSONObject;
36+
37+
public class BinaryFrameScalarCPInstruction extends BinaryCPInstruction {
38+
// private static final Log LOG = LogFactory.getLog(BinaryFrameFrameCPInstruction.class.getName());
39+
40+
protected BinaryFrameScalarCPInstruction(MultiThreadedOperator op, CPOperand in1, CPOperand in2, CPOperand out,
41+
String opcode, String istr) {
42+
super(CPType.Binary, op, in1, in2, out, opcode, istr);
43+
}
44+
45+
@Override
46+
public void processInstruction(ExecutionContext ec) {
47+
// get input frames
48+
FrameBlock inBlock1 = ec.getFrameInput(input1.getName());
49+
ScalarObject spec = ec.getScalarInput(input2.getName(), ValueType.STRING, true);
50+
if(getOpcode().equals(Builtins.GET_CATEGORICAL_MASK.toString().toLowerCase())) {
51+
processGetCategorical(ec, inBlock1, spec);
52+
}
53+
else {
54+
throw new DMLRuntimeException("Unsupported operation");
55+
}
56+
57+
// Release the memory occupied by input frames
58+
ec.releaseFrameInput(input1.getName());
59+
}
60+
61+
public void processGetCategorical(ExecutionContext ec, FrameBlock f, ScalarObject spec) {
62+
try {
63+
64+
// MatrixBlock ret = new MatrixBlock();
65+
int nCol = f.getNumColumns();
66+
67+
JSONObject jSpec = new JSONObject(spec.getStringValue());
68+
69+
if(!jSpec.containsKey("ids") && jSpec.getBoolean("ids")) {
70+
throw new DMLRuntimeException("not supported non ID based spec for get_categorical_mask");
71+
}
72+
73+
String recode = TfMethod.RECODE.toString();
74+
String dummycode = TfMethod.DUMMYCODE.toString();
75+
76+
int[] lengths = new int[nCol];
77+
// assume all columns encode to at least one column.
78+
Arrays.fill(lengths, 1);
79+
boolean[] categorical = new boolean[nCol];
80+
81+
if(jSpec.containsKey(recode)) {
82+
JSONArray a = jSpec.getJSONArray(recode);
83+
for(Object aa : a) {
84+
int av = (Integer) aa - 1;
85+
categorical[av] = true;
86+
}
87+
}
88+
89+
if(jSpec.containsKey(dummycode)) {
90+
JSONArray a = jSpec.getJSONArray(dummycode);
91+
for(Object aa : a) {
92+
int av = (Integer) aa - 1;
93+
ColumnMetadata d = f.getColumnMetadata()[av];
94+
String v = f.getString(0, av);
95+
int ndist;
96+
if(v.length() > 1 && v.charAt(0) == '¿') {
97+
ndist = UtilFunctions.parseToInt(v.substring(1));
98+
}
99+
else {
100+
ndist = d.isDefault() ? 0 : (int) d.getNumDistinct();
101+
}
102+
lengths[av] = ndist;
103+
categorical[av] = true;
104+
}
105+
}
106+
107+
// get total size after mapping
108+
109+
int sumLengths = 0;
110+
for(int i : lengths) {
111+
sumLengths += i;
112+
}
113+
114+
MatrixBlock ret = new MatrixBlock(1, sumLengths, false);
115+
ret.allocateDenseBlock();
116+
int off = 0;
117+
for(int i = 0; i < lengths.length; i++) {
118+
for(int j = 0; j < lengths[i]; j++) {
119+
ret.set(0, off++, categorical[i] ? 1 : 0);
120+
}
121+
}
122+
123+
ec.setMatrixOutput(output.getName(), ret);
124+
125+
}
126+
catch(Exception e) {
127+
throw new DMLRuntimeException(e);
128+
}
129+
}
130+
}

src/test/java/org/apache/sysds/test/TestUtils.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import java.io.FileInputStream;
3333
import java.io.FileOutputStream;
3434
import java.io.FileReader;
35+
import java.io.FileWriter;
3536
import java.io.IOException;
3637
import java.io.InputStreamReader;
3738
import java.io.OutputStreamWriter;
@@ -2941,6 +2942,25 @@ public static void writeTestScalar(String file, double value) {
29412942
}
29422943
}
29432944

2945+
2946+
/**
2947+
* Write scalar to file
2948+
*
2949+
* @param file File to write to
2950+
* @param value Value to write
2951+
*/
2952+
public static void writeTestScalar(String file, String value) {
2953+
try {
2954+
DataOutputStream out = new DataOutputStream(new FileOutputStream(file));
2955+
try(PrintWriter pw = new PrintWriter(out)) {
2956+
pw.println(value);
2957+
}
2958+
}
2959+
catch(IOException e) {
2960+
fail("unable to write test scalar (" + file + "): " + e.getMessage());
2961+
}
2962+
}
2963+
29442964
/**
29452965
* Write scalar to file
29462966
*

0 commit comments

Comments
 (0)