Skip to content

Commit bcb29e5

Browse files
authored
Add IR, parsing, printing, and binary for atomic cmpxchg (#1083)
1 parent 4995132 commit bcb29e5

12 files changed

Lines changed: 287 additions & 14 deletions

src/passes/Print.cpp

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -350,22 +350,25 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
350350
printFullLine(curr->value);
351351
decIndent();
352352
}
353-
void visitAtomicRMW(AtomicRMW* curr) {
354-
o << '(';
355-
prepareColor(o) << printWasmType(curr->type) << ".atomic.rmw";
356-
if (curr->bytes != getWasmTypeSize(curr->type)) {
357-
if (curr->bytes == 1) {
353+
static void printRMWSize(std::ostream& o, WasmType type, uint8_t bytes) {
354+
prepareColor(o) << printWasmType(type) << ".atomic.rmw";
355+
if (bytes != getWasmTypeSize(type)) {
356+
if (bytes == 1) {
358357
o << '8';
359-
} else if (curr->bytes == 2) {
358+
} else if (bytes == 2) {
360359
o << "16";
361-
} else if (curr->bytes == 4) {
360+
} else if (bytes == 4) {
362361
o << "32";
363362
} else {
364363
WASM_UNREACHABLE();
365364
}
366365
o << "_u";
367366
}
368367
o << '.';
368+
}
369+
void visitAtomicRMW(AtomicRMW* curr) {
370+
o << '(';
371+
printRMWSize(o, curr->type, curr->bytes);
369372
switch (curr->op) {
370373
case Add: o << "add"; break;
371374
case Sub: o << "sub"; break;
@@ -383,6 +386,20 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
383386
printFullLine(curr->value);
384387
decIndent();
385388
}
389+
void visitAtomicCmpxchg(AtomicCmpxchg* curr) {
390+
o << '(';
391+
printRMWSize(o, curr->type, curr->bytes);
392+
o << "cmpxchg";
393+
restoreNormalColor(o);
394+
if (curr->offset) {
395+
o << " offset=" << curr->offset;
396+
}
397+
incIndent();
398+
printFullLine(curr->ptr);
399+
printFullLine(curr->expected);
400+
printFullLine(curr->replacement);
401+
decIndent();
402+
}
386403
void visitConst(Const *curr) {
387404
o << curr->value;
388405
}

src/wasm-binary.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,16 @@ enum AtomicOpcodes {
571571
I64AtomicRMWXchg16U = 0x46,
572572
I64AtomicRMWXchg32U = 0x47,
573573
AtomicRMWOps_End = 0x47,
574+
575+
AtomicCmpxchgOps_Begin = 0x48,
576+
I32AtomicCmpxchg = 0x48,
577+
I64AtomicCmpxchg = 0x49,
578+
I32AtomicCmpxchg8U = 0x4a,
579+
I32AtomicCmpxchg16U = 0x4b,
580+
I64AtomicCmpxchg8U = 0x4c,
581+
I64AtomicCmpxchg16U = 0x4d,
582+
I64AtomicCmpxchg32U = 0x4e,
583+
AtomicCmpxchgOps_End = 0x4e
574584
};
575585

576586

@@ -723,6 +733,7 @@ class WasmBinaryWriter : public Visitor<WasmBinaryWriter, void> {
723733
void visitLoad(Load *curr);
724734
void visitStore(Store *curr);
725735
void visitAtomicRMW(AtomicRMW *curr);
736+
void visitAtomicCmpxchg(AtomicCmpxchg *curr);
726737
void visitConst(Const *curr);
727738
void visitUnary(Unary *curr);
728739
void visitBinary(Binary *curr);
@@ -881,6 +892,7 @@ class WasmBinaryBuilder {
881892
bool maybeVisitLoad(Expression*& out, uint8_t code, bool isAtomic);
882893
bool maybeVisitStore(Expression*& out, uint8_t code, bool isAtomic);
883894
bool maybeVisitAtomicRMW(Expression*& out, uint8_t code);
895+
bool maybeVisitAtomicCmpxchg(Expression*& out, uint8_t code);
884896
bool maybeVisitConst(Expression*& out, uint8_t code);
885897
bool maybeVisitUnary(Expression*& out, uint8_t code);
886898
bool maybeVisitBinary(Expression*& out, uint8_t code);

src/wasm-s-parser.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,9 @@ class SExpressionWasmBuilder {
177177
Expression* makeConst(Element& s, WasmType type);
178178
Expression* makeLoad(Element& s, WasmType type, bool isAtomic);
179179
Expression* makeStore(Element& s, WasmType type, bool isAtomic);
180-
Expression* makeAtomicRMW(Element& s, WasmType type);
180+
Expression* makeAtomicRMWOrCmpxchg(Element& s, WasmType type);
181+
Expression* makeAtomicRMW(Element& s, WasmType type, uint8_t bytes, const char* extra);
182+
Expression* makeAtomicCmpxchg(Element& s, WasmType type, uint8_t bytes, const char* extra);
181183
Expression* makeIf(Element& s);
182184
Expression* makeMaybeBlock(Element& s, size_t i, WasmType type);
183185
Expression* makeLoop(Element& s);

src/wasm-traversal.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ struct Visitor {
5050
ReturnType visitLoad(Load* curr) {}
5151
ReturnType visitStore(Store* curr) {}
5252
ReturnType visitAtomicRMW(AtomicRMW* curr) {return ReturnType();} //Stub impl so not every pass has to implement this yet.
53+
ReturnType visitAtomicCmpxchg(AtomicCmpxchg* curr) {return ReturnType();} //Stub impl so not every pass has to implement this yet.
5354
ReturnType visitConst(Const* curr) {}
5455
ReturnType visitUnary(Unary* curr) {}
5556
ReturnType visitBinary(Binary* curr) {}
@@ -92,6 +93,7 @@ struct Visitor {
9293
case Expression::Id::LoadId: DELEGATE(Load);
9394
case Expression::Id::StoreId: DELEGATE(Store);
9495
case Expression::Id::AtomicRMWId: DELEGATE(AtomicRMW);
96+
case Expression::Id::AtomicCmpxchgId: DELEGATE(AtomicCmpxchg);
9597
case Expression::Id::ConstId: DELEGATE(Const);
9698
case Expression::Id::UnaryId: DELEGATE(Unary);
9799
case Expression::Id::BinaryId: DELEGATE(Binary);
@@ -133,6 +135,7 @@ struct UnifiedExpressionVisitor : public Visitor<SubType> {
133135
ReturnType visitLoad(Load* curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
134136
ReturnType visitStore(Store* curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
135137
ReturnType visitAtomicRMW(AtomicRMW* curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
138+
ReturnType visitAtomicCmpxchg(AtomicCmpxchg* curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
136139
ReturnType visitConst(Const* curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
137140
ReturnType visitUnary(Unary* curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
138141
ReturnType visitBinary(Binary* curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
@@ -310,6 +313,7 @@ struct Walker : public VisitorType {
310313
static void doVisitLoad(SubType* self, Expression** currp) { self->visitLoad((*currp)->cast<Load>()); }
311314
static void doVisitStore(SubType* self, Expression** currp) { self->visitStore((*currp)->cast<Store>()); }
312315
static void doVisitAtomicRMW(SubType* self, Expression** currp) { self->visitAtomicRMW((*currp)->cast<AtomicRMW>()); }
316+
static void doVisitAtomicCmpxchg(SubType* self, Expression** currp){ self->visitAtomicCmpxchg((*currp)->cast<AtomicCmpxchg>()); }
313317
static void doVisitConst(SubType* self, Expression** currp) { self->visitConst((*currp)->cast<Const>()); }
314318
static void doVisitUnary(SubType* self, Expression** currp) { self->visitUnary((*currp)->cast<Unary>()); }
315319
static void doVisitBinary(SubType* self, Expression** currp) { self->visitBinary((*currp)->cast<Binary>()); }
@@ -438,6 +442,13 @@ struct PostWalker : public Walker<SubType, VisitorType> {
438442
self->pushTask(SubType::scan, &curr->cast<AtomicRMW>()->ptr);
439443
break;
440444
}
445+
case Expression::Id::AtomicCmpxchgId: {
446+
self->pushTask(SubType::doVisitAtomicCmpxchg, currp);
447+
self->pushTask(SubType::scan, &curr->cast<AtomicCmpxchg>()->replacement);
448+
self->pushTask(SubType::scan, &curr->cast<AtomicCmpxchg>()->expected);
449+
self->pushTask(SubType::scan, &curr->cast<AtomicCmpxchg>()->ptr);
450+
break;
451+
}
441452
case Expression::Id::ConstId: {
442453
self->pushTask(SubType::doVisitConst, currp);
443454
break;

src/wasm.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ enum HostOp {
131131
};
132132

133133
enum AtomicRMWOp {
134-
Add, Sub, And, Or, Xor, Xchg,
134+
Add, Sub, And, Or, Xor, Xchg
135135
};
136136

137137
//
@@ -445,6 +445,15 @@ class AtomicRMW : public SpecificExpression<Expression::AtomicRMWId> {
445445
class AtomicCmpxchg : public SpecificExpression<Expression::AtomicCmpxchgId> {
446446
public:
447447
AtomicCmpxchg() = default;
448+
AtomicCmpxchg(MixedArena& allocator) : AtomicCmpxchg() {}
449+
450+
uint8_t bytes;
451+
Address offset;
452+
Expression* ptr;
453+
Expression* expected;
454+
Expression* replacement;
455+
456+
void finalize();
448457
};
449458

450459
class Const : public SpecificExpression<Expression::ConstId> {

src/wasm/wasm-binary.cpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -877,6 +877,37 @@ void WasmBinaryWriter::visitAtomicRMW(AtomicRMW *curr) {
877877
emitMemoryAccess(curr->bytes, curr->bytes, curr->offset);
878878
}
879879

880+
void WasmBinaryWriter::visitAtomicCmpxchg(AtomicCmpxchg *curr) {
881+
if (debug) std::cerr << "zz node: AtomicCmpxchg" << std::endl;
882+
recurse(curr->ptr);
883+
recurse(curr->expected);
884+
recurse(curr->replacement);
885+
886+
o << int8_t(BinaryConsts::AtomicPrefix);
887+
switch (curr->type) {
888+
case i32:
889+
switch (curr->bytes) {
890+
case 1: o << int8_t(BinaryConsts::I32AtomicCmpxchg8U); break;
891+
case 2: o << int8_t(BinaryConsts::I32AtomicCmpxchg16U); break;
892+
case 4: o << int8_t(BinaryConsts::I32AtomicCmpxchg); break;
893+
default: WASM_UNREACHABLE();
894+
}
895+
break;
896+
case i64:
897+
switch (curr->bytes) {
898+
case 1: o << int8_t(BinaryConsts::I64AtomicCmpxchg8U); break;
899+
case 2: o << int8_t(BinaryConsts::I64AtomicCmpxchg16U); break;
900+
case 4: o << int8_t(BinaryConsts::I64AtomicCmpxchg32U); break;
901+
case 8: o << int8_t(BinaryConsts::I64AtomicCmpxchg); break;
902+
default: WASM_UNREACHABLE();
903+
}
904+
break;
905+
default: WASM_UNREACHABLE();
906+
}
907+
emitMemoryAccess(curr->bytes, curr->bytes, curr->offset);
908+
}
909+
910+
880911
void WasmBinaryWriter::visitConst(Const *curr) {
881912
if (debug) std::cerr << "zz node: Const" << curr << " : " << curr->type << std::endl;
882913
switch (curr->type) {
@@ -1980,6 +2011,7 @@ BinaryConsts::ASTNodes WasmBinaryBuilder::readExpression(Expression*& curr) {
19802011
if (maybeVisitLoad(curr, code, /*isAtomic=*/true)) break;
19812012
if (maybeVisitStore(curr, code, /*isAtomic=*/true)) break;
19822013
if (maybeVisitAtomicRMW(curr, code)) break;
2014+
if (maybeVisitAtomicCmpxchg(curr, code)) break;
19832015
throw ParseException("invalid code after atomic prefix: " + std::to_string(code));
19842016
}
19852017
default: {
@@ -2372,6 +2404,38 @@ bool WasmBinaryBuilder::maybeVisitAtomicRMW(Expression*& out, uint8_t code) {
23722404
return true;
23732405
}
23742406

2407+
bool WasmBinaryBuilder::maybeVisitAtomicCmpxchg(Expression*& out, uint8_t code) {
2408+
if (code < BinaryConsts::AtomicCmpxchgOps_Begin || code > BinaryConsts::AtomicCmpxchgOps_End) return false;
2409+
auto* curr = allocator.alloc<AtomicCmpxchg>();
2410+
2411+
// Set curr to the given type and size.
2412+
#define SET(optype, size) \
2413+
curr->type = optype; \
2414+
curr->bytes = size
2415+
2416+
switch (code) {
2417+
case BinaryConsts::I32AtomicCmpxchg: SET(i32, 4); break;
2418+
case BinaryConsts::I64AtomicCmpxchg: SET(i64, 8); break;
2419+
case BinaryConsts::I32AtomicCmpxchg8U: SET(i32, 1); break;
2420+
case BinaryConsts::I32AtomicCmpxchg16U: SET(i32, 2); break;
2421+
case BinaryConsts::I64AtomicCmpxchg8U: SET(i64, 1); break;
2422+
case BinaryConsts::I64AtomicCmpxchg16U: SET(i64, 2); break;
2423+
case BinaryConsts::I64AtomicCmpxchg32U: SET(i64, 4); break;
2424+
default: WASM_UNREACHABLE();
2425+
}
2426+
2427+
if (debug) std::cerr << "zz node: AtomicCmpxchg" << std::endl;
2428+
Address readAlign;
2429+
readMemoryAccess(readAlign, curr->bytes, curr->offset);
2430+
if (readAlign != curr->bytes) throw ParseException("Align of AtomicCpxchg must match size");
2431+
curr->replacement = popNonVoidExpression();
2432+
curr->expected = popNonVoidExpression();
2433+
curr->ptr = popNonVoidExpression();
2434+
curr->finalize();
2435+
out = curr;
2436+
return true;
2437+
}
2438+
23752439
bool WasmBinaryBuilder::maybeVisitConst(Expression*& out, uint8_t code) {
23762440
Const* curr;
23772441
if (debug) std::cerr << "zz node: Const, code " << code << std::endl;

src/wasm/wasm-s-parser.cpp

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -666,7 +666,7 @@ Expression* SExpressionWasmBuilder::makeExpression(Element& s) {
666666
if (op[1] == 't' && !strncmp(op, "atomic.", strlen("atomic."))) {
667667
if (op[7] == 'l') return makeLoad(s, type, /*isAtomic=*/true);
668668
if (op[7] == 's') return makeStore(s, type, /*isAtomic=*/true);
669-
if (op[7] == 'r') return makeAtomicRMW(s, type);
669+
if (op[7] == 'r') return makeAtomicRMWOrCmpxchg(s, type);
670670
}
671671
abort_on(op);
672672
}
@@ -1197,14 +1197,20 @@ Expression* SExpressionWasmBuilder::makeStore(Element& s, WasmType type, bool is
11971197
return ret;
11981198
}
11991199

1200-
Expression* SExpressionWasmBuilder::makeAtomicRMW(Element& s, WasmType type) {
1200+
Expression* SExpressionWasmBuilder::makeAtomicRMWOrCmpxchg(Element& s, WasmType type) {
12011201
const char* extra = strchr(s[0]->c_str(), '.') + 11; // afer "type.atomic.rmw"
1202-
auto ret = allocator.alloc<AtomicRMW>();
1203-
ret->type = type;
1204-
ret->bytes = parseMemBytes(&extra, getWasmTypeSize(type));
1202+
auto bytes = parseMemBytes(&extra, getWasmTypeSize(type));
12051203
extra = strchr(extra, '.'); // after the optional '_u' and before the opcode
12061204
if (!extra) throw ParseException("malformed atomic rmw instruction");
12071205
extra++; // after the '.'
1206+
if (!strncmp(extra, "cmpxchg", 7)) return makeAtomicCmpxchg(s, type, bytes, extra);
1207+
return makeAtomicRMW(s, type, bytes, extra);
1208+
}
1209+
1210+
Expression* SExpressionWasmBuilder::makeAtomicRMW(Element& s, WasmType type, uint8_t bytes, const char* extra) {
1211+
auto ret = allocator.alloc<AtomicRMW>();
1212+
ret->type = type;
1213+
ret->bytes = bytes;
12081214
if (!strncmp(extra, "add", 3)) ret->op = Add;
12091215
else if (!strncmp(extra, "and", 3)) ret->op = And;
12101216
else if (!strncmp(extra, "or", 2)) ret->op = Or;
@@ -1221,6 +1227,20 @@ Expression* SExpressionWasmBuilder::makeAtomicRMW(Element& s, WasmType type) {
12211227
return ret;
12221228
}
12231229

1230+
Expression* SExpressionWasmBuilder::makeAtomicCmpxchg(Element& s, WasmType type, uint8_t bytes, const char* extra) {
1231+
auto ret = allocator.alloc<AtomicCmpxchg>();
1232+
ret->type = type;
1233+
ret->bytes = bytes;
1234+
Address align;
1235+
size_t i = parseMemAttributes(s, &ret->offset, &align, ret->bytes);
1236+
if (align != ret->bytes) throw ParseException("Align of Atomic Cmpxchg must match size");
1237+
ret->ptr = parseExpression(s[i]);
1238+
ret->expected = parseExpression(s[i+1]);
1239+
ret->replacement = parseExpression(s[i+2]);
1240+
ret->finalize();
1241+
return ret;
1242+
}
1243+
12241244
Expression* SExpressionWasmBuilder::makeIf(Element& s) {
12251245
auto ret = allocator.alloc<If>();
12261246
Index i = 1;

src/wasm/wasm.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,12 @@ void AtomicRMW::finalize() {
359359
}
360360
}
361361

362+
void AtomicCmpxchg::finalize() {
363+
if (ptr->type == unreachable || expected->type == unreachable || replacement->type == unreachable) {
364+
type = unreachable;
365+
}
366+
}
367+
362368
Const* Const::set(Literal value_) {
363369
value = value_;
364370
type = value.type;

test/atomics.wast

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,4 +102,36 @@
102102
)
103103
)
104104
)
105+
(func $atomic-cmpxchg (type $0)
106+
(local $0 i32)
107+
(local $1 i32)
108+
(drop
109+
(i32.atomic.rmw.cmpxchg offset=4
110+
(get_local $0)
111+
(get_local $0)
112+
(get_local $0)
113+
)
114+
)
115+
(drop
116+
(i32.atomic.rmw8_u.cmpxchg
117+
(get_local $0)
118+
(get_local $0)
119+
(get_local $0)
120+
)
121+
)
122+
(drop
123+
(i64.atomic.rmw.cmpxchg offset=4
124+
(get_local $0)
125+
(get_local $0)
126+
(get_local $0)
127+
)
128+
)
129+
(drop
130+
(i64.atomic.rmw32_u.cmpxchg align=4
131+
(get_local $0)
132+
(get_local $0)
133+
(get_local $0)
134+
)
135+
)
136+
)
105137
)

test/atomics.wast.from-wast

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,4 +102,36 @@
102102
)
103103
)
104104
)
105+
(func $atomic-cmpxchg (type $0)
106+
(local $0 i32)
107+
(local $1 i32)
108+
(drop
109+
(i32.atomic.rmw.cmpxchg offset=4
110+
(get_local $0)
111+
(get_local $0)
112+
(get_local $0)
113+
)
114+
)
115+
(drop
116+
(i32.atomic.rmw8_u.cmpxchg
117+
(get_local $0)
118+
(get_local $0)
119+
(get_local $0)
120+
)
121+
)
122+
(drop
123+
(i64.atomic.rmw.cmpxchg offset=4
124+
(get_local $0)
125+
(get_local $0)
126+
(get_local $0)
127+
)
128+
)
129+
(drop
130+
(i64.atomic.rmw32_u.cmpxchg
131+
(get_local $0)
132+
(get_local $0)
133+
(get_local $0)
134+
)
135+
)
136+
)
105137
)

0 commit comments

Comments
 (0)