Reject CUDA BERT EmbedLayerNorm/SkipLayerNorm shapes exceeding 32-bit output indexing#29264
Conversation
There was a problem hiding this comment.
Pull request overview
This PR addresses integer overflow in CUDA BERT LayerNorm-family kernels by widening the global element write offset (row * hidden_size) from 32-bit to 64-bit, preventing wrapped output indexing for very large tensors.
Changes:
- Widen LayerNorm device helper offset/index parameters to
int64_tinlayer_norm.cuh. - Compute per-row offsets/indices in 64-bit in
skip_layer_norm_impl.cukernels. - Compute
output_offsetin 64-bit inembed_layer_norm_impl.cuand pass it through toLayerNorm.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu | Uses int64_t for per-row offset/idx to avoid overflow when indexing large output tensors. |
| onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh | Updates LayerNorm helpers to accept 64-bit offsets and use 64-bit indices for global element access. |
| onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu | Uses 64-bit output_offset for writing/normalizing large outputs in EmbedLayerNorm. |
|
Is it needed? Typical max sequence length for BERT model is 512, and int32 offset is enough. |
… output indexing The CUDA EmbedLayerNormalization and SkipLayerNormalization kernels compute output write offsets (row_index * hidden_size) using 32-bit arithmetic. For very large output tensors the element count can exceed INT32_MAX and the offset would no longer be representable in 32 bits. Every output write index in these kernels is a pure function of the launch grid and hidden_size (no data-dependent write indexing), so the maximum index is exactly output_element_count - 1, which the host knows from the input shapes before launch. Add a host-side guard in each ComputeInternal that computes the output element count in 64-bit arithmetic and returns a clear error when it exceeds the supported 32-bit indexing range, instead of silently relying on the int32 kernels for shapes they cannot index. Kernels are unchanged (int32 baseline); no numeric behavior change for supported shapes. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1379258 to
0b9d5e2
Compare
There was a problem hiding this comment.
Thanks — the host-side guard approach is sound and keeps the hot kernel untouched (zero cost in the inference path).
SkipLayerNormalization (skip_layer_norm.cc) — looks complete. The guard is fully sound: the output, sum_output, and the skip[idx % skip_size] read in skip_layer_norm_impl.cu are all bounded by row_count * hidden_size, which equals input->Shape().Size() — exactly what the guard checks. Max device index is output_element_count - 1, so every write and read site is covered.
EmbedLayerNormalization (embed_layer_norm.cc). The 64-bit accumulation is correct (static_cast<int64_t>(batch_size) * sequence_length * hidden_size promotes before any multiply), and the guard correctly bounds the output-write offset.
One residual gap worth noting (same root issue the automated reviewer raised): in embed_layer_norm_impl.cu the embedding reads word_offset = word_id * hidden_size, segment_offset, and position_offset are still 32-bit int. These index into the embedding tables (word_embedding is [vocab_size, hidden_size], etc.), whose sizes are independent of the output element count. A model with a large embedding table (vocab_size * hidden_size > 2^31) but a modest output would still overflow these reads silently — the new guard does not cover that path. Suggest either widening those three offsets to int64_t (with the pointer arithmetic that uses them), or narrowing the PR framing to explicitly cover only output-write indexing.
tianleiwu
left a comment
There was a problem hiding this comment.
Suggest adding a check of vocab_size * hidden_size too in this PR or a follow-up PR.
Summary
The CUDA
EmbedLayerNormalizationandSkipLayerNormalizationkernels compute output write offsets (row_index * hidden_size) using 32-bit arithmetic. For very large output tensors the element count can exceedINT32_MAX, at which point the offset is no longer representable in 32 bits.Every output write index in these kernels is a pure function of the launch grid and
hidden_size— there is no data-dependent write indexing — so the maximum index is exactlyoutput_element_count - 1, which the host knows from the input shapes before launch. This PR adds a host-side guard in each op'sComputeInternalthat computes the output element count in 64-bit arithmetic and returns a clear error when it exceeds the supported 32-bit indexing range.Design
EmbedLayerNormalization(embed_layer_norm.cc):output_element_count = (int64)batch_size * sequence_length * hidden_size, guarded withORT_RETURN_IF_NOT(... <= INT32_MAX, ...).SkipLayerNormalization(skip_layer_norm.cc):output_element_count = input->Shape().Size()(output shares the input shape), same guard.Behavior
This rejects (rather than silently attempting) single-op LayerNorm outputs larger than 2³¹ elements — a regime no real BERT-family model produces (it would require a multi-GB single-op activation). For all supported shapes there is no behavior or numeric change.
Co-authored-by: Copilot 223556219+Copilot@users.noreply.github.com