Skip to content

Add NVFP4 per-token quantization recipe#3045

Open
cael-ling wants to merge 19 commits into
NVIDIA:mainfrom
cael-ling:feature/nvfp4-per-token-recipe
Open

Add NVFP4 per-token quantization recipe#3045
cael-ling wants to merge 19 commits into
NVIDIA:mainfrom
cael-ling:feature/nvfp4-per-token-recipe

Conversation

@cael-ling

@cael-ling cael-ling commented May 26, 2026

Copy link
Copy Markdown
Contributor

Description

This PR adds an NVFP4 per-token quantization recipe for model pre-training. The default NVFP4BlockScaling recipe computes a single per-tensor outer amax (s_global) per tensor. The per-token variant instead computes a per-row outer amax (length M) for rowwise data and a per-col outer amax (length K) for columnwise data, giving each token/row its own global scale.

Changes

  • Per-token cast kernels: vector-amax + encode/swizzle producing NVFP4 tensors whose _amax_rowwise / _amax_columnwise are per-row/per-col vectors.
  • CUTLASS GEMM (nvfp4_cutlass_per_token_gemm) that rescales with the per-row/per-col outer-amax vectors inside the epilogue;
  • Forward + backward coverage (dgrad NN / wgrad NT layouts).
  • NVFP4PerTokenBlockScaling recipe (re-exported from transformer_engine.pytorch.fp8), plus an equivalent NVTE_NVFP4_PER_TOKEN=1 env-var switch on a plain NVFP4BlockScaling so frameworks that only build a default recipe (e.g. Megatron-Core) can opt in with no code change.
  • Opt-in RHT / SR (per_token_rht / per_token_sr) — off by default on the per-token path.
  • Opt-in 2D weight quantization (per_token_weight_2d): transposition-invariant 16×16 cast emitted in per-token layout.
  • Docs: API reference entry, NVTE_NVFP4_* env-var docs, and a "Per-token NVFP4" feature section with Megatron-Core launch instructions.
  • Example: examples/pytorch/nvfp4_per_token_megatron — single-GPU MoE example comparing per-token vs per-tensor vs BF16 with identical model/data/seed.

Ongoing work

The per-token recipe currently targets accuracy evaluation, not optimized production deployment:

  • Requires NVTE_NORM_FWD_USE_CUDNN=1 (unfused norm forward); the fused norm+amax path rejects per-token quantizers.
  • fuse_wgrad_accumulation=True is unsupported → launch with --no-gradient-accumulation-fusion in Mcore.
  • Forward/backward output quantization, communication/bulk overlap, and CUDA graphs are not yet supported/validated.
  • Kernels are functional but not perf-tuned; use for numerical comparison, not perf benchmarking.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 26, 2026
Rewrites the grouped multi-tensor cast as a K1 fused amax + K2 fused cast
pair and ships pytest correctness + sweep benches against the per-tensor
RHT+SR production baseline.

  * common/cast/.../quantize_nvfp4_per_token_group.cu: K1+K2 fused
    grouped kernel, reusing the single-tensor 4-stage TMA pipeline.
  * common/gemm/nvfp4_per_token_post_scale.cu: row-wise post-scale
    kernel for the cuBLASLT NVFP4 dequantize step (maybe updated due
    to 2d quant of W).
  * pytorch/csrc/extensions/nvfp4_per_token.cpp + pybind.cpp: new C++
    grouped bulk binding and per-token GEMM entry; thin pybind layer.
  * pytorch/custom_recipes/{gemm_nvfp4_per_token,
    quantization_nvfp4_per_token_group}.py: Python wrappers.
  * tests/pytorch/nvfp4/test_nvfp4_per_token{,_group}.py: byte-equal
    cast tests + bf16-close GEMM tests.
  * tests/pytorch/nvfp4/bench_nvfp4_per_token{,_group}.py: 6x3 sweep
    over M in {1024..32768} x K in {2048,4096,8192}, eager + CUDA
    Graphs columns, ratio against per-tensor RHT+SR baseline.

Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
@cael-ling cael-ling force-pushed the feature/nvfp4-per-token-recipe branch from 6f17fe4 to 928ab1c Compare May 27, 2026 13:09
pre-commit-ci Bot and others added 9 commits May 27, 2026 13:10
…uped)

