Skip to content

Commit 059e6e3

Browse files
authored
Fix atomics refinalization (we were missing some glue) (#1241)
And add a visitor which must override all its elements, so this never happens again
1 parent de4b36f commit 059e6e3

2 files changed

Lines changed: 118 additions & 2 deletions

File tree

src/ast_utils.h

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ struct ExpressionAnalyzer {
7676
// vs
7777
// (block (unreachable))
7878
// This converts to the latter form.
79-
struct ReFinalize : public WalkerPass<PostWalker<ReFinalize>> {
79+
struct ReFinalize : public WalkerPass<PostWalker<ReFinalize, OverriddenVisitor<ReFinalize>>> {
8080
bool isFunctionParallel() override { return true; }
8181

8282
Pass* create() override { return new ReFinalize; }
@@ -154,6 +154,8 @@ struct ReFinalize : public WalkerPass<PostWalker<ReFinalize>> {
154154
void visitStore(Store *curr) { curr->finalize(); }
155155
void visitAtomicRMW(AtomicRMW *curr) { curr->finalize(); }
156156
void visitAtomicCmpxchg(AtomicCmpxchg *curr) { curr->finalize(); }
157+
void visitAtomicWait(AtomicWait* curr) { curr->finalize(); }
158+
void visitAtomicWake(AtomicWake* curr) { curr->finalize(); }
157159
void visitConst(Const *curr) { curr->finalize(); }
158160
void visitUnary(Unary *curr) { curr->finalize(); }
159161
void visitBinary(Binary *curr) { curr->finalize(); }
@@ -173,6 +175,14 @@ struct ReFinalize : public WalkerPass<PostWalker<ReFinalize>> {
173175
}
174176
}
175177

178+
void visitFunctionType(FunctionType* curr) { WASM_UNREACHABLE(); }
179+
void visitImport(Import* curr) { WASM_UNREACHABLE(); }
180+
void visitExport(Export* curr) { WASM_UNREACHABLE(); }
181+
void visitGlobal(Global* curr) { WASM_UNREACHABLE(); }
182+
void visitTable(Table* curr) { WASM_UNREACHABLE(); }
183+
void visitMemory(Memory* curr) { WASM_UNREACHABLE(); }
184+
void visitModule(Module* curr) { WASM_UNREACHABLE(); }
185+
176186
WasmType getValueType(Expression* value) {
177187
return value ? value->type : none;
178188
}
@@ -186,7 +196,7 @@ struct ReFinalize : public WalkerPass<PostWalker<ReFinalize>> {
186196

187197
// Re-finalize a single node. This is slow, if you want to refinalize
188198
// an entire ast, use ReFinalize
189-
struct ReFinalizeNode : public Visitor<ReFinalizeNode> {
199+
struct ReFinalizeNode : public OverriddenVisitor<ReFinalizeNode> {
190200
void visitBlock(Block *curr) { curr->finalize(); }
191201
void visitIf(If *curr) { curr->finalize(); }
192202
void visitLoop(Loop *curr) { curr->finalize(); }
@@ -201,6 +211,10 @@ struct ReFinalizeNode : public Visitor<ReFinalizeNode> {
201211
void visitSetGlobal(SetGlobal *curr) { curr->finalize(); }
202212
void visitLoad(Load *curr) { curr->finalize(); }
203213
void visitStore(Store *curr) { curr->finalize(); }
214+
void visitAtomicRMW(AtomicRMW* curr) { curr->finalize(); }
215+
void visitAtomicCmpxchg(AtomicCmpxchg* curr) { curr->finalize(); }
216+
void visitAtomicWait(AtomicWait* curr) { curr->finalize(); }
217+
void visitAtomicWake(AtomicWake* curr) { curr->finalize(); }
204218
void visitConst(Const *curr) { curr->finalize(); }
205219
void visitUnary(Unary *curr) { curr->finalize(); }
206220
void visitBinary(Binary *curr) { curr->finalize(); }
@@ -211,6 +225,14 @@ struct ReFinalizeNode : public Visitor<ReFinalizeNode> {
211225
void visitNop(Nop *curr) { curr->finalize(); }
212226
void visitUnreachable(Unreachable *curr) { curr->finalize(); }
213227

228+
void visitFunctionType(FunctionType* curr) { WASM_UNREACHABLE(); }
229+
void visitImport(Import* curr) { WASM_UNREACHABLE(); }
230+
void visitExport(Export* curr) { WASM_UNREACHABLE(); }
231+
void visitGlobal(Global* curr) { WASM_UNREACHABLE(); }
232+
void visitTable(Table* curr) { WASM_UNREACHABLE(); }
233+
void visitMemory(Memory* curr) { WASM_UNREACHABLE(); }
234+
void visitModule(Module* curr) { WASM_UNREACHABLE(); }
235+
214236
// given a stack of nested expressions, update them all from child to parent
215237
static void updateStack(std::vector<Expression*>& expressionStack) {
216238
for (int i = int(expressionStack.size()) - 1; i >= 0; i--) {

src/wasm-traversal.h

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232

3333
namespace wasm {
3434

35+
// A generic visitor, defaulting to doing nothing on each visit
36+
3537
template<typename SubType, typename ReturnType = void>
3638
struct Visitor {
3739
// Expression visitors
@@ -115,6 +117,98 @@ struct Visitor {
115117
}
116118
};
117119

120+
// A visitor which must be overridden for each visitor that is reached.
121+
122+
template<typename SubType, typename ReturnType = void>
123+
struct OverriddenVisitor {
124+
// Expression visitors, which must be overridden
125+
#define UNIMPLEMENTED(CLASS_TO_VISIT) \
126+
ReturnType visit##CLASS_TO_VISIT(CLASS_TO_VISIT* curr) { \
127+
static_assert(&SubType::visit##CLASS_TO_VISIT != &OverriddenVisitor<SubType, ReturnType>::visit##CLASS_TO_VISIT, "Derived class must implement visit" #CLASS_TO_VISIT); \
128+
WASM_UNREACHABLE(); \
129+
}
130+
131+
UNIMPLEMENTED(Block);
132+
UNIMPLEMENTED(If);
133+
UNIMPLEMENTED(Loop);
134+
UNIMPLEMENTED(Break);
135+
UNIMPLEMENTED(Switch);
136+
UNIMPLEMENTED(Call);
137+
UNIMPLEMENTED(CallImport);
138+
UNIMPLEMENTED(CallIndirect);
139+
UNIMPLEMENTED(GetLocal);
140+
UNIMPLEMENTED(SetLocal);
141+
UNIMPLEMENTED(GetGlobal);
142+
UNIMPLEMENTED(SetGlobal);
143+
UNIMPLEMENTED(Load);
144+
UNIMPLEMENTED(Store);
145+
UNIMPLEMENTED(AtomicRMW);
146+
UNIMPLEMENTED(AtomicCmpxchg);
147+
UNIMPLEMENTED(AtomicWait);
148+
UNIMPLEMENTED(AtomicWake);
149+
UNIMPLEMENTED(Const);
150+
UNIMPLEMENTED(Unary);
151+
UNIMPLEMENTED(Binary);
152+
UNIMPLEMENTED(Select);
153+
UNIMPLEMENTED(Drop);
154+
UNIMPLEMENTED(Return);
155+
UNIMPLEMENTED(Host);
156+
UNIMPLEMENTED(Nop);
157+
UNIMPLEMENTED(Unreachable);
158+
UNIMPLEMENTED(FunctionType);
159+
UNIMPLEMENTED(Import);
160+
UNIMPLEMENTED(Export);
161+
UNIMPLEMENTED(Global);
162+
UNIMPLEMENTED(Function);
163+
UNIMPLEMENTED(Table);
164+
UNIMPLEMENTED(Memory);
165+
UNIMPLEMENTED(Module);
166+
167+
#undef UNIMPLEMENTED
168+
169+
ReturnType visit(Expression* curr) {
170+
assert(curr);
171+
172+
#define DELEGATE(CLASS_TO_VISIT) \
173+
return static_cast<SubType*>(this)-> \
174+
visit##CLASS_TO_VISIT(static_cast<CLASS_TO_VISIT*>(curr))
175+
176+
switch (curr->_id) {
177+
case Expression::Id::BlockId: DELEGATE(Block);
178+
case Expression::Id::IfId: DELEGATE(If);
179+
case Expression::Id::LoopId: DELEGATE(Loop);
180+
case Expression::Id::BreakId: DELEGATE(Break);
181+
case Expression::Id::SwitchId: DELEGATE(Switch);
182+
case Expression::Id::CallId: DELEGATE(Call);
183+
case Expression::Id::CallImportId: DELEGATE(CallImport);
184+
case Expression::Id::CallIndirectId: DELEGATE(CallIndirect);
185+
case Expression::Id::GetLocalId: DELEGATE(GetLocal);
186+
case Expression::Id::SetLocalId: DELEGATE(SetLocal);
187+
case Expression::Id::GetGlobalId: DELEGATE(GetGlobal);
188+
case Expression::Id::SetGlobalId: DELEGATE(SetGlobal);
189+
case Expression::Id::LoadId: DELEGATE(Load);
190+
case Expression::Id::StoreId: DELEGATE(Store);
191+
case Expression::Id::AtomicRMWId: DELEGATE(AtomicRMW);
192+
case Expression::Id::AtomicCmpxchgId: DELEGATE(AtomicCmpxchg);
193+
case Expression::Id::AtomicWaitId: DELEGATE(AtomicWait);
194+
case Expression::Id::AtomicWakeId: DELEGATE(AtomicWake);
195+
case Expression::Id::ConstId: DELEGATE(Const);
196+
case Expression::Id::UnaryId: DELEGATE(Unary);
197+
case Expression::Id::BinaryId: DELEGATE(Binary);
198+
case Expression::Id::SelectId: DELEGATE(Select);
199+
case Expression::Id::DropId: DELEGATE(Drop);
200+
case Expression::Id::ReturnId: DELEGATE(Return);
201+
case Expression::Id::HostId: DELEGATE(Host);
202+
case Expression::Id::NopId: DELEGATE(Nop);
203+
case Expression::Id::UnreachableId: DELEGATE(Unreachable);
204+
case Expression::Id::InvalidId:
205+
default: WASM_UNREACHABLE();
206+
}
207+
208+
#undef DELEGATE
209+
}
210+
};
211+
118212
// Visit with a single unified visitor, called on every node, instead of
119213
// separate visit* per node
120214

0 commit comments

Comments
 (0)