Skip to content

Commit e14d775

Browse files
committed
extend to wdivmm
1 parent 6429df2 commit e14d775

2 files changed

Lines changed: 140 additions & 48 deletions

File tree

src/main/java/org/apache/sysds/runtime/instructions/ooc/WDivMMOOCInstruction.java

Lines changed: 85 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,8 @@
1919

2020
package org.apache.sysds.runtime.instructions.ooc;
2121

22-
2322
import org.apache.sysds.common.Opcodes;
24-
import org.apache.sysds.common.Types.DataType;
2523
import org.apache.sysds.lops.WeightedDivMM.WDivMMType;
26-
import org.apache.sysds.runtime.DMLRuntimeException;
2724
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
2825
import org.apache.sysds.runtime.functionobjects.Multiply;
2926
import org.apache.sysds.runtime.functionobjects.Plus;
@@ -41,8 +38,7 @@
4138
import java.util.function.Function;
4239

4340

44-
public class WDivMMOOCInstruction extends QuaternaryOOCInstruction
45-
{
41+
public class WDivMMOOCInstruction extends QuaternaryOOCInstruction {
4642

4743
protected WDivMMOOCInstruction(QuaternaryOperator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4,
4844
CPOperand out, String opcode, String istr) {
@@ -56,103 +52,144 @@ public static WDivMMOOCInstruction parseInstruction(QuaternaryOOCInstruction ins
5652
instr.input3, instr.input4, instr.output, opcode, instrStr);
5753
}
5854

59-
6055
@Override
6156
public void processInstruction(ExecutionContext ec) {
62-
QuaternaryOperator _qop = ((QuaternaryOperator)_optr);
57+
QuaternaryOperator _qop = ((QuaternaryOperator) _optr);
6358
final WDivMMType wt = _qop.wtype3;
6459

65-
if(!(wt.hasFourInputs()&&wt.hasScalar()) || wt.isBasic() || wt.isMult() || wt.isMinus()) throw new DMLRuntimeException("Not implemented: only pnmf supported yet");
66-
6760
CachingStream X = new CachingStream(ec.getMatrixObject(input1).getStreamHandle());
6861
CachingStream U = new CachingStream(ec.getMatrixObject(input2).getStreamHandle());
6962
CachingStream V = new CachingStream(ec.getMatrixObject(input3).getStreamHandle());
7063

71-
double eps = 0.0;
72-
if(_qop.hasFourInputs()) {
73-
if (input4.getDataType() == DataType.SCALAR)
74-
eps = ec.getScalarInput(input4).getDoubleValue();
75-
}
64+
boolean basic = wt.isBasic();
65+
boolean left = wt.isLeft();
66+
boolean mult = wt.isMult();
67+
boolean minus = wt.isMinus();
68+
boolean four = wt.hasFourInputs();
69+
boolean scalar = wt.hasScalar();
7670

77-
OOCStream<IndexedMatrixValue> mmt = matMultOOC(U.getReadStream(), V.getReadStream(), U.getDataCharacteristics(), V.getDataCharacteristics(), false, true);
78-
OOCStream<IndexedMatrixValue> plus = elemPlusOOC(mmt, eps);
79-
OOCStream<IndexedMatrixValue> inter = elemDivOOC(X.getReadStream(), plus);
71+
OOCStream<IndexedMatrixValue> mmt = matMultOOC(U.getReadStream(), V.getReadStream(), U.getDataCharacteristics(),
72+
V.getDataCharacteristics(), false, true);
73+
OOCStream<IndexedMatrixValue> inter;
8074
OOCStream<IndexedMatrixValue> out;
8175

82-
if(wt.isLeft())
76+
if(basic) {
77+
out = elemMultOOC(X.getReadStream(), mmt);
78+
ec.getMatrixObject(output).setStreamHandle(out);
79+
return;
80+
}
81+
else if(four) {
82+
if(scalar) {
83+
double eps = ec.getScalarInput(input4).getDoubleValue();
84+
inter = elemDivOOC(X.getReadStream(), elemPlusOOC(mmt, eps));
85+
}
86+
else {
87+
CachingStream W = new CachingStream(ec.getMatrixObject(input4).getStreamHandle());
88+
inter = elemMultOOC(X.getReadStream(), elemMinusOOC(mmt, W.getReadStream()));
89+
}
90+
}
91+
else {
92+
if(minus)
93+
inter = elemMinusOOC(mmt, X.getReadStream());
94+
else {
95+
if(mult)
96+
inter = elemMultOOC(X.getReadStream(), mmt);
97+
else
98+
inter = elemDivOOC(X.getReadStream(), mmt);
99+
}
100+
}
101+
102+
if(left)
83103
out = matMultOOC(inter, U.getReadStream(), X.getDataCharacteristics(), U.getDataCharacteristics(), true, false);
84104
else
85105
out = matMultOOC(inter, V.getReadStream(), X.getDataCharacteristics(), V.getDataCharacteristics(), false, false);
86106

87107
ec.getMatrixObject(output).setStreamHandle(out);
88108
}
89109

90-
private OOCStream<IndexedMatrixValue> matMultOOC(OOCStream<IndexedMatrixValue> m1, OOCStream<IndexedMatrixValue> m2, DataCharacteristics dc1, DataCharacteristics dc2, boolean leftTranspose, boolean rightTranspose){
110+
private OOCStream<IndexedMatrixValue> matMultOOC(OOCStream<IndexedMatrixValue> m1, OOCStream<IndexedMatrixValue> m2,
111+
DataCharacteristics dc1, DataCharacteristics dc2, boolean leftTranspose, boolean rightTranspose) {
91112

92-
int emitLeftThreshold = rightTranspose? (int) dc2.getNumRowBlocks() : (int) dc2.getNumColBlocks();
93-
int emitRightThreshold = leftTranspose? (int) dc1.getNumColBlocks() : (int) dc1.getNumRowBlocks();
113+
int emitLeftThreshold = rightTranspose ? (int) dc2.getNumRowBlocks() : (int) dc2.getNumColBlocks();
114+
int emitRightThreshold = leftTranspose ? (int) dc1.getNumColBlocks() : (int) dc1.getNumRowBlocks();
94115

95116
OOCStream<IndexedMatrixValue> intermediateStream = createWritableStream();
96117
OOCStream<IndexedMatrixValue> out = createWritableStream();
97118

98119
AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
99120
AggregateBinaryOperator op = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
100121

101-
joinManyOOC(m1, m2, intermediateStream,
102-
(left, right) -> {
103-
MatrixBlock leftBlock = (MatrixBlock) left.getValue();
104-
MatrixBlock rightBlock = (MatrixBlock) right.getValue();
105-
if(leftTranspose) leftBlock = leftBlock.transpose();
106-
if(rightTranspose) rightBlock = rightBlock.transpose();
107-
108-
MatrixBlock partialResult = leftBlock.aggregateBinaryOperations(leftBlock, rightBlock,
109-
new MatrixBlock(), op);
110-
int lidx = (int) (leftTranspose? left.getIndexes().getColumnIndex() : left.getIndexes().getRowIndex());
111-
int ridx = (int) (rightTranspose? right.getIndexes().getRowIndex() : right.getIndexes().getColumnIndex());
112-
return new IndexedMatrixValue(new MatrixIndexes(lidx, ridx), partialResult);
113-
},
114-
tmp -> leftTranspose? tmp.getIndexes().getRowIndex() : tmp.getIndexes().getColumnIndex(),
115-
tmp -> rightTranspose? tmp.getIndexes().getColumnIndex() : tmp.getIndexes().getRowIndex(),
122+
joinManyOOC(m1, m2, intermediateStream, (left, right) -> {
123+
MatrixBlock leftBlock = (MatrixBlock) left.getValue();
124+
MatrixBlock rightBlock = (MatrixBlock) right.getValue();
125+
if(leftTranspose)
126+
leftBlock = leftBlock.transpose();
127+
if(rightTranspose)
128+
rightBlock = rightBlock.transpose();
129+
130+
MatrixBlock partialResult = leftBlock.aggregateBinaryOperations(leftBlock, rightBlock, new MatrixBlock(), op);
131+
int lidx = (int) (leftTranspose ? left.getIndexes().getColumnIndex() : left.getIndexes().getRowIndex());
132+
int ridx = (int) (rightTranspose ? right.getIndexes().getRowIndex() : right.getIndexes().getColumnIndex());
133+
return new IndexedMatrixValue(new MatrixIndexes(lidx, ridx), partialResult);
134+
}, tmp -> leftTranspose ? tmp.getIndexes().getRowIndex() : tmp.getIndexes().getColumnIndex(),
135+
tmp -> rightTranspose ? tmp.getIndexes().getColumnIndex() : tmp.getIndexes().getRowIndex(),
116136
emitLeftThreshold, emitRightThreshold);
117137

118138
BinaryOperator plus = InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString());
119-
int emitAggThreshold = leftTranspose? (int) dc1.getNumRowBlocks() : (int) dc1.getNumColBlocks();
139+
int emitAggThreshold = leftTranspose ? (int) dc1.getNumRowBlocks() : (int) dc1.getNumColBlocks();
120140

121141
groupedReduceOOC(intermediateStream, out, (left, right) -> {
122-
MatrixBlock mb = ((MatrixBlock)left.getValue()).binaryOperationsInPlace(plus, right.getValue());
142+
MatrixBlock mb = ((MatrixBlock) left.getValue()).binaryOperationsInPlace(plus, right.getValue());
123143
left.setValue(mb);
124144
return left;
125145
}, emitAggThreshold);
126146

127147
return out;
128148
}
129149

130-
private OOCStream<IndexedMatrixValue> elemDivOOC(OOCStream<IndexedMatrixValue> m1, OOCStream<IndexedMatrixValue> m2){
150+
private OOCStream<IndexedMatrixValue> elemOOC(OOCStream<IndexedMatrixValue> m1, OOCStream<IndexedMatrixValue> m2, BinaryOperator bop) {
131151
SubscribableTaskQueue<IndexedMatrixValue> out = new SubscribableTaskQueue<>();
132-
BinaryOperator div = InstructionUtils.parseBinaryOperator(Opcodes.DIV.toString());
133-
Function<IndexedMatrixValue, MatrixIndexes> key = imv -> new MatrixIndexes(imv.getIndexes().getRowIndex(), imv.getIndexes().getColumnIndex());
152+
Function<IndexedMatrixValue, MatrixIndexes> key = imv ->
153+
new MatrixIndexes(imv.getIndexes().getRowIndex(), imv.getIndexes().getColumnIndex());
134154

135155
joinOOC(m1, m2, out, (left, right) -> {
136156
MatrixBlock lb = (MatrixBlock) left.getValue();
137157
MatrixBlock rb = (MatrixBlock) right.getValue();
138-
MatrixBlock combined = lb.binaryOperations(div, rb);
158+
MatrixBlock combined = lb.binaryOperations(bop, rb);
139159
return new IndexedMatrixValue(
140160
new MatrixIndexes(left.getIndexes().getRowIndex(), left.getIndexes().getColumnIndex()), combined);
141161
}, key);
142162

143163
return out;
144164
}
145165

146-
private OOCStream<IndexedMatrixValue> elemPlusOOC(OOCStream<IndexedMatrixValue> m1, double eps){
166+
private OOCStream<IndexedMatrixValue> elemDivOOC(OOCStream<IndexedMatrixValue> m1, OOCStream<IndexedMatrixValue> m2) {
167+
BinaryOperator div = InstructionUtils.parseBinaryOperator(Opcodes.DIV.toString());
168+
return elemOOC(m1, m2, div);
169+
}
170+
171+
private OOCStream<IndexedMatrixValue> elemMultOOC(OOCStream<IndexedMatrixValue> m1, OOCStream<IndexedMatrixValue> m2) {
172+
BinaryOperator div = InstructionUtils.parseBinaryOperator(Opcodes.MULT.toString());
173+
return elemOOC(m1, m2, div);
174+
}
175+
176+
private OOCStream<IndexedMatrixValue> elemMinusOOC(OOCStream<IndexedMatrixValue> m1, OOCStream<IndexedMatrixValue> m2) {
177+
BinaryOperator div = InstructionUtils.parseBinaryOperator(Opcodes.MINUS.toString());
178+
return elemOOC(m1, m2, div);
179+
}
180+
181+
private OOCStream<IndexedMatrixValue> elemPlusOOC(OOCStream<IndexedMatrixValue> m1, double eps) {
147182
SubscribableTaskQueue<IndexedMatrixValue> out = new SubscribableTaskQueue<>();
148-
mapOOC(m1, out, blk -> new IndexedMatrixValue(
149-
new MatrixIndexes(blk.getIndexes().getRowIndex(), blk.getIndexes().getColumnIndex()), plusDouble((MatrixBlock) blk.getValue(), eps)));
183+
mapOOC(m1, out,
184+
blk -> new IndexedMatrixValue(
185+
new MatrixIndexes(blk.getIndexes().getRowIndex(), blk.getIndexes().getColumnIndex()),
186+
elemPlus((MatrixBlock) blk.getValue(), eps)));
150187
return out;
151188
}
152189

153-
private MatrixBlock plusDouble(MatrixBlock blk, double eps){
154-
for(int i=0; i<blk.getNumRows(); i++){
155-
for(int j=0; j<blk.getNumColumns(); j++){
190+
private MatrixBlock elemPlus(MatrixBlock blk, double eps) {
191+
for(int i = 0; i < blk.getNumRows(); i++) {
192+
for(int j = 0; j < blk.getNumColumns(); j++) {
156193
blk.set(i, j, blk.get(i, j) + eps);
157194
}
158195
}

src/test/java/org/apache/sysds/test/functions/quaternary/WeightedDivMatrixMultTest.java

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,16 @@ public void testWeightedDivMMRightDenseSPRep() {
164164
runWeightedDivMMTest(TEST_NAME2, false, true, true, ExecType.SPARK);
165165
}
166166

167+
@Test
168+
public void testWeightedDivMMLeftDenseOOC() {
169+
runWeightedDivMMTest(TEST_NAME1, false, true, false, ExecType.OOC);
170+
}
171+
172+
@Test
173+
public void testWeightedDivMMRightDenseOOC() {
174+
runWeightedDivMMTest(TEST_NAME2, false, true, false, ExecType.OOC);
175+
}
176+
167177
//b) testcases for wdivmm w/ MULTIPLY BASIC/LEFT/RIGHT
168178

169179
@Test
@@ -341,6 +351,41 @@ public void testWeightedDivMM4MultMinusRightSparseSP() {
341351
public void testWeightedDivMM4MultMinusRightDenseSPRep() {
342352
runWeightedDivMMTest(TEST_NAME9, false, true, true, ExecType.SPARK);
343353
}
354+
355+
@Test
356+
public void testWeightedDivMMMultBasicDenseOOC() {
357+
runWeightedDivMMTest(TEST_NAME3, false, true, false, ExecType.OOC);
358+
}
359+
360+
@Test
361+
public void testWeightedDivMMMultLeftDenseOOC() {
362+
runWeightedDivMMTest(TEST_NAME4, false, true, false, ExecType.OOC);
363+
}
364+
365+
@Test
366+
public void testWeightedDivMMMultRightDenseOOC() {
367+
runWeightedDivMMTest(TEST_NAME5, false, true, false, ExecType.OOC);
368+
}
369+
370+
@Test
371+
public void testWeightedDivMMMultMinusLeftDenseOOC() {
372+
runWeightedDivMMTest(TEST_NAME6, false, true, false, ExecType.OOC);
373+
}
374+
375+
@Test
376+
public void testWeightedDivMMMultMinusRightDenseOOC() {
377+
runWeightedDivMMTest(TEST_NAME7, false, true, false, ExecType.OOC);
378+
}
379+
380+
@Test
381+
public void testWeightedDivMM4MultMinusLeftDenseOOC() {
382+
runWeightedDivMMTest(TEST_NAME8, false, true, false, ExecType.OOC);
383+
}
384+
385+
@Test
386+
public void testWeightedDivMM4MultMinusRightDenseOOC() {
387+
runWeightedDivMMTest(TEST_NAME9, false, true, false, ExecType.OOC);
388+
}
344389

345390
//c) testcases for wdivmm w/ DIVIDE LEFT/RIGHT with Epsilon
346391

@@ -394,6 +439,16 @@ public void testWeightedDivMMRightEpsDenseSPRep() {
394439
runWeightedDivMMTest(TEST_NAME11, false, true, true, ExecType.SPARK);
395440
}
396441

442+
@Test
443+
public void testWeightedDivMMLeftEpsDenseOOC() {
444+
runWeightedDivMMTest(TEST_NAME10, false, true, false, ExecType.OOC);
445+
}
446+
447+
@Test
448+
public void testWeightedDivMMRightEpsDenseOOC() {
449+
runWeightedDivMMTest(TEST_NAME11, false, true, false, ExecType.OOC);
450+
}
451+
397452
//d) testcases for wdivmm w/ DIVIDE LEFT/RIGHT with Epsilon
398453

399454
@Test

0 commit comments

Comments
 (0)