Skip to content

Commit 6a67282

Browse files
committed
feat: Support PipelineTrainer with dedicated LocalBackend
1 parent abecd93 commit 6a67282

File tree

7 files changed

+495
-26
lines changed

7 files changed

+495
-26
lines changed

src/art/local/backend.py

Lines changed: 61 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -159,19 +159,59 @@ def _allocated_gpu_count(self, model: Model) -> int:
159159
def __enter__(self) -> Self:
160160
return self
161161

162+
async def __aenter__(self) -> Self:
163+
return self
164+
162165
def __exit__(
163166
self,
164167
exc_type: type[BaseException] | None,
165168
exc: BaseException | None,
166169
tb: TracebackType | None,
167170
) -> None:
171+
try:
172+
asyncio.get_running_loop()
173+
except RuntimeError:
174+
running_loop = False
175+
else:
176+
running_loop = True
177+
178+
if running_loop or any(
179+
getattr(service, "aclose", None) is not None
180+
for service in self._services.values()
181+
):
182+
warnings.warn(
183+
"LocalBackend used as a sync context manager. Cleanup uses the "
184+
"best-effort sync shutdown path and cannot await service "
185+
"teardown safely here; use `async with LocalBackend(...)` or "
186+
"`await backend.close()` instead.",
187+
RuntimeWarning,
188+
stacklevel=2,
189+
)
168190
self._close()
169191

192+
async def __aexit__(
193+
self,
194+
exc_type: type[BaseException] | None,
195+
exc: BaseException | None,
196+
tb: TracebackType | None,
197+
) -> None:
198+
await self.close()
199+
170200
async def close(self) -> None:
171201
"""
172202
If running vLLM in a separate process, this will kill that process and close the communication threads.
173203
"""
174-
self._close()
204+
for _, service in self._services.items():
205+
# Keep this logic aligned with _close(), but avoid double-closing
206+
# services that expose an awaited aclose() path.
207+
aclose = getattr(service, "aclose", None)
208+
if aclose is not None:
209+
await aclose()
210+
else:
211+
close = getattr(service, "close", None)
212+
if close is not None:
213+
close()
214+
close_proxy(service)
175215

