Skip to content

Commit 3de7cbe

Browse files
jessicapriebejanniklinde
authored andcommitted
Add OOC WDivMM
Closes #2464.
1 parent 6644ce3 commit 3de7cbe

21 files changed

Lines changed: 815 additions & 13 deletions

scripts/builtin/pnmf.dml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,12 @@
4242
# H List of amplitude matrices, one for each repetition.
4343
# ------------------------------------------------------------------------------------
4444

45-
m_pnmf = function(Matrix[Double] X, Integer rnk, Double eps = 1e-8, Integer maxi = 10, Boolean verbose=TRUE)
45+
m_pnmf = function(Matrix[Double] X, Integer rnk, Double eps = 1e-8, Integer maxi = 10, Boolean verbose=TRUE, Integer seed=-1)
4646
return (Matrix[Double] W, Matrix[Double] H)
4747
{
4848
#initialize W and H
49-
W = rand(rows=nrow(X), cols=rnk, min=0, max=0.025);
50-
H = rand(rows=rnk, cols=ncol(X), min=0, max=0.025);
49+
W = rand(rows=nrow(X), cols=rnk, min=0, max=0.025, seed=seed);
50+
H = rand(rows=rnk, cols=ncol(X), min=0, max=0.025, seed=seed);
5151

5252
i = 0;
5353
while(i < maxi) {

src/main/java/org/apache/sysds/hops/QuaternaryOp.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,8 @@ else if( et == ExecType.SPARK )
211211
constructCPLopsWeightedDivMM(wtype);
212212
else if( et == ExecType.SPARK )
213213
constructSparkLopsWeightedDivMM(wtype);
214+
else if( et == ExecType.OOC )
215+
constructOOCLopsWeightedDivMM(wtype);
214216
else
215217
throw new HopsException("Unsupported quaternaryop-wdivmm exec type: "+et);
216218
break;
@@ -462,6 +464,20 @@ private void constructSparkLopsWeightedDivMM( WDivMMType wtype )
462464
}
463465
}
464466

