Skip to content

Commit ff72d4f

Browse files
PR updates
1 parent d764f59 commit ff72d4f

1 file changed

Lines changed: 55 additions & 37 deletions

File tree

src/passes/GlobalEffects.cpp

Lines changed: 55 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
// PassOptions structure; see more details there.
2020
//
2121

22+
#include <ranges>
23+
2224
#include "ir/effects.h"
2325
#include "ir/module-utils.h"
2426
#include "ir/subtypes.h"
@@ -94,7 +96,7 @@ std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
9496
} else if (auto* callIndirect = curr->dynCast<CallIndirect>()) {
9597
type = callIndirect->heapType;
9698
} else {
97-
Fatal() << "Unexpected call type";
99+
WASM_UNREACHABLE("Unexpected call type");
98100
}
99101

100102
funcInfo.indirectCalledTypes.insert(type);
@@ -123,7 +125,8 @@ using CallGraphNode = std::variant<Function*, HeapType>;
123125
using CallGraph =
124126
std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>;
125127

126-
/* Build a call graph for indirect and direct calls.
128+
/*
129+
Build a call graph for indirect and direct calls.
127130
128131
key (caller) -> value (callee)
129132
Name -> Name : direct call
@@ -137,38 +140,46 @@ using CallGraph =
137140
138141
If we're running in an open world, we only include Name -> Name edges.
139142
*/
140-
CallGraph buildCallGraph(Module& module,
143+
CallGraph buildCallGraph(const Module& module,
141144
const std::map<Function*, FuncInfo>& funcInfos,
142145
bool closedWorld) {
143146
CallGraph callGraph;
144147

145148
std::unordered_set<HeapType> allFunctionTypes;
146149
for (const auto& [caller, callerInfo] : funcInfos) {
147150
auto& callees = callGraph[caller];
151+
152+
// Name -> Name
148153
for (Name calleeFunction : callerInfo.calledFunctions) {
149154
callees.insert(module.getFunction(calleeFunction));
150155
}
151156

152-
// In open world, just connect functions. Indirect calls are already handled
153-
// by giving such functions unknown effects.
154157
if (!closedWorld) {
155158
continue;
156159
}
157160

161+
// Name -> Type
158162
allFunctionTypes.insert(caller->type.getHeapType());
159163
for (HeapType calleeType : callerInfo.indirectCalledTypes) {
160164
callees.insert(calleeType);
161165
allFunctionTypes.insert(calleeType);
162166
}
167+
168+
// Type -> Name
163169
callGraph[caller->type.getHeapType()].insert(caller);
164170
}
165171

166-
SubTypes subtypes(module);
172+
// Type -> Type
167173
for (HeapType type : allFunctionTypes) {
168-
subtypes.iterSubTypes(type, [&callGraph, type](HeapType sub, auto _) {
169-
callGraph[type].insert(sub);
170-
return true;
171-
});
174+
// Not needed but during lookup we expect the key to exist.
175+
callGraph[type];
176+
177+
for (auto super = type.getDeclaredSuperType(); super;
178+
super = super->getDeclaredSuperType()) {
179+
if (allFunctionTypes.contains(*super)) {
180+
callGraph[*super].insert(type);
181+
}
182+
}
172183
}
173184

174185
return callGraph;
@@ -187,6 +198,31 @@ void mergeMaybeEffects(std::optional<EffectAnalyzer>& dest,
187198
dest->mergeIn(*src);
188199
}
189200

201+
template<std::ranges::common_range Range>
202+
requires std::same_as<std::ranges::range_value_t<Range>, CallGraphNode>
203+
struct CallGraphSCCs
204+
: SCCs<std::ranges::iterator_t<Range>, CallGraphSCCs<Range>> {
205+
const std::map<Function*, FuncInfo>& funcInfos;
206+
const CallGraph& callGraph;
207+
const Module& module;
208+
209+
CallGraphSCCs(
210+
Range&& nodes,
211+
const std::map<Function*, FuncInfo>& funcInfos,
212+
const std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>&
213+
callGraph,
214+
const Module& module)
215+
: SCCs<std::ranges::iterator_t<Range>, CallGraphSCCs<Range>>(
216+
std::ranges::begin(nodes), std::ranges::end(nodes)),
217+
funcInfos(funcInfos), callGraph(callGraph), module(module) {}
218+
219+
void pushChildren(CallGraphNode node) {
220+
for (CallGraphNode callee : callGraph.at(node)) {
221+
this->push(callee);
222+
}
223+
}
224+
};
225+
190226
// Propagate effects from callees to callers transitively
191227
// e.g. if A -> B -> C (A calls B which calls C)
192228
// Then B inherits effects from C and A inherits effects from both B and C.
@@ -200,29 +236,6 @@ void propagateEffects(const Module& module,
200236
const PassOptions& passOptions,
201237
std::map<Function*, FuncInfo>& funcInfos,
202238
const CallGraph& callGraph) {
203-
struct CallGraphSCCs
204-
: SCCs<std::vector<CallGraphNode>::const_iterator, CallGraphSCCs> {
205-
const std::map<Function*, FuncInfo>& funcInfos;
206-
const CallGraph& callGraph;
207-
const Module& module;
208-
209-
CallGraphSCCs(
210-
const std::vector<CallGraphNode>& nodes,
211-
const std::map<Function*, FuncInfo>& funcInfos,
212-
const std::unordered_map<CallGraphNode,
213-
std::unordered_set<CallGraphNode>>& callGraph,
214-
const Module& module)
215-
: SCCs<std::vector<CallGraphNode>::const_iterator, CallGraphSCCs>(
216-
nodes.begin(), nodes.end()),
217-
funcInfos(funcInfos), callGraph(callGraph), module(module) {}
218-
219-
void pushChildren(CallGraphNode node) {
220-
for (CallGraphNode callee : callGraph.at(node)) {
221-
push(callee);
222-
}
223-
}
224-
};
225-
226239
// We only care about Functions that are roots, not types
227240
// A type would be a root if a function exists with that type, but no-one
228241
// indirect calls the type.
@@ -231,11 +244,16 @@ void propagateEffects(const Module& module,
231244
allFuncs.push_back(func);
232245
}
233246

234-
CallGraphSCCs sccs(allFuncs, funcInfos, callGraph, module);
247+
auto funcNodes = std::views::keys(callGraph) |
248+
std::views::filter([](auto node) {
249+
return std::holds_alternative<Function*>(node);
250+
}) |
251+
std::views::common;
252+
CallGraphSCCs sccs(std::move(funcNodes), funcInfos, callGraph, module);
235253

236254
std::vector<std::optional<EffectAnalyzer>> componentEffects;
237255
// Points to an index in componentEffects
238-
std::unordered_map<CallGraphNode, Index> funcComponents;
256+
std::unordered_map<CallGraphNode, Index> nodeComponents;
239257

240258
for (auto ccIterator : sccs) {
241259
std::optional<EffectAnalyzer>& ccEffects =
@@ -244,7 +262,7 @@ void propagateEffects(const Module& module,
244262

245263
std::vector<Function*> ccFuncs;
246264
for (CallGraphNode node : cc) {
247-
funcComponents.emplace(node, componentEffects.size() - 1);
265+
nodeComponents.emplace(node, componentEffects.size() - 1);
248266
if (auto** func = std::get_if<Function*>(&node)) {
249267
ccFuncs.push_back(*func);
250268
}
@@ -253,7 +271,7 @@ void propagateEffects(const Module& module,
253271
std::unordered_set<int> calleeSccs;
254272
for (CallGraphNode caller : cc) {
255273
for (CallGraphNode callee : callGraph.at(caller)) {
256-
calleeSccs.insert(funcComponents.at(callee));
274+
calleeSccs.insert(nodeComponents.at(callee));
257275
}
258276
}
259277

0 commit comments

Comments
 (0)