Skip to content

Commit 6429df2

Browse files
committed
add pnmf
1 parent 780d790 commit 6429df2

9 files changed

Lines changed: 257 additions & 13 deletions

File tree

scripts/builtin/pnmf.dml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,12 @@
4242
# H List of amplitude matrices, one for each repetition.
4343
# ------------------------------------------------------------------------------------
4444

45-
m_pnmf = function(Matrix[Double] X, Integer rnk, Double eps = 1e-8, Integer maxi = 10, Boolean verbose=TRUE)
45+
m_pnmf = function(Matrix[Double] X, Integer rnk, Double eps = 1e-8, Integer maxi = 10, Boolean verbose=TRUE, Integer seed=-1)
4646
return (Matrix[Double] W, Matrix[Double] H)
4747
{
4848
#initialize W and H
49-
W = rand(rows=nrow(X), cols=rnk, min=0, max=0.025);
50-
H = rand(rows=rnk, cols=ncol(X), min=0, max=0.025);
49+
W = rand(rows=nrow(X), cols=rnk, min=0, max=0.025, seed=seed);
50+
H = rand(rows=rnk, cols=ncol(X), min=0, max=0.025, seed=seed);
5151

5252
i = 0;
5353
while(i < maxi) {

src/main/java/org/apache/sysds/hops/QuaternaryOp.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,8 @@ else if( et == ExecType.SPARK )
211211
constructCPLopsWeightedDivMM(wtype);
212212
else if( et == ExecType.SPARK )
213213
constructSparkLopsWeightedDivMM(wtype);
214+
else if( et == ExecType.OOC )
215+
constructOOCLopsWeightedDivMM(wtype);
214216
else
215217
throw new HopsException("Unsupported quaternaryop-wdivmm exec type: "+et);
216218
break;
@@ -462,6 +464,20 @@ private void constructSparkLopsWeightedDivMM( WDivMMType wtype )
462464
}
463465
}
464466

