Skip to content

Commit c915043

Browse files
authored
feat: Support PipelineTrainer with dedicated LocalBackend (#621)
1 parent d48764f commit c915043

File tree

11 files changed

+615
-37
lines changed

11 files changed

+615
-37
lines changed

docs/features/checkpoint-forking.mdx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import art
3131
from art.local import LocalBackend
3232

3333
async def train():
34-
with LocalBackend() as backend:
34+
async with LocalBackend() as backend:
3535
# Create a new model that will fork from an existing checkpoint
3636
model = art.TrainableModel(
3737
name="my-model-v2",
@@ -115,7 +115,7 @@ low_lr_model = art.TrainableModel(
115115
)
116116

117117
async def experiment():
118-
with LocalBackend() as backend:
118+
async with LocalBackend() as backend:
119119
# Fork the model from the base model
120120
await backend._experimental_fork_checkpoint(
121121
low_lr_model,

docs/fundamentals/art-backend.mdx

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,30 @@ backend = LocalBackend(
7373
)
7474
```
7575

76+
If you're using `PipelineTrainer`, `LocalBackend` is currently supported only in dedicated mode, where training and inference run on separate GPUs.
77+
78+
```python
79+
from art import TrainableModel
80+
from art.dev import InternalModelConfig
81+
from art.local import LocalBackend
82+
83+
backend = LocalBackend(path="./.art")
84+
model = TrainableModel(
85+
name="pipeline-localbackend",
86+
project="my-project",
87+
base_model="Qwen/Qwen3-0.6B",
88+
_internal_config=InternalModelConfig(
89+
trainer_gpu_ids=[0],
90+
inference_gpu_ids=[1],
91+
),
92+
)
93+
```
94+
95+
Shared `LocalBackend` still pauses inference during training, so ART rejects that configuration for `PipelineTrainer`.
96+
97+
In dedicated mode, a new checkpoint becomes the default inference target only after its LoRA has been reloaded into vLLM. That checkpoint publication flow is backend-specific, so `save_checkpoint` does not have identical semantics across every ART backend.
98+
Requests that are already in flight keep using the adapter they started with; the reload only affects subsequent routing to the latest served step.
99+
76100
## Using a backend
77101

78102
Once initialized, a backend can be used in the same way regardless of whether it runs locally or remotely.

docs/fundamentals/training-loop.mdx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ ART's functionality is divided into a [**client**](/fundamentals/art-client) and
2222

2323
This training loop runs until a specified number of inference and training iterations have completed.
2424

25+
This describes the default shared-resource loop. `PipelineTrainer` can also run with `LocalBackend` in dedicated mode, where training and inference stay on separate GPUs and the latest served step advances only after vLLM reloads the new LoRA.
26+
2527
Training and inference use both the ART **client** and **backend**. Learn more by following the links below!
2628

2729
<div className="cards-container">

src/art/local/backend.py

Lines changed: 63 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,9 @@ def _allocated_gpu_count(self, model: Model) -> int:
160160
def __enter__(self) -> Self:
161161
return self
162162

163+
async def __aenter__(self) -> Self:
164+
return self
165+
163166
def __exit__(
164167
self,
165168
exc_type: type[BaseException] | None,
@@ -168,14 +171,30 @@ def __exit__(
168171
) -> None:
169172
self._close()
170173

174+
async def __aexit__(
175+
self,
176+
exc_type: type[BaseException] | None,
177+
exc: BaseException | None,
178+
tb: TracebackType | None,
179+
) -> None:
180+
await self.close()
181+
171182
async def close(self) -> None:
172183
"""
173184
If running vLLM in a separate process, this will kill that process and close the communication threads.
174185
"""
175-
self._close()
186+
for service in self._services.values():
187+
aclose = getattr(service, "aclose", None)
188+
if aclose is None:
189+
close = getattr(service, "close", None)
190+
if close is not None:
191+
close()
192+
else:
193+
await aclose()
194+
close_proxy(service)
176195

177196
def _close(self) -> None:
178-
for _, service in self._services.items():
197+
for service in self._services.values():
179198
close = getattr(service, "close", None)
180199
if close is not None:
181200
close()
@@ -225,25 +244,27 @@ def _model_inference_name(self, model: Model, step: int | None = None) -> str:
225244
If None, returns name for latest checkpoint (step 0 initially).
226245
"""
227246

228-
# For LocalBackend, vLLM always serves LoRA adapters with @step suffix
229-
# Default to step 0 when not specified (the initial checkpoint created at registration)
230-
if step is not None:
231-
actual_step = step
232-
elif model.name in self._services and self._in_process:
233-
# In dedicated mode the service tracks which adapter vLLM has
234-
# actually loaded. Reading the filesystem would race: the
235-
# checkpoint directory appears before the HTTP reload completes.
236-
svc = self._services[model.name]
237-
loaded_step = getattr(svc, "_latest_step", None)
238-
actual_step = (
239-
loaded_step if loaded_step is not None else self.__get_step(model)
240-
)
241-
else:
242-
actual_step = self.__get_step(model)
243-
name = f"{model.name}@{actual_step}"
247+
requested_step = step
248+
249+
if step is None and isinstance(model, TrainableModel):
250+
from ..dev.validate import is_dedicated_mode
251+
252+
service = self._services.get(model.name)
253+
if service is not None and is_dedicated_mode(
254+
model._internal_config or dev.InternalModelConfig()
255+
):
256+
loaded_step = getattr(service, "_latest_step", None)
257+
if isinstance(loaded_step, int):
258+
step = loaded_step
259+
260+
if step is None:
261+
# The checkpoint directory is written before dedicated-mode
262+
# vLLM finishes reloading the new adapter.
263+
step = self.__get_step(model)
264+
name = f"{model.name}@{step}"
244265
logger.debug(
245-
f"[BACKEND] _model_inference_name: step_arg={step} "
246-
f"actual_step={actual_step} -> {name}"
266+
f"[BACKEND] _model_inference_name: step_arg={requested_step} "
267+
f"actual_step={step} -> {name}"
247268
)
248269
return name
249270

@@ -508,12 +529,14 @@ async def train( # type: ignore[override]
508529
*,
509530
# Core training parameters
510531
learning_rate: float = 5e-6,
532+
loss_fn: Literal["cispo", "ppo"] = "cispo",
533+
loss_fn_config: dict | None = None,
534+
normalize_advantages: bool = True,
535+
adam_params: object | None = None,
511536
# KL-penalized advantage adjustment
512537
kl_penalty_coef: float = 0.0,
513538
kl_penalty_reference_step: int | None = None,
514539
kl_ref_adapter_path: str | None = None,
515-
# RL algorithm settings
516-
ppo: bool = False,
517540
epsilon: float | None = None,
518541
epsilon_high: float | None = None,
519542
# Advantage computation
@@ -550,6 +573,14 @@ async def train( # type: ignore[override]
550573
model: The trainable model to train.
551574
trajectory_groups: Batches of trajectories to train on.
552575
learning_rate: Learning rate for training. Defaults to 5e-6.
576+
loss_fn: RL loss function. LocalBackend currently supports
577+
"cispo" and "ppo".
578+
loss_fn_config: Additional loss-function config. Not supported by
579+
LocalBackend.
580+
normalize_advantages: Whether to normalize advantages. LocalBackend
581+
currently requires True.
582+
adam_params: Custom optimizer params. Not supported by
583+
LocalBackend.
553584
kl_penalty_coef: Coefficient for KL-penalized advantage adjustment.
554585
Tokens diverging more from the reference get reduced advantages.
555586
Defaults to 0.0 (disabled).
@@ -559,8 +590,7 @@ async def train( # type: ignore[override]
559590
kl_ref_adapter_path: Direct filesystem path to a LoRA adapter
560591
checkpoint to use as the KL reference. Alternative to
561592
kl_penalty_reference_step.
562-
ppo: Whether to use PPO clipping. Defaults to False.
563-
epsilon: Clip epsilon for importance sampling. Defaults based on ppo.
593+
epsilon: Clip epsilon for importance sampling. Defaults based on loss_fn.
564594
epsilon_high: Asymmetric upper clip bound. Defaults to epsilon.
565595
advantage_balance: Balance between negative and positive advantages
566596
in range [-1.0, 1.0]. Defaults to 0.0 (balanced).
@@ -603,6 +633,14 @@ async def train( # type: ignore[override]
603633
# await model.log(metrics=result.metrics, step=result.step)
604634
"""
605635
groups_list = list(trajectory_groups)
636+
if loss_fn not in {"cispo", "ppo"}:
637+
raise ValueError("LocalBackend only supports loss_fn='cispo' or 'ppo'.")
638+
if loss_fn_config is not None:
639+
raise ValueError("LocalBackend requires loss_fn_config=None.")
640+
if not normalize_advantages:
641+
raise ValueError("LocalBackend requires normalize_advantages=True.")
642+
if adam_params is not None:
643+
raise ValueError("LocalBackend requires adam_params=None.")
606644

607645
# Build config objects from explicit kwargs
608646
config = TrainConfig(
@@ -615,7 +653,7 @@ async def train( # type: ignore[override]
615653
"kl_penalty_coef": kl_penalty_coef,
616654
"mask_prob_ratio": mask_prob_ratio,
617655
"plot_tensors": plot_tensors,
618-
"ppo": ppo,
656+
"ppo": loss_fn == "ppo",
619657
"precalculate_logprobs": precalculate_logprobs,
620658
"scale_learning_rate_by_reward_std_dev": scale_learning_rate_by_reward_std_dev,
621659
"scale_rewards": scale_rewards,

src/art/pipeline_trainer/trainer.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def __init__(
154154
total_scenarios=total_scenarios,
155155
num_workers=num_rollout_workers,
156156
)
157+
self._validate_backend_support()
157158

158159
async def train(self, *, handle_signals: bool = True) -> None:
159160
"""Run the training pipeline over the configured scenario iterator."""
@@ -277,6 +278,42 @@ async def _notify_policy() -> None:
277278
except asyncio.QueueFull:
278279
loop.create_task(self._output_queue.put(None))
279280

281+
def _validate_backend_support(self) -> None:
282+
from art.dev.validate import is_dedicated_mode
283+
from art.local.backend import LocalBackend
284+
285+
if not isinstance(self.backend, LocalBackend):
286+
return
287+
288+
model_config = self.model._internal_config or art.dev.InternalModelConfig()
289+
if not is_dedicated_mode(model_config):
290+
raise ValueError(
291+
"PipelineTrainer only supports LocalBackend in dedicated mode. "
292+
"Shared LocalBackend pauses inference during training and is not "
293+
"a supported async PipelineTrainer path. Set both "
294+
"trainer_gpu_ids and inference_gpu_ids on the TrainableModel "
295+
"_internal_config to use LocalBackend with PipelineTrainer."
296+
)
297+
if self.loss_fn not in {"cispo", "ppo"}:
298+
raise ValueError(
299+
"PipelineTrainer + LocalBackend(dedicated) only supports "
300+
"loss_fn='cispo' or loss_fn='ppo'."
301+
)
302+
if self.loss_fn_config is not None:
303+
raise ValueError(
304+
"PipelineTrainer + LocalBackend(dedicated) requires "
305+
"loss_fn_config=None."
306+
)
307+
if not self.normalize_advantages:
308+
raise ValueError(
309+
"PipelineTrainer + LocalBackend(dedicated) requires "
310+
"normalize_advantages=True."
311+
)
312+
if self.adam_params is not None:
313+
raise ValueError(
314+
"PipelineTrainer + LocalBackend(dedicated) requires adam_params=None."
315+
)
316+
280317
async def _skip_scenarios(
281318
self, scenarios: AsyncIterator[ScenarioT], count: int
282319
) -> int:

src/art/test/test_step_skipping.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ async def test_step_skipping():
4444
# Set up backend with custom art path
4545
art_path = os.path.join(tmpdir, ".art")
4646

47-
with LocalBackend(path=art_path) as backend:
47+
async with LocalBackend(path=art_path) as backend:
4848
# Create a test model
4949
model = TrainableModel(
5050
name=f"test-step-skip-{uuid.uuid4()}",

src/art/unsloth/service.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from ..utils.get_model_step import get_step_from_dir
3636
from ..utils.output_dirs import get_step_checkpoint_dir
3737
from ..vllm import get_llm, get_worker, openai_server_task, run_on_workers
38-
from .train import gc_and_empty_cuda_cache, train
38+
from .train import StopTrainingLoop, gc_and_empty_cuda_cache, train
3939

4040
logger = logging.getLogger(__name__)
4141

@@ -55,6 +55,15 @@ class SupportsLoadLora(Protocol):
5555
def load_lora(self, lora_path: str, load_tensors: bool = True) -> LoRARequest: ...
5656

5757

58+
class _StopTrainInputs:
59+
"""Dedicated sentinel for stopping the background trainer loop."""
60+
61+
62+
_STOP_TRAIN_INPUT = _StopTrainInputs()
63+
_TRAIN_TASK_SHUTDOWN_TIMEOUT_S = 5.0
64+
_TrainLoopInput = TrainInputs | _StopTrainInputs
65+
66+
5867
def precalculate_new_logprobs(
5968
trainer: "GRPOTrainer",
6069
peft_model: "PeftModelForCausalLM",
@@ -91,7 +100,7 @@ async def process_train_batch(
91100
packed_tensors: PackedTensors,
92101
config: types.TrainConfig,
93102
_config: dev.TrainConfig,
94-
inputs_queue: asyncio.Queue[TrainInputs],
103+
inputs_queue: asyncio.Queue[_TrainLoopInput],
95104
results_queue: asyncio.Queue[dict[str, float]],
96105
train_task: asyncio.Task[None],
97106
trainer: "GRPOTrainer",
@@ -215,7 +224,7 @@ class UnslothState:
215224
tokenizer: PreTrainedTokenizerBase
216225
peft_model: peft.peft_model.PeftModelForCausalLM
217226
trainer: GRPOTrainer
218-
inputs_queue: asyncio.Queue[TrainInputs]
227+
inputs_queue: asyncio.Queue[_TrainLoopInput]
219228
results_queue: asyncio.Queue[dict[str, float]]
220229
_is_offloaded: bool = False
221230
_pinned_buffers: dict[str, torch.Tensor] | None = None
@@ -316,6 +325,7 @@ class UnslothService:
316325
_vllm_log_file: Any = field(default=None, repr=False)
317326
_vllm_host: str = "127.0.0.1"
318327
_vllm_port: int = 0
328+
_train_task: asyncio.Task[None] | None = field(default=None, init=False, repr=False)
319329

320330
@property
321331
def is_dedicated(self) -> bool:
@@ -326,6 +336,24 @@ def _next_lora_id(self) -> int:
326336
self._lora_id_counter += 1
327337
return self._lora_id_counter
328338

339+
async def aclose(self) -> None:
340+
train_task = self._train_task
341+
self._train_task = None
342+
if train_task is None or train_task.done():
343+
self.close()
344+
return
345+
346+
# `_state` is a cached_property. Read from __dict__ directly so
347+
# closing does not instantiate trainer state only to stop a task.
348+
state = self.__dict__.get("_state")
349+
assert isinstance(state, UnslothState)
350+
state.inputs_queue.put_nowait(_STOP_TRAIN_INPUT)
351+
try:
352+
await asyncio.wait_for(train_task, timeout=_TRAIN_TASK_SHUTDOWN_TIMEOUT_S)
353+
except asyncio.TimeoutError:
354+
train_task.cancel()
355+
self.close()
356+
329357
# =========================================================================
330358
# Dedicated mode: vLLM subprocess lifecycle
331359
# =========================================================================
@@ -595,7 +623,7 @@ async def _train_dedicated(
595623

596624
await self._state.results_queue.join()
597625

598-
if not hasattr(self, "_train_task") or self._train_task is None:
626+
if self._train_task is None:
599627
self._train_task = asyncio.create_task(
600628
train(
601629
trainer=self._state.trainer,
@@ -685,7 +713,7 @@ async def _train_shared(
685713
await self._state.results_queue.join()
686714

687715
# If we haven't already, start the training task
688-
if not hasattr(self, "_train_task") or self._train_task is None:
716+
if self._train_task is None:
689717
self._train_task = asyncio.create_task(
690718
train(
691719
trainer=self._state.trainer,
@@ -981,17 +1009,19 @@ def _state(self) -> UnslothState:
9811009
trainer.create_optimizer()
9821010

9831011
# Initialize queues
984-
inputs_queue: asyncio.Queue[TrainInputs] = asyncio.Queue()
1012+
inputs_queue: asyncio.Queue[_TrainLoopInput] = asyncio.Queue()
9851013
results_queue: asyncio.Queue[dict[str, float]] = asyncio.Queue()
9861014

9871015
# Patch trainer _prepare_inputs() to pull from queue
9881016
def _async_prepare_inputs(*_: Any, **__: Any) -> dict[str, torch.Tensor]:
989-
async def get_inputs() -> TrainInputs:
1017+
async def get_inputs() -> _TrainLoopInput:
9901018
return await inputs_queue.get()
9911019

9921020
# Force otherwise synchronous _prepare_inputs() to yield
9931021
# with nested asyncio.run() call
9941022
inputs = asyncio.run(get_inputs())
1023+
if isinstance(inputs, _StopTrainInputs):
1024+
raise StopTrainingLoop()
9951025

9961026
return cast(dict[str, torch.Tensor], inputs)
9971027

0 commit comments

Comments
 (0)