Skip to content

Commit 66f0da7

Browse files
authored
[Extractor] Use function return for the one and only output (#191824)
Currently code extractor uses parameters to pass outputs. Alloca/store/load instructions are used to get the output value in the parent functions. When there is only one output from the extracted code (this is one of the most common cases), using the function return for the only one output can facilitate the other transformations (eg, tail call opt). This is to modify the code for the extracted function to return the output value if there is only one output for the extracted region.
1 parent 6e7f08d commit 66f0da7

16 files changed

Lines changed: 152 additions & 53 deletions

llvm/include/llvm/Transforms/Utils/CodeExtractor.h

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,13 @@ class LLVM_ABI CodeExtractor {
132132
// space.
133133
bool ArgsInZeroAddressSpace;
134134

135+
// If true, the outlined function always return void even when there is only
136+
// one output.
137+
bool VoidReturnWithSingleOutput;
138+
139+
// If set, the return value of the outline function.
140+
Value *FuncRetVal = nullptr;
141+
135142
public:
136143
/// Create a code extractor for a sequence of blocks.
137144
///
@@ -147,13 +154,16 @@ class LLVM_ABI CodeExtractor {
147154
/// which case it will be placed in the entry block of the function from which
148155
/// the code is being extracted. If ArgsInZeroAddressSpace param is set to
149156
/// true, then the aggregate param pointer of the outlined function is
150-
/// declared in zero address space.
157+
/// declared in zero address space. If VoidReturnWithSingleOutput is set to
158+
/// true, then the return type of the outlined function is set void even if
159+
/// there is only one output.
151160
CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT = nullptr,
152161
bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr,
153162
BranchProbabilityInfo *BPI = nullptr,
154163
AssumptionCache *AC = nullptr, bool AllowVarArgs = false,
155164
bool AllowAlloca = false, BasicBlock *AllocationBlock = nullptr,
156-
std::string Suffix = "", bool ArgsInZeroAddressSpace = false);
165+
std::string Suffix = "", bool ArgsInZeroAddressSpace = false,
166+
bool VoidReturnWithSingleOutput = true);
157167

158168
/// Perform the extraction, returning the new function.
159169
///
@@ -201,7 +211,7 @@ class LLVM_ABI CodeExtractor {
201211
/// on the cost however.
202212
void findInputsOutputs(ValueSet &Inputs, ValueSet &Outputs,
203213
const ValueSet &Allocas,
204-
bool CollectGlobalInputs = false) const;
214+
bool CollectGlobalInputs = false);
205215

206216
/// Check if life time marker nodes can be hoisted/sunk into the outline
207217
/// region.
@@ -300,7 +310,7 @@ class LLVM_ABI CodeExtractor {
300310
/// into the original function's control flow.
301311
void
302312
insertReplacerCall(Function *oldFunction, BasicBlock *header,
303-
BasicBlock *codeReplacer, const ValueSet &outputs,
313+
CallInst *ReplacerCall, const ValueSet &outputs,
304314
ArrayRef<Value *> Reloads,
305315
const DenseMap<BasicBlock *, BlockFrequency> &ExitWeights);
306316
};

llvm/lib/Transforms/IPO/HotColdSplitting.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -721,7 +721,9 @@ bool HotColdSplitting::outlineColdRegions(Function &F, bool HasProfileSummary) {
721721
SubRegion, &*DT, /* AggregateArgs */ false, /* BFI */ nullptr,
722722
/* BPI */ nullptr, AC, /* AllowVarArgs */ false,
723723
/* AllowAlloca */ false, /* AllocaBlock */ nullptr,
724-
/* Suffix */ "cold." + std::to_string(OutlinedFunctionID));
724+
/* Suffix */ "cold." + std::to_string(OutlinedFunctionID),
725+
/* ArgsInZeroAddressSpace */ false,
726+
/* VoidReturnWithSingleOutput */ false);
725727

