Skip to content

Commit f5d72b5

Browse files
committed
Enhance Megatron model configuration and training index handling
- Added internal configuration for model registration in yes-no-maybe-megatron.py to optimize GPU memory utilization and tensor parallel size. - Refactored index calculation in train.py to improve efficiency and handle cases where indices may be empty, ensuring robust data parallelism during training.
1 parent 0ad9a89 commit f5d72b5

2 files changed

Lines changed: 15 additions & 12 deletions

File tree

dev/yes-no-maybe-megatron.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from dotenv import load_dotenv
66
import openai
7+
import torch
78

89
import art
910
from art.megatron import MegatronBackend
@@ -43,6 +44,12 @@ async def main():
4344
name=os.environ.get("MODEL_NAME", "megatron-001"),
4445
project="yes-no-maybe-megatron",
4546
base_model=base_model,
47+
_internal_config=art.dev.InternalModelConfig(
48+
engine_args=art.dev.EngineArgs(
49+
gpu_memory_utilization=0.8,
50+
tensor_parallel_size=torch.cuda.device_count(),
51+
),
52+
),
4653
)
4754
await model.register(backend)
4855

src/art/megatron/train.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import gc
1010
import json
11+
import math
1112
import shutil
1213
import time
1314
from typing import Any, cast
@@ -213,18 +214,13 @@ def calculate_mask(
213214
num_sequences = job.disk_packed_tensors["num_sequences"]
214215
dp_rank = ps.get_data_parallel_rank()
215216
dp_world_size = ps.get_data_parallel_world_size()
216-
indices = list(
217-
range(
218-
dp_rank,
219-
num_sequences,
220-
dp_world_size,
221-
)
222-
)
223-
# pad indices
224-
if num_sequences % dp_world_size <= dp_rank > 0:
225-
indices.append(
226-
(list(range(num_sequences)) * (dp_world_size // num_sequences + 1))[dp_rank]
227-
)
217+
num_indices = math.ceil(num_sequences / dp_world_size)
218+
indices = list(range(dp_rank, num_sequences, dp_world_size))
219+
if not indices:
220+
indices = [dp_rank % num_sequences]
221+
# pad indices by repeating & slicing to target length
222+
repeat = math.ceil(num_indices / len(indices))
223+
indices = (indices * repeat)[:num_indices]
228224
for index in indices:
229225
inputs = PackedTensors( # type: ignore
230226
**{

0 commit comments

Comments
 (0)