Skip to content

Commit 8025025

Browse files
committed
Add Support For Grouped Tiles
Add Stream Split and Merge Primitives + Bugfixes / Additional Tests Preliminary Bugfixes and Performance Improvements Bugfix Cache Deletion Add Support for Output Write Partitioning
1 parent 3f4868d commit 8025025

37 files changed

Lines changed: 2656 additions & 226 deletions
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.hops.ipa;
21+
22+
import org.apache.sysds.api.DMLScript;
23+
import org.apache.sysds.hops.HopsException;
24+
import org.apache.sysds.hops.rewrite.ProgramRewriteStatus;
25+
import org.apache.sysds.hops.rewrite.ProgramRewriter;
26+
import org.apache.sysds.hops.rewrite.RewriteInjectOOCTee;
27+
import org.apache.sysds.parser.DMLProgram;
28+
import org.apache.sysds.parser.LanguageException;
29+
30+
/**
31+
* Applies OOC tee injection after static/dynamic rewrites in IPA.
32+
*/
33+
public class IPAPassInjectOOCTee extends IPAPass {
34+
@Override
35+
public boolean isApplicable(FunctionCallGraph fgraph) {
36+
return DMLScript.USE_OOC;
37+
}
38+
39+
@Override
40+
public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes) {
41+
try {
42+
ProgramRewriter rewriter = new ProgramRewriter(new RewriteInjectOOCTee());
43+
ProgramRewriteStatus status = new ProgramRewriteStatus();
44+
rewriter.rewriteProgramHopDAGs(prog, true, status);
45+
return false;
46+
}
47+
catch(LanguageException ex) {
48+
throw new HopsException(ex);
49+
}
50+
}
51+
}
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.hops.ipa;
21+
22+
import java.util.HashSet;
23+
import java.util.List;
24+
import java.util.Set;
25+
26+
import org.apache.sysds.api.DMLScript;
27+
import org.apache.sysds.hops.Hop;
28+
import org.apache.sysds.parser.DMLProgram;
29+
import org.apache.sysds.parser.ForStatement;
30+
import org.apache.sysds.parser.ForStatementBlock;
31+
import org.apache.sysds.parser.FunctionStatement;
32+
import org.apache.sysds.parser.FunctionStatementBlock;
33+
import org.apache.sysds.parser.IfStatement;
34+
import org.apache.sysds.parser.IfStatementBlock;
35+
import org.apache.sysds.parser.StatementBlock;
36+
import org.apache.sysds.parser.WhileStatement;
37+
import org.apache.sysds.parser.WhileStatementBlock;
38+
39+
/**
40+
* Prune stale parent links by keeping only parent references reachable from the statement block roots/predicates.
41+
*/
42+
public class IPAPassPruneUnreachableHops extends IPAPass {
43+
@Override
44+
public boolean isApplicable(FunctionCallGraph fgraph) {
45+
return DMLScript.USE_OOC;
46+
}
47+
48+
@Override
49+
public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes) {
50+
pruneStatementBlocks(prog.getStatementBlocks());
51+
for(FunctionStatementBlock fsb : prog.getFunctionStatementBlocks())
52+
pruneStatementBlocks(((FunctionStatement) fsb.getStatement(0)).getBody());
53+
return false;
54+
}
55+
56+
private static void pruneStatementBlocks(List<StatementBlock> sbs) {
57+
for(StatementBlock sb : sbs) {
58+
if(sb instanceof WhileStatementBlock) {
59+
WhileStatementBlock wsb = (WhileStatementBlock) sb;
60+
WhileStatement wstmt = (WhileStatement) sb.getStatement(0);
61+
pruneHops(wsb.getPredicateHops());
62+
pruneStatementBlocks(wstmt.getBody());
63+
}
64+
else if(sb instanceof IfStatementBlock) {
65+
IfStatementBlock isb = (IfStatementBlock) sb;
66+
IfStatement istmt = (IfStatement) sb.getStatement(0);
67+
pruneHops(isb.getPredicateHops());
68+
pruneStatementBlocks(istmt.getIfBody());
69+
if(istmt.getElseBody() != null)
70+
pruneStatementBlocks(istmt.getElseBody());
71+
}
72+
else if(sb instanceof ForStatementBlock) {
73+
ForStatementBlock fsb = (ForStatementBlock) sb;
74+
ForStatement fstmt = (ForStatement) sb.getStatement(0);
75+
pruneHops(fsb.getFromHops());
76+
pruneHops(fsb.getToHops());
77+
pruneHops(fsb.getIncrementHops());
78+
pruneStatementBlocks(fstmt.getBody());
79+
}
80+
else if(sb instanceof FunctionStatementBlock) {
81+
FunctionStatement fstmt = (FunctionStatement) sb.getStatement(0);
82+
pruneStatementBlocks(fstmt.getBody());
83+
}
84+
else {
85+
pruneHops(sb.getHops());
86+
}
87+
}
88+
}
89+
90+
private static void pruneHops(Hop root) {
91+
if(root == null)
92+
return;
93+
Set<Long> reachable = new HashSet<>();
94+
collectReachable(root, reachable);
95+
pruneParents(root, reachable, new HashSet<Long>());
96+
}
97+
98+
private static void pruneHops(List<Hop> roots) {
99+
if(roots == null || roots.isEmpty())
100+
return;
101+
102+
Set<Long> reachable = new HashSet<>();
103+
for(Hop root : roots)
104+
collectReachable(root, reachable);
105+
106+
for(Hop root : roots)
107+
pruneParents(root, reachable, new HashSet<Long>());
108+
}
109+
110+
private static void collectReachable(Hop hop, Set<Long> reachable) {
111+
if(hop == null || !reachable.add(hop.getHopID()))
112+
return;
113+
for(Hop in : hop.getInput())
114+
collectReachable(in, reachable);
115+
}
116+
117+
private static void pruneParents(Hop hop, Set<Long> reachable, Set<Long> visited) {
118+
if(hop == null || !visited.add(hop.getHopID()))
119+
return;
120+
hop.getParent().removeIf(p -> !reachable.contains(p.getHopID()));
121+
for(Hop in : hop.getInput())
122+
pruneParents(in, reachable, visited);
123+
}
124+
}

