3131import org .apache .sysds .runtime .matrix .operators .MultiThreadedOperator ;
3232import org .apache .sysds .runtime .transform .TfUtils .TfMethod ;
3333import org .apache .sysds .runtime .util .UtilFunctions ;
34- import org .apache .wink .json4j .JSONArray ;
34+ import org .apache .wink .json4j .JSONException ;
3535import org .apache .wink .json4j .JSONObject ;
3636
3737public class BinaryFrameScalarCPInstruction extends BinaryCPInstruction {
3838 // private static final Log LOG = LogFactory.getLog(BinaryFrameFrameCPInstruction.class.getName());
3939
40+ private static final TfMethod [] UNSUPPORTED_MASK_METHODS = new TfMethod [] {TfMethod .BIN ,
41+ TfMethod .WORD_EMBEDDING , TfMethod .BAG_OF_WORDS , TfMethod .UDF };
42+
4043 protected BinaryFrameScalarCPInstruction (MultiThreadedOperator op , CPOperand in1 , CPOperand in2 , CPOperand out ,
4144 String opcode , String istr ) {
4245 super (CPType .Binary , op , in1 , in2 , out , opcode , istr );
@@ -58,108 +61,146 @@ public void processInstruction(ExecutionContext ec) {
5861 ec .releaseFrameInput (input1 .getName ());
5962 }
6063
61- public void processGetCategorical ( ExecutionContext ec , FrameBlock f , ScalarObject spec ) {
64+ private static void validate ( JSONObject jSpec ) {
6265 try {
66+ if (!jSpec .containsKey ("ids" ) || !jSpec .getBoolean ("ids" ))
67+ throw new DMLRuntimeException ("not supported non ID based spec for get_categorical_mask" );
6368
64- // MatrixBlock ret = new MatrixBlock();
65- int nCol = f .getNumColumns ();
69+ for (TfMethod m : UNSUPPORTED_MASK_METHODS )
70+ if (jSpec .containsKey (m .toString ()))
71+ throw new DMLRuntimeException ("unsupported transform method '" + m + "' for get_categorical_mask" );
72+ }
73+ catch (JSONException e ) {
74+ throw new DMLRuntimeException (e );
75+ }
76+ }
6677
78+ public void processGetCategorical (ExecutionContext ec , FrameBlock f , ScalarObject spec ) {
79+ try {
80+ // 1. extract the spec, 2. validate it
6781 JSONObject jSpec = new JSONObject (spec .getStringValue ());
82+ validate (jSpec );
6883
69- if (!jSpec .containsKey ("ids" ) || !jSpec .getBoolean ("ids" )) {
70- throw new DMLRuntimeException ("not supported non ID based spec for get_categorical_mask" );
71- }
84+ // 3.-5. fold each supported transform method into the per-column mask state
85+ CategoricalMask mask = new CategoricalMask (f , jSpec );
86+ mask .hash ();
87+ mask .recode ();
88+ mask .dummycode ();
7289
73- // get_categorical_mask only models the column expansion of recode/dummycode/hash.
74- // Methods that change the output arity (bin expands under dummycode, word_embedding and
75- // bag_of_words map to many columns) or are user-defined (udf) would produce a mask with
76- // the wrong number of columns, so reject them explicitly instead of emitting a silently
77- // incorrect result. impute and omit are intentionally allowed: they do not alter the
78- // output column count or the categorical flag of a column.
79- for (TfMethod m : new TfMethod [] {TfMethod .BIN , TfMethod .WORD_EMBEDDING , TfMethod .BAG_OF_WORDS ,
80- TfMethod .UDF }) {
81- if (jSpec .containsKey (m .toString ()))
82- throw new DMLRuntimeException (
83- "unsupported transform method '" + m + "' for get_categorical_mask" );
84- }
90+ // 6.-7. size and materialize the output mask
91+ ec .setMatrixOutput (output .getName (), mask .toMatrixBlock ());
92+ }
93+ catch (Exception e ) {
94+ throw new DMLRuntimeException (e );
95+ }
96+ }
8597
86- String recode = TfMethod .RECODE .toString ();
87- String dummycode = TfMethod .DUMMYCODE .toString ();
88- String hash = TfMethod .HASH .toString ();
98+ /**
99+ * Accumulates, per input column, how many output columns it expands to (lengths) and whether those
100+ * output columns are categorical (categorical). The arrays are allocated lazily: a column that no
101+ * method touches keeps the implicit default of a single, non-categorical output column.
102+ */
103+ private static final class CategoricalMask {
104+ private final FrameBlock f ;
105+ private final JSONObject jSpec ;
106+ private final int nCol ;
107+
108+ private int [] lengths = null ;
109+ private boolean [] categorical = null ;
110+
111+ // feature-hashed columns map to K buckets; a plain hashed column produces a single
112+ // (categorical) bucket-id column, while a hashed column that is additionally dummycoded
113+ // expands to K columns.
114+ private boolean [] hashed = null ;
115+ private int K = 0 ;
116+
117+ private CategoricalMask (FrameBlock f , JSONObject jSpec ) {
118+ this .f = f ;
119+ this .jSpec = jSpec ;
120+ this .nCol = f .getNumColumns ();
121+ }
89122
90- int [] lengths = new int [nCol ];
91- // assume all columns encode to at least one column.
92- Arrays .fill (lengths , 1 );
93- boolean [] categorical = new boolean [nCol ];
94-
95- // feature-hashed columns map to K buckets; a plain hashed column
96- // produces a single (categorical) bucket-id column, while a hashed
97- // column that is additionally dummycoded expands to K columns.
98- boolean [] hashed = new boolean [nCol ];
99- int K = 0 ;
100- if (jSpec .containsKey (hash )) {
101- K = jSpec .getInt ("K" );
102- JSONArray a = jSpec .getJSONArray (hash );
103- for (Object aa : a ) {
104- int av = (Integer ) aa - 1 ;
105- hashed [av ] = true ;
106- categorical [av ] = true ;
107- }
123+ private void hash () throws JSONException {
124+ String hash = TfMethod .HASH .toString ();
125+ if (!jSpec .containsKey (hash ))
126+ return ;
127+ K = jSpec .getInt ("K" );
128+ hashed = new boolean [nCol ];
129+ ensureCategorical ();
130+ for (Object aa : jSpec .getJSONArray (hash )) {
131+ int av = (Integer ) aa - 1 ;
132+ hashed [av ] = true ;
133+ categorical [av ] = true ;
108134 }
135+ }
109136
110- if (jSpec .containsKey (recode )) {
111- JSONArray a = jSpec .getJSONArray (recode );
112- for (Object aa : a ) {
113- int av = (Integer ) aa - 1 ;
114- categorical [av ] = true ;
115- }
137+ private void recode () throws JSONException {
138+ String recode = TfMethod .RECODE .toString ();
139+ if (!jSpec .containsKey (recode ))
140+ return ;
141+ ensureCategorical ();
142+ for (Object aa : jSpec .getJSONArray (recode )) {
143+ int av = (Integer ) aa - 1 ;
144+ categorical [av ] = true ;
116145 }
146+ }
117147
118- if (jSpec .containsKey (dummycode )) {
119- JSONArray a = jSpec .getJSONArray (dummycode );
120- for (Object aa : a ) {
121- int av = (Integer ) aa - 1 ;
122- int ndist ;
123- if (hashed [av ]) {
124- // feature hashing followed by dummycoding yields K columns
125- ndist = K ;
126- }
127- else {
128- ColumnMetadata d = f .getColumnMetadata ()[av ];
129- String v = f .getString (0 , av );
130- if (v .length () > 1 && v .charAt (0 ) == '¿' ) {
131- ndist = UtilFunctions .parseToInt (v .substring (1 ));
132- }
133- else {
134- ndist = d .isDefault () ? 0 : (int ) d .getNumDistinct ();
135- }
136- }
137- lengths [av ] = ndist ;
138- categorical [av ] = true ;
139- }
148+ private void dummycode () throws JSONException {
149+ String dummycode = TfMethod .DUMMYCODE .toString ();
150+ if (!jSpec .containsKey (dummycode ))
151+ return ;
152+ ensureCategorical ();
153+ ensureLengths ();
154+ for (Object aa : jSpec .getJSONArray (dummycode )) {
155+ int av = (Integer ) aa - 1 ;
156+ lengths [av ] = distinctCount (av );
157+ categorical [av ] = true ;
140158 }
159+ }
141160
142- // get total size after mapping
161+ private int distinctCount (int av ) {
162+ if (hashed != null && hashed [av ])
163+ // feature hashing followed by dummycoding yields K columns
164+ return K ;
165+ ColumnMetadata d = f .getColumnMetadata ()[av ];
166+ String v = f .getString (0 , av );
167+ if (v .length () > 1 && v .charAt (0 ) == '¿' )
168+ return UtilFunctions .parseToInt (v .substring (1 ));
169+ return d .isDefault () ? 0 : (int ) d .getNumDistinct ();
170+ }
143171
144- int sumLengths = 0 ;
145- for (int i : lengths ) {
146- sumLengths += i ;
147- }
172+ private int sumLengths () {
173+ if (lengths == null )
174+ return nCol ;
175+ int sum = 0 ;
176+ for (int i = 0 ; i < nCol ; i ++)
177+ sum += lengths [i ];
178+ return sum ;
179+ }
148180
149- MatrixBlock ret = new MatrixBlock (1 , sumLengths , false );
181+ private MatrixBlock toMatrixBlock () {
182+ MatrixBlock ret = new MatrixBlock (1 , sumLengths (), false );
150183 ret .allocateDenseBlock ();
151184 int off = 0 ;
152- for (int i = 0 ; i < lengths .length ; i ++) {
153- for (int j = 0 ; j < lengths [i ]; j ++) {
154- ret .set (0 , off ++, categorical [i ] ? 1 : 0 );
155- }
185+ for (int i = 0 ; i < nCol ; i ++) {
186+ int len = (lengths == null ) ? 1 : lengths [i ];
187+ double val = (categorical != null && categorical [i ]) ? 1 : 0 ;
188+ for (int j = 0 ; j < len ; j ++)
189+ ret .set (0 , off ++, val );
156190 }
191+ return ret ;
192+ }
157193
158- ec .setMatrixOutput (output .getName (), ret );
159-
194+ private void ensureCategorical () {
195+ if (categorical == null )
196+ categorical = new boolean [nCol ];
160197 }
161- catch (Exception e ) {
162- throw new DMLRuntimeException (e );
198+
199+ private void ensureLengths () {
200+ if (lengths == null ) {
201+ lengths = new int [nCol ];
202+ Arrays .fill (lengths , 1 );
203+ }
163204 }
164205 }
165206}
0 commit comments