726728
if (CE.isEligible() && isSplittingBeneficial(CE, SubRegion, TTI) &&
727729
// If this outlining region intersects with another, drop the new

llvm/lib/Transforms/IPO/PartialInlining.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,7 +1104,10 @@ bool PartialInlinerImpl::FunctionCloner::doMultiRegionFunctionOutlining() {
11041104
CodeExtractor CE(RegionInfo.Region, &DT, /*AggregateArgs*/ false,
11051105
ClonedFuncBFI.get(), &BPI,
11061106
LookupAC(*RegionInfo.EntryBlock->getParent()),
1107-
/* AllowVarargs */ false);
1107+
/* AllowVarargs */ false, /* AllowAlloca */ false,
1108+
/* AllocaBlock */ nullptr, /* Suffix */ "",
1109+
/* ArgsInZeroAddressSpace */ false,
1110+
/* VoidReturnWithSingleOutput */ false);
11081111

11091112
CE.findInputsOutputs(Inputs, Outputs, Sinks);
11101113

@@ -1185,7 +1188,10 @@ PartialInlinerImpl::FunctionCloner::doSingleRegionFunctionOutlining() {
11851188
Function *OutlinedFunc =
11861189
CodeExtractor(ToExtract, &DT, /*AggregateArgs*/ false,
11871190
ClonedFuncBFI.get(), &BPI, LookupAC(*ClonedFunc),
1188-
/* AllowVarargs */ true)
1191+
/* AllowVarargs */ true, /* AllowAlloca */ false,
1192+
/* AllocaBlock */ nullptr, /* Suffix */ "",
1193+
/* ArgsInZeroAddressSpace */ false,
1194+
/* VoidReturnWithSingleOutput */ false)
11891195
.extractCodeRegion(CEAC);
11901196

11911197
if (OutlinedFunc) {

llvm/lib/Transforms/Utils/CodeExtractor.cpp

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -264,12 +264,14 @@ CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
264264
BranchProbabilityInfo *BPI, AssumptionCache *AC,
265265
bool AllowVarArgs, bool AllowAlloca,
266266
BasicBlock *AllocationBlock, std::string Suffix,
267-
bool ArgsInZeroAddressSpace)
267+
bool ArgsInZeroAddressSpace,
268+
bool VoidReturnWithSingleOutput)
268269
: DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
269270
BPI(BPI), AC(AC), AllocationBlock(AllocationBlock),
270271
AllowVarArgs(AllowVarArgs),
271272
Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)),
272-
Suffix(Suffix), ArgsInZeroAddressSpace(ArgsInZeroAddressSpace) {}
273+
Suffix(Suffix), ArgsInZeroAddressSpace(ArgsInZeroAddressSpace),
274+
VoidReturnWithSingleOutput(VoidReturnWithSingleOutput) {}
273275

274276
/// definedInRegion - Return true if the specified value is defined in the
275277
/// extracted region.
@@ -662,7 +664,7 @@ bool CodeExtractor::isEligible() const {
662664

663665
void CodeExtractor::findInputsOutputs(ValueSet &Inputs, ValueSet &Outputs,
664666
const ValueSet &SinkCands,
665-
bool CollectGlobalInputs) const {
667+
bool CollectGlobalInputs) {
666668
for (BasicBlock *BB : Blocks) {
667669
// If a used value is defined outside the region, it's an input. If an
668670
// instruction is used outside the region, it's an output.
@@ -682,6 +684,12 @@ void CodeExtractor::findInputsOutputs(ValueSet &Inputs, ValueSet &Outputs,
682684
}
683685
}
684686
}
687+
688+
if (!VoidReturnWithSingleOutput && !AggregateArgs && Outputs.size() == 1 &&
689+
getCommonExitBlock(Blocks)) {
690+
FuncRetVal = Outputs[0];
691+
Outputs.clear();
692+
}
685693
}
686694

687695
/// severSplitPHINodesOfEntry - If a PHI node has multiple inputs from outside
@@ -877,7 +885,7 @@ Function *CodeExtractor::constructFunctionDeclaration(
877885
M->getContext(), ArgsInZeroAddressSpace ? 0 : DL.getAllocaAddrSpace()));
878886
}
879887

