summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorFiona Glaser <escha@apple.com>2017-12-12 19:18:02 +0000
committerFiona Glaser <escha@apple.com>2017-12-12 19:18:02 +0000
commit1feb97a12bbc090ce0130f2f8fad4de4aa3f6c46 (patch)
tree35bfae04fa8c4f59e6c0b6d7b1cd59c174028312 /lib
parent960dcea840cb497bc0ff4b5112dc181ecdef566d (diff)
Reassociate: add global reassociation algorithm
This algorithm (explained more in the source code) takes into account global redundancies by building a "pair map" to find common subexprs. The primary motivation of this is to handle situations like foo = (a * b) * c bar = (a * d) * c where we currently don't identify that "a * c" is redundant. Accordingly, it prioritizes the emission of a * c so that CSE can remove the redundant calculation later. Does not change the actual reassociation algorithm -- only the order in which the reassociated operand chain is reconstructed. Gives ~1.5% floating point math instruction count reduction on a large offline suite of graphics shaders. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@320515 91177308-0d34-0410-b5e6-96231b3b80d8
Diffstat (limited to 'lib')
-rw-r--r--lib/Transforms/Scalar/Reassociate.cpp112
1 files changed, 110 insertions, 2 deletions
diff --git a/lib/Transforms/Scalar/Reassociate.cpp b/lib/Transforms/Scalar/Reassociate.cpp
index dcaa4034081..88dcaf0f8a3 100644
--- a/lib/Transforms/Scalar/Reassociate.cpp
+++ b/lib/Transforms/Scalar/Reassociate.cpp
@@ -27,6 +27,7 @@
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/GlobalsModRef.h"
@@ -2184,11 +2185,104 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) {
return;
}
+ if (Ops.size() > 2 && Ops.size() <= GlobalReassociateLimit) {
+ // Find the pair with the highest count in the pairmap and move it to the
+ // back of the list so that it can later be CSE'd.
+ // example:
+ // a*b*c*d*e
+ // if c*e is the most "popular" pair, we can express this as
+ // (((c*e)*d)*b)*a
+ unsigned Max = 1;
+ unsigned BestRank = 0;
+ std::pair<unsigned, unsigned> BestPair;
+ unsigned Idx = I->getOpcode() - Instruction::BinaryOpsBegin;
+ for (unsigned i = 0; i < Ops.size() - 1; ++i)
+ for (unsigned j = i + 1; j < Ops.size(); ++j) {
+ unsigned Score = 0;
+ Value *Op0 = Ops[i].Op;
+ Value *Op1 = Ops[j].Op;
+ if (std::less<Value *>()(Op1, Op0))
+ std::swap(Op0, Op1);
+ auto it = PairMap[Idx].find({Op0, Op1});
+ if (it != PairMap[Idx].end())
+ Score += it->second;
+
+ unsigned MaxRank = std::max(Ops[i].Rank, Ops[j].Rank);
+ if (Score > Max || (Score == Max && MaxRank < BestRank)) {
+ BestPair = {i, j};
+ Max = Score;
+ BestRank = MaxRank;
+ }
+ }
+ if (Max > 1) {
+ auto Op0 = Ops[BestPair.first];
+ auto Op1 = Ops[BestPair.second];
+ Ops.erase(&Ops[BestPair.second]);
+ Ops.erase(&Ops[BestPair.first]);
+ Ops.push_back(Op0);
+ Ops.push_back(Op1);
+ }
+ }
// Now that we ordered and optimized the expressions, splat them back into
// the expression tree, removing any unneeded nodes.
RewriteExprTree(I, Ops);
}
+void
+ReassociatePass::BuildPairMap(ReversePostOrderTraversal<Function *> &RPOT) {
+ // Make a "pairmap" of how often each operand pair occurs.
+ for (BasicBlock *BI : RPOT) {
+ for (Instruction &I : *BI) {
+ if (!I.isAssociative())
+ continue;
+
+ // Ignore nodes that aren't at the root of trees.
+ if (I.hasOneUse() && I.user_back()->getOpcode() == I.getOpcode())
+ continue;
+
+ // Collect all operands in a single reassociable expression.
+ // Since Reassociate has already been run once, we can assume things
+ // are already canonical according to Reassociation's regime.
+ SmallVector<Value *, 8> Worklist = { I.getOperand(0), I.getOperand(1) };
+ SmallVector<Value *, 8> Ops;
+ while (!Worklist.empty() && Ops.size() <= GlobalReassociateLimit) {
+ Value *Op = Worklist.pop_back_val();
+ Instruction *OpI = dyn_cast<Instruction>(Op);
+ if (!OpI || OpI->getOpcode() != I.getOpcode() || !OpI->hasOneUse()) {
+ Ops.push_back(Op);
+ continue;
+ }
+ // Be paranoid about self-referencing expressions in unreachable code.
+ if (OpI->getOperand(0) != OpI)
+ Worklist.push_back(OpI->getOperand(0));
+ if (OpI->getOperand(1) != OpI)
+ Worklist.push_back(OpI->getOperand(1));
+ }
+ // Skip extremely long expressions.
+ if (Ops.size() > GlobalReassociateLimit)
+ continue;
+
+ // Add all pairwise combinations of operands to the pair map.
+ unsigned BinaryIdx = I.getOpcode() - Instruction::BinaryOpsBegin;
+ SmallSet<std::pair<Value *, Value*>, 32> Visited;
+ for (unsigned i = 0; i < Ops.size() - 1; ++i) {
+ for (unsigned j = i + 1; j < Ops.size(); ++j) {
+ // Canonicalize operand orderings.
+ Value *Op0 = Ops[i];
+ Value *Op1 = Ops[j];
+ if (std::less<Value *>()(Op1, Op0))
+ std::swap(Op0, Op1);
+ if (!Visited.insert({Op0, Op1}).second)
+ continue;
+ auto res = PairMap[BinaryIdx].insert({{Op0, Op1}, 1});
+ if (!res.second)
+ ++res.first->second;
+ }
+ }
+ }
+ }
+}
+
PreservedAnalyses ReassociatePass::run(Function &F, FunctionAnalysisManager &) {
// Get the functions basic blocks in Reverse Post Order. This order is used by
// BuildRankMap to pre calculate ranks correctly. It also excludes dead basic
@@ -2199,8 +2293,20 @@ PreservedAnalyses ReassociatePass::run(Function &F, FunctionAnalysisManager &) {
// Calculate the rank map for F.
BuildRankMap(F, RPOT);
+ // Build the pair map before running reassociate.
+ // Technically this would be more accurate if we did it after one round
+ // of reassociation, but in practice it doesn't seem to help much on
+ // real-world code, so don't waste the compile time running reassociate
+ // twice.
+ // If a user wants, they could expicitly run reassociate twice in their
+ // pass pipeline for further potential gains.
+ // It might also be possible to update the pair map during runtime, but the
+ // overhead of that may be large if there's many reassociable chains.
+ BuildPairMap(RPOT);
+
MadeChange = false;
- // Traverse the same blocks that was analysed by BuildRankMap.
+
+ // Traverse the same blocks that were analysed by BuildRankMap.
for (BasicBlock *BI : RPOT) {
assert(RankMap.count(&*BI) && "BB should be ranked.");
// Optimize every instruction in the basic block.
@@ -2239,9 +2345,11 @@ PreservedAnalyses ReassociatePass::run(Function &F, FunctionAnalysisManager &) {
}
}
- // We are done with the rank map.
+ // We are done with the rank map and pair map.
RankMap.clear();
ValueRankMap.clear();
+ for (auto &Entry : PairMap)
+ Entry.clear();
if (MadeChange) {
PreservedAnalyses PA;