Current Scaling Group Quantization + Enabling Varying Last/Both Dims in Group Quantize#3114
Current Scaling Group Quantization + Enabling Varying Last/Both Dims in Group Quantize#3114vthumbe1503 wants to merge 75 commits into
Conversation
Route grouped Float8CurrentScalingQuantizer through the existing grouped quantize entry point, prepare per-group current-scaling metadata with existing amax/scale helpers, and add focused tests plus a GB200 bandwidth benchmark. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_5507e814ee50f9ff304a4ce708d19768 Orchestra-Run: run_516e1e26891f4ce7d4cde07147c10862
Use wider vectorized grouped FP8 cast-transpose tiles and vectorized masked stores for rowwise and columnwise outputs. Capture all benchmark modes in a single post-warmup profiler range. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_3d6e33eab11e293d72eb4394bad76a81 Orchestra-Run: run_a6e2c31d5fdf850594f71438e53148da
Route non-MXFP8 grouped-linear bias backward through group_quantize plus grouped dbias while keeping MXFP8 bgrad_group_quantize fusion intact. Add focused zero-row grouped FP8 coverage and a current-scaling GroupedLinear bias-backward regression. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_ab566800d87047635cd27f9e64661abe Orchestra-Run: run_5f9bfef17ccd854232c54d56268ef9e8
Use packed FP8 conversion and reduce columnwise transpose staging register and synchronization overhead in group_cast_fp8_kernel. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_7a830e018ceac8de0018280bd0740a54 Orchestra-Run: run_d2f1df4ffc2265d9cfa5ed01028ee476
Match the grouped FP8 conversion helper's element-count template parameter to Vec's uint32_t parameter so rowwise, columnwise, and activation instantiations can build. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_30c4b6ddb896e5ea3ca5b54731d2c819 Orchestra-Run: run_e95cdbb445943304622b95736f0eca49
Use cached grouped offsets to avoid launching FP8 quantization over unused overallocated rows, permit larger grouped backing buffers when split metadata is present, and tighten full-tile vector paths in the grouped FP8 cast kernel. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_c5db93823dc101838cb1323e283cd6e9 Orchestra-Run: run_063e2e4c724e132612aa5597d6765c9b
Use the FP8 grouped output logical shape when computing the tensor-scaling launch grid so overallocated buffers with active metadata avoid empty tail-row launches while preserving the allocated-shape fallback. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_b4abb47c990404d73142342a19996a3f Orchestra-Run: run_8f09e7b9d7af9754ef505f2e2ce3cf90
Use larger grouped FP8 tiles with 8-warp CTAs and 16-row columnwise store fragments. Treat uniform overallocated FP8 grouped outputs as same-shape wrappers during output reuse so the timed path avoids varying-shape metadata overlaunch. Add overallocated current-scaling coverage for all grouped FP8 direction modes. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_3f98ac9c5b82192ec289d8d2a9816c7f Orchestra-Run: run_83f3b99cc950024cf06ee836337fbf72
Stage columnwise transpose fragments through shared-memory vectors with smaller columnwise row tiles to reduce register pressure and barrier overhead while preserving the larger rowwise-only store path. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_495cc57eef84749103aded403a508d99 Orchestra-Run: run_53e038e90f83186bc6c12cb722c986b5
Add fast grouped FP8 rowwise and full-tile columnwise paths for uniform active groups while preserving the general fallback for varying grouped metadata. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_4c33e88776c8a7148e9da5cc2bae84ea Orchestra-Run: run_2caaff219394eb5d59b7be38ab2bf346
Add a same-shape bidirectional full-tile kernel with wider input vectors and rowwise stores while preserving the existing rowwise-only, columnwise-only, and fallback grouped paths. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_87cec01d94f053b53e3c79377ad379ab Orchestra-Run: run_ed48db00a730a4bf56530d551ecd350e
Route same-shape rowwise+columnwise grouped FP8 tensor-scaling quantization through the compact full-tile transpose schedule instead of the wide dynamic-shared-memory variant, preserving the existing single-direction and fallback paths. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_fdddd228a620039c024b4ecf43f3ab42 Orchestra-Run: run_30a2753eea9c893cb0fadb8233da8ce6
Hint the rowwise stores in the full-tile rowwise+columnwise grouped FP8 path as streaming global stores to reduce cache/writeback pressure without changing single-direction launch geometry. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_bf82020032e68276f4e47c65f62d97ae Orchestra-Run: run_754ea4c864f329c6f2003b413b723c43
Add graph-safe grouped FP8 tensor-scaling metadata, support varying last dimensions, preserve same-shape fast paths, adjust grouped FP8 columnwise allocation by architecture, and expand benchmark/test coverage for the reviewed shape cases. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_d104e74844fbc3d3b1a98a8d96d76037 Orchestra-Run: run_1314e997c61ffb92ff7120b0b26f0318
Map varying-last columnwise tiles per group to avoid tile-alignment device errors, expand nonaligned boundary coverage, and restore same-shape benchmark baseline criteria. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_14e0e7973300d26f69550bc0aee21acc Orchestra-Run: run_2f42b8ba138ed8b2b4d9dc90b92caf85
Add grouped FP8 benchmark support for baseline-ref same-session reports and update the benchmark request to enforce same-shape baseline regression checks alongside the per-mode throughput thresholds. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_d0cada957a4aafdce9d52be86520e182 Orchestra-Run: run_4da74e9bdb4f4a4c72304a385692b6c9
Update the grouped FP8 benchmark driver so same-session baseline checks out and builds the baseline ref into an isolated PyTorch install, verifies the baseline subprocess loads those shared objects, and preserves the required same-shape baseline comparisons. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_4fd88b172872f547f2f2d0053dce73d1 Orchestra-Run: run_6a44ee0467ffff47d4b278de6127354d
Preserve grouped delayed-FP8 amax metadata and keep unsupported FP8 tensor-scaling quantizers out of the grouped GEMM path. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_2aa8e6bf11ae356f4b34d4540b508031 Orchestra-Run: run_302681098d7f4e05b0ad96450f2d9826
Set NVTE_GROUPED_LINEAR_SINGLE_PARAM inside the targeted state-dict tests so they exercise the gated single grouped parameter path without relying on external environment setup. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_261900f987bdc9397965019983a77c41 Orchestra-Run: run_c6624e34717cbe121b3e0edcf490e3d3
Add a segmented flat rowwise kernel for varying-first grouped FP8 tensor-scaling outputs while preserving the existing same-shape fast path. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_c1b7020b27290318848ef6ac9048dd5f Orchestra-Run: run_5c257b8a5d2e7e4aa95e67aa16436166
Omit the last_dims keyword when absent so the same-session baseline can run against the base extension, and refresh the benchmark request to include direct varying-last current-scaling coverage. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_c20a3c94fdc798e741a469bd7bb9c4df Orchestra-Run: run_457448e6cba80fc63ac72b3db71c5fd0
Dispatch varying-first tensor-scaling work per group to reduce inactive-tail CTAs and offset lookup overhead while preserving same-shape fast paths and graph-safe device metadata handling. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_d84e1fefef8641e558df064452f4689b Orchestra-Run: run_a361ca2f93fcec53ddd60dd99f4639e5
Add a no-tail rowwise flat kernel for aligned varying-first grouped FP8 tensor-scaling quantization and keep same-shape and varying-last dispatch isolated. Tighten benchmark profiler timing so post-warmup measured ranges exclude profiler start overhead. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_2e478be1fb38195f36d25c51320dc01f Orchestra-Run: run_9a133a75fa3d98dc3b1a63b0ff4d84af
Write grouped FP8 benchmark reports to a sidecar path by default and label script reports as benchmark_raw_report/v1 so regular 100-iteration measurements are fetched instead of the wrapper command report. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_27770b2e1d490b1a3053244d4b4ce248 Orchestra-Run: run_214052d0c1316e231443d645183a2675
Write the grouped FP8 benchmark JSON once and mirror the completed sidecar to ORCHESTRA_BENCHMARK_RAW_REPORT when running under Orchestra so the benchmark fetch path can parse the emitted measurements. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_b2e2747371204088c8e3f7cf10263164 Orchestra-Run: run_1d4ea38266807c8acb59143ee74ba241
Allow the grouped FP8 benchmark to use ORCHESTRA_BENCHMARK_RAW_REPORT as its primary output so the benchmark wrapper can fetch canonical measurements directly. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_10fdcfef6b70de4676b7843e4bbfac31 Orchestra-Run: run_4ce57df9e86d6d03a26f7aa95ac252cc
Write canonical grouped FP8 benchmark measurements to ORCHESTRA_BENCHMARK_RAW_REPORT in a small schema-shaped payload so the benchmark wrapper can materialize per-mode threshold evidence. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_3e862eebd585c74f2a58497fedea3511 Orchestra-Run: run_3770ab3dbbf51329d0839b3d10a91b5c
Write candidate_results and nonempty measurements into the Orchestra raw report path, and fail fast if the benchmark cannot produce threshold-ready evidence. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_aa587a7b0d35aa9c2b715ec1b7c8bec3 Orchestra-Run: run_b42870e5d5e142a6cbf53bb5a3cafc2e
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…TransformerEngine into current_scaling_group_quant
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
|
/te-ci pytorch |
Greptile SummaryThis PR adds grouped FP8 current-scaling quantization (amax-on-the-fly, per-group scale/scale_inv) and extends the existing grouped-quantize infrastructure to support varying last dim and varying both dims, in addition to the previously supported varying first dim.
Confidence Score: 4/5Safe to merge for the FP8 current-scaling and MXFP8 varying-dim paths; the NVFP4 code path needs a missing guard before
Important Files Changed
Sequence DiagramsequenceDiagram
participant PY as Python (group_quantize)
participant CPP as cast.cpp
participant AMAX as nvte_group_compute_amax_with_config
participant NCCL as NCCL allreduce (optional)
participant SCALE as nvte_group_compute_scale_from_amax
participant QUANT as nvte_group_quantize (FP8 cast kernel)
PY->>CPP: group_quantize(tensor, Float8CurrentScalingQuantizer, first_dims, last_dims, noop_flag)
CPP->>CPP: create_grouped_tensor (allocates data, amax, scale, scale_inv)
CPP->>AMAX: compute per-group amax from input data
AMAX-->>CPP: amax[0..N-1] written
alt "with_amax_reduction=True"
CPP->>NCCL: allreduce(amax, MAX)
end
CPP->>SCALE: "derive scale = fp8_max / amax, scale_inv = 1/scale"
SCALE-->>CPP: scale[0..N-1], scale_inv[0..N-1] written
CPP->>QUANT: cast input to FP8 using per-group scale
QUANT-->>CPP: grouped FP8 output tensor
CPP-->>PY: grouped output Python object
|
… specific Signed-off-by: Varun Thumbe <vthumbe@vthumbe-mlt.client.nvidia.com>
|
/te-ci |
|
Pipeline: 54747206 |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…scale from amax Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
…TransformerEngine into current_scaling_group_quant
Removed duplicate brief comment about scaled prefix-sum offsets. Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: