Skip to content

Commit 3990b66

Browse files
authored
feat: Add dedicated mode for UnslothService (#577)
1 parent 25992ca commit 3990b66

9 files changed

Lines changed: 770 additions & 12 deletions

File tree

src/art/dev/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
)
1111
from .openai_server import OpenAIServerConfig, ServerArgs, get_openai_server_config
1212
from .train import TrainConfig, TrainSFTConfig
13+
from .validate import is_dedicated_mode, validate_dedicated_config
1314

1415
__all__ = [
1516
"EngineArgs",
@@ -21,8 +22,10 @@
2122
"TinkerTrainingClientArgs",
2223
"TrainerArgs",
2324
"get_openai_server_config",
25+
"is_dedicated_mode",
2426
"OpenAIServerConfig",
2527
"ServerArgs",
2628
"TrainSFTConfig",
2729
"TrainConfig",
30+
"validate_dedicated_config",
2831
]

src/art/dev/get_model_config.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .engine import EngineArgs
22
from .model import InitArgs, InternalModelConfig, PeftArgs, TrainerArgs
3+
from .validate import is_dedicated_mode
34

45

56
def get_model_config(
@@ -12,13 +13,22 @@ def get_model_config(
1213
if config is None:
1314
config = InternalModelConfig()
1415

15-
enable_sleep_mode = config.get("engine_args", {}).get("enable_sleep_mode", True)
16+
dedicated = is_dedicated_mode(config)
17+
18+
if dedicated:
19+
enable_sleep_mode = False
20+
else:
21+
enable_sleep_mode = config.get("engine_args", {}).get("enable_sleep_mode", True)
22+
1623
init_args = InitArgs(
17-
fast_inference=False,
1824
load_in_4bit=True,
1925
max_seq_length=32768,
2026
model_name=base_model,
2127
)
28+
# fast_inference triggers in-process vLLM via Unsloth; dedicated mode runs vLLM as a subprocess
29+
if not dedicated:
30+
init_args["fast_inference"] = False
31+
2232
engine_args = EngineArgs(
2333
allowed_local_media_path="/tmp",
2434
enable_sleep_mode=enable_sleep_mode,
@@ -63,10 +73,15 @@ def get_model_config(
6373
weight_decay=0.1,
6474
)
6575
trainer_args.update(config.get("trainer_args", {}))
66-
return InternalModelConfig(
76+
result = InternalModelConfig(
6777
init_args=init_args,
6878
engine_args=engine_args,
6979
peft_args=peft_args,
7080
tinker_args=config.get("tinker_args"),
7181
trainer_args=trainer_args,
7282
)
83+
if "trainer_gpu_ids" in config:
84+
result["trainer_gpu_ids"] = config["trainer_gpu_ids"]
85+
if "inference_gpu_ids" in config:
86+
result["inference_gpu_ids"] = config["inference_gpu_ids"]
87+
return result

src/art/dev/model.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,11 @@ class InternalModelConfig(TypedDict, total=False):
115115
peft: Arguments for creating an Unsloth PEFT model wrapper.
116116
tinker: Arguments for the Tinker training client.
117117
trainer: Arguments for the GRPO trainer.
118+
trainer_gpu_ids: GPU IDs for training (e.g., [0]). When set with
119+
inference_gpu_ids, enables dedicated mode where training and
120+
inference run on separate GPUs.
121+
inference_gpu_ids: GPU IDs for vLLM inference (e.g., [1]). When set
122+
with trainer_gpu_ids, enables dedicated mode.
118123
"""
119124

120125
init_args: "InitArgs"
@@ -123,6 +128,8 @@ class InternalModelConfig(TypedDict, total=False):
123128
tinker_args: "TinkerArgs | None"
124129
tinker_native_args: "TinkerNativeArgs | None"
125130
trainer_args: "TrainerArgs"
131+
trainer_gpu_ids: list[int]
132+
inference_gpu_ids: list[int]
126133

127134

128135
class TinkerArgs(TypedDict, total=False):

src/art/dev/validate.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""Validation functions for model configuration."""
2+
3+
from .model import InternalModelConfig
4+
5+
6+
def is_dedicated_mode(config: InternalModelConfig) -> bool:
7+
"""Return True if the config specifies dedicated mode (separate training and inference GPUs)."""
8+
return "trainer_gpu_ids" in config and "inference_gpu_ids" in config
9+
10+
11+
def validate_dedicated_config(config: InternalModelConfig) -> None:
12+
"""Validate dedicated mode GPU configuration.
13+
14+
Raises ValueError if the configuration is invalid.
15+
Does nothing if neither trainer_gpu_ids nor inference_gpu_ids is set (shared mode).
16+
"""
17+
has_trainer = "trainer_gpu_ids" in config
18+
has_inference = "inference_gpu_ids" in config
19+
20+
if has_trainer != has_inference:
21+
raise ValueError(
22+
"trainer_gpu_ids and inference_gpu_ids must both be set or both unset"
23+
)
24+
25+
if not has_trainer:
26+
return
27+
28+
trainer_gpu_ids = config["trainer_gpu_ids"]
29+
inference_gpu_ids = config["inference_gpu_ids"]
30+
31+
if not trainer_gpu_ids:
32+
raise ValueError("trainer_gpu_ids must be non-empty")
33+
34+
if not inference_gpu_ids:
35+
raise ValueError("inference_gpu_ids must be non-empty")
36+
37+
if set(trainer_gpu_ids) & set(inference_gpu_ids):
38+
raise ValueError("trainer_gpu_ids and inference_gpu_ids must not overlap")
39+
40+
if len(inference_gpu_ids) > 1:
41+
raise ValueError(
42+
"Multi-GPU inference not yet supported; inference_gpu_ids must have exactly one GPU"
43+
)
44+
45+
if trainer_gpu_ids[0] != 0:
46+
raise ValueError(
47+
"trainer_gpu_ids must start at GPU 0 (training runs in-process)"
48+
)
49+
50+
expected = list(range(len(trainer_gpu_ids)))
51+
if trainer_gpu_ids != expected:
52+
raise ValueError(
53+
"trainer_gpu_ids must be contiguous starting from 0 (e.g., [0], [0,1])"
54+
)
55+
56+
# Reject settings that are incompatible with dedicated mode
57+
if config.get("init_args", {}).get("fast_inference"):
58+
raise ValueError(
59+
"fast_inference is incompatible with dedicated mode "
60+
"(dedicated mode runs vLLM as a subprocess, not in-process)"
61+
)
62+
63+
if config.get("engine_args", {}).get("enable_sleep_mode"):
64+
raise ValueError(
65+
"enable_sleep_mode is incompatible with dedicated mode "
66+
"(dedicated mode runs vLLM on a separate GPU, sleep/wake is not needed)"
67+
)

src/art/local/backend.py

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import json
3+
import logging
34
import math
45
import os
56
import shutil
@@ -9,6 +10,8 @@
910
from typing import AsyncIterator, Iterable, Literal, cast
1011
import warnings
1112

13+
logger = logging.getLogger(__name__)
14+
1215
import aiohttp
1316
import numpy as np
1417
from openai import AsyncOpenAI
@@ -97,6 +100,9 @@ async def close(self) -> None:
97100

98101
def _close(self) -> None:
99102
for _, service in self._services.items():
103+
close = getattr(service, "close", None)
104+
if close is not None:
105+
close()
100106
close_proxy(service)
101107

102108
async def register(
@@ -140,18 +146,39 @@ def _model_inference_name(self, model: Model, step: int | None = None) -> str:
140146

141147
# For LocalBackend, vLLM always serves LoRA adapters with @step suffix
142148
# Default to step 0 when not specified (the initial checkpoint created at registration)
143-
actual_step = step if step is not None else self.__get_step(model)
144-
return f"{model.name}@{actual_step}"
149+
if step is not None:
150+
actual_step = step
151+
elif model.name in self._services:
152+
# In dedicated mode the service tracks which adapter vLLM has
153+
# actually loaded. Reading the filesystem would race: the
154+
# checkpoint directory appears before the HTTP reload completes.
155+
svc = self._services[model.name]
156+
loaded_step = getattr(svc, "_latest_step", None)
157+
actual_step = (
158+
loaded_step if loaded_step is not None else self.__get_step(model)
159+
)
160+
else:
161+
actual_step = self.__get_step(model)
162+
name = f"{model.name}@{actual_step}"
163+
logger.debug(
164+
f"[BACKEND] _model_inference_name: step_arg={step} "
165+
f"actual_step={actual_step} -> {name}"
166+
)
167+
return name
145168

146169
async def _get_service(self, model: TrainableModel) -> ModelService:
147170
from ..dev.get_model_config import get_model_config
171+
from ..dev.validate import is_dedicated_mode, validate_dedicated_config
148172

149173
if model.name not in self._services:
150174
config = get_model_config(
151175
base_model=model.base_model,
152176
output_dir=get_model_dir(model=model, art_path=self._path),
153177
config=model._internal_config,
154178
)
179+
validate_dedicated_config(config)
180+
dedicated = is_dedicated_mode(config)
181+
155182
is_tinker = config.get("tinker_args") is not None
156183
if is_tinker:
157184
from ..tinker.service import TinkerService
@@ -164,13 +191,19 @@ async def _get_service(self, model: TrainableModel) -> ModelService:
164191
# When moving the service to a child process, import unsloth
165192
# early to maximize optimizations
166193
os.environ["IMPORT_UNSLOTH"] = "1"
194+
195+
if dedicated:
196+
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
197+
str(g) for g in config["trainer_gpu_ids"]
198+
)
199+
167200
self._services[model.name] = service_class(
168201
model_name=model.name,
169202
base_model=model.base_model,
170203
config=config,
171204
output_dir=get_model_dir(model=model, art_path=self._path),
172205
)
173-
if not self._in_process:
206+
if not dedicated and not self._in_process:
174207
# Kill all "model-service" processes to free up GPU memory
175208
subprocess.run(["pkill", "-9", "model-service"])
176209
self._services[model.name] = move_to_child_process(
@@ -609,6 +642,10 @@ async def _train_model(
609642
# Still advance the step by renaming the checkpoint directory
610643
current_step = self.__get_step(model)
611644
next_step = current_step + 1
645+
logger.info(
646+
f"[BACKEND] _train_model SKIP: current_step={current_step} "
647+
f"next_step={next_step} (all rewards equal)"
648+
)
612649
current_checkpoint_dir = get_step_checkpoint_dir(
613650
get_model_dir(model=model, art_path=self._path), current_step
614651
)
@@ -623,8 +660,9 @@ async def _train_model(
623660
next_checkpoint_dir,
624661
dirs_exist_ok=True,
625662
)
626-
print(
627-
f"Advanced step from {current_step} to {next_step} (no training occurred)"
663+
logger.info(
664+
f"[BACKEND] _train_model SKIP: copied checkpoint "
665+
f"{current_step} -> {next_step}, calling register_lora_for_step..."
628666
)
629667

630668
try:
@@ -634,6 +672,10 @@ async def _train_model(
634672
await service.register_lora_for_step( # type: ignore[attr-defined]
635673
next_step, next_checkpoint_dir
636674
)
675+
logger.info(
676+
f"[BACKEND] _train_model SKIP: register_lora_for_step "
677+
f"completed for step {next_step}"
678+
)
637679
except ModuleNotFoundError:
638680
pass # Unsloth is not installed
639681

0 commit comments

Comments
 (0)