Skip to content

Commit 4356c25

Browse files
committed
Simplify branch hint instrumentation
The previous branch hint instrumentation logic would introduce a scratch local to hold the condition so it could be passed into both the logging function and the original branching instruction. The de-instrumentation pass would then need to find this local and attempt to undo the data flow change. Simplify all of this by having the logging function return the condition value so it can interpose between the condition and the branch without any new locals. De-instrumentation can now just replace the call to the log function with its condition parameter. To allow further simplification, also change the order of parameters to the logging function so the condition value is the first parameter. This ensures that we don't need to introduce a scratch local even when the condition is a `pop`, because the pop will remain the leftmost leaf expression in the catch body.
1 parent f051945 commit 4356c25

File tree

8 files changed

+157
-560
lines changed

8 files changed

+157
-560
lines changed

scripts/fuzz_opt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2281,7 +2281,7 @@ def handle(self, wasm):
22812281
for line in out.splitlines():
22822282
if line.startswith(LOG_BRANCH_PREFIX):
22832283
# (1:-1 strips away the '[', ']' at the edges)
2284-
_, _, id_, hint, actual = line[1:-1].split(' ')
2284+
_, _, actual, hint, id_ = line[1:-1].split(' ')
22852285
all_ids.add(id_)
22862286
if hint != actual:
22872287
# This hint was misleading.

scripts/fuzz_shell.js

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -400,8 +400,9 @@ var baseImports = {
400400
});
401401
},
402402

403-
'log-branch': (id, expected, actual) => {
404-
console.log(`[LoggingExternalInterface log-branch ${id} ${expected} ${actual}]`);
403+
'log-branch': (actual, expected, id) => {
404+
console.log(`[LoggingExternalInterface log-branch ${actual} ${expected} ${id}]`);
405+
return actual;
405406
},
406407
},
407408
// Emscripten support.

src/passes/InstrumentBranchHints.cpp

Lines changed: 40 additions & 179 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@
2828
// into
2929
//
3030
// @metadata.branch.hint B
31-
// ;; log the ID of the condition (123), the prediction (B), and the actual
32-
// ;; runtime result (temp == condition).
33-
// if (temp = condition; log(123, B, temp); temp) {
31+
// ;; log the actual runtime result (condition), the prediction (B), and the
32+
// ;; ID (123), and return that result.
33+
// if (log(condition, B, 123)) {
3434
// X
3535
// } else {
3636
// Y
@@ -39,19 +39,20 @@
3939
// Concretely, we emit calls to this logging function:
4040
//
4141
// (import "fuzzing-support" "log-branch"
42-
// (func $log-branch (param i32 i32 i32)) ;; ID, prediction, actual
42+
// (func $log-branch (param i32 i32 i32) (result i32))
4343
// )
4444
//
4545
// This can be used to verify that branch hints are accurate, by implementing
4646
// the import like this for example:
4747
//
48-
// imports['fuzzing-support']['log-branch'] = (id, prediction, actual) => {
48+
// imports['fuzzing-support']['log-branch'] = (actual, prediction, id) => {
4949
// // We only care about truthiness of the expected and actual values.
5050
// expected = +!!expected;
5151
// actual = +!!actual;
5252
// // Throw if the hint said this branch would be taken, but it was not, or
5353
// // vice versa.
5454
// if (expected != actual) throw `Bad branch hint! (${id})`;
55+
// return actual;
5556
// };
5657
//
5758
// A pass to delete branch hints is also provided, which finds instrumentations
@@ -63,28 +64,28 @@
6364
// would do this transformation:
6465
//
6566
// @metadata.branch.hint A
66-
// if (temp = condition; log(10, A, temp); temp) { // 10 matches one of 10,20
67+
// if (log(condition, A, 10)) { // 10 matches one of 10,20
6768
// X
6869
// }
6970
// @metadata.branch.hint B
70-
// if (temp = condition; log(99, B, temp); temp) { // 99 does not match
71+
// if (log(condition, B, 99)) { // 99 does not match
7172
// Y
7273
// }
7374
//
7475
// =>
7576
//
7677
// // Used to be a branch hint here, but it was deleted.
77-
// if (temp = condition; log(10, A, temp); temp) {
78+
// if (log(condition, A, 10)) {
7879
// X
7980
// }
8081
// @metadata.branch.hint B // this one is unmodified.
81-
// if (temp = condition; log(99, B, temp); temp) {
82+
// if (log(condition, B, 99)) {
8283
// Y
8384
// }
8485
//
8586
// A pass to undo the instrumentation is also provided, which does
8687
//
87-
// if (temp = condition; log(123, A, temp); temp) {
88+
// if (log(condition, A, 123)) {
8889
// X
8990
// }
9091
//
@@ -95,14 +96,8 @@
9596
// }
9697
//
9798

