Skip to content

Commit 47579de

Browse files
committed
refactor: simplify LocalBackend pipeline trainer integration
1 parent e4c6b2b commit 47579de

File tree

6 files changed

+165
-266
lines changed

6 files changed

+165
-266
lines changed

src/art/local/backend.py

Lines changed: 42 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -168,25 +168,6 @@ def __exit__(
168168
exc: BaseException | None,
169169
tb: TracebackType | None,
170170
) -> None:
171-
try:
172-
asyncio.get_running_loop()
173-
except RuntimeError:
174-
running_loop = False
175-
else:
176-
running_loop = True
177-
178-
if running_loop or any(
179-
getattr(service, "aclose", None) is not None
180-
for service in self._services.values()
181-
):
182-
warnings.warn(
183-
"LocalBackend used as a sync context manager. Cleanup uses the "
184-
"best-effort sync shutdown path and cannot await service "
185-
"teardown safely here; use `async with LocalBackend(...)` or "
186-
"`await backend.close()` instead.",
187-
RuntimeWarning,
188-
stacklevel=2,
189-
)
190171
self._close()
191172

192173
async def __aexit__(
@@ -201,20 +182,18 @@ async def close(self) -> None:
201182
"""
202183
If running vLLM in a separate process, this will kill that process and close the communication threads.
203184
"""
204-
for _, service in self._services.items():
205-
# Keep this logic aligned with _close(), but avoid double-closing
206-
# services that expose an awaited aclose() path.
185+
for service in self._services.values():
207186
aclose = getattr(service, "aclose", None)
208-
if aclose is not None:
209-
await aclose()
210-
else:
187+
if aclose is None:
211188
close = getattr(service, "close", None)
212189
if close is not None:
213190
close()
191+
else:
192+
await aclose()
214193
close_proxy(service)
215194

216195
def _close(self) -> None:
217-
for _, service in self._services.items():
196+
for service in self._services.values():
218197
close = getattr(service, "close", None)
219198
if close is not None:
220199
close()
@@ -259,35 +238,27 @@ def _model_inference_name(self, model: Model, step: int | None = None) -> str:
259238
If None, returns name for latest checkpoint (step 0 initially).
260239
"""
261240

262-
def _served_step() -> int | None:
263-
if not isinstance(model, TrainableModel):
264-
return None
265-
if model.name not in self._services:
266-
return None
241+
requested_step = step
242+
243+
if step is None and isinstance(model, TrainableModel):
267244
from ..dev.validate import is_dedicated_mode
268245

269-
if not is_dedicated_mode(
246+
service = self._services.get(model.name)
247+
if service is not None and is_dedicated_mode(
270248
model._internal_config or dev.InternalModelConfig()
271249
):
272-
return None
273-
loaded_step = getattr(self._services[model.name], "_latest_step", None)
274-
return loaded_step if isinstance(loaded_step, int) else None
275-
276-
# For LocalBackend, vLLM always serves LoRA adapters with @step suffix
277-
# Default to step 0 when not specified (the initial checkpoint created at registration)
278-
if step is not None:
279-
actual_step = step
280-
else:
281-
# In dedicated mode the service tracks which adapter vLLM has
282-
# actually loaded. Reading the filesystem would race: the checkpoint
283-
# directory appears before the HTTP reload completes.
284-
actual_step = _served_step()
285-
if actual_step is None:
286-
actual_step = self.__get_step(model)
287-
name = f"{model.name}@{actual_step}"
250+
loaded_step = getattr(service, "_latest_step", None)
251+
if isinstance(loaded_step, int):
252+
step = loaded_step
253+
254+
if step is None:
255+
# The checkpoint directory is written before dedicated-mode
256+
# vLLM finishes reloading the new adapter.
257+
step = self.__get_step(model)
258+
name = f"{model.name}@{step}"
288259
logger.debug(
289-
f"[BACKEND] _model_inference_name: step_arg={step} "
290-
f"actual_step={actual_step} -> {name}"
260+
f"[BACKEND] _model_inference_name: step_arg={requested_step} "
261+
f"actual_step={step} -> {name}"
291262
)
292263
return name
293264

@@ -552,12 +523,14 @@ async def train( # type: ignore[override]
552523
*,
553524
# Core training parameters
554525
learning_rate: float = 5e-6,
526+
loss_fn: Literal["cispo", "ppo", "importance_sampling", "dro"] = "cispo",
527+
loss_fn_config: dict | None = None,
528+
normalize_advantages: bool = True,
529+
adam_params: object | None = None,
555530
# KL-penalized advantage adjustment
556531
kl_penalty_coef: float = 0.0,
557532
kl_penalty_reference_step: int | None = None,
558533
kl_ref_adapter_path: str | None = None,
559-
# RL algorithm settings
560-
ppo: bool = False,
561534
epsilon: float | None = None,
562535
epsilon_high: float | None = None,
563536
# Advantage computation
@@ -594,6 +567,14 @@ async def train( # type: ignore[override]
594567
model: The trainable model to train.
595568
trajectory_groups: Batches of trajectories to train on.
596569
learning_rate: Learning rate for training. Defaults to 5e-6.
570+
loss_fn: RL loss function. LocalBackend currently supports
571+
"cispo" and "ppo".
572+
loss_fn_config: Additional loss-function config. Not supported by
573+
LocalBackend.
574+
normalize_advantages: Whether to normalize advantages. LocalBackend
575+
currently requires True.
576+
adam_params: Custom optimizer params. Not supported by
577+
LocalBackend.
597578
kl_penalty_coef: Coefficient for KL-penalized advantage adjustment.
598579
Tokens diverging more from the reference get reduced advantages.
599580
Defaults to 0.0 (disabled).
@@ -603,7 +584,6 @@ async def train( # type: ignore[override]
603584
kl_ref_adapter_path: Direct filesystem path to a LoRA adapter
604585
checkpoint to use as the KL reference. Alternative to
605586
kl_penalty_reference_step.
606-
ppo: Whether to use PPO clipping. Defaults to False.
607587
epsilon: Clip epsilon for importance sampling. Defaults based on ppo.
608588
epsilon_high: Asymmetric upper clip bound. Defaults to epsilon.
609589
advantage_balance: Balance between negative and positive advantages
@@ -647,6 +627,14 @@ async def train( # type: ignore[override]
647627
# await model.log(metrics=result.metrics, step=result.step)
648628
"""
649629
groups_list = list(trajectory_groups)
630+
if loss_fn not in {"cispo", "ppo"}:
631+
raise ValueError("LocalBackend only supports loss_fn='cispo' or 'ppo'.")
632+
if loss_fn_config is not None:
633+
raise ValueError("LocalBackend requires loss_fn_config=None.")
634+
if not normalize_advantages:
635+
raise ValueError("LocalBackend requires normalize_advantages=True.")
636+
if adam_params is not None:
637+
raise ValueError("LocalBackend requires adam_params=None.")
650638

651639
# Build config objects from explicit kwargs
652640
config = TrainConfig(
@@ -659,7 +647,7 @@ async def train( # type: ignore[override]
659647
"kl_penalty_coef": kl_penalty_coef,
660648
"mask_prob_ratio": mask_prob_ratio,
661649
"plot_tensors": plot_tensors,
662-
"ppo": ppo,
650+
"ppo": loss_fn == "ppo",
663651
"precalculate_logprobs": precalculate_logprobs,
664652
"scale_learning_rate_by_reward_std_dev": scale_learning_rate_by_reward_std_dev,
665653
"scale_rewards": scale_rewards,

src/art/pipeline_trainer/trainer.py

Lines changed: 18 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -278,35 +278,22 @@ async def _notify_policy() -> None:
278278
except asyncio.QueueFull:
279279
loop.create_task(self._output_queue.put(None))
280280

281-
def _is_local_backend(self) -> bool:
282-
from art.local.backend import LocalBackend
283-
284-
return isinstance(self.backend, LocalBackend)
285-
286-
def _local_backend_is_dedicated(self) -> bool:
287-
if not isinstance(self.model, art.TrainableModel):
288-
return False
281+
def _validate_backend_support(self) -> None:
289282
from art.dev.validate import is_dedicated_mode
283+
from art.local.backend import LocalBackend
290284

291-
return is_dedicated_mode(
292-
self.model._internal_config or art.dev.InternalModelConfig()
293-
)
294-
295-
def _validate_backend_support(self) -> None:
296-
if not self._is_local_backend():
297-
return
298-
if self._local_backend_is_dedicated():
299-
self._validate_local_backend_train_config()
285+
if not isinstance(self.backend, LocalBackend):
300286
return
301-
raise ValueError(
302-
"PipelineTrainer only supports LocalBackend in dedicated mode. "
303-
"Shared LocalBackend pauses inference during training and is not "
304-
"a supported async PipelineTrainer path. Set both "
305-
"trainer_gpu_ids and inference_gpu_ids on the TrainableModel "
306-
"_internal_config to use LocalBackend with PipelineTrainer."
307-
)
308287

309-
def _validate_local_backend_train_config(self) -> None:
288+
model_config = self.model._internal_config or art.dev.InternalModelConfig()
289+
if not is_dedicated_mode(model_config):
290+
raise ValueError(
291+
"PipelineTrainer only supports LocalBackend in dedicated mode. "
292+
"Shared LocalBackend pauses inference during training and is not "
293+
"a supported async PipelineTrainer path. Set both "
294+
"trainer_gpu_ids and inference_gpu_ids on the TrainableModel "
295+
"_internal_config to use LocalBackend with PipelineTrainer."
296+
)
310297
if self.loss_fn not in {"cispo", "ppo"}:
311298
raise ValueError(
312299
"PipelineTrainer + LocalBackend(dedicated) only supports "
@@ -327,23 +314,6 @@ def _validate_local_backend_train_config(self) -> None:
327314
"PipelineTrainer + LocalBackend(dedicated) requires adam_params=None."
328315
)
329316

330-
def _backend_train_kwargs(self, *, save_checkpoint: bool) -> dict[str, Any]:
331-
if not self._is_local_backend():
332-
return {
333-
"learning_rate": self.learning_rate,
334-
"loss_fn": self.loss_fn,
335-
"loss_fn_config": self.loss_fn_config,
336-
"normalize_advantages": self.normalize_advantages,
337-
"save_checkpoint": save_checkpoint,
338-
"adam_params": self.adam_params,
339-
}
340-
341-
return {
342-
"learning_rate": self.learning_rate,
343-
"ppo": self.loss_fn == "ppo",
344-
"save_checkpoint": save_checkpoint,
345-
}
346-
347317
async def _skip_scenarios(
348318
self, scenarios: AsyncIterator[ScenarioT], count: int
349319
) -> int:
@@ -479,14 +449,18 @@ async def _training_stage(self) -> None:
479449

480450
self._status.note_training_start(len(batch))
481451
train_call_start = time.monotonic()
482-
train_kwargs = self._backend_train_kwargs(save_checkpoint=should_checkpoint)
483452
if os.getenv("ART_TRAIN_STEP_LOG"):
484453
print(f"[train] step {expected_step} starting (batch={len(batch)})")
485454
try:
486455
result = await self.backend.train(
487456
self.model,
488457
batch,
489-
**train_kwargs,
458+
learning_rate=self.learning_rate,
459+
loss_fn=self.loss_fn,
460+
loss_fn_config=self.loss_fn_config,
461+
normalize_advantages=self.normalize_advantages,
462+
save_checkpoint=should_checkpoint,
463+
adam_params=self.adam_params,
490464
)
491465
except Exception:
492466
self._status.note_training_end()

src/art/unsloth/service.py

Lines changed: 18 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ class _StopTrainInputs:
6060

6161

6262
_STOP_TRAIN_INPUT = _StopTrainInputs()
63-
_TRAIN_TASK_GRACEFUL_SHUTDOWN_TIMEOUT_S = 5.0
64-
_TRAIN_TASK_CANCEL_TIMEOUT_S = 1.0
63+
_TRAIN_TASK_SHUTDOWN_TIMEOUT_S = 5.0
64+
_TrainLoopInput = TrainInputs | _StopTrainInputs
6565

6666

6767
def precalculate_new_logprobs(
@@ -100,7 +100,7 @@ async def process_train_batch(
100100
packed_tensors: PackedTensors,
101101
config: types.TrainConfig,
102102
_config: dev.TrainConfig,
103-
inputs_queue: asyncio.Queue[TrainInputs | _StopTrainInputs],
103+
inputs_queue: asyncio.Queue[_TrainLoopInput],
104104
results_queue: asyncio.Queue[dict[str, float]],
105105
train_task: asyncio.Task[None],
106106
trainer: "GRPOTrainer",
@@ -224,7 +224,7 @@ class UnslothState:
224224
tokenizer: PreTrainedTokenizerBase
225225
peft_model: peft.peft_model.PeftModelForCausalLM
226226
trainer: GRPOTrainer
227-
inputs_queue: asyncio.Queue[TrainInputs | _StopTrainInputs]
227+
inputs_queue: asyncio.Queue[_TrainLoopInput]
228228
results_queue: asyncio.Queue[dict[str, float]]
229229
_is_offloaded: bool = False
230230
_pinned_buffers: dict[str, torch.Tensor] | None = None
@@ -336,44 +336,22 @@ def _next_lora_id(self) -> int:
336336
self._lora_id_counter += 1
337337
return self._lora_id_counter
338338

339-
def _request_train_task_stop(self) -> asyncio.Task[None] | None:
339+
async def aclose(self) -> None:
340340
train_task = self._train_task
341-
if train_task is None:
342-
return None
343-
if train_task.done():
344-
return train_task
345-
346-
# `_state` is a cached_property. Read from __dict__ directly so shutdown
347-
# does not instantiate the full trainer state solely to stop a task.
348-
state = self.__dict__.get("_state")
349-
if isinstance(state, UnslothState):
350-
state.inputs_queue.put_nowait(_STOP_TRAIN_INPUT)
351-
return train_task
352-
353-
async def _shutdown_train_task(self) -> None:
354-
train_task = self._request_train_task_stop()
355-
if train_task is None:
341+
self._train_task = None
342+
if train_task is None or train_task.done():
343+
self.close()
356344
return
357345

346+
# `_state` is a cached_property. Read from __dict__ directly so
347+
# closing does not instantiate trainer state only to stop a task.
348+
state = self.__dict__.get("_state")
349+
assert isinstance(state, UnslothState)
350+
state.inputs_queue.put_nowait(_STOP_TRAIN_INPUT)
358351
try:
359-
# Give the trainer loop time to consume the stop sentinel and exit
360-
# normally before falling back to cancellation.
361-
await asyncio.wait_for(
362-
train_task, timeout=_TRAIN_TASK_GRACEFUL_SHUTDOWN_TIMEOUT_S
363-
)
352+
await asyncio.wait_for(train_task, timeout=_TRAIN_TASK_SHUTDOWN_TIMEOUT_S)
364353
except asyncio.TimeoutError:
365354
train_task.cancel()
366-
try:
367-
await asyncio.wait_for(train_task, timeout=_TRAIN_TASK_CANCEL_TIMEOUT_S)
368-
except (asyncio.CancelledError, asyncio.TimeoutError):
369-
pass
370-
except asyncio.CancelledError:
371-
pass
372-
finally:
373-
self._train_task = None
374-
375-
async def aclose(self) -> None:
376-
await self._shutdown_train_task()
377355
self.close()
378356

379357
# =========================================================================
@@ -500,7 +478,6 @@ async def _reload_adapter(self, checkpoint_path: str, step: int) -> None:
500478

501479
def close(self) -> None:
502480
"""Terminate vLLM subprocess if running."""
503-
self._request_train_task_stop()
504481
if self._vllm_process is None:
505482
return
506483
self._vllm_process.terminate()
@@ -646,7 +623,7 @@ async def _train_dedicated(
646623

647624
await self._state.results_queue.join()
648625

649-
if not hasattr(self, "_train_task") or self._train_task is None:
626+
if self._train_task is None:
650627
self._train_task = asyncio.create_task(
651628
train(
652629
trainer=self._state.trainer,
@@ -736,7 +713,7 @@ async def _train_shared(
736713
await self._state.results_queue.join()
737714

738715
# If we haven't already, start the training task
739-
if not hasattr(self, "_train_task") or self._train_task is None:
716+
if self._train_task is None:
740717
self._train_task = asyncio.create_task(
741718
train(
742719
trainer=self._state.trainer,
@@ -1032,12 +1009,12 @@ def _state(self) -> UnslothState:
10321009
trainer.create_optimizer()
10331010

10341011
# Initialize queues
1035-
inputs_queue: asyncio.Queue[TrainInputs | _StopTrainInputs] = asyncio.Queue()
1012+
inputs_queue: asyncio.Queue[_TrainLoopInput] = asyncio.Queue()
10361013
results_queue: asyncio.Queue[dict[str, float]] = asyncio.Queue()
10371014

10381015
# Patch trainer _prepare_inputs() to pull from queue
10391016
def _async_prepare_inputs(*_: Any, **__: Any) -> dict[str, torch.Tensor]:
1040-
async def get_inputs() -> TrainInputs | _StopTrainInputs:
1017+
async def get_inputs() -> _TrainLoopInput:
10411018
return await inputs_queue.get()
10421019

10431020
# Force otherwise synchronous _prepare_inputs() to yield

src/art/unsloth/train.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,9 @@ async def train(
7878
if not is_train_dict:
7979
trainer._metrics = {"train": defaultdict(list)}
8080
try:
81-
try:
82-
trainer.train()
83-
except StopTrainingLoop:
84-
return
81+
trainer.train()
82+
except StopTrainingLoop:
83+
return
8584
finally:
8685
trainer.compute_loss = _compute_loss
8786
trainer.log = _log # ty:ignore[invalid-assignment]

0 commit comments

Comments
 (0)