467+
private void constructOOCLopsWeightedDivMM(WDivMMType wtype)
468+
{
469+
WeightedDivMM wdiv = new WeightedDivMM(
470+
getInput().get(0).constructLops(),
471+
getInput().get(1).constructLops(),
472+
getInput().get(2).constructLops(),
473+
getInput().get(3).constructLops(),
474+
getDataType(), getValueType(), wtype, ExecType.OOC);
475+
476+
setOutputDimensions(wdiv);
477+
setLineNumbers(wdiv);
478+
setLops(wdiv);
479+
}
480+
465481
private void constructCPLopsWeightedCeMM(WCeMMType wtype)
466482
{
467483
WeightedCrossEntropy wcemm = new WeightedCrossEntropy(

src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import org.apache.sysds.runtime.instructions.ooc.ReorgOOCInstruction;
4444
import org.apache.sysds.runtime.instructions.ooc.TeeOOCInstruction;
4545
import org.apache.sysds.runtime.instructions.ooc.AppendOOCInstruction;
46+
import org.apache.sysds.runtime.instructions.ooc.QuaternaryOOCInstruction;
4647

4748
public class OOCInstructionParser extends InstructionParser {
4849
protected static final Log LOG = LogFactory.getLog(OOCInstructionParser.class.getName());
@@ -111,6 +112,8 @@ else if(parts.length == 4)
111112
return DataGenOOCInstruction.parseInstruction(str);
112113
case Append:
113114
return AppendOOCInstruction.parseInstruction(str);
115+
case Quaternary:
116+
return QuaternaryOOCInstruction.parseInstruction(str);
114117

115118
default:
116119
throw new DMLRuntimeException("Invalid OOC Instruction Type: " + ooctype);

src/main/java/org/apache/sysds/runtime/instructions/ooc/ComputationOOCInstruction.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
public abstract class ComputationOOCInstruction extends OOCInstruction {
2626
public CPOperand output;
27-
public CPOperand input1, input2, input3;
27+
public CPOperand input1, input2, input3, input4;
2828

2929
protected ComputationOOCInstruction(OOCType type, Operator op, CPOperand in1, CPOperand out, String opcode, String istr) {
3030
super(type, op, opcode, istr);
@@ -50,6 +50,15 @@ protected ComputationOOCInstruction(OOCType type, Operator op, CPOperand in1, CP
5050
output = out;
5151
}
5252

53+
protected ComputationOOCInstruction(OOCType type, Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand out, String opcode, String istr) {
54+
super(type, op, opcode, istr);
55+
input1 = in1;
56+
input2 = in2;
57+
input3 = in3;
58+
input4 = in4;
59+
output = out;
60+
}
61+
5362
public String getOutputVariableName() {
5463
return output.getName();
5564
}

src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ public abstract class OOCInstruction extends Instruction {
8080

8181
public enum OOCType {
8282
Reblock, Tee, Binary, Ternary, Unary, AggregateUnary, AggregateBinary, AggregateTernary, MAPMM, MMTSJ,
83-
MAPMMCHAIN, Reorg, CM, Ctable, MatrixIndexing, ParameterizedBuiltin, Rand, Append
83+
MAPMMCHAIN, Reorg, CM, Ctable, MatrixIndexing, ParameterizedBuiltin, Rand, Append, Quaternary
8484
}
8585

8686
protected final OOCInstruction.OOCType _ooctype;
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.runtime.instructions.ooc;
21+
22+
23+
import org.apache.sysds.lops.WeightedDivMM;
24+
import org.apache.sysds.runtime.DMLRuntimeException;
25+
import org.apache.sysds.runtime.instructions.InstructionUtils;
26+
import org.apache.sysds.runtime.instructions.cp.CPOperand;
27+
import org.apache.sysds.runtime.matrix.operators.Operator;
28+
import org.apache.sysds.runtime.matrix.operators.QuaternaryOperator;
29+
30+
public abstract class QuaternaryOOCInstruction extends ComputationOOCInstruction {
31+
32+
protected QuaternaryOOCInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4,
33+
CPOperand out, String opcode, String istr) {
34+
super(OOCType.Quaternary, op, in1, in2, in3, in4, out, opcode, istr);
35+
}
36+
37+
public static QuaternaryOOCInstruction parseInstruction(String str) {
38+
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
39+
String opcode = parts[0];
40+
41+
if(opcode.contains(WeightedDivMM.OPCODE)) {
42+
InstructionUtils.checkNumFields(parts, 6);
43+
CPOperand in1 = new CPOperand(parts[1]);
44+
CPOperand in2 = new CPOperand(parts[2]);
45+
CPOperand in3 = new CPOperand(parts[3]);
46+
CPOperand in4 = new CPOperand(parts[4]);
47+
CPOperand out = new CPOperand(parts[5]);
48+
QuaternaryOperator qop = new QuaternaryOperator(WeightedDivMM.WDivMMType.valueOf(parts[6]));
49+
return new WDivMMOOCInstruction(qop, in1, in2, in3, in4, out, opcode, str);
50+
}
51+
throw new DMLRuntimeException("Not implemented yet opcode " + opcode);
52+
}
53+
}
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.runtime.instructions.ooc;
21+
22+
23+
import org.apache.sysds.common.Opcodes;
24+
import org.apache.sysds.common.Types.DataType;
25+
import org.apache.sysds.lops.WeightedDivMM.WDivMMType;
26+
import org.apache.sysds.runtime.DMLRuntimeException;
27+
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
28+
import org.apache.sysds.runtime.functionobjects.Multiply;
29+
import org.apache.sysds.runtime.functionobjects.Plus;
30+
import org.apache.sysds.runtime.instructions.InstructionUtils;
31+
import org.apache.sysds.runtime.instructions.cp.CPOperand;
32+
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
33+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
34+
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
35+
import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
36+
import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
37+
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
38+
import org.apache.sysds.runtime.matrix.operators.QuaternaryOperator;
39+
import org.apache.sysds.runtime.meta.DataCharacteristics;
40+
41+
import java.util.function.Function;
42+
43+
44+
public class WDivMMOOCInstruction extends QuaternaryOOCInstruction
45+
{
46+
47+
protected WDivMMOOCInstruction(QuaternaryOperator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4,
48+
CPOperand out, String opcode, String istr) {
49+
super(op, in1, in2, in3, in4, out, opcode, istr);
50+
}
51+
52+
public static WDivMMOOCInstruction parseInstruction(QuaternaryOOCInstruction instr) {
53+
String instrStr = instr.getInstructionString();
54+
String opcode = InstructionUtils.getInstructionPartsWithValueType(instr.getInstructionString())[0];
55+
return new WDivMMOOCInstruction((QuaternaryOperator) instr.getOperator(), instr.input1, instr.input2,
56+
instr.input3, instr.input4, instr.output, opcode, instrStr);
57+
}
58+
59+
60+
@Override
61+
public void processInstruction(ExecutionContext ec) {
62+
QuaternaryOperator _qop = ((QuaternaryOperator)_optr);
63+
final WDivMMType wt = _qop.wtype3;
64+
65+
if(!(wt.hasFourInputs()&&wt.hasScalar()) || wt.isBasic() || wt.isMult() || wt.isMinus()) throw new DMLRuntimeException("Not implemented: only pnmf supported yet");
66+
67+
CachingStream X = new CachingStream(ec.getMatrixObject(input1).getStreamHandle());
68+
CachingStream U = new CachingStream(ec.getMatrixObject(input2).getStreamHandle());
69+
CachingStream V = new CachingStream(ec.getMatrixObject(input3).getStreamHandle());
70+
71+
double eps = 0.0;
72+
if(_qop.hasFourInputs()) {
73+
if (input4.getDataType() == DataType.SCALAR)
74+
eps = ec.getScalarInput(input4).getDoubleValue();
75+
}
76+
77+
OOCStream<IndexedMatrixValue> mmt = matMultOOC(U.getReadStream(), V.getReadStream(), U.getDataCharacteristics(), V.getDataCharacteristics(), false, true);
78+
OOCStream<IndexedMatrixValue> plus = elemPlusOOC(mmt, eps);
79+
OOCStream<IndexedMatrixValue> inter = elemDivOOC(X.getReadStream(), plus);
80+
OOCStream<IndexedMatrixValue> out;
81+
82+
if(wt.isLeft())
83+
out = matMultOOC(inter, U.getReadStream(), X.getDataCharacteristics(), U.getDataCharacteristics(), true, false);
84+
else
85+
out = matMultOOC(inter, V.getReadStream(), X.getDataCharacteristics(), V.getDataCharacteristics(), false, false);
86+
87+
ec.getMatrixObject(output).setStreamHandle(out);
88+
}
89+
90+
private OOCStream<IndexedMatrixValue> matMultOOC(OOCStream<IndexedMatrixValue> m1, OOCStream<IndexedMatrixValue> m2, DataCharacteristics dc1, DataCharacteristics dc2, boolean leftTranspose, boolean rightTranspose){
91+
92+
int emitLeftThreshold = rightTranspose? (int) dc2.getNumRowBlocks() : (int) dc2.getNumColBlocks();
93+
int emitRightThreshold = leftTranspose? (int) dc1.getNumColBlocks() : (int) dc1.getNumRowBlocks();
94+
95+
OOCStream<IndexedMatrixValue> intermediateStream = createWritableStream();
96+
OOCStream<IndexedMatrixValue> out = createWritableStream();
97+
98+
AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
99+
AggregateBinaryOperator op = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
100+
101+
joinManyOOC(m1, m2, intermediateStream,
102+
(left, right) -> {
103+
MatrixBlock leftBlock = (MatrixBlock) left.getValue();
104+
MatrixBlock rightBlock = (MatrixBlock) right.getValue();
105+
if(leftTranspose) leftBlock = leftBlock.transpose();
106+
if(rightTranspose) rightBlock = rightBlock.transpose();
107+
108+
MatrixBlock partialResult = leftBlock.aggregateBinaryOperations(leftBlock, rightBlock,
109+
new MatrixBlock(), op);
110+
int lidx = (int) (leftTranspose? left.getIndexes().getColumnIndex() : left.getIndexes().getRowIndex());
111+
int ridx = (int) (rightTranspose? right.getIndexes().getRowIndex() : right.getIndexes().getColumnIndex());
112+
return new IndexedMatrixValue(new MatrixIndexes(lidx, ridx), partialResult);
113+
},
114+
tmp -> leftTranspose? tmp.getIndexes().getRowIndex() : tmp.getIndexes().getColumnIndex(),
115+
tmp -> rightTranspose? tmp.getIndexes().getColumnIndex() : tmp.getIndexes().getRowIndex(),
116+
emitLeftThreshold, emitRightThreshold);
117+
118+
BinaryOperator plus = InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString());
119+
int emitAggThreshold = leftTranspose? (int) dc1.getNumRowBlocks() : (int) dc1.getNumColBlocks();
120+
121+
groupedReduceOOC(intermediateStream, out, (left, right) -> {
122+
MatrixBlock mb = ((MatrixBlock)left.getValue()).binaryOperationsInPlace(plus, right.getValue());
123+
left.setValue(mb);
124+
return left;
125+
}, emitAggThreshold);
126+
127+
return out;
128+
}
129+
130+
private OOCStream<IndexedMatrixValue> elemDivOOC(OOCStream<IndexedMatrixValue> m1, OOCStream<IndexedMatrixValue> m2){
131+
SubscribableTaskQueue<IndexedMatrixValue> out = new SubscribableTaskQueue<>();
132+
BinaryOperator div = InstructionUtils.parseBinaryOperator(Opcodes.DIV.toString());
133+
Function<IndexedMatrixValue, MatrixIndexes> key = imv -> new MatrixIndexes(imv.getIndexes().getRowIndex(), imv.getIndexes().getColumnIndex());
134+
135+
joinOOC(m1, m2, out, (left, right) -> {
136+
MatrixBlock lb = (MatrixBlock) left.getValue();
137+
MatrixBlock rb = (MatrixBlock) right.getValue();
138+
MatrixBlock combined = lb.binaryOperations(div, rb);
139+
return new IndexedMatrixValue(
140+
new MatrixIndexes(left.getIndexes().getRowIndex(), left.getIndexes().getColumnIndex()), combined);
141+
}, key);
142+
143+
return out;
144+
}
145+
146+
private OOCStream<IndexedMatrixValue> elemPlusOOC(OOCStream<IndexedMatrixValue> m1, double eps){
147+
SubscribableTaskQueue<IndexedMatrixValue> out = new SubscribableTaskQueue<>();
148+
mapOOC(m1, out, blk -> new IndexedMatrixValue(
149+
new MatrixIndexes(blk.getIndexes().getRowIndex(), blk.getIndexes().getColumnIndex()), plusDouble((MatrixBlock) blk.getValue(), eps)));
150+
return out;
151+
}
152+
153+
private MatrixBlock plusDouble(MatrixBlock blk, double eps){
154+
for(int i=0; i<blk.getNumRows(); i++){
155+
for(int j=0; j<blk.getNumColumns(); j++){
156+
blk.set(i, j, blk.get(i, j) + eps);
157+
}
158+
}
159+
return blk;
160+
}
161+
}

src/test/java/org/apache/sysds/test/functions/ooc/PNMFTest.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.apache.sysds.test.AutomatedTestBase;
2828
import org.apache.sysds.test.TestConfiguration;
2929
import org.apache.sysds.test.TestUtils;
30+
import org.junit.Test;
3031

3132
public class PNMFTest extends AutomatedTestBase {
3233
private static final String TEST_NAME = "PNMF";
@@ -44,6 +45,7 @@ public class PNMFTest extends AutomatedTestBase {
4445
private static final int RANK = 20;
4546
private static final int MAX_ITER = 10;
4647
private static final int BLOCK_SIZE = 1000;
48+
private static final int SEED = 7;
4749

4850
private static final double SPARSITY = 0.7;
4951
private static final double EPS = 1e-6;
@@ -54,7 +56,7 @@ public void setUp() {
5456
addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME));
5557
}
5658

57-
//@Test
59+
@Test
5860
public void testPNMFOOCVsCP() {
5961
runPNMFTest();
6062
}
@@ -71,13 +73,13 @@ private void runPNMFTest() {
7173
double[][] xData = getRandomMatrix(ROWS, COLS, 1, 10, SPARSITY, 7);
7274
writeBinaryWithMTD(INPUT_X, DataConverter.convertToMatrixBlock(xData));
7375

74-
programArgs = new String[] {"-explain", "-stats", "-seed", "7", "-ooc", "-args",
75-
input(INPUT_X), String.valueOf(RANK), String.valueOf(MAX_ITER),
76+
programArgs = new String[] {"-explain", "-stats", "-ooc", "-args",
77+
input(INPUT_X), String.valueOf(RANK), String.valueOf(MAX_ITER), String.valueOf(SEED),
7678
output(OUTPUT_W_OOC), output(OUTPUT_H_OOC)};
7779
runTest(true, false, null, -1);
7880

79-
programArgs = new String[] {"-explain", "-stats", "-seed", "7", "-args",
80-
input(INPUT_X), String.valueOf(RANK), String.valueOf(MAX_ITER),
81+
programArgs = new String[] {"-explain", "-stats", "-args",
82+
input(INPUT_X), String.valueOf(RANK), String.valueOf(MAX_ITER), String.valueOf(SEED),
8183
output(OUTPUT_W_CP), output(OUTPUT_H_CP)};
8284
runTest(true, false, null, -1);
8385

src/test/scripts/functions/ooc/PNMF.dml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
#-------------------------------------------------------------
2121

2222
X = read($1);
23-
[W, H] = pnmf(X=X, rnk=$2, maxi=$3, verbose=FALSE);
23+
[W, H] = pnmf(X=X, rnk=$2, maxi=$3, verbose=FALSE, seed=$4);
2424

25-
write(W, $4, format="binary");
26-
write(H, $5, format="binary");
25+
write(W, $5, format="binary");
26+
write(H, $6, format="binary");

0 commit comments

Comments
 (0)