-
Notifications
You must be signed in to change notification settings - Fork 804
Expand file tree
/
Copy pathflex_attention.py
More file actions
225 lines (192 loc) · 7.56 KB
/
flex_attention.py
File metadata and controls
225 lines (192 loc) · 7.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
"""Flex attention plumbing for ART's Megatron backend."""
from collections.abc import Callable
import math
from typing import Any, ClassVar, TypeAlias, cast
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import divide
from pydantic import BaseModel, ConfigDict
import torch
from torch import Tensor
from torch.nn.attention.flex_attention import (
BlockMask,
FlexKernelOptions,
create_block_mask,
flex_attention,
)
class SharedPrefixAttentionState(BaseModel):
"""Shared-prefix sparsity metadata for one packed ART training sample."""
model_config = ConfigDict(arbitrary_types_allowed=True)
block_mask: BlockMask
CompileOptions: TypeAlias = dict[str, str | int | bool | Callable[..., Any]]
class FlexAttentionWrapper(torch.nn.Module):
"""Compiled `flex_attention` wrapper with Torchtitan-style inductor options."""
# Torchtitan inductor options for compiling flex attention.
_compile_options: ClassVar[CompileOptions] = {
"max_autotune": True,
"coordinate_descent_tuning": True,
"triton.cudagraphs": False,
}
# Skip Inductor's flex_decoding specialization: it has triggered both
# shared-memory OOMs (triton_flex_decoding) and symbolic-shape assertion
# failures (create_flex_decoding_kernel). The regular flex_attention
# kernel autotunes against the actual hardware smem budget, so this
# stays GPU-agnostic.
_kernel_options: ClassVar[FlexKernelOptions] = {
"FORCE_USE_FLEX_ATTENTION": True,
}
_compiled_flex_attention: ClassVar = torch.compile(
flex_attention,
options=_compile_options,
)
def forward(
self,
q: Tensor,
k: Tensor,
v: Tensor,
*,
block_mask: BlockMask,
scale: float,
enable_gqa: bool,
) -> Tensor:
# q, k, v are [B, H, S, D] tensors expected by torch.flex_attention.
return cast(
Tensor,
FlexAttentionWrapper._compiled_flex_attention(
q,
k,
v,
block_mask=block_mask,
scale=scale,
enable_gqa=enable_gqa,
kernel_options=FlexAttentionWrapper._kernel_options,
),
)
# Sequence-length churn can break the Inductor backend here. Keep this
# on aot_eager instead.
_compiled_create_block_mask = torch.compile(create_block_mask, backend="aot_eager")
def create_shared_prefix_attention_state(
group_ids: Tensor,
parent_ids: Tensor,
) -> SharedPrefixAttentionState:
"""Build a block mask for ART shared-prefix packing.
Initialized on the device of the group_ids tensor.
Args:
group_ids: `[B, S]` group id for each token in a packed sequence.
parent_ids: `[B, S]` parent group id for each token in a packed sequence.
"""
def _shared_prefix_mask(
batch_idx: Tensor,
head_idx: Tensor,
query_idx: Tensor,
kv_idx: Tensor,
) -> Tensor:
del head_idx
# Token q can attend token k if k is causal and either from the same
# traj (traj -> traj)/within the shared prefix (prefix -> prefix) (same_group)
# or from the prefix which q uses (traj -> prefix) (parent_prefix).
same_group = group_ids[batch_idx, query_idx] == group_ids[batch_idx, kv_idx]
parent_prefix = parent_ids[batch_idx, query_idx] == group_ids[batch_idx, kv_idx]
return (query_idx >= kv_idx) & (same_group | parent_prefix)
block_mask = _compiled_create_block_mask(
_shared_prefix_mask,
group_ids.shape[0],
None,
group_ids.shape[1],
group_ids.shape[1],
device=group_ids.device,
)
return SharedPrefixAttentionState(block_mask=block_mask)
class FlexDotProductAttention(torch.nn.Module):
"""Megatron core-attention module backed by compiled torch flex attention.
The current implementation lacks support for fp8 and context parallelism (which are available in TEDotProductAttention)
"""
def __init__(
self,
config: TransformerConfig,
layer_number: int,
attn_mask_type: AttnMaskType,
attention_type: str,
attention_dropout: float | None = None,
softmax_scale: float | None = None,
cp_comm_type: str | None = None,
pg_collection: ProcessGroupCollection | None = None,
):
super().__init__()
del (
layer_number,
attn_mask_type,
attention_type,
attention_dropout,
cp_comm_type,
)
self.config = config
self.flex_attention = FlexAttentionWrapper()
if pg_collection is None:
tp_world_size = self.config.tensor_model_parallel_size
else:
tp_world_size = pg_collection.tp.size()
kv_channels = self.config.kv_channels
assert kv_channels is not None, "Megatron config must provide kv_channels."
projection_size = kv_channels * self.config.num_attention_heads
self.hidden_size_per_partition = divide(projection_size, tp_world_size)
num_query_groups = (
self.config.num_query_groups or self.config.num_attention_heads
)
self.num_attention_heads_per_partition = divide(
self.config.num_attention_heads, tp_world_size
)
self.num_query_groups_per_partition = divide(num_query_groups, tp_world_size)
if softmax_scale is None:
head_dim = divide(projection_size, self.config.num_attention_heads)
self.softmax_scale = 1.0 / math.sqrt(head_dim)
else:
self.softmax_scale = softmax_scale
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
attention_mask: Tensor,
attn_mask_type: AttnMaskType | None = None,
attention_bias: Any = None,
packed_seq_params: PackedSeqParams | None = None,
) -> Tensor:
"""Compute self attention with compiled flex kernels.
Args:
query: `[S, B, Hq, D]`
key: `[S, B, Hkv, D]`
value: `[S, B, Hkv, D]`
attention_mask: unused placeholder tensor kept for Megatron checkpoint API.
attention_bias: `SharedPrefixAttentionState` or `BlockMask`.
"""
del attention_mask, attn_mask_type
assert packed_seq_params is None, (
"PackedSeqParams is not used in ART Megatron flex path."
)
if isinstance(attention_bias, SharedPrefixAttentionState):
block_mask = attention_bias.block_mask
else:
assert isinstance(attention_bias, BlockMask), (
"Expected a flex BlockMask in attention_bias."
)
block_mask = attention_bias
# Megatron uses [S, B, H, D], while flex attention expects [B, H, S, D].
q = query.permute(1, 2, 0, 3)
k = key.permute(1, 2, 0, 3)
v = value.permute(1, 2, 0, 3)
out = self.flex_attention(
q,
k,
v,
block_mask=block_mask,
scale=self.softmax_scale,
enable_gqa=self.num_attention_heads_per_partition
!= self.num_query_groups_per_partition,
)
# Return to Megatron's expected layout [S, B, Hq*D].
out = out.permute(2, 0, 1, 3).contiguous()
out = out.view(out.size(0), out.size(1), self.hidden_size_per_partition)
return out