@@ -37,8 +37,11 @@ static Name STACK_SAVE("stackSave");
3737static Name STACK_RESTORE (" stackRestore" );
3838static Name STACK_ALLOC (" stackAlloc" );
3939static Name STACK_INIT (" stack$init" );
40+ static Name STACK_LIMIT (" __stack_limit" );
41+ static Name SET_STACK_LIMIT (" __set_stack_limit" );
4042static Name POST_INSTANTIATE (" __post_instantiate" );
4143static Name ASSIGN_GOT_ENTIRES (" __assign_got_enties" );
44+ static Name STACK_OVERFLOW_IMPORT (" __handle_stack_overflow" );
4245
4346void addExportedFunction (Module& wasm, Function* function) {
4447 wasm.addFunction (function);
@@ -92,8 +95,32 @@ Expression* EmscriptenGlueGenerator::generateLoadStackPointer() {
9295 return builder.makeGlobalGet (stackPointer->name , i32 );
9396}
9497
98+ inline Expression* stackBoundsCheck (Builder& builder,
99+ Function* func,
100+ Expression* value,
101+ Global* stackPointer,
102+ Global* stackLimit,
103+ Name handler) {
104+ // Add a local to store the value of the expression. We need the value twice:
105+ // once to check if it has overflowed, and again to assign to store it.
106+ auto newSP = Builder::addVar (func, stackPointer->type );
107+ // (if (i32.lt_u (local.tee $newSP (...value...)) (global.get $__stack_limit))
108+ // (call $handler))
109+ auto check =
110+ builder.makeIf (builder.makeBinary (
111+ BinaryOp::LtUInt32,
112+ builder.makeLocalTee (newSP, value),
113+ builder.makeGlobalGet (stackLimit->name , stackLimit->type )),
114+ builder.makeCall (handler, {}, none));
115+ // (global.set $__stack_pointer (local.get $newSP))
116+ auto newSet = builder.makeGlobalSet (
117+ stackPointer->name , builder.makeLocalGet (newSP, stackPointer->type ));
118+ return builder.blockify (check, newSet);
119+ }
120+
95121Expression*
96- EmscriptenGlueGenerator::generateStoreStackPointer (Expression* value) {
122+ EmscriptenGlueGenerator::generateStoreStackPointer (Function* func,
123+ Expression* value) {
97124 if (!useStackPointerGlobal) {
98125 return builder.makeStore (
99126 /* bytes =*/ 4 ,
@@ -107,6 +134,14 @@ EmscriptenGlueGenerator::generateStoreStackPointer(Expression* value) {
107134 if (!stackPointer) {
108135 Fatal () << " stack pointer global not found" ;
109136 }
137+ if (auto * stackLimit = wasm.getGlobalOrNull (STACK_LIMIT)) {
138+ return stackBoundsCheck (builder,
139+ func,
140+ value,
141+ stackPointer,
142+ stackLimit,
143+ importStackOverflowHandler ());
144+ }
110145 return builder.makeGlobalSet (stackPointer->name , value);
111146}
112147
@@ -132,7 +167,7 @@ void EmscriptenGlueGenerator::generateStackAllocFunction() {
132167 Const* subConst = builder.makeConst (Literal (~bitMask));
133168 Binary* maskedSub = builder.makeBinary (AndInt32, sub, subConst);
134169 LocalSet* teeStackLocal = builder.makeLocalTee (1 , maskedSub);
135- Expression* storeStack = generateStoreStackPointer (teeStackLocal);
170+ Expression* storeStack = generateStoreStackPointer (function, teeStackLocal);
136171
137172 Block* block = builder.makeBlock ();
138173 block->list .push_back (storeStack);
@@ -149,7 +184,7 @@ void EmscriptenGlueGenerator::generateStackRestoreFunction() {
149184 Function* function =
150185 builder.makeFunction (STACK_RESTORE, std::move (params), none, {});
151186 LocalGet* getArg = builder.makeLocalGet (0 , i32 );
152- Expression* store = generateStoreStackPointer (getArg);
187+ Expression* store = generateStoreStackPointer (function, getArg);
153188
154189 function->body = store;
155190
@@ -444,6 +479,86 @@ void EmscriptenGlueGenerator::replaceStackPointerGlobal() {
444479 wasm.removeGlobal (stackPointer->name );
445480}
446481
482+ struct StackLimitEnforcer : public WalkerPass <PostWalker<StackLimitEnforcer>> {
483+ StackLimitEnforcer (Global* stackPointer,
484+ Global* stackLimit,
485+ Builder& builder,
486+ Name handler)
487+ : stackPointer(stackPointer), stackLimit(stackLimit), builder(builder),
488+ handler (handler) {}
489+
490+ bool isFunctionParallel () override { return true ; }
491+
492+ Pass* create () override {
493+ return new StackLimitEnforcer (stackPointer, stackLimit, builder, handler);
494+ }
495+
496+ void visitGlobalSet (GlobalSet* curr) {
497+ if (getModule ()->getGlobalOrNull (curr->name ) == stackPointer) {
498+ replaceCurrent (stackBoundsCheck (builder,
499+ getFunction (),
500+ curr->value ,
501+ stackPointer,
502+ stackLimit,
503+ handler));
504+ }
505+ }
506+
507+ private:
508+ Global* stackPointer;
509+ Global* stackLimit;
510+ Builder& builder;
511+ Name handler;
512+ };
513+
514+ void EmscriptenGlueGenerator::enforceStackLimit () {
515+ Global* stackPointer = getStackPointerGlobal ();
516+ if (!stackPointer) {
517+ return ;
518+ }
519+
520+ auto * stackLimit = builder.makeGlobal (STACK_LIMIT,
521+ stackPointer->type ,
522+ builder.makeConst (Literal (0 )),
523+ Builder::Mutable);
524+ wasm.addGlobal (stackLimit);
525+
526+ auto handler = importStackOverflowHandler ();
527+
528+ StackLimitEnforcer walker (stackPointer, stackLimit, builder, handler);
529+ PassRunner runner (&wasm);
530+ walker.run (&runner, &wasm);
531+
532+ generateSetStackLimitFunction ();
533+ }
534+
535+ void EmscriptenGlueGenerator::generateSetStackLimitFunction () {
536+ Function* function =
537+ builder.makeFunction (SET_STACK_LIMIT, std::vector<Type>({i32 }), none, {});
538+ LocalGet* getArg = builder.makeLocalGet (0 , i32 );
539+ Expression* store = builder.makeGlobalSet (STACK_LIMIT, getArg);
540+ function->body = store;
541+ addExportedFunction (wasm, function);
542+ }
543+
544+ Name EmscriptenGlueGenerator::importStackOverflowHandler () {
545+ ImportInfo info (wasm);
546+
547+ if (auto * existing = info.getImportedFunction (ENV, STACK_OVERFLOW_IMPORT)) {
548+ return existing->name ;
549+ } else {
550+ auto * import = new Function;
551+ import ->name = STACK_OVERFLOW_IMPORT;
552+ import ->module = ENV;
553+ import ->base = STACK_OVERFLOW_IMPORT;
554+ auto * functionType = ensureFunctionType (" v" , &wasm);
555+ import ->type = functionType->name ;
556+ FunctionTypeUtils::fillFunction (import , functionType);
557+ wasm.addFunction (import );
558+ return STACK_OVERFLOW_IMPORT;
559+ }
560+ }
561+
447562const Address UNKNOWN_OFFSET (uint32_t (-1 ));
448563
449564std::vector<Address> getSegmentOffsets (Module& wasm) {
0 commit comments