Wire `with_rht` / `random_sign_mask_t` through the per-token K1 (amax)
and K2 (encode) kernels for both single-tensor and grouped paths.
with_rht=False is byte-equal to the pre-RHT code path; when true,
applies a 16-pt RHT on the columnwise direction in both K1 and K2
(rowwise stays raw) with outer amax + inner SF self-consistent.

Implementation: per-thread fp32 FHT on CUDA cores, branchless fp32
sign-bit XOR for the +/-1 sign diagonal, 0.25 normalization folded into
block_amax / block_scale (bit-exact).

Tests cover K1, K2, composite + grouped vs a PyTorch fp32 reference and
byte-equality regressions. Benches gain a --rht flag (2-way default,
3-way under --rht).

Perf vs prod NVFP4Quantizer(rht+sr), Graph mode, 18 shapes M up to 32K:
* single tensor : 0.49x-0.77x (no RHT), 0.59x-0.88x (+RHT)
* grouped (N=8) : 0.41x-0.77x (no RHT), 0.50x-0.94x (+RHT)

Also drops unused THREADS_X_TR / THREADS_Y_TR (nvcc warning NVIDIA#177-D).

Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
Add an optional fused-swizzle path to the NVFP4 per-token K2 encode
kernel: when with_swizzle=True the rowwise scale_inv is emitted directly
in the cuBLAS LT 128Mx4K swizzled tile layout, skipping the downstream
nvte_swizzle_scaling_factors launch. The colwise scale_inv stays in the
compact M-major layout (rowwise-only fusion for now).

The new code path is gated by a kWithSwizzle template parameter on
per_token_encode_kernel. The scatter epilogue uses thread mapping
b=tid&3, ty=tid>>2 to give each warp a coalesced 128-byte gmem store,
and packs two K-tiles into one uint64_t SMEM load (2-way bank conflict
instead of 4-way). Pre-existing code path is byte-equal.

with_swizzle is threaded through nvte_nvfp4_per_token_{quantize,encode},
their PyTorch bindings, and the nvfp4_per_token_{quantize,encode} Python
recipes. nvfp4_per_token_gemm takes new a_sf_swizzled / b_sf_swizzled
flags so the caller opts into the fast path per operand (mirrors prod
NVFP4 GEMM's per-operand swizzle).

Add tex.nvfp4_per_token_swizzle_rowwise_sf -- a thin wrapper around
nvte_swizzle_scaling_factors that does one standalone per-operand
swizzle launch. Bench-only; lets --qs attribute swizzle cost separately
from K1+K2 and from cuBLAS LT GEMM.

Bench (bench_nvfp4_per_token.py): add --qs mode (K1+K2 + standalone
swizzle, no GEMM) with two modifiers -- --pair (2 operands, matches one
prod GEMM call's quant+swizzle pipeline) and --fuse (adds a per-token
(fuse) column for the K2-fused path). The existing --swizzle end-to-end
mode also gains the fused-swizzle column. --pair / --fuse auto-imply
--qs to avoid silent fall-through to the default --composite table.

Tests (test_nvfp4_per_token.py): byte-equality of the fused-swizzle
rowwise SF vs a pure-Python permutation reference, byte-equality of all
other outputs (FP4 data, colwise SF, row/col amax) vs with_swizzle=False,
and numerical equivalence of the end-to-end GEMM via both code paths.

Perf at K=N=4096, Graph mode: fused-swizzle path is ~7-35% faster than
the unfused per-token pipeline (--qs) and reaches up to ~2.6x faster
than per-tensor at small M.

Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
The per-token cuBLASLt NVFP4 path needs a trailing post-scale kernel
(D *= alpha_a[i] * alpha_b[j]) that is HBM-bound on the M*N output. This
patch ships a forked-CUTLASS NVFP4 GEMM whose EVT epilogue folds the
per-row * per-col rescale into the in-TMEM accumulator -- a single launch
with no separate post-scale, no M*N HBM round-trip.

New C-API entry points (transformer_engine/common/gemm/nvfp4_cutlass_gemm.cu):
  - nvte_nvfp4_cutlass_gemm: scalar (alpha, beta) NVFP4xNVFP4 -> BF16 GEMM
    (CUTLASS analog of the cuBLASLt per-tensor path; used as test ground truth).
  - nvte_nvfp4_cutlass_per_token_gemm: same mainloop, EVT epilogue
    D[i,j] = bf16(NVFP4_DEQUANT_K * alpha_a[i] * alpha_b[j] * acc).
    The outer 1/2688^2 factor (NVFP4 spec) is baked into the EVT explicitly,
    matching the value cuBLASLt auto-folds via its amax slot.

Python bindings (tex.nvfp4_cutlass_gemm / tex.nvfp4_cutlass_per_token_gemm)
plus a/b_sf_swizzled flags for apples-to-apples --gemm-only benching.

Numerical correctness (tests/pytorch/nvfp4/test_nvfp4_cutlass_per_token_gemm.py):
  - fused EVT == cuBLASLt per-token within bf16 ULP (rtol=2e-2), across
    M,N,K = 256..1024.
  - fused EVT with unity alphas == nvfp4_cutlass_gemm(alpha=1/2688^2) BIT-EXACT
    (sanity check that the EVT tree and the baked constant are both correct).

Bench (tests/pytorch/nvfp4/bench_nvfp4_per_token.py --gemm-only) streamlined
to the only comparison that matters for shipping: ct_fused (per-token CUTLASS
fused) vs pten_gemm (prod per-tensor cuBLASLt), with the cf/pten ratio.

Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
Extends tests/pytorch/nvfp4/{bench,test}_nvfp4_cutlass_per_token_gemm
with end-to-end forward and backward coverage that aligns the prod
baseline with NVFP4BlockScaling real-ship defaults (input RHT-1D,
weight 2D no-RHT, grad RHT-cols + SR), so per-token (no RHT/SR) is
measured against an actually-shippable prod recipe rather than a
toy quantizer.

bench_nvfp4_per_token.py:
* --e2e-fwd: per-token quant (with_swizzle=True) + fused-EVT CUTLASS
  GEMM vs NVFP4Quantizer + general_gemm (the real nn.Linear fwd
  dispatch). Quant + GEMM inside the timing loop, N = K. Function
  docstring carries an ASCII kernel-pipeline diagram for both paths
  (per-call launch budget: per-token ~5 vs prod ~10).
* --e2e-bwd: real prod nn.Linear.bwd lifecycle. Timing loop = 1 x dY
  quant + dgrad GEMM + wgrad GEMM; X and W are pre-quantized OUTSIDE
  the loop (mirrors prod's reuse of fwd-saved QuantizedTensorStorage,
  bwd never re-quantizes). pten side uses RHT-cols + SR grad
  quantizer + general_gemm NN (dgrad) / NT (wgrad). Function docstring
  carries an ASCII kernel-pipeline diagram (per-step launch budget:
  per-token ~4 vs prod ~12).
* --gemm-only: 3-way table adds an lt_post column (cuBLASLt NVFP4 +
  bf16 per-row*per-col post-scale, "Route 1") next to the existing
  ct_fused fused-EVT path ("Route 2") and the prod pten_gemm
  baseline. Headline ratio lp/cf decides whether to dispatch
  per-token through cuBLASLt + post_scale or fused EVT; current
  data shows ct_fused wins or ties at every shape we care about.

test_nvfp4_cutlass_per_token_gemm.py:
* Layer 2 fwd: per-token quant + fused-EVT GEMM vs BF16 fp32 ground
  truth (rel_l2 < 0.30, robust to per-shape noise).
* Layer 3 fwd: dual-SNR table comparing per-token vs prod, both
  measured against BF16 ground truth, with a per-token-vs-prod ratio.
* Layer 3 bwd: same dual-SNR pattern for dgrad and wgrad. Prod side
  uses real-ship NVFP4BlockScaling grad quantizer (RHT cols + SR);
  per-token side has no RHT/SR (numerical-floor comparison).
* Sanity micro-test for weight 2D quant plumbing through general_gemm
  (catches breakage cheaper than the broader Layer 3 test).

Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr int dshmem_size = buff_size_aligned_in + TMA_SHMEM_ALIGNMENT; // + align pad

dim3 grid(static_cast<unsigned>(K / CHUNK_DIM_X), static_cast<unsigned>(M / CHUNK_DIM_Y), 1);

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe use DIVUP here to handle the remainder case?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fast path has a hard precondition that M and K are exact multiples of CHUNK_DIM (128): validate() does NVTE_CHECK(M % CHUNK_DIM_Y == 0) / NVTE_CHECK(K % CHUNK_DIM_X == 0), and is_supported() returns false unless both hold — so any non-multiple shape is rejected / routed to the generic per-token fallback before it ever reaches this launcher.

// After all 4 stages, emit one atomicMaxFloat per row slot + one per col slot.
//
// kWithRht=true: col-wise amax over RHT-rotated 16-row strips (per-thread
// FHT with random_sign_mask_t). Row direction never sees RHT.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: Row direction never sees RHT -> Row direction never uses RHT

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch

}
}
#else
NVTE_DEVICE_ERROR("Per-token amax kernel requires SM 10.0+ (Blackwell).");

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For these quantization kernel, TMA only require SM 9.0+ only. Is there any other constraints that limit to sm 10.0+?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CUDA_ARCH >= 1000 guard is intentional but not because of a hardware op in this kernel. Two reasons:

  1. The shared TE PTX wrappers it calls — cp_async_bulk_tensor_2d_global_to_shared and mbarrier_wait_parity_acquire_cta_shared_cta in util/ptx.cuh — are themselves guarded to >= 1000 and emit NVTE_DEVICE_ERROR below that. They were authored/validated only for the Blackwell path.
  2. The whole NVFP4 quantize path is host-gated to SM100 anyway (NVTE_ERROR("NVFP4 requires SM100 ...")), since NVFP4 is a Blackwell datatype and the downstream FP4 GEMM that consumes these scales only exists on SM100. So the amax kernel is never launched off <SM100; the per-arch guard just yields a clean error instead of an undefined symbol.

