summaryrefslogtreecommitdiff
path: root/lib/Analysis
diff options
context:
space:
mode:
authorAlexey Bataev <a.bataev@hotmail.com>2017-09-08 13:49:36 +0000
committerAlexey Bataev <a.bataev@hotmail.com>2017-09-08 13:49:36 +0000
commit4fcc7e8528d8d6020a5299c9e7eccaa565413c9e (patch)
tree4012f4f1e8051afee7c0f114c0736faff44a5bcd /lib/Analysis
parent977c908e78cb2aa2ec8a9077602bf25d497712cc (diff)
[SLP] Support for horizontal min/max reduction.
SLP vectorizer supports horizontal reductions for Add/FAdd binary operations. Patch adds support for horizontal min/max reductions. Function getReductionCost() is split to getArithmeticReductionCost() for binary operation reductions and getMinMaxReductionCost() for min/max reductions. Patch fixes PR26956. Differential revision: https://reviews.llvm.org/D27846 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@312791 91177308-0d34-0410-b5e6-96231b3b80d8
Diffstat (limited to 'lib/Analysis')
-rw-r--r--lib/Analysis/CostModel.cpp155
-rw-r--r--lib/Analysis/TargetTransformInfo.cpp9
2 files changed, 115 insertions, 49 deletions
diff --git a/lib/Analysis/CostModel.cpp b/lib/Analysis/CostModel.cpp
index 071e23e90ff..47513f3c387 100644
--- a/lib/Analysis/CostModel.cpp
+++ b/lib/Analysis/CostModel.cpp
@@ -186,26 +186,56 @@ static bool matchPairwiseShuffleMask(ShuffleVectorInst *SI, bool IsLeft,
}
namespace {
+/// Kind of the reduction data.
+enum ReductionKind {
+ RK_None, /// Not a reduction.
+ RK_Arithmetic, /// Binary reduction data.
+ RK_MinMax, /// Min/max reduction data.
+ RK_UnsignedMinMax, /// Unsigned min/max reduction data.
+};
/// Contains opcode + LHS/RHS parts of the reduction operations.
struct ReductionData {
- explicit ReductionData() = default;
- ReductionData(unsigned Opcode, Value *LHS, Value *RHS)
- : Opcode(Opcode), LHS(LHS), RHS(RHS) {}
+ ReductionData() = delete;
+ ReductionData(ReductionKind Kind, unsigned Opcode, Value *LHS, Value *RHS)
+ : Opcode(Opcode), LHS(LHS), RHS(RHS), Kind(Kind) {
+ assert(Kind != RK_None && "expected binary or min/max reduction only.");
+ }
unsigned Opcode = 0;
Value *LHS = nullptr;
Value *RHS = nullptr;
+ ReductionKind Kind = RK_None;
+ bool hasSameData(ReductionData &RD) const {
+ return Kind == RD.Kind && Opcode == RD.Opcode;
+ }
};
} // namespace
static Optional<ReductionData> getReductionData(Instruction *I) {
Value *L, *R;
if (m_BinOp(m_Value(L), m_Value(R)).match(I))
- return ReductionData(I->getOpcode(), L, R);
+ return ReductionData(RK_Arithmetic, I->getOpcode(), L, R);
+ if (auto *SI = dyn_cast<SelectInst>(I)) {
+ if (m_SMin(m_Value(L), m_Value(R)).match(SI) ||
+ m_SMax(m_Value(L), m_Value(R)).match(SI) ||
+ m_OrdFMin(m_Value(L), m_Value(R)).match(SI) ||
+ m_OrdFMax(m_Value(L), m_Value(R)).match(SI) ||
+ m_UnordFMin(m_Value(L), m_Value(R)).match(SI) ||
+ m_UnordFMax(m_Value(L), m_Value(R)).match(SI)) {
+ auto *CI = cast<CmpInst>(SI->getCondition());
+ return ReductionData(RK_MinMax, CI->getOpcode(), L, R);
+ }
+ if (m_UMin(m_Value(L), m_Value(R)).match(SI) ||
+ m_UMax(m_Value(L), m_Value(R)).match(SI)) {
+ auto *CI = cast<CmpInst>(SI->getCondition());
+ return ReductionData(RK_UnsignedMinMax, CI->getOpcode(), L, R);
+ }
+ }
return llvm::None;
}
-static bool matchPairwiseReductionAtLevel(Instruction *I, unsigned Level,
- unsigned NumLevels) {
+static ReductionKind matchPairwiseReductionAtLevel(Instruction *I,
+ unsigned Level,
+ unsigned NumLevels) {
// Match one level of pairwise operations.
// %rdx.shuf.0.0 = shufflevector <4 x float> %rdx, <4 x float> undef,
// <4 x i32> <i32 0, i32 2 , i32 undef, i32 undef>
@@ -213,24 +243,24 @@ static bool matchPairwiseReductionAtLevel(Instruction *I, unsigned Level,
// <4 x i32> <i32 1, i32 3, i32 undef, i32 undef>
// %bin.rdx.0 = fadd <4 x float> %rdx.shuf.0.0, %rdx.shuf.0.1
if (!I)
- return false;
+ return RK_None;
assert(I->getType()->isVectorTy() && "Expecting a vector type");
Optional<ReductionData> RD = getReductionData(I);
if (!RD)
- return false;
+ return RK_None;
ShuffleVectorInst *LS = dyn_cast<ShuffleVectorInst>(RD->LHS);
if (!LS && Level)
- return false;
+ return RK_None;
ShuffleVectorInst *RS = dyn_cast<ShuffleVectorInst>(RD->RHS);
if (!RS && Level)
- return false;
+ return RK_None;
// On level 0 we can omit one shufflevector instruction.
if (!Level && !RS && !LS)
- return false;
+ return RK_None;
// Shuffle inputs must match.
Value *NextLevelOpL = LS ? LS->getOperand(0) : nullptr;
@@ -239,7 +269,7 @@ static bool matchPairwiseReductionAtLevel(Instruction *I, unsigned Level,
if (NextLevelOpR && NextLevelOpL) {
// If we have two shuffles their operands must match.
if (NextLevelOpL != NextLevelOpR)
- return false;
+ return RK_None;
NextLevelOp = NextLevelOpL;
} else if (Level == 0 && (NextLevelOpR || NextLevelOpL)) {
@@ -250,45 +280,47 @@ static bool matchPairwiseReductionAtLevel(Instruction *I, unsigned Level,
// %NextLevelOpL = shufflevector %R, <1, undef ...>
// %BinOp = fadd %NextLevelOpL, %R
if (NextLevelOpL && NextLevelOpL != RD->RHS)
- return false;
+ return RK_None;
else if (NextLevelOpR && NextLevelOpR != RD->LHS)
- return false;
+ return RK_None;
NextLevelOp = NextLevelOpL ? RD->RHS : RD->LHS;
- } else
- return false;
+ } else {
+ return RK_None;
+ }
// Check that the next levels binary operation exists and matches with the
// current one.
if (Level + 1 != NumLevels) {
Optional<ReductionData> NextLevelRD =
getReductionData(cast<Instruction>(NextLevelOp));
- if (!NextLevelRD || RD->Opcode != NextLevelRD->Opcode)
- return false;
+ if (!NextLevelRD || !RD->hasSameData(*NextLevelRD))
+ return RK_None;
}
// Shuffle mask for pairwise operation must match.
if (matchPairwiseShuffleMask(LS, /*IsLeft=*/true, Level)) {
if (!matchPairwiseShuffleMask(RS, /*IsLeft=*/false, Level))
- return false;
+ return RK_None;
} else if (matchPairwiseShuffleMask(RS, /*IsLeft=*/true, Level)) {
if (!matchPairwiseShuffleMask(LS, /*IsLeft=*/false, Level))
- return false;
- } else
- return false;
+ return RK_None;
+ } else {
+ return RK_None;
+ }
if (++Level == NumLevels)
- return true;
+ return RD->Kind;
// Match next level.
return matchPairwiseReductionAtLevel(cast<Instruction>(NextLevelOp), Level,
NumLevels);
}
-static bool matchPairwiseReduction(const ExtractElementInst *ReduxRoot,
- unsigned &Opcode, Type *&Ty) {
+static ReductionKind matchPairwiseReduction(const ExtractElementInst *ReduxRoot,
+ unsigned &Opcode, Type *&Ty) {
if (!EnableReduxCost)
- return false;
+ return RK_None;
// Need to extract the first element.
ConstantInt *CI = dyn_cast<ConstantInt>(ReduxRoot->getOperand(1));
@@ -296,19 +328,19 @@ static bool matchPairwiseReduction(const ExtractElementInst *ReduxRoot,
if (CI)
Idx = CI->getZExtValue();
if (Idx != 0)
- return false;
+ return RK_None;
auto *RdxStart = dyn_cast<Instruction>(ReduxRoot->getOperand(0));
if (!RdxStart)
- return false;
+ return RK_None;
Optional<ReductionData> RD = getReductionData(RdxStart);
if (!RD)
- return false;
+ return RK_None;
Type *VecTy = RdxStart->getType();
unsigned NumVecElems = VecTy->getVectorNumElements();
if (!isPowerOf2_32(NumVecElems))
- return false;
+ return RK_None;
// We look for a sequence of shuffle,shuffle,add triples like the following
// that builds a pairwise reduction tree.
@@ -328,13 +360,14 @@ static bool matchPairwiseReduction(const ExtractElementInst *ReduxRoot,
// <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef>
// %bin.rdx8 = fadd <4 x float> %rdx.shuf.1.0, %rdx.shuf.1.1
// %r = extractelement <4 x float> %bin.rdx8, i32 0
- if (!matchPairwiseReductionAtLevel(RdxStart, 0, Log2_32(NumVecElems)))
- return false;
+ if (matchPairwiseReductionAtLevel(RdxStart, 0, Log2_32(NumVecElems)) ==
+ RK_None)
+ return RK_None;
Opcode = RD->Opcode;
Ty = VecTy;
- return true;
+ return RD->Kind;
}
static std::pair<Value *, ShuffleVectorInst *>
@@ -348,10 +381,11 @@ getShuffleAndOtherOprd(Value *L, Value *R) {
return std::make_pair(L, S);
}
-static bool matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot,
- unsigned &Opcode, Type *&Ty) {
+static ReductionKind
+matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot,
+ unsigned &Opcode, Type *&Ty) {
if (!EnableReduxCost)
- return false;
+ return RK_None;
// Need to extract the first element.
ConstantInt *CI = dyn_cast<ConstantInt>(ReduxRoot->getOperand(1));
@@ -359,19 +393,19 @@ static bool matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot,
if (CI)
Idx = CI->getZExtValue();
if (Idx != 0)
- return false;
+ return RK_None;
auto *RdxStart = dyn_cast<Instruction>(ReduxRoot->getOperand(0));
if (!RdxStart)
- return false;
+ return RK_None;
Optional<ReductionData> RD = getReductionData(RdxStart);
if (!RD)
- return false;
+ return RK_None;
Type *VecTy = ReduxRoot->getOperand(0)->getType();
unsigned NumVecElems = VecTy->getVectorNumElements();
if (!isPowerOf2_32(NumVecElems))
- return false;
+ return RK_None;
// We look for a sequence of shuffles and adds like the following matching one
// fadd, shuffle vector pair at a time.
@@ -391,10 +425,10 @@ static bool matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot,
while (NumVecElemsRemain - 1) {
// Check for the right reduction operation.
if (!RdxOp)
- return false;
+ return RK_None;
Optional<ReductionData> RDLevel = getReductionData(RdxOp);
- if (!RDLevel || RDLevel->Opcode != RD->Opcode)
- return false;
+ if (!RDLevel || !RDLevel->hasSameData(*RD))
+ return RK_None;
Value *NextRdxOp;
ShuffleVectorInst *Shuffle;
@@ -403,9 +437,9 @@ static bool matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot,
// Check the current reduction operation and the shuffle use the same value.
if (Shuffle == nullptr)
- return false;
+ return RK_None;
if (Shuffle->getOperand(0) != NextRdxOp)
- return false;
+ return RK_None;
// Check that shuffle masks matches.
for (unsigned j = 0; j != MaskStart; ++j)
@@ -415,7 +449,7 @@ static bool matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot,
SmallVector<int, 16> Mask = Shuffle->getShuffleMask();
if (ShuffleMask != Mask)
- return false;
+ return RK_None;
RdxOp = dyn_cast<Instruction>(NextRdxOp);
NumVecElemsRemain /= 2;
@@ -424,7 +458,7 @@ static bool matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot,
Opcode = RD->Opcode;
Ty = VecTy;
- return true;
+ return RD->Kind;
}
unsigned CostModelAnalysis::getInstructionCost(const Instruction *I) const {
@@ -519,13 +553,36 @@ unsigned CostModelAnalysis::getInstructionCost(const Instruction *I) const {
unsigned ReduxOpCode;
Type *ReduxType;
- if (matchVectorSplittingReduction(EEI, ReduxOpCode, ReduxType)) {
+ switch (matchVectorSplittingReduction(EEI, ReduxOpCode, ReduxType)) {
+ case RK_Arithmetic:
return TTI->getArithmeticReductionCost(ReduxOpCode, ReduxType,
/*IsPairwiseForm=*/false);
+ case RK_MinMax:
+ return TTI->getMinMaxReductionCost(
+ ReduxType, CmpInst::makeCmpResultType(ReduxType),
+ /*IsPairwiseForm=*/false, /*IsUnsigned=*/false);
+ case RK_UnsignedMinMax:
+ return TTI->getMinMaxReductionCost(
+ ReduxType, CmpInst::makeCmpResultType(ReduxType),
+ /*IsPairwiseForm=*/false, /*IsUnsigned=*/true);
+ case RK_None:
+ break;
}
- if (matchPairwiseReduction(EEI, ReduxOpCode, ReduxType)) {
+
+ switch (matchPairwiseReduction(EEI, ReduxOpCode, ReduxType)) {
+ case RK_Arithmetic:
return TTI->getArithmeticReductionCost(ReduxOpCode, ReduxType,
/*IsPairwiseForm=*/true);
+ case RK_MinMax:
+ return TTI->getMinMaxReductionCost(
+ ReduxType, CmpInst::makeCmpResultType(ReduxType),
+ /*IsPairwiseForm=*/true, /*IsUnsigned=*/false);
+ case RK_UnsignedMinMax:
+ return TTI->getMinMaxReductionCost(
+ ReduxType, CmpInst::makeCmpResultType(ReduxType),
+ /*IsPairwiseForm=*/true, /*IsUnsigned=*/true);
+ case RK_None:
+ break;
}
return TTI->getVectorInstrCost(I->getOpcode(),
diff --git a/lib/Analysis/TargetTransformInfo.cpp b/lib/Analysis/TargetTransformInfo.cpp
index e09138168c9..8673b1b55d9 100644
--- a/lib/Analysis/TargetTransformInfo.cpp
+++ b/lib/Analysis/TargetTransformInfo.cpp
@@ -484,6 +484,15 @@ int TargetTransformInfo::getArithmeticReductionCost(unsigned Opcode, Type *Ty,
return Cost;
}
+int TargetTransformInfo::getMinMaxReductionCost(Type *Ty, Type *CondTy,
+ bool IsPairwiseForm,
+ bool IsUnsigned) const {
+ int Cost =
+ TTIImpl->getMinMaxReductionCost(Ty, CondTy, IsPairwiseForm, IsUnsigned);
+ assert(Cost >= 0 && "TTI should not produce negative costs!");
+ return Cost;
+}
+
unsigned
TargetTransformInfo::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) const {
return TTIImpl->getCostOfKeepingLiveOverCall(Tys);