Skip to content

Commit 13e18f8

Browse files
[SLP] Improve cost model for i1 select-as-or/and patterns
Model `select i1 %c, i1 true, i1 %d` as `or` and `select i1 %c, i1 %d, i1 false` as `and` in the SLP cost model, since these are the operations the backend will lower them to. The previous select cost overestimated the vector cost of these patterns, preventing profitable vectorization of i1 condition chains. Reviewers: hiraditya, RKSimon, bababuck Pull Request: #188572
1 parent c758592 commit 13e18f8

2 files changed

Lines changed: 104 additions & 58 deletions

File tree

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 84 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -16562,27 +16562,51 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
1656216562
CmpPredicate CurrentPred = ScalarTy->isFloatingPointTy()
1656316563
? CmpInst::BAD_FCMP_PREDICATE
1656416564
: CmpInst::BAD_ICMP_PREDICATE;
16565+
Value *LHS = nullptr, *RHS = nullptr;
1656516566
auto MatchCmp = m_Cmp(CurrentPred, m_Value(), m_Value());
16566-
if ((!match(VI, m_Select(MatchCmp, m_Value(), m_Value())) &&
16567-
!match(VI, MatchCmp)) ||
16567+
bool IsSelect =
16568+
ShuffleOrOp == Instruction::Select &&
16569+
(match(VI, m_Select(MatchCmp, m_Value(LHS), m_Value(RHS))) ||
16570+
match(VI, m_Select(m_Value(), m_Value(LHS), m_Value(RHS))));
16571+
if ((!IsSelect && !match(VI, MatchCmp)) ||
1656816572
(CurrentPred != static_cast<CmpInst::Predicate>(VecPred) &&
1656916573
CurrentPred != static_cast<CmpInst::Predicate>(SwappedVecPred)))
1657016574
VecPred = SwappedVecPred = ScalarTy->isFloatingPointTy()
1657116575
? CmpInst::BAD_FCMP_PREDICATE
1657216576
: CmpInst::BAD_ICMP_PREDICATE;
1657316577

16574-
// For selects, the "condition type" arg is the condition operand's
16575-
// type; for standalone compares, it is the result type (i1).
16576-
InstructionCost ScalarCost = TTI->getCmpSelInstrCost(
16577-
E->getOpcode(), OrigScalarTy,
16578-
ShuffleOrOp == Instruction::Select ? VL0->getOperand(0)->getType()
16579-
: VL0->getType(),
16580-
CurrentPred, CostKind,
16581-
getOperandInfo(
16582-
VI->getOperand(ShuffleOrOp == Instruction::Select ? 1 : 0)),
16583-
getOperandInfo(
16584-
VI->getOperand(ShuffleOrOp == Instruction::Select ? 2 : 1)),
16585-
VI);
16578+
// Check if operands are of i1 types, like a condition expression.
16579+
// TODO: consider implementing this in TTI.
16580+
InstructionCost ScalarCost = InstructionCost::getInvalid();
16581+
if (IsSelect && LHS->getType() == VI->getOperand(0)->getType()) {
16582+
assert(LHS->getType() == RHS->getType() &&
16583+
"Expected same type for LHS/RHS");
16584+
// select i1 v, i1 true, i1 b -> or i1 v, i1 b
16585+
if (match(LHS, m_AllOnes())) {
16586+
ScalarCost = TTI->getArithmeticInstrCost(
16587+
Instruction::Or, LHS->getType(), CostKind,
16588+
getOperandInfo(VI->getOperand(0)), getOperandInfo(RHS));
16589+
} else if (match(RHS, m_Zero())) {
16590+
// select i1 v, i1 b, i1 false -> and i1 v, i1 b
16591+
ScalarCost = TTI->getArithmeticInstrCost(
16592+
Instruction::And, LHS->getType(), CostKind,
16593+
getOperandInfo(VI->getOperand(0)), getOperandInfo(LHS));
16594+
}
16595+
}
16596+
if (!ScalarCost.isValid()) {
16597+
// For selects, the "condition type" arg is the condition operand's
16598+
// type; for standalone compares, it is the result type (i1).
16599+
ScalarCost = TTI->getCmpSelInstrCost(
16600+
E->getOpcode(), OrigScalarTy,
16601+
ShuffleOrOp == Instruction::Select ? VL0->getOperand(0)->getType()
16602+
: VL0->getType(),
16603+
CurrentPred, CostKind,
16604+
getOperandInfo(
16605+
VI->getOperand(ShuffleOrOp == Instruction::Select ? 1 : 0)),
16606+
getOperandInfo(
16607+
VI->getOperand(ShuffleOrOp == Instruction::Select ? 2 : 1)),
16608+
VI);
16609+
}
1658616610
InstructionCost IntrinsicCost = GetMinMaxCost(OrigScalarTy, VI);
1658716611
if (IntrinsicCost.isValid())
1658816612
ScalarCost = IntrinsicCost;
@@ -16599,26 +16623,52 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
1659916623
: VL0->getType(),
1660016624
VL.size());
1660116625

