Skip to content

Commit a2741b3

Browse files
authored
Finalize tail call support (#2246)
Adds tail call support to fuzzer and makes small changes to handle return calls in multiple utilities and passes. Makes larger changes to DAE and inlining passes to properly handle tail calls.
1 parent 0beba8a commit a2741b3

28 files changed

Lines changed: 758 additions & 139 deletions

scripts/fuzz_opt.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
# simd: known issues with d8
3535
# atomics, bulk memory: doesn't work in wasm2js
3636
# truncsat: https://github.com/WebAssembly/binaryen/issues/2198
37-
# tail-call: WIP
3837
CONSTANT_FEATURE_OPTS = ['--all-features']
3938

4039
# possible feature options that are sometimes passed to the tools.
@@ -298,7 +297,7 @@ def run(self, wasm):
298297
return out
299298

300299
def can_run_on_feature_opts(self, feature_opts):
301-
return all([x in feature_opts for x in ['--disable-exception-handling', '--disable-simd', '--disable-threads', '--disable-bulk-memory', '--disable-nontrapping-float-to-int']])
300+
return all([x in feature_opts for x in ['--disable-exception-handling', '--disable-simd', '--disable-threads', '--disable-bulk-memory', '--disable-nontrapping-float-to-int', '--disable-tail-call']])
302301

303302

304303
class Asyncify(TestCaseHandler):
@@ -343,7 +342,7 @@ def do_asyncify(wasm):
343342
compare(before, after_asyncify, 'Asyncify (before/after_asyncify)')
344343

345344
def can_run_on_feature_opts(self, feature_opts):
346-
return all([x in feature_opts for x in ['--disable-exception-handling', '--disable-simd']])
345+
return all([x in feature_opts for x in ['--disable-exception-handling', '--disable-simd', '--disable-tail-call']])
347346

348347

349348
# The global list of all test case handlers

src/ir/ExpressionAnalyzer.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,13 @@ template<typename T> void visitImmediates(Expression* curr, T& visitor) {
132132
}
133133
visitor.visitScopeName(curr->default_);
134134
}
135-
void visitCall(Call* curr) { visitor.visitNonScopeName(curr->target); }
135+
void visitCall(Call* curr) {
136+
visitor.visitNonScopeName(curr->target);
137+
visitor.visitInt(curr->isReturn);
138+
}
136139
void visitCallIndirect(CallIndirect* curr) {
137140
visitor.visitNonScopeName(curr->fullType);
141+
visitor.visitInt(curr->isReturn);
138142
}
139143
void visitLocalGet(LocalGet* curr) { visitor.visitIndex(curr->index); }
140144
void visitLocalSet(LocalSet* curr) { visitor.visitIndex(curr->index); }

src/ir/ExpressionManipulator.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,16 @@ flexibleCopy(Expression* original, Module& wasm, CustomCopier custom) {
7171
copy(curr->value));
7272
}
7373
Expression* visitCall(Call* curr) {
74-
auto* ret = builder.makeCall(curr->target, {}, curr->type);
74+
auto* ret =
75+
builder.makeCall(curr->target, {}, curr->type, curr->isReturn);
7576
for (Index i = 0; i < curr->operands.size(); i++) {
7677
ret->operands.push_back(copy(curr->operands[i]));
7778
}
7879
return ret;
7980
}
8081
Expression* visitCallIndirect(CallIndirect* curr) {
8182
auto* ret = builder.makeCallIndirect(
82-
curr->fullType, copy(curr->target), {}, curr->type);
83+
curr->fullType, copy(curr->target), {}, curr->type, curr->isReturn);
8384
for (Index i = 0; i < curr->operands.size(); i++) {
8485
ret->operands.push_back(copy(curr->operands[i]));
8586
}

src/ir/effects.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,14 +223,22 @@ struct EffectAnalyzer
223223

224224
void visitCall(Call* curr) {
225225
calls = true;
226+
if (curr->isReturn) {
227+
branches = true;
228+
}
226229
if (debugInfo) {
227230
// debugInfo call imports must be preserved very strongly, do not
228231
// move code around them
229232
// FIXME: we could check if the call is to an import
230233
branches = true;
231234
}
232235
}
233-
void visitCallIndirect(CallIndirect* curr) { calls = true; }
236+
void visitCallIndirect(CallIndirect* curr) {
237+
calls = true;
238+
if (curr->isReturn) {
239+
branches = true;
240+
}
241+
}
234242
void visitLocalGet(LocalGet* curr) { localsRead.insert(curr->index); }
235243
void visitLocalSet(LocalSet* curr) { localsWritten.insert(curr->index); }
236244
void visitGlobalGet(GlobalGet* curr) { globalsRead.insert(curr->name); }