src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.apache.sysds.common.Types.DataType;
2525
import org.apache.sysds.common.Types.ValueType;
2626
import org.apache.sysds.conf.ConfigurationManager;
27+
import org.apache.sysds.api.DMLScript;
2728
import org.apache.sysds.hops.DataOp;
2829
import org.apache.sysds.hops.FunctionOp;
2930
import org.apache.sysds.hops.FunctionOp.FunctionType;
@@ -141,6 +142,10 @@ public InterProceduralAnalysis(DMLProgram dmlp) {
141142
//would require an update of the function call graph
142143
_passes.add(new IPAPassForwardFunctionCalls());
143144
_passes.add(new IPAPassApplyStaticAndDynamicHopRewrites());
145+
if (DMLScript.USE_OOC) {
146+
_passes.add(new IPAPassPruneUnreachableHops());
147+
_passes.add(new IPAPassInjectOOCTee());
148+
}
144149
}
145150

146151
public InterProceduralAnalysis(StatementBlock sb) {

src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,6 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites)
152152
_dagRuleSet.add( new RewriteConstantFolding() ); //dependency: cse
153153
_sbRuleSet.add( new RewriteRemoveEmptyBasicBlocks() );
154154
_sbRuleSet.add( new RewriteRemoveEmptyForLoops() );
155-
_sbRuleSet.add( new RewriteInjectOOCTee() );
156155
}
157156