cael-ling and others added 8 commits June 2, 2026 01:28
Add NN/NT GEMM layout dispatch so the per-token NVFP4 path covers dgrad
and wgrad, and let per-token opt into RHT via
NVFP4PerTokenBlockScaling(per_token_rht=...) while SR/2D stay disabled
(kernels unimplemented at this commit). Extends the per-token CUTLASS
GEMM, the torch NVFP4Quantizer, and the NVFP4Tensor plumbing, plus
dgrad/wgrad numerical tests and a fwd+bwd module smoke test.

Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
Thread a Philox rng_state and a kWithSr template flag through the
per-token encode kernel (rowwise + colwise) and the
nvte_nvfp4_per_token_encode/quantize C-API, mirroring the per-tensor
SR path. Drop the SR mutex check in the torch NVFP4Quantizer and build
the rng_state when stochastic rounding is requested. Add a per_token_sr
recipe flag on NVFP4PerTokenBlockScaling wired through the quantizer
factory, plus statistical tests (SR unbiasedness -- lower RMSE than RN
when averaged -- and RN-determinism / SR-nondeterminism) folded into
test_nvfp4_per_token.py.

Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
Wire with_sr + rng_state through the grouped per-token C-API and cast
dispatch, implement the SR FP4 cast in the grouped kernel, and drop the
"per-token does not support SR" guard. Also fix two comment typos
(sees -> uses) in quantize_nvfp4_per_token.cu per review.

Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
Introduce NVTE_NVFP4_PER_TOKEN_WEIGHT_2D (recipe.per_token_weight_2d),
default off so the per-token path stays byte-equal. When enabled, only the
forward WEIGHT switches to the per-tensor 2D cast (16x16 inner tile + scalar
outer amax) re-dressed in per-token tensor layout: the scalar outer amax is
broadcast across the per-row/col alpha vectors and the inner SF is the same
16-row-replicated 2D tile, so the existing per-token CUTLASS GEMM consumes it
unchanged with no kernel modification. Activation/gradient casts stay
per-token 1D.

Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
Document the user-facing surface of the NVFP4 per-token recipe and add a
runnable single-GPU example so the recipe can be exercised end to end.

