1919
2020package org .apache .sysds .runtime .instructions .ooc ;
2121
22-
2322import org .apache .sysds .common .Opcodes ;
24- import org .apache .sysds .common .Types .DataType ;
2523import org .apache .sysds .lops .WeightedDivMM .WDivMMType ;
26- import org .apache .sysds .runtime .DMLRuntimeException ;
2724import org .apache .sysds .runtime .controlprogram .context .ExecutionContext ;
2825import org .apache .sysds .runtime .functionobjects .Multiply ;
2926import org .apache .sysds .runtime .functionobjects .Plus ;
4138import java .util .function .Function ;
4239
4340
44- public class WDivMMOOCInstruction extends QuaternaryOOCInstruction
45- {
41+ public class WDivMMOOCInstruction extends QuaternaryOOCInstruction {
4642
4743 protected WDivMMOOCInstruction (QuaternaryOperator op , CPOperand in1 , CPOperand in2 , CPOperand in3 , CPOperand in4 ,
4844 CPOperand out , String opcode , String istr ) {
@@ -56,103 +52,144 @@ public static WDivMMOOCInstruction parseInstruction(QuaternaryOOCInstruction ins
5652 instr .input3 , instr .input4 , instr .output , opcode , instrStr );
5753 }
5854
59-
6055 @ Override
6156 public void processInstruction (ExecutionContext ec ) {
62- QuaternaryOperator _qop = ((QuaternaryOperator )_optr );
57+ QuaternaryOperator _qop = ((QuaternaryOperator ) _optr );
6358 final WDivMMType wt = _qop .wtype3 ;
6459
65- if (!(wt .hasFourInputs ()&&wt .hasScalar ()) || wt .isBasic () || wt .isMult () || wt .isMinus ()) throw new DMLRuntimeException ("Not implemented: only pnmf supported yet" );
66-
6760 CachingStream X = new CachingStream (ec .getMatrixObject (input1 ).getStreamHandle ());
6861 CachingStream U = new CachingStream (ec .getMatrixObject (input2 ).getStreamHandle ());
6962 CachingStream V = new CachingStream (ec .getMatrixObject (input3 ).getStreamHandle ());
7063
71- double eps = 0.0 ;
72- if (_qop .hasFourInputs ()) {
73- if (input4 .getDataType () == DataType .SCALAR )
74- eps = ec .getScalarInput (input4 ).getDoubleValue ();
75- }
64+ boolean basic = wt .isBasic ();
65+ boolean left = wt .isLeft ();
66+ boolean mult = wt .isMult ();
67+ boolean minus = wt .isMinus ();
68+ boolean four = wt .hasFourInputs ();
69+ boolean scalar = wt .hasScalar ();
7670
77- OOCStream <IndexedMatrixValue > mmt = matMultOOC (U .getReadStream (), V .getReadStream (), U .getDataCharacteristics (), V . getDataCharacteristics (), false , true );
78- OOCStream < IndexedMatrixValue > plus = elemPlusOOC ( mmt , eps );
79- OOCStream <IndexedMatrixValue > inter = elemDivOOC ( X . getReadStream (), plus ) ;
71+ OOCStream <IndexedMatrixValue > mmt = matMultOOC (U .getReadStream (), V .getReadStream (), U .getDataCharacteristics (),
72+ V . getDataCharacteristics (), false , true );
73+ OOCStream <IndexedMatrixValue > inter ;
8074 OOCStream <IndexedMatrixValue > out ;
8175
82- if (wt .isLeft ())
76+ if (basic ) {
77+ out = elemMultOOC (X .getReadStream (), mmt );
78+ ec .getMatrixObject (output ).setStreamHandle (out );
79+ return ;
80+ }
81+ else if (four ) {
82+ if (scalar ) {
83+ double eps = ec .getScalarInput (input4 ).getDoubleValue ();
84+ inter = elemDivOOC (X .getReadStream (), elemPlusOOC (mmt , eps ));
85+ }
86+ else {
87+ CachingStream W = new CachingStream (ec .getMatrixObject (input4 ).getStreamHandle ());
88+ inter = elemMultOOC (X .getReadStream (), elemMinusOOC (mmt , W .getReadStream ()));
89+ }
90+ }
91+ else {
92+ if (minus )
93+ inter = elemMinusOOC (mmt , X .getReadStream ());
94+ else {
95+ if (mult )
96+ inter = elemMultOOC (X .getReadStream (), mmt );
97+ else
98+ inter = elemDivOOC (X .getReadStream (), mmt );
99+ }
100+ }
101+
102+ if (left )
83103 out = matMultOOC (inter , U .getReadStream (), X .getDataCharacteristics (), U .getDataCharacteristics (), true , false );
84104 else
85105 out = matMultOOC (inter , V .getReadStream (), X .getDataCharacteristics (), V .getDataCharacteristics (), false , false );
86106
87107 ec .getMatrixObject (output ).setStreamHandle (out );
88108 }
89109
90- private OOCStream <IndexedMatrixValue > matMultOOC (OOCStream <IndexedMatrixValue > m1 , OOCStream <IndexedMatrixValue > m2 , DataCharacteristics dc1 , DataCharacteristics dc2 , boolean leftTranspose , boolean rightTranspose ){
110+ private OOCStream <IndexedMatrixValue > matMultOOC (OOCStream <IndexedMatrixValue > m1 , OOCStream <IndexedMatrixValue > m2 ,
111+ DataCharacteristics dc1 , DataCharacteristics dc2 , boolean leftTranspose , boolean rightTranspose ) {
91112
92- int emitLeftThreshold = rightTranspose ? (int ) dc2 .getNumRowBlocks () : (int ) dc2 .getNumColBlocks ();
93- int emitRightThreshold = leftTranspose ? (int ) dc1 .getNumColBlocks () : (int ) dc1 .getNumRowBlocks ();
113+ int emitLeftThreshold = rightTranspose ? (int ) dc2 .getNumRowBlocks () : (int ) dc2 .getNumColBlocks ();
114+ int emitRightThreshold = leftTranspose ? (int ) dc1 .getNumColBlocks () : (int ) dc1 .getNumRowBlocks ();
94115
95116 OOCStream <IndexedMatrixValue > intermediateStream = createWritableStream ();
96117 OOCStream <IndexedMatrixValue > out = createWritableStream ();
97118
98119 AggregateOperator agg = new AggregateOperator (0 , Plus .getPlusFnObject ());
99120 AggregateBinaryOperator op = new AggregateBinaryOperator (Multiply .getMultiplyFnObject (), agg );
100121
101- joinManyOOC (m1 , m2 , intermediateStream ,
102- (left , right ) -> {
103- MatrixBlock leftBlock = (MatrixBlock ) left .getValue ();
104- MatrixBlock rightBlock = (MatrixBlock ) right .getValue ();
105- if (leftTranspose ) leftBlock = leftBlock .transpose ();
106- if (rightTranspose ) rightBlock = rightBlock .transpose ();
107-
108- MatrixBlock partialResult = leftBlock .aggregateBinaryOperations (leftBlock , rightBlock ,
109- new MatrixBlock (), op );
110- int lidx = (int ) (leftTranspose ? left .getIndexes ().getColumnIndex () : left .getIndexes ().getRowIndex ());
111- int ridx = (int ) (rightTranspose ? right .getIndexes ().getRowIndex () : right .getIndexes ().getColumnIndex ());
112- return new IndexedMatrixValue (new MatrixIndexes (lidx , ridx ), partialResult );
113- },
114- tmp -> leftTranspose ? tmp .getIndexes ().getRowIndex () : tmp .getIndexes ().getColumnIndex (),
115- tmp -> rightTranspose ? tmp .getIndexes ().getColumnIndex () : tmp .getIndexes ().getRowIndex (),
122+ joinManyOOC (m1 , m2 , intermediateStream , (left , right ) -> {
123+ MatrixBlock leftBlock = (MatrixBlock ) left .getValue ();
124+ MatrixBlock rightBlock = (MatrixBlock ) right .getValue ();
125+ if (leftTranspose )
126+ leftBlock = leftBlock .transpose ();
127+ if (rightTranspose )
128+ rightBlock = rightBlock .transpose ();
129+
130+ MatrixBlock partialResult = leftBlock .aggregateBinaryOperations (leftBlock , rightBlock , new MatrixBlock (), op );
131+ int lidx = (int ) (leftTranspose ? left .getIndexes ().getColumnIndex () : left .getIndexes ().getRowIndex ());
132+ int ridx = (int ) (rightTranspose ? right .getIndexes ().getRowIndex () : right .getIndexes ().getColumnIndex ());
133+ return new IndexedMatrixValue (new MatrixIndexes (lidx , ridx ), partialResult );
134+ }, tmp -> leftTranspose ? tmp .getIndexes ().getRowIndex () : tmp .getIndexes ().getColumnIndex (),
135+ tmp -> rightTranspose ? tmp .getIndexes ().getColumnIndex () : tmp .getIndexes ().getRowIndex (),
116136 emitLeftThreshold , emitRightThreshold );
117137
118138 BinaryOperator plus = InstructionUtils .parseBinaryOperator (Opcodes .PLUS .toString ());
119- int emitAggThreshold = leftTranspose ? (int ) dc1 .getNumRowBlocks () : (int ) dc1 .getNumColBlocks ();
139+ int emitAggThreshold = leftTranspose ? (int ) dc1 .getNumRowBlocks () : (int ) dc1 .getNumColBlocks ();
120140
121141 groupedReduceOOC (intermediateStream , out , (left , right ) -> {
122- MatrixBlock mb = ((MatrixBlock )left .getValue ()).binaryOperationsInPlace (plus , right .getValue ());
142+ MatrixBlock mb = ((MatrixBlock ) left .getValue ()).binaryOperationsInPlace (plus , right .getValue ());
123143 left .setValue (mb );
124144 return left ;
125145 }, emitAggThreshold );
126146
127147 return out ;
128148 }
129149
130- private OOCStream <IndexedMatrixValue > elemDivOOC (OOCStream <IndexedMatrixValue > m1 , OOCStream <IndexedMatrixValue > m2 ) {
150+ private OOCStream <IndexedMatrixValue > elemOOC (OOCStream <IndexedMatrixValue > m1 , OOCStream <IndexedMatrixValue > m2 , BinaryOperator bop ) {
131151 SubscribableTaskQueue <IndexedMatrixValue > out = new SubscribableTaskQueue <>();
132- BinaryOperator div = InstructionUtils . parseBinaryOperator ( Opcodes . DIV . toString ());
133- Function < IndexedMatrixValue , MatrixIndexes > key = imv -> new MatrixIndexes (imv .getIndexes ().getRowIndex (), imv .getIndexes ().getColumnIndex ());
152+ Function < IndexedMatrixValue , MatrixIndexes > key = imv ->
153+ new MatrixIndexes (imv .getIndexes ().getRowIndex (), imv .getIndexes ().getColumnIndex ());
134154
135155 joinOOC (m1 , m2 , out , (left , right ) -> {
136156 MatrixBlock lb = (MatrixBlock ) left .getValue ();
137157 MatrixBlock rb = (MatrixBlock ) right .getValue ();
138- MatrixBlock combined = lb .binaryOperations (div , rb );
158+ MatrixBlock combined = lb .binaryOperations (bop , rb );
139159 return new IndexedMatrixValue (
140160 new MatrixIndexes (left .getIndexes ().getRowIndex (), left .getIndexes ().getColumnIndex ()), combined );
141161 }, key );
142162
143163 return out ;
144164 }
145165
146- private OOCStream <IndexedMatrixValue > elemPlusOOC (OOCStream <IndexedMatrixValue > m1 , double eps ){
166+ private OOCStream <IndexedMatrixValue > elemDivOOC (OOCStream <IndexedMatrixValue > m1 , OOCStream <IndexedMatrixValue > m2 ) {
167+ BinaryOperator div = InstructionUtils .parseBinaryOperator (Opcodes .DIV .toString ());
168+ return elemOOC (m1 , m2 , div );
169+ }
170+
171+ private OOCStream <IndexedMatrixValue > elemMultOOC (OOCStream <IndexedMatrixValue > m1 , OOCStream <IndexedMatrixValue > m2 ) {
172+ BinaryOperator div = InstructionUtils .parseBinaryOperator (Opcodes .MULT .toString ());
173+ return elemOOC (m1 , m2 , div );
174+ }
175+
176+ private OOCStream <IndexedMatrixValue > elemMinusOOC (OOCStream <IndexedMatrixValue > m1 , OOCStream <IndexedMatrixValue > m2 ) {
177+ BinaryOperator div = InstructionUtils .parseBinaryOperator (Opcodes .MINUS .toString ());
178+ return elemOOC (m1 , m2 , div );
179+ }
180+
181+ private OOCStream <IndexedMatrixValue > elemPlusOOC (OOCStream <IndexedMatrixValue > m1 , double eps ) {
147182 SubscribableTaskQueue <IndexedMatrixValue > out = new SubscribableTaskQueue <>();
148- mapOOC (m1 , out , blk -> new IndexedMatrixValue (
149- new MatrixIndexes (blk .getIndexes ().getRowIndex (), blk .getIndexes ().getColumnIndex ()), plusDouble ((MatrixBlock ) blk .getValue (), eps )));
183+ mapOOC (m1 , out ,
184+ blk -> new IndexedMatrixValue (
185+ new MatrixIndexes (blk .getIndexes ().getRowIndex (), blk .getIndexes ().getColumnIndex ()),
186+ elemPlus ((MatrixBlock ) blk .getValue (), eps )));
150187 return out ;
151188 }
152189
153- private MatrixBlock plusDouble (MatrixBlock blk , double eps ){
154- for (int i = 0 ; i < blk .getNumRows (); i ++){
155- for (int j = 0 ; j < blk .getNumColumns (); j ++){
190+ private MatrixBlock elemPlus (MatrixBlock blk , double eps ) {
191+ for (int i = 0 ; i < blk .getNumRows (); i ++) {
192+ for (int j = 0 ; j < blk .getNumColumns (); j ++) {
156193 blk .set (i , j , blk .get (i , j ) + eps );
157194 }
158195 }
0 commit comments