Skip to content

Commit ab8dbae

Browse files
authored
Optimizer support for atomic instructions (#1094)
* Teach EffectAnalyzer not to reorder atomics wrt other memory operations. * Teach EffectAnalyzer not to reorder host operations with memory operations * Teach various passes about the operands of AtomicRMW and AtomicCmpxchg * Factor out some functions in DeadCodeElimination and MergeBlocks
1 parent da680fd commit ab8dbae

13 files changed

Lines changed: 414 additions & 77 deletions

src/ast/ExpressionManipulator.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,27 @@ Expression* flexibleCopy(Expression* original, Module& wasm, CustomCopier custom
9696
return builder.makeSetGlobal(curr->name, copy(curr->value));
9797
}
9898
Expression* visitLoad(Load *curr) {
99+
if (curr->isAtomic) {
100+
return builder.makeAtomicLoad(curr->bytes, curr->signed_, curr->offset,
101+
copy(curr->ptr), curr->type);
102+
}
99103
return builder.makeLoad(curr->bytes, curr->signed_, curr->offset, curr->align, copy(curr->ptr), curr->type);
100104
}
101105
Expression* visitStore(Store *curr) {
106+
if (curr->isAtomic) {
107+
return builder.makeAtomicStore(curr->bytes, curr->offset, copy(curr->ptr), copy(curr->value), curr->valueType);
108+
}
102109
return builder.makeStore(curr->bytes, curr->offset, curr->align, copy(curr->ptr), copy(curr->value), curr->valueType);
103110
}
111+
Expression* visitAtomicRMW(AtomicRMW* curr) {
112+
return builder.makeAtomicRMW(curr->op, curr->bytes, curr->offset,
113+
copy(curr->ptr), copy(curr->value), curr->type);
114+
}
115+
Expression* visitAtomicCmpxchg(AtomicCmpxchg* curr) {
116+
return builder.makeAtomicCmpxchg(curr->bytes, curr->offset,
117+
copy(curr->ptr), copy(curr->expected), copy(curr->replacement),
118+
curr->type);
119+
}
104120
Expression* visitConst(Const *curr) {
105121
return builder.makeConst(curr->value);
106122
}

src/ast/cost.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,16 @@ struct CostAnalyzer : public Visitor<CostAnalyzer, Index> {
7878
return 2;
7979
}
8080
Index visitLoad(Load *curr) {
81-
return 1 + visit(curr->ptr);
81+
return 1 + visit(curr->ptr) + 10 * curr->isAtomic;
8282
}
8383
Index visitStore(Store *curr) {
84-
return 2 + visit(curr->ptr) + visit(curr->value);
84+
return 2 + visit(curr->ptr) + visit(curr->value) + 10 * curr->isAtomic;
85+
}
86+
Index visitAtomicRMW(AtomicRMW *curr) {
87+
return 100;
88+
}
89+
Index visitAtomicCmpxchg(AtomicCmpxchg* curr) {
90+
return 100;
8591
}
8692
Index visitConst(Const *curr) {
8793
return 1;

src/ast/effects.h

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,14 @@ struct EffectAnalyzer : public PostWalker<EffectAnalyzer> {
5353
// (so a trap may occur later or earlier, if it is
5454
// going to occur anyhow), but we can't remove them,
5555
// they count as side effects
56+
bool isAtomic = false; // An atomic load/store/RMW/Cmpxchg or an operator that
57+
// has a defined ordering wrt atomics (e.g. grow_memory)
5658

5759
bool accessesLocal() { return localsRead.size() + localsWritten.size() > 0; }
5860
bool accessesGlobal() { return globalsRead.size() + globalsWritten.size() > 0; }
5961
bool accessesMemory() { return calls || readsMemory || writesMemory; }
60-
bool hasSideEffects() { return calls || localsWritten.size() > 0 || writesMemory || branches || globalsWritten.size() > 0 || implicitTrap; }
61-
bool hasAnything() { return branches || calls || accessesLocal() || readsMemory || writesMemory || accessesGlobal() || implicitTrap; }
62+
bool hasSideEffects() { return calls || localsWritten.size() > 0 || writesMemory || branches || globalsWritten.size() > 0 || implicitTrap || isAtomic; }
63+
bool hasAnything() { return branches || calls || accessesLocal() || readsMemory || writesMemory || accessesGlobal() || implicitTrap || isAtomic; }
6264

6365
// checks if these effects would invalidate another set (e.g., if we write, we invalidate someone that reads, they can't be moved past us)
6466
bool invalidates(EffectAnalyzer& other) {
@@ -67,6 +69,12 @@ struct EffectAnalyzer : public PostWalker<EffectAnalyzer> {
6769
|| (accessesMemory() && (other.writesMemory || other.calls))) {
6870
return true;
6971
}
72+
// All atomics are sequentially consistent for now, and ordered wrt other
73+
// memory references.
74+
if ((isAtomic && other.accessesMemory()) ||
75+
(other.isAtomic && accessesMemory())) {
76+
return true;
77+
}
7078
for (auto local : localsWritten) {
7179
if (other.localsWritten.count(local) || other.localsRead.count(local)) {
7280
return true;
@@ -176,10 +184,24 @@ struct EffectAnalyzer : public PostWalker<EffectAnalyzer> {
176184
}
177185
void visitLoad(Load *curr) {
178186
readsMemory = true;
187+
isAtomic |= curr->isAtomic;
179188
if (!ignoreImplicitTraps) implicitTrap = true;
180189
}
181190
void visitStore(Store *curr) {
182191
writesMemory = true;
192+
isAtomic |= curr->isAtomic;
193+
if (!ignoreImplicitTraps) implicitTrap = true;
194+
}
195+
void visitAtomicRMW(AtomicRMW* curr) {
196+
readsMemory = true;
197+
writesMemory = true;
198+
isAtomic = true;
199+
if (!ignoreImplicitTraps) implicitTrap = true;
200+
}
201+
void visitAtomicCmpxchg(AtomicCmpxchg* curr) {
202+
readsMemory = true;
203+
writesMemory = true;
204+
isAtomic = true;
183205
if (!ignoreImplicitTraps) implicitTrap = true;
184206
}
185207
void visitUnary(Unary *curr) {
@@ -219,11 +241,16 @@ struct EffectAnalyzer : public PostWalker<EffectAnalyzer> {
219241
}
220242
}
221243
void visitReturn(Return *curr) { branches = true; }
222-
void visitHost(Host *curr) { calls = true; }
244+
void visitHost(Host *curr) {
245+
calls = true;
246+
// grow_memory modifies the set of valid addresses, and thus can be modeled as modifying memory
247+
writesMemory = true;
248+
// Atomics are also sequentially consistent with grow_memory.
249+
isAtomic = true;
250+
}
223251
void visitUnreachable(Unreachable *curr) { branches = true; }
224252
};
225253

226254
} // namespace wasm
227255

228256
#endif // wasm_ast_effects_h
229-

src/ast_utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@ struct ReFinalize : public WalkerPass<PostWalker<ReFinalize>> {
154154
void visitSetGlobal(SetGlobal *curr) { curr->finalize(); }
155155
void visitLoad(Load *curr) { curr->finalize(); }
156156
void visitStore(Store *curr) { curr->finalize(); }
157+
void visitAtomicRMW(AtomicRMW *curr) { curr->finalize(); }
158+
void visitAtomicCmpxchg(AtomicCmpxchg *curr) { curr->finalize(); }
157159
void visitConst(Const *curr) { curr->finalize(); }
158160
void visitUnary(Unary *curr) { curr->finalize(); }
159161
void visitBinary(Binary *curr) { curr->finalize(); }

src/passes/DeadCodeElimination.cpp

Lines changed: 38 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
// have no side effects.
2929
//
3030

31+
#include <vector>
3132
#include <wasm.h>
3233
#include <pass.h>
3334
#include <wasm-builder.h>
@@ -321,84 +322,62 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination>>
321322
}
322323
}
323324

324-
void visitSetLocal(SetLocal* curr) {
325-
if (isUnreachable(curr->value)) {
326-
replaceCurrent(curr->value);
325+
// Append the reachable operands of the current node to a block, and replace
326+
// it with the block
327+
void blockifyReachableOperands(std::vector<Expression*>&& list, WasmType type) {
328+
for (size_t i = 0; i < list.size(); ++i) {
329+
auto* elem = list[i];
330+
if (isUnreachable(elem)) {
331+
auto* replacement = elem;
332+
if (i > 0) {
333+
auto* block = getModule()->allocator.alloc<Block>();
334+
for (size_t j = 0; j < i; ++j) {
335+
block->list.push_back(drop(list[j]));
336+
}
337+
block->list.push_back(list[i]);
338+
block->finalize(type);
339+
replacement = block;
340+
}
341+
replaceCurrent(replacement);
342+
return;
343+
}
327344
}
328345
}
329346

347+
void visitSetLocal(SetLocal* curr) {
348+
blockifyReachableOperands({ curr->value }, curr->type);
349+
}
350+
330351
void visitLoad(Load* curr) {
331-
if (isUnreachable(curr->ptr)) {
332-
replaceCurrent(curr->ptr);
333-
}
352+
blockifyReachableOperands({ curr->ptr}, curr->type);
334353
}
335354

336355
void visitStore(Store* curr) {
337-
if (isUnreachable(curr->ptr)) {
338-
replaceCurrent(curr->ptr);
339-
return;
340-
}
341-
if (isUnreachable(curr->value)) {
342-
auto* block = getModule()->allocator.alloc<Block>();
343-
block->list.resize(2);
344-
block->list[0] = drop(curr->ptr);
345-
block->list[1] = curr->value;
346-
block->finalize(curr->type);
347-
replaceCurrent(block);
348-
}
356+
blockifyReachableOperands({ curr->ptr, curr->value }, curr->type);
357+
}
358+
359+
void visitAtomicRMW(AtomicRMW* curr) {
360+
blockifyReachableOperands({ curr->ptr, curr->value }, curr->type);
361+
}
362+
363+
void visitAtomicCmpxchg(AtomicCmpxchg* curr) {
364+
blockifyReachableOperands({ curr->ptr, curr->expected, curr->replacement }, curr->type);
349365
}
350366

351367
void visitUnary(Unary* curr) {
352-
if (isUnreachable(curr->value)) {
353-
replaceCurrent(curr->value);
354-
}
368+
blockifyReachableOperands({ curr->value }, curr->type);
355369
}
356370

357371
void visitBinary(Binary* curr) {
358-
if (isUnreachable(curr->left)) {
359-
replaceCurrent(curr->left);
360-
return;
361-
}
362-
if (isUnreachable(curr->right)) {
363-
auto* block = getModule()->allocator.alloc<Block>();
364-
block->list.resize(2);
365-
block->list[0] = drop(curr->left);
366-
block->list[1] = curr->right;
367-
block->finalize(curr->type);
368-
replaceCurrent(block);
369-
}
372+
blockifyReachableOperands({ curr->left, curr->right}, curr->type);
370373
}
371374

372375
void visitSelect(Select* curr) {
373-
if (isUnreachable(curr->ifTrue)) {
374-
replaceCurrent(curr->ifTrue);
375-
return;
376-
}
377-
if (isUnreachable(curr->ifFalse)) {
378-
auto* block = getModule()->allocator.alloc<Block>();
379-
block->list.resize(2);
380-
block->list[0] = drop(curr->ifTrue);
381-
block->list[1] = curr->ifFalse;
382-
block->finalize(curr->type);
383-
replaceCurrent(block);
384-
return;
385-
}
386-
if (isUnreachable(curr->condition)) {
387-
auto* block = getModule()->allocator.alloc<Block>();
388-
block->list.resize(3);
389-
block->list[0] = drop(curr->ifTrue);
390-
block->list[1] = drop(curr->ifFalse);
391-
block->list[2] = curr->condition;
392-
block->finalize(curr->type);
393-
replaceCurrent(block);
394-
return;
395-
}
376+
blockifyReachableOperands({ curr->ifTrue, curr->ifFalse, curr->condition}, curr->type);
396377
}
397378

398379
void visitDrop(Drop* curr) {
399-
if (isUnreachable(curr->value)) {
400-
replaceCurrent(curr->value);
401-
}
380+
blockifyReachableOperands({ curr->value }, curr->type);
402381
}
403382

404383
void visitHost(Host* curr) {
@@ -415,4 +394,3 @@ Pass *createDeadCodeEliminationPass() {
415394
}
416395

417396
} // namespace wasm
418-

src/passes/InstrumentMemory.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ namespace wasm {
6666

6767
Name load("load");
6868
Name store("store");
69+
// TODO: Add support for atomicRMW/cmpxchg
6970

7071
struct InstrumentMemory : public WalkerPass<PostWalker<InstrumentMemory>> {
7172
void visitLoad(Load* curr) {

src/passes/MergeBlocks.cpp

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -286,17 +286,27 @@ struct MergeBlocks : public WalkerPass<PostWalker<MergeBlocks>> {
286286
void visitStore(Store* curr) {
287287
optimize(curr, curr->value, optimize(curr, curr->ptr), &curr->ptr);
288288
}
289-
290-
void visitSelect(Select* curr) {
289+
void visitAtomicRMW(AtomicRMW* curr) {
290+
optimize(curr, curr->value, optimize(curr, curr->ptr), &curr->ptr);
291+
}
292+
void optimizeTernary(Expression* curr,
293+
Expression*& first, Expression*& second, Expression*& third) {
291294
// TODO: for now, just stop when we see any side effect. instead, we could
292295
// check effects carefully for reordering
293296
Block* outer = nullptr;
294-
if (EffectAnalyzer(getPassOptions(), curr->ifTrue).hasSideEffects()) return;
295-
outer = optimize(curr, curr->ifTrue, outer);
296-
if (EffectAnalyzer(getPassOptions(), curr->ifFalse).hasSideEffects()) return;
297-
outer = optimize(curr, curr->ifFalse, outer);
298-
if (EffectAnalyzer(getPassOptions(), curr->condition).hasSideEffects()) return;
299-
optimize(curr, curr->condition, outer);
297+
if (EffectAnalyzer(getPassOptions(), first).hasSideEffects()) return;
298+
outer = optimize(curr, first, outer);
299+
if (EffectAnalyzer(getPassOptions(), second).hasSideEffects()) return;
300+
outer = optimize(curr, second, outer);
301+
if (EffectAnalyzer(getPassOptions(), third).hasSideEffects()) return;
302+
optimize(curr, third, outer);
303+
}
304+
void visitAtomicCmpxchg(AtomicCmpxchg* curr) {
305+
optimizeTernary(curr, curr->ptr, curr->expected, curr->replacement);
306+
}
307+
308+
void visitSelect(Select* curr) {
309+
optimizeTernary(curr, curr->ifTrue, curr->ifFalse, curr->condition);
300310
}
301311

302312
void visitDrop(Drop* curr) {
@@ -344,4 +354,3 @@ Pass *createMergeBlocksPass() {
344354
}
345355

346356
} // namespace wasm
347-

src/passes/Precompute.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@ class StandaloneExpressionRunner : public ExpressionRunner<StandaloneExpressionR
6767
Flow visitStore(Store *curr) {
6868
return Flow(NONSTANDALONE_FLOW);
6969
}
70+
Flow visitAtomicRMW(AtomicRMW *curr) {
71+
return Flow(NONSTANDALONE_FLOW);
72+
}
73+
Flow visitAtomicCmpxchg(AtomicCmpxchg *curr) {
74+
return Flow(NONSTANDALONE_FLOW);
75+
}
7076
Flow visitHost(Host *curr) {
7177
return Flow(NONSTANDALONE_FLOW);
7278
}

src/wasm-builder.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,11 @@ class Builder {
193193
ret->type = type;
194194
return ret;
195195
}
196+
Load* makeAtomicLoad(unsigned bytes, bool signed_, uint32_t offset, Expression* ptr, WasmType type) {
197+
Load* load = makeLoad(bytes, signed_, offset, getWasmTypeSize(type), ptr, type);
198+
load->isAtomic = true;
199+
return load;
200+
}
196201
Store* makeStore(unsigned bytes, uint32_t offset, unsigned align, Expression *ptr, Expression *value, WasmType type) {
197202
auto* ret = allocator.alloc<Store>();
198203
ret->isAtomic = false;
@@ -201,6 +206,36 @@ class Builder {
201206
assert(isConcreteWasmType(ret->value->type) ? ret->value->type == type : true);
202207
return ret;
203208
}
209+
Store* makeAtomicStore(unsigned bytes, uint32_t offset, Expression* ptr, Expression* value, WasmType type) {
210+
Store* store = makeStore(bytes, offset, getWasmTypeSize(type), ptr, value, type);
211+
store->isAtomic = true;
212+
return store;
213+
}
214+
AtomicRMW* makeAtomicRMW(AtomicRMWOp op, unsigned bytes, uint32_t offset,
215+
Expression* ptr, Expression* value, WasmType type) {
216+
auto* ret = allocator.alloc<AtomicRMW>();
217+
ret->op = op;
218+
ret->bytes = bytes;
219+
ret->offset = offset;
220+
ret->ptr = ptr;
221+
ret->value = value;
222+
ret->type = type;
223+
ret->finalize();
224+
return ret;
225+
}
226+
AtomicCmpxchg* makeAtomicCmpxchg(unsigned bytes, uint32_t offset,
227+
Expression* ptr, Expression* expected, Expression* replacement,
228+
WasmType type) {
229+
auto* ret = allocator.alloc<AtomicCmpxchg>();
230+
ret->bytes = bytes;
231+
ret->offset = offset;
232+
ret->ptr = ptr;
233+
ret->expected = expected;
234+
ret->replacement = replacement;
235+
ret->type = type;
236+
ret->finalize();
237+
return ret;
238+
}
204239
Const* makeConst(Literal value) {
205240
assert(isConcreteWasmType(value.type));
206241
auto* ret = allocator.alloc<Const>();

0 commit comments

Comments
 (0)