- docs/api/common.rst: list NVFP4PerTokenBlockScaling in the API reference.
- docs/envvars.rst: document the NVTE_NVFP4_* knobs -- per-token activation
  (NVTE_NVFP4_PER_TOKEN) plus the RHT/SR/weight-2D opt-ins, and the
  per-tensor disable flags.
- docs/features/.../nvfp4.rst: add a "Per-token NVFP4" section explaining the
  per-row/per-col outer-amax cast, its differences from the per-tensor default
  (RHT/SR off by default, forced-off knobs, unfused-norm requirement), and how
  to launch it with Megatron-Core.
- recipe/__init__.py: document the per_token_rht/per_token_sr/per_token_weight_2d
  constructor kwargs and drop the stale "stochastic rounding unsupported" note.
- pytorch/fp8.py: re-export NVFP4PerTokenBlockScaling.
- examples/pytorch/nvfp4_per_token_megatron: single-GPU MoE example (run +
  sbatch + job-chain scripts and README) comparing per-token vs per-tensor vs
  BF16 with identical model/data/seed.

Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
@cael-ling cael-ling marked this pull request as ready for review June 11, 2026 07:58
@greptile-apps

greptile-apps Bot commented Jun 11, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR adds a new NVFP4PerTokenBlockScaling recipe that replaces the per-tensor outer amax with a per-row (M,) / per-col (K,) vector, giving each token its own global scale. It introduces per-token CUDA cast kernels (K1 vector-amax + K2 encode/swizzle), a fused CUTLASS EVT GEMM (nvfp4_cutlass_per_token_gemm) that rescales with those vectors in the epilogue, and an env-var-based activation path (NVTE_NVFP4_PER_TOKEN=1) so frameworks that only construct a plain NVFP4BlockScaling (e.g. Megatron-Core) can opt in without code changes. As noted in the PR description, this targets accuracy evaluation rather than optimized production deployment.

  • Adds per-token cast kernels, a fused-EVT CUTLASS GEMM (single + grouped), and a NVFP4PerTokenBlockScaling recipe class (re-exported from transformer_engine.pytorch.fp8) covering forward + backward (dgrad NN / wgrad NT layouts).
  • Extends NVFP4Quantizer in Python and C++ with per_token / per_token_weight_2d flags; general_gemm / general_grouped_gemm auto-dispatch to the per-token CUTLASS GEMM based on the _per_token tensor flag.
  • Documents all known limitations (unfused norm required, no wgrad-accumulation fusion, no comm-overlap/CUDA graphs) and adds an end-to-end Megatron-Core MoE example.

