Skip to content

Commit ec53d11

Browse files
authored
Refactor and optimize binary writing type collection (#2478)
Create a new ParallelFunctionAnalysis helper, which lets us run in parallel on all functions and collect info from them, without manually handling locks etc. Use that in the binary writing code's type collection logic, avoiding a lock for each type increment. Also add Signature printing which was useful to debug this.
1 parent 7665f70 commit ec53d11

6 files changed

Lines changed: 136 additions & 81 deletions

File tree

src/ir/module-utils.h

Lines changed: 67 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -262,33 +262,20 @@ template<typename T> inline void iterDefinedEvents(Module& wasm, T visitor) {
262262
}
263263
}
264264

265-
// Helper class for analyzing the call graph.
266-
//
267-
// Provides hooks for running some initial calculation on each function (which
268-
// is done in parallel), writing to a FunctionInfo structure for each function.
269-
// Then you can call propagateBack() to propagate a property of interest to the
270-
// calling functions, transitively.
271-
//
272-
// For example, if some functions are known to call an import "foo", then you
273-
// can use this to find which functions call something that might eventually
274-
// reach foo, by initially marking the direct callers as "calling foo" and
275-
// propagating that backwards.
276-
template<typename T> struct CallGraphPropertyAnalysis {
265+
// Helper class for performing an operation on all the functions in the module,
266+
// in parallel, with an Info object for each one that can contain results of
267+
// some computation that the operation performs.
268+
// The operation performend should not modify the wasm module in any way.
269+
// TODO: enforce this
270+
template<typename T> struct ParallelFunctionAnalysis {
277271
Module& wasm;
278272

279-
// The basic information for each function about whom it calls and who is
280-
// called by it.
281-
struct FunctionInfo {
282-
std::set<Function*> callsTo;
283-
std::set<Function*> calledBy;
284-
};
285-
286273
typedef std::map<Function*, T> Map;
287274
Map map;
288275

289276
typedef std::function<void(Function*, T&)> Func;
290277

291-
CallGraphPropertyAnalysis(Module& wasm, Func work) : wasm(wasm) {
278+
ParallelFunctionAnalysis(Module& wasm, Func work) : wasm(wasm) {
292279
// Fill in map, as we operate on it in parallel (each function to its own
293280
// entry).
294281
for (auto& func : wasm.functions) {
@@ -304,30 +291,78 @@ template<typename T> struct CallGraphPropertyAnalysis {
304291

305292
struct Mapper : public WalkerPass<PostWalker<Mapper>> {
306293
bool isFunctionParallel() override { return true; }
294+
bool modifiesBinaryenIR() override { return false; }
307295

308-
Mapper(Module* module, Map* map, Func work)
296+
Mapper(Module& module, Map& map, Func work)
309297
: module(module), map(map), work(work) {}
310298

311299
Mapper* create() override { return new Mapper(module, map, work); }
312300

313-
void visitCall(Call* curr) {
314-
(*map)[this->getFunction()].callsTo.insert(
315-
module->getFunction(curr->target));
316-
}
317-
318-
void visitFunction(Function* curr) {
319-
assert((*map).count(curr));
320-
work(curr, (*map)[curr]);
301+
void doWalkFunction(Function* curr) {
302+
assert(map.count(curr));
303+
work(curr, map[curr]);
321304
}
322305

323306
private:
324-
Module* module;
325-
Map* map;
307+
Module& module;
308+
Map& map;
326309
Func work;
327310
};
328311

329312
PassRunner runner(&wasm);
330-
Mapper(&wasm, &map, work).run(&runner, &wasm);
313+
Mapper(wasm, map, work).run(&runner, &wasm);
314+
}
315+
};
316+
317+
// Helper class for analyzing the call graph.
318+
//
319+
// Provides hooks for running some initial calculation on each function (which
320+
// is done in parallel), writing to a FunctionInfo structure for each function.
321+
// Then you can call propagateBack() to propagate a property of interest to the
322+
// calling functions, transitively.
323+
//
324+
// For example, if some functions are known to call an import "foo", then you
325+
// can use this to find which functions call something that might eventually
326+
// reach foo, by initially marking the direct callers as "calling foo" and
327+
// propagating that backwards.
328+
template<typename T> struct CallGraphPropertyAnalysis {
329+
Module& wasm;
330+
331+
// The basic information for each function about whom it calls and who is
332+
// called by it.
333+
struct FunctionInfo {
334+
std::set<Function*> callsTo;
335+
std::set<Function*> calledBy;
336+
};
337+
338+
typedef std::map<Function*, T> Map;
339+
Map map;
340+
341+
typedef std::function<void(Function*, T&)> Func;
342+
343+
CallGraphPropertyAnalysis(Module& wasm, Func work) : wasm(wasm) {
344+
ParallelFunctionAnalysis<T> analysis(wasm, [&](Function* func, T& info) {
345+
work(func, info);
346+
if (func->imported()) {
347+
return;
348+
}
349+
struct Mapper : public PostWalker<Mapper> {
350+
Mapper(Module* module, T& info, Func work)
351+
: module(module), info(info), work(work) {}
352+
353+
void visitCall(Call* curr) {
354+
info.callsTo.insert(module->getFunction(curr->target));
355+
}
356+
357+
private:
358+
Module* module;
359+
T& info;
360+
Func work;
361+
} mapper(&wasm, info, work);
362+
mapper.walk(func->body);
363+
});
364+
365+
map.swap(analysis.map);
331366

332367
// Find what is called by what.
333368
for (auto& pair : map) {

src/wasm-type.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,6 @@ struct ResultType {
9696
std::string toString() const;
9797
};
9898

99-
std::ostream& operator<<(std::ostream& os, Type t);
100-
std::ostream& operator<<(std::ostream& os, ParamType t);
101-
std::ostream& operator<<(std::ostream& os, ResultType t);
102-
10399
struct Signature {
104100
Type params;
105101
Type results;
@@ -112,6 +108,11 @@ struct Signature {
112108
bool operator<(const Signature& other) const;
113109
};
114110

111+
std::ostream& operator<<(std::ostream& os, Type t);
112+
std::ostream& operator<<(std::ostream& os, ParamType t);
113+
std::ostream& operator<<(std::ostream& os, ResultType t);
114+
std::ostream& operator<<(std::ostream& os, Signature t);
115+
115116
constexpr Type none = Type::none;
116117
constexpr Type i32 = Type::i32;
117118
constexpr Type i64 = Type::i64;

src/wasm/wasm-binary.cpp

Lines changed: 29 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -16,68 +16,52 @@
1616

1717
#include <algorithm>
1818
#include <fstream>
19-
#include <shared_mutex>
2019

20+
#include "ir/module-utils.h"
2121
#include "support/bits.h"
2222
#include "wasm-binary.h"
2323
#include "wasm-stack.h"
2424

2525
namespace wasm {
2626

2727
void WasmBinaryWriter::prepare() {
28-
// Collect function types and their frequencies
29-
using Counts = std::unordered_map<Signature, size_t>;
30-
using AtomicCounts = std::unordered_map<Signature, std::atomic_size_t>;
28+
// Collect function types and their frequencies. Collect information in each
29+
// function in parallel, then merge.
30+
typedef std::unordered_map<Signature, size_t> Counts;
31+
ModuleUtils::ParallelFunctionAnalysis<Counts> analysis(
32+
*wasm, [&](Function* func, Counts& counts) {
33+
if (func->imported()) {
34+
return;
35+
}
36+
struct TypeCounter : PostWalker<TypeCounter> {
37+
Module& wasm;
38+
Counts& counts;
39+
40+
TypeCounter(Module& wasm, Counts& counts)
41+
: wasm(wasm), counts(counts) {}
42+
43+
void visitCallIndirect(CallIndirect* curr) {
44+
auto* type = wasm.getFunctionType(curr->fullType);
45+
Signature sig(Type(type->params), type->result);
46+
counts[sig]++;
47+
}
48+
};
49+
TypeCounter(*wasm, counts).walk(func->body);
50+
});
51+
// Collect all the counts.
3152
Counts counts;
3253
for (auto& curr : wasm->functions) {
3354
counts[Signature(Type(curr->params), curr->result)]++;
3455
}
3556
for (auto& curr : wasm->events) {
3657
counts[curr->sig]++;
3758
}
38-
39-
// Parallelize collection of call_indirect type counts
40-
struct TypeCounter : WalkerPass<PostWalker<TypeCounter>> {
41-
AtomicCounts& counts;
42-
std::shared_timed_mutex& mutex;
43-
TypeCounter(AtomicCounts& counts, std::shared_timed_mutex& mutex)
44-
: counts(counts), mutex(mutex) {}
45-
bool isFunctionParallel() override { return true; }
46-
bool modifiesBinaryenIR() override { return false; }
47-
void visitCallIndirect(CallIndirect* curr) {
48-
auto* type = getModule()->getFunctionType(curr->fullType);
49-
Signature sig(Type(type->params), type->result);
50-
{
51-
std::shared_lock<std::shared_timed_mutex> lock(mutex);
52-
auto it = counts.find(sig);
53-
if (it != counts.end()) {
54-
it->second++;
55-
return;
56-
}
57-
}
58-
{
59-
std::lock_guard<std::shared_timed_mutex> lock(mutex);
60-
counts[sig]++;
61-
}
59+
for (auto& pair : analysis.map) {
60+
Counts& functionCounts = pair.second;
61+
for (auto& innerPair : functionCounts) {
62+
counts[innerPair.first] += innerPair.second;
6263
}
63-
Pass* create() override { return new TypeCounter(counts, mutex); }
64-
};
65-
66-
std::shared_timed_mutex mutex;
67-
AtomicCounts parallelCounts;
68-
for (auto& kv : counts) {
69-
parallelCounts[kv.first] = 0;
7064
}
71-
72-
TypeCounter counter(parallelCounts, mutex);
73-
PassRunner runner(wasm);
74-
runner.setIsNested(true);
75-
counter.run(&runner, wasm);
76-
77-
for (auto& kv : parallelCounts) {
78-
counts[kv.first] += kv.second;
79-
}
80-
8165
std::vector<std::pair<Signature, size_t>> sorted(counts.begin(),
8266
counts.end());
8367
std::sort(sorted.begin(), sorted.end(), [&](auto a, auto b) {

src/wasm/wasm-type.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,10 @@ std::ostream& operator<<(std::ostream& os, ResultType param) {
222222
return printPrefixedTypes(os, "result", param.type);
223223
}
224224

225+
std::ostream& operator<<(std::ostream& os, Signature sig) {
226+
return os << "Signature(" << sig.params << " => " << sig.results << ")";
227+
}
228+
225229
std::string Type::toString() const { return genericToString(*this); }
226230

227231
std::string ParamType::toString() const { return genericToString(*this); }

test/unit/input/stack_ir.wast

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
(module
2+
(import "env" "bar" (func $bar (param i32) (result i32)))
3+
(func "foo1" (result i32)
4+
(local $x i32)
5+
(local.set $x (call $bar (i32.const 0)))
6+
(drop
7+
(call $bar (i32.const 1))
8+
)
9+
(local.get $x) ;; local2stack can help here
10+
)
11+
)
12+

test/unit/test_stack_ir.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import os
2+
from scripts.test import shared
3+
from . import utils
4+
5+
6+
class StackIRTest(utils.BinaryenTestCase):
7+
# test that stack IR opts make a difference.
8+
def test_stack_ir_opts(self):
9+
path = self.input_path('stack_ir.wast')
10+
opt = shared.run_process(shared.WASM_OPT + [path, '-O', '--generate-stack-ir', '--optimize-stack-ir', '--print-stack-ir', '-o', 'a.wasm'], capture_output=True).stdout
11+
nonopt = shared.run_process(shared.WASM_OPT + [path, '-O', '--generate-stack-ir', '--print-stack-ir', '-o', 'b.wasm'], capture_output=True).stdout
12+
# see a difference in the printed stack IR (the optimizations let us
13+
# remove a pair of local.set/get)
14+
self.assertNotEqual(opt, nonopt)
15+
self.assertLess(len(opt), len(nonopt))
16+
# see a difference in the actual emitted wasm binary.
17+
opt_size = os.path.getsize('a.wasm')
18+
nonopt_size = os.path.getsize('b.wasm')
19+
self.assertLess(opt_size, nonopt_size)

0 commit comments

Comments
 (0)