diff options
author | David Bolvansky <david.bolvansky@gmail.com> | 2018-07-30 16:50:00 +0000 |
---|---|---|
committer | David Bolvansky <david.bolvansky@gmail.com> | 2018-07-30 16:50:00 +0000 |
commit | 44dc58d645c32f4f55c3c48f0066311469e74eb6 (patch) | |
tree | c44946370ac3a8f08fa047e2177423f6ab14e3a2 /lib/CodeGen | |
parent | 005ad0240a6bca58b55f5500c62d3b15ed81364a (diff) |
[DAGCombiner] Bug 31275- Extract a shift from a constant mul or udiv if a rotate can be formed
Summary:
Attempt to extract a shrl from a udiv or a shl from a mul if this allows a rotate to be formed. This targets cases where the input to a rotate pattern was a mul or udiv by a constant and InstCombine merged one of the shifts with the op.
Patch by: sameconrad (Sam Conrad)
Reviewers: RKSimon, craig.topper, spatel, lebedev.ri, javed.absar
Reviewed By: lebedev.ri
Subscribers: efriedma, kparzysz, llvm-commits
Differential Revision: https://reviews.llvm.org/D47681
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@338270 91177308-0d34-0410-b5e6-96231b3b80d8
Diffstat (limited to 'lib/CodeGen')
-rw-r--r-- | lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 173 |
1 files changed, 156 insertions, 17 deletions
diff --git a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 2b037e465c8..6385fc6d415 100644 --- a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -483,9 +483,6 @@ namespace { /// returns false. bool findBetterNeighborChains(StoreSDNode *St); - /// Match "(X shl/srl V1) & V2" where V2 may not be present. - bool MatchRotateHalf(SDValue Op, SDValue &Shift, SDValue &Mask); - /// Holds a pointer to an LSBaseSDNode as well as information on where it /// is located in a sequence of memory operations connected by a chain. struct MemOpLink { @@ -5148,25 +5145,140 @@ SDValue DAGCombiner::visitOR(SDNode *N) { return SDValue(); } -/// Match "(X shl/srl V1) & V2" where V2 may not be present. -bool DAGCombiner::MatchRotateHalf(SDValue Op, SDValue &Shift, SDValue &Mask) { - if (Op.getOpcode() == ISD::AND) { - if (DAG.isConstantIntBuildVectorOrConstantInt(Op.getOperand(1))) { - Mask = Op.getOperand(1); - Op = Op.getOperand(0); - } else { - return false; - } +static SDValue stripConstantMask(SelectionDAG &DAG, SDValue Op, SDValue &Mask) { + if (Op.getOpcode() == ISD::AND && + DAG.isConstantIntBuildVectorOrConstantInt(Op.getOperand(1))) { + Mask = Op.getOperand(1); + return Op.getOperand(0); } + return Op; +} +/// Match "(X shl/srl V1) & V2" where V2 may not be present. +static bool matchRotateHalf(SelectionDAG &DAG, SDValue Op, SDValue &Shift, + SDValue &Mask) { + Op = stripConstantMask(DAG, Op, Mask); if (Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SHL) { Shift = Op; return true; } - return false; } +/// Helper function for visitOR to extract the needed side of a rotate idiom +/// from a shl/srl/mul/udiv. This is meant to handle cases where +/// InstCombine merged some outside op with one of the shifts from +/// the rotate pattern. +/// \returns An empty \c SDValue if the needed shift couldn't be extracted. +/// Otherwise, returns an expansion of \p ExtractFrom based on the following +/// patterns: +/// +/// (or (mul v c0) (shrl (mul v c1) c2)): +/// expands (mul v c0) -> (shl (mul v c1) c3) +/// +/// (or (udiv v c0) (shl (udiv v c1) c2)): +/// expands (udiv v c0) -> (shrl (udiv v c1) c3) +/// +/// (or (shl v c0) (shrl (shl v c1) c2)): +/// expands (shl v c0) -> (shl (shl v c1) c3) +/// +/// (or (shrl v c0) (shl (shrl v c1) c2)): +/// expands (shrl v c0) -> (shrl (shrl v c1) c3) +/// +/// Such that in all cases, c3+c2==bitwidth(op v c1). +static SDValue extractShiftForRotate(SelectionDAG &DAG, SDValue OppShift, + SDValue ExtractFrom, SDValue &Mask, + const SDLoc &DL) { + assert(OppShift && ExtractFrom && "Empty SDValue"); + assert( + (OppShift.getOpcode() == ISD::SHL || OppShift.getOpcode() == ISD::SRL) && + "Existing shift must be valid as a rotate half"); + + ExtractFrom = stripConstantMask(DAG, ExtractFrom, Mask); + // Preconditions: + // (or (op0 v c0) (shiftl/r (op0 v c1) c2)) + // + // Find opcode of the needed shift to be extracted from (op0 v c0). + unsigned Opcode = ISD::DELETED_NODE; + bool IsMulOrDiv = false; + // Set Opcode and IsMulOrDiv if the extract opcode matches the needed shift + // opcode or its arithmetic (mul or udiv) variant. + auto SelectOpcode = [&](unsigned NeededShift, unsigned MulOrDivVariant) { + IsMulOrDiv = ExtractFrom.getOpcode() == MulOrDivVariant; + if (!IsMulOrDiv && ExtractFrom.getOpcode() != NeededShift) + return false; + Opcode = NeededShift; + return true; + }; + // op0 must be either the needed shift opcode or the mul/udiv equivalent + // that the needed shift can be extracted from. + if ((OppShift.getOpcode() != ISD::SRL || !SelectOpcode(ISD::SHL, ISD::MUL)) && + (OppShift.getOpcode() != ISD::SHL || !SelectOpcode(ISD::SRL, ISD::UDIV))) + return SDValue(); + + // op0 must be the same opcode on both sides, have the same LHS argument, + // and produce the same value type. + SDValue OppShiftLHS = OppShift.getOperand(0); + EVT ShiftedVT = OppShiftLHS.getValueType(); + if (OppShiftLHS.getOpcode() != ExtractFrom.getOpcode() || + OppShiftLHS.getOperand(0) != ExtractFrom.getOperand(0) || + ShiftedVT != ExtractFrom.getValueType()) + return SDValue(); + + // Amount of the existing shift. + ConstantSDNode *OppShiftCst = isConstOrConstSplat(OppShift.getOperand(1)); + // Constant mul/udiv/shift amount from the RHS of the shift's LHS op. + ConstantSDNode *OppLHSCst = isConstOrConstSplat(OppShiftLHS.getOperand(1)); + // Constant mul/udiv/shift amount from the RHS of the ExtractFrom op. + ConstantSDNode *ExtractFromCst = + isConstOrConstSplat(ExtractFrom.getOperand(1)); + // TODO: We should be able to handle non-uniform constant vectors for these values + // Check that we have constant values. + if (!OppShiftCst || !OppShiftCst->getAPIntValue() || + !OppLHSCst || !OppLHSCst->getAPIntValue() || + !ExtractFromCst || !ExtractFromCst->getAPIntValue()) + return SDValue(); + + // Compute the shift amount we need to extract to complete the rotate. + const unsigned VTWidth = ShiftedVT.getScalarSizeInBits(); + APInt NeededShiftAmt = VTWidth - OppShiftCst->getAPIntValue(); + if (NeededShiftAmt.isNegative()) + return SDValue(); + // Normalize the bitwidth of the two mul/udiv/shift constant operands. + APInt ExtractFromAmt = ExtractFromCst->getAPIntValue(); + APInt OppLHSAmt = OppLHSCst->getAPIntValue(); + zeroExtendToMatch(ExtractFromAmt, OppLHSAmt); + + // Now try extract the needed shift from the ExtractFrom op and see if the + // result matches up with the existing shift's LHS op. + if (IsMulOrDiv) { + // Op to extract from is a mul or udiv by a constant. + // Check: + // c2 / (1 << (bitwidth(op0 v c0) - c1)) == c0 + // c2 % (1 << (bitwidth(op0 v c0) - c1)) == 0 + const APInt ExtractDiv = APInt::getOneBitSet(ExtractFromAmt.getBitWidth(), + NeededShiftAmt.getZExtValue()); + APInt ResultAmt; + APInt Rem; + APInt::udivrem(ExtractFromAmt, ExtractDiv, ResultAmt, Rem); + if (Rem != 0 || ResultAmt != OppLHSAmt) + return SDValue(); + } else { + // Op to extract from is a shift by a constant. + // Check: + // c2 - (bitwidth(op0 v c0) - c1) == c0 + if (OppLHSAmt != ExtractFromAmt - NeededShiftAmt.zextOrTrunc( + ExtractFromAmt.getBitWidth())) + return SDValue(); + } + + // Return the expanded shift op that should allow a rotate to be formed. + EVT ShiftVT = OppShift.getOperand(1).getValueType(); + EVT ResVT = ExtractFrom.getValueType(); + SDValue NewShiftNode = DAG.getConstant(NeededShiftAmt, DL, ShiftVT); + return DAG.getNode(Opcode, DL, ResVT, OppShiftLHS, NewShiftNode); +} + // Return true if we can prove that, whenever Neg and Pos are both in the // range [0, EltSize), Neg == (Pos == 0 ? 0 : EltSize - Pos). This means that // for two opposing shifts shift1 and shift2 and a value X with OpBits bits: @@ -5333,14 +5445,41 @@ SDNode *DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) { // Match "(X shl/srl V1) & V2" where V2 may not be present. SDValue LHSShift; // The shift. SDValue LHSMask; // AND value if any. - if (!MatchRotateHalf(LHS, LHSShift, LHSMask)) - return nullptr; // Not part of a rotate. + matchRotateHalf(DAG, LHS, LHSShift, LHSMask); SDValue RHSShift; // The shift. SDValue RHSMask; // AND value if any. - if (!MatchRotateHalf(RHS, RHSShift, RHSMask)) - return nullptr; // Not part of a rotate. + matchRotateHalf(DAG, RHS, RHSShift, RHSMask); + + // If neither side matched a rotate half, bail + if (!LHSShift && !RHSShift) + return nullptr; + + // InstCombine may have combined a constant shl, srl, mul, or udiv with one + // side of the rotate, so try to handle that here. In all cases we need to + // pass the matched shift from the opposite side to compute the opcode and + // needed shift amount to extract. We still want to do this if both sides + // matched a rotate half because one half may be a potential overshift that + // can be broken down (ie if InstCombine merged two shl or srl ops into a + // single one). + + // Have LHS side of the rotate, try to extract the needed shift from the RHS. + if (LHSShift) + if (SDValue NewRHSShift = + extractShiftForRotate(DAG, LHSShift, RHS, RHSMask, DL)) + RHSShift = NewRHSShift; + // Have RHS side of the rotate, try to extract the needed shift from the LHS. + if (RHSShift) + if (SDValue NewLHSShift = + extractShiftForRotate(DAG, RHSShift, LHS, LHSMask, DL)) + LHSShift = NewLHSShift; + + // If a side is still missing, nothing else we can do. + if (!RHSShift || !LHSShift) + return nullptr; + // At this point we've matched or extracted a shift op on each side. + if (LHSShift.getOperand(0) != RHSShift.getOperand(0)) return nullptr; // Not shifting the same value. |