File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -65,8 +65,15 @@ def loss_fn(
6565 prob_ratio = (prob_ratio + sequence_prob_ratio ) / 2
6666 elif importance_sampling_level == "geometric_average" :
6767 prob_ratio = (prob_ratio ** 0.5 ) * (sequence_prob_ratio ** 0.5 )
68- epsilon = experimental_config .get ("epsilon" , 0.2 )
69- epsilon_high = experimental_config .get ("epsilon_high" , epsilon )
68+ ppo = experimental_config .get ("ppo" , False )
69+ if ppo :
70+ epsilon_default = 0.2
71+ epsilon_high_default = None
72+ else :
73+ epsilon_default = 1.0
74+ epsilon_high_default = 4.0
75+ epsilon = experimental_config .get ("epsilon" , epsilon_default )
76+ epsilon_high = experimental_config .get ("epsilon_high" , epsilon_high_default )
7077 if epsilon_high is None :
7178 epsilon_high = epsilon
7279 if max_negative_advantage_importance_sampling_weight := experimental_config .get (
@@ -83,7 +90,7 @@ def loss_fn(
8390 )
8491 if tau := experimental_config .get ("kimi_k2_tau" , None ):
8592 advantages -= tau * logprob_diff .detach ()
86- if experimental_config . get ( " ppo" , True ) :
93+ if ppo :
8794 policy_loss = - torch .min (
8895 prob_ratio * advantages ,
8996 torch .clip (prob_ratio , 1 - epsilon , 1 + epsilon_high ) * advantages ,
You can’t perform that action at this time.
0 commit comments