Skip to content

Commit 64e9e6d

Browse files
authored
feat: Add MegatronBackend (#545)
* feat: implement Megatron backend and training infrastructure - Introduced MegatronBackend for managing model services and training processes. - Added MegatronService for handling training jobs and OpenAI server interactions. - Created yes-no-maybe-megatron.py for orchestrating model training with prompts. - Included setup script for environment configuration and dependencies. - Implemented training logic in train.py to facilitate distributed training with LoRA support. * refactor: improve code formatting and organization in MegatronService - Reformatted command construction for better readability. - Updated optimizer state path assignment for clarity. - Rearranged import statements for consistency and organization. * feat: enhance LoRA initialization and parameter loading in train.py - Added a reset_lora_parameters method to initialize LoRA weights with Kaiming and zero initialization. - Improved assertion messages for clarity in various sections of the LoRA class. - Refactored loading logic to utilize the new reset method for better parameter handling. - Enhanced code readability by restructuring assertions and method calls. * refactor: improve assertion formatting and readability in train.py - Restructured assertions in the LoRA class for better clarity and consistency. - Enhanced error messages to provide more informative feedback. - Improved code readability by consolidating assertion statements. * feat: add Docker image ID for PyTorch in SkyPilot configuration - Included the Docker image ID for PyTorch version 2.9.0 with CUDA 12.8 and cuDNN 9 in skypilot-config.yaml. - This addition enhances the configuration for better compatibility with specific model training requirements. * feat: enhance setup script for package installation and sudo handling - Added logic to create a custom sudo command if not available, ensuring script compatibility. - Implemented checks for essential packages (git, curl, tmux) and automated their installation if missing. - Updated the installation process for 'uv' to use a script from the official source, improving reliability. * feat: enhance LocalBackend and MegatronService with improved checkpoint handling and LoRA configuration - Updated LocalBackend to copy current checkpoints instead of renaming, ensuring data integrity during training steps. - Refactored MegatronService to ensure identity LoRA creation and configuration management, enhancing model training reliability. - Improved offloading and reloading of model parameters to optimize memory usage during training. - Enhanced error handling and logging for better debugging and user feedback. * feat: add method to manage optimizer state path in MegatronService - Introduced _get_optimizer_state_path method to streamline optimizer state path management. - Refactored optimizer state path assignment to ensure consistent directory creation and handling. - Improved code clarity and organization within the MegatronService class. * feat: add megatron dependency and improve code formatting - Added "megatron.**" to allowed unresolved imports in pyproject.toml for better dependency management. - Refactored code in LocalBackend and MegatronService for improved readability and consistency, including assertion formatting and path handling. - Enhanced clarity in the handling of inputs and outputs in training logic. * refactor: enhance LoRA configuration handling in MegatronService - Updated _default_lora_adapter_config method to return a LoraConfig instance for improved type safety and clarity. - Refactored _create_identity_lora method to utilize the updated configuration structure. - Improved JSON serialization of LoRA configuration by using asdict for better compatibility. - Cleaned up import statements for consistency and removed unnecessary imports. * feat: implement LoRA and offloading functionality in Megatron - Added LoRA class for low-rank adaptation, including methods for parameter initialization, loading, and forward pass. - Introduced OffloadState class and functions to offload model parameters and optimizer state to CPU, enhancing memory management. - Implemented reload functionality to transfer parameters back to GPU, improving training efficiency. - Integrated new provider setup for model initialization, streamlining the process of obtaining the GPT model provider. * feat: add type assertions for linear layers in LoRA classes - Introduced type assertions to ensure linear projection layers are of the correct type, enhancing type safety. - Added checks for tensor types in various LoRA classes to prevent runtime errors and improve debugging. - Updated apply_lora_adapters function to include type checks for expert linear layers, ensuring compatibility with the expected types. * fix: update import statements and add type assertions in Megatron modules - Removed unnecessary imports and added missing type imports for better clarity and type safety. - Introduced an assertion to ensure compatibility with Qwen3 MoE models in the provider setup. - Enhanced type checking for linear layers in LoRA classes to prevent runtime errors.
1 parent 0dca3a8 commit 64e9e6d

14 files changed

Lines changed: 1501 additions & 13 deletions

File tree

dev/yes-no-maybe-megatron.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import asyncio
2+
from itertools import permutations
3+
import os
4+
5+
from dotenv import load_dotenv
6+
import openai
7+
8+
import art
9+
from art.megatron import MegatronBackend
10+
11+
12+
async def rollout(
13+
client: openai.AsyncOpenAI, model_name: str, prompt: str
14+
) -> art.Trajectory:
15+
messages: art.Messages = [{"role": "user", "content": prompt}]
16+
chat_completion = await client.chat.completions.create(
17+
messages=messages, model=model_name, max_tokens=100, timeout=100
18+
)
19+
choice = chat_completion.choices[0]
20+
content = choice.message.content
21+
assert isinstance(content, str)
22+
if content == "yes":
23+
reward = 0.5
24+
elif content == "no":
25+
reward = 0.75
26+
elif content == "maybe":
27+
reward = 1.0
28+
else:
29+
reward = 0.0
30+
return art.Trajectory(messages_and_choices=[*messages, choice], reward=reward)
31+
32+
33+
def with_quotes(w: str) -> str:
34+
return f"'{w}'"
35+
36+
37+
async def main():
38+
load_dotenv()
39+
40+
backend = MegatronBackend()
41+
base_model = os.environ.get("BASE_MODEL", "Qwen/Qwen3-30B-A3B-Instruct-2507")
42+
model = art.TrainableModel(
43+
name=os.environ.get("MODEL_NAME", "megatron-001"),
44+
project="yes-no-maybe-megatron",
45+
base_model=base_model,
46+
)
47+
await model.register(backend)
48+
49+
prompts = [
50+
f"{prefix} with {', '.join([with_quotes(w) if use_quotes else w for w in words]) if len(words) == 3 else f'{words[0]}' + (f' or {words[1]}' if len(words) > 1 else '')}"
51+
for prefix in ["respond", "just respond"]
52+
for use_quotes in [True, False]
53+
for words in (
54+
list(p) for n in [3, 2] for p in permutations(["yes", "no", "maybe"], n)
55+
)
56+
]
57+
58+
openai_client = model.openai_client()
59+
max_steps = int(os.environ.get("NUM_STEPS", "20"))
60+
start_step = await model.get_step()
61+
62+
for step in range(start_step, start_step + max_steps):
63+
print(f"\n=== Step {step + 1} ===")
64+
train_groups = await art.gather_trajectory_groups(
65+
(
66+
art.TrajectoryGroup(
67+
rollout(openai_client, model.name, prompt) for _ in range(32)
68+
)
69+
for prompt in prompts
70+
)
71+
)
72+
await model.train(
73+
train_groups,
74+
config=art.TrainConfig(learning_rate=1e-4),
75+
)
76+
77+
78+
if __name__ == "__main__":
79+
asyncio.run(main())

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@ allowed-unresolved-imports = [
152152
# plotting deps
153153
"matplotlib.**",
154154
"seaborn.**",
155+
# megatron deps
156+
"megatron.**",
155157
]
156158

157159
[dependency-groups]

scripts/setup.sh

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,31 @@ if [ -f .env ]; then
1313
done < .env
1414
fi
1515

16+
if ! command -v sudo >/dev/null 2>&1; then
17+
sudo_path="/usr/local/bin/sudo"
18+
if [ ! -w /usr/local/bin ]; then
19+
sudo_path="$HOME/.local/bin/sudo"
20+
mkdir -p "$HOME/.local/bin"
21+
export PATH="$HOME/.local/bin:$PATH"
22+
fi
23+
24+
cat <<'EOF' > "$sudo_path"
25+
#!/bin/sh
26+
exec "$@"
27+
EOF
28+
chmod +x "$sudo_path"
29+
fi
30+
31+
need_pkgs=()
32+
command -v git >/dev/null 2>&1 || need_pkgs+=("git")
33+
command -v curl >/dev/null 2>&1 || need_pkgs+=("curl")
34+
command -v tmux >/dev/null 2>&1 || need_pkgs+=("tmux")
35+
36+
if [ "${#need_pkgs[@]}" -gt 0 ]; then
37+
apt-get update
38+
apt-get install -y "${need_pkgs[@]}"
39+
fi
40+
1641
# Configure git user name and email
1742
git config --global user.name "${GIT_USER_NAME}"
1843
git config --global user.email "${GIT_USER_EMAIL}"
@@ -29,14 +54,17 @@ else
2954
fi
3055

3156
# Install astral-uv
32-
sudo snap install --classic astral-uv
57+
if ! command -v uv >/dev/null 2>&1; then
58+
if ! curl -LsSf https://astral.sh/uv/install.sh | sh; then
59+
echo "Failed to install uv." >&2
60+
exit 1
61+
fi
62+
export PATH="$HOME/.local/bin:$HOME/.cargo/bin:$PATH"
63+
fi
3364

3465
# Update uv
3566
uv self update
3667

37-
# Install tmux
38-
apt install tmux -y
39-
4068
# Sync the dependencies
4169
if [ "${INSTALL_EXTRAS:-false}" = "true" ]; then
4270
uv sync --all-extras

skypilot-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,7 @@
383383
workdir: .
384384
resources:
385385
accelerators: ["H100-SXM:1", "H100:1", "A100-80GB:1"]
386+
image_id: docker:pytorch/pytorch:2.9.0-cuda12.8-cudnn9-devel
386387
ports:
387388
- 7999 # main ART server
388389
- 8000 # vLLM server

src/art/local/backend.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import math
44
import os
5+
import shutil
56
import subprocess
67
from types import TracebackType
78
from typing import AsyncIterator, Iterable, Literal, cast
@@ -570,20 +571,22 @@ async def _train_model(
570571
get_model_dir(model=model, art_path=self._path), next_step
571572
)
572573

573-
# If the current checkpoint exists, rename it to the next step
574+
# If the current checkpoint exists, copy it to the next step
574575
if os.path.exists(current_checkpoint_dir):
575-
os.rename(current_checkpoint_dir, next_checkpoint_dir)
576+
shutil.copytree(
577+
current_checkpoint_dir,
578+
next_checkpoint_dir,
579+
dirs_exist_ok=True,
580+
)
576581
print(
577582
f"Advanced step from {current_step} to {next_step} (no training occurred)"
578583
)
579584

580585
try:
581-
# Register the renamed checkpoint as a new LoRA adapter
586+
# Register the copied checkpoint as a new LoRA adapter
582587
# so it's available for inference at the new step
583-
from ..unsloth.service import UnslothService
584-
585-
if isinstance(service, UnslothService):
586-
await service.register_lora_for_step(
588+
if hasattr(service, "register_lora_for_step"):
589+
await service.register_lora_for_step( # type: ignore[attr-defined]
587590
next_step, next_checkpoint_dir
588591
)
589592
except ModuleNotFoundError:

src/art/megatron/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .backend import MegatronBackend
2+
3+
__all__ = ["MegatronBackend"]

src/art/megatron/backend.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from mp_actors import move_to_child_process
2+
3+
from ..local.backend import LocalBackend
4+
from ..local.service import ModelService
5+
from ..model import TrainableModel
6+
from ..utils.output_dirs import get_model_dir
7+
8+
9+
class MegatronBackend(LocalBackend):
10+
def __init__(
11+
self,
12+
*,
13+
in_process: bool = False,
14+
path: str | None = None,
15+
) -> None:
16+
super().__init__(in_process=in_process, path=path)
17+
18+
async def _get_service(self, model: TrainableModel) -> ModelService:
19+
from ..dev.get_model_config import get_model_config
20+
from .service import MegatronService
21+
22+
if model.name not in self._services:
23+
config = get_model_config(
24+
base_model=model.base_model,
25+
output_dir=get_model_dir(model=model, art_path=self._path),
26+
config=model._internal_config,
27+
)
28+
self._services[model.name] = MegatronService(
29+
model_name=model.name,
30+
base_model=model.base_model,
31+
config=config,
32+
output_dir=get_model_dir(model=model, art_path=self._path),
33+
)
34+
if not self._in_process:
35+
self._services[model.name] = move_to_child_process(
36+
self._services[model.name],
37+
process_name="megatron-service",
38+
)
39+
return self._services[model.name]

0 commit comments

Comments
 (0)