Skip to content

Commit dbefea6

Browse files
committed
refactor: Narrow LocalBackend train types
1 parent 47579de commit dbefea6

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/art/local/backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,7 @@ async def train( # type: ignore[override]
523523
*,
524524
# Core training parameters
525525
learning_rate: float = 5e-6,
526-
loss_fn: Literal["cispo", "ppo", "importance_sampling", "dro"] = "cispo",
526+
loss_fn: Literal["cispo", "ppo"] = "cispo",
527527
loss_fn_config: dict | None = None,
528528
normalize_advantages: bool = True,
529529
adam_params: object | None = None,
@@ -584,7 +584,7 @@ async def train( # type: ignore[override]
584584
kl_ref_adapter_path: Direct filesystem path to a LoRA adapter
585585
checkpoint to use as the KL reference. Alternative to
586586
kl_penalty_reference_step.
587-
epsilon: Clip epsilon for importance sampling. Defaults based on ppo.
587+
epsilon: Clip epsilon for importance sampling. Defaults based on loss_fn.
588588
epsilon_high: Asymmetric upper clip bound. Defaults to epsilon.
589589
advantage_balance: Balance between negative and positive advantages
590590
in range [-1.0, 1.0]. Defaults to 0.0 (balanced).

0 commit comments

Comments
 (0)