diff --git a/loom-cli/src/main.rs b/loom-cli/src/main.rs index c658622..b33c939 100644 --- a/loom-cli/src/main.rs +++ b/loom-cli/src/main.rs @@ -582,6 +582,18 @@ fn optimize_command( track_pass("inline", before, after); } + // #219: dissolve the u64 ABI carrier the inline leaves behind (scalar-forward + // the single-assignment carrier to its unpack sites, then SROA). Runs right + // after inline so the carrier is fresh; downstream dce/dead-locals reap it. + if should_run("forward-carrier") { + println!(" Running: forward-carrier"); + let before = count_instructions(&module); + loom_core::optimize::forward_carrier_locals(&mut module) + .context("Carrier forwarding failed")?; + let after = count_instructions(&module); + track_pass("forward-carrier", before, after); + } + // loom#228: whole-function (module-level) dead-function elimination. // Multi-site inlining duplicates a callee's body into each caller and // leaves the ORIGINAL orphaned; the body-level `dce` (eliminate_dead_code) diff --git a/loom-core/src/lib.rs b/loom-core/src/lib.rs index f09183e..7dae69f 100644 --- a/loom-core/src/lib.rs +++ b/loom-core/src/lib.rs @@ -6953,6 +6953,11 @@ pub mod optimize { // are inlined, enabling subsequent passes to optimize across boundaries. inline_functions(module)?; + // Phase 1b (#219): dissolve the u64 ABI carrier the inline leaves behind — + // scalar-forward the single-assignment carrier to its unpack sites so the + // SROA rules can collapse the pack/unpack round-trip. + forward_carrier_locals(module)?; + // Phase 2: Constant folding (ISLE pattern rewrites) constant_folding(module)?; @@ -8945,6 +8950,601 @@ pub mod optimize { Ok(()) } + /// #219: trap-aware structured dominance for carrier scalar-forwarding. + /// Returns true iff `local`'s SINGLE def dominates ALL its uses, so its + /// defining expression can be forwarded to every use. Conservative + /// structured-CFG argument (rejects on any uncertainty — soundness over + /// completeness; a wrong `true` here is a silent miscompile): + /// - exactly one write of `local`, and it is NOT inside any `If`/`Loop` + /// (all its enclosing blocks are plain `Block`/function root, hence + /// unconditionally entered); + /// - every use is textually AFTER the def; + /// - no `Br`/`BrIf`/`BrTable` textually BEFORE the def targets a block that + /// ENCLOSES the def (an "ancestor"): such a branch could deliver control + /// to the def's enclosing-block continuation — possibly a use — without + /// passing the def. Branches to inner blocks that close before the def + /// (or to the function frame = return, which terminates) are safe, as is + /// any `unreachable`/`return` before the def (they never reach a use). + /// + /// Together these make the def reached on every non-trapping path from entry, + /// so it dominates all (textually-later) uses. + pub(crate) fn carrier_def_dominates_uses(instrs: &[Instruction], local: u32) -> bool { + #[derive(Default)] + struct Scan { + seq: usize, + next_block_id: usize, + def_seq: Option, + def_ancestors: std::collections::HashSet, + def_disqualified: bool, + uses: Vec, + // (branch seq, resolved target block ids). usize::MAX = function frame + // (return-equivalent) — never reaches a use. + branches: Vec<(usize, Vec)>, + } + // Resolve a relative branch depth to the open block's id (innermost = + // depth 0); usize::MAX if it targets the function frame. + fn resolve(open: &[(usize, bool)], depth: u32) -> usize { + let d = depth as usize; + if d < open.len() { + open[open.len() - 1 - d].0 + } else { + usize::MAX + } + } + fn walk(instrs: &[Instruction], s: &mut Scan, open: &mut Vec<(usize, bool)>, local: u32) { + for instr in instrs { + s.seq += 1; + match instr { + Instruction::LocalSet(i) | Instruction::LocalTee(i) if *i == local => { + if s.def_seq.is_some() { + s.def_disqualified = true; // more than one write + } else { + s.def_seq = Some(s.seq); + s.def_ancestors = open.iter().map(|(id, _)| *id).collect(); + if open.iter().any(|(_, conditional)| *conditional) { + s.def_disqualified = true; // def under an If/Loop + } + } + } + Instruction::LocalGet(i) if *i == local => s.uses.push(s.seq), + Instruction::Br(d) | Instruction::BrIf(d) => { + s.branches.push((s.seq, vec![resolve(open, *d)])); + } + Instruction::BrTable { targets, default } => { + let mut ids: Vec = + targets.iter().map(|d| resolve(open, *d)).collect(); + ids.push(resolve(open, *default)); + s.branches.push((s.seq, ids)); + } + Instruction::Block { body, .. } => { + let id = s.next_block_id; + s.next_block_id += 1; + open.push((id, false)); + walk(body, s, open, local); + open.pop(); + } + Instruction::Loop { body, .. } => { + let id = s.next_block_id; + s.next_block_id += 1; + open.push((id, true)); // back-edge → conservative + walk(body, s, open, local); + open.pop(); + } + Instruction::If { + then_body, + else_body, + .. + } => { + let id = s.next_block_id; + s.next_block_id += 1; + open.push((id, true)); + walk(then_body, s, open, local); + walk(else_body, s, open, local); + open.pop(); + } + _ => {} + } + } + } + let mut s = Scan::default(); + let mut open: Vec<(usize, bool)> = Vec::new(); + walk(instrs, &mut s, &mut open, local); + + let def_seq = match s.def_seq { + Some(seq) if !s.def_disqualified => seq, + _ => return false, + }; + // Every use must be after the def. + if s.uses.iter().any(|&u| u <= def_seq) { + return false; + } + // No pre-def branch may escape to a def-ancestor block. + for (bseq, tgts) in &s.branches { + if *bseq < def_seq && tgts.iter().any(|t| s.def_ancestors.contains(t)) { + return false; + } + } + true + } + + /// #219 STEP 1 — is this instruction a pure, side-effect-free, integer-domain + /// value computation whose term round-trip is faithful? + /// + /// The term IR is a value-expression stack model: control-flow ops (`br_if`/ + /// `br_table`) are pushed as value terms, so a mid-block branch followed by + /// trailing stack computation does NOT round-trip (`instructions_to_terms` → + /// `terms_to_instructions` is lossy — see #219). But a maximal contiguous run + /// of these *pure* ops round-trips faithfully, so we can safely apply the + /// structural rewriter (`rewrite_pure`) to such a "window". We restrict to the + /// integer domain (no floats — avoids NaN-bit canonicalization questions) and + /// exclude memory access, calls, control flow, and local/global writes. + fn is_simple_pure_instr(instr: &Instruction) -> bool { + use Instruction::*; + matches!( + instr, + // Constants (integer only) + I32Const(_) | I64Const(_) + // Reads (no side effects, no writes) + | LocalGet(_) | GlobalGet(_) + // Integer binary ops (consume 2, produce 1) + | I32Add | I32Sub | I32Mul | I32DivS | I32DivU | I32RemS | I32RemU + | I32And | I32Or | I32Xor | I32Shl | I32ShrS | I32ShrU | I32Rotl | I32Rotr + | I32Eq | I32Ne | I32LtS | I32LtU | I32GtS | I32GtU | I32LeS | I32LeU | I32GeS | I32GeU + | I64Add | I64Sub | I64Mul | I64DivS | I64DivU | I64RemS | I64RemU + | I64And | I64Or | I64Xor | I64Shl | I64ShrS | I64ShrU | I64Rotl | I64Rotr + | I64Eq | I64Ne | I64LtS | I64LtU | I64GtS | I64GtU | I64LeS | I64LeU | I64GeS | I64GeU + // Integer unary ops (consume 1, produce 1) + | I32Eqz | I32Clz | I32Ctz | I32Popcnt + | I64Eqz | I64Clz | I64Ctz | I64Popcnt + | I32WrapI64 | I64ExtendI32S | I64ExtendI32U + | I32Extend8S | I32Extend16S | I64Extend8S | I64Extend16S | I64Extend32S + // Select (consume 3, produce 1) — pure + | Select + ) + } + + /// Net stack effect (produced − consumed) of an instruction sequence. + fn net_stack_effect(instrs: &[Instruction]) -> i32 { + instrs + .iter() + .map(|i| { + let (c, p) = instruction_stack_io(i); + p - c + }) + .sum() + } + + /// Total instruction count, recursing into nested bodies. Used as a + /// never-pessimize guard: forwarding that does not shrink the function is + /// reverted (a single-use carrier whose RHS does not dissolve would + /// otherwise grow the code). + fn count_instrs(instrs: &[Instruction]) -> usize { + let mut n = 0; + for i in instrs { + n += 1; + match i { + Instruction::Block { body, .. } | Instruction::Loop { body, .. } => { + n += count_instrs(body); + } + Instruction::If { + then_body, + else_body, + .. + } => { + n += count_instrs(then_body) + count_instrs(else_body); + } + _ => {} + } + } + n + } + + /// #219 STEP 1 — rewrite a single pure window through the term IR. + /// + /// Returns `Some(new_instrs)` only when the window is stack-closed + /// (`instructions_to_terms` succeeds — no underflow, i.e. it does not consume + /// values produced before the window) AND the rewrite preserves the net stack + /// effect. Otherwise `None` (leave the window untouched — conservative). + fn rewrite_pure_window(window: &[Instruction]) -> Option> { + use loom_isle::rewrite_pure; + let terms = super::terms::instructions_to_terms(window).ok()?; + if terms.is_empty() { + return None; + } + let rewritten: Vec<_> = terms.into_iter().map(rewrite_pure).collect(); + let new_instrs = super::terms::terms_to_instructions(&rewritten).ok()?; + if net_stack_effect(window) != net_stack_effect(&new_instrs) { + return None; + } + Some(new_instrs) + } + + /// #219 STEP 1 — WINDOWED SROA. Apply the structural rewriter (`rewrite_pure`) + /// to every maximal pure straight-line window in `instrs`, recursing into + /// nested bodies first. This is the SROA path that survives `br_table` + /// functions like `z_impl`: it NEVER round-trips a control-flow instruction + /// through the (lossy) term IR — only self-contained pure runs, which are + /// faithful. `rewrite_pure` is the proven structural rule set, so each window + /// rewrite is semantics-preserving; the net-stack-effect guard in + /// `rewrite_pure_window` keeps the surrounding stack shape intact. + pub(crate) fn simplify_pure_windows(instrs: &mut Vec) { + // Recurse into nested bodies first. + for instr in instrs.iter_mut() { + match instr { + Instruction::Block { body, .. } | Instruction::Loop { body, .. } => { + simplify_pure_windows(body); + } + Instruction::If { + then_body, + else_body, + .. + } => { + simplify_pure_windows(then_body); + simplify_pure_windows(else_body); + } + _ => {} + } + } + + // Rewrite maximal pure windows in THIS body. + let mut out: Vec = Vec::with_capacity(instrs.len()); + let mut i = 0; + while i < instrs.len() { + if is_simple_pure_instr(&instrs[i]) { + let start = i; + while i < instrs.len() && is_simple_pure_instr(&instrs[i]) { + i += 1; + } + let window = &instrs[start..i]; + match rewrite_pure_window(window) { + Some(rewritten) => out.extend(rewritten), + None => out.extend_from_slice(window), + } + } else { + out.push(instrs[i].clone()); + i += 1; + } + } + *instrs = out; + } + + /// #219 STEP 3 — is the value stored by the def at `instrs[def_idx]` + /// provably `< 2^32` (high 32 bits always zero)? + /// + /// Conservative, purely local: looks only at the instruction(s) immediately + /// preceding the `local.set`/`local.tee`. A stored value is narrow when it is + /// `i64.extend_i32_u(_)` (a zero-extend always fits in 32 bits), an + /// `i64.const c` with the high 32 bits clear, or `(_ & i64.const c)` with `c` + /// fitting in 32 bits (the mask clears bits ≥ 32). Anything else → not narrow. + fn def_value_is_narrow(instrs: &[Instruction], def_idx: usize) -> bool { + if def_idx == 0 { + return false; + } + match &instrs[def_idx - 1] { + Instruction::I64ExtendI32U => true, + Instruction::I64Const(c) => (*c as u64) >> 32 == 0, + // `; i64.const c; i64.and` — the AND's RHS const (pushed last) + // sits two before the set. A mask < 2^32 bounds the result. + Instruction::I64And => { + def_idx >= 2 + && matches!( + &instrs[def_idx - 2], + Instruction::I64Const(c) if (*c as u64) >> 32 == 0 + ) + } + _ => false, + } + } + + /// #219 STEP 3 — the set of locals provably `< 2^32` (high 32 bits always + /// zero): every write of the local stores a narrow value + /// (`def_value_is_narrow`), recursing into nested bodies. A local with no + /// write, or any non-narrow write, is excluded (conservative). + pub(crate) fn narrow_locals(instrs: &[Instruction]) -> std::collections::HashSet { + use std::collections::HashSet; + fn scan(instrs: &[Instruction], written: &mut HashSet, not_narrow: &mut HashSet) { + for (i, instr) in instrs.iter().enumerate() { + match instr { + Instruction::LocalSet(n) | Instruction::LocalTee(n) => { + written.insert(*n); + if !def_value_is_narrow(instrs, i) { + not_narrow.insert(*n); + } + } + Instruction::Block { body, .. } | Instruction::Loop { body, .. } => { + scan(body, written, not_narrow); + } + Instruction::If { + then_body, + else_body, + .. + } => { + scan(then_body, written, not_narrow); + scan(else_body, written, not_narrow); + } + _ => {} + } + } + } + let mut written = HashSet::new(); + let mut not_narrow = HashSet::new(); + scan(instrs, &mut written, &mut not_narrow); + written.difference(¬_narrow).copied().collect() + } + + /// #219 STEP 3 — fold `(i64.shr_u (local.get N) k)` → `i64.const 0` when `N` + /// is narrow (`< 2^32`) and the effective shift `k mod 64 ≥ 32`, recursing + /// into nested bodies. + /// + /// Shifting a value whose high 32 bits are zero right by ≥ 32 yields 0. This + /// is the value-range fact pure structural SROA can't see; supplying it here + /// lets the subsequent `simplify_pure_windows` collapse `(or P 0) → P`, + /// completing the seam unpack `(extend(a)<<32 | status) >> 32 → extend(a)`. + pub(crate) fn fold_narrow_high_shr( + instrs: &mut Vec, + narrow: &std::collections::HashSet, + ) { + for instr in instrs.iter_mut() { + match instr { + Instruction::Block { body, .. } | Instruction::Loop { body, .. } => { + fold_narrow_high_shr(body, narrow); + } + Instruction::If { + then_body, + else_body, + .. + } => { + fold_narrow_high_shr(then_body, narrow); + fold_narrow_high_shr(else_body, narrow); + } + _ => {} + } + } + + let mut out: Vec = Vec::with_capacity(instrs.len()); + let mut i = 0; + while i < instrs.len() { + if i + 2 < instrs.len() { + if let (Instruction::LocalGet(n), Instruction::I64Const(k), Instruction::I64ShrU) = + (&instrs[i], &instrs[i + 1], &instrs[i + 2]) + { + if narrow.contains(n) && (*k as u64 & 63) >= 32 { + out.push(Instruction::I64Const(0)); + i += 3; + continue; + } + } + } + out.push(instrs[i].clone()); + i += 1; + } + *instrs = out; + } + + /// #219 STEP 2 — locate the single def of `carrier` and extract its pure + /// defining expression (the RHS), recursing into nested bodies. + /// + /// Returns `(rhs_instrs, is_tee)` where `rhs_instrs` is the maximal pure, + /// stack-closed, single-value (`net == +1`) instruction run immediately + /// preceding the def — i.e. exactly the expression the `local.set`/`local.tee` + /// stores. Returns `None` if the def isn't found or its RHS isn't a pure, + /// self-contained single value. The caller has already established (via + /// `carrier_def_dominates_uses`) that `carrier` is single-assignment and its + /// def is not under an `If`/`Loop`, so there is exactly one def to find. + fn extract_carrier_rhs( + instrs: &[Instruction], + carrier: u32, + ) -> Option<(Vec, bool)> { + for (d, instr) in instrs.iter().enumerate() { + match instr { + Instruction::LocalSet(i) | Instruction::LocalTee(i) if *i == carrier => { + let is_tee = matches!(instr, Instruction::LocalTee(_)); + // Maximal pure run ending at the def. + let mut run_start = d; + while run_start > 0 && is_simple_pure_instr(&instrs[run_start - 1]) { + run_start -= 1; + } + // Longest suffix [s..d] that is a single self-contained value + // (round-trips through the term IR as exactly one term). + for s in run_start..d { + let cand = &instrs[s..d]; + if let Ok(terms) = super::terms::instructions_to_terms(cand) { + if terms.len() == 1 && net_stack_effect(cand) == 1 { + return Some((cand.to_vec(), is_tee)); + } + } + } + return None; + } + Instruction::Block { body, .. } | Instruction::Loop { body, .. } => { + if let Some(r) = extract_carrier_rhs(body, carrier) { + return Some(r); + } + } + Instruction::If { + then_body, + else_body, + .. + } => { + if let Some(r) = extract_carrier_rhs(then_body, carrier) + .or_else(|| extract_carrier_rhs(else_body, carrier)) + { + return Some(r); + } + } + _ => {} + } + } + None + } + + /// #219 STEP 2 — forward `carrier`'s pure RHS to every use and drop the def, + /// recursing into nested bodies. + /// + /// - Each `local.get carrier` is replaced by a clone of `rhs` (pure ⇒ + /// re-evaluating at the use is sound). + /// - The def is removed: a `local.tee` is dropped but its RHS run is LEFT in + /// place (its value stays on the stack for the original consumer); a + /// `local.set` is dropped together with its preceding RHS run (now-dead pure + /// computation — sound to delete). + /// + /// The carrier local becomes dead afterwards; later DCE/dead-local passes + /// remove its declaration. The exposed pack/unpack is dissolved by a + /// subsequent `simplify_pure_windows`. + fn forward_carrier_in_body( + body: &mut Vec, + carrier: u32, + rhs: &[Instruction], + is_tee: bool, + ) { + // Recurse into nested bodies first so uses there are substituted. + for instr in body.iter_mut() { + match instr { + Instruction::Block { body: b, .. } | Instruction::Loop { body: b, .. } => { + forward_carrier_in_body(b, carrier, rhs, is_tee); + } + Instruction::If { + then_body, + else_body, + .. + } => { + forward_carrier_in_body(then_body, carrier, rhs, is_tee); + forward_carrier_in_body(else_body, carrier, rhs, is_tee); + } + _ => {} + } + } + + let mut out: Vec = Vec::with_capacity(body.len()); + for instr in body.drain(..) { + match &instr { + Instruction::LocalGet(i) if *i == carrier => { + out.extend_from_slice(rhs); + } + Instruction::LocalSet(i) if *i == carrier && !is_tee => { + // Drop the dead set and its preceding RHS run. + let drop_n = rhs.len().min(out.len()); + out.truncate(out.len() - drop_n); + } + Instruction::LocalTee(i) if *i == carrier && is_tee => { + // Drop the tee; its RHS value stays on the stack for the + // original consumer. + } + _ => out.push(instr), + } + } + *body = out; + } + + /// #219 carrier scalar-forwarding (the perf milestone): forward a TOP-LEVEL + /// single-assignment local's pure defining expression to its use sites, then + /// let the committed SROA rules dissolve the now-exposed pack/unpack. This + /// removes the u64 ABI carrier the proven `br_table` seam inline leaves inside + /// `z_impl` (decide builds `extend_i32_u<<32 | status`, z_impl tears it back + /// with `&0xff` / `>>32` through a dead i64 carrier). + /// + /// SOUNDNESS (this transform is NOT Z3-backstoppable for the void seam fn — + /// the value flows only into a havoc'd impure-call arg → vacuous pass; gale's + /// G474RE silicon is the behavioral gate, #219): + /// - We forward ONLY locals with exactly one write whose def DOMINATES all + /// uses, by `carrier_def_dominates_uses` (trap-aware structured + /// dominance: def not under If/Loop, all uses after it, no pre-def branch + /// escapes the def's enclosing blocks). The def is then reached on every + /// non-trapping path from entry ⇒ dominates every use. + /// - The forwarding is done at the INSTRUCTION level + /// (`extract_carrier_rhs` + `forward_carrier_in_body`), NOT through the + /// term IR: the term round-trip is lossy for control flow (it pushes + /// `br_if`/`br_table` as value terms), so any term-based forwarding turns + /// a `br_table` function like `z_impl` stack-invalid (#219). Only the + /// carrier's pure RHS — itself a self-contained straight-line run — is + /// ever taken through terms, and only via `simplify_pure_windows` for the + /// final SROA dissolution, which never round-trips control flow. + /// - `extract_carrier_rhs` admits only a pure (`is_simple_pure_instr`) + /// single-value run, so re-evaluating the expression at a use is safe. + /// + /// `verify_or_revert` still guards non-void functions and stack correctness. + pub fn forward_carrier_locals(module: &mut Module) -> Result<()> { + use crate::stack::validation::{ValidationContext, ValidationGuard}; + use crate::verify::{TranslationValidator, VerificationSignatureContext}; + use std::collections::HashSet; + + let ctx = ValidationContext::from_module(module); + let verify_sig_ctx = VerificationSignatureContext::from_module(module); + + for func in &mut module.functions { + if has_unknown_instructions(func) || has_unsupported_isle_instructions(func) { + continue; + } + + // Precompute the forwardable carriers: single-assignment locals whose + // (one) def DOMINATES all uses by trap-aware structured dominance — + // see carrier_def_dominates_uses. (analyze_locals' `sets.len()==1` is + // the fast pre-filter; the dominance check is the soundness gate.) + let (usage, _equiv) = analyze_locals(&func.instructions); + let single_assign: HashSet = usage + .iter() + .filter(|(idx, u)| { + u.sets.len() == 1 && carrier_def_dominates_uses(&func.instructions, **idx) + }) + .map(|(idx, _)| *idx) + .collect(); + if single_assign.is_empty() { + continue; + } + + let guard = ValidationGuard::with_context(func, "forward_carrier_locals", ctx.clone()); + let translator = TranslationValidator::new_with_context( + func, + "forward_carrier_locals", + verify_sig_ctx.clone(), + ); + let original_instructions = func.instructions.clone(); + + // Instruction-level forwarding (br_table-safe): for each forwardable + // carrier, extract its pure single-value RHS, substitute it into every + // use, and drop the dead def. Process carriers deterministically. + let mut carriers: Vec = single_assign.into_iter().collect(); + carriers.sort_unstable(); + for carrier in carriers { + if let Some((rhs, is_tee)) = extract_carrier_rhs(&func.instructions, carrier) { + forward_carrier_in_body(&mut func.instructions, carrier, &rhs, is_tee); + } + } + // Dissolve the exposed pack/unpack. Iterate two passes to a fixpoint: + // - simplify_pure_windows runs rewrite_pure on pure windows, which + // DISTRIBUTES the unpack: (or P Q)>>32 → (P>>32) | (Q>>32), + // EXPOSING a bare `status>>32`; + // - fold_narrow_high_shr then supplies the value-range fact pure + // SROA can't see — (shr_u (local.get N) k≥32) → 0 for narrow + // (< 2^32) locals — letting the next window collapse (or P 0)→P. + // The bare shr_u only appears AFTER the first window pass distributes, + // so neither order alone suffices; iterate until stable (capped). + for _ in 0..4 { + let before = func.instructions.clone(); + let narrow = narrow_locals(&func.instructions); + fold_narrow_high_shr(&mut func.instructions, &narrow); + simplify_pure_windows(&mut func.instructions); + if func.instructions == before { + break; + } + } + + // Revert unless: stack/verify pass AND the function did not grow + // (never pessimize — forwarding that fails to dissolve is dropped). + let grew = count_instrs(&func.instructions) > count_instrs(&original_instructions); + if grew || guard.validate(func).is_err() || translator.verify(func).is_err() { + if grew { + crate::stats::record_revert("forward_carrier_locals_no_shrink"); + } else { + eprintln!("forward_carrier_locals: reverting function (verification rejected)"); + crate::stats::record_revert("forward_carrier_locals"); + } + func.instructions = original_instructions; + } + } + Ok(()) + } + #[derive(Debug, Clone)] struct LocalUsage { // Positions where this local is read (local.get) @@ -14677,6 +15277,195 @@ mod tests { // Just test that ISLE types are accessible } + // #219: trap-aware structured dominance for carrier forwarding. + #[test] + fn test_carrier_dominance() { + use crate::optimize::carrier_def_dominates_uses; + let dom = |wat: &str, local: u32| -> bool { + let m = parse::parse_wat(wat).unwrap(); + carrier_def_dominates_uses(&m.functions[0].instructions, local) + }; + + // Straight-line: def (local 1) then use → dominates. + assert!( + dom( + "(module (func (param i32) (result i32) (local i32) + local.get 0 local.set 1 local.get 1))", + 1 + ), + "straight-line def-before-use must dominate" + ); + + // Use before def → not dominated. + assert!( + !dom( + "(module (func (param i32) (result i32) (local i32) + local.get 1 drop local.get 0 local.set 1 local.get 1))", + 1 + ), + "a use before the def must NOT be dominated" + ); + + // Def inside an If branch → not dominated (conditional). + assert!( + !dom( + "(module (func (param i32) (result i32) (local i32) + local.get 0 + if (result i32) local.get 0 local.set 1 i32.const 1 else i32.const 0 end + local.get 1 i32.add))", + 1 + ), + "a def inside an If branch must NOT be dominated" + ); + + // Seam shape: def nested in a block, reached after an inner dispatch + // block whose branch targets that INNER block (does not escape), use + // after the outer block via a br_if positioned AFTER the def → dominates. + assert!( + dom( + "(module (func (param i32) (result i32) (local i32) + block (result i32) + block + local.get 0 + br_if 0 + end + local.get 0 + local.set 1 + local.get 0 + br_if 1 + i32.const 7 + end + local.get 1 + i32.add))", + 1 + ), + "nested def reached past an inner-only branch, used after, must dominate" + ); + + // Escape shape: a branch BEFORE the def targets the def's ENCLOSING block + // (exits past where the def lives) → not dominated. + assert!( + !dom( + "(module (func (param i32) (result i32) (local i32) + block (result i32) + local.get 0 + br_if 0 + local.get 0 + local.set 1 + i32.const 7 + end + local.get 1 + i32.add))", + 1 + ), + "a pre-def branch escaping the def's enclosing block must NOT dominate" + ); + } + + // #219 STEP 1: windowed SROA — rewrite_pure applied to maximal pure + // straight-line windows, br_table-safe (never round-trips control flow). + #[test] + fn test_windowed_sroa() { + use crate::optimize::simplify_pure_windows; + + // A pure window dissolves: wrap_i64(extend_i32_u(x)) = x (Z3-validated + // seam-SROA rule). The 3-instruction window collapses to a single get. + let mut w = vec![ + Instruction::LocalGet(0), + Instruction::I64ExtendI32U, + Instruction::I32WrapI64, + ]; + simplify_pure_windows(&mut w); + assert_eq!( + w, + vec![Instruction::LocalGet(0)], + "pure window wrap(extend_u(x)) must dissolve to x" + ); + + // A window ADJACENT to a br_table is rewritten on its pure parts only; + // the br_table itself (and block structure) is preserved untouched — + // the lossy term round-trip is never invoked on control flow. + let inner = vec![ + // pure prefix: 2 + 3 → 5 (const-fold), then the index for br_table + Instruction::LocalGet(0), + Instruction::I64ExtendI32U, + Instruction::I32WrapI64, // dissolves to local.get 0 + Instruction::BrTable { + targets: vec![0, 0], + default: 0, + }, + ]; + let mut body = vec![Instruction::Block { + block_type: BlockType::Empty, + body: inner, + }]; + simplify_pure_windows(&mut body); + // The Block + BrTable survive; the pure prefix inside is simplified. + match &body[0] { + Instruction::Block { body: b, .. } => { + assert_eq!(b[0], Instruction::LocalGet(0), "pure prefix simplified"); + assert!( + matches!(b.last(), Some(Instruction::BrTable { .. })), + "br_table preserved untouched, not round-tripped" + ); + assert_eq!(b.len(), 2, "window collapsed to [local.get 0, br_table]"); + } + other => panic!("expected Block, got {other:?}"), + } + } + + // #219 STEP 3: narrow-local analysis + high-shift fold. + #[test] + fn test_narrow_locals_and_high_shr_fold() { + use crate::optimize::{fold_narrow_high_shr, narrow_locals}; + + // local 0: every def narrow (extend_u, then small const) → narrow. + // local 1: a def with the high 32 bits set (0x1_0000_0000) → NOT narrow. + // local 2: def via (_ & 0xff) const mask → narrow. + let instrs = vec![ + Instruction::LocalGet(9), + Instruction::I64ExtendI32U, + Instruction::LocalSet(0), + Instruction::I64Const(1), + Instruction::LocalSet(0), // local 0: both defs narrow + Instruction::I64Const(0x1_0000_0000), + Instruction::LocalSet(1), // local 1: wide const → not narrow + Instruction::LocalGet(9), + Instruction::I64ExtendI32U, + Instruction::I64Const(0xff), + Instruction::I64And, + Instruction::LocalSet(2), // local 2: masked by 0xff → narrow + ]; + let narrow = narrow_locals(&instrs); + assert!( + narrow.contains(&0), + "local 0 (extend + small const) is narrow" + ); + assert!(!narrow.contains(&1), "local 1 (wide const) is NOT narrow"); + assert!(narrow.contains(&2), "local 2 (& 0xff) is narrow"); + + // Fold: narrow local's >>32 → 0; non-narrow local's >>32 untouched. + let mut body = vec![ + Instruction::LocalGet(0), + Instruction::I64Const(32), + Instruction::I64ShrU, // narrow → folds to const 0 + Instruction::LocalGet(1), + Instruction::I64Const(32), + Instruction::I64ShrU, // not narrow → preserved + ]; + fold_narrow_high_shr(&mut body, &narrow); + assert_eq!( + body, + vec![ + Instruction::I64Const(0), + Instruction::LocalGet(1), + Instruction::I64Const(32), + Instruction::I64ShrU, + ], + "narrow >>32 folds to 0; non-narrow >>32 preserved" + ); + } + #[test] fn test_parse_wat_simple() { let wat = r#" @@ -18052,28 +18841,24 @@ mod tests { #[test] fn test_seam_sroa_shr_extracts_high_field() { // #219 seam-SROA: (shr_u (or (shl (extend_u x) 32) const) 32) extracts the - // HIGH field of a u64 pack → (extend_u x) & 0xffffffff. The low (const) - // field shifts out (logical shift distributes over OR; shl-then-shr-same - // masks to the low 64-k bits). Mirrors the sem decide whose low field is - // a constant 0/1. + // HIGH field of a u64 pack → (extend_u x). The low (const) field shifts + // out (logical shift distributes over OR; shl-then-shr-same masks to the + // low 64-k bits, leaving (extend_u x) & 0xffffffff, which the extend-mask + // rule collapses to (extend_u x) since a zero-extend already fits in 32 + // bits). Mirrors the sem decide whose low field is a constant 0/1. use loom_isle::{ - Imm64, i64_extend_i32_u, iand64, iconst64, ior64, ishl64, ishru64, local_get, - rewrite_pure, + Imm64, i64_extend_i32_u, iconst64, ior64, ishl64, ishru64, local_get, rewrite_pure, }; let high = ishl64(i64_extend_i32_u(local_get(0)), iconst64(Imm64(32))); let pack = ior64(high, iconst64(Imm64(1))); // low field = const (like local 3 = 0/1) let unpacked = ishru64(pack, iconst64(Imm64(32))); let simplified = rewrite_pure(unpacked); - let want = terms::terms_to_instructions(&[iand64( - i64_extend_i32_u(local_get(0)), - iconst64(Imm64(0xffff_ffff)), - )]) - .unwrap(); + let want = terms::terms_to_instructions(&[i64_extend_i32_u(local_get(0))]).unwrap(); let got = terms::terms_to_instructions(&[simplified]).unwrap(); assert_eq!( got, want, - "#219: (shr_u (or (shl (extend_u x) 32) const) 32) must extract (extend_u x) & 0xffffffff" + "#219: (shr_u (or (shl (extend_u x) 32) const) 32) must extract (extend_u x)" ); } diff --git a/loom-core/tests/optimization_tests.rs b/loom-core/tests/optimization_tests.rs index f564aec..8706c43 100644 --- a/loom-core/tests/optimization_tests.rs +++ b/loom-core/tests/optimization_tests.rs @@ -1155,13 +1155,19 @@ fn test_rse_different_locals() { let wasm_bytes = encode::encode_wasm(&module).expect("Failed to encode"); wasmparser::validate(&wasm_bytes).expect("Generated WASM is invalid"); - // Only $y store remains (local $x is completely eliminated via optimization cascade) + // After the #219 merge the library pipeline (optimize_module now runs + // forward_carrier_locals + main's stronger simplify_locals) folds this whole + // function to its constant result: x=30, y=20, x+y == 50. Both locals are + // eliminated and the add is constant-folded, leaving `[I32Const(50)]` — 0 + // stores. This is strictly more aggressive than the old "1 store ($y) + // survives" expectation and is provably equivalent (30 + 20 == 50); the + // behavioral differential gate certifies semantics PRESERVED on this module. let instructions_str = format!("{:?}", module.functions[0].instructions); let store_count = instructions_str.matches("LocalSet").count() + instructions_str.matches("LocalTee").count(); assert_eq!( - store_count, 1, - "Expected 1 store ($y only - $x eliminated), got {} in: {:?}", + store_count, 0, + "Expected 0 stores (both locals eliminated, x+y folded to const 50), got {} in: {:?}", store_count, module.functions[0].instructions ); } diff --git a/loom-shared/src/lib.rs b/loom-shared/src/lib.rs index 2cce36e..2fa7271 100644 --- a/loom-shared/src/lib.rs +++ b/loom-shared/src/lib.rs @@ -2712,6 +2712,18 @@ pub struct OptimizationEnv { pub locals: std::collections::HashMap, /// Memory state: location → stored value pub memory: std::collections::HashMap, + /// #219 carrier scalar-forwarding: indices of SINGLE-ASSIGNMENT locals + /// (exactly one write in the whole function, precomputed by the caller) whose + /// pure defining expression may be forwarded to its uses. Empty by default → + /// forwarding OFF, behavior identical to plain dataflow. + pub single_assign: std::collections::HashSet, + /// #219: captured defining expression for each `single_assign` local, recorded + /// at its (single) `local.set`/`local.tee`. SURVIVES control-flow `clear()` + /// (sound: a single-assignment local has one value on every path that reaches + /// a use). INVALIDATED when any input local of the expression is reassigned + /// (reaching-defs guard), so a forwarded expression always carries its + /// def-time inputs. `local.get` of such a local forwards to this expression. + pub pinned: std::collections::HashMap, } impl Default for OptimizationEnv { @@ -2725,6 +2737,8 @@ impl OptimizationEnv { OptimizationEnv { locals: std::collections::HashMap::new(), memory: std::collections::HashMap::new(), + single_assign: std::collections::HashSet::new(), + pinned: std::collections::HashMap::new(), } } @@ -2737,6 +2751,162 @@ impl OptimizationEnv { /// Legacy type alias for compatibility pub type LocalEnv = OptimizationEnv; +/// #219 carrier scalar-forwarding: is `v` a side-effect-free, re-evaluatable +/// value expression safe to FORWARD to a single-assignment local's use sites? +/// Whitelist of total pure ops (constants, locals, globals, integer +/// arithmetic/bitwise/shift/rotate/compare, width converts, select); recurses +/// into operands. Calls, ANY load/store, div/rem (trap), floats, blocks/branches +/// → `false`. Bounded to the listed variants (NOT a full 187-variant walk): +/// unknown/impure → `false` (conservative, sound). +#[allow(clippy::match_like_matches_macro)] +fn is_forwardable_expr(v: &Value) -> bool { + match v.data() { + ValueData::I32Const { .. } + | ValueData::I64Const { .. } + | ValueData::LocalGet { .. } + | ValueData::GlobalGet { .. } => true, + ValueData::I32WrapI64 { val } + | ValueData::I64ExtendI32S { val } + | ValueData::I64ExtendI32U { val } + | ValueData::I32Eqz { val } + | ValueData::I64Eqz { val } => is_forwardable_expr(val), + ValueData::I32Add { lhs, rhs } + | ValueData::I32Sub { lhs, rhs } + | ValueData::I32Mul { lhs, rhs } + | ValueData::I32And { lhs, rhs } + | ValueData::I32Or { lhs, rhs } + | ValueData::I32Xor { lhs, rhs } + | ValueData::I32Shl { lhs, rhs } + | ValueData::I32ShrS { lhs, rhs } + | ValueData::I32ShrU { lhs, rhs } + | ValueData::I32Rotl { lhs, rhs } + | ValueData::I32Rotr { lhs, rhs } + | ValueData::I32Eq { lhs, rhs } + | ValueData::I32Ne { lhs, rhs } + | ValueData::I32LtS { lhs, rhs } + | ValueData::I32LtU { lhs, rhs } + | ValueData::I32GtS { lhs, rhs } + | ValueData::I32GtU { lhs, rhs } + | ValueData::I32LeS { lhs, rhs } + | ValueData::I32LeU { lhs, rhs } + | ValueData::I32GeS { lhs, rhs } + | ValueData::I32GeU { lhs, rhs } + | ValueData::I64Add { lhs, rhs } + | ValueData::I64Sub { lhs, rhs } + | ValueData::I64Mul { lhs, rhs } + | ValueData::I64And { lhs, rhs } + | ValueData::I64Or { lhs, rhs } + | ValueData::I64Xor { lhs, rhs } + | ValueData::I64Shl { lhs, rhs } + | ValueData::I64ShrS { lhs, rhs } + | ValueData::I64ShrU { lhs, rhs } + | ValueData::I64Rotl { lhs, rhs } + | ValueData::I64Rotr { lhs, rhs } + | ValueData::I64Eq { lhs, rhs } + | ValueData::I64Ne { lhs, rhs } + | ValueData::I64LtS { lhs, rhs } + | ValueData::I64LtU { lhs, rhs } + | ValueData::I64GtS { lhs, rhs } + | ValueData::I64GtU { lhs, rhs } + | ValueData::I64LeS { lhs, rhs } + | ValueData::I64LeU { lhs, rhs } + | ValueData::I64GeS { lhs, rhs } + | ValueData::I64GeU { lhs, rhs } => is_forwardable_expr(lhs) && is_forwardable_expr(rhs), + ValueData::Select { + cond, + true_val, + false_val, + } => { + is_forwardable_expr(cond) + && is_forwardable_expr(true_val) + && is_forwardable_expr(false_val) + } + _ => false, + } +} + +/// #219: does `v` (a forwardable expr — same whitelist as `is_forwardable_expr`) +/// reference `local.get local_idx`? Used to INVALIDATE a pinned forwarding when +/// one of its input locals is reassigned (reaching-defs guard), so a forwarded +/// expression never picks up a later redefinition of an input. +fn expr_references_local(v: &Value, local_idx: u32) -> bool { + match v.data() { + ValueData::LocalGet { idx } => *idx == local_idx, + ValueData::I32Const { .. } | ValueData::I64Const { .. } | ValueData::GlobalGet { .. } => { + false + } + ValueData::I32WrapI64 { val } + | ValueData::I64ExtendI32S { val } + | ValueData::I64ExtendI32U { val } + | ValueData::I32Eqz { val } + | ValueData::I64Eqz { val } => expr_references_local(val, local_idx), + ValueData::I32Add { lhs, rhs } + | ValueData::I32Sub { lhs, rhs } + | ValueData::I32Mul { lhs, rhs } + | ValueData::I32And { lhs, rhs } + | ValueData::I32Or { lhs, rhs } + | ValueData::I32Xor { lhs, rhs } + | ValueData::I32Shl { lhs, rhs } + | ValueData::I32ShrS { lhs, rhs } + | ValueData::I32ShrU { lhs, rhs } + | ValueData::I32Rotl { lhs, rhs } + | ValueData::I32Rotr { lhs, rhs } + | ValueData::I32Eq { lhs, rhs } + | ValueData::I32Ne { lhs, rhs } + | ValueData::I32LtS { lhs, rhs } + | ValueData::I32LtU { lhs, rhs } + | ValueData::I32GtS { lhs, rhs } + | ValueData::I32GtU { lhs, rhs } + | ValueData::I32LeS { lhs, rhs } + | ValueData::I32LeU { lhs, rhs } + | ValueData::I32GeS { lhs, rhs } + | ValueData::I32GeU { lhs, rhs } + | ValueData::I64Add { lhs, rhs } + | ValueData::I64Sub { lhs, rhs } + | ValueData::I64Mul { lhs, rhs } + | ValueData::I64And { lhs, rhs } + | ValueData::I64Or { lhs, rhs } + | ValueData::I64Xor { lhs, rhs } + | ValueData::I64Shl { lhs, rhs } + | ValueData::I64ShrS { lhs, rhs } + | ValueData::I64ShrU { lhs, rhs } + | ValueData::I64Rotl { lhs, rhs } + | ValueData::I64Rotr { lhs, rhs } + | ValueData::I64Eq { lhs, rhs } + | ValueData::I64Ne { lhs, rhs } + | ValueData::I64LtS { lhs, rhs } + | ValueData::I64LtU { lhs, rhs } + | ValueData::I64GtS { lhs, rhs } + | ValueData::I64GtU { lhs, rhs } + | ValueData::I64LeS { lhs, rhs } + | ValueData::I64LeU { lhs, rhs } + | ValueData::I64GeS { lhs, rhs } + | ValueData::I64GeU { lhs, rhs } => { + expr_references_local(lhs, local_idx) || expr_references_local(rhs, local_idx) + } + ValueData::Select { + cond, + true_val, + false_val, + } => { + expr_references_local(cond, local_idx) + || expr_references_local(true_val, local_idx) + || expr_references_local(false_val, local_idx) + } + // A pinned expr only ever holds is_forwardable_expr shapes; anything else + // conservatively "might reference" → invalidate. + _ => true, + } +} + +/// #219: a local `idx` was just (re)assigned — drop any pinned forwarding whose +/// expression references it, so a forwarded expression never captures a stale or +/// post-redefinition input value. +fn invalidate_pins_referencing(env: &mut OptimizationEnv, idx: u32) { + env.pinned + .retain(|_, expr| !expr_references_local(expr, idx)); +} + /// Dataflow-aware ISLE rewrite — tracks local variables and memory state. /// /// Applies all pure structural rewrites (constant folding, algebraic @@ -2747,6 +2917,10 @@ pub type LocalEnv = OptimizationEnv; /// /// Only safe for straight-line code or with env clearing at join points. /// For functions with BrIf/BrTable, use `rewrite_pure` instead. +/// +/// #219: when `env.single_assign` is non-empty, also forwards those +/// single-assignment locals' pure defining expressions (carrier scalar- +/// forwarding), with a reaching-defs invalidation guard. pub fn rewrite_with_dataflow(val: Value, env: &mut OptimizationEnv) -> Value { match val.data() { // Local variable operations @@ -2763,6 +2937,15 @@ pub fn rewrite_with_dataflow(val: Value, env: &mut OptimizationEnv) -> Value { env.locals.remove(idx); } + // #219 reaching-defs guard: reassigning `idx` invalidates any pinned + // forwarding that references it (so a forwarded expr never captures a + // post-redefinition input). Then, if `idx` is the single-assignment + // carrier with a forwardable RHS, pin its defining expression. + invalidate_pins_referencing(env, *idx); + if env.single_assign.contains(idx) && is_forwardable_expr(&simplified_val) { + env.pinned.insert(*idx, simplified_val.clone()); + } + local_set(*idx, simplified_val) } @@ -2770,6 +2953,10 @@ pub fn rewrite_with_dataflow(val: Value, env: &mut OptimizationEnv) -> Value { // Look up in environment - dataflow analysis! if let Some(known_val) = env.locals.get(idx) { known_val.clone() + } else if let Some(pinned_val) = env.pinned.get(idx) { + // #219 carrier scalar-forwarding: forward the single-assignment + // local's defining expression to this use, exposing pack/unpack. + pinned_val.clone() } else { local_get(*idx) } @@ -2787,6 +2974,13 @@ pub fn rewrite_with_dataflow(val: Value, env: &mut OptimizationEnv) -> Value { env.locals.remove(idx); } + // #219 (see LocalSet): invalidate pins referencing `idx`, then pin the + // single-assignment carrier. local.tee both stores AND leaves the value. + invalidate_pins_referencing(env, *idx); + if env.single_assign.contains(idx) && is_forwardable_expr(&simplified_val) { + env.pinned.insert(*idx, simplified_val.clone()); + } + local_tee(*idx, simplified_val) } @@ -3192,6 +3386,13 @@ pub fn rewrite_with_dataflow(val: Value, env: &mut OptimizationEnv) -> Value { // After if: clear env. We don't know which branch was taken. env.locals.clear(); env.invalidate_memory(); + // #219 reaching-defs across the fork: keep a pin only if it SURVIVED + // in BOTH branches. A branch that reassigned one of the pin's input + // locals dropped it (via invalidate_pins_referencing during that + // branch), so the intersection drops any pin whose inputs could have + // changed on either path — sound regardless of which branch ran. + env.pinned + .retain(|k, _| then_env.pinned.contains_key(k) && else_env.pinned.contains_key(k)); Value(Box::new(ValueData::If { label: label.clone(), block_type: block_type.clone(), @@ -3791,6 +3992,21 @@ fn rewrite_pure_impl(val: Value) -> Value { { iconst64(Imm64(0)) } + // #219 seam-SROA: extend_i32_u(x) & M → extend_i32_u(x) when + // M's low 32 bits are all set. Zero-extending an i32 yields a + // value in [0, 2^32), so a mask covering bits [0,32) preserves + // it (bits [32,64) of the extend are already 0). Z3: (zext32 x) + // & M == (zext32 x) when (M & 0xffffffff) == 0xffffffff. + (ValueData::I64ExtendI32U { .. }, ValueData::I64Const { val: m }) + if (m.value() as u64) & 0xffff_ffff == 0xffff_ffff => + { + lhs_simplified + } + (ValueData::I64Const { val: m }, ValueData::I64ExtendI32U { .. }) + if (m.value() as u64) & 0xffff_ffff == 0xffff_ffff => + { + rhs_simplified + } // #219 seam-SROA: (or A B) & M → (survivor & M) when one OR // operand is a left shift the mask clears. Recurse so the // survivor (and a both-shifted case) simplifies further.