98-
#include "ir/drop.h"
9999
#include "ir/effects.h"
100-
#include "ir/eh-utils.h"
101-
#include "ir/find_all.h"
102-
#include "ir/local-graph.h"
103100
#include "ir/names.h"
104-
#include "ir/parents.h"
105-
#include "ir/properties.h"
106101
#include "ir/utils.h"
107102
#include "pass.h"
108103
#include "support/string.h"
@@ -133,8 +128,6 @@ int branchId = 1;
133128
struct InstrumentBranchHints
134129
: public WalkerPass<PostWalker<InstrumentBranchHints>> {
135130

136-
using Super = WalkerPass<PostWalker<InstrumentBranchHints>>;
137-
138131
// The internal name of our import.
139132
Name logBranch;
140133

@@ -148,8 +141,6 @@ struct InstrumentBranchHints
148141

149142
// TODO: BrOn, but the condition there is not an i32
150143

151-
bool addedInstrumentation = false;
152-
153144
template<typename T> void processCondition(T* curr) {
154145
if (curr->condition->type == Type::unreachable) {
155146
// This branch is not even reached.
@@ -167,41 +158,32 @@ struct InstrumentBranchHints
167158
int id = branchId++;
168159

169160
// Instrument the condition.
170-
auto tempLocal = builder.addVar(getFunction(), Type::i32);
171-
auto* set = builder.makeLocalSet(tempLocal, curr->condition);
172161
auto* idConst = builder.makeConst(Literal(int32_t(id)));
173162
auto* guess = builder.makeConst(Literal(int32_t(*likely)));
174-
auto* get1 = builder.makeLocalGet(tempLocal, Type::i32);
175-
auto* log = builder.makeCall(logBranch, {idConst, guess, get1}, Type::none);
176-
auto* get2 = builder.makeLocalGet(tempLocal, Type::i32);
177-
curr->condition = builder.makeBlock({set, log, get2});
178-
addedInstrumentation = true;
179-
}
180-
181-
void doWalkFunction(Function* func) {
182-
Super::doWalkFunction(func);
183163

184-
// Our added blocks may have caused nested pops.
185-
if (addedInstrumentation) {
186-
EHUtils::handleBlockNestedPops(func, *getModule());
187-
addedInstrumentation = false;
188-
}
164+
curr->condition =
165+
builder.makeCall(logBranch, {curr->condition, guess, idConst}, Type::i32);
189166
}
190167

191168
void doWalkModule(Module* module) {
192169
if (auto existing = getLogBranchImport(module)) {
193170
// This file already has our import. We nop it out, as whatever the
194171
// current code does may be dangerous (it may log incorrect hints).
195172
auto* func = module->getFunction(existing);
196-
func->body = Builder(*module).makeNop();
173+
Builder builder(*module);
174+
if (func->getSig().results == Type::none) {
175+
func->body = builder.makeNop();
176+
} else {
177+
func->body = builder.makeUnreachable();
178+
}
197179
func->module = func->base = Name();
198180
func->type = func->type.with(Exact);
199181
}
200182

201183
// Add our import.
202184
auto* func = module->addFunction(Builder::makeFunction(
203185
Names::getValidFunctionName(*module, BASE),
204-
Type(Signature({Type::i32, Type::i32, Type::i32}, Type::none),
186+
Type(Signature({Type::i32, Type::i32, Type::i32}, Type::i32),
205187
NonNullable,
206188
Inexact),
207189
{}));
@@ -210,7 +192,7 @@ struct InstrumentBranchHints
210192
logBranch = func->name;
211193

212194
// Walk normally, using logBranch as we go.
213-
Super::doWalkModule(module);
195+
PostWalker<InstrumentBranchHints>::doWalkModule(module);
214196

215197
// Update ref.func type changes.
216198
ReFinalize().run(getPassRunner(), module);
@@ -228,12 +210,6 @@ struct InstrumentationProcessor : public WalkerPass<PostWalker<Sub>> {
228210
// The internal name of our import.
229211
Name logBranch;
230212

231-
// A LocalGraph, so we can identify the pattern.
232-
std::unique_ptr<LocalGraph> localGraph;
233-
234-
// A map of expressions to their parents, so we can identify the pattern.
235-
std::unique_ptr<Parents> parents;
236-
237213
Sub* self() { return static_cast<Sub*>(this); }
238214

239215
void visitIf(If* curr) { self()->processCondition(curr); }
@@ -246,15 +222,6 @@ struct InstrumentationProcessor : public WalkerPass<PostWalker<Sub>> {
246222

247223
// TODO: BrOn, but the condition there is not an i32
248224

249-
void doWalkFunction(Function* func) {
250-
localGraph = std::make_unique<LocalGraph>(func, this->getModule());
251-
localGraph->computeSetInfluences();
252-
253-
parents = std::make_unique<Parents>(func->body);
254-
255-
Super::doWalkFunction(func);
256-
}
257-
258225
void doWalkModule(Module* module) {
259226
logBranch = getLogBranchImport(module);
260227
if (!logBranch) {
@@ -267,73 +234,14 @@ struct InstrumentationProcessor : public WalkerPass<PostWalker<Sub>> {
267234

268235
// Helpers
269236

270-
// Instrumentation info for a chunk of code that is the result of the
271-
// instrumentation pass.
272-
struct Instrumentation {
273-
// The condition before the instrumentation (a pointer to it, so we can
274-
// replace it).
275-
Expression** originalCondition;
276-
// The local that the original condition is stored in temporarily.
277-
Index tempLocal;
278-
// The call to the logging that the instrumentation added.
279-
Call* call;
280-
};
281-
282-
// Check if an expression's condition is an instrumentation, and return the
283-
// info if so.
284-
std::optional<Instrumentation> getInstrumentation(Expression* condition) {
285-
// We must identify this pattern:
286-
//
287-
// (br_if
288-
// (block
289-
// (local.set $temp (condition))
290-
// (call $log (id, prediction, (local.get $temp)))
291-
// (local.get $temp)
292-
// )
293-
//
294-
// The block may vanish during roundtrip though, so we just follow back from
295-
// the last local.get, which appears in the condition:
296-
//
297-
// (local.set $temp (condition))
298-
// (call $log (id, prediction, (local.get $temp)))
299-
// (br_if
300-
// (local.get $temp)
301-
//
302-
auto* fallthrough = Properties::getFallthrough(
303-
condition, this->getPassOptions(), *this->getModule());
304-
auto* get = fallthrough->template dynCast<LocalGet>();
305-
if (!get) {
306-
return {};
307-
}
308-
auto& sets = localGraph->getSets(get);
309-
if (sets.size() != 1) {
310-
return {};
311-
}
312-
auto* set = *sets.begin();
313-
if (!set) {
314-
return {};
315-
}
316-
auto& gets = localGraph->getSetInfluences(set);
317-
if (gets.size() != 2) {
318-
return {};
319-
}
320-
// The set has two gets: the get in the condition we began at, and
321-
// another.
322-
LocalGet* otherGet = nullptr;
323-
for (auto* get2 : gets) {
324-
if (get2 != get) {
325-
otherGet = get2;
326-
}
327-
}
328-
assert(otherGet);
329-
// See if that other get is used in a logging. The parent should be a
330-
// logging call.
331-
auto* call = parents->getParent(otherGet)->template dynCast<Call>();
237+
// Check if an expression's condition is instrumented, and return the
238+
// instrumentation call if so. Otherwise return null.
239+
Call* getInstrumentation(Expression* condition) {
240+
auto* call = condition->dynCast<Call>();
332241
if (!call || call->target != logBranch) {
333-
return {};
242+
return nullptr;
334243
}
335-
// Great, this is indeed a prior instrumentation.
336-
return Instrumentation{&set->value, set->index, call};
244+
return call;
337245
}
338246
};
339247

@@ -344,8 +252,8 @@ struct DeleteBranchHints : public InstrumentationProcessor<DeleteBranchHints> {
344252
std::unordered_set<Index> idsToDelete;
345253

346254
template<typename T> void processCondition(T* curr) {
347-
if (auto info = getInstrumentation(curr->condition)) {
348-
if (auto* c = info->call->operands[0]->template dynCast<Const>()) {
255+
if (auto* call = getInstrumentation(curr->condition)) {
256+
if (auto* c = call->operands[2]->template dynCast<Const>()) {
349257
auto id = c->value.geti32();
350258
if (idsToDelete.contains(id)) {
351259
// Remove the branch hint.
@@ -368,78 +276,31 @@ struct DeleteBranchHints : public InstrumentationProcessor<DeleteBranchHints> {
368276
};
369277

370278
struct DeInstrumentBranchHints
371-
: public InstrumentationProcessor<DeInstrumentBranchHints> {
279+
: public WalkerPass<PostWalker<DeInstrumentBranchHints>> {
372280

373-
template<typename T> void processCondition(T* curr) {
374-
if (auto info = getInstrumentation(curr->condition)) {
375-
// Replace the instrumented condition with the original one (swap so that
376-
// the IR remains valid: we cannot use the same expression twice in our
377-
// IR, and the original condition is still used in another place, until
378-
// we remove the logging calls; since we will remove the calls anyhow, we
379-
// just need some valid IR there).
380-
//
381-
// Check for dangerous effects in the condition we are about to replace,
382-
// to avoid a situation where the condition looks like this:
383-
//
384-
// (set $temp (original condition))
385-
// ..effects..
386-
// (local.get $temp)
387-
//
388-
// We cannot replace all this with the original condition, as it would
389-
// remove the effects. (Even in that case we will remove the actual call
390-
// to log the branch hint, below, so this just prevents some cleanup that
391-
// is normally safe - the cleanup is mainly useful to allow inspection of
392-
// testcases for debugging.)
393-
EffectAnalyzer effects(getPassOptions(), *getModule(), curr->condition);
394-
// The only condition we allow is a write to the temp local from the
395-
// instrumentation, which getInstrumentation() verified has no other uses
396-
// than us.
397-
effects.localsWritten.erase(info->tempLocal);
398-
if (!effects.hasUnremovableSideEffects()) {
399-
std::swap(curr->condition, *info->originalCondition);
400-
}
401-
}
402-
}
281+
// The internal name of our import.
282+
Name logBranch;
403283

404-
void visitFunction(Function* func) {
405-
if (func->imported()) {
406-
return;
407-
}
408-
// At the very end, remove all logging calls (we use them during the main
409-
// walk to identify instrumentation).
410-
for (auto** callp : FindAllPointers<Call>(func->body).list) {
411-
auto* call = (*callp)->cast<Call>();
412-
if (call->target == logBranch) {
413-
Builder builder(*getModule());
414-
Expression* last;
415-
if (call->type == Type::none) {
416-
last = builder.makeNop();
417-
} else {
418-
last = builder.makeUnreachable();
419-
}
420-
*callp = getDroppedChildrenAndAppend(call,
421-
*getModule(),
422-
getPassOptions(),
423-
last,
424-
// We know the call is removable.
425-
DropMode::IgnoreParentEffects);
426-
}
284+
void visitCall(Call* curr) {
285+
if (curr->target == logBranch) {
286+
// Replace the call with its first operand (the original condition).
287+
replaceCurrent(curr->operands[0]);
427288
}
428289
}
429290

430291
void doWalkModule(Module* module) {
431-
auto logBranchImport = getLogBranchImport(module);
432-
if (!logBranchImport) {
292+
logBranch = getLogBranchImport(module);
293+
if (!logBranch) {
433294
Fatal()
434295
<< "No branch hint logging import found. Was this code instrumented?";
435296
}
436297

437298
// Mark the log-branch import as having no side effects - we are removing it
438299
// entirely here, and its effect should not stop us when we compute effects.
439-
module->getFunction(logBranchImport)->effects =
300+
module->getFunction(logBranch)->effects =
440301
std::make_shared<EffectAnalyzer>(getPassOptions(), *module);
441302

442-
InstrumentationProcessor<DeInstrumentBranchHints>::doWalkModule(module);
303+
WalkerPass<PostWalker<DeInstrumentBranchHints>>::doWalkModule(module);
443304
}
444305
};
445306

0 commit comments

Comments
 (0)