16602-
InstructionCost VecCost = TTI->getCmpSelInstrCost(
16603-
E->getOpcode(), VecTy, MaskTy, VecPred, CostKind,
16604-
getOperandInfo(
16605-
E->getOperand(ShuffleOrOp == Instruction::Select ? 1 : 0)),
16606-
getOperandInfo(
16607-
E->getOperand(ShuffleOrOp == Instruction::Select ? 2 : 1)),
16608-
VL0);
16609-
if (isa<SelectInst>(VL0)) {
16610-
unsigned CondNumElements = getNumElements(MaskTy);
16611-
unsigned VecTyNumElements = getNumElements(VecTy);
16612-
assert(VecTyNumElements >= CondNumElements &&
16613-
VecTyNumElements % CondNumElements == 0 &&
16614-
"Cannot vectorize Instruction::Select");
16615-
if (CondNumElements != VecTyNumElements) {
16616-
// When the return type is i1 but the source is fixed vector type, we
16617-
// need to duplicate the condition value.
16618-
VecCost += ::getShuffleCost(
16619-
*TTI, TTI::SK_PermuteSingleSrc, MaskTy,
16620-
createReplicatedMask(VecTyNumElements / CondNumElements,
16621-
CondNumElements));
16626+
InstructionCost VecCost = InstructionCost::getInvalid();
16627+
if (ShuffleOrOp == Instruction::Select) {
16628+
ArrayRef<Value *> Cond = E->getOperand(0);
16629+
ArrayRef<Value *> LHS = E->getOperand(1);
16630+
ArrayRef<Value *> RHS = E->getOperand(2);
16631+
// select <VF x i1>, <VF x i1>, <VF x i1>?
16632+
// TODO: consider implementing this in TTI.
16633+
if (Cond.front()->getType() == LHS.front()->getType()) {
16634+
// select <VF x i1> v, <VF x i1> true, <VF x i1> b -> or <VF x i1> v,
16635+
// <VF x i1> b
16636+
if (all_of(LHS, [&](Value *V) { return match(V, m_AllOnes()); })) {
16637+
VecCost = TTI->getArithmeticInstrCost(
16638+
Instruction::Or, VecTy, CostKind, getOperandInfo(Cond),
16639+
getOperandInfo(RHS));
16640+
} else if (all_of(RHS,
16641+
[&](Value *V) { return match(V, m_Zero()); })) {
16642+
// select <VF x i1> v, <VF x i1> b, <VF x i1> false -> and <VF x i1>
16643+
// v, <VF x i1> b
16644+
VecCost = TTI->getArithmeticInstrCost(
16645+
Instruction::And, VecTy, CostKind, getOperandInfo(Cond),
16646+
getOperandInfo(LHS));
16647+
}
16648+
}
16649+
}
16650+
if (!VecCost.isValid()) {
16651+
VecCost = TTI->getCmpSelInstrCost(
16652+
E->getOpcode(), VecTy, MaskTy, VecPred, CostKind,
16653+
getOperandInfo(
16654+
E->getOperand(ShuffleOrOp == Instruction::Select ? 1 : 0)),
16655+
getOperandInfo(
16656+
E->getOperand(ShuffleOrOp == Instruction::Select ? 2 : 1)),
16657+
VL0);
16658+
if (isa<SelectInst>(VL0)) {
16659+
unsigned CondNumElements = getNumElements(MaskTy);
16660+
unsigned VecTyNumElements = getNumElements(VecTy);
16661+
assert(VecTyNumElements >= CondNumElements &&
16662+
VecTyNumElements % CondNumElements == 0 &&
16663+
"Cannot vectorize Instruction::Select");
16664+
if (CondNumElements != VecTyNumElements) {
16665+
// When the return type is i1 but the source is fixed vector type,
16666+
// we need to duplicate the condition value.
16667+
VecCost += ::getShuffleCost(
16668+
*TTI, TTI::SK_PermuteSingleSrc, MaskTy,
16669+
createReplicatedMask(VecTyNumElements / CondNumElements,
16670+
CondNumElements));
16671+
}
1662216672
}
1662316673
}
1662416674
return VecCost + CommonCost;

llvm/test/Transforms/SLPVectorizer/X86/select-logical-or-and-i1-vector.ll

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,16 @@ define void @select_logical_or_i1(ptr %dst,
1212
; CHECK-LABEL: define void @select_logical_or_i1(
1313
; CHECK-SAME: ptr [[DST:%.*]], float [[D0:%.*]], float [[D1:%.*]], float [[D2:%.*]], float [[D3:%.*]], float [[THRESHOLD:%.*]], float [[HPHB_VAL:%.*]], i1 [[SCALAR_COND:%.*]], float [[Y0:%.*]], float [[Y1:%.*]], float [[Y2:%.*]], float [[Y3:%.*]], float [[E0:%.*]], float [[E1:%.*]], float [[E2:%.*]], float [[E3:%.*]]) #[[ATTR0:[0-9]+]] {
1414
; CHECK-NEXT: [[ENTRY:.*:]]
15-
; CHECK-NEXT: [[CMP0:%.*]] = fcmp fast uge float [[D0]], [[THRESHOLD]]
16-
; CHECK-NEXT: [[CMP1:%.*]] = fcmp fast uge float [[D1]], [[THRESHOLD]]
17-
; CHECK-NEXT: [[CMP2:%.*]] = fcmp fast uge float [[D2]], [[THRESHOLD]]
18-
; CHECK-NEXT: [[CMP3:%.*]] = fcmp fast uge float [[D3]], [[THRESHOLD]]
19-
; CHECK-NEXT: [[OR3:%.*]] = select i1 [[CMP3]], i1 true, i1 [[SCALAR_COND]]
20-
; CHECK-NEXT: [[OR2:%.*]] = select i1 [[CMP2]], i1 true, i1 [[SCALAR_COND]]
21-
; CHECK-NEXT: [[OR1:%.*]] = select i1 [[CMP1]], i1 true, i1 [[SCALAR_COND]]
22-
; CHECK-NEXT: [[OR0:%.*]] = select i1 [[CMP0]], i1 true, i1 [[SCALAR_COND]]
23-
; CHECK-NEXT: [[TMP0:%.*]] = insertelement <4 x i1> poison, i1 [[OR0]], i32 0
24-
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <4 x i1> [[TMP0]], i1 [[OR1]], i32 1
25-
; CHECK-NEXT: [[TMP2:%.*]] = insertelement <4 x i1> [[TMP1]], i1 [[OR2]], i32 2
26-
; CHECK-NEXT: [[TMP9:%.*]] = insertelement <4 x i1> [[TMP2]], i1 [[OR3]], i32 3
15+
; CHECK-NEXT: [[TMP0:%.*]] = insertelement <4 x float> poison, float [[D0]], i32 0
16+
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <4 x float> [[TMP0]], float [[D1]], i32 1
17+
; CHECK-NEXT: [[TMP2:%.*]] = insertelement <4 x float> [[TMP1]], float [[D2]], i32 2
18+
; CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x float> [[TMP2]], float [[D3]], i32 3
19+
; CHECK-NEXT: [[TMP4:%.*]] = insertelement <4 x float> poison, float [[THRESHOLD]], i32 0
20+
; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <4 x float> [[TMP4]], <4 x float> poison, <4 x i32> zeroinitializer
21+
; CHECK-NEXT: [[TMP6:%.*]] = fcmp fast uge <4 x float> [[TMP3]], [[TMP5]]
22+
; CHECK-NEXT: [[TMP7:%.*]] = insertelement <4 x i1> poison, i1 [[SCALAR_COND]], i32 0
23+
; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <4 x i1> [[TMP7]], <4 x i1> poison, <4 x i32> zeroinitializer
24+
; CHECK-NEXT: [[TMP9:%.*]] = select <4 x i1> [[TMP6]], <4 x i1> splat (i1 true), <4 x i1> [[TMP8]]
2725
; CHECK-NEXT: [[TMP10:%.*]] = insertelement <4 x float> poison, float [[HPHB_VAL]], i32 0
2826
; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <4 x float> [[TMP10]], <4 x float> poison, <4 x i32> zeroinitializer
2927
; CHECK-NEXT: [[TMP12:%.*]] = select <4 x i1> [[TMP9]], <4 x float> zeroinitializer, <4 x float> [[TMP11]]
@@ -86,18 +84,16 @@ define void @select_logical_and_i1(ptr %dst,
8684
; CHECK-LABEL: define void @select_logical_and_i1(
8785
; CHECK-SAME: ptr [[DST:%.*]], float [[D0:%.*]], float [[D1:%.*]], float [[D2:%.*]], float [[D3:%.*]], float [[THRESHOLD:%.*]], float [[HPHB_VAL:%.*]], i1 [[SCALAR_COND:%.*]], float [[Y0:%.*]], float [[Y1:%.*]], float [[Y2:%.*]], float [[Y3:%.*]], float [[E0:%.*]], float [[E1:%.*]], float [[E2:%.*]], float [[E3:%.*]]) #[[ATTR0]] {
8886
; CHECK-NEXT: [[ENTRY:.*:]]
89-
; CHECK-NEXT: [[CMP0:%.*]] = fcmp fast uge float [[D0]], [[THRESHOLD]]
90-
; CHECK-NEXT: [[CMP1:%.*]] = fcmp fast uge float [[D1]], [[THRESHOLD]]
91-
; CHECK-NEXT: [[CMP2:%.*]] = fcmp fast uge float [[D2]], [[THRESHOLD]]
92-
; CHECK-NEXT: [[CMP3:%.*]] = fcmp fast uge float [[D3]], [[THRESHOLD]]
93-
; CHECK-NEXT: [[AND3:%.*]] = select i1 [[CMP3]], i1 [[SCALAR_COND]], i1 false
94-
; CHECK-NEXT: [[AND2:%.*]] = select i1 [[CMP2]], i1 [[SCALAR_COND]], i1 false
95-
; CHECK-NEXT: [[AND1:%.*]] = select i1 [[CMP1]], i1 [[SCALAR_COND]], i1 false
96-
; CHECK-NEXT: [[AND0:%.*]] = select i1 [[CMP0]], i1 [[SCALAR_COND]], i1 false
97-
; CHECK-NEXT: [[TMP0:%.*]] = insertelement <4 x i1> poison, i1 [[AND0]], i32 0
98-
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <4 x i1> [[TMP0]], i1 [[AND1]], i32 1
99-
; CHECK-NEXT: [[TMP2:%.*]] = insertelement <4 x i1> [[TMP1]], i1 [[AND2]], i32 2
100-
; CHECK-NEXT: [[TMP9:%.*]] = insertelement <4 x i1> [[TMP2]], i1 [[AND3]], i32 3
87+
; CHECK-NEXT: [[TMP0:%.*]] = insertelement <4 x float> poison, float [[D0]], i32 0
88+
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <4 x float> [[TMP0]], float [[D1]], i32 1
89+
; CHECK-NEXT: [[TMP2:%.*]] = insertelement <4 x float> [[TMP1]], float [[D2]], i32 2
90+
; CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x float> [[TMP2]], float [[D3]], i32 3
91+
; CHECK-NEXT: [[TMP4:%.*]] = insertelement <4 x float> poison, float [[THRESHOLD]], i32 0
92+
; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <4 x float> [[TMP4]], <4 x float> poison, <4 x i32> zeroinitializer
93+
; CHECK-NEXT: [[TMP6:%.*]] = fcmp fast uge <4 x float> [[TMP3]], [[TMP5]]
94+
; CHECK-NEXT: [[TMP7:%.*]] = insertelement <4 x i1> poison, i1 [[SCALAR_COND]], i32 0
95+
; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <4 x i1> [[TMP7]], <4 x i1> poison, <4 x i32> zeroinitializer
96+
; CHECK-NEXT: [[TMP9:%.*]] = select <4 x i1> [[TMP6]], <4 x i1> [[TMP8]], <4 x i1> zeroinitializer
10197
; CHECK-NEXT: [[TMP10:%.*]] = insertelement <4 x float> poison, float [[HPHB_VAL]], i32 0
10298
; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <4 x float> [[TMP10]], <4 x float> poison, <4 x i32> zeroinitializer
10399
; CHECK-NEXT: [[TMP12:%.*]] = select <4 x i1> [[TMP9]], <4 x float> zeroinitializer, <4 x float> [[TMP11]]

0 commit comments

Comments
 (0)