Skip to content

Commit 25cbf64

Browse files
authored
Refactor validation failure and printing, validate atomic memory (#1090)
1 parent dec4529 commit 25cbf64

2 files changed

Lines changed: 76 additions & 33 deletions

File tree

src/wasm-validator.h

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,26 @@
3838
#define wasm_wasm_validator_h
3939

4040
#include <set>
41+
#include <sstream>
4142

4243
#include "wasm.h"
4344
#include "wasm-printing.h"
4445

4546
namespace wasm {
4647

48+
// Print anything that can be streamed to an ostream
49+
template <typename T>
50+
inline std::ostream& printModuleComponent(T curr, std::ostream& stream) {
51+
stream << curr << std::endl;
52+
return stream;
53+
}
54+
// Specialization for Expressions to print type info too
55+
template <>
56+
inline std::ostream& printModuleComponent(Expression* curr, std::ostream& stream) {
57+
WasmPrinter::printExpression(curr, stream, false, true) << std::endl;
58+
return stream;
59+
}
60+
4761
struct WasmValidator : public PostWalker<WasmValidator> {
4862
bool valid = true;
4963

@@ -123,6 +137,8 @@ struct WasmValidator : public PostWalker<WasmValidator> {
123137
void visitSetLocal(SetLocal *curr);
124138
void visitLoad(Load *curr);
125139
void visitStore(Store *curr);
140+
void visitAtomicRMW(AtomicRMW *curr);
141+
void visitAtomicCmpxchg(AtomicCmpxchg *curr);
126142
void visitBinary(Binary *curr);
127143
void visitUnary(Unary *curr);
128144
void visitSelect(Select* curr);
@@ -144,21 +160,22 @@ struct WasmValidator : public PostWalker<WasmValidator> {
144160

145161
// helpers
146162
private:
147-
std::ostream& fail();
163+
template <typename T, typename S>
164+
std::ostream& fail(T curr, S text);
165+
std::ostream& printFailureHeader();
166+
148167
template<typename T>
149168
bool shouldBeTrue(bool result, T curr, const char* text) {
150169
if (!result) {
151-
fail() << "unexpected false: " << text << ", on \n" << curr << std::endl;
152-
valid = false;
170+
fail(curr, "unexpected false: " + std::string(text));
153171
return false;
154172
}
155173
return result;
156174
}
157175
template<typename T>
158176
bool shouldBeFalse(bool result, T curr, const char* text) {
159177
if (result) {
160-
fail() << "unexpected true: " << text << ", on \n" << curr << std::endl;
161-
valid = false;
178+
fail(curr, "unexpected true: " + std::string(text));
162179
return false;
163180
}
164181
return result;
@@ -167,18 +184,9 @@ struct WasmValidator : public PostWalker<WasmValidator> {
167184
template<typename T, typename S>
168185
bool shouldBeEqual(S left, S right, T curr, const char* text) {
169186
if (left != right) {
170-
fail() << "" << left << " != " << right << ": " << text << ", on \n";
171-
WasmPrinter::printExpression(curr, std::cerr, false, true) << std::endl;
172-
valid = false;
173-
return false;
174-
}
175-
return true;
176-
}
177-
template<typename T, typename S, typename U>
178-
bool shouldBeEqual(S left, S right, T curr, U other, const char* text) {
179-
if (left != right) {
180-
fail() << "" << left << " != " << right << ": " << text << ", on \n" << curr << " / " << other << std::endl;
181-
valid = false;
187+
std::ostringstream ss;
188+
ss << left << " != " << right << ": " << text;
189+
fail(curr, ss.str());
182190
return false;
183191
}
184192
return true;
@@ -187,9 +195,9 @@ struct WasmValidator : public PostWalker<WasmValidator> {
187195
template<typename T, typename S>
188196
bool shouldBeEqualOrFirstIsUnreachable(S left, S right, T curr, const char* text) {
189197
if (left != unreachable && left != right) {
190-
fail() << "" << left << " != " << right << ": " << text << ", on \n";
191-
WasmPrinter::printExpression(curr, std::cerr, false, true) << std::endl;
192-
valid = false;
198+
std::ostringstream ss;
199+
ss << left << " != " << right << ": " << text;
200+
fail(curr, ss.str());
193201
return false;
194202
}
195203
return true;
@@ -198,14 +206,17 @@ struct WasmValidator : public PostWalker<WasmValidator> {
198206
template<typename T, typename S>
199207
bool shouldBeUnequal(S left, S right, T curr, const char* text) {
200208
if (left == right) {
201-
fail() << "" << left << " == " << right << ": " << text << ", on \n" << curr << std::endl;
202-
valid = false;
209+
std::ostringstream ss;
210+
ss << left << " == " << right << ": " << text;
211+
fail(curr, ss.str());
203212
return false;
204213
}
205214
return true;
206215
}
207216

208-
void validateAlignment(size_t align, WasmType type, Index bytes);
217+
void validateAlignment(size_t align, WasmType type, Index bytes, bool isAtomic,
218+
Expression* curr);
219+
void validateMemBytes(uint8_t bytes, WasmType ty, Expression* curr);
209220
void validateBinaryenIR(Module& wasm);
210221
};
211222

src/wasm/wasm-validator.cpp

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -219,15 +219,36 @@ void WasmValidator::visitSetLocal(SetLocal *curr) {
219219
}
220220
}
221221
void WasmValidator::visitLoad(Load *curr) {
222-
validateAlignment(curr->align, curr->type, curr->bytes);
222+
validateMemBytes(curr->bytes, curr->type, curr);
223+
validateAlignment(curr->align, curr->type, curr->bytes, curr->isAtomic, curr);
223224
shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "load pointer type must be i32");
224225
}
225226
void WasmValidator::visitStore(Store *curr) {
226-
validateAlignment(curr->align, curr->type, curr->bytes);
227+
validateMemBytes(curr->bytes, curr->valueType, curr);
228+
validateAlignment(curr->align, curr->type, curr->bytes, curr->isAtomic, curr);
227229
shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "store pointer type must be i32");
228230
shouldBeUnequal(curr->value->type, none, curr, "store value type must not be none");
229231
shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->valueType, curr, "store value type must match");
230232
}
233+
void WasmValidator::visitAtomicRMW(AtomicRMW* curr) {
234+
validateMemBytes(curr->bytes, curr->type, curr);
235+
}
236+
void WasmValidator::visitAtomicCmpxchg(AtomicCmpxchg* curr) {
237+
validateMemBytes(curr->bytes, curr->type, curr);
238+
}
239+
void WasmValidator::validateMemBytes(uint8_t bytes, WasmType ty, Expression* curr) {
240+
switch (bytes) {
241+
case 1:
242+
case 2:
243+
case 4:
244+
break;
245+
case 8: {
246+
shouldBeEqual(getWasmTypeSize(ty), 8U, curr, "8-byte mem operations are only allowed with 8-byte wasm types");
247+
break;
248+
}
249+
default: fail("Memory operations must be 1,2,4, or 8 bytes", curr);
250+
}
251+
}
231252
void WasmValidator::visitBinary(Binary *curr) {
232253
if (curr->left->type != unreachable && curr->right->type != unreachable) {
233254
shouldBeEqual(curr->left->type, curr->right->type, curr, "binary child types must be equal");
@@ -566,28 +587,32 @@ void WasmValidator::visitModule(Module *curr) {
566587
}
567588
}
568589

569-
void WasmValidator::validateAlignment(size_t align, WasmType type, Index bytes) {
590+
void WasmValidator::validateAlignment(size_t align, WasmType type, Index bytes,
591+
bool isAtomic, Expression* curr) {
592+
if (isAtomic) {
593+
shouldBeEqual(align, (size_t)bytes, curr, "atomic accesses must have natural alignment");
594+
return;
595+
}
570596
switch (align) {
571597
case 1:
572598
case 2:
573599
case 4:
574600
case 8: break;
575601
default:{
576-
fail() << "bad alignment: " << align << std::endl;
577-
valid = false;
602+
fail("bad alignment: " + std::to_string(align), curr);
578603
break;
579604
}
580605
}
581-
shouldBeTrue(align <= bytes, align, "alignment must not exceed natural");
606+
shouldBeTrue(align <= bytes, curr, "alignment must not exceed natural");
582607
switch (type) {
583608
case i32:
584609
case f32: {
585-
shouldBeTrue(align <= 4, align, "alignment must not exceed natural");
610+
shouldBeTrue(align <= 4, curr, "alignment must not exceed natural");
586611
break;
587612
}
588613
case i64:
589614
case f64: {
590-
shouldBeTrue(align <= 8, align, "alignment must not exceed natural");
615+
shouldBeTrue(align <= 8, curr, "alignment must not exceed natural");
591616
break;
592617
}
593618
default: {}
@@ -614,7 +639,7 @@ void WasmValidator::validateBinaryenIR(Module& wasm) {
614639
// The block has an added type, not derived from the ast itself, so it is
615640
// ok for it to be either i32 or unreachable.
616641
if (!(isConcreteWasmType(oldType) && newType == unreachable)) {
617-
parent.fail() << "stale type found in " << (getFunction() ? getFunction()->name : Name("(global scope)")) << " on " << curr << "\n(marked as " << printWasmType(oldType) << ", should be " << printWasmType(newType) << ")\n";
642+
parent.printFailureHeader() << "stale type found in " << (getFunction() ? getFunction()->name : Name("(global scope)")) << " on " << curr << "\n(marked as " << printWasmType(oldType) << ", should be " << printWasmType(newType) << ")\n";
618643
parent.valid = false;
619644
}
620645
curr->type = oldType;
@@ -625,7 +650,14 @@ void WasmValidator::validateBinaryenIR(Module& wasm) {
625650
binaryenIRValidator.walkModule(&wasm);
626651
}
627652

628-
std::ostream& WasmValidator::fail() {
653+
template <typename T, typename S>
654+
std::ostream& WasmValidator::fail(T curr, S text) {
655+
valid = false;
656+
auto& ret = printFailureHeader() << text << ", on \n";
657+
return printModuleComponent(curr, ret);
658+
}
659+
660+
std::ostream& WasmValidator::printFailureHeader() {
629661
Colors::red(std::cerr);
630662
if (getFunction()) {
631663
std::cerr << "[wasm-validator error in function ";

0 commit comments

Comments
 (0)