diff --git a/scripts/analyze_ci_stats.py b/scripts/analyze_ci_stats.py new file mode 100644 index 00000000000..924d3882060 --- /dev/null +++ b/scripts/analyze_ci_stats.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python3 +"""Analyze /tmp/ci_stats.csv produced by instrumented checkInvalidations. + +Usage: + 1. Apply instrumentation to SimplifyLocals.cpp (see ci_instrument.patch) + 2. Build and run: wasm-opt --simplify-locals ... -o /dev/null input.wasm + 3. Run: python3 scripts/analyze_ci_stats.py +""" + +import csv +import sys +from collections import Counter + +CSV_PATH = sys.argv[1] if len(sys.argv) > 1 else "/tmp/ci_stats.csv" + + +def main(): + total = 0 + total_work = 0 + fast = 0 + slow = 0 + fast_sinkables = 0 + slow_sinkables = 0 + fast_candidates = 0 + fast_breakdown = Counter() + categories = {} # category -> (calls, work) + slow_sink_sizes = [] + fast_sink_sizes = [] + + with open(CSV_PATH) as f: + reader = csv.DictReader(f) + for row in reader: + total += 1 + s = int(row["sinkables"]) + c = int(row["candidates"]) + rl = int(row["readsLocal"]) + wl = int(row["writesLocal"]) + ca = int(row["calls"]) + mem = int(row["memory"]) + gc = int(row["gc"]) + trap = int(row["trap"]) + cf = int(row["controlFlow"]) + gs = int(row["globalState"]) + total_work += s + + # Categorize by what effects are present + cats = [] + if rl or wl: + cats.append("local") + if ca: + cats.append("calls") + if mem: + cats.append("mem") + if gc: + cats.append("gc") + if trap: + cats.append("trap") + if cf: + cats.append("cf") + if gs: + cats.append("gs") + key = "+".join(sorted(cats)) if cats else "none" + if key not in categories: + categories[key] = [0, 0] + categories[key][0] += 1 + categories[key][1] += s + + if row["path"] == "fast": + fast += 1 + fast_sinkables += s + fast_candidates += c + if s > 0: + fast_sink_sizes.append(s) + if not rl and not wl: + fast_breakdown["no_local_effects"] += 1 + elif rl and not wl: + fast_breakdown["only_reads"] += 1 + elif not rl and wl: + fast_breakdown["only_writes"] += 1 + else: + fast_breakdown["reads_and_writes"] += 1 + else: + slow += 1 + slow_sinkables += s + if s > 0: + slow_sink_sizes.append(s) + + fast_sink_sizes.sort() + slow_sink_sizes.sort() + + def percentile(lst, p): + if not lst: + return 0 + k = int(len(lst) * p / 100) + return lst[min(k, len(lst) - 1)] + + print(f"=== checkInvalidations analysis ({CSV_PATH}) ===") + print(f"Total calls: {total:,}") + print(f"Total orderedAfter work: {total_work:,}") + print() + + # --- By effect category --- + print(f"{'Category':<45} {'Calls':>10} {'Work':>15} {'Work%':>7}") + print("-" * 80) + for k, (count, work) in sorted(categories.items(), key=lambda x: -x[1][1]): + if work > 0 or count > 1000: + print(f"{k:<45} {count:>10,} {work:>15,} {100*work/max(total_work,1):>6.1f}%") + print() + + # --- Fast path --- + print(f"FAST PATH: {fast:,} ({100*fast/max(total,1):.1f}%)") + fast_nonempty = len(fast_sink_sizes) + print(f" With sinkables>0: {fast_nonempty:,}") + print( + f" Avg sinkables (when >0): {fast_sinkables/max(fast_nonempty,1):.1f}" + ) + print(f" Total sinkables (work if no fast path): {fast_sinkables:,}") + print(f" Total candidates actually checked: {fast_candidates:,}") + saved = fast_sinkables - fast_candidates + print( + f" Work saved: {saved:,} ({100*saved/max(fast_sinkables,1):.1f}%)" + ) + print(f" Breakdown:") + for k, v in fast_breakdown.most_common(): + print(f" {k}: {v:,} ({100*v/max(fast,1):.1f}%)") + if fast_sink_sizes: + print( + f" Sinkable distribution: p50={percentile(fast_sink_sizes, 50)}, " + f"p90={percentile(fast_sink_sizes, 90)}, " + f"p99={percentile(fast_sink_sizes, 99)}, " + f"max={fast_sink_sizes[-1]}" + ) + print() + + # --- Slow path --- + print(f"SLOW PATH: {slow:,} ({100*slow/max(total,1):.1f}%)") + slow_nonempty = len(slow_sink_sizes) + print(f" With sinkables>0: {slow_nonempty:,}") + print( + f" Avg sinkables (when >0): {slow_sinkables/max(slow_nonempty,1):.1f}" + ) + print(f" Total orderedAfter checks: {slow_sinkables:,}") + if slow_sink_sizes: + print( + f" Sinkable distribution: p50={percentile(slow_sink_sizes, 50)}, " + f"p90={percentile(slow_sink_sizes, 90)}, " + f"p99={percentile(slow_sink_sizes, 99)}, " + f"max={slow_sink_sizes[-1]}" + ) + print() + + # --- Overall --- + print(f"OVERALL:") + print(f" Without optimization: {total_work:,} orderedAfter calls") + opt_work = fast_candidates + slow_sinkables + print(f" With fast path: {opt_work:,} orderedAfter calls") + print(f" Reduction: {100*(1 - opt_work/max(total_work,1)):.1f}%") + + +if __name__ == "__main__": + main() diff --git a/src/passes/SimplifyLocals.cpp b/src/passes/SimplifyLocals.cpp index 08791380a5e..22bd2287aa9 100644 --- a/src/passes/SimplifyLocals.cpp +++ b/src/passes/SimplifyLocals.cpp @@ -47,6 +47,7 @@ // #include "ir/equivalent_sets.h" +#include #include #include #include @@ -54,6 +55,7 @@ #include #include #include +#include #include #include #include @@ -61,6 +63,47 @@ namespace wasm { +// Instrumentation: append one line per checkInvalidations call to a CSV. +namespace { +FILE* ci_logFile = nullptr; +std::mutex ci_logMutex; + +void ci_log(const char* path, + size_t sinkables, + size_t candidates, + bool readsLocal, + bool writesLocal, + bool calls, + bool memory, + bool gc, + bool trap, + bool cf, + bool globalState) { + std::lock_guard lock(ci_logMutex); + if (!ci_logFile) { + ci_logFile = fopen("/tmp/ci_stats.csv", "w"); + if (!ci_logFile) + return; + fprintf(ci_logFile, + "path,sinkables,candidates,readsLocal,writesLocal,calls,memory,gc," + "trap,controlFlow,globalState\n"); + } + fprintf(ci_logFile, + "%s,%zu,%zu,%d,%d,%d,%d,%d,%d,%d,%d\n", + path, + sinkables, + candidates, + readsLocal, + writesLocal, + calls, + memory, + gc, + trap, + cf, + globalState); +} +} // anonymous namespace + // Main class template; + using Sinkables = std::unordered_map; // locals in current linear execution trace, which we try to sink Sinkables sinkables; + // Reverse index: for each local L, tracks which sinkable keys have effects + // that read L. Used to find read-write conflicts when the current expression + // writes L. + std::unordered_map> localReadBySinkable; + + // Reverse index: for each local L, tracks which sinkable keys have effects + // that write L. A sinkable at key K always writes K, but may also write + // other locals if its value contains nested local.sets. + std::unordered_map> localWrittenBySinkable; + + // Sinkable keys whose effects include transfersControlFlow(). These must + // be invalidated whenever the current expression has side effects (including + // local writes), due to the asymmetric check in orderedBefore: + // sinkable.transfersControlFlow() && current.hasSideEffects() + // This set is usually empty since sinkables rarely transfer control flow. + std::unordered_set controlFlowSinkables; + + void registerSinkable(Index key) { + auto& effects = sinkables.at(key).effects; + for (auto L : effects.localsRead) { + localReadBySinkable[L].insert(key); + } + for (auto L : effects.localsWritten) { + localWrittenBySinkable[L].insert(key); + } + if (effects.transfersControlFlow()) { + controlFlowSinkables.insert(key); + } + } + + void unregisterSinkable(Index key) { + auto it = sinkables.find(key); + if (it == sinkables.end()) { + return; + } + auto& effects = it->second.effects; + for (auto L : effects.localsRead) { + auto mapIt = localReadBySinkable.find(L); + if (mapIt != localReadBySinkable.end()) { + mapIt->second.erase(key); + if (mapIt->second.empty()) { + localReadBySinkable.erase(mapIt); + } + } + } + for (auto L : effects.localsWritten) { + auto mapIt = localWrittenBySinkable.find(L); + if (mapIt != localWrittenBySinkable.end()) { + mapIt->second.erase(key); + if (mapIt->second.empty()) { + localWrittenBySinkable.erase(mapIt); + } + } + } + controlFlowSinkables.erase(key); + } + + void clearSinkables() { + sinkables.clear(); + localReadBySinkable.clear(); + localWrittenBySinkable.clear(); + controlFlowSinkables.clear(); + } + + Sinkables takeSinkables() { + localReadBySinkable.clear(); + localWrittenBySinkable.clear(); + controlFlowSinkables.clear(); + return std::move(sinkables); + } + + void eraseSinkable(typename Sinkables::iterator it) { + unregisterSinkable(it->first); + sinkables.erase(it); + } + + void eraseSinkable(Index key) { + unregisterSinkable(key); + sinkables.erase(key); + } + + void addSinkable(Index key, Expression** currp) { + sinkables.emplace(std::pair{ + key, SinkableInfo(currp, this->getPassOptions(), *this->getModule())}); + registerSinkable(key); + } + // Information about an exit from a block: the break, and the // sinkables. For the final exit from a block (falling off) // exitter is null. @@ -135,8 +265,7 @@ struct SimplifyLocals // value means the block already has a return value self->unoptimizableBlocks.insert(br->name); } else { - self->blockBreaks[br->name].push_back( - {currp, std::move(self->sinkables)}); + self->blockBreaks[br->name].push_back({currp, self->takeSinkables()}); } } else if (curr->is()) { return; // handled in visitBlock @@ -153,7 +282,7 @@ struct SimplifyLocals } // TODO: we could use this info to stop gathering data on these blocks } - self->sinkables.clear(); + self->clearSinkables(); } static void doNoteIfCondition( @@ -161,7 +290,7 @@ struct SimplifyLocals Expression** currp) { // we processed the condition of this if-else, and now control flow branches // into either the true or the false sides - self->sinkables.clear(); + self->clearSinkables(); } static void @@ -170,13 +299,13 @@ struct SimplifyLocals auto* iff = (*currp)->cast(); if (iff->ifFalse) { // We processed the ifTrue side of this if-else, save it on the stack. - self->ifStack.push_back(std::move(self->sinkables)); + self->ifStack.push_back(self->takeSinkables()); } else { // This is an if without an else. if (allowStructure) { self->optimizeIfReturn(iff, currp); } - self->sinkables.clear(); + self->clearSinkables(); } } @@ -191,7 +320,7 @@ struct SimplifyLocals self->optimizeIfElseReturn(iff, currp, self->ifStack.back()); } self->ifStack.pop_back(); - self->sinkables.clear(); + self->clearSinkables(); } void visitBlock(Block* curr) { @@ -204,13 +333,13 @@ struct SimplifyLocals // post-block cleanups if (curr->name.is()) { if (unoptimizableBlocks.contains(curr->name)) { - sinkables.clear(); + clearSinkables(); unoptimizableBlocks.erase(curr->name); } if (hasBreaks) { // more than one path to here, so nonlinear - sinkables.clear(); + clearSinkables(); blockBreaks.erase(curr->name); } } @@ -284,7 +413,7 @@ struct SimplifyLocals // reuse the local.get that is dying *found->second.item = curr; ExpressionManipulator::nop(curr); - sinkables.erase(found); + eraseSinkable(found); anotherCycle = true; } } @@ -300,7 +429,96 @@ struct SimplifyLocals } void checkInvalidations(EffectAnalyzer& effects) { - // TODO: this is O(bad) + bool readsLocal = !effects.localsRead.empty(); + bool writesLocal = !effects.localsWritten.empty(); + bool hasCalls = effects.calls; + bool hasMemory = effects.readsMemory || effects.writesMemory; + bool hasGC = effects.readsMutableStruct || effects.writesStruct || + effects.readsMutableArray || effects.writesArray; + bool hasTrap = effects.trap; + bool hasCF = effects.transfersControlFlow(); + bool hasGlobalState = + effects.writesGlobalState() || effects.readsMutableGlobalState(); + + // Fast path: if the current expression only accesses locals (no memory, + // calls, globals, traps, control flow, etc.), we can use reverse indices + // to find conflicting sinkables in O(|locals touched|) instead of + // iterating all sinkables. + // + // Each condition below corresponds to a non-local conflict category in + // EffectAnalyzer::orderedBefore. When all are false, the only remaining + // conflict paths are through local variable read/write pairs, PLUS the + // asymmetric check: sinkable.transfersControlFlow() && + // current.hasSideEffects(). The latter is handled via + // controlFlowSinkables below. + if (!hasCF && !hasGlobalState && !effects.danglingPop && !hasTrap && + !effects.hasSynchronization() && !effects.mayNotReturn) { + std::unordered_set candidates; + // When the current expression reads local L, any sinkable that writes L + // has a write-read conflict. + for (auto L : effects.localsRead) { + auto it = localWrittenBySinkable.find(L); + if (it != localWrittenBySinkable.end()) { + candidates.insert(it->second.begin(), it->second.end()); + } + } + // When the current expression writes local L, any sinkable that reads L + // (read-write conflict) or writes L (write-write conflict) is a + // candidate. + for (auto L : effects.localsWritten) { + auto it = localReadBySinkable.find(L); + if (it != localReadBySinkable.end()) { + candidates.insert(it->second.begin(), it->second.end()); + } + auto it2 = localWrittenBySinkable.find(L); + if (it2 != localWrittenBySinkable.end()) { + candidates.insert(it2->second.begin(), it2->second.end()); + } + } + // Handle the asymmetric orderedBefore check: a sinkable that transfers + // control flow conflicts with any expression that has side effects + // (which includes local writes). This set is usually empty. + if (effects.hasSideEffects() && !controlFlowSinkables.empty()) { + candidates.insert(controlFlowSinkables.begin(), + controlFlowSinkables.end()); + } + ci_log("fast", + sinkables.size(), + candidates.size(), + readsLocal, + writesLocal, + hasCalls, + hasMemory, + hasGC, + hasTrap, + hasCF, + hasGlobalState); + std::vector invalidated; + for (auto key : candidates) { + auto it = sinkables.find(key); + if (it != sinkables.end() && effects.orderedAfter(it->second.effects)) { + invalidated.push_back(key); + } + } + for (auto key : invalidated) { + eraseSinkable(key); + } + return; + } + + // Slow path: the expression has non-local effects, so we must check all + // sinkables. + ci_log("slow", + sinkables.size(), + sinkables.size(), + readsLocal, + writesLocal, + hasCalls, + hasMemory, + hasGC, + hasTrap, + hasCF, + hasGlobalState); std::vector invalidated; for (auto& [index, info] : sinkables) { if (effects.orderedAfter(info.effects)) { @@ -308,7 +526,7 @@ struct SimplifyLocals } } for (auto index : invalidated) { - sinkables.erase(index); + eraseSinkable(index); } } @@ -334,7 +552,7 @@ struct SimplifyLocals } } for (auto index : invalidated) { - self->sinkables.erase(index); + self->eraseSinkable(index); } } @@ -419,7 +637,7 @@ struct SimplifyLocals Drop* drop = ExpressionManipulator::convert(previous); drop->value = previousValue; drop->finalize(); - self->sinkables.erase(found); + self->eraseSinkable(found); self->anotherCycle = true; } } @@ -432,9 +650,7 @@ struct SimplifyLocals if (set && self->canSink(set)) { Index index = set->index; assert(!self->sinkables.contains(index)); - self->sinkables.emplace(std::pair{ - index, - SinkableInfo(currp, self->getPassOptions(), *self->getModule())}); + self->addSinkable(index, currp); } if (!allowNesting) { @@ -476,7 +692,13 @@ struct SimplifyLocals if (sinkables.empty()) { return; } - Index goodIndex = sinkables.begin()->first; + // Pick the lowest-index sinkable for deterministic output. + Index goodIndex = std::min_element(sinkables.begin(), + sinkables.end(), + [](const auto& a, const auto& b) { + return a.first < b.first; + }) + ->first; // Ensure we have a place to write the return values for, if not, we // need another cycle. auto* block = loop->body->dynCast(); @@ -498,7 +720,7 @@ struct SimplifyLocals this->replaceCurrent(set); // We moved things around, clear all tracking; we'll do another cycle // anyhow. - sinkables.clear(); + clearSinkables(); anotherCycle = true; } @@ -515,7 +737,8 @@ struct SimplifyLocals // block does not already have a return value (if one break has one, they // all do) assert(!(*breaks[0].brp)->template cast()->value); - // look for a local.set that is present in them all + // look for a local.set that is present in them all. + // Pick the lowest index for deterministic output. bool found = false; Index sharedIndex = -1; for (auto& [index, _] : sinkables) { @@ -526,10 +749,9 @@ struct SimplifyLocals break; } } - if (inAll) { + if (inAll && (!found || index < sharedIndex)) { sharedIndex = index; found = true; - break; } } if (!found) { @@ -624,7 +846,7 @@ struct SimplifyLocals auto* newLocalSet = Builder(*this->getModule()).makeLocalSet(sharedIndex, block); this->replaceCurrent(newLocalSet); - sinkables.clear(); + clearSinkables(); anotherCycle = true; block->finalize(); } @@ -656,27 +878,35 @@ struct SimplifyLocals Sinkables& ifFalse = sinkables; Index goodIndex = -1; bool found = false; + auto pickLowest = [](Sinkables& s) { + return std::min_element( + s.begin(), + s.end(), + [](const auto& a, const auto& b) { return a.first < b.first; }) + ->first; + }; if (iff->ifTrue->type == Type::unreachable) { // since the if type is none assert(iff->ifFalse->type != Type::unreachable); if (!ifFalse.empty()) { - goodIndex = ifFalse.begin()->first; + goodIndex = pickLowest(ifFalse); found = true; } } else if (iff->ifFalse->type == Type::unreachable) { // since the if type is none assert(iff->ifTrue->type != Type::unreachable); if (!ifTrue.empty()) { - goodIndex = ifTrue.begin()->first; + goodIndex = pickLowest(ifTrue); found = true; } } else { - // Look for a shared index. + // Look for a shared index (pick the lowest for determinism). for (auto& [index, _] : ifTrue) { if (ifFalse.contains(index)) { - goodIndex = index; - found = true; - break; + if (!found || index < goodIndex) { + goodIndex = index; + found = true; + } } } } @@ -799,7 +1029,13 @@ struct SimplifyLocals // element). // // TODO investigate more - Index goodIndex = sinkables.begin()->first; + // Pick the lowest-index sinkable for deterministic output. + Index goodIndex = std::min_element(sinkables.begin(), + sinkables.end(), + [](const auto& a, const auto& b) { + return a.first < b.first; + }) + ->first; auto localType = this->getFunction()->getLocalType(goodIndex); if (!localType.isDefaultable()) { return; @@ -973,7 +1209,7 @@ struct SimplifyLocals anotherCycle = true; } // clean up - sinkables.clear(); + clearSinkables(); blockBreaks.clear(); unoptimizableBlocks.clear(); return anotherCycle;