467+
private void constructOOCLopsWeightedDivMM(WDivMMType wtype)
468+
{
469+
WeightedDivMM wdiv = new WeightedDivMM(
470+
getInput().get(0).constructLops(),
471+
getInput().get(1).constructLops(),
472+
getInput().get(2).constructLops(),
473+
getInput().get(3).constructLops(),
474+
getDataType(), getValueType(), wtype, ExecType.OOC);
475+
476+
setOutputDimensions(wdiv);
477+
setLineNumbers(wdiv);
478+
setLops(wdiv);
479+
}
480+
465481
private void constructCPLopsWeightedCeMM(WCeMMType wtype)
466482
{
467483
WeightedCrossEntropy wcemm = new WeightedCrossEntropy(

src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import org.apache.sysds.runtime.instructions.ooc.ReorgOOCInstruction;
4444
import org.apache.sysds.runtime.instructions.ooc.TeeOOCInstruction;
4545
import org.apache.sysds.runtime.instructions.ooc.AppendOOCInstruction;
46+
import org.apache.sysds.runtime.instructions.ooc.QuaternaryOOCInstruction;
4647

4748
public class OOCInstructionParser extends InstructionParser {
4849
protected static final Log LOG = LogFactory.getLog(OOCInstructionParser.class.getName());
@@ -111,6 +112,8 @@ else if(parts.length == 4)
111112
return DataGenOOCInstruction.parseInstruction(str);
112113
case Append:
113114
return AppendOOCInstruction.parseInstruction(str);
115+
case Quaternary:
116+
return QuaternaryOOCInstruction.parseInstruction(str);
114117

115118
default:
116119
throw new DMLRuntimeException("Invalid OOC Instruction Type: " + ooctype);

src/main/java/org/apache/sysds/runtime/instructions/ooc/ComputationOOCInstruction.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
public abstract class ComputationOOCInstruction extends OOCInstruction {
2626
public CPOperand output;
27-
public CPOperand input1, input2, input3;
27+
public CPOperand input1, input2, input3, input4;
2828

2929
protected ComputationOOCInstruction(OOCType type, Operator op, CPOperand in1, CPOperand out, String opcode, String istr) {
3030
super(type, op, opcode, istr);
@@ -50,6 +50,15 @@ protected ComputationOOCInstruction(OOCType type, Operator op, CPOperand in1, CP
5050
output = out;
5151
}
5252

53+
protected ComputationOOCInstruction(OOCType type, Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand out, String opcode, String istr) {
54+
super(type, op, opcode, istr);
55+
input1 = in1;
56+
input2 = in2;
57+
input3 = in3;
58+
input4 = in4;
59+
output = out;
60+
}
61+
5362
public String getOutputVariableName() {
5463
return output.getName();
5564
}

src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ public abstract class OOCInstruction extends Instruction {
8080

8181
public enum OOCType {
8282
Reblock, Tee, Binary, Ternary, Unary, AggregateUnary, AggregateBinary, AggregateTernary, MAPMM, MMTSJ,
83-
MAPMMCHAIN, Reorg, CM, Ctable, MatrixIndexing, ParameterizedBuiltin, Rand, Append
83+
MAPMMCHAIN, Reorg, CM, Ctable, MatrixIndexing, ParameterizedBuiltin, Rand, Append, Quaternary
8484
}
8585

8686
protected final OOCInstruction.OOCType _ooctype;
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.runtime.instructions.ooc;
21+
22+
23+
import org.apache.sysds.common.Opcodes;
24+
import org.apache.sysds.lops.WeightedDivMM.WDivMMType;
25+
import org.apache.sysds.runtime.DMLRuntimeException;
26+
import org.apache.sysds.runtime.instructions.InstructionUtils;
27+
import org.apache.sysds.runtime.instructions.cp.CPOperand;
28+
import org.apache.sysds.runtime.matrix.operators.Operator;
29+
import org.apache.sysds.runtime.matrix.operators.QuaternaryOperator;
30+
31+
public abstract class QuaternaryOOCInstruction extends ComputationOOCInstruction {
32+
33+
protected QuaternaryOOCInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4,
34+
CPOperand out, String opcode, String istr) {
35+
super(OOCType.Quaternary, op, in1, in2, in3, in4, out, opcode, istr);
36+
}
37+
38+
public static QuaternaryOOCInstruction parseInstruction(String str) {
39+
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
40+
String opcode = parts[0];
41+
42+
if(opcode.contains(Opcodes.WEIGHTEDDIVMM.toString())) {
43+
InstructionUtils.checkNumFields(parts, 6);
44+
CPOperand in1 = new CPOperand(parts[1]);
45+
CPOperand in2 = new CPOperand(parts[2]);
46+
CPOperand in3 = new CPOperand(parts[3]);
47+
CPOperand in4 = new CPOperand(parts[4]);
48+
CPOperand out = new CPOperand(parts[5]);
49+
QuaternaryOperator qop = new QuaternaryOperator(WDivMMType.valueOf(parts[6]));
50+
return new WDivMMOOCInstruction(qop, in1, in2, in3, in4, out, opcode, str);
51+
}
52+
throw new DMLRuntimeException("Not implemented yet opcode " + opcode);
53+
}
54+
}
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.runtime.instructions.ooc;
21+
22+
import org.apache.sysds.common.Opcodes;
23+
import org.apache.sysds.lops.WeightedDivMM.WDivMMType;
24+
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
25+
import org.apache.sysds.runtime.functionobjects.Multiply;
26+
import org.apache.sysds.runtime.functionobjects.Plus;
27+
import org.apache.sysds.runtime.instructions.InstructionUtils;
28+
import org.apache.sysds.runtime.instructions.cp.CPOperand;
29+
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
30+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
31+
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
32+
import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
33+
import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
34+
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
35+
import org.apache.sysds.runtime.matrix.operators.RightScalarOperator;
36+
import org.apache.sysds.runtime.matrix.operators.QuaternaryOperator;
37+
import org.apache.sysds.runtime.meta.DataCharacteristics;
38+
39+
import java.util.function.Function;
40+
41+
public class WDivMMOOCInstruction extends QuaternaryOOCInstruction {
42+
43+
protected WDivMMOOCInstruction(QuaternaryOperator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4,
44+
CPOperand out, String opcode, String istr) {
45+
super(op, in1, in2, in3, in4, out, opcode, istr);
46+
}
47+
48+
public static WDivMMOOCInstruction parseInstruction(QuaternaryOOCInstruction instr) {
49+
String instrStr = instr.getInstructionString();
50+
String opcode = InstructionUtils.getInstructionPartsWithValueType(instr.getInstructionString())[0];
51+
return new WDivMMOOCInstruction((QuaternaryOperator) instr.getOperator(), instr.input1, instr.input2,
52+
instr.input3, instr.input4, instr.output, opcode, instrStr);
53+
}
54+
55+
@Override
56+
public void processInstruction(ExecutionContext ec) {
57+
QuaternaryOperator qop = ((QuaternaryOperator) _optr);
58+
final WDivMMType wt = qop.wtype3;
59+
60+
CachingStream X = new CachingStream(ec.getMatrixObject(input1).getStreamHandle());
61+
CachingStream U = new CachingStream(ec.getMatrixObject(input2).getStreamHandle());
62+
CachingStream V = new CachingStream(ec.getMatrixObject(input3).getStreamHandle());
63+
64+
boolean basic = wt.isBasic();
65+
boolean left = wt.isLeft();
66+
boolean mult = wt.isMult();
67+
boolean minus = wt.isMinus();
68+
boolean four = wt.hasFourInputs();
69+
boolean scalar = wt.hasScalar();
70+
71+
OOCStream<IndexedMatrixValue> mmt = matMultOOC(U.getReadStream(), V.getReadStream(), U.getDataCharacteristics(),
72+
V.getDataCharacteristics(), false, true);
73+
OOCStream<IndexedMatrixValue> inter;
74+
OOCStream<IndexedMatrixValue> out;
75+
76+
if(basic) {
77+
out = elemMultOOC(X.getReadStream(), mmt);
78+
ec.getMatrixObject(output).setStreamHandle(out);
79+
return;
80+
}
81+
else if(four) {
82+
if(scalar) {
83+
double eps = ec.getScalarInput(input4).getDoubleValue();
84+
inter = elemDivOOC(X.getReadStream(), elemPlusOOC(mmt, eps));
85+
}
86+
else {
87+
CachingStream W = new CachingStream(ec.getMatrixObject(input4).getStreamHandle());
88+
inter = elemMultOOC(X.getReadStream(), elemMinusOOC(mmt, W.getReadStream()));
89+
}
90+
}
91+
else {
92+
if(minus)
93+
inter = maskOOC(X.getReadStream(), elemMinusOOC(mmt, X.getReadStream()));
94+
else {
95+
if(mult)
96+
inter = elemMultOOC(X.getReadStream(), mmt);
97+
else
98+
inter = elemDivOOC(X.getReadStream(), mmt);
99+
}
100+
}
101+
102+
if(left)
103+
out = matMultOOC(inter, U.getReadStream(), X.getDataCharacteristics(), U.getDataCharacteristics(),
104+
true, false);
105+
else
106+
out = matMultOOC(inter, V.getReadStream(), X.getDataCharacteristics(), V.getDataCharacteristics(),
107+
false, false);
108+
109+
ec.getMatrixObject(output).setStreamHandle(out);
110+
}
111+
112+
private OOCStream<IndexedMatrixValue> matMultOOC(OOCStream<IndexedMatrixValue> m1, OOCStream<IndexedMatrixValue> m2,
113+
DataCharacteristics dc1, DataCharacteristics dc2, boolean leftTranspose, boolean rightTranspose) {
114+
115+
int emitLeftThreshold = rightTranspose ? (int) dc2.getNumRowBlocks() : (int) dc2.getNumColBlocks();
116+
int emitRightThreshold = leftTranspose ? (int) dc1.getNumColBlocks() : (int) dc1.getNumRowBlocks();
117+
118+
OOCStream<IndexedMatrixValue> intermediateStream = createWritableStream();
119+
OOCStream<IndexedMatrixValue> out = createWritableStream();
120+
121+
AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
122+
AggregateBinaryOperator op = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
123+
124+
joinManyOOC(m1, m2, intermediateStream, (left, right) -> {
125+
MatrixBlock leftBlock = (MatrixBlock) left.getValue();
126+
MatrixBlock rightBlock = (MatrixBlock) right.getValue();
127+
if(leftTranspose)
128+
leftBlock = leftBlock.transpose();
129+
if(rightTranspose)
130+
rightBlock = rightBlock.transpose();
131+
132+
MatrixBlock partialResult = leftBlock.aggregateBinaryOperations(leftBlock, rightBlock, new MatrixBlock(), op);
133+
int lidx = (int) (leftTranspose ? left.getIndexes().getColumnIndex() : left.getIndexes().getRowIndex());
134+
int ridx = (int) (rightTranspose ? right.getIndexes().getRowIndex() : right.getIndexes().getColumnIndex());
135+
return new IndexedMatrixValue(new MatrixIndexes(lidx, ridx), partialResult);
136+
}, tmp -> leftTranspose ? tmp.getIndexes().getRowIndex() : tmp.getIndexes().getColumnIndex(),
137+
tmp -> rightTranspose ? tmp.getIndexes().getColumnIndex() : tmp.getIndexes().getRowIndex(),
138+
emitLeftThreshold, emitRightThreshold);
139+
140+
BinaryOperator plus = InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString());
141+
int emitAggThreshold = leftTranspose ? (int) dc1.getNumRowBlocks() : (int) dc1.getNumColBlocks();
142+
143+
groupedReduceOOC(intermediateStream, out, (left, right) -> {
144+
MatrixBlock mb = ((MatrixBlock) left.getValue()).binaryOperationsInPlace(plus, right.getValue());
145+
left.setValue(mb);
146+
return left;
147+
}, emitAggThreshold);
148+
149+
return out;
150+
}
151+
152+
private OOCStream<IndexedMatrixValue> elemOOC(OOCStream<IndexedMatrixValue> m1, OOCStream<IndexedMatrixValue> m2, BinaryOperator bop) {
153+
SubscribableTaskQueue<IndexedMatrixValue> out = new SubscribableTaskQueue<>();
154+
Function<IndexedMatrixValue, MatrixIndexes> key = imv ->
155+
new MatrixIndexes(imv.getIndexes().getRowIndex(), imv.getIndexes().getColumnIndex());
156+
157+
joinOOC(m1, m2, out, (left, right) -> {
158+
MatrixBlock lb = (MatrixBlock) left.getValue();
159+
MatrixBlock rb = (MatrixBlock) right.getValue();
160+
MatrixBlock combined = lb.binaryOperations(bop, rb);
161+
return new IndexedMatrixValue(
162+
new MatrixIndexes(left.getIndexes().getRowIndex(), left.getIndexes().getColumnIndex()), combined);
163+
}, key);
164+
165+
return out;
166+
}
167+
168+
private OOCStream<IndexedMatrixValue> elemDivOOC(OOCStream<IndexedMatrixValue> m1, OOCStream<IndexedMatrixValue> m2) {
169+
BinaryOperator div = InstructionUtils.parseBinaryOperator(Opcodes.DIV.toString());
170+
return elemOOC(m1, m2, div);
171+
}
172+
173+
private OOCStream<IndexedMatrixValue> elemMultOOC(OOCStream<IndexedMatrixValue> m1, OOCStream<IndexedMatrixValue> m2) {
174+
BinaryOperator div = InstructionUtils.parseBinaryOperator(Opcodes.MULT.toString());
175+
return elemOOC(m1, m2, div);
176+
}
177+
178+
private OOCStream<IndexedMatrixValue> elemMinusOOC(OOCStream<IndexedMatrixValue> m1, OOCStream<IndexedMatrixValue> m2) {
179+
BinaryOperator div = InstructionUtils.parseBinaryOperator(Opcodes.MINUS.toString());
180+
return elemOOC(m1, m2, div);
181+
}
182+
183+
private OOCStream<IndexedMatrixValue> elemPlusOOC(OOCStream<IndexedMatrixValue> m1, double eps) {
184+
SubscribableTaskQueue<IndexedMatrixValue> out = new SubscribableTaskQueue<>();
185+
mapOOC(m1, out, blk -> {
186+
MatrixBlock res = ((MatrixBlock) blk.getValue())
187+
.scalarOperations(new RightScalarOperator(Plus.getPlusFnObject(), eps), null);
188+
return new IndexedMatrixValue(
189+
new MatrixIndexes(blk.getIndexes().getRowIndex(), blk.getIndexes().getColumnIndex()), res);
190+
});
191+
return out;
192+
}
193+
194+
private OOCStream<IndexedMatrixValue> maskOOC(OOCStream<IndexedMatrixValue> mask, OOCStream<IndexedMatrixValue> m1) {
195+
SubscribableTaskQueue<IndexedMatrixValue> out = new SubscribableTaskQueue<>();
196+
Function<IndexedMatrixValue, MatrixIndexes> key = imv ->
197+
new MatrixIndexes(imv.getIndexes().getRowIndex(), imv.getIndexes().getColumnIndex());
198+
199+
joinOOC(mask, m1, out, (left, right) -> {
200+
MatrixBlock lb = (MatrixBlock) left.getValue();
201+
MatrixBlock rb = (MatrixBlock) right.getValue();
202+
MatrixBlock combined = mask(lb, rb);
203+
return new IndexedMatrixValue(
204+
new MatrixIndexes(left.getIndexes().getRowIndex(), left.getIndexes().getColumnIndex()), combined);
205+
}, key);
206+
207+
return out;
208+
}
209+
210+
private MatrixBlock mask(MatrixBlock mask, MatrixBlock blk) {
211+
for(int i = 0; i < blk.getNumRows(); i++) {
212+
for(int j = 0; j < blk.getNumColumns(); j++) {
213+
if(mask.get(i,j) ==0) blk.set(i, j, 0);
214+
}
215+
}
216+
return blk;
217+
}
218+
}

0 commit comments

Comments
 (0)