@@ -264,12 +264,14 @@ CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
264264 BranchProbabilityInfo *BPI, AssumptionCache *AC,
265265 bool AllowVarArgs, bool AllowAlloca,
266266 BasicBlock *AllocationBlock, std::string Suffix,
267- bool ArgsInZeroAddressSpace)
267+ bool ArgsInZeroAddressSpace,
268+ bool VoidReturnWithSingleOutput)
268269 : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
269270 BPI(BPI), AC(AC), AllocationBlock(AllocationBlock),
270271 AllowVarArgs(AllowVarArgs),
271272 Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)),
272- Suffix(Suffix), ArgsInZeroAddressSpace(ArgsInZeroAddressSpace) {}
273+ Suffix(Suffix), ArgsInZeroAddressSpace(ArgsInZeroAddressSpace),
274+ VoidReturnWithSingleOutput(VoidReturnWithSingleOutput) {}
273275
274276// / definedInRegion - Return true if the specified value is defined in the
275277// / extracted region.
@@ -662,7 +664,7 @@ bool CodeExtractor::isEligible() const {
662664
663665void CodeExtractor::findInputsOutputs (ValueSet &Inputs, ValueSet &Outputs,
664666 const ValueSet &SinkCands,
665- bool CollectGlobalInputs) const {
667+ bool CollectGlobalInputs) {
666668 for (BasicBlock *BB : Blocks) {
667669 // If a used value is defined outside the region, it's an input. If an
668670 // instruction is used outside the region, it's an output.
@@ -682,6 +684,12 @@ void CodeExtractor::findInputsOutputs(ValueSet &Inputs, ValueSet &Outputs,
682684 }
683685 }
684686 }
687+
688+ if (!VoidReturnWithSingleOutput && !AggregateArgs && Outputs.size () == 1 &&
689+ getCommonExitBlock (Blocks)) {
690+ FuncRetVal = Outputs[0 ];
691+ Outputs.clear ();
692+ }
685693}
686694
687695// / severSplitPHINodesOfEntry - If a PHI node has multiple inputs from outside
@@ -877,7 +885,7 @@ Function *CodeExtractor::constructFunctionDeclaration(
877885 M->getContext (), ArgsInZeroAddressSpace ? 0 : DL.getAllocaAddrSpace ()));
878886 }
879887
880- Type *RetTy = getSwitchType ();
888+ Type *RetTy = FuncRetVal ? FuncRetVal-> getType () : getSwitchType ();
881889 LLVM_DEBUG ({
882890 dbgs () << " Function type: " << *RetTy << " f(" ;
883891 for (Type *i : ParamTy)
@@ -1519,8 +1527,8 @@ CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC,
15191527 inputs, outputs, StructValues, newFunction, StructTy, oldFunction, ReplIP,
15201528 EntryFreq, LifetimesStart.getArrayRef (), Reloads);
15211529
1522- insertReplacerCall (oldFunction, header, TheCall-> getParent () , outputs,
1523- Reloads, ExitWeights);
1530+ insertReplacerCall (oldFunction, header, TheCall, outputs, Reloads ,
1531+ ExitWeights);
15241532
15251533 fixupDebugInfoPostExtraction (*oldFunction, *newFunction, *TheCall, inputs,
15261534 NewValues);
@@ -1699,14 +1707,17 @@ void CodeExtractor::emitFunctionBody(
16991707 ExitBlockMap[OldTarget] = NewTarget;
17001708
17011709 Value *brVal = nullptr ;
1702- Type *RetTy = getSwitchType ();
1710+ Type *RetTy = FuncRetVal ? FuncRetVal-> getType () : getSwitchType ();
17031711 assert (ExtractedFuncRetVals.size () < 0xffff &&
17041712 " too many exit blocks for switch" );
17051713 switch (ExtractedFuncRetVals.size ()) {
17061714 case 0 :
1707- case 1 :
17081715 // No value needed.
17091716 break ;
1717+ case 1 :
1718+ if (FuncRetVal)
1719+ brVal = FuncRetVal;
1720+ break ;
17101721 case 2 : // Conditional branch, return a bool
17111722 brVal = ConstantInt::get (RetTy, !SuccNum);
17121723 break ;
@@ -2019,13 +2030,14 @@ CallInst *CodeExtractor::emitReplacerCall(
20192030}
20202031
20212032void CodeExtractor::insertReplacerCall (
2022- Function *oldFunction, BasicBlock *header, BasicBlock *codeReplacer ,
2033+ Function *oldFunction, BasicBlock *header, CallInst *ReplacerCall ,
20232034 const ValueSet &outputs, ArrayRef<Value *> Reloads,
20242035 const DenseMap<BasicBlock *, BlockFrequency> &ExitWeights) {
20252036
20262037 // Rewrite branches to basic blocks outside of the loop to new dummy blocks
20272038 // within the new function. This must be done before we lose track of which
20282039 // blocks were originally in the code region.
2040+ BasicBlock *codeReplacer = ReplacerCall->getParent ();
20292041 std::vector<User *> Users (header->user_begin (), header->user_end ());
20302042 for (auto &U : Users)
20312043 // The BasicBlock which contains the branch is not in the region
@@ -2067,6 +2079,13 @@ void CodeExtractor::insertReplacerCall(
20672079 }
20682080 }
20692081
2082+ if (FuncRetVal)
2083+ for (User *U : FuncRetVal->users ()) {
2084+ Instruction *inst = cast<Instruction>(U);
2085+ if (inst->getParent ()->getParent () == oldFunction)
2086+ inst->replaceUsesOfWith (FuncRetVal, ReplacerCall);
2087+ }
2088+
20702089 // Update the branch weights for the exit block.
20712090 if (BFI && ExtractedFuncRetVals.size () > 1 )
20722091 calculateNewCallTerminatorWeights (codeReplacer, ExitWeights, BPI);
0 commit comments