Skip to content

perf: speed up top-p sampling for large vocabularies#1232

Open
NorbertKlockiewicz wants to merge 4 commits into
mainfrom
@nk/optimize-sampling
Open

perf: speed up top-p sampling for large vocabularies#1232
NorbertKlockiewicz wants to merge 4 commits into
mainfrom
@nk/optimize-sampling

Conversation

@NorbertKlockiewicz

@NorbertKlockiewicz NorbertKlockiewicz commented Jun 15, 2026

Copy link
Copy Markdown
Contributor

Description

Optimizes token sampling for large-vocabulary models (e.g. Gemma 4 E2B, 262k vocab), where the previous full-vocabulary sort in top-p dominated per-token latency.

Two changes in sampler.cpp:

  • mask_topp: replaces the O(n log n) sort over all logits with a logit-space histogram (kBins=2048) that locates the nucleus threshold in two O(n) passes — no sort, no per-token vocab-sized allocation. Binning in logit space (rather than probability space) keeps uniform resolution for both peaked and flat distributions.
  • softmax: skips exp() on logits already masked to lowest() by top-k/top-p. The result underflows to zero anyway, and the call is slow on device.

On an iPhone 17 Pro with Gemma 4 E2B (int4), per-token sampling drops from ~45 ms to ~10 ms. The histogram approximates the exact sort-based nucleus; the resulting sampled distribution is statistically equivalent (verified the kept-mass fraction stays within <1% of the exact nucleus across peaked, flat, and sharp distributions).

Introduces a breaking change?

  • Yes
  • No

Type of change

  • Bug fix (change which fixes an issue)
  • New feature (change which adds functionality)
  • Documentation update (improves or adds clarity to existing documentation)
  • Other (chores, tests, code style improvements etc.)

Tested on

  • iOS
  • Android

Testing instructions

  1. Run an LLM with a large vocabulary and a non-zero temperature with topP set (e.g. Gemma 4 E2B with temperature: 0.8, topP: 0.9).
  2. Generate a long response and observe tokens/sec.
  3. Confirm output remains coherent and sampling is unchanged in character (still stochastic, not greedy).

Greedy decoding (temperature: 0) is unaffected — it bypasses this path entirely.

Screenshots

Related issues

Checklist

  • I have performed a self-review of my code
  • I have commented my code, particularly in hard-to-understand areas
  • I have updated the documentation accordingly
  • My changes generate no new warnings

Additional notes

The histogram is an approximation bounded by bin granularity (kBins=2048 over a kRange=40 logit span). This is intentional: exact top-p over a 262k vocab where the nucleus can exceed 100k tokens is inherently expensive, and the sampling outcome is statistically indistinguishable from the exact version.

Replace the full-vocabulary sort in mask_topp with a logit-space
histogram that finds the nucleus threshold in two O(n) passes, and skip
exp() on masked logits in softmax. For Gemma's 262k vocab this cuts
per-token sampling from ~45ms to ~10ms on device. The histogram
approximates the exact sort-based nucleus; the sampled distribution is
statistically equivalent.

Authored with Claude.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Comment thread packages/react-native-executorch/common/runner/sampler.cpp Outdated
Comment thread packages/react-native-executorch/common/runner/sampler.cpp Outdated
Comment thread packages/react-native-executorch/common/runner/sampler.cpp Outdated
Comment thread packages/react-native-executorch/common/runner/sampler.cpp Outdated
Comment thread packages/react-native-executorch/common/runner/sampler.cpp Outdated
Comment thread packages/react-native-executorch/common/runner/sampler.cpp Outdated
Comment thread packages/react-native-executorch/common/runner/sampler.cpp Outdated
Comment thread packages/react-native-executorch/common/runner/sampler.cpp Outdated
Use std::ranges::max_element over a std::span for the max scan, std::clamp
for bin clamping, and explicit int32_t for bin indices.

Authored with Claude.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
NorbertKlockiewicz and others added 2 commits June 15, 2026 17:14
float is sufficient for the exp() accumulation over the vocab (error well
below the histogram bin granularity); double bought no real precision.
Accumulating in T directly is still avoided since T may be bf16, which
saturates when summing many small terms.

Authored with Claude.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Use acc_t = double when T is double, float otherwise. Every logit
conversion is then a widening (or no-op), never a narrowing, regardless
of which logit dtype instantiates the sampler. Accumulating in T itself
stays avoided because bf16 saturates when summing exp() over the vocab.

Authored with Claude.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@msluszniak

Copy link
Copy Markdown
Member

Code is ok, now testing

@msluszniak msluszniak left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran it on my android with different topp and other params described in testing steps and it worked.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants