diff options
author | Alexey Bataev <a.bataev@hotmail.com> | 2017-09-08 13:49:36 +0000 |
---|---|---|
committer | Alexey Bataev <a.bataev@hotmail.com> | 2017-09-08 13:49:36 +0000 |
commit | 4fcc7e8528d8d6020a5299c9e7eccaa565413c9e (patch) | |
tree | 4012f4f1e8051afee7c0f114c0736faff44a5bcd /lib/Analysis | |
parent | 977c908e78cb2aa2ec8a9077602bf25d497712cc (diff) |
[SLP] Support for horizontal min/max reduction.
SLP vectorizer supports horizontal reductions for Add/FAdd binary
operations. Patch adds support for horizontal min/max reductions.
Function getReductionCost() is split to getArithmeticReductionCost() for
binary operation reductions and getMinMaxReductionCost() for min/max
reductions.
Patch fixes PR26956.
Differential revision: https://reviews.llvm.org/D27846
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@312791 91177308-0d34-0410-b5e6-96231b3b80d8
Diffstat (limited to 'lib/Analysis')
-rw-r--r-- | lib/Analysis/CostModel.cpp | 155 | ||||
-rw-r--r-- | lib/Analysis/TargetTransformInfo.cpp | 9 |
2 files changed, 115 insertions, 49 deletions
diff --git a/lib/Analysis/CostModel.cpp b/lib/Analysis/CostModel.cpp index 071e23e90ff..47513f3c387 100644 --- a/lib/Analysis/CostModel.cpp +++ b/lib/Analysis/CostModel.cpp @@ -186,26 +186,56 @@ static bool matchPairwiseShuffleMask(ShuffleVectorInst *SI, bool IsLeft, } namespace { +/// Kind of the reduction data. +enum ReductionKind { + RK_None, /// Not a reduction. + RK_Arithmetic, /// Binary reduction data. + RK_MinMax, /// Min/max reduction data. + RK_UnsignedMinMax, /// Unsigned min/max reduction data. +}; /// Contains opcode + LHS/RHS parts of the reduction operations. struct ReductionData { - explicit ReductionData() = default; - ReductionData(unsigned Opcode, Value *LHS, Value *RHS) - : Opcode(Opcode), LHS(LHS), RHS(RHS) {} + ReductionData() = delete; + ReductionData(ReductionKind Kind, unsigned Opcode, Value *LHS, Value *RHS) + : Opcode(Opcode), LHS(LHS), RHS(RHS), Kind(Kind) { + assert(Kind != RK_None && "expected binary or min/max reduction only."); + } unsigned Opcode = 0; Value *LHS = nullptr; Value *RHS = nullptr; + ReductionKind Kind = RK_None; + bool hasSameData(ReductionData &RD) const { + return Kind == RD.Kind && Opcode == RD.Opcode; + } }; } // namespace static Optional<ReductionData> getReductionData(Instruction *I) { Value *L, *R; if (m_BinOp(m_Value(L), m_Value(R)).match(I)) - return ReductionData(I->getOpcode(), L, R); + return ReductionData(RK_Arithmetic, I->getOpcode(), L, R); + if (auto *SI = dyn_cast<SelectInst>(I)) { + if (m_SMin(m_Value(L), m_Value(R)).match(SI) || + m_SMax(m_Value(L), m_Value(R)).match(SI) || + m_OrdFMin(m_Value(L), m_Value(R)).match(SI) || + m_OrdFMax(m_Value(L), m_Value(R)).match(SI) || + m_UnordFMin(m_Value(L), m_Value(R)).match(SI) || + m_UnordFMax(m_Value(L), m_Value(R)).match(SI)) { + auto *CI = cast<CmpInst>(SI->getCondition()); + return ReductionData(RK_MinMax, CI->getOpcode(), L, R); + } + if (m_UMin(m_Value(L), m_Value(R)).match(SI) || + m_UMax(m_Value(L), m_Value(R)).match(SI)) { + auto *CI = cast<CmpInst>(SI->getCondition()); + return ReductionData(RK_UnsignedMinMax, CI->getOpcode(), L, R); + } + } return llvm::None; } -static bool matchPairwiseReductionAtLevel(Instruction *I, unsigned Level, - unsigned NumLevels) { +static ReductionKind matchPairwiseReductionAtLevel(Instruction *I, + unsigned Level, + unsigned NumLevels) { // Match one level of pairwise operations. // %rdx.shuf.0.0 = shufflevector <4 x float> %rdx, <4 x float> undef, // <4 x i32> <i32 0, i32 2 , i32 undef, i32 undef> @@ -213,24 +243,24 @@ static bool matchPairwiseReductionAtLevel(Instruction *I, unsigned Level, // <4 x i32> <i32 1, i32 3, i32 undef, i32 undef> // %bin.rdx.0 = fadd <4 x float> %rdx.shuf.0.0, %rdx.shuf.0.1 if (!I) - return false; + return RK_None; assert(I->getType()->isVectorTy() && "Expecting a vector type"); Optional<ReductionData> RD = getReductionData(I); if (!RD) - return false; + return RK_None; ShuffleVectorInst *LS = dyn_cast<ShuffleVectorInst>(RD->LHS); if (!LS && Level) - return false; + return RK_None; ShuffleVectorInst *RS = dyn_cast<ShuffleVectorInst>(RD->RHS); if (!RS && Level) - return false; + return RK_None; // On level 0 we can omit one shufflevector instruction. if (!Level && !RS && !LS) - return false; + return RK_None; // Shuffle inputs must match. Value *NextLevelOpL = LS ? LS->getOperand(0) : nullptr; @@ -239,7 +269,7 @@ static bool matchPairwiseReductionAtLevel(Instruction *I, unsigned Level, if (NextLevelOpR && NextLevelOpL) { // If we have two shuffles their operands must match. if (NextLevelOpL != NextLevelOpR) - return false; + return RK_None; NextLevelOp = NextLevelOpL; } else if (Level == 0 && (NextLevelOpR || NextLevelOpL)) { @@ -250,45 +280,47 @@ static bool matchPairwiseReductionAtLevel(Instruction *I, unsigned Level, // %NextLevelOpL = shufflevector %R, <1, undef ...> // %BinOp = fadd %NextLevelOpL, %R if (NextLevelOpL && NextLevelOpL != RD->RHS) - return false; + return RK_None; else if (NextLevelOpR && NextLevelOpR != RD->LHS) - return false; + return RK_None; NextLevelOp = NextLevelOpL ? RD->RHS : RD->LHS; - } else - return false; + } else { + return RK_None; + } // Check that the next levels binary operation exists and matches with the // current one. if (Level + 1 != NumLevels) { Optional<ReductionData> NextLevelRD = getReductionData(cast<Instruction>(NextLevelOp)); - if (!NextLevelRD || RD->Opcode != NextLevelRD->Opcode) - return false; + if (!NextLevelRD || !RD->hasSameData(*NextLevelRD)) + return RK_None; } // Shuffle mask for pairwise operation must match. if (matchPairwiseShuffleMask(LS, /*IsLeft=*/true, Level)) { if (!matchPairwiseShuffleMask(RS, /*IsLeft=*/false, Level)) - return false; + return RK_None; } else if (matchPairwiseShuffleMask(RS, /*IsLeft=*/true, Level)) { if (!matchPairwiseShuffleMask(LS, /*IsLeft=*/false, Level)) - return false; - } else - return false; + return RK_None; + } else { + return RK_None; + } if (++Level == NumLevels) - return true; + return RD->Kind; // Match next level. return matchPairwiseReductionAtLevel(cast<Instruction>(NextLevelOp), Level, NumLevels); } -static bool matchPairwiseReduction(const ExtractElementInst *ReduxRoot, - unsigned &Opcode, Type *&Ty) { +static ReductionKind matchPairwiseReduction(const ExtractElementInst *ReduxRoot, + unsigned &Opcode, Type *&Ty) { if (!EnableReduxCost) - return false; + return RK_None; // Need to extract the first element. ConstantInt *CI = dyn_cast<ConstantInt>(ReduxRoot->getOperand(1)); @@ -296,19 +328,19 @@ static bool matchPairwiseReduction(const ExtractElementInst *ReduxRoot, if (CI) Idx = CI->getZExtValue(); if (Idx != 0) - return false; + return RK_None; auto *RdxStart = dyn_cast<Instruction>(ReduxRoot->getOperand(0)); if (!RdxStart) - return false; + return RK_None; Optional<ReductionData> RD = getReductionData(RdxStart); if (!RD) - return false; + return RK_None; Type *VecTy = RdxStart->getType(); unsigned NumVecElems = VecTy->getVectorNumElements(); if (!isPowerOf2_32(NumVecElems)) - return false; + return RK_None; // We look for a sequence of shuffle,shuffle,add triples like the following // that builds a pairwise reduction tree. @@ -328,13 +360,14 @@ static bool matchPairwiseReduction(const ExtractElementInst *ReduxRoot, // <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef> // %bin.rdx8 = fadd <4 x float> %rdx.shuf.1.0, %rdx.shuf.1.1 // %r = extractelement <4 x float> %bin.rdx8, i32 0 - if (!matchPairwiseReductionAtLevel(RdxStart, 0, Log2_32(NumVecElems))) - return false; + if (matchPairwiseReductionAtLevel(RdxStart, 0, Log2_32(NumVecElems)) == + RK_None) + return RK_None; Opcode = RD->Opcode; Ty = VecTy; - return true; + return RD->Kind; } static std::pair<Value *, ShuffleVectorInst *> @@ -348,10 +381,11 @@ getShuffleAndOtherOprd(Value *L, Value *R) { return std::make_pair(L, S); } -static bool matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot, - unsigned &Opcode, Type *&Ty) { +static ReductionKind +matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot, + unsigned &Opcode, Type *&Ty) { if (!EnableReduxCost) - return false; + return RK_None; // Need to extract the first element. ConstantInt *CI = dyn_cast<ConstantInt>(ReduxRoot->getOperand(1)); @@ -359,19 +393,19 @@ static bool matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot, if (CI) Idx = CI->getZExtValue(); if (Idx != 0) - return false; + return RK_None; auto *RdxStart = dyn_cast<Instruction>(ReduxRoot->getOperand(0)); if (!RdxStart) - return false; + return RK_None; Optional<ReductionData> RD = getReductionData(RdxStart); if (!RD) - return false; + return RK_None; Type *VecTy = ReduxRoot->getOperand(0)->getType(); unsigned NumVecElems = VecTy->getVectorNumElements(); if (!isPowerOf2_32(NumVecElems)) - return false; + return RK_None; // We look for a sequence of shuffles and adds like the following matching one // fadd, shuffle vector pair at a time. @@ -391,10 +425,10 @@ static bool matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot, while (NumVecElemsRemain - 1) { // Check for the right reduction operation. if (!RdxOp) - return false; + return RK_None; Optional<ReductionData> RDLevel = getReductionData(RdxOp); - if (!RDLevel || RDLevel->Opcode != RD->Opcode) - return false; + if (!RDLevel || !RDLevel->hasSameData(*RD)) + return RK_None; Value *NextRdxOp; ShuffleVectorInst *Shuffle; @@ -403,9 +437,9 @@ static bool matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot, // Check the current reduction operation and the shuffle use the same value. if (Shuffle == nullptr) - return false; + return RK_None; if (Shuffle->getOperand(0) != NextRdxOp) - return false; + return RK_None; // Check that shuffle masks matches. for (unsigned j = 0; j != MaskStart; ++j) @@ -415,7 +449,7 @@ static bool matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot, SmallVector<int, 16> Mask = Shuffle->getShuffleMask(); if (ShuffleMask != Mask) - return false; + return RK_None; RdxOp = dyn_cast<Instruction>(NextRdxOp); NumVecElemsRemain /= 2; @@ -424,7 +458,7 @@ static bool matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot, Opcode = RD->Opcode; Ty = VecTy; - return true; + return RD->Kind; } unsigned CostModelAnalysis::getInstructionCost(const Instruction *I) const { @@ -519,13 +553,36 @@ unsigned CostModelAnalysis::getInstructionCost(const Instruction *I) const { unsigned ReduxOpCode; Type *ReduxType; - if (matchVectorSplittingReduction(EEI, ReduxOpCode, ReduxType)) { + switch (matchVectorSplittingReduction(EEI, ReduxOpCode, ReduxType)) { + case RK_Arithmetic: return TTI->getArithmeticReductionCost(ReduxOpCode, ReduxType, /*IsPairwiseForm=*/false); + case RK_MinMax: + return TTI->getMinMaxReductionCost( + ReduxType, CmpInst::makeCmpResultType(ReduxType), + /*IsPairwiseForm=*/false, /*IsUnsigned=*/false); + case RK_UnsignedMinMax: + return TTI->getMinMaxReductionCost( + ReduxType, CmpInst::makeCmpResultType(ReduxType), + /*IsPairwiseForm=*/false, /*IsUnsigned=*/true); + case RK_None: + break; } - if (matchPairwiseReduction(EEI, ReduxOpCode, ReduxType)) { + + switch (matchPairwiseReduction(EEI, ReduxOpCode, ReduxType)) { + case RK_Arithmetic: return TTI->getArithmeticReductionCost(ReduxOpCode, ReduxType, /*IsPairwiseForm=*/true); + case RK_MinMax: + return TTI->getMinMaxReductionCost( + ReduxType, CmpInst::makeCmpResultType(ReduxType), + /*IsPairwiseForm=*/true, /*IsUnsigned=*/false); + case RK_UnsignedMinMax: + return TTI->getMinMaxReductionCost( + ReduxType, CmpInst::makeCmpResultType(ReduxType), + /*IsPairwiseForm=*/true, /*IsUnsigned=*/true); + case RK_None: + break; } return TTI->getVectorInstrCost(I->getOpcode(), diff --git a/lib/Analysis/TargetTransformInfo.cpp b/lib/Analysis/TargetTransformInfo.cpp index e09138168c9..8673b1b55d9 100644 --- a/lib/Analysis/TargetTransformInfo.cpp +++ b/lib/Analysis/TargetTransformInfo.cpp @@ -484,6 +484,15 @@ int TargetTransformInfo::getArithmeticReductionCost(unsigned Opcode, Type *Ty, return Cost; } +int TargetTransformInfo::getMinMaxReductionCost(Type *Ty, Type *CondTy, + bool IsPairwiseForm, + bool IsUnsigned) const { + int Cost = + TTIImpl->getMinMaxReductionCost(Ty, CondTy, IsPairwiseForm, IsUnsigned); + assert(Cost >= 0 && "TTI should not produce negative costs!"); + return Cost; +} + unsigned TargetTransformInfo::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) const { return TTIImpl->getCostOfKeepingLiveOverCall(Tys); |