176216
def _close(self) -> None:
177217
for _, service in self._services.items():
@@ -219,21 +259,31 @@ def _model_inference_name(self, model: Model, step: int | None = None) -> str:
219259
If None, returns name for latest checkpoint (step 0 initially).
220260
"""
221261

262+
def _served_step() -> int | None:
263+
if not isinstance(model, TrainableModel):
264+
return None
265+
if model.name not in self._services:
266+
return None
267+
from ..dev.validate import is_dedicated_mode
268+
269+
if not is_dedicated_mode(
270+
model._internal_config or dev.InternalModelConfig()
271+
):
272+
return None
273+
loaded_step = getattr(self._services[model.name], "_latest_step", None)
274+
return loaded_step if isinstance(loaded_step, int) else None
275+
222276
# For LocalBackend, vLLM always serves LoRA adapters with @step suffix
223277
# Default to step 0 when not specified (the initial checkpoint created at registration)
224278
if step is not None:
225279
actual_step = step
226-
elif model.name in self._services and self._in_process:
227-
# In dedicated mode the service tracks which adapter vLLM has
228-
# actually loaded. Reading the filesystem would race: the
229-
# checkpoint directory appears before the HTTP reload completes.
230-
svc = self._services[model.name]
231-
loaded_step = getattr(svc, "_latest_step", None)
232-
actual_step = (
233-
loaded_step if loaded_step is not None else self.__get_step(model)
234-
)
235280
else:
236-
actual_step = self.__get_step(model)
281+
# In dedicated mode the service tracks which adapter vLLM has
282+
# actually loaded. Reading the filesystem would race: the checkpoint
283+
# directory appears before the HTTP reload completes.
284+
actual_step = _served_step()
285+
if actual_step is None:
286+
actual_step = self.__get_step(model)
237287
name = f"{model.name}@{actual_step}"
238288
logger.debug(
239289
f"[BACKEND] _model_inference_name: step_arg={step} "

src/art/pipeline_trainer/trainer.py

Lines changed: 69 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def __init__(
154154
total_scenarios=total_scenarios,
155155
num_workers=num_rollout_workers,
156156
)
157+
self._validate_backend_support()
157158

158159
async def train(self, *, handle_signals: bool = True) -> None:
159160
"""Run the training pipeline over the configured scenario iterator."""
@@ -277,6 +278,72 @@ async def _notify_policy() -> None:
277278
except asyncio.QueueFull:
278279
loop.create_task(self._output_queue.put(None))
279280

281+
def _is_local_backend(self) -> bool:
282+
from art.local.backend import LocalBackend
283+
284+
return isinstance(self.backend, LocalBackend)
285+
286+
def _local_backend_is_dedicated(self) -> bool:
287+
if not isinstance(self.model, art.TrainableModel):
288+
return False
289+
from art.dev.validate import is_dedicated_mode
290+
291+
return is_dedicated_mode(
292+
self.model._internal_config or art.dev.InternalModelConfig()
293+
)
294+
295+
def _validate_backend_support(self) -> None:
296+
if not self._is_local_backend():
297+
return
298+
if self._local_backend_is_dedicated():
299+
self._validate_local_backend_train_config()
300+
return
301+
raise ValueError(
302+
"PipelineTrainer only supports LocalBackend in dedicated mode. "
303+
"Shared LocalBackend pauses inference during training and is not "
304+
"a supported async PipelineTrainer path. Set both "
305+
"trainer_gpu_ids and inference_gpu_ids on the TrainableModel "
306+
"_internal_config to use LocalBackend with PipelineTrainer."
307+
)
308+
309+
def _validate_local_backend_train_config(self) -> None:
310+
if self.loss_fn not in {"cispo", "ppo"}:
311+
raise ValueError(
312+
"PipelineTrainer + LocalBackend(dedicated) only supports "
313+
"loss_fn='cispo' or loss_fn='ppo'."
314+
)
315+
if self.loss_fn_config is not None:
316+
raise ValueError(
317+
"PipelineTrainer + LocalBackend(dedicated) requires "
318+
"loss_fn_config=None."
319+
)
320+
if not self.normalize_advantages:
321+
raise ValueError(
322+
"PipelineTrainer + LocalBackend(dedicated) requires "
323+
"normalize_advantages=True."
324+
)
325+
if self.adam_params is not None:
326+
raise ValueError(
327+
"PipelineTrainer + LocalBackend(dedicated) requires adam_params=None."
328+
)
329+
330+
def _backend_train_kwargs(self, *, save_checkpoint: bool) -> dict[str, Any]:
331+
if not self._is_local_backend():
332+
return {
333+
"learning_rate": self.learning_rate,
334+
"loss_fn": self.loss_fn,
335+
"loss_fn_config": self.loss_fn_config,
336+
"normalize_advantages": self.normalize_advantages,
337+
"save_checkpoint": save_checkpoint,
338+
"adam_params": self.adam_params,
339+
}
340+
341+
return {
342+
"learning_rate": self.learning_rate,
343+
"ppo": self.loss_fn == "ppo",
344+
"save_checkpoint": save_checkpoint,
345+
}
346+
280347
async def _skip_scenarios(
281348
self, scenarios: AsyncIterator[ScenarioT], count: int
282349
) -> int:
@@ -412,18 +479,14 @@ async def _training_stage(self) -> None:
412479

413480
self._status.note_training_start(len(batch))
414481
train_call_start = time.monotonic()
482+
train_kwargs = self._backend_train_kwargs(save_checkpoint=should_checkpoint)
415483
if os.getenv("ART_TRAIN_STEP_LOG"):
416484
print(f"[train] step {expected_step} starting (batch={len(batch)})")
417485
try:
418486
result = await self.backend.train(
419487
self.model,
420488
batch,
421-
learning_rate=self.learning_rate,
422-
loss_fn=self.loss_fn,
423-
loss_fn_config=self.loss_fn_config,
424-
normalize_advantages=self.normalize_advantages,
425-
save_checkpoint=should_checkpoint,
426-
adam_params=self.adam_params,
489+
**train_kwargs,
427490
)
428491
except Exception:
429492
self._status.note_training_end()

src/art/test/test_step_skipping.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ async def test_step_skipping():
4444
# Set up backend with custom art path
4545
art_path = os.path.join(tmpdir, ".art")
4646

47-
with LocalBackend(path=art_path) as backend:
47+
async with LocalBackend(path=art_path) as backend:
4848
# Create a test model
4949
model = TrainableModel(
5050
name=f"test-step-skip-{uuid.uuid4()}",

src/art/unsloth/service.py

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from ..utils.get_model_step import get_step_from_dir
3636
from ..utils.output_dirs import get_step_checkpoint_dir
3737
from ..vllm import get_llm, get_worker, openai_server_task, run_on_workers
38-
from .train import gc_and_empty_cuda_cache, train
38+
from .train import StopTrainingLoop, gc_and_empty_cuda_cache, train
3939

4040
logger = logging.getLogger(__name__)
4141

@@ -55,6 +55,15 @@ class SupportsLoadLora(Protocol):
5555
def load_lora(self, lora_path: str, load_tensors: bool = True) -> LoRARequest: ...
5656

5757

58+
class _StopTrainInputs:
59+
"""Dedicated sentinel for stopping the background trainer loop."""
60+
61+
62+
_STOP_TRAIN_INPUT = _StopTrainInputs()
63+
_TRAIN_TASK_GRACEFUL_SHUTDOWN_TIMEOUT_S = 5.0
64+
_TRAIN_TASK_CANCEL_TIMEOUT_S = 1.0
65+
66+
5867
def precalculate_new_logprobs(
5968
trainer: "GRPOTrainer",
6069
peft_model: "PeftModelForCausalLM",
@@ -91,7 +100,7 @@ async def process_train_batch(
91100
packed_tensors: PackedTensors,
92101
config: types.TrainConfig,
93102
_config: dev.TrainConfig,
94-
inputs_queue: asyncio.Queue[TrainInputs],
103+
inputs_queue: asyncio.Queue[TrainInputs | _StopTrainInputs],
95104
results_queue: asyncio.Queue[dict[str, float]],
96105
train_task: asyncio.Task[None],
97106
trainer: "GRPOTrainer",
@@ -215,7 +224,7 @@ class UnslothState:
215224
tokenizer: PreTrainedTokenizerBase
216225
peft_model: peft.peft_model.PeftModelForCausalLM
217226
trainer: GRPOTrainer
218-
inputs_queue: asyncio.Queue[TrainInputs]
227+
inputs_queue: asyncio.Queue[TrainInputs | _StopTrainInputs]
219228
results_queue: asyncio.Queue[dict[str, float]]
220229
_is_offloaded: bool = False
221230
_pinned_buffers: dict[str, torch.Tensor] | None = None
@@ -316,6 +325,7 @@ class UnslothService:
316325
_vllm_log_file: Any = field(default=None, repr=False)
317326
_vllm_host: str = "127.0.0.1"
318327
_vllm_port: int = 0
328+
_train_task: asyncio.Task[None] | None = field(default=None, init=False, repr=False)
319329

320330
@property
321331
def is_dedicated(self) -> bool:
@@ -326,6 +336,46 @@ def _next_lora_id(self) -> int:
326336
self._lora_id_counter += 1
327337
return self._lora_id_counter
328338

339+
def _request_train_task_stop(self) -> asyncio.Task[None] | None:
340+
train_task = self._train_task
341+
if train_task is None:
342+
return None
343+
if train_task.done():
344+
return train_task
345+
346+
# `_state` is a cached_property. Read from __dict__ directly so shutdown
347+
# does not instantiate the full trainer state solely to stop a task.
348+
state = self.__dict__.get("_state")
349+
if isinstance(state, UnslothState):
350+
state.inputs_queue.put_nowait(_STOP_TRAIN_INPUT)
351+
return train_task
352+
353+
async def _shutdown_train_task(self) -> None:
354+
train_task = self._request_train_task_stop()
355+
if train_task is None:
356+
return
357+
358+
try:
359+
# Give the trainer loop time to consume the stop sentinel and exit
360+
# normally before falling back to cancellation.
361+
await asyncio.wait_for(
362+
train_task, timeout=_TRAIN_TASK_GRACEFUL_SHUTDOWN_TIMEOUT_S
363+
)
364+
except asyncio.TimeoutError:
365+
train_task.cancel()
366+
try:
367+
await asyncio.wait_for(train_task, timeout=_TRAIN_TASK_CANCEL_TIMEOUT_S)
368+
except (asyncio.CancelledError, asyncio.TimeoutError):
369+
pass
370+
except asyncio.CancelledError:
371+
pass
372+
finally:
373+
self._train_task = None
374+
375+
async def aclose(self) -> None:
376+
await self._shutdown_train_task()
377+
self.close()
378+
329379
# =========================================================================
330380
# Dedicated mode: vLLM subprocess lifecycle
331381
# =========================================================================
@@ -450,6 +500,7 @@ async def _reload_adapter(self, checkpoint_path: str, step: int) -> None:
450500

451501
def close(self) -> None:
452502
"""Terminate vLLM subprocess if running."""
503+
self._request_train_task_stop()
453504
if self._vllm_process is None:
454505
return
455506
self._vllm_process.terminate()
@@ -981,17 +1032,19 @@ def _state(self) -> UnslothState:
9811032
trainer.create_optimizer()
9821033

9831034
# Initialize queues
984-
inputs_queue: asyncio.Queue[TrainInputs] = asyncio.Queue()
1035+
inputs_queue: asyncio.Queue[TrainInputs | _StopTrainInputs] = asyncio.Queue()
9851036
results_queue: asyncio.Queue[dict[str, float]] = asyncio.Queue()
9861037

9871038
# Patch trainer _prepare_inputs() to pull from queue
9881039
def _async_prepare_inputs(*_: Any, **__: Any) -> dict[str, torch.Tensor]:
989-
async def get_inputs() -> TrainInputs:
1040+
async def get_inputs() -> TrainInputs | _StopTrainInputs:
9901041
return await inputs_queue.get()
9911042

9921043
# Force otherwise synchronous _prepare_inputs() to yield
9931044
# with nested asyncio.run() call
9941045
inputs = asyncio.run(get_inputs())
1046+
if isinstance(inputs, _StopTrainInputs):
1047+
raise StopTrainingLoop()
9951048

9961049
return cast(dict[str, torch.Tensor], inputs)
9971050

src/art/unsloth/train.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@
3939
}
4040

4141

42+
class StopTrainingLoop(Exception):
43+
"""Signal that the background trainer loop should exit cleanly."""
44+
45+
4246
def _canonicalize_upstream_metric_key(metric: str) -> str:
4347
if "/" in metric:
4448
return metric
@@ -74,7 +78,10 @@ async def train(
7478
if not is_train_dict:
7579
trainer._metrics = {"train": defaultdict(list)}
7680
try:
77-
trainer.train()
81+
try:
82+
trainer.train()
83+
except StopTrainingLoop:
84+
return
7885
finally:
7986
trainer.compute_loss = _compute_loss
8087
trainer.log = _log # ty:ignore[invalid-assignment]

src/art/utils/benchmarking/pull_model_trajectories.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@ async def pull_model_trajectories(model: ArtModel) -> None:
3131
"Environment variable BACKUP_BUCKET is required but was not found."
3232
)
3333

34-
# Use the LocalBackend context manager to work with the on-disk artefacts.
35-
with LocalBackend() as backend:
34+
# Use the LocalBackend async context manager so backend cleanup can await
35+
# any background service shutdown before returning.
36+
async with LocalBackend() as backend:
3637
print(
3738
f"Pulling trajectories for model '{model.name}' from S3 bucket '{bucket}'…",
3839
flush=True,

0 commit comments

Comments
 (0)