Skip to content

Commit 4995132

Browse files
authored
Add IR, parsing and binary support for AtomicRMW instructions from wasm threads proposal (#1082)
Also leave a stub (but valid) visitAtomicRMW in the visitor template so that not all visitors need to implement this function yet.
1 parent cd2a431 commit 4995132

12 files changed

Lines changed: 427 additions & 73 deletions

src/passes/Print.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,39 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
350350
printFullLine(curr->value);
351351
decIndent();
352352
}
353+
void visitAtomicRMW(AtomicRMW* curr) {
354+
o << '(';
355+
prepareColor(o) << printWasmType(curr->type) << ".atomic.rmw";
356+
if (curr->bytes != getWasmTypeSize(curr->type)) {
357+
if (curr->bytes == 1) {
358+
o << '8';
359+
} else if (curr->bytes == 2) {
360+
o << "16";
361+
} else if (curr->bytes == 4) {
362+
o << "32";
363+
} else {
364+
WASM_UNREACHABLE();
365+
}
366+
o << "_u";
367+
}
368+
o << '.';
369+
switch (curr->op) {
370+
case Add: o << "add"; break;
371+
case Sub: o << "sub"; break;
372+
case And: o << "and"; break;
373+
case Or: o << "or"; break;
374+
case Xor: o << "xor"; break;
375+
case Xchg: o << "xchg"; break;
376+
}
377+
restoreNormalColor(o);
378+
if (curr->offset) {
379+
o << " offset=" << curr->offset;
380+
}
381+
incIndent();
382+
printFullLine(curr->ptr);
383+
printFullLine(curr->value);
384+
decIndent();
385+
}
353386
void visitConst(Const *curr) {
354387
o << curr->value;
355388
}

src/wasm-binary.h

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,9 +525,55 @@ enum AtomicOpcodes {
525525
I32AtomicStore16 = 0x1a,
526526
I64AtomicStore8 = 0x1b,
527527
I64AtomicStore16 = 0x1c,
528-
I64AtomicStore32 = 0x1d
528+
I64AtomicStore32 = 0x1d,
529+
530+
AtomicRMWOps_Begin = 0x1e,
531+
I32AtomicRMWAdd = 0x1e,
532+
I64AtomicRMWAdd = 0x1f,
533+
I32AtomicRMWAdd8U = 0x20,
534+
I32AtomicRMWAdd16U = 0x21,
535+
I64AtomicRMWAdd8U = 0x22,
536+
I64AtomicRMWAdd16U = 0x23,
537+
I64AtomicRMWAdd32U = 0x24,
538+
I32AtomicRMWSub = 0x25,
539+
I64AtomicRMWSub = 0x26,
540+
I32AtomicRMWSub8U = 0x27,
541+
I32AtomicRMWSub16U = 0x28,
542+
I64AtomicRMWSub8U = 0x29,
543+
I64AtomicRMWSub16U = 0x2a,
544+
I64AtomicRMWSub32U = 0x2b,
545+
I32AtomicRMWAnd = 0x2c,
546+
I64AtomicRMWAnd = 0x2d,
547+
I32AtomicRMWAnd8U = 0x2e,
548+
I32AtomicRMWAnd16U = 0x2f,
549+
I64AtomicRMWAnd8U = 0x30,
550+
I64AtomicRMWAnd16U = 0x31,
551+
I64AtomicRMWAnd32U = 0x32,
552+
I32AtomicRMWOr = 0x33,
553+
I64AtomicRMWOr = 0x34,
554+
I32AtomicRMWOr8U = 0x35,
555+
I32AtomicRMWOr16U = 0x36,
556+
I64AtomicRMWOr8U = 0x37,
557+
I64AtomicRMWOr16U = 0x38,
558+
I64AtomicRMWOr32U = 0x39,
559+
I32AtomicRMWXor = 0x3a,
560+
I64AtomicRMWXor = 0x3b,
561+
I32AtomicRMWXor8U = 0x3c,
562+
I32AtomicRMWXor16U = 0x3d,
563+
I64AtomicRMWXor8U = 0x3e,
564+
I64AtomicRMWXor16U = 0x3f,
565+
I64AtomicRMWXor32U = 0x40,
566+
I32AtomicRMWXchg = 0x41,
567+
I64AtomicRMWXchg = 0x42,
568+
I32AtomicRMWXchg8U = 0x43,
569+
I32AtomicRMWXchg16U = 0x44,
570+
I64AtomicRMWXchg8U = 0x45,
571+
I64AtomicRMWXchg16U = 0x46,
572+
I64AtomicRMWXchg32U = 0x47,
573+
AtomicRMWOps_End = 0x47,
529574
};
530575