Confidence Score: 4/5

Safe to merge for evaluation/research use as described in the PR; the feature is explicitly gated as non-production and the known unsupported paths (fuse_wgrad_accumulation, comm-overlap, CUDA graphs) all raise clear runtime errors rather than silently corrupting results.

The implementation is large but well-structured, with comprehensive guards in both Python and C++. Fresh findings are style/robustness observations. The return-type divergence in _nvfp4_per_token_grouped_gemm (tensor vs list for single_output=True) is harmless today because all call sites discard the return value and rely on in-place writes, but could trap a future caller. The env-var recipe-state divergence path produces misleading logs in an unusual (mid-run env-var change) scenario. Both are minor concerns for a feature intentionally scoped to accuracy evaluation.

transformer_engine/pytorch/cpp_extensions/gemm.py (return-type consistency in grouped per-token path) and transformer_engine/common/recipe/init.py (env-var re-evaluation in nvfp4_per_token classmethod).

Important Files Changed

Filename Overview
transformer_engine/pytorch/cpp_extensions/gemm.py Adds _nvfp4_per_token_gemm and _nvfp4_per_token_grouped_gemm helpers plus dispatch logic in general_gemm / general_grouped_gemm; well-guarded but the grouped helper's return type differs from the non-per-token path when single_output=True (returns tensor vs list), though current callers ignore the return value so this is benign.
transformer_engine/pytorch/csrc/quantizer.cpp Adds per_token / per_token_weight_2d branches to NVFP4Quantizer: create_tensor, convert_and_update_tensor, and quantize_impl. The per_token_weight_2d amax-set-before-quantize ordering is correct for the reachable (rowwise != nullptr) code path; the latent null-ptr edge case remains (previously flagged).
transformer_engine/common/recipe/init.py Adds NVFP4PerTokenBlockScaling subclass and env-var activation path (NVTE_NVFP4_PER_TOKEN=1) on the base class; nvfp4_per_token() classmethod + _force_per_token_settings() keep the two activation paths consistent.
transformer_engine/pytorch/quantization.py Correctly computes per_token / per_token_weight_2d flags per tensor_type and mode; per-token feature overrides (rht, SR, 2d, 4over6) are properly gated in the NVFP4Quantizer constructor call.
transformer_engine/pytorch/tensor/nvfp4_tensor.py Propagates per_token flag through NVFP4Tensor lifecycle (copy, reshape, view, FSDP2 metadata, reduce_ex); per_token and per_token_weight_2d field docstrings are string literals placed after the annotation (not before), so they are discarded by Python and won't appear in help() output.
transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token_group.py Public grouped-quantize entry with validation; rejects split_sections[i] <= 0 which prevents use with zero-token experts (previously flagged). Shape checks and 64-group cap are correct.
transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token.py Pure-PyTorch reference quantizer plus production composite/K1-only/K2-only Python wrappers; arithmetic matches the kernel (row amax floor, S_enc computation, FP8 saturating cast) and validation matches kernel constraints.
transformer_engine/pytorch/custom_recipes/gemm_nvfp4_per_token.py Stand-alone production GEMM wrapper with clear TN/NN/NT layout dispatch table and good shape-mismatch error messages.
transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py Adds _per_token class annotation, constructor parameter, copy_from_storage mismatch check, and __init_kwargs propagation; straightforward and complete.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["RecipeState.get_quantizer()"] --> B{nvfp4_per_token?}
    B -- "False\n(default NVFP4BlockScaling)" --> C["per-tensor path\n(cuBLASLt 1A/2A)"]
    B -- "True\n(NVFP4PerTokenBlockScaling\nOR NVTE_NVFP4_PER_TOKEN=1)" --> D{per_token_weight_2d?}
    D -- "False (default)" --> E["NVFP4Quantizer\n(per_token=True)"]
    D -- "True\n(NVTE_NVFP4_PER_TOKEN_WEIGHT_2D=1)" --> F["NVFP4Quantizer\n(per_token=True,\nper_token_weight_2d=True)"]
    E --> G["quantize_impl:\nnvte_nvfp4_per_token_quantize\n(K1 row/col amax + K2 FP4 encode)"]
    F --> H["quantize_impl:\nnvte_compute_amax → nvte_quantize_v2\n(2D cast) → broadcast amax vector"]
    G --> I["NVFP4Tensor\n_amax_rowwise shape: (M,)\n_amax_columnwise shape: (K,)"]
    H --> I
    I --> J["general_gemm / general_grouped_gemm:\n_is_nvfp4_per_token_tensor check"]
    J -- "True" --> K["_nvfp4_per_token_gemm\nor\n_nvfp4_per_token_grouped_gemm\n(CUTLASS EVT fused epilogue)"]
    J -- "False" --> L["tex.generic_gemm\n(cuBLASLt)"]
    K --> M["Output: BF16\nD = alpha_row x alpha_col · (A @ B^T)"]
    L --> N["Output: BF16/FP8\n(standard path)"]
