Skip to content
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/art/megatron/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch import Tensor
from torch.nn.attention.flex_attention import (
BlockMask,
FlexKernelOptions,
create_block_mask,
flex_attention,
)
Expand All @@ -29,17 +30,16 @@ class FlexAttentionWrapper(torch.nn.Module):
"""Compiled `flex_attention` wrapper with Torchtitan-style inductor options."""

# Torchtitan inductor options for compiling flex attention.
_compile_options = {
_compile_options: dict[str, Any] = {
"max_autotune": True,
"coordinate_descent_tuning": True,
"triton.cudagraphs": False,
}
# Skip Inductor's flex_decoding specialization: it has triggered both
# shared-memory OOMs (triton_flex_decoding) and symbolic-shape assertion
# failures (create_flex_decoding_kernel). The regular flex_attention
# kernel autotunes against the actual hardware smem budget, so this
# stays GPU-agnostic.
_kernel_options = {
_kernel_options: FlexKernelOptions = {
"FORCE_USE_FLEX_ATTENTION": True,
}
_compiled_flex_attention: ClassVar = torch.compile(
Expand Down Expand Up @@ -72,14 +72,16 @@ def forward(
)


_compiled_create_block_mask = torch.compile(create_block_mask)
# Sequence-length churn can break the Inductor backend here. Keep this
# on aot_eager instead.
_compiled_create_block_mask = torch.compile(create_block_mask, backend="aot_eager")


def create_shared_prefix_attention_state(
group_ids: Tensor,
parent_ids: Tensor,
) -> SharedPrefixAttentionState:
"""Build a compiled block mask for ART shared-prefix packing.
"""Build a block mask for ART shared-prefix packing.

Initialized on the device of the group_ids tensor.

Expand Down