Skip to content

Commit ee69e51

Browse files
authored
[BWARE] Track more compressed-friendly ops in FederatedWorkloadAnalyzer (#2481)
* Track more compressed-friendly ops in FederatedWorkloadAnalyzer Extends the federated workload counter so that compression decisions account for additional instruction shapes beyond AggregateBinary. - Pass the right-hand column count to incOverlappingDecompressions so the cost model reflects the actual decompression size rather than counting a single column - Count MMChainCPInstruction as one LMM and one RMM contribution per invocation - Count AggregateUnaryCPInstruction: when reducing columns with a sum/mean operator, treat it as a dict-op (compression-friendly); otherwise count it as a decompression - Minor formatting cleanup in compressRun * Add unit tests for FederatedWorkloadAnalyzer workload tracking Cover the instruction-shape branches in incrementWorkload that drive federated compression decisions, which previously had no direct tests: - AggregateBinary: RMM/LMM counting, overlapping-decompress sizing by the right-hand column count, and the validSize row/column guards - MMChain: one LMM and one RMM contribution per invocation - AggregateUnary: dict-op vs decompression classification across ReduceAll/ReduceRow/ReduceCol with sum, mean, product, and max operators - Instance-level dispatch and compressRun threshold behavior, asserting async compression materializes when the cost model would compress
1 parent 3b564bb commit ee69e51

2 files changed

Lines changed: 388 additions & 5 deletions

File tree

src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkloadAnalyzer.java

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,18 @@
2727
import org.apache.sysds.runtime.compress.cost.InstructionTypeCounter;
2828
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
2929
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
30+
import org.apache.sysds.runtime.functionobjects.IndexFunction;
31+
import org.apache.sysds.runtime.functionobjects.KahanPlus;
32+
import org.apache.sysds.runtime.functionobjects.Mean;
33+
import org.apache.sysds.runtime.functionobjects.Plus;
34+
import org.apache.sysds.runtime.functionobjects.ReduceCol;
3035
import org.apache.sysds.runtime.instructions.Instruction;
3136
import org.apache.sysds.runtime.instructions.cp.AggregateBinaryCPInstruction;
37+
import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction;
3238
import org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction;
39+
import org.apache.sysds.runtime.instructions.cp.MMChainCPInstruction;
40+
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
41+
import org.apache.sysds.runtime.matrix.operators.Operator;
3342

3443
public class FederatedWorkloadAnalyzer {
3544
protected static final Log LOG = LogFactory.getLog(FederatedWorkloadAnalyzer.class.getName());
@@ -55,7 +64,7 @@ public void incrementWorkload(ExecutionContext ec, long tid, Instruction ins) {
5564
}
5665

5766
public void compressRun(ExecutionContext ec, long tid) {
58-
if(counter >= compressRunFrequency ){
67+
if(counter >= compressRunFrequency) {
5968
counter = 0;
6069
get(tid).forEach((K, V) -> CompressedMatrixBlockFactory.compressAsync(ec, Long.toString(K), V));
6170
}
@@ -68,6 +77,7 @@ private void incrementWorkload(ExecutionContext ec, long tid, ComputationCPInstr
6877
public void incrementWorkload(ExecutionContext ec, ConcurrentHashMap<Long, InstructionTypeCounter> mm,
6978
ComputationCPInstruction cpIns) {
7079
// TODO: Count transitive closure via lineage
80+
// TODO: add more operations
7181
if(cpIns instanceof AggregateBinaryCPInstruction) {
7282
final String n1 = cpIns.input1.getName();
7383
MatrixObject d1 = (MatrixObject) ec.getCacheableData(n1);
@@ -81,15 +91,45 @@ public void incrementWorkload(ExecutionContext ec, ConcurrentHashMap<Long, Instr
8191
if(validSize(r1, c1)) {
8292
getOrMakeCounter(mm, Long.parseLong(n1)).incRMM(c2);
8393
// safety add overlapping decompress for RMM
84-
getOrMakeCounter(mm, Long.parseLong(n1)).incOverlappingDecompressions();
94+
getOrMakeCounter(mm, Long.parseLong(n1)).incOverlappingDecompressions(c2);
8595
counter++;
8696
}
8797
if(validSize(r2, c2)) {
8898
getOrMakeCounter(mm, Long.parseLong(n2)).incLMM(r1);
8999
counter++;
90100
}
91-
92101
}
102+
else if(cpIns instanceof MMChainCPInstruction) {
103+
final String n1 = cpIns.input1.getName();
104+
getOrMakeCounter(mm, Long.parseLong(n1)).incRMM(1);
105+
getOrMakeCounter(mm, Long.parseLong(n1)).incLMM(1);
106+
counter++;
107+
}
108+
else if(cpIns instanceof AggregateUnaryCPInstruction) {
109+
Operator op = cpIns.getOperator();
110+
final String n1 = cpIns.input1.getName();
111+
long id = Long.parseLong(n1);
112+
if(op instanceof AggregateUnaryOperator) {
113+
AggregateUnaryOperator aop = (AggregateUnaryOperator) op;
114+
IndexFunction idxF = aop.indexFn;
115+
getOrMakeCounter(mm, id).incDictOps();
116+
if(idxF instanceof ReduceCol) {
117+
if((aop.aggOp.increOp.fn instanceof KahanPlus //
118+
|| aop.aggOp.increOp.fn instanceof Plus //
119+
|| aop.aggOp.increOp.fn instanceof Mean)) {
120+
getOrMakeCounter(mm, id).incDictOps();
121+
}
122+
else {
123+
// increment decompression if row reduce.
124+
getOrMakeCounter(mm, id).incDecompressions();
125+
}
126+
}
127+
else {
128+
getOrMakeCounter(mm, id).incDictOps();
129+
}
130+
}
131+
}
132+
93133
}
94134

95135
private static InstructionTypeCounter getOrMakeCounter(ConcurrentHashMap<Long, InstructionTypeCounter> mm, long id) {
@@ -117,8 +157,8 @@ private static boolean validSize(int nRow, int nCol) {
117157
return nRow > 90 && nRow >= nCol;
118158
}
119159

120-
@Override
121-
public String toString(){
160+
@Override
161+
public String toString() {
122162
StringBuilder sb = new StringBuilder();
123163
sb.append(this.getClass().getSimpleName());
124164
sb.append(" Counter: ");

0 commit comments

Comments
 (0)