576+
531577
enum MemoryAccess {
532578
Offset = 0x10, // bit 4
533579
Alignment = 0x80, // bit 7
@@ -676,6 +722,7 @@ class WasmBinaryWriter : public Visitor<WasmBinaryWriter, void> {
676722
void emitMemoryAccess(size_t alignment, size_t bytes, uint32_t offset);
677723
void visitLoad(Load *curr);
678724
void visitStore(Store *curr);
725+
void visitAtomicRMW(AtomicRMW *curr);
679726
void visitConst(Const *curr);
680727
void visitUnary(Unary *curr);
681728
void visitBinary(Binary *curr);
@@ -833,6 +880,7 @@ class WasmBinaryBuilder {
833880
void readMemoryAccess(Address& alignment, size_t bytes, Address& offset);
834881
bool maybeVisitLoad(Expression*& out, uint8_t code, bool isAtomic);
835882
bool maybeVisitStore(Expression*& out, uint8_t code, bool isAtomic);
883+
bool maybeVisitAtomicRMW(Expression*& out, uint8_t code);
836884
bool maybeVisitConst(Expression*& out, uint8_t code);
837885
bool maybeVisitUnary(Expression*& out, uint8_t code);
838886
bool maybeVisitBinary(Expression*& out, uint8_t code);

src/wasm-s-parser.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ class SExpressionWasmBuilder {
177177
Expression* makeConst(Element& s, WasmType type);
178178
Expression* makeLoad(Element& s, WasmType type, bool isAtomic);
179179
Expression* makeStore(Element& s, WasmType type, bool isAtomic);
180+
Expression* makeAtomicRMW(Element& s, WasmType type);
180181
Expression* makeIf(Element& s);
181182
Expression* makeMaybeBlock(Element& s, size_t i, WasmType type);
182183
Expression* makeLoop(Element& s);

src/wasm-traversal.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ struct Visitor {
4949
ReturnType visitSetGlobal(SetGlobal* curr) {}
5050
ReturnType visitLoad(Load* curr) {}
5151
ReturnType visitStore(Store* curr) {}
52+
ReturnType visitAtomicRMW(AtomicRMW* curr) {return ReturnType();} //Stub impl so not every pass has to implement this yet.
5253
ReturnType visitConst(Const* curr) {}
5354
ReturnType visitUnary(Unary* curr) {}
5455
ReturnType visitBinary(Binary* curr) {}
@@ -90,6 +91,7 @@ struct Visitor {
9091
case Expression::Id::SetGlobalId: DELEGATE(SetGlobal);
9192
case Expression::Id::LoadId: DELEGATE(Load);
9293
case Expression::Id::StoreId: DELEGATE(Store);
94+
case Expression::Id::AtomicRMWId: DELEGATE(AtomicRMW);
9395
case Expression::Id::ConstId: DELEGATE(Const);
9496
case Expression::Id::UnaryId: DELEGATE(Unary);
9597
case Expression::Id::BinaryId: DELEGATE(Binary);
@@ -130,6 +132,7 @@ struct UnifiedExpressionVisitor : public Visitor<SubType> {
130132
ReturnType visitSetGlobal(SetGlobal* curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
131133
ReturnType visitLoad(Load* curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
132134
ReturnType visitStore(Store* curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
135+
ReturnType visitAtomicRMW(AtomicRMW* curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
133136
ReturnType visitConst(Const* curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
134137
ReturnType visitUnary(Unary* curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
135138
ReturnType visitBinary(Binary* curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
@@ -306,6 +309,7 @@ struct Walker : public VisitorType {
306309
static void doVisitSetGlobal(SubType* self, Expression** currp) { self->visitSetGlobal((*currp)->cast<SetGlobal>()); }
307310
static void doVisitLoad(SubType* self, Expression** currp) { self->visitLoad((*currp)->cast<Load>()); }
308311
static void doVisitStore(SubType* self, Expression** currp) { self->visitStore((*currp)->cast<Store>()); }
312+
static void doVisitAtomicRMW(SubType* self, Expression** currp) { self->visitAtomicRMW((*currp)->cast<AtomicRMW>()); }
309313
static void doVisitConst(SubType* self, Expression** currp) { self->visitConst((*currp)->cast<Const>()); }
310314
static void doVisitUnary(SubType* self, Expression** currp) { self->visitUnary((*currp)->cast<Unary>()); }
311315
static void doVisitBinary(SubType* self, Expression** currp) { self->visitBinary((*currp)->cast<Binary>()); }
@@ -428,6 +432,12 @@ struct PostWalker : public Walker<SubType, VisitorType> {
428432
self->pushTask(SubType::scan, &curr->cast<Store>()->ptr);
429433
break;
430434
}
435+
case Expression::Id::AtomicRMWId: {
436+
self->pushTask(SubType::doVisitAtomicRMW, currp);
437+
self->pushTask(SubType::scan, &curr->cast<AtomicRMW>()->value);
438+
self->pushTask(SubType::scan, &curr->cast<AtomicRMW>()->ptr);
439+
break;
440+
}
431441
case Expression::Id::ConstId: {
432442
self->pushTask(SubType::doVisitConst, currp);
433443
break;

src/wasm.h

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,10 @@ enum HostOp {
130130
PageSize, CurrentMemory, GrowMemory, HasFeature
131131
};
132132

133+
enum AtomicRMWOp {
134+
Add, Sub, And, Or, Xor, Xchg,
135+
};
136+
133137
//
134138
// Expressions
135139
//
@@ -177,6 +181,7 @@ class Expression {
177181
HostId,
178182
NopId,
179183
UnreachableId,
184+
AtomicCmpxchgId,
180185
AtomicRMWId,
181186
NumExpressionIds
182187
};
@@ -423,6 +428,25 @@ class Store : public SpecificExpression<Expression::StoreId> {
423428
void finalize();
424429
};
425430

431+
class AtomicRMW : public SpecificExpression<Expression::AtomicRMWId> {
432+
public:
433+
AtomicRMW() = default;
434+
AtomicRMW(MixedArena& allocator) : AtomicRMW() {}
435+
436+
AtomicRMWOp op;
437+
uint8_t bytes;
438+
Address offset;
439+
Expression* ptr;
440+
Expression* value;
441+
442+
void finalize();
443+
};
444+
445+
class AtomicCmpxchg : public SpecificExpression<Expression::AtomicCmpxchgId> {
446+
public:
447+
AtomicCmpxchg() = default;
448+
};
449+
426450
class Const : public SpecificExpression<Expression::ConstId> {
427451
public:
428452
Const() {}
@@ -514,13 +538,6 @@ class Unreachable : public SpecificExpression<Expression::UnreachableId> {
514538
Unreachable(MixedArena& allocator) : Unreachable() {}
515539
};
516540

517-
class AtomicRMW : public SpecificExpression<Expression::AtomicRMWId> {
518-
public:
519-
AtomicRMW() {}
520-
AtomicRMW(MixedArena& allocator) : AtomicRMW() {}
521-
bool finalize();
522-
};
523-
524541
// Globals
525542

526543
class Function {

src/wasm/wasm-binary.cpp

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -832,6 +832,51 @@ void WasmBinaryWriter::visitStore(Store *curr) {
832832
emitMemoryAccess(curr->align, curr->bytes, curr->offset);
833833
}
834834

835+
void WasmBinaryWriter::visitAtomicRMW(AtomicRMW *curr) {
836+
if (debug) std::cerr << "zz node: AtomicRMW" << std::endl;
837+
recurse(curr->ptr);
838+
recurse(curr->value);
839+
840+
o << int8_t(BinaryConsts::AtomicPrefix);
841+
842+
#define CASE_FOR_OP(Op) \
843+
case Op: \
844+
switch (curr->type) { \
845+
case i32: \
846+
switch (curr->bytes) { \
847+
case 1: o << int8_t(BinaryConsts::I32AtomicRMW##Op##8U); break; \
848+
case 2: o << int8_t(BinaryConsts::I32AtomicRMW##Op##16U); break; \
849+
case 4: o << int8_t(BinaryConsts::I32AtomicRMW##Op); break; \
850+
default: WASM_UNREACHABLE(); \
851+
} \
852+
break; \
853+
case i64: \
854+
switch (curr->bytes) { \
855+
case 1: o << int8_t(BinaryConsts::I64AtomicRMW##Op##8U); break; \
856+
case 2: o << int8_t(BinaryConsts::I64AtomicRMW##Op##16U); break; \
857+
case 4: o << int8_t(BinaryConsts::I64AtomicRMW##Op##32U); break; \
858+
case 8: o << int8_t(BinaryConsts::I64AtomicRMW##Op); break; \
859+
default: WASM_UNREACHABLE(); \
860+
} \
861+
break; \
862+
default: WASM_UNREACHABLE(); \
863+
} \
864+
break
865+
866+
switch(curr->op) {
867+
CASE_FOR_OP(Add);
868+
CASE_FOR_OP(Sub);
869+
CASE_FOR_OP(And);
870+
CASE_FOR_OP(Or);
871+
CASE_FOR_OP(Xor);
872+
CASE_FOR_OP(Xchg);
873+
default: WASM_UNREACHABLE();
874+
}
875+
#undef CASE_FOR_OP
876+
877+
emitMemoryAccess(curr->bytes, curr->bytes, curr->offset);
878+
}
879+
835880
void WasmBinaryWriter::visitConst(Const *curr) {
836881
if (debug) std::cerr << "zz node: Const" << curr << " : " << curr->type << std::endl;
837882
switch (curr->type) {
@@ -1934,6 +1979,7 @@ BinaryConsts::ASTNodes WasmBinaryBuilder::readExpression(Expression*& curr) {
19341979
code = getInt8();
19351980
if (maybeVisitLoad(curr, code, /*isAtomic=*/true)) break;
19361981
if (maybeVisitStore(curr, code, /*isAtomic=*/true)) break;
1982+
if (maybeVisitAtomicRMW(curr, code)) break;
19371983
throw ParseException("invalid code after atomic prefix: " + std::to_string(code));
19381984
}
19391985
default: {
@@ -2282,6 +2328,50 @@ bool WasmBinaryBuilder::maybeVisitStore(Expression*& out, uint8_t code, bool isA
22822328
return true;
22832329
}
22842330

2331+
2332+
bool WasmBinaryBuilder::maybeVisitAtomicRMW(Expression*& out, uint8_t code) {
2333+
if (code < BinaryConsts::AtomicRMWOps_Begin || code > BinaryConsts::AtomicRMWOps_End) return false;
2334+
auto* curr = allocator.alloc<AtomicRMW>();
2335+
2336+
// Set curr to the given opcode, type and size.
2337+
#define SET(opcode, optype, size) \
2338+
curr->op = opcode; \
2339+
curr->type = optype; \
2340+
curr->bytes = size
2341+
2342+
// Handle the cases for all the valid types for a particular opcode
2343+
#define SET_FOR_OP(Op) \
2344+
case BinaryConsts::I32AtomicRMW##Op: SET(Op, i32, 4); break; \
2345+
case BinaryConsts::I32AtomicRMW##Op##8U: SET(Op, i32, 1); break; \
2346+
case BinaryConsts::I32AtomicRMW##Op##16U: SET(Op, i32, 2); break; \
2347+
case BinaryConsts::I64AtomicRMW##Op: SET(Op, i64, 8); break; \
2348+
case BinaryConsts::I64AtomicRMW##Op##8U: SET(Op, i64, 1); break; \
2349+
case BinaryConsts::I64AtomicRMW##Op##16U: SET(Op, i64, 2); break; \
2350+
case BinaryConsts::I64AtomicRMW##Op##32U: SET(Op, i64, 4); break;
2351+
2352+
switch(code) {
2353+
SET_FOR_OP(Add);
2354+
SET_FOR_OP(Sub);
2355+
SET_FOR_OP(And);
2356+
SET_FOR_OP(Or);
2357+
SET_FOR_OP(Xor);
2358+
SET_FOR_OP(Xchg);
2359+
default: WASM_UNREACHABLE();
2360+
}
2361+
#undef SET_FOR_OP
2362+
#undef SET
2363+
2364+
if (debug) std::cerr << "zz node: AtomicRMW" << std::endl;
2365+
Address readAlign;
2366+
readMemoryAccess(readAlign, curr->bytes, curr->offset);
2367+
if (readAlign != curr->bytes) throw ParseException("Align of AtomicRMW must match size");
2368+
curr->value = popNonVoidExpression();
2369+
curr->ptr = popNonVoidExpression();
2370+
curr->finalize();
2371+
out = curr;
2372+
return true;
2373+
}
2374+
22852375
bool WasmBinaryBuilder::maybeVisitConst(Expression*& out, uint8_t code) {
22862376
Const* curr;
22872377
if (debug) std::cerr << "zz node: Const, code " << code << std::endl;

0 commit comments

Comments
 (0)