Skip to content

Commit e4dd03b

Browse files
committed
Align unary/binary transitive Spark exec-type decision and add compile test
Make the spark-specific decision refinement consistent between UnaryOp and BinaryOp: - UnaryOp: restore the input-is-not-checkpoint and single-parent guards, and drop the redundant `_etype != ExecType.SPARK` clause - BinaryOp: use the shared hasSparkOutput() helper instead of an inline optFindExecType() == SPARK check Add a compilation-verification test suite under component/compile that compiles a DML script into a runtime program and inspects instruction exec types without executing. CompilerTestBase provides the compile and plan-inspection helpers; SparkTransitiveExecTypeCompileTest verifies a CP-by-estimate unary on a Spark-resident input is pulled into Spark only when it is the sole consumer.
1 parent b6e3900 commit e4dd03b

4 files changed

Lines changed: 251 additions & 4 deletions

File tree

src/main/java/org/apache/sysds/hops/BinaryOp.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -828,7 +828,7 @@ else if ( dt1 == DataType.SCALAR && dt2 == DataType.MATRIX ) {
828828
&& (supportsMatrixScalarOperations() || op == OpOp2.APPLY_SCHEMA) // supported operation
829829
&& sparkIn.getParent().size() == 1 // only one parent
830830
&& !HopRewriteUtils.isSingleBlock(sparkIn) // single block triggered exec
831-
&& sparkIn.optFindExecType() == ExecType.SPARK // input was spark op.
831+
&& sparkIn.hasSparkOutput() // input was spark op.
832832
&& !(sparkIn instanceof DataOp) // input is not checkpoint
833833
) {
834834
// pull operation into spark

src/main/java/org/apache/sysds/hops/UnaryOp.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -508,13 +508,12 @@ else if ( getInput().get(0).areDimsBelowThreshold() || getInput().get(0).isVecto
508508
//spark-specific decision refinement (execute unary w/ spark input and
509509
//single parent also in spark because it's likely cheap and reduces intermediates)
510510
if(_etype == ExecType.CP // currently CP instruction
511-
&& _etype != ExecType.SPARK /// currently not SP.
512511
&& _etypeForced != ExecType.CP // not forced as CP instruction
513512
&& getInput(0).hasSparkOutput() // input is a spark instruction
514513
&& (getDataType().isMatrix() || getDataType().isFrame()) // output is a matrix or frame
515514
&& !isDisallowedSparkOps() // is invalid spark instruction
516-
// && !(getInput().get(0) instanceof DataOp) // input is not checkpoint
517-
// && getInput(0).getParent().size() <= 1// unary is only parent
515+
&& !(getInput(0) instanceof DataOp) // input is not checkpoint
516+
&& getInput(0).getParent().size() == 1 // unary is only parent
518517
) {
519518
//pull unary operation into spark
520519
_etype = ExecType.SPARK;
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.component.compile;
21+
22+
import java.util.ArrayList;
23+
import java.util.HashMap;
24+
import java.util.List;
25+
import java.util.Map;
26+
import java.util.stream.Collectors;
27+
28+
import org.apache.sysds.api.DMLScript;
29+
import org.apache.sysds.common.Types.ExecMode;
30+
import org.apache.sysds.conf.ConfigurationManager;
31+
import org.apache.sysds.conf.DMLConfig;
32+
import org.apache.sysds.hops.OptimizerUtils;
33+
import org.apache.sysds.hops.recompile.Recompiler;
34+
import org.apache.sysds.parser.DMLProgram;
35+
import org.apache.sysds.parser.DMLTranslator;
36+
import org.apache.sysds.parser.ParserFactory;
37+
import org.apache.sysds.parser.ParserWrapper;
38+
import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
39+
import org.apache.sysds.runtime.controlprogram.ForProgramBlock;
40+
import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
41+
import org.apache.sysds.runtime.controlprogram.IfProgramBlock;
42+
import org.apache.sysds.runtime.controlprogram.Program;
43+
import org.apache.sysds.runtime.controlprogram.ProgramBlock;
44+
import org.apache.sysds.runtime.controlprogram.WhileProgramBlock;
45+
import org.apache.sysds.runtime.instructions.Instruction;
46+
import org.apache.sysds.runtime.instructions.cp.CPInstruction;
47+
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
48+
import org.apache.sysds.test.AutomatedTestBase;
49+
import org.apache.sysds.utils.Explain;
50+
import org.apache.sysds.utils.stats.InfrastructureAnalyzer;
51+
import org.junit.Assert;
52+
53+
/**
54+
* Base class for compilation-verification tests: compile a DML script into a runtime {@link Program} and inspect the
55+
* resulting plan (instructions and their exec types) without ever executing it.
56+
*/
57+
public abstract class CompilerTestBase extends AutomatedTestBase {
58+
59+
/** A small default local memory budget (8 MB) that forces large operations onto Spark in HYBRID mode. */
60+
public static final long SMALL_MEM_BUDGET = 8L * 1024 * 1024;
61+
62+
@Override
63+
public void setUp() {
64+
// no test-configuration setup needed; scripts are compiled from in-memory strings
65+
}
66+
67+
/**
68+
* Compile a DML script string into a runtime {@link Program} without executing it.
69+
*
70+
* @param dmlScript the DML source
71+
* @param args named command-line arguments ($name -&gt; value), may be null
72+
* @param mode the global execution mode (e.g. {@link ExecMode#HYBRID})
73+
* @param localMaxMem the local memory budget in bytes used for memory-based exec-type decisions
74+
* @return the compiled runtime program
75+
*/
76+
protected Program compile(String dmlScript, Map<String, String> args, ExecMode mode, long localMaxMem) {
77+
final ExecMode oldMode = DMLScript.getGlobalExecMode();
78+
final long oldMem = InfrastructureAnalyzer.getLocalMaxMemory();
79+
final DMLConfig oldConfig = ConfigurationManager.getDMLConfig();
80+
try {
81+
ConfigurationManager.setGlobalConfig(new DMLConfig());
82+
DMLScript.setGlobalExecMode(mode);
83+
InfrastructureAnalyzer.setLocalMaxMemory(localMaxMem);
84+
OptimizerUtils.resetDefaultSize();
85+
86+
Map<String, String> argVals = (args == null) ? new HashMap<>() : new HashMap<>(args);
87+
ParserWrapper parser = ParserFactory.createParser();
88+
DMLProgram prog = parser.parse(null, dmlScript, argVals);
89+
DMLTranslator dmlt = new DMLTranslator(prog);
90+
dmlt.liveVariableAnalysis(prog);
91+
dmlt.validateParseTree(prog);
92+
dmlt.constructHops(prog);
93+
dmlt.rewriteHopsDAG(prog);
94+
dmlt.constructLops(prog);
95+
dmlt.rewriteLopDAG(prog);
96+
return dmlt.getRuntimeProgram(prog, ConfigurationManager.getDMLConfig());
97+
}
98+
catch(Exception e) {
99+
throw new RuntimeException("Failed to compile DML script:\n" + dmlScript, e);
100+
}
101+
finally {
102+
DMLScript.setGlobalExecMode(oldMode);
103+
InfrastructureAnalyzer.setLocalMaxMemory(oldMem);
104+
ConfigurationManager.setGlobalConfig(oldConfig);
105+
Recompiler.reinitRecompiler();
106+
}
107+
}
108+
109+
/** Recursively collect every instruction in the program, including control-flow predicates and function bodies. */
110+
protected List<Instruction> getInstructions(Program prog) {
111+
List<Instruction> out = new ArrayList<>();
112+
for(ProgramBlock pb : prog.getProgramBlocks())
113+
collect(pb, out);
114+
for(FunctionProgramBlock fpb : prog.getFunctionProgramBlocks(false).values())
115+
collect(fpb, out);
116+
return out;
117+
}
118+
119+
private void collect(ProgramBlock pb, List<Instruction> out) {
120+
if(pb instanceof BasicProgramBlock) {
121+
out.addAll(((BasicProgramBlock) pb).getInstructions());
122+
}
123+
else if(pb instanceof IfProgramBlock) {
124+
IfProgramBlock ipb = (IfProgramBlock) pb;
125+
out.addAll(ipb.getPredicate());
126+
ipb.getChildBlocksIfBody().forEach(c -> collect(c, out));
127+
ipb.getChildBlocksElseBody().forEach(c -> collect(c, out));
128+
}
129+
else if(pb instanceof WhileProgramBlock) {
130+
WhileProgramBlock wpb = (WhileProgramBlock) pb;
131+
out.addAll(wpb.getPredicate());
132+
wpb.getChildBlocks().forEach(c -> collect(c, out));
133+
}
134+
else if(pb instanceof ForProgramBlock) { // incl. ParForProgramBlock
135+
ForProgramBlock fpb = (ForProgramBlock) pb;
136+
out.addAll(fpb.getFromInstructions());
137+
out.addAll(fpb.getToInstructions());
138+
out.addAll(fpb.getIncrementInstructions());
139+
fpb.getChildBlocks().forEach(c -> collect(c, out));
140+
}
141+
else if(pb instanceof FunctionProgramBlock) {
142+
((FunctionProgramBlock) pb).getChildBlocks().forEach(c -> collect(c, out));
143+
}
144+
}
145+
146+
/** All instructions whose opcode equals {@code opcode} (exact match). */
147+
protected List<Instruction> getByOpcode(Program prog, String opcode) {
148+
return getInstructions(prog).stream().filter(i -> opcode.equals(i.getOpcode()))
149+
.collect(Collectors.toList());
150+
}
151+
152+
protected static boolean isSpark(Instruction inst) {
153+
return inst instanceof SPInstruction;
154+
}
155+
156+
protected static boolean isCP(Instruction inst) {
157+
return inst instanceof CPInstruction;
158+
}
159+
160+
/** Assert that at least one instruction with the given opcode exists and that all such instructions are Spark. */
161+
protected void assertSpark(Program prog, String opcode) {
162+
assertExecType(prog, opcode, true);
163+
}
164+
165+
/** Assert that at least one instruction with the given opcode exists and that all such instructions are CP. */
166+
protected void assertCP(Program prog, String opcode) {
167+
assertExecType(prog, opcode, false);
168+
}
169+
170+
private void assertExecType(Program prog, String opcode, boolean expectSpark) {
171+
List<Instruction> matches = getByOpcode(prog, opcode);
172+
Assert.assertFalse("Expected at least one '" + opcode + "' instruction but found none.\n"
173+
+ Explain.explain(prog), matches.isEmpty());
174+
for(Instruction inst : matches) {
175+
boolean spark = isSpark(inst);
176+
Assert.assertEquals("Instruction '" + opcode + "' expected exec type "
177+
+ (expectSpark ? "SPARK" : "CP") + " but was " + (spark ? "SPARK" : "CP") + ".\n"
178+
+ Explain.explain(prog), expectSpark, spark);
179+
}
180+
}
181+
182+
protected long countSpark(Program prog) {
183+
return getInstructions(prog).stream().filter(CompilerTestBase::isSpark).count();
184+
}
185+
186+
protected String explain(Program prog) {
187+
return Explain.explain(prog);
188+
}
189+
}
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.component.compile;
21+
22+
import org.apache.sysds.common.Types.ExecMode;
23+
import org.apache.sysds.runtime.controlprogram.Program;
24+
import org.junit.Test;
25+
26+
/**
27+
* Verifies the transitive Spark exec-type refinement in {@link org.apache.sysds.hops.UnaryOp}: a CP-by-estimate unary on
28+
* a Spark-resident input is pulled into Spark only when it is the sole consumer ({@code getParent().size() == 1}).
29+
*/
30+
public class SparkTransitiveExecTypeCompileTest extends CompilerTestBase {
31+
32+
private static final String DML_HEADER =
33+
"X = rand(rows=20000000, cols=8, seed=1);\n" + // ~1.2GB -> rand and colSums run on Spark
34+
"v = colSums(X);\n"; // 1x8 Spark-resident vector (opcode uack+)
35+
36+
@Test
37+
public void singleConsumerUnaryPulledIntoSpark() {
38+
String dml = DML_HEADER +
39+
"r = round(v);\n" + // sole consumer of the Spark-resident vector -> pulled into Spark
40+
"print(sum(r));\n";
41+
Program prog = compile(dml, null, ExecMode.HYBRID, SMALL_MEM_BUDGET);
42+
43+
assertSpark(prog, "uack+"); // input genuinely has a Spark output
44+
assertSpark(prog, "round"); // unary pulled into Spark (CP by mem estimate, single consumer)
45+
}
46+
47+
@Test
48+
public void multiConsumerUnaryStaysCP() {
49+
String dml = DML_HEADER +
50+
"a = round(v);\n" + // v now has two consumers (round + abs) ...
51+
"b = abs(v);\n" +
52+
"print(sum(a) + sum(b));\n";
53+
Program prog = compile(dml, null, ExecMode.HYBRID, SMALL_MEM_BUDGET);
54+
55+
assertSpark(prog, "uack+"); // input still has a Spark output ...
56+
assertCP(prog, "round"); // ... but the multi-parent guard keeps both unaries in CP
57+
assertCP(prog, "abs");
58+
}
59+
}

0 commit comments

Comments
 (0)