Skip to content

Commit 0b1ee70

Browse files
authored
Flex Attention for Local Megatron Backend (#586)
* megatron: replace fused core attention with compiled flex attention * Cleanup after codex and minimize a bit. * Fix bug where Megatron training holds old LoRAs and set Triton/TorchInductor caches explicitly.
1 parent 6559367 commit 0b1ee70

3 files changed

Lines changed: 269 additions & 38 deletions

File tree

src/art/megatron/flex_attention.py

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
"""Flex attention plumbing for ART's Megatron backend."""
2+
3+
import math
4+
from typing import Any, ClassVar, cast
5+
6+
from megatron.core.packed_seq_params import PackedSeqParams
7+
from megatron.core.process_groups_config import ProcessGroupCollection
8+
from megatron.core.transformer.enums import AttnMaskType
9+
from megatron.core.transformer.transformer_config import TransformerConfig
10+
from megatron.core.utils import divide
11+
from pydantic import BaseModel, ConfigDict
12+
import torch
13+
from torch import Tensor
14+
from torch.nn.attention.flex_attention import (
15+
BlockMask,
16+
create_block_mask,
17+
flex_attention,
18+
)
19+
20+
21+
class SharedPrefixAttentionState(BaseModel):
22+
"""Shared-prefix sparsity metadata for one packed ART training sample."""
23+
24+
model_config = ConfigDict(arbitrary_types_allowed=True)
25+
block_mask: BlockMask
26+
27+
28+
class FlexAttentionWrapper(torch.nn.Module):
29+
"""Compiled `flex_attention` wrapper with Torchtitan-style inductor options."""
30+
31+
# Torchtitan inductor options for compiling flex attention.
32+
_compile_options = {
33+
"max_autotune": True,
34+
"coordinate_descent_tuning": True,
35+
"triton.cudagraphs": False,
36+
}
37+
_compiled_flex_attention: ClassVar = torch.compile(
38+
flex_attention,
39+
options=_compile_options,
40+
)
41+
42+
def forward(
43+
self,
44+
q: Tensor,
45+
k: Tensor,
46+
v: Tensor,
47+
*,
48+
block_mask: BlockMask,
49+
scale: float,
50+
enable_gqa: bool,
51+
) -> Tensor:
52+
# q, k, v are [B, H, S, D] tensors expected by torch.flex_attention.
53+
return cast(
54+
Tensor,
55+
FlexAttentionWrapper._compiled_flex_attention(
56+
q,
57+
k,
58+
v,
59+
block_mask=block_mask,
60+
scale=scale,
61+
enable_gqa=enable_gqa,
62+
),
63+
)
64+
65+
66+
_compiled_create_block_mask = torch.compile(create_block_mask)
67+
68+
69+
def create_shared_prefix_attention_state(
70+
group_ids: Tensor,
71+
parent_ids: Tensor,
72+
) -> SharedPrefixAttentionState:
73+
"""Build a compiled block mask for ART shared-prefix packing.
74+
75+
Initialized on the device of the group_ids tensor.
76+
77+
Args:
78+
group_ids: `[B, S]` group id for each token in a packed sequence.
79+
parent_ids: `[B, S]` parent group id for each token in a packed sequence.
80+
"""
81+
82+
def _shared_prefix_mask(
83+
batch_idx: Tensor,
84+
head_idx: Tensor,
85+
query_idx: Tensor,
86+
kv_idx: Tensor,
87+
) -> Tensor:
88+
del head_idx
89+
# Token q can attend token k if k is causal and either from the same
90+
# traj (traj -> traj)/within the shared prefix (prefix -> prefix) (same_group)
91+
# or from the prefix which q uses (traj -> prefix) (parent_prefix).
92+
same_group = group_ids[batch_idx, query_idx] == group_ids[batch_idx, kv_idx]
93+
parent_prefix = parent_ids[batch_idx, query_idx] == group_ids[batch_idx, kv_idx]
94+
return (query_idx >= kv_idx) & (same_group | parent_prefix)
95+
96+
block_mask = _compiled_create_block_mask(
97+
_shared_prefix_mask,
98+
group_ids.shape[0],
99+
None,
100+
group_ids.shape[1],
101+
group_ids.shape[1],
102+
device=group_ids.device,
103+
)
104+
return SharedPrefixAttentionState(block_mask=block_mask)
105+
106+
107+
class FlexDotProductAttention(torch.nn.Module):
108+
"""Megatron core-attention module backed by compiled torch flex attention.
109+
110+
The current implementation lacks support for fp8 and context parallelism (which are available in TEDotProductAttention)
111+
"""
112+
113+
def __init__(
114+
self,
115+
config: TransformerConfig,
116+
layer_number: int,
117+
attn_mask_type: AttnMaskType,
118+
attention_type: str,
119+
attention_dropout: float | None = None,
120+
softmax_scale: float | None = None,
121+
cp_comm_type: str | None = None,
122+
pg_collection: ProcessGroupCollection | None = None,
123+
):
124+
super().__init__()
125+
del (
126+
layer_number,
127+
attn_mask_type,
128+
attention_type,
129+
attention_dropout,
130+
cp_comm_type,
131+
)
132+
self.config = config
133+
self.flex_attention = FlexAttentionWrapper()
134+
135+
if pg_collection is None:
136+
tp_world_size = self.config.tensor_model_parallel_size
137+
else:
138+
tp_world_size = pg_collection.tp.size()
139+
140+
kv_channels = self.config.kv_channels
141+
assert kv_channels is not None, "Megatron config must provide kv_channels."
142+
projection_size = kv_channels * self.config.num_attention_heads
143+
self.hidden_size_per_partition = divide(projection_size, tp_world_size)
144+
num_query_groups = (
145+
self.config.num_query_groups or self.config.num_attention_heads
146+
)
147+
self.num_attention_heads_per_partition = divide(
148+
self.config.num_attention_heads, tp_world_size
149+
)
150+
self.num_query_groups_per_partition = divide(num_query_groups, tp_world_size)
151+
152+
if softmax_scale is None:
153+
head_dim = divide(projection_size, self.config.num_attention_heads)
154+
self.softmax_scale = 1.0 / math.sqrt(head_dim)
155+
else:
156+
self.softmax_scale = softmax_scale
157+
158+
def forward(
159+
self,
160+
query: Tensor,
161+
key: Tensor,
162+
value: Tensor,
163+
attention_mask: Tensor,
164+
attn_mask_type: AttnMaskType | None = None,
165+
attention_bias: Any = None,
166+
packed_seq_params: PackedSeqParams | None = None,
167+
) -> Tensor:
168+
"""Compute self attention with compiled flex kernels.
169+
170+
Args:
171+
query: `[S, B, Hq, D]`
172+
key: `[S, B, Hkv, D]`
173+
value: `[S, B, Hkv, D]`
174+
attention_mask: unused placeholder tensor kept for Megatron checkpoint API.
175+
attention_bias: `SharedPrefixAttentionState` or `BlockMask`.
176+
"""
177+
178+
del attention_mask, attn_mask_type
179+
assert packed_seq_params is None, (
180+
"PackedSeqParams is not used in ART Megatron flex path."
181+
)
182+
183+
if isinstance(attention_bias, SharedPrefixAttentionState):
184+
block_mask = attention_bias.block_mask
185+
else:
186+
assert isinstance(attention_bias, BlockMask), (
187+
"Expected a flex BlockMask in attention_bias."
188+
)
189+
block_mask = attention_bias
190+
191+
# Megatron uses [S, B, H, D], while flex attention expects [B, H, S, D].
192+
q = query.permute(1, 2, 0, 3)
193+
k = key.permute(1, 2, 0, 3)
194+
v = value.permute(1, 2, 0, 3)
195+
196+
out = self.flex_attention(
197+
q,
198+
k,
199+
v,
200+
block_mask=block_mask,
201+
scale=self.softmax_scale,
202+
enable_gqa=self.num_attention_heads_per_partition
203+
!= self.num_query_groups_per_partition,
204+
)
205+
206+
# Return to Megatron's expected layout [S, B, Hq*D].
207+
out = out.permute(2, 0, 1, 3).contiguous()
208+
out = out.view(out.size(0), out.size(1), self.hidden_size_per_partition)
209+
return out

src/art/megatron/provider.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,32 @@
1+
import copy
2+
from functools import partial
3+
import inspect
4+
from typing import Callable
5+
16
from megatron.bridge import AutoBridge
27
from megatron.bridge.models.gpt_provider import GPTModelProvider
38
from megatron.bridge.models.qwen.qwen3_moe_bridge import Qwen3MoEBridge
49
from megatron.core.transformer.enums import AttnBackend
10+
from megatron.core.transformer.spec_utils import ModuleSpec
511
import torch
612

13+
from art.megatron.flex_attention import FlexDotProductAttention
14+
15+
16+
def _resolve_layer_spec(
17+
base_layer_spec: ModuleSpec | Callable[[GPTModelProvider], ModuleSpec],
18+
config: GPTModelProvider,
19+
vp_stage: int | None = None,
20+
) -> ModuleSpec:
21+
if isinstance(base_layer_spec, ModuleSpec):
22+
return copy.deepcopy(base_layer_spec)
23+
kwargs = (
24+
{"vp_stage": vp_stage}
25+
if vp_stage in inspect.signature(base_layer_spec).parameters
26+
else {}
27+
)
28+
return base_layer_spec(config, **kwargs)
29+
730

831
def get_provider(model: str) -> GPTModelProvider:
932
bridge = AutoBridge.from_hf_pretrained(
@@ -15,7 +38,20 @@ def get_provider(model: str) -> GPTModelProvider:
1538
"Only Qwen3 MoE models are supported"
1639
)
1740
provider = bridge.to_megatron_provider()
18-
provider.attention_backend = AttnBackend.fused
41+
base_layer_spec = provider.transformer_layer_spec
42+
43+
def _flex_attention_layer_spec(
44+
config: GPTModelProvider, vp_stage: int | None = None
45+
) -> ModuleSpec:
46+
layer_spec = _resolve_layer_spec(base_layer_spec, config, vp_stage)
47+
# Keep Megatron's standard layer stack and replace only core attention.
48+
layer_spec.submodules.self_attention.submodules.core_attention = ( # ty: ignore[unresolved-attribute]
49+
FlexDotProductAttention
50+
)
51+
return layer_spec
52+
53+
provider.transformer_layer_spec = _flex_attention_layer_spec
54+
provider.attention_backend = AttnBackend.auto
1955
provider.recompute_granularity = "full"
2056
provider.recompute_method = "uniform"
2157
provider.recompute_num_layers = 1

src/art/megatron/train.py

Lines changed: 23 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,18 @@
11
# isort: off
22
import os
33

4+
5+
def _set_cache_dir(env_var: str, default_path: str) -> None:
6+
if not os.environ.get(env_var):
7+
os.environ[env_var] = os.path.expanduser(default_path)
8+
os.makedirs(os.environ[env_var], exist_ok=True)
9+
10+
411
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
512
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
613
os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0"
14+
_set_cache_dir("TORCHINDUCTOR_CACHE_DIR", "~/.cache/torchinductor")
15+
_set_cache_dir("TRITON_CACHE_DIR", "~/.triton/cache")
716
# isort: on
817

918
import gc
@@ -21,9 +30,11 @@
2130
from pydantic import BaseModel
2231
from safetensors.torch import load_file, save_file
2332
import torch
33+
from torch._inductor.runtime.cache_dir_utils import cache_dir as inductor_cache_dir
2434

2535
from art import dev, types
2636
from art.loss import loss_fn, shift_tensor
37+
from art.megatron.flex_attention import create_shared_prefix_attention_state
2738
from art.megatron.lora import apply_lora_adapters
2839
from art.megatron.offload import OffloadState, offload_to_cpu, reload_to_gpu
2940
from art.megatron.provider import get_provider
@@ -55,6 +66,11 @@ def freeze_model(model_chunks: list[MegatronModule]) -> list[MegatronModule]:
5566
rank = torch.distributed.get_rank()
5667
world_size = torch.distributed.get_world_size()
5768

69+
if rank == 0:
70+
print("TORCHINDUCTOR_CACHE_DIR:", os.environ["TORCHINDUCTOR_CACHE_DIR"])
71+
print("Resolved inductor cache_dir():", inductor_cache_dir())
72+
print("TRITON_CACHE_DIR:", os.environ["TRITON_CACHE_DIR"])
73+
5874
for module in model:
5975
while not isinstance(module, GPTModel) and hasattr(module, "module"):
6076
module = module.module
@@ -122,31 +138,6 @@ def print0(*values: Any) -> None:
122138
offload_state = OffloadState()
123139

124140

125-
def calculate_mask(
126-
batch_size: int,
127-
seq_len: int,
128-
device: torch.device,
129-
group_ids: torch.Tensor,
130-
parent_ids: torch.Tensor,
131-
) -> torch.Tensor:
132-
causal_mask = (
133-
torch.tril(
134-
torch.ones(
135-
seq_len,
136-
seq_len,
137-
dtype=torch.bool,
138-
device=device,
139-
)
140-
)
141-
.unsqueeze(0)
142-
.expand(batch_size, seq_len, seq_len)
143-
)
144-
group_mask = group_ids.unsqueeze(2) == group_ids.unsqueeze(1)
145-
parent_mask = parent_ids.unsqueeze(2) == group_ids.unsqueeze(1)
146-
mask = causal_mask & (group_mask | parent_mask)
147-
return mask
148-
149-
150141
offload_to_cpu(model, optimizer, rank, offload_state)
151142

152143
while True:
@@ -236,26 +227,19 @@ def calculate_mask(
236227
for key, value in inputs.items():
237228
if isinstance(value, torch.Tensor):
238229
inputs[key] = value.to(device) # type: ignore
239-
attention_mask = ~calculate_mask(
240-
batch_size=inputs["tokens"].shape[0],
241-
seq_len=inputs["tokens"].shape[1],
242-
device=device,
230+
attention_state = create_shared_prefix_attention_state( # should happen after group_ids is moved to device
243231
group_ids=inputs["group_ids"],
244232
parent_ids=inputs["parent_ids"],
245-
).unsqueeze(1) # add head dimension [B, H=1, S, S]
246-
attention_bias = torch.where(
247-
attention_mask,
248-
torch.tensor(
249-
float("-inf"), dtype=next(model[0].parameters()).dtype, device=device
250-
),
251-
torch.tensor(0.0, dtype=next(model[0].parameters()).dtype, device=device),
252233
)
234+
# Megatron full-layer recompute saves positional tensor args, so keep a tiny
235+
# placeholder Tensor here and pass flex BlockMask state via attention_bias.
236+
attention_mask = torch.zeros((1, 1, 1, 1), dtype=torch.bool, device=device)
253237
new_logprobs: torch.Tensor = -model[0](
254238
input_ids=inputs["tokens"],
255239
position_ids=inputs["input_pos"],
256240
attention_mask=attention_mask,
257241
labels=shift_tensor(inputs["tokens"], 0),
258-
extra_block_kwargs={"attention_bias": attention_bias},
242+
extra_block_kwargs={"attention_bias": attention_state},
259243
)
260244
loss = loss_fn(
261245
inputs, # type: ignore
@@ -332,9 +316,11 @@ def calculate_mask(
332316
offload_to_cpu(model, optimizer, rank, offload_state)
333317
# Release mmap-backed packed tensor references on all ranks before rank0 cleanup.
334318
del packed_tensors
319+
del adapter_model
335320
if "inputs" in locals():
336321
del inputs
337322
gc.collect()
323+
torch.cuda.empty_cache()
338324
# Ensure all ranks have finished saving before signaling completion
339325
torch.distributed.barrier()
340326
if rank == 0:

0 commit comments

Comments
 (0)