@@ -16562,27 +16562,51 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
1656216562 CmpPredicate CurrentPred = ScalarTy->isFloatingPointTy()
1656316563 ? CmpInst::BAD_FCMP_PREDICATE
1656416564 : CmpInst::BAD_ICMP_PREDICATE;
16565+ Value *LHS = nullptr, *RHS = nullptr;
1656516566 auto MatchCmp = m_Cmp(CurrentPred, m_Value(), m_Value());
16566- if ((!match(VI, m_Select(MatchCmp, m_Value(), m_Value())) &&
16567- !match(VI, MatchCmp)) ||
16567+ bool IsSelect =
16568+ ShuffleOrOp == Instruction::Select &&
16569+ (match(VI, m_Select(MatchCmp, m_Value(LHS), m_Value(RHS))) ||
16570+ match(VI, m_Select(m_Value(), m_Value(LHS), m_Value(RHS))));
16571+ if ((!IsSelect && !match(VI, MatchCmp)) ||
1656816572 (CurrentPred != static_cast<CmpInst::Predicate>(VecPred) &&
1656916573 CurrentPred != static_cast<CmpInst::Predicate>(SwappedVecPred)))
1657016574 VecPred = SwappedVecPred = ScalarTy->isFloatingPointTy()
1657116575 ? CmpInst::BAD_FCMP_PREDICATE
1657216576 : CmpInst::BAD_ICMP_PREDICATE;
1657316577
16574- // For selects, the "condition type" arg is the condition operand's
16575- // type; for standalone compares, it is the result type (i1).
16576- InstructionCost ScalarCost = TTI->getCmpSelInstrCost(
16577- E->getOpcode(), OrigScalarTy,
16578- ShuffleOrOp == Instruction::Select ? VL0->getOperand(0)->getType()
16579- : VL0->getType(),
16580- CurrentPred, CostKind,
16581- getOperandInfo(
16582- VI->getOperand(ShuffleOrOp == Instruction::Select ? 1 : 0)),
16583- getOperandInfo(
16584- VI->getOperand(ShuffleOrOp == Instruction::Select ? 2 : 1)),
16585- VI);
16578+ // Check if operands are of i1 types, like a condition expression.
16579+ // TODO: consider implementing this in TTI.
16580+ InstructionCost ScalarCost = InstructionCost::getInvalid();
16581+ if (IsSelect && LHS->getType() == VI->getOperand(0)->getType()) {
16582+ assert(LHS->getType() == RHS->getType() &&
16583+ "Expected same type for LHS/RHS");
16584+ // select i1 v, i1 true, i1 b -> or i1 v, i1 b
16585+ if (match(LHS, m_AllOnes())) {
16586+ ScalarCost = TTI->getArithmeticInstrCost(
16587+ Instruction::Or, LHS->getType(), CostKind,
16588+ getOperandInfo(VI->getOperand(0)), getOperandInfo(RHS));
16589+ } else if (match(RHS, m_Zero())) {
16590+ // select i1 v, i1 b, i1 false -> and i1 v, i1 b
16591+ ScalarCost = TTI->getArithmeticInstrCost(
16592+ Instruction::And, LHS->getType(), CostKind,
16593+ getOperandInfo(VI->getOperand(0)), getOperandInfo(LHS));
16594+ }
16595+ }
16596+ if (!ScalarCost.isValid()) {
16597+ // For selects, the "condition type" arg is the condition operand's
16598+ // type; for standalone compares, it is the result type (i1).
16599+ ScalarCost = TTI->getCmpSelInstrCost(
16600+ E->getOpcode(), OrigScalarTy,
16601+ ShuffleOrOp == Instruction::Select ? VL0->getOperand(0)->getType()
16602+ : VL0->getType(),
16603+ CurrentPred, CostKind,
16604+ getOperandInfo(
16605+ VI->getOperand(ShuffleOrOp == Instruction::Select ? 1 : 0)),
16606+ getOperandInfo(
16607+ VI->getOperand(ShuffleOrOp == Instruction::Select ? 2 : 1)),
16608+ VI);
16609+ }
1658616610 InstructionCost IntrinsicCost = GetMinMaxCost(OrigScalarTy, VI);
1658716611 if (IntrinsicCost.isValid())
1658816612 ScalarCost = IntrinsicCost;
@@ -16599,26 +16623,52 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
1659916623 : VL0->getType(),
1660016624 VL.size());
1660116625
16602- InstructionCost VecCost = TTI->getCmpSelInstrCost(
16603- E->getOpcode(), VecTy, MaskTy, VecPred, CostKind,
16604- getOperandInfo(
16605- E->getOperand(ShuffleOrOp == Instruction::Select ? 1 : 0)),
16606- getOperandInfo(
16607- E->getOperand(ShuffleOrOp == Instruction::Select ? 2 : 1)),
16608- VL0);
16609- if (isa<SelectInst>(VL0)) {
16610- unsigned CondNumElements = getNumElements(MaskTy);
16611- unsigned VecTyNumElements = getNumElements(VecTy);
16612- assert(VecTyNumElements >= CondNumElements &&
16613- VecTyNumElements % CondNumElements == 0 &&
16614- "Cannot vectorize Instruction::Select");
16615- if (CondNumElements != VecTyNumElements) {
16616- // When the return type is i1 but the source is fixed vector type, we
16617- // need to duplicate the condition value.
16618- VecCost += ::getShuffleCost(
16619- *TTI, TTI::SK_PermuteSingleSrc, MaskTy,
16620- createReplicatedMask(VecTyNumElements / CondNumElements,
16621- CondNumElements));
16626+ InstructionCost VecCost = InstructionCost::getInvalid();
16627+ if (ShuffleOrOp == Instruction::Select) {
16628+ ArrayRef<Value *> Cond = E->getOperand(0);
16629+ ArrayRef<Value *> LHS = E->getOperand(1);
16630+ ArrayRef<Value *> RHS = E->getOperand(2);
16631+ // select <VF x i1>, <VF x i1>, <VF x i1>?
16632+ // TODO: consider implementing this in TTI.
16633+ if (Cond.front()->getType() == LHS.front()->getType()) {
16634+ // select <VF x i1> v, <VF x i1> true, <VF x i1> b -> or <VF x i1> v,
16635+ // <VF x i1> b
16636+ if (all_of(LHS, [&](Value *V) { return match(V, m_AllOnes()); })) {
16637+ VecCost = TTI->getArithmeticInstrCost(
16638+ Instruction::Or, VecTy, CostKind, getOperandInfo(Cond),
16639+ getOperandInfo(RHS));
16640+ } else if (all_of(RHS,
16641+ [&](Value *V) { return match(V, m_Zero()); })) {
16642+ // select <VF x i1> v, <VF x i1> b, <VF x i1> false -> and <VF x i1>
16643+ // v, <VF x i1> b
16644+ VecCost = TTI->getArithmeticInstrCost(
16645+ Instruction::And, VecTy, CostKind, getOperandInfo(Cond),
16646+ getOperandInfo(LHS));
16647+ }
16648+ }
16649+ }
16650+ if (!VecCost.isValid()) {
16651+ VecCost = TTI->getCmpSelInstrCost(
16652+ E->getOpcode(), VecTy, MaskTy, VecPred, CostKind,
16653+ getOperandInfo(
16654+ E->getOperand(ShuffleOrOp == Instruction::Select ? 1 : 0)),
16655+ getOperandInfo(
16656+ E->getOperand(ShuffleOrOp == Instruction::Select ? 2 : 1)),
16657+ VL0);
16658+ if (isa<SelectInst>(VL0)) {
16659+ unsigned CondNumElements = getNumElements(MaskTy);
16660+ unsigned VecTyNumElements = getNumElements(VecTy);
16661+ assert(VecTyNumElements >= CondNumElements &&
16662+ VecTyNumElements % CondNumElements == 0 &&
16663+ "Cannot vectorize Instruction::Select");
16664+ if (CondNumElements != VecTyNumElements) {
16665+ // When the return type is i1 but the source is fixed vector type,
16666+ // we need to duplicate the condition value.
16667+ VecCost += ::getShuffleCost(
16668+ *TTI, TTI::SK_PermuteSingleSrc, MaskTy,
16669+ createReplicatedMask(VecTyNumElements / CondNumElements,
16670+ CondNumElements));
16671+ }
1662216672 }
1662316673 }
1662416674 return VecCost + CommonCost;
0 commit comments