880-
Type *RetTy = getSwitchType();
888+
Type *RetTy = FuncRetVal ? FuncRetVal->getType() : getSwitchType();
881889
LLVM_DEBUG({
882890
dbgs() << "Function type: " << *RetTy << " f(";
883891
for (Type *i : ParamTy)
@@ -1519,8 +1527,8 @@ CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC,
15191527
inputs, outputs, StructValues, newFunction, StructTy, oldFunction, ReplIP,
15201528
EntryFreq, LifetimesStart.getArrayRef(), Reloads);
15211529

1522-
insertReplacerCall(oldFunction, header, TheCall->getParent(), outputs,
1523-
Reloads, ExitWeights);
1530+
insertReplacerCall(oldFunction, header, TheCall, outputs, Reloads,
1531+
ExitWeights);
15241532

15251533
fixupDebugInfoPostExtraction(*oldFunction, *newFunction, *TheCall, inputs,
15261534
NewValues);
@@ -1699,14 +1707,17 @@ void CodeExtractor::emitFunctionBody(
16991707
ExitBlockMap[OldTarget] = NewTarget;
17001708

17011709
Value *brVal = nullptr;
1702-
Type *RetTy = getSwitchType();
1710+
Type *RetTy = FuncRetVal ? FuncRetVal->getType() : getSwitchType();
17031711
assert(ExtractedFuncRetVals.size() < 0xffff &&
17041712
"too many exit blocks for switch");
17051713
switch (ExtractedFuncRetVals.size()) {
17061714
case 0:
1707-
case 1:
17081715
// No value needed.
17091716
break;
1717+
case 1:
1718+
if (FuncRetVal)
1719+
brVal = FuncRetVal;
1720+
break;
17101721
case 2: // Conditional branch, return a bool
17111722
brVal = ConstantInt::get(RetTy, !SuccNum);
17121723
break;
@@ -2019,13 +2030,14 @@ CallInst *CodeExtractor::emitReplacerCall(
20192030
}
20202031

20212032
void CodeExtractor::insertReplacerCall(
2022-
Function *oldFunction, BasicBlock *header, BasicBlock *codeReplacer,
2033+
Function *oldFunction, BasicBlock *header, CallInst *ReplacerCall,
20232034
const ValueSet &outputs, ArrayRef<Value *> Reloads,
20242035
const DenseMap<BasicBlock *, BlockFrequency> &ExitWeights) {
20252036

20262037
// Rewrite branches to basic blocks outside of the loop to new dummy blocks
20272038
// within the new function. This must be done before we lose track of which
20282039
// blocks were originally in the code region.
2040+
BasicBlock *codeReplacer = ReplacerCall->getParent();
20292041
std::vector<User *> Users(header->user_begin(), header->user_end());
20302042
for (auto &U : Users)
20312043
// The BasicBlock which contains the branch is not in the region
@@ -2067,6 +2079,13 @@ void CodeExtractor::insertReplacerCall(
20672079
}
20682080
}
20692081

2082+
if (FuncRetVal)
2083+
for (User *U : FuncRetVal->users()) {
2084+
Instruction *inst = cast<Instruction>(U);
2085+
if (inst->getParent()->getParent() == oldFunction)
2086+
inst->replaceUsesOfWith(FuncRetVal, ReplacerCall);
2087+
}
2088+
20702089
// Update the branch weights for the exit block.
20712090
if (BFI && ExtractedFuncRetVals.size() > 1)
20722091
calculateNewCallTerminatorWeights(codeReplacer, ExitWeights, BPI);

llvm/test/Transforms/CodeExtractor/PartialInlineAnd.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ bb:
4343
; LIMIT-LABEL: @dummy_caller
4444
; LIMIT: br i1
4545
; LIMIT-NOT: br
46-
; LIMIT: call void @bar.1.
46+
; LIMIT: call i32 @bar.1.
4747
%tmp = tail call i32 @bar(i32 %arg)
4848
ret i32 %tmp
4949
}

llvm/test/Transforms/CodeExtractor/PartialInlineAttributes.ll

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ if.end:
5555
ret i32 %add
5656
}
5757
; CHECK-LABEL: @caller
58-
; CHECK: call void @callee_most.2.if.then(i32 %v
58+
; CHECK: call i32 @callee_most.2.if.then(i32 %v
5959
; CHECK: call i32 @callee_noinline(i32 %v)
60-
; CHECK: call void @callee_writeonly.1.if.then(i32 %v
60+
; CHECK: call i32 @callee_writeonly.1.if.then(i32 %v
6161
define i32 @caller(i32 %v) ssp {
6262
entry:
6363
%c1 = call i32 @callee_most(i32 %v)
@@ -66,8 +66,8 @@ entry:
6666
ret i32 %c3
6767
}
6868

69-
; CHECK: define internal void @callee_writeonly.1.if.then(i32 %v, ptr %sub.out) [[FN_ATTRS0:#[0-9]+]]
70-
; CHECK: define internal void @callee_most.2.if.then(i32 %v, ptr %sub.out) [[FN_ATTRS:#[0-9]+]]
69+
; CHECK: define internal i32 @callee_writeonly.1.if.then(i32 %v) [[FN_ATTRS0:#[0-9]+]]
70+
; CHECK: define internal i32 @callee_most.2.if.then(i32 %v) [[FN_ATTRS:#[0-9]+]]
7171

7272
; attributes to preserve
7373
attributes #0 = {

llvm/test/Transforms/CodeExtractor/PartialInlineDebug.ll

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ if.end: ; preds = %if.then, %entry
2424
; CHECK-LABEL: @caller
2525
; CHECK: codeRepl.i:
2626
; CHECK-NOT: br label
27-
; CHECK: call void @callee.2.if.then(i32 %v, ptr %mul.loc.i), !dbg ![[DBG2:[0-9]+]]
27+
; CHECK: call i32 @callee.2.if.then(i32 %v), !dbg ![[DBG2:[0-9]+]]
2828
define i32 @caller(i32 %v) !dbg !8 {
2929
entry:
3030
%call = call i32 @callee(i32 %v), !dbg !14
@@ -55,17 +55,17 @@ if.end:
5555
; CHECK-LABEL: @caller2
5656
; CHECK: codeRepl.i:
5757
; CHECK-NOT: br label
58-
; CHECK: call void @callee2.1.if.then(i32 %v, ptr %sub.loc.i), !dbg ![[DBG4:[0-9]+]]
58+
; CHECK: call i32 @callee2.1.if.then(i32 %v), !dbg ![[DBG4:[0-9]+]]
5959
define i32 @caller2(i32 %v) !dbg !21 {
6060
entry:
6161
%call = call i32 @callee2(i32 %v), !dbg !22
6262
ret i32 %call
6363
}
6464

65-
; CHECK-LABEL: define internal void @callee2.1.if.then
65+
; CHECK-LABEL: define internal i32 @callee2.1.if.then
6666
; CHECK: br label %if.then, !dbg ![[DBG5:[0-9]+]]
6767

68-
; CHECK-LABEL: define internal void @callee.2.if.then
68+
; CHECK-LABEL: define internal i32 @callee.2.if.then
6969
; CHECK: br label %if.then, !dbg ![[DBG6:[0-9]+]]
7070

7171
; CHECK: ![[DBG1]] = !DILocation(line: 10, column: 7,

llvm/test/Transforms/CodeExtractor/PartialInlineInvokeProducesOutVal.ll

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,25 +24,22 @@ bb5: ; preds = %bb4, %bb1, %bb
2424

2525
; CHECK-LABEL: @dummy_caller
2626
; CHECK-LABEL: bb:
27-
; CHECK-NEXT: [[CALL26LOC:%.*]] = alloca ptr
2827
; CHECK-LABEL: codeRepl.i:
29-
; CHECK-NEXT: call void @llvm.lifetime.start.p0(ptr [[CALL26LOC]])
30-
; CHECK-NEXT: call void @bar.1.bb1(ptr [[CALL26LOC]])
31-
; CHECK-NEXT: %call26.reload.i = load ptr, ptr [[CALL26LOC]]
32-
; CHECK-NEXT: call void @llvm.lifetime.end.p0(ptr [[CALL26LOC]])
28+
; CHECK-NEXT: call ptr @bar.1.bb1()
3329
define ptr @dummy_caller(i32 %arg) {
3430
bb:
3531
%tmp = tail call ptr @bar(i32 %arg)
3632
ret ptr %tmp
3733
}
3834

39-
; CHECK-LABEL: define internal void @bar.1.bb1
35+
; CHECK-LABEL: define internal ptr @bar.1.bb1
4036
; CHECK-LABEL: bb1:
4137
; CHECK-NEXT: %call26 = invoke ptr @invoke_callee()
4238
; CHECK-NEXT: to label %cont unwind label %lpad
4339
; CHECK-LABEL: cont:
44-
; CHECK-NEXT: store ptr %call26, ptr %call26.out
4540
; CHECK-NEXT: br label %bb5.exitStub
41+
; CHECK-LABEL: bb5.exitStub:
42+
; CHECK-NEXT: ret ptr %call26
4643

4744
; Function Attrs: nobuiltin
4845
declare dso_local noalias nonnull ptr @invoke_callee() local_unnamed_addr #1

llvm/test/Transforms/CodeExtractor/PartialInlineOrAnd.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ bb:
5454
; LIMIT3: br i1
5555
; LIMIT3: br i1
5656
; LIMIT3-NOT: br i1
57-
; LIMIT3: call void @bar.1.
57+
; LIMIT3: call i32 @bar.1.
5858
; LIMIT2-LABEL: @dummy_caller
5959
; LIMIT2-NOT: br i1
6060
; LIMIT2: call i32 @bar(

llvm/test/Transforms/CodeExtractor/PartialInlineVarArgsDebug.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@ if.end: ; preds = %if.then, %entry
2020
; CHECK-LABEL: @caller
2121
; CHECK: codeRepl.i:
2222
; CHECK-NOT: br label
23-
; CHECK: call void (i32, ptr, ...) @callee.1.if.then(i32 %v, ptr %mul.loc.i, i32 99), !dbg ![[DBG2:[0-9]+]]
23+
; CHECK: call i32 (i32, ...) @callee.1.if.then(i32 %v, i32 99), !dbg ![[DBG2:[0-9]+]]
2424
define i32 @caller(i32 %v) !dbg !8 {
2525
entry:
2626
%call = call i32 (i32, ...) @callee(i32 %v, i32 99), !dbg !14
2727
ret i32 %call, !dbg !15
2828
}
2929

30-
; CHECK-LABEL: define internal void @callee.1.if.then
30+
; CHECK-LABEL: define internal i32 @callee.1.if.then
3131
; CHECK: br label %if.then, !dbg ![[DBG3:[0-9]+]]
3232

3333
; CHECK: ![[DBG1]] = !DILocation(line: 10, column: 7,

0 commit comments

Comments
 (0)