-
Notifications
You must be signed in to change notification settings - Fork 804
Expand file tree
/
Copy pathtrain.py
More file actions
36 lines (31 loc) · 1.3 KB
/
train.py
File metadata and controls
36 lines (31 loc) · 1.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from typing import Literal
from typing_extensions import TypedDict
class TrainConfig(TypedDict, total=False):
advantage_balance: float
"""Balance between negative and positive advantages in the range [-1.0, 1.0]. \
-1.0 means only training on negative advantages, 1.0 means only training on \
positive advantages. Defaults to 0.0 (perfectly balanced)."""
allow_training_without_logprobs: bool
epsilon: float # clip epsilon, using the same name as TRL
epsilon_high: (
float | None
) # asymmetric clip upper bound. Defaults to epsilon when None
importance_sampling_level: Literal[
"token", "sequence", "average", "geometric_average"
]
kimi_k2_tau: float | None
kl_penalty_coef: float
kl_penalty_source: Literal["current_learner", "sample"]
kl_ref_adapter_path: str | None
logprob_calculation_chunk_size: int
mask_prob_ratio: bool
max_negative_advantage_importance_sampling_weight: float
num_trajectories_learning_rate_multiplier_power: float
plot_tensors: bool
ppo: bool
precalculate_logprobs: bool
scale_learning_rate_by_reward_std_dev: bool
scale_rewards: bool
truncated_importance_sampling: float | None
class TrainSFTConfig(TypedDict, total=False):
"""Experimental SFT configuration options. Use at your own risk."""