|
| 1 | +"""Flex attention plumbing for ART's Megatron backend.""" |
| 2 | + |
| 3 | +import math |
| 4 | +from typing import Any, ClassVar, cast |
| 5 | + |
| 6 | +from megatron.core.packed_seq_params import PackedSeqParams |
| 7 | +from megatron.core.process_groups_config import ProcessGroupCollection |
| 8 | +from megatron.core.transformer.enums import AttnMaskType |
| 9 | +from megatron.core.transformer.transformer_config import TransformerConfig |
| 10 | +from megatron.core.utils import divide |
| 11 | +from pydantic import BaseModel, ConfigDict |
| 12 | +import torch |
| 13 | +from torch import Tensor |
| 14 | +from torch.nn.attention.flex_attention import ( |
| 15 | + BlockMask, |
| 16 | + create_block_mask, |
| 17 | + flex_attention, |
| 18 | +) |
| 19 | + |
| 20 | + |
| 21 | +class SharedPrefixAttentionState(BaseModel): |
| 22 | + """Shared-prefix sparsity metadata for one packed ART training sample.""" |
| 23 | + |
| 24 | + model_config = ConfigDict(arbitrary_types_allowed=True) |
| 25 | + block_mask: BlockMask |
| 26 | + |
| 27 | + |
| 28 | +class FlexAttentionWrapper(torch.nn.Module): |
| 29 | + """Compiled `flex_attention` wrapper with Torchtitan-style inductor options.""" |
| 30 | + |
| 31 | + # Torchtitan inductor options for compiling flex attention. |
| 32 | + _compile_options = { |
| 33 | + "max_autotune": True, |
| 34 | + "coordinate_descent_tuning": True, |
| 35 | + "triton.cudagraphs": False, |
| 36 | + } |
| 37 | + _compiled_flex_attention: ClassVar = torch.compile( |
| 38 | + flex_attention, |
| 39 | + options=_compile_options, |
| 40 | + ) |
| 41 | + |
| 42 | + def forward( |
| 43 | + self, |
| 44 | + q: Tensor, |
| 45 | + k: Tensor, |
| 46 | + v: Tensor, |
| 47 | + *, |
| 48 | + block_mask: BlockMask, |
| 49 | + scale: float, |
| 50 | + enable_gqa: bool, |
| 51 | + ) -> Tensor: |
| 52 | + # q, k, v are [B, H, S, D] tensors expected by torch.flex_attention. |
| 53 | + return cast( |
| 54 | + Tensor, |
| 55 | + FlexAttentionWrapper._compiled_flex_attention( |
| 56 | + q, |
| 57 | + k, |
| 58 | + v, |
| 59 | + block_mask=block_mask, |
| 60 | + scale=scale, |
| 61 | + enable_gqa=enable_gqa, |
| 62 | + ), |
| 63 | + ) |
| 64 | + |
| 65 | + |
| 66 | +_compiled_create_block_mask = torch.compile(create_block_mask) |
| 67 | + |
| 68 | + |
| 69 | +def create_shared_prefix_attention_state( |
| 70 | + group_ids: Tensor, |
| 71 | + parent_ids: Tensor, |
| 72 | +) -> SharedPrefixAttentionState: |
| 73 | + """Build a compiled block mask for ART shared-prefix packing. |
| 74 | +
|
| 75 | + Initialized on the device of the group_ids tensor. |
| 76 | +
|
| 77 | + Args: |
| 78 | + group_ids: `[B, S]` group id for each token in a packed sequence. |
| 79 | + parent_ids: `[B, S]` parent group id for each token in a packed sequence. |
| 80 | + """ |
| 81 | + |
| 82 | + def _shared_prefix_mask( |
| 83 | + batch_idx: Tensor, |
| 84 | + head_idx: Tensor, |
| 85 | + query_idx: Tensor, |
| 86 | + kv_idx: Tensor, |
| 87 | + ) -> Tensor: |
| 88 | + del head_idx |
| 89 | + # Token q can attend token k if k is causal and either from the same |
| 90 | + # traj (traj -> traj)/within the shared prefix (prefix -> prefix) (same_group) |
| 91 | + # or from the prefix which q uses (traj -> prefix) (parent_prefix). |
| 92 | + same_group = group_ids[batch_idx, query_idx] == group_ids[batch_idx, kv_idx] |
| 93 | + parent_prefix = parent_ids[batch_idx, query_idx] == group_ids[batch_idx, kv_idx] |
| 94 | + return (query_idx >= kv_idx) & (same_group | parent_prefix) |
| 95 | + |
| 96 | + block_mask = _compiled_create_block_mask( |
| 97 | + _shared_prefix_mask, |
| 98 | + group_ids.shape[0], |
| 99 | + None, |
| 100 | + group_ids.shape[1], |
| 101 | + group_ids.shape[1], |
| 102 | + device=group_ids.device, |
| 103 | + ) |
| 104 | + return SharedPrefixAttentionState(block_mask=block_mask) |
| 105 | + |
| 106 | + |
| 107 | +class FlexDotProductAttention(torch.nn.Module): |
| 108 | + """Megatron core-attention module backed by compiled torch flex attention. |
| 109 | +
|
| 110 | + The current implementation lacks support for fp8 and context parallelism (which are available in TEDotProductAttention) |
| 111 | + """ |
| 112 | + |
| 113 | + def __init__( |
| 114 | + self, |
| 115 | + config: TransformerConfig, |
| 116 | + layer_number: int, |
| 117 | + attn_mask_type: AttnMaskType, |
| 118 | + attention_type: str, |
| 119 | + attention_dropout: float | None = None, |
| 120 | + softmax_scale: float | None = None, |
| 121 | + cp_comm_type: str | None = None, |
| 122 | + pg_collection: ProcessGroupCollection | None = None, |
| 123 | + ): |
| 124 | + super().__init__() |
| 125 | + del ( |
| 126 | + layer_number, |
| 127 | + attn_mask_type, |
| 128 | + attention_type, |
| 129 | + attention_dropout, |
| 130 | + cp_comm_type, |
| 131 | + ) |
| 132 | + self.config = config |
| 133 | + self.flex_attention = FlexAttentionWrapper() |
| 134 | + |
| 135 | + if pg_collection is None: |
| 136 | + tp_world_size = self.config.tensor_model_parallel_size |
| 137 | + else: |
| 138 | + tp_world_size = pg_collection.tp.size() |
| 139 | + |
| 140 | + kv_channels = self.config.kv_channels |
| 141 | + assert kv_channels is not None, "Megatron config must provide kv_channels." |
| 142 | + projection_size = kv_channels * self.config.num_attention_heads |
| 143 | + self.hidden_size_per_partition = divide(projection_size, tp_world_size) |
| 144 | + num_query_groups = ( |
| 145 | + self.config.num_query_groups or self.config.num_attention_heads |
| 146 | + ) |
| 147 | + self.num_attention_heads_per_partition = divide( |
| 148 | + self.config.num_attention_heads, tp_world_size |
| 149 | + ) |
| 150 | + self.num_query_groups_per_partition = divide(num_query_groups, tp_world_size) |
| 151 | + |
| 152 | + if softmax_scale is None: |
| 153 | + head_dim = divide(projection_size, self.config.num_attention_heads) |
| 154 | + self.softmax_scale = 1.0 / math.sqrt(head_dim) |
| 155 | + else: |
| 156 | + self.softmax_scale = softmax_scale |
| 157 | + |
| 158 | + def forward( |
| 159 | + self, |
| 160 | + query: Tensor, |
| 161 | + key: Tensor, |
| 162 | + value: Tensor, |
| 163 | + attention_mask: Tensor, |
| 164 | + attn_mask_type: AttnMaskType | None = None, |
| 165 | + attention_bias: Any = None, |
| 166 | + packed_seq_params: PackedSeqParams | None = None, |
| 167 | + ) -> Tensor: |
| 168 | + """Compute self attention with compiled flex kernels. |
| 169 | +
|
| 170 | + Args: |
| 171 | + query: `[S, B, Hq, D]` |
| 172 | + key: `[S, B, Hkv, D]` |
| 173 | + value: `[S, B, Hkv, D]` |
| 174 | + attention_mask: unused placeholder tensor kept for Megatron checkpoint API. |
| 175 | + attention_bias: `SharedPrefixAttentionState` or `BlockMask`. |
| 176 | + """ |
| 177 | + |
| 178 | + del attention_mask, attn_mask_type |
| 179 | + assert packed_seq_params is None, ( |
| 180 | + "PackedSeqParams is not used in ART Megatron flex path." |
| 181 | + ) |
| 182 | + |
| 183 | + if isinstance(attention_bias, SharedPrefixAttentionState): |
| 184 | + block_mask = attention_bias.block_mask |
| 185 | + else: |
| 186 | + assert isinstance(attention_bias, BlockMask), ( |
| 187 | + "Expected a flex BlockMask in attention_bias." |
| 188 | + ) |
| 189 | + block_mask = attention_bias |
| 190 | + |
| 191 | + # Megatron uses [S, B, H, D], while flex attention expects [B, H, S, D]. |
| 192 | + q = query.permute(1, 2, 0, 3) |
| 193 | + k = key.permute(1, 2, 0, 3) |
| 194 | + v = value.permute(1, 2, 0, 3) |
| 195 | + |
| 196 | + out = self.flex_attention( |
| 197 | + q, |
| 198 | + k, |
| 199 | + v, |
| 200 | + block_mask=block_mask, |
| 201 | + scale=self.softmax_scale, |
| 202 | + enable_gqa=self.num_attention_heads_per_partition |
| 203 | + != self.num_query_groups_per_partition, |
| 204 | + ) |
| 205 | + |
| 206 | + # Return to Megatron's expected layout [S, B, Hq*D]. |
| 207 | + out = out.permute(2, 0, 1, 3).contiguous() |
| 208 | + out = out.view(out.size(0), out.size(1), self.hidden_size_per_partition) |
| 209 | + return out |
0 commit comments