src/passes/Asyncify.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,9 @@ class ModuleAnalyzer {
349349
}
350350
struct Walker : PostWalker<Walker> {
351351
void visitCall(Call* curr) {
352+
if (curr->isReturn) {
353+
Fatal() << "tail calls not yet supported in aysncify";
354+
}
352355
auto* target = module->getFunction(curr->target);
353356
if (target->imported() && target->module == ASYNCIFY) {
354357
// Redirect the imports to the functions we'll add later.
@@ -375,6 +378,9 @@ class ModuleAnalyzer {
375378
info->callsTo.insert(target);
376379
}
377380
void visitCallIndirect(CallIndirect* curr) {
381+
if (curr->isReturn) {
382+
Fatal() << "tail calls not yet supported in aysncify";
383+
}
378384
if (canIndirectChangeState) {
379385
info->canChangeState = true;
380386
}

src/passes/DeadArgumentElimination.cpp

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,16 @@ struct DAEFunctionInfo {
5757
// Map of all calls that are dropped, to their drops' locations (so that
5858
// if we can optimize out the drop, we can replace the drop there).
5959
std::unordered_map<Call*, Expression**> droppedCalls;
60+
// Whether this function contains any tail calls (including indirect tail
61+
// calls) and the set of functions this function tail calls. Tail-callers and
62+
// tail-callees cannot have their dropped returns removed because of the
63+
// constraint that tail-callees must have the same return type as
64+
// tail-callers. Indirectly tail called functions are already not optimized
65+
// because being in a table inhibits DAE. TODO: Allow the removal of dropped
66+
// returns from tail-callers if their tail-callees can have their returns
67+
// removed as well.
68+
bool hasTailCalls = false;
69+
std::unordered_set<Name> tailCallees;
6070
// Whether the function can be called from places that
6171
// affect what we can do. For now, any call we don't
6272
// see inhibits our optimizations, but TODO: an export
@@ -117,6 +127,16 @@ struct DAEScanner
117127
if (!getModule()->getFunction(curr->target)->imported()) {
118128
info->calls[curr->target].push_back(curr);
119129
}
130+
if (curr->isReturn) {
131+
info->hasTailCalls = true;
132+
info->tailCallees.insert(curr->target);
133+
}
134+
}
135+
136+
void visitCallIndirect(CallIndirect* curr) {
137+
if (curr->isReturn) {
138+
info->hasTailCalls = true;
139+
}
120140
}
121141

122142
void visitDrop(Drop* curr) {
@@ -239,6 +259,7 @@ struct DAE : public Pass {
239259
DAEScanner(&infoMap).run(runner, module);
240260
// Combine all the info.
241261
std::unordered_map<Name, std::vector<Call*>> allCalls;
262+
std::unordered_set<Name> tailCallees;
242263
for (auto& pair : infoMap) {
243264
auto& info = pair.second;
244265
for (auto& pair : info.calls) {
@@ -247,6 +268,9 @@ struct DAE : public Pass {
247268
auto& allCallsToName = allCalls[name];
248269
allCallsToName.insert(allCallsToName.end(), calls.begin(), calls.end());
249270
}
271+
for (auto& callee : info.tailCallees) {
272+
tailCallees.insert(callee);
273+
}
250274
for (auto& pair : info.droppedCalls) {
251275
allDroppedCalls[pair.first] = pair.second;
252276
}
@@ -314,14 +338,11 @@ struct DAE : public Pass {
314338
// Great, it's not used. Check if none of the calls has a param with
315339
// side effects, as that would prevent us removing them (flattening
316340
// should have been done earlier).
317-
bool canRemove = true;
318-
for (auto* call : calls) {
319-
auto* operand = call->operands[i];
320-
if (EffectAnalyzer(runner->options, operand).hasSideEffects()) {
321-
canRemove = false;
322-
break;
323-
}
324-
}
341+
bool canRemove =
342+
std::none_of(calls.begin(), calls.end(), [&](Call* call) {
343+
auto* operand = call->operands[i];
344+
return EffectAnalyzer(runner->options, operand).hasSideEffects();
345+
});
325346
if (canRemove) {
326347
// Wonderful, nothing stands in our way! Do it.
327348
// TODO: parallelize this?
@@ -348,18 +369,21 @@ struct DAE : public Pass {
348369
if (infoMap[name].hasUnseenCalls) {
349370
continue;
350371
}
372+
if (infoMap[name].hasTailCalls) {
373+
continue;
374+
}
375+
if (tailCallees.find(name) != tailCallees.end()) {
376+
continue;
377+
}
351378
auto iter = allCalls.find(name);
352379
if (iter == allCalls.end()) {
353380
continue;
354381
}
355382
auto& calls = iter->second;
356-
bool allDropped = true;
357-
for (auto* call : calls) {
358-
if (!allDroppedCalls.count(call)) {
359-
allDropped = false;
360-
break;
361-
}
362-
}
383+
bool allDropped =
384+
std::all_of(calls.begin(), calls.end(), [&](Call* call) {
385+
return allDroppedCalls.count(call);
386+
});
363387
if (!allDropped) {
364388
continue;
365389
}

src/passes/DeadCodeElimination.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,12 @@ struct DeadCodeElimination
365365
return curr;
366366
}
367367

368-
void visitCall(Call* curr) { handleCall(curr); }
368+
void visitCall(Call* curr) {
369+
handleCall(curr);
370+
if (curr->isReturn) {
371+
reachable = false;
372+
}
373+
}
369374

370375
void visitCallIndirect(CallIndirect* curr) {
371376
if (handleCall(curr) != curr) {
@@ -380,6 +385,9 @@ struct DeadCodeElimination
380385
block->finalize(curr->type);
381386
replaceCurrent(block);
382387
}
388+
if (curr->isReturn) {
389+
reachable = false;
390+
}
383391
}
384392

385393
// Append the reachable operands of the current node to a block, and replace

src/passes/Directize.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ struct FunctionDirectizer : public WalkerPass<PostWalker<FunctionDirectizer>> {
6565
}
6666
// Everything looks good!
6767
replaceCurrent(
68-
Builder(*getModule()).makeCall(name, curr->operands, curr->type));
68+
Builder(*getModule())
69+
.makeCall(name, curr->operands, curr->type, curr->isReturn));
6970
}
7071
}
7172

src/passes/I64ToI32Lowering.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,14 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> {
262262
return call;
263263
}
264264
void visitCall(Call* curr) {
265+
if (curr->isReturn &&
266+
getModule()->getFunction(curr->target)->result == i64) {
267+
Fatal()
268+
<< "i64 to i32 lowering of return_call values not yet implemented";
269+
}
265270
auto* fixedCall = visitGenericCall<Call>(
266271
curr, [&](std::vector<Expression*>& args, Type ty) {
267-
return builder->makeCall(curr->target, args, ty);
272+
return builder->makeCall(curr->target, args, ty, curr->isReturn);
268273
});
269274
// If this was to an import, we need to call the legal version. This assumes
270275
// that legalize-js-interface has been run before.
@@ -275,10 +280,15 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> {
275280
}
276281

277282
void visitCallIndirect(CallIndirect* curr) {
283+
if (curr->isReturn &&
284+
getModule()->getFunctionType(curr->fullType)->result == i64) {
285+
Fatal()
286+
<< "i64 to i32 lowering of return_call values not yet implemented";
287+
}
278288
visitGenericCall<CallIndirect>(
279289
curr, [&](std::vector<Expression*>& args, Type ty) {
280290
return builder->makeCallIndirect(
281-
curr->fullType, curr->target, args, ty);
291+
curr->fullType, curr->target, args, ty, curr->isReturn);
282292
});
283293
}
284294

src/passes/Inlining.cpp

Lines changed: 67 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,17 @@ struct Planner : public WalkerPass<PostWalker<Planner>> {
146146
// plan to inline if we know this is valid to inline, and if the call is
147147
// actually performed - if it is dead code, it's pointless to inline.
148148
// we also cannot inline ourselves.
149-
if (state->worthInlining.count(curr->target) && curr->type != unreachable &&
149+
bool isUnreachable;
150+
if (curr->isReturn) {
151+
// Tail calls are only actually unreachable if an argument is
152+
isUnreachable =
153+
std::any_of(curr->operands.begin(),
154+
curr->operands.end(),
155+
[](Expression* op) { return op->type == unreachable; });
156+
} else {
157+
isUnreachable = curr->type == unreachable;
158+
}
159+
if (state->worthInlining.count(curr->target) && !isUnreachable &&
150160
curr->target != getFunction()->name) {
151161
// nest the call in a block. that way the location of the pointer to the
152162
// call will not change even if we inline multiple times into the same
@@ -164,32 +174,69 @@ struct Planner : public WalkerPass<PostWalker<Planner>> {
164174
InliningState* state;
165175
};
166176

177+
struct Updater : public PostWalker<Updater> {
178+
Module* module;
179+
std::map<Index, Index> localMapping;
180+
Name returnName;
181+
Builder* builder;
182+
void visitReturn(Return* curr) {
183+
replaceCurrent(builder->makeBreak(returnName, curr->value));
184+
}
185+
// Return calls in inlined functions should only break out of the scope of
186+
// the inlined code, not the entire function they are being inlined into. To
187+
// achieve this, make the call a non-return call and add a break. This does
188+
// not cause unbounded stack growth because inlining and return calling both
189+
// avoid creating a new stack frame.
190+
template<typename T> void handleReturnCall(T* curr, Type targetType) {
191+
curr->isReturn = false;
192+
curr->type = targetType;
193+
if (isConcreteType(targetType)) {
194+
replaceCurrent(builder->makeBreak(returnName, curr));
195+
} else {
196+
replaceCurrent(builder->blockify(curr, builder->makeBreak(returnName)));
197+
}
198+
}
199+
void visitCall(Call* curr) {
200+
if (curr->isReturn) {
201+
handleReturnCall(curr, module->getFunction(curr->target)->result);
202+
}
203+
}
204+
void visitCallIndirect(CallIndirect* curr) {
205+
if (curr->isReturn) {
206+
handleReturnCall(curr, module->getFunctionType(curr->fullType)->result);
207+
}
208+
}
209+
void visitLocalGet(LocalGet* curr) {
210+
curr->index = localMapping[curr->index];
211+
}
212+
void visitLocalSet(LocalSet* curr) {
213+
curr->index = localMapping[curr->index];
214+
}
215+
};
216+
167217
// Core inlining logic. Modifies the outside function (adding locals as
168218
// needed), and returns the inlined code.
169219
static Expression*
170220
doInlining(Module* module, Function* into, InliningAction& action) {
171221
Function* from = action.contents;
172222
auto* call = (*action.callSite)->cast<Call>();
223+
// Works for return_call, too
224+
Type retType = module->getFunction(call->target)->result;
173225
Builder builder(*module);
174-
auto* block = Builder(*module).makeBlock();
226+
auto* block = builder.makeBlock();
175227
block->name = Name(std::string("__inlined_func$") + from->name.str);
176-
*action.callSite = block;
177-
// Prepare to update the inlined code's locals and other things.
178-
struct Updater : public PostWalker<Updater> {
179-
std::map<Index, Index> localMapping;
180-
Name returnName;
181-
Builder* builder;
182-
183-
void visitReturn(Return* curr) {
184-
replaceCurrent(builder->makeBreak(returnName, curr->value));
228+
if (call->isReturn) {
229+
if (isConcreteType(retType)) {
230+
*action.callSite = builder.makeReturn(block);
231+
} else {
232+
*action.callSite = builder.makeSequence(block, builder.makeReturn());
185233
}
186-
void visitLocalGet(LocalGet* curr) {
187-
curr->index = localMapping[curr->index];
188-
}
189-
void visitLocalSet(LocalSet* curr) {
190-
curr->index = localMapping[curr->index];
191-
}
192-
} updater;
234+
} else {
235+
*action.callSite = block;
236+
}
237+
// Prepare to update the inlined code's locals and other things.
238+
Updater updater;
239+
updater.module = module;
193240
updater.returnName = block->name;
194241
updater.builder = &builder;
195242
// Set up a locals mapping
@@ -215,12 +262,12 @@ doInlining(Module* module, Function* into, InliningAction& action) {
215262
}
216263
updater.walk(contents);
217264
block->list.push_back(contents);
218-
block->type = call->type;
265+
block->type = retType;
219266
// If the function returned a value, we just set the block containing the
220267
// inlined code to have that type. or, if the function was void and
221268
// contained void, that is fine too. a bad case is a void function in which
222269
// we have unreachable code, so we would be replacing a void call with an
223-
// unreachable; we need to handle
270+
// unreachable.
224271
if (contents->type == unreachable && block->type == none) {
225272
// Make the block reachable by adding a break to it
226273
block->list.push_back(builder.makeBreak(block->name));

0 commit comments

Comments
 (0)