Skip to content

Commit 4f0d960

Browse files
authored
Implement --check-stack-overflow flag for wasm-emscripten-finalize (#2278)
1 parent 8d4d43f commit 4f0d960

6 files changed

Lines changed: 475 additions & 8 deletions

File tree

scripts/test/lld.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,12 @@
2222

2323

2424
def args_for_finalize(filename):
25-
if 'shared' in filename:
26-
return ['--side-module']
27-
else:
28-
return ['--global-base=568']
25+
if 'safe_stack' in filename:
26+
return ['--check-stack-overflow', '--global-base=568']
27+
elif 'shared' in filename:
28+
return ['--side-module']
29+
else:
30+
return ['--global-base=568']
2931

3032

3133
def test_wasm_emscripten_finalize():

src/tools/wasm-emscripten-finalize.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ int main(int argc, const char* argv[]) {
4848
bool debugInfo = false;
4949
bool isSideModule = false;
5050
bool legalizeJavaScriptFFI = true;
51+
bool checkStackOverflow = false;
5152
uint64_t globalBase = INVALID_BASE;
5253
ToolOptions options("wasm-emscripten-finalize",
5354
"Performs Emscripten-specific transforms on .wasm files");
@@ -127,6 +128,13 @@ int main(int argc, const char* argv[]) {
127128
[&dataSegmentFile](Options* o, const std::string& argument) {
128129
dataSegmentFile = argument;
129130
})
131+
.add("--check-stack-overflow",
132+
"",
133+
"Check for stack overflows every time the stack is extended",
134+
Options::Arguments::Zero,
135+
[&checkStackOverflow](Options* o, const std::string&) {
136+
checkStackOverflow = true;
137+
})
130138
.add_positional("INFILE",
131139
Options::Arguments::One,
132140
[&infile](Options* o, const std::string& argument) {
@@ -200,6 +208,10 @@ int main(int argc, const char* argv[]) {
200208
}
201209
wasm.updateMaps();
202210

211+
if (checkStackOverflow && !isSideModule) {
212+
generator.enforceStackLimit();
213+
}
214+
203215
if (isSideModule) {
204216
generator.replaceStackPointerGlobal();
205217
generator.generatePostInstantiateFunction();

src/wasm-emscripten.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ class EmscriptenGlueGenerator {
5454

5555
void fixInvokeFunctionNames();
5656

57+
void enforceStackLimit();
58+
5759
// Emits the data segments to a file. The file contains data from address base
5860
// onwards (we must pass in base, as we can't tell it from the wasm - the
5961
// first segment may start after a run of zeros, but we need those zeros in
@@ -71,11 +73,13 @@ class EmscriptenGlueGenerator {
7173

7274
Global* getStackPointerGlobal();
7375
Expression* generateLoadStackPointer();
74-
Expression* generateStoreStackPointer(Expression* value);
76+
Expression* generateStoreStackPointer(Function* func, Expression* value);
7577
void generateDynCallThunk(std::string sig);
7678
void generateStackSaveFunction();
7779
void generateStackAllocFunction();
7880
void generateStackRestoreFunction();
81+
void generateSetStackLimitFunction();
82+
Name importStackOverflowHandler();
7983
};
8084

8185
} // namespace wasm

src/wasm/wasm-emscripten.cpp

Lines changed: 118 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,11 @@ static Name STACK_SAVE("stackSave");
3737
static Name STACK_RESTORE("stackRestore");
3838
static Name STACK_ALLOC("stackAlloc");
3939
static Name STACK_INIT("stack$init");
40+
static Name STACK_LIMIT("__stack_limit");
41+
static Name SET_STACK_LIMIT("__set_stack_limit");
4042
static Name POST_INSTANTIATE("__post_instantiate");
4143
static Name ASSIGN_GOT_ENTIRES("__assign_got_enties");
44+
static Name STACK_OVERFLOW_IMPORT("__handle_stack_overflow");
4245

4346
void addExportedFunction(Module& wasm, Function* function) {
4447
wasm.addFunction(function);
@@ -92,8 +95,32 @@ Expression* EmscriptenGlueGenerator::generateLoadStackPointer() {
9295
return builder.makeGlobalGet(stackPointer->name, i32);
9396
}
9497

98+
inline Expression* stackBoundsCheck(Builder& builder,
99+
Function* func,
100+
Expression* value,
101+
Global* stackPointer,
102+
Global* stackLimit,
103+
Name handler) {
104+
// Add a local to store the value of the expression. We need the value twice:
105+
// once to check if it has overflowed, and again to assign to store it.
106+
auto newSP = Builder::addVar(func, stackPointer->type);
107+
// (if (i32.lt_u (local.tee $newSP (...value...)) (global.get $__stack_limit))
108+
// (call $handler))
109+
auto check =
110+
builder.makeIf(builder.makeBinary(
111+
BinaryOp::LtUInt32,
112+
builder.makeLocalTee(newSP, value),
113+
builder.makeGlobalGet(stackLimit->name, stackLimit->type)),
114+
builder.makeCall(handler, {}, none));
115+
// (global.set $__stack_pointer (local.get $newSP))
116+
auto newSet = builder.makeGlobalSet(
117+
stackPointer->name, builder.makeLocalGet(newSP, stackPointer->type));
118+
return builder.blockify(check, newSet);
119+
}
120+
95121
Expression*
96-
EmscriptenGlueGenerator::generateStoreStackPointer(Expression* value) {
122+
EmscriptenGlueGenerator::generateStoreStackPointer(Function* func,
123+
Expression* value) {
97124
if (!useStackPointerGlobal) {
98125
return builder.makeStore(
99126
/* bytes =*/4,
@@ -107,6 +134,14 @@ EmscriptenGlueGenerator::generateStoreStackPointer(Expression* value) {
107134
if (!stackPointer) {
108135
Fatal() << "stack pointer global not found";
109136
}
137+
if (auto* stackLimit = wasm.getGlobalOrNull(STACK_LIMIT)) {
138+
return stackBoundsCheck(builder,
139+
func,
140+
value,
141+
stackPointer,
142+
stackLimit,
143+
importStackOverflowHandler());
144+
}
110145
return builder.makeGlobalSet(stackPointer->name, value);
111146
}
112147

@@ -132,7 +167,7 @@ void EmscriptenGlueGenerator::generateStackAllocFunction() {
132167
Const* subConst = builder.makeConst(Literal(~bitMask));
133168
Binary* maskedSub = builder.makeBinary(AndInt32, sub, subConst);
134169
LocalSet* teeStackLocal = builder.makeLocalTee(1, maskedSub);
135-
Expression* storeStack = generateStoreStackPointer(teeStackLocal);
170+
Expression* storeStack = generateStoreStackPointer(function, teeStackLocal);
136171

137172
Block* block = builder.makeBlock();
138173
block->list.push_back(storeStack);
@@ -149,7 +184,7 @@ void EmscriptenGlueGenerator::generateStackRestoreFunction() {
149184
Function* function =
150185
builder.makeFunction(STACK_RESTORE, std::move(params), none, {});
151186
LocalGet* getArg = builder.makeLocalGet(0, i32);
152-
Expression* store = generateStoreStackPointer(getArg);
187+
Expression* store = generateStoreStackPointer(function, getArg);
153188

154189
function->body = store;
155190

@@ -444,6 +479,86 @@ void EmscriptenGlueGenerator::replaceStackPointerGlobal() {
444479
wasm.removeGlobal(stackPointer->name);
445480
}
446481

482+
struct StackLimitEnforcer : public WalkerPass<PostWalker<StackLimitEnforcer>> {
483+
StackLimitEnforcer(Global* stackPointer,
484+
Global* stackLimit,
485+
Builder& builder,
486+
Name handler)
487+
: stackPointer(stackPointer), stackLimit(stackLimit), builder(builder),
488+
handler(handler) {}
489+
490+
bool isFunctionParallel() override { return true; }
491+
492+
Pass* create() override {
493+
return new StackLimitEnforcer(stackPointer, stackLimit, builder, handler);
494+
}
495+
496+
void visitGlobalSet(GlobalSet* curr) {
497+
if (getModule()->getGlobalOrNull(curr->name) == stackPointer) {
498+
replaceCurrent(stackBoundsCheck(builder,
499+
getFunction(),
500+
curr->value,
501+
stackPointer,
502+
stackLimit,
503+
handler));
504+
}
505+
}
506+
507+
private:
508+
Global* stackPointer;
509+
Global* stackLimit;
510+
Builder& builder;
511+
Name handler;
512+
};
513+
514+
void EmscriptenGlueGenerator::enforceStackLimit() {
515+
Global* stackPointer = getStackPointerGlobal();
516+
if (!stackPointer) {
517+
return;
518+
}
519+
520+
auto* stackLimit = builder.makeGlobal(STACK_LIMIT,
521+
stackPointer->type,
522+
builder.makeConst(Literal(0)),
523+
Builder::Mutable);
524+
wasm.addGlobal(stackLimit);
525+
526+
auto handler = importStackOverflowHandler();
527+
528+
StackLimitEnforcer walker(stackPointer, stackLimit, builder, handler);
529+
PassRunner runner(&wasm);
530+
walker.run(&runner, &wasm);
531+
532+
generateSetStackLimitFunction();
533+
}
534+
535+
void EmscriptenGlueGenerator::generateSetStackLimitFunction() {
536+
Function* function =
537+
builder.makeFunction(SET_STACK_LIMIT, std::vector<Type>({i32}), none, {});
538+
LocalGet* getArg = builder.makeLocalGet(0, i32);
539+
Expression* store = builder.makeGlobalSet(STACK_LIMIT, getArg);
540+
function->body = store;
541+
addExportedFunction(wasm, function);
542+
}
543+
544+
Name EmscriptenGlueGenerator::importStackOverflowHandler() {
545+
ImportInfo info(wasm);
546+
547+
if (auto* existing = info.getImportedFunction(ENV, STACK_OVERFLOW_IMPORT)) {
548+
return existing->name;
549+
} else {
550+
auto* import = new Function;
551+
import->name = STACK_OVERFLOW_IMPORT;
552+
import->module = ENV;
553+
import->base = STACK_OVERFLOW_IMPORT;
554+
auto* functionType = ensureFunctionType("v", &wasm);
555+
import->type = functionType->name;
556+
FunctionTypeUtils::fillFunction(import, functionType);
557+
wasm.addFunction(import);
558+
return STACK_OVERFLOW_IMPORT;
559+
}
560+
}
561+
447562
const Address UNKNOWN_OFFSET(uint32_t(-1));
448563

449564
std::vector<Address> getSegmentOffsets(Module& wasm) {

test/lld/recursive_safe_stack.wast

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
(module
2+
(type $0 (func (param i32 i32) (result i32)))
3+
(type $1 (func))
4+
(type $2 (func (result i32)))
5+
(import "env" "printf" (func $printf (param i32 i32) (result i32)))
6+
(memory $0 2)
7+
(data (i32.const 568) "%d:%d\n\00Result: %d\n\00")
8+
(table $0 1 1 funcref)
9+
(global $global$0 (mut i32) (i32.const 66128))
10+
(global $global$1 i32 (i32.const 66128))
11+
(global $global$2 i32 (i32.const 587))
12+
(export "memory" (memory $0))
13+
(export "__wasm_call_ctors" (func $__wasm_call_ctors))
14+
(export "__heap_base" (global $global$1))
15+
(export "__data_end" (global $global$2))
16+
(export "main" (func $main))
17+
(func $__wasm_call_ctors (; 1 ;) (type $1)
18+
)
19+
(func $foo (; 2 ;) (type $0) (param $0 i32) (param $1 i32) (result i32)
20+
(local $2 i32)
21+
(global.set $global$0
22+
(local.tee $2
23+
(i32.sub
24+
(global.get $global$0)
25+
(i32.const 16)
26+
)
27+
)
28+
)
29+
(i32.store offset=4
30+
(local.get $2)
31+
(local.get $1)
32+
)
33+
(i32.store
34+
(local.get $2)
35+
(local.get $0)
36+
)
37+
(drop
38+
(call $printf
39+
(i32.const 568)
40+
(local.get $2)
41+
)
42+
)
43+
(global.set $global$0
44+
(i32.add
45+
(local.get $2)
46+
(i32.const 16)
47+
)
48+
)
49+
(i32.add
50+
(local.get $1)
51+
(local.get $0)
52+
)
53+
)
54+
(func $__original_main (; 3 ;) (type $2) (result i32)
55+
(local $0 i32)
56+
(global.set $global$0
57+
(local.tee $0
58+
(i32.sub
59+
(global.get $global$0)
60+
(i32.const 16)
61+
)
62+
)
63+
)
64+
(i32.store
65+
(local.get $0)
66+
(call $foo
67+
(i32.const 1)
68+
(i32.const 2)
69+
)
70+
)
71+
(drop
72+
(call $printf
73+
(i32.const 575)
74+
(local.get $0)
75+
)
76+
)
77+
(global.set $global$0
78+
(i32.add
79+
(local.get $0)
80+
(i32.const 16)
81+
)
82+
)
83+
(i32.const 0)
84+
)
85+
(func $main (; 4 ;) (type $0) (param $0 i32) (param $1 i32) (result i32)
86+
(call $__original_main)
87+
)
88+
;; custom section "producers", size 111
89+
)
90+

0 commit comments

Comments
 (0)