158157
/**

src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
6464
import org.apache.sysds.runtime.meta.MetaData;
6565
import org.apache.sysds.runtime.meta.MetaDataFormat;
66+
import org.apache.sysds.runtime.ooc.stream.SourceOOCStreamable;
6667
import org.apache.sysds.runtime.util.HDFSTool;
6768
import org.apache.sysds.runtime.util.LocalFileUtils;
6869
import org.apache.sysds.runtime.util.UtilFunctions;
@@ -496,7 +497,7 @@ public synchronized OOCStream<IndexedMatrixValue> getStreamHandle() {
496497
}
497498

498499
public OOCStreamable<IndexedMatrixValue> getStreamable() {
499-
return _streamHandle;
500+
return _streamHandle == null ? new SourceOOCStreamable(this) : _streamHandle;
500501
}
501502

502503
/**

src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,12 @@
3737
import org.apache.sysds.runtime.DMLScriptException;
3838
import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
3939
import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
40+
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
4041
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
4142
import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory;
4243
import org.apache.sysds.runtime.instructions.Instruction;
4344
import org.apache.sysds.runtime.instructions.InstructionUtils;
45+
import org.apache.sysds.runtime.instructions.ooc.TeeOOCInstruction;
4446
import org.apache.sysds.runtime.io.IOUtilFunctions;
4547
import org.apache.sysds.runtime.lineage.Lineage;
4648
import org.apache.sysds.runtime.lineage.LineageCache;
@@ -172,6 +174,8 @@ public void processInstruction(ExecutionContext ec) {
172174

173175
//set input parameter
174176
functionVariables.put(currFormalParam.getName(), value);
177+
if (DMLScript.USE_OOC && value instanceof MatrixObject)
178+
TeeOOCInstruction.incrRef(((MatrixObject) value).getStreamable(), 1);
175179

176180
//map lineage to function arguments
177181
if( lineage != null ) {
@@ -227,7 +231,8 @@ public void processInstruction(ExecutionContext ec) {
227231
if( expectRetVars.contains(varName) )
228232
continue;
229233
//cleanup unexpected return values to avoid leaks
230-
fn_ec.cleanupDataObject(fn_ec.removeVariable(varName));
234+
//(including OOC reference tracking for matrix streams)
235+
VariableCPInstruction.processRmvarInstruction(fn_ec, varName);
231236
}
232237

233238
// Unpin the pinned variables
@@ -247,10 +252,12 @@ public void processInstruction(ExecutionContext ec) {
247252

248253
// remove existing data bound to output variable name
249254
Data exdata = ec.removeVariable(boundVarName);
255+
if (DMLScript.USE_OOC && exdata instanceof MatrixObject && exdata != boundValue)
256+
TeeOOCInstruction.incrRef(((MatrixObject) exdata).getStreamable(), -1);
250257
// save old data for cleanup later
251258
if (exdata != boundValue && !retVars.hasReferences(exdata))
252259
toBeCleanedUp.add(exdata);
253-
//FIXME: interferes with reuse. Removes broadcasts before materialization
260+
//FIXME: interferes with reuse. Removes broadcasts before materialization
254261

255262
//add/replace data in symbol table
256263
ec.setVariable(boundVarName, boundValue);
@@ -276,11 +283,17 @@ public void processInstruction(ExecutionContext ec) {
276283
//update lineage cache with the functions outputs
277284
if ((DMLScript.LINEAGE && LineageCacheConfig.isMultiLevelReuse() && !fpb.isNondeterministic())
278285
|| (LineageCacheConfig.isEstimator() && !fpb.isNondeterministic())) {
279-
LineageCache.putValue(fpb.getOutputParams(), liInputs,
286+
LineageCache.putValue(fpb.getOutputParams(), liInputs,
280287
getCacheFunctionName(_functionName, fpb), fn_ec, t1-t0);
281-
//FIXME: send _boundOutputNames instead of fpb.getOutputParams as
288+
//FIXME: send _boundOutputNames instead of fpb.getOutputParams as
282289
//those are already replaced by boundoutput names in the lineage map.
283290
}
291+
292+
// cleanup declared outputs that are not bound at callsite
293+
for (int i = numOutputs; i < fpb.getOutputParams().size(); i++) {
294+
String retVarName = fpb.getOutputParams().get(i).getName();
295+
VariableCPInstruction.processRmvarInstruction(fn_ec, retVarName);
296+
}
284297
}
285298

286299
@Override

src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ public void processInstruction( ExecutionContext ec ) {
157157
else {
158158
OOCStream<MatrixBlock> qLocal = createWritableStream();
159159

160-
mapOOC(qIn, qLocal, tmp -> (MatrixBlock) ((MatrixBlock) tmp.getValue())
160+
mapOOC(qIn, qLocal, tmp -> (MatrixBlock) tmp.getValue()
161161
.aggregateUnaryOperations(aggun, new MatrixBlock(), blen, tmp.getIndexes()));
162162

163163
MatrixBlock ltmp;

src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ protected void processMatrixMatrixInstruction(ExecutionContext ec) {
8282
boolean isRowBroadcast = m1.getNumRows() > 1 && m2.getNumRows() == 1;
8383

8484
if (isColBroadcast && !isRowBroadcast) {
85-
final long maxProcessesPerBroadcast = m1.getNumColumns() / m1.getBlocksize();
85+
final long maxProcessesPerBroadcast = (m1.getNumColumns() + m1.getBlocksize() - 1) / m1.getBlocksize();
8686

8787
broadcastJoinOOC(qIn1, qIn2, qOut, (tmp1, b) -> {
8888
IndexedMatrixValue tmpOut = new IndexedMatrixValue();
@@ -96,7 +96,7 @@ protected void processMatrixMatrixInstruction(ExecutionContext ec) {
9696
}, tmp -> tmp.getIndexes().getRowIndex());
9797
}
9898
else if (isRowBroadcast && !isColBroadcast) {
99-
final long maxProcessesPerBroadcast = m1.getNumRows() / m1.getBlocksize();
99+
final long maxProcessesPerBroadcast = (m1.getNumRows() + m1.getBlocksize() - 1) / m1.getBlocksize();
100100

101101
broadcastJoinOOC(qIn1, qIn2, qOut, (tmp1, b) -> {
102102
IndexedMatrixValue tmpOut = new IndexedMatrixValue();

0 commit comments

Comments
 (0)