Skip to content

Commit e9b4b4c

Browse files
[NFC] cache repeated tree walks to avoid O(N^2) in optimizeTerminatingTails in CodeFolding (#8602)
Cache the result of getBranchTargets(getFunction()->body) in optimizeTerminatingTails so that recursive calls share the same computed set rather than each re-walking the entire function body. This avoids O(N²) behavior where N is the size of the function body, since the recursive calls previously each performed an O(N) tree walk. The cached targets are computed lazily on first need and passed through to the canMove overload that accepts pre-computed branch targets. ## Benmark data For the test case in #7319 (comment) Main head: ```shell time ./build/bin/wasm-opt --code-folding --enable-bulk-memory --enable-multivalue --enable-reference-types --enable-gc --enable-tail-call --enable-exception-handling -o /dev/null ./test3.wasm real 5m45.996s user 6m6.267s sys 0m3.798s ``` This PR: ```shell time ./build/bin/wasm-opt --code-folding --enable-bulk-memory --enable-multivalue --enable-reference-types --enable-gc --enable-tail-call --enable-exception-handling -o /dev/null ./test3.wasm real 2m2.380s user 2m25.700s sys 0m2.449s ``` ## Benchmark regression test Test case: https://jetbrains.github.io/kotlinconf-app/73cbe24d7cf5a54d37ad.wasm On main ```shell Performance counter stats for 'build/bin/wasm-opt 73cbe24d7cf5a54d37ad.wasm -all --code-folding -o /dev/null' (10 runs): 4837936912 task-clock # 1.445 CPUs utilized ( +- 0.51% ) 114 context-switches # 23.564 /sec ( +- 7.58% ) 7 cpu-migrations # 1.447 /sec ( +- 16.88% ) 46271 page-faults # 9.564 K/sec ( +- 0.00% ) 13431328103 instructions # 1.21 insn per cycle ( +- 0.01% ) 11125222873 cycles # 2.300 GHz ( +- 0.51% ) 64641504 branch-misses ( +- 1.26% ) 3.3484 +- 0.0221 seconds time elapsed ( +- 0.66% ) ``` On current PR ```shell Performance counter stats for 'build/bin/wasm-opt 73cbe24d7cf5a54d37ad.wasm -all --code-folding -o /dev/null' (10 runs): 4802304211 task-clock # 1.437 CPUs utilized ( +- 0.47% ) 125 context-switches # 26.029 /sec ( +- 6.50% ) 8 cpu-migrations # 1.666 /sec ( +- 14.20% ) 46272 page-faults # 9.635 K/sec ( +- 0.00% ) 13391520427 instructions # 1.21 insn per cycle ( +- 0.01% ) 11043221889 cycles # 2.300 GHz ( +- 0.47% ) 59021679 branch-misses ( +- 1.24% ) 3.3427 +- 0.0207 seconds time elapsed ( +- 0.62% ) ```
1 parent 3180c6f commit e9b4b4c

1 file changed

Lines changed: 29 additions & 6 deletions

File tree

src/passes/CodeFolding.cpp

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,14 @@ struct CodeFolding
398398
// if one of the items has a branch to something inside outOf that is not
399399
// inside that item
400400
bool canMove(const std::vector<Expression*>& items, Expression* outOf) {
401-
auto allTargets = BranchUtils::getBranchTargets(outOf);
401+
return canMove(items, outOf, BranchUtils::getBranchTargets(outOf));
402+
}
403+
404+
// Overload that accepts pre-computed branch targets to avoid redundant
405+
// O(N) getBranchTargets calls.
406+
bool canMove(const std::vector<Expression*>& items,
407+
Expression* outOf,
408+
const BranchUtils::NameSet& allTargets) {
402409
for (auto* item : items) {
403410
auto exiting = BranchUtils::getExitingBranches(item);
404411
std::vector<Name> intersection;
@@ -632,11 +639,18 @@ struct CodeFolding
632639
// we are just starting; num > 0 means that tails is guaranteed to be
633640
// equal in the last num items, so we can merge there, but we look for
634641
// deeper merges first.
642+
// bodyTargets is lazily computed on first need and then passed to recursive
643+
// calls to avoid repeated O(N) getBranchTargets walks over the function body.
635644
// returns whether we optimized something.
636-
bool optimizeTerminatingTails(std::vector<Tail>& tails, Index num = 0) {
645+
bool optimizeTerminatingTails(std::vector<Tail>& tails,
646+
Index num = 0,
647+
BranchUtils::NameSet* bodyTargets = nullptr) {
637648
if (tails.size() < 2) {
638649
return false;
639650
}
651+
// Storage for body branch targets, declared here so it outlives the
652+
// pointer stored in bodyTargets.
653+
BranchUtils::NameSet localBodyTargets;
640654
// remove things that are untoward and cannot be optimized
641655
tails.erase(
642656
std::remove_if(tails.begin(),
@@ -697,9 +711,11 @@ struct CodeFolding
697711
// can be removed, though
698712
cost += WORTH_ADDING_BLOCK_TO_REMOVE_THIS_MUCH;
699713
// if we cannot merge to the end, then we definitely need 2 blocks,
700-
// and a branch
701-
// TODO: efficiency, entire body
702-
if (!canMove(items, getFunction()->body)) {
714+
// and a branch. Use the pre-computed bodyTargets to avoid repeated
715+
// O(N) getBranchTargets calls.
716+
assert(bodyTargets);
717+
bool canMoveItems = canMove(items, getFunction()->body, *bodyTargets);
718+
if (!canMoveItems) {
703719
cost += 1 + WORTH_ADDING_BLOCK_TO_REMOVE_THIS_MUCH;
704720
// TODO: to do this, we need to maintain a map of element=>parent,
705721
// so that we can insert the new blocks in the right place
@@ -795,7 +811,14 @@ struct CodeFolding
795811
// as the changes may influence us. we leave further opts to further
796812
// passes (as this is rare in practice, it's generally not a perf
797813
// issue, but TODO optimize)
798-
if (optimizeTerminatingTails(explore, num + 1)) {
814+
// Compute body branch targets once and share across recursive
815+
// calls to avoid repeated O(N) tree walks.
816+
if (!bodyTargets) {
817+
localBodyTargets =
818+
BranchUtils::getBranchTargets(getFunction()->body);
819+
bodyTargets = &localBodyTargets;
820+
}
821+
if (optimizeTerminatingTails(explore, num + 1, bodyTargets)) {
799822
return true;
800823
}
801824
}

0 commit comments

Comments
 (0)