@@ -168,25 +168,6 @@ def __exit__(
168168 exc : BaseException | None ,
169169 tb : TracebackType | None ,
170170 ) -> None :
171- try :
172- asyncio .get_running_loop ()
173- except RuntimeError :
174- running_loop = False
175- else :
176- running_loop = True
177-
178- if running_loop or any (
179- getattr (service , "aclose" , None ) is not None
180- for service in self ._services .values ()
181- ):
182- warnings .warn (
183- "LocalBackend used as a sync context manager. Cleanup uses the "
184- "best-effort sync shutdown path and cannot await service "
185- "teardown safely here; use `async with LocalBackend(...)` or "
186- "`await backend.close()` instead." ,
187- RuntimeWarning ,
188- stacklevel = 2 ,
189- )
190171 self ._close ()
191172
192173 async def __aexit__ (
@@ -201,20 +182,18 @@ async def close(self) -> None:
201182 """
202183 If running vLLM in a separate process, this will kill that process and close the communication threads.
203184 """
204- for _ , service in self ._services .items ():
205- # Keep this logic aligned with _close(), but avoid double-closing
206- # services that expose an awaited aclose() path.
185+ for service in self ._services .values ():
207186 aclose = getattr (service , "aclose" , None )
208- if aclose is not None :
209- await aclose ()
210- else :
187+ if aclose is None :
211188 close = getattr (service , "close" , None )
212189 if close is not None :
213190 close ()
191+ else :
192+ await aclose ()
214193 close_proxy (service )
215194
216195 def _close (self ) -> None :
217- for _ , service in self ._services .items ():
196+ for service in self ._services .values ():
218197 close = getattr (service , "close" , None )
219198 if close is not None :
220199 close ()
@@ -259,35 +238,27 @@ def _model_inference_name(self, model: Model, step: int | None = None) -> str:
259238 If None, returns name for latest checkpoint (step 0 initially).
260239 """
261240
262- def _served_step () -> int | None :
263- if not isinstance (model , TrainableModel ):
264- return None
265- if model .name not in self ._services :
266- return None
241+ requested_step = step
242+
243+ if step is None and isinstance (model , TrainableModel ):
267244 from ..dev .validate import is_dedicated_mode
268245
269- if not is_dedicated_mode (
246+ service = self ._services .get (model .name )
247+ if service is not None and is_dedicated_mode (
270248 model ._internal_config or dev .InternalModelConfig ()
271249 ):
272- return None
273- loaded_step = getattr (self ._services [model .name ], "_latest_step" , None )
274- return loaded_step if isinstance (loaded_step , int ) else None
275-
276- # For LocalBackend, vLLM always serves LoRA adapters with @step suffix
277- # Default to step 0 when not specified (the initial checkpoint created at registration)
278- if step is not None :
279- actual_step = step
280- else :
281- # In dedicated mode the service tracks which adapter vLLM has
282- # actually loaded. Reading the filesystem would race: the checkpoint
283- # directory appears before the HTTP reload completes.
284- actual_step = _served_step ()
285- if actual_step is None :
286- actual_step = self .__get_step (model )
287- name = f"{ model .name } @{ actual_step } "
250+ loaded_step = getattr (service , "_latest_step" , None )
251+ if isinstance (loaded_step , int ):
252+ step = loaded_step
253+
254+ if step is None :
255+ # The checkpoint directory is written before dedicated-mode
256+ # vLLM finishes reloading the new adapter.
257+ step = self .__get_step (model )
258+ name = f"{ model .name } @{ step } "
288259 logger .debug (
289- f"[BACKEND] _model_inference_name: step_arg={ step } "
290- f"actual_step={ actual_step } -> { name } "
260+ f"[BACKEND] _model_inference_name: step_arg={ requested_step } "
261+ f"actual_step={ step } -> { name } "
291262 )
292263 return name
293264
@@ -552,12 +523,14 @@ async def train( # type: ignore[override]
552523 * ,
553524 # Core training parameters
554525 learning_rate : float = 5e-6 ,
526+ loss_fn : Literal ["cispo" , "ppo" , "importance_sampling" , "dro" ] = "cispo" ,
527+ loss_fn_config : dict | None = None ,
528+ normalize_advantages : bool = True ,
529+ adam_params : object | None = None ,
555530 # KL-penalized advantage adjustment
556531 kl_penalty_coef : float = 0.0 ,
557532 kl_penalty_reference_step : int | None = None ,
558533 kl_ref_adapter_path : str | None = None ,
559- # RL algorithm settings
560- ppo : bool = False ,
561534 epsilon : float | None = None ,
562535 epsilon_high : float | None = None ,
563536 # Advantage computation
@@ -594,6 +567,14 @@ async def train( # type: ignore[override]
594567 model: The trainable model to train.
595568 trajectory_groups: Batches of trajectories to train on.
596569 learning_rate: Learning rate for training. Defaults to 5e-6.
570+ loss_fn: RL loss function. LocalBackend currently supports
571+ "cispo" and "ppo".
572+ loss_fn_config: Additional loss-function config. Not supported by
573+ LocalBackend.
574+ normalize_advantages: Whether to normalize advantages. LocalBackend
575+ currently requires True.
576+ adam_params: Custom optimizer params. Not supported by
577+ LocalBackend.
597578 kl_penalty_coef: Coefficient for KL-penalized advantage adjustment.
598579 Tokens diverging more from the reference get reduced advantages.
599580 Defaults to 0.0 (disabled).
@@ -603,7 +584,6 @@ async def train( # type: ignore[override]
603584 kl_ref_adapter_path: Direct filesystem path to a LoRA adapter
604585 checkpoint to use as the KL reference. Alternative to
605586 kl_penalty_reference_step.
606- ppo: Whether to use PPO clipping. Defaults to False.
607587 epsilon: Clip epsilon for importance sampling. Defaults based on ppo.
608588 epsilon_high: Asymmetric upper clip bound. Defaults to epsilon.
609589 advantage_balance: Balance between negative and positive advantages
@@ -647,6 +627,14 @@ async def train( # type: ignore[override]
647627 # await model.log(metrics=result.metrics, step=result.step)
648628 """
649629 groups_list = list (trajectory_groups )
630+ if loss_fn not in {"cispo" , "ppo" }:
631+ raise ValueError ("LocalBackend only supports loss_fn='cispo' or 'ppo'." )
632+ if loss_fn_config is not None :
633+ raise ValueError ("LocalBackend requires loss_fn_config=None." )
634+ if not normalize_advantages :
635+ raise ValueError ("LocalBackend requires normalize_advantages=True." )
636+ if adam_params is not None :
637+ raise ValueError ("LocalBackend requires adam_params=None." )
650638
651639 # Build config objects from explicit kwargs
652640 config = TrainConfig (
@@ -659,7 +647,7 @@ async def train( # type: ignore[override]
659647 "kl_penalty_coef" : kl_penalty_coef ,
660648 "mask_prob_ratio" : mask_prob_ratio ,
661649 "plot_tensors" : plot_tensors ,
662- "ppo" : ppo ,
650+ "ppo" : loss_fn == " ppo" ,
663651 "precalculate_logprobs" : precalculate_logprobs ,
664652 "scale_learning_rate_by_reward_std_dev" : scale_learning_rate_by_reward_std_dev ,
665653 "scale_rewards" : scale_rewards ,
0 commit comments