Add NVFP4 per-token quantization recipe#3045
Conversation
Rewrites the grouped multi-tensor cast as a K1 fused amax + K2 fused cast
pair and ships pytest correctness + sweep benches against the per-tensor
RHT+SR production baseline.
* common/cast/.../quantize_nvfp4_per_token_group.cu: K1+K2 fused
grouped kernel, reusing the single-tensor 4-stage TMA pipeline.
* common/gemm/nvfp4_per_token_post_scale.cu: row-wise post-scale
kernel for the cuBLASLT NVFP4 dequantize step (maybe updated due
to 2d quant of W).
* pytorch/csrc/extensions/nvfp4_per_token.cpp + pybind.cpp: new C++
grouped bulk binding and per-token GEMM entry; thin pybind layer.
* pytorch/custom_recipes/{gemm_nvfp4_per_token,
quantization_nvfp4_per_token_group}.py: Python wrappers.
* tests/pytorch/nvfp4/test_nvfp4_per_token{,_group}.py: byte-equal
cast tests + bf16-close GEMM tests.
* tests/pytorch/nvfp4/bench_nvfp4_per_token{,_group}.py: 6x3 sweep
over M in {1024..32768} x K in {2048,4096,8192}, eager + CUDA
Graphs columns, ratio against per-tensor RHT+SR baseline.
Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
6f17fe4 to
928ab1c
Compare
for more information, see https://pre-commit.ci
…uped) Wire `with_rht` / `random_sign_mask_t` through the per-token K1 (amax) and K2 (encode) kernels for both single-tensor and grouped paths. with_rht=False is byte-equal to the pre-RHT code path; when true, applies a 16-pt RHT on the columnwise direction in both K1 and K2 (rowwise stays raw) with outer amax + inner SF self-consistent. Implementation: per-thread fp32 FHT on CUDA cores, branchless fp32 sign-bit XOR for the +/-1 sign diagonal, 0.25 normalization folded into block_amax / block_scale (bit-exact). Tests cover K1, K2, composite + grouped vs a PyTorch fp32 reference and byte-equality regressions. Benches gain a --rht flag (2-way default, 3-way under --rht). Perf vs prod NVFP4Quantizer(rht+sr), Graph mode, 18 shapes M up to 32K: * single tensor : 0.49x-0.77x (no RHT), 0.59x-0.88x (+RHT) * grouped (N=8) : 0.41x-0.77x (no RHT), 0.50x-0.94x (+RHT) Also drops unused THREADS_X_TR / THREADS_Y_TR (nvcc warning NVIDIA#177-D). Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com> Signed-off-by: Cael Ling <caell@nvidia.com>
for more information, see https://pre-commit.ci
Add an optional fused-swizzle path to the NVFP4 per-token K2 encode
kernel: when with_swizzle=True the rowwise scale_inv is emitted directly
in the cuBLAS LT 128Mx4K swizzled tile layout, skipping the downstream
nvte_swizzle_scaling_factors launch. The colwise scale_inv stays in the
compact M-major layout (rowwise-only fusion for now).
The new code path is gated by a kWithSwizzle template parameter on
per_token_encode_kernel. The scatter epilogue uses thread mapping
b=tid&3, ty=tid>>2 to give each warp a coalesced 128-byte gmem store,
and packs two K-tiles into one uint64_t SMEM load (2-way bank conflict
instead of 4-way). Pre-existing code path is byte-equal.
with_swizzle is threaded through nvte_nvfp4_per_token_{quantize,encode},
their PyTorch bindings, and the nvfp4_per_token_{quantize,encode} Python
recipes. nvfp4_per_token_gemm takes new a_sf_swizzled / b_sf_swizzled
flags so the caller opts into the fast path per operand (mirrors prod
NVFP4 GEMM's per-operand swizzle).
Add tex.nvfp4_per_token_swizzle_rowwise_sf -- a thin wrapper around
nvte_swizzle_scaling_factors that does one standalone per-operand
swizzle launch. Bench-only; lets --qs attribute swizzle cost separately
from K1+K2 and from cuBLAS LT GEMM.
Bench (bench_nvfp4_per_token.py): add --qs mode (K1+K2 + standalone
swizzle, no GEMM) with two modifiers -- --pair (2 operands, matches one
prod GEMM call's quant+swizzle pipeline) and --fuse (adds a per-token
(fuse) column for the K2-fused path). The existing --swizzle end-to-end
mode also gains the fused-swizzle column. --pair / --fuse auto-imply
--qs to avoid silent fall-through to the default --composite table.
Tests (test_nvfp4_per_token.py): byte-equality of the fused-swizzle
rowwise SF vs a pure-Python permutation reference, byte-equality of all
other outputs (FP4 data, colwise SF, row/col amax) vs with_swizzle=False,
and numerical equivalence of the end-to-end GEMM via both code paths.
Perf at K=N=4096, Graph mode: fused-swizzle path is ~7-35% faster than
the unfused per-token pipeline (--qs) and reaches up to ~2.6x faster
than per-tensor at small M.
Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
for more information, see https://pre-commit.ci
The per-token cuBLASLt NVFP4 path needs a trailing post-scale kernel
(D *= alpha_a[i] * alpha_b[j]) that is HBM-bound on the M*N output. This
patch ships a forked-CUTLASS NVFP4 GEMM whose EVT epilogue folds the
per-row * per-col rescale into the in-TMEM accumulator -- a single launch
with no separate post-scale, no M*N HBM round-trip.
New C-API entry points (transformer_engine/common/gemm/nvfp4_cutlass_gemm.cu):
- nvte_nvfp4_cutlass_gemm: scalar (alpha, beta) NVFP4xNVFP4 -> BF16 GEMM
(CUTLASS analog of the cuBLASLt per-tensor path; used as test ground truth).
- nvte_nvfp4_cutlass_per_token_gemm: same mainloop, EVT epilogue
D[i,j] = bf16(NVFP4_DEQUANT_K * alpha_a[i] * alpha_b[j] * acc).
The outer 1/2688^2 factor (NVFP4 spec) is baked into the EVT explicitly,
matching the value cuBLASLt auto-folds via its amax slot.
Python bindings (tex.nvfp4_cutlass_gemm / tex.nvfp4_cutlass_per_token_gemm)
plus a/b_sf_swizzled flags for apples-to-apples --gemm-only benching.
Numerical correctness (tests/pytorch/nvfp4/test_nvfp4_cutlass_per_token_gemm.py):
- fused EVT == cuBLASLt per-token within bf16 ULP (rtol=2e-2), across
M,N,K = 256..1024.
- fused EVT with unity alphas == nvfp4_cutlass_gemm(alpha=1/2688^2) BIT-EXACT
(sanity check that the EVT tree and the baked constant are both correct).
Bench (tests/pytorch/nvfp4/bench_nvfp4_per_token.py --gemm-only) streamlined
to the only comparison that matters for shipping: ct_fused (per-token CUTLASS
fused) vs pten_gemm (prod per-tensor cuBLASLt), with the cf/pten ratio.
Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
for more information, see https://pre-commit.ci
Extends tests/pytorch/nvfp4/{bench,test}_nvfp4_cutlass_per_token_gemm
with end-to-end forward and backward coverage that aligns the prod
baseline with NVFP4BlockScaling real-ship defaults (input RHT-1D,
weight 2D no-RHT, grad RHT-cols + SR), so per-token (no RHT/SR) is
measured against an actually-shippable prod recipe rather than a
toy quantizer.
bench_nvfp4_per_token.py:
* --e2e-fwd: per-token quant (with_swizzle=True) + fused-EVT CUTLASS
GEMM vs NVFP4Quantizer + general_gemm (the real nn.Linear fwd
dispatch). Quant + GEMM inside the timing loop, N = K. Function
docstring carries an ASCII kernel-pipeline diagram for both paths
(per-call launch budget: per-token ~5 vs prod ~10).
* --e2e-bwd: real prod nn.Linear.bwd lifecycle. Timing loop = 1 x dY
quant + dgrad GEMM + wgrad GEMM; X and W are pre-quantized OUTSIDE
the loop (mirrors prod's reuse of fwd-saved QuantizedTensorStorage,
bwd never re-quantizes). pten side uses RHT-cols + SR grad
quantizer + general_gemm NN (dgrad) / NT (wgrad). Function docstring
carries an ASCII kernel-pipeline diagram (per-step launch budget:
per-token ~4 vs prod ~12).
* --gemm-only: 3-way table adds an lt_post column (cuBLASLt NVFP4 +
bf16 per-row*per-col post-scale, "Route 1") next to the existing
ct_fused fused-EVT path ("Route 2") and the prod pten_gemm
baseline. Headline ratio lp/cf decides whether to dispatch
per-token through cuBLASLt + post_scale or fused EVT; current
data shows ct_fused wins or ties at every shape we care about.
test_nvfp4_cutlass_per_token_gemm.py:
* Layer 2 fwd: per-token quant + fused-EVT GEMM vs BF16 fp32 ground
truth (rel_l2 < 0.30, robust to per-shape noise).
* Layer 3 fwd: dual-SNR table comparing per-token vs prod, both
measured against BF16 ground truth, with a per-token-vs-prod ratio.
* Layer 3 bwd: same dual-SNR pattern for dgrad and wgrad. Prod side
uses real-ship NVFP4BlockScaling grad quantizer (RHT cols + SR);
per-token side has no RHT/SR (numerical-floor comparison).
* Sanity micro-test for weight 2D quant plumbing through general_gemm
(catches breakage cheaper than the broader Layer 3 test).
Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
for more information, see https://pre-commit.ci
| DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT); | ||
| constexpr int dshmem_size = buff_size_aligned_in + TMA_SHMEM_ALIGNMENT; // + align pad | ||
|
|
||
| dim3 grid(static_cast<unsigned>(K / CHUNK_DIM_X), static_cast<unsigned>(M / CHUNK_DIM_Y), 1); |
There was a problem hiding this comment.
maybe use DIVUP here to handle the remainder case?
There was a problem hiding this comment.
This fast path has a hard precondition that M and K are exact multiples of CHUNK_DIM (128): validate() does NVTE_CHECK(M % CHUNK_DIM_Y == 0) / NVTE_CHECK(K % CHUNK_DIM_X == 0), and is_supported() returns false unless both hold — so any non-multiple shape is rejected / routed to the generic per-token fallback before it ever reaches this launcher.
| // After all 4 stages, emit one atomicMaxFloat per row slot + one per col slot. | ||
| // | ||
| // kWithRht=true: col-wise amax over RHT-rotated 16-row strips (per-thread | ||
| // FHT with random_sign_mask_t). Row direction never sees RHT. |
There was a problem hiding this comment.
typo: Row direction never sees RHT -> Row direction never uses RHT
| } | ||
| } | ||
| #else | ||
| NVTE_DEVICE_ERROR("Per-token amax kernel requires SM 10.0+ (Blackwell)."); |
There was a problem hiding this comment.
For these quantization kernel, TMA only require SM 9.0+ only. Is there any other constraints that limit to sm 10.0+?
There was a problem hiding this comment.
The CUDA_ARCH >= 1000 guard is intentional but not because of a hardware op in this kernel. Two reasons:
- The shared TE PTX wrappers it calls — cp_async_bulk_tensor_2d_global_to_shared and mbarrier_wait_parity_acquire_cta_shared_cta in util/ptx.cuh — are themselves guarded to >= 1000 and emit NVTE_DEVICE_ERROR below that. They were authored/validated only for the Blackwell path.
- The whole NVFP4 quantize path is host-gated to SM100 anyway (NVTE_ERROR("NVFP4 requires SM100 ...")), since NVFP4 is a Blackwell datatype and the downstream FP4 GEMM that consumes these scales only exists on SM100. So the amax kernel is never launched off <SM100; the per-arch guard just yields a clean error instead of an undefined symbol.
Add NN/NT GEMM layout dispatch so the per-token NVFP4 path covers dgrad and wgrad, and let per-token opt into RHT via NVFP4PerTokenBlockScaling(per_token_rht=...) while SR/2D stay disabled (kernels unimplemented at this commit). Extends the per-token CUTLASS GEMM, the torch NVFP4Quantizer, and the NVFP4Tensor plumbing, plus dgrad/wgrad numerical tests and a fwd+bwd module smoke test. Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com> Co-authored-by: Jiaxing Qi <jqi@nvidia.com> Signed-off-by: Cael Ling <caell@nvidia.com>
for more information, see https://pre-commit.ci
Thread a Philox rng_state and a kWithSr template flag through the per-token encode kernel (rowwise + colwise) and the nvte_nvfp4_per_token_encode/quantize C-API, mirroring the per-tensor SR path. Drop the SR mutex check in the torch NVFP4Quantizer and build the rng_state when stochastic rounding is requested. Add a per_token_sr recipe flag on NVFP4PerTokenBlockScaling wired through the quantizer factory, plus statistical tests (SR unbiasedness -- lower RMSE than RN when averaged -- and RN-determinism / SR-nondeterminism) folded into test_nvfp4_per_token.py. Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com> Co-authored-by: Jiaxing Qi <jqi@nvidia.com> Signed-off-by: Cael Ling <caell@nvidia.com>
for more information, see https://pre-commit.ci
Wire with_sr + rng_state through the grouped per-token C-API and cast dispatch, implement the SR FP4 cast in the grouped kernel, and drop the "per-token does not support SR" guard. Also fix two comment typos (sees -> uses) in quantize_nvfp4_per_token.cu per review. Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com> Co-authored-by: Jiaxing Qi <jqi@nvidia.com> Signed-off-by: Cael Ling <caell@nvidia.com>
for more information, see https://pre-commit.ci
Introduce NVTE_NVFP4_PER_TOKEN_WEIGHT_2D (recipe.per_token_weight_2d), default off so the per-token path stays byte-equal. When enabled, only the forward WEIGHT switches to the per-tensor 2D cast (16x16 inner tile + scalar outer amax) re-dressed in per-token tensor layout: the scalar outer amax is broadcast across the per-row/col alpha vectors and the inner SF is the same 16-row-replicated 2D tile, so the existing per-token CUTLASS GEMM consumes it unchanged with no kernel modification. Activation/gradient casts stay per-token 1D. Co-authored-by: Jiaxing Qi <jqi@nvidia.com> Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com> Signed-off-by: Cael Ling <caell@nvidia.com>
Document the user-facing surface of the NVFP4 per-token recipe and add a runnable single-GPU example so the recipe can be exercised end to end. - docs/api/common.rst: list NVFP4PerTokenBlockScaling in the API reference. - docs/envvars.rst: document the NVTE_NVFP4_* knobs -- per-token activation (NVTE_NVFP4_PER_TOKEN) plus the RHT/SR/weight-2D opt-ins, and the per-tensor disable flags. - docs/features/.../nvfp4.rst: add a "Per-token NVFP4" section explaining the per-row/per-col outer-amax cast, its differences from the per-tensor default (RHT/SR off by default, forced-off knobs, unfused-norm requirement), and how to launch it with Megatron-Core. - recipe/__init__.py: document the per_token_rht/per_token_sr/per_token_weight_2d constructor kwargs and drop the stale "stochastic rounding unsupported" note. - pytorch/fp8.py: re-export NVFP4PerTokenBlockScaling. - examples/pytorch/nvfp4_per_token_megatron: single-GPU MoE example (run + sbatch + job-chain scripts and README) comparing per-token vs per-tensor vs BF16 with identical model/data/seed. Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com> Co-authored-by: Jiaxing Qi <jqi@nvidia.com> Signed-off-by: Cael Ling <caell@nvidia.com>
Greptile SummaryThis PR adds a new
Confidence Score: 4/5Safe to merge for evaluation/research use as described in the PR; the feature is explicitly gated as non-production and the known unsupported paths (fuse_wgrad_accumulation, comm-overlap, CUDA graphs) all raise clear runtime errors rather than silently corrupting results. The implementation is large but well-structured, with comprehensive guards in both Python and C++. Fresh findings are style/robustness observations. The return-type divergence in transformer_engine/pytorch/cpp_extensions/gemm.py (return-type consistency in grouped per-token path) and transformer_engine/common/recipe/init.py (env-var re-evaluation in nvfp4_per_token classmethod). Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["RecipeState.get_quantizer()"] --> B{nvfp4_per_token?}
B -- "False\n(default NVFP4BlockScaling)" --> C["per-tensor path\n(cuBLASLt 1A/2A)"]
B -- "True\n(NVFP4PerTokenBlockScaling\nOR NVTE_NVFP4_PER_TOKEN=1)" --> D{per_token_weight_2d?}
D -- "False (default)" --> E["NVFP4Quantizer\n(per_token=True)"]
D -- "True\n(NVTE_NVFP4_PER_TOKEN_WEIGHT_2D=1)" --> F["NVFP4Quantizer\n(per_token=True,\nper_token_weight_2d=True)"]
E --> G["quantize_impl:\nnvte_nvfp4_per_token_quantize\n(K1 row/col amax + K2 FP4 encode)"]
F --> H["quantize_impl:\nnvte_compute_amax → nvte_quantize_v2\n(2D cast) → broadcast amax vector"]
G --> I["NVFP4Tensor\n_amax_rowwise shape: (M,)\n_amax_columnwise shape: (K,)"]
H --> I
I --> J["general_gemm / general_grouped_gemm:\n_is_nvfp4_per_token_tensor check"]
J -- "True" --> K["_nvfp4_per_token_gemm\nor\n_nvfp4_per_token_grouped_gemm\n(CUTLASS EVT fused epilogue)"]
J -- "False" --> L["tex.generic_gemm\n(cuBLASLt)"]
K --> M["Output: BF16\nD = alpha_row x alpha_col · (A @ B^T)"]
L --> N["Output: BF16/FP8\n(standard path)"]
Reviews (2): Last reviewed commit: "Add native CUTLASS grouped per-token NVF..." | Re-trigger Greptile |
| NVTE_CHECK(amax_ptr != nullptr, "Could not find amax pointer for per-token weight-2D."); | ||
|
|
||
| // 1. Single scalar tensor amax -> amax[0] (mirror the per-tensor no-RHT path: | ||
| // treat the buffer as length 1 for the reduction, then fan out to both | ||
| // rowwise/columnwise amax[0]). | ||
| out.set_amax(amax_ptr, DType::kFloat32, std::vector<size_t>{1}); | ||
| NVTE_SCOPED_GIL_RELEASE( | ||
| { nvte_compute_amax_with_config(input.data(), out.data(), w2d_config, stream); }); | ||
| out.set_amax(rowwise_amax_ptr, DType::kFloat32, std::vector<size_t>{1}); | ||
| if (rowwise_amax_ptr != amax_ptr && rowwise_amax_ptr != nullptr) { | ||
| NVTE_CHECK_CUDA(cudaMemcpyAsync(rowwise_amax_ptr, amax_ptr, sizeof(float), | ||
| cudaMemcpyDeviceToDevice, stream)); | ||
| } | ||
| if (columnwise_amax_ptr != amax_ptr && columnwise_amax_ptr != nullptr) { | ||
| NVTE_CHECK_CUDA(cudaMemcpyAsync(columnwise_amax_ptr, amax_ptr, sizeof(float), |
There was a problem hiding this comment.
Latent null-pointer dereference in per-token weight-2D amax restore
After nvte_compute_amax_with_config writes the global amax into amax_ptr[0], the code restores the output tensor's amax field with out.set_amax(rowwise_amax_ptr, ...). If rowwise_amax_ptr is nullptr (i.e., the quantizer was constructed with rowwise=False), this sets the output's amax descriptor to a null pointer. The immediately following nvte_quantize_v2 then tries to read amax[0] to derive S_enc and will crash.
Currently this path is unreachable because per_token_weight_2d is only set for weight quantizers, and all weight quantizers in the recipe are constructed with rowwise=True, columnwise=True. However, the guard in step 3 (if (rowwise_amax_ptr != nullptr && w2d_rows > 1)) shows the author anticipated both pointers could be null, while the critical out.set_amax call on this line does not. Using amax_ptr (the non-null pointer already validated by the NVTE_CHECK above) would be safe in all configurations: out.set_amax(amax_ptr, DType::kFloat32, std::vector<size_t>{1}).
| # Per-token NVFP4 dispatches to fused EVT GEMM that consumes per-row | ||
| # (M,) and per-col (N,) outer-amax vectors directly. cuBLASLt cannot, | ||
| # so this MUST short-circuit before the row-scaled-or-generic fork. | ||
| if _is_nvfp4_per_token_tensor(A) or _is_nvfp4_per_token_tensor(B): | ||
| if not (_is_nvfp4_per_token_tensor(A) and _is_nvfp4_per_token_tensor(B)): | ||
| raise NotImplementedError( | ||
| "NVFP4 per-token GEMM requires both A and B to be per-token tensors. " | ||
| "Mixing per-token + prod NVFP4 in one GEMM is not supported." | ||
| ) | ||
| out = _nvfp4_per_token_gemm( | ||
| A, | ||
| B, | ||
| transa=transa, | ||
| transb=transb, | ||
| out=out, | ||
| out_dtype=out_dtype, | ||
| bias=bias, | ||
| grad=grad, | ||
| accumulate=accumulate, | ||
| gelu=gelu, | ||
| quantization_params=quantization_params, | ||
| ub=ub, | ||
| extra_output=extra_output, | ||
| ) |
There was a problem hiding this comment.
alpha scalar silently ignored for per-token GEMM
general_gemm validates and stores alpha in kwargs["alpha"], but the per-token short-circuit path dispatches to _nvfp4_per_token_gemm which has no alpha parameter and never forwards the value. The C++ binding nvfp4_cutlass_per_token_gemm also lacks a global scalar alpha argument — only the per-row/per-col alpha_a/alpha_b vectors are supported. For all current TE module call sites alpha=1.0 is the invariant, so numerical output is unaffected today. If a caller ever passes alpha != 1.0 through general_gemm with per-token tensors, the result will be silently wrong instead of raising an error.
| for i, M_i in enumerate(split_sections): | ||
| if M_i <= 0: | ||
| raise ValueError(f"split_sections[{i}] must be > 0, got {M_i}") | ||
| if M_i % _PER_TOKEN_TILE != 0: | ||
| raise ValueError(f"split_sections[{i}] = {M_i} must be a multiple of {_PER_TOKEN_TILE}") |
There was a problem hiding this comment.
Public grouped-quantize API unconditionally rejects 0-token splits
split_sections[i] <= 0 raises ValueError, but in MoE training with dynamic token routing, experts commonly receive zero tokens in a given micro-batch. The general_grouped_gemm per-token loop already handles this by skipping the launch when m_splits[i] == 0, so the GEMM side is fine. If users call this Python wrapper directly (e.g., from bench scripts or custom MoE quantization pipelines), they must pre-filter empty experts. A comment or guard skipping allocation for empty splits would make the API usable in unbalanced-routing scenarios.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Replace the per-expert Python loop for the plain D = bf16(alpha_a * alpha_b * (A @ B^T)) path with a single ptr-array CUTLASS grouped kernel (SM100). The dispatcher in general_grouped_gemm routes to the native kernel when no accumulate/bias/gelu/output-quant is requested, and otherwise falls back to the per-expert loop (NVTE_NVFP4_PER_TOKEN_GROUPED_FALLBACK=1 forces the fallback). The launcher caches the SM count and reuses persistent device scratch + workspace buffers across launches to avoid per-call cudaMalloc/Free and cudaGetDeviceProperties overhead. Parity tests assert the grouped kernel matches the dense per-token GEMM bit-exact per group. Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com> Co-authored-by: Jiaxing Qi <jqi@nvidia.com> Signed-off-by: Cael Ling <caell@nvidia.com>
Description
This PR adds an NVFP4 per-token quantization recipe for model pre-training. The default NVFP4BlockScaling recipe computes a single per-tensor outer amax (s_global) per tensor. The per-token variant instead computes a per-row outer amax (length M) for rowwise data and a per-col outer amax (length K) for columnwise data, giving each token/row its own global scale.
Changes
Ongoing work
The per-token recipe currently targets accuracy evaluation, not optimized production deployment:
Type of change
Checklist: