Skip to content

Commit 51f2694

Browse files
committed
refactor and improve break validation. breaks names are unique, so we don't need a stack, and break targets must exist even if they are not actually taken
1 parent 18e096c commit 51f2694

4 files changed

Lines changed: 33 additions & 18 deletions

File tree

src/ast/branch-utils.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,15 @@ inline bool isBranchTaken(Switch* sw) {
3737
sw->condition->type != unreachable;
3838
}
3939

40+
inline bool isBranchTaken(Expression* expr) {
41+
if (auto* br = expr->dynCast<Break>()) {
42+
return isBranchTaken(br);
43+
} else if (auto* sw = expr->dynCast<Switch>()) {
44+
return isBranchTaken(sw);
45+
}
46+
WASM_UNREACHABLE();
47+
}
48+
4049
// returns the set of targets to which we branch that are
4150
// outside of a node
4251
inline std::set<Name> getExitingBranches(Expression* ast) {

src/binaryen-c.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "wasm-s-parser.h"
3131
#include "wasm-validator.h"
3232
#include "cfg/Relooper.h"
33+
#include "ast_utils.h"
3334
#include "shell-interface.h"
3435

3536
using namespace wasm;

src/wasm-validator.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,9 @@ struct WasmValidator : public PostWalker<WasmValidator> {
5858
BreakInfo(WasmType type, Index arity) : type(type), arity(arity) {}
5959
};
6060

61-
std::map<Name, std::vector<Expression*>> breakTargets; // more than one block/loop may use a label name, so stack them
61+
std::map<Name, Expression*> breakTargets;
6262
std::map<Expression*, BreakInfo> breakInfos;
63+
std::set<Name> namedBreakTargets; // even breaks not taken must not be named if they go to a place that does not exist
6364

6465
WasmType returnType = unreachable; // type used in returns
6566

@@ -91,14 +92,14 @@ struct WasmValidator : public PostWalker<WasmValidator> {
9192

9293
static void visitPreBlock(WasmValidator* self, Expression** currp) {
9394
auto* curr = (*currp)->cast<Block>();
94-
if (curr->name.is()) self->breakTargets[curr->name].push_back(curr);
95+
if (curr->name.is()) self->breakTargets[curr->name] = curr;
9596
}
9697

9798
void visitBlock(Block *curr);
9899

99100
static void visitPreLoop(WasmValidator* self, Expression** currp) {
100101
auto* curr = (*currp)->cast<Loop>();
101-
if (curr->name.is()) self->breakTargets[curr->name].push_back(curr);
102+
if (curr->name.is()) self->breakTargets[curr->name] = curr;
102103
}
103104

104105
void visitLoop(Loop *curr);

src/wasm/wasm-validator.cpp

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ void WasmValidator::visitBlock(Block *curr) {
5757
}
5858
}
5959
}
60-
breakTargets[curr->name].pop_back();
60+
breakTargets.erase(curr->name);
61+
namedBreakTargets.erase(curr->name);
6162
}
6263
if (curr->list.size() > 1) {
6364
for (Index i = 0; i < curr->list.size() - 1; i++) {
@@ -88,7 +89,8 @@ void WasmValidator::visitBlock(Block *curr) {
8889
void WasmValidator::visitLoop(Loop *curr) {
8990
if (curr->name.is()) {
9091
noteLabelName(curr->name);
91-
breakTargets[curr->name].pop_back();
92+
breakTargets.erase(curr->name);
93+
namedBreakTargets.erase(curr->name);
9294
if (breakInfos.count(curr) > 0) {
9395
auto& info = breakInfos[curr];
9496
shouldBeEqual(info.arity, Index(0), curr, "breaks to a loop cannot pass a value");
@@ -120,15 +122,20 @@ void WasmValidator::visitIf(If *curr) {
120122
}
121123

122124
void WasmValidator::noteBreak(Name name, Expression* value, Expression* curr) {
125+
if (!BranchUtils::isBranchTaken(curr)) {
126+
// if not actually taken, just note the name
127+
namedBreakTargets.insert(name);
128+
return;
129+
}
123130
WasmType valueType = none;
124131
Index arity = 0;
125132
if (value) {
126133
valueType = value->type;
127134
shouldBeUnequal(valueType, none, curr, "breaks must have a valid value");
128135
arity = 1;
129136
}
130-
if (!shouldBeTrue(breakTargets[name].size() > 0, curr, "all break targets must be valid")) return;
131-
auto* target = breakTargets[name].back();
137+
if (!shouldBeTrue(breakTargets.count(name) > 0, curr, "all break targets must be valid")) return;
138+
auto* target = breakTargets[name];
132139
if (breakInfos.count(target) == 0) {
133140
breakInfos[target] = BreakInfo(valueType, arity);
134141
} else {
@@ -146,23 +153,17 @@ void WasmValidator::noteBreak(Name name, Expression* value, Expression* curr) {
146153
}
147154
}
148155
void WasmValidator::visitBreak(Break *curr) {
149-
// note breaks (that are actually taken)
150-
if (BranchUtils::isBranchTaken(curr)) {
151-
noteBreak(curr->name, curr->value, curr);
152-
}
156+
noteBreak(curr->name, curr->value, curr);
153157
if (curr->condition) {
154158
shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "break condition must be i32");
155159
}
156160
}
157161

158162
void WasmValidator::visitSwitch(Switch *curr) {
159-
// note breaks (that are actually taken)
160-
if (BranchUtils::isBranchTaken(curr)) {
161-
for (auto& target : curr->targets) {
162-
noteBreak(target, curr->value, curr);
163-
}
164-
noteBreak(curr->default_, curr->value, curr);
163+
for (auto& target : curr->targets) {
164+
noteBreak(target, curr->value, curr);
165165
}
166+
noteBreak(curr->default_, curr->value, curr);
166167
shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "br_table condition must be i32");
167168
}
168169
void WasmValidator::visitCall(Call *curr) {
@@ -467,7 +468,7 @@ void WasmValidator::visitGlobal(Global* curr) {
467468
shouldBeTrue(curr->init != nullptr, curr->name, "global init must be non-null");
468469
shouldBeTrue(curr->init->is<Const>() || curr->init->is<GetGlobal>(), curr->name, "global init must be valid");
469470
if (!shouldBeEqual(curr->type, curr->init->type, curr->init, "global init must have correct type")) {
470-
std::cerr << "(on global " << curr->name << '\n';
471+
std::cerr << "(on global " << curr->name << ")\n";
471472
}
472473
}
473474

@@ -480,6 +481,9 @@ void WasmValidator::visitFunction(Function *curr) {
480481
if (returnType != unreachable) {
481482
shouldBeEqual(curr->result, returnType, curr->body, "function result must match, if function has returns");
482483
}
484+
if (!shouldBeTrue(namedBreakTargets.empty(), curr->body, "all named break targets must exist (even if not taken)")) {
485+
std::cerr << "(on label " << *namedBreakTargets.begin() << ")\n";
486+
}
483487
returnType = unreachable;
484488
labelNames.clear();
485489
}

0 commit comments

Comments
 (0)