Loading

Reviews (2): Last reviewed commit: "Add native CUTLASS grouped per-token NVF..." | Re-trigger Greptile

Comment on lines +2399 to +2413
NVTE_CHECK(amax_ptr != nullptr, "Could not find amax pointer for per-token weight-2D.");

// 1. Single scalar tensor amax -> amax[0] (mirror the per-tensor no-RHT path:
// treat the buffer as length 1 for the reduction, then fan out to both
// rowwise/columnwise amax[0]).
out.set_amax(amax_ptr, DType::kFloat32, std::vector<size_t>{1});
NVTE_SCOPED_GIL_RELEASE(
{ nvte_compute_amax_with_config(input.data(), out.data(), w2d_config, stream); });
out.set_amax(rowwise_amax_ptr, DType::kFloat32, std::vector<size_t>{1});
if (rowwise_amax_ptr != amax_ptr && rowwise_amax_ptr != nullptr) {
NVTE_CHECK_CUDA(cudaMemcpyAsync(rowwise_amax_ptr, amax_ptr, sizeof(float),
cudaMemcpyDeviceToDevice, stream));
}
if (columnwise_amax_ptr != amax_ptr && columnwise_amax_ptr != nullptr) {
NVTE_CHECK_CUDA(cudaMemcpyAsync(columnwise_amax_ptr, amax_ptr, sizeof(float),

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Latent null-pointer dereference in per-token weight-2D amax restore

After nvte_compute_amax_with_config writes the global amax into amax_ptr[0], the code restores the output tensor's amax field with out.set_amax(rowwise_amax_ptr, ...). If rowwise_amax_ptr is nullptr (i.e., the quantizer was constructed with rowwise=False), this sets the output's amax descriptor to a null pointer. The immediately following nvte_quantize_v2 then tries to read amax[0] to derive S_enc and will crash.

Currently this path is unreachable because per_token_weight_2d is only set for weight quantizers, and all weight quantizers in the recipe are constructed with rowwise=True, columnwise=True. However, the guard in step 3 (if (rowwise_amax_ptr != nullptr && w2d_rows > 1)) shows the author anticipated both pointers could be null, while the critical out.set_amax call on this line does not. Using amax_ptr (the non-null pointer already validated by the NVTE_CHECK above) would be safe in all configurations: out.set_amax(amax_ptr, DType::kFloat32, std::vector<size_t>{1}).

Comment on lines +507 to +530
# Per-token NVFP4 dispatches to fused EVT GEMM that consumes per-row
# (M,) and per-col (N,) outer-amax vectors directly. cuBLASLt cannot,
# so this MUST short-circuit before the row-scaled-or-generic fork.
if _is_nvfp4_per_token_tensor(A) or _is_nvfp4_per_token_tensor(B):
if not (_is_nvfp4_per_token_tensor(A) and _is_nvfp4_per_token_tensor(B)):
raise NotImplementedError(
"NVFP4 per-token GEMM requires both A and B to be per-token tensors. "
"Mixing per-token + prod NVFP4 in one GEMM is not supported."
)
out = _nvfp4_per_token_gemm(
A,
B,
transa=transa,
transb=transb,
out=out,
out_dtype=out_dtype,
bias=bias,
grad=grad,
accumulate=accumulate,
gelu=gelu,
quantization_params=quantization_params,
ub=ub,
extra_output=extra_output,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 alpha scalar silently ignored for per-token GEMM

general_gemm validates and stores alpha in kwargs["alpha"], but the per-token short-circuit path dispatches to _nvfp4_per_token_gemm which has no alpha parameter and never forwards the value. The C++ binding nvfp4_cutlass_per_token_gemm also lacks a global scalar alpha argument — only the per-row/per-col alpha_a/alpha_b vectors are supported. For all current TE module call sites alpha=1.0 is the invariant, so numerical output is unaffected today. If a caller ever passes alpha != 1.0 through general_gemm with per-token tensors, the result will be silently wrong instead of raising an error.

Comment on lines +47 to +51
for i, M_i in enumerate(split_sections):
if M_i <= 0:
raise ValueError(f"split_sections[{i}] must be > 0, got {M_i}")
if M_i % _PER_TOKEN_TILE != 0:
raise ValueError(f"split_sections[{i}] = {M_i} must be a multiple of {_PER_TOKEN_TILE}")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Public grouped-quantize API unconditionally rejects 0-token splits

split_sections[i] <= 0 raises ValueError, but in MoE training with dynamic token routing, experts commonly receive zero tokens in a given micro-batch. The general_grouped_gemm per-token loop already handles this by skipping the launch when m_splits[i] == 0, so the GEMM side is fine. If users call this Python wrapper directly (e.g., from bench scripts or custom MoE quantization pipelines), they must pre-filter empty experts. A comment or guard skipping allocation for empty splits would make the API usable in unbalanced-routing scenarios.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Replace the per-expert Python loop for the plain
D = bf16(alpha_a * alpha_b * (A @ B^T)) path with a single ptr-array
CUTLASS grouped kernel (SM100). The dispatcher in general_grouped_gemm
routes to the native kernel when no accumulate/bias/gelu/output-quant is
requested, and otherwise falls back to the per-expert loop
(NVTE_NVFP4_PER_TOKEN_GROUPED_FALLBACK=1 forces the fallback).

The launcher caches the SM count and reuses persistent device scratch +
workspace buffers across launches to avoid per-call cudaMalloc/Free and
cudaGetDeviceProperties overhead. Parity tests assert the grouped kernel
matches the dense per-token GEMM bit-exact per group.

Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants