Skip to content

Commit 2daa845

Browse files
authored
feat: Default loss to CISPO
1 parent b1da539 commit 2daa845

1 file changed

Lines changed: 10 additions & 3 deletions

File tree

src/art/loss.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,15 @@ def loss_fn(
6565
prob_ratio = (prob_ratio + sequence_prob_ratio) / 2
6666
elif importance_sampling_level == "geometric_average":
6767
prob_ratio = (prob_ratio**0.5) * (sequence_prob_ratio**0.5)
68-
epsilon = experimental_config.get("epsilon", 0.2)
69-
epsilon_high = experimental_config.get("epsilon_high", epsilon)
68+
ppo = experimental_config.get("ppo", False)
69+
if ppo:
70+
epsilon_default = 0.2
71+
epsilon_high_default = None
72+
else:
73+
epsilon_default = 1.0
74+
epsilon_high_default = 4.0
75+
epsilon = experimental_config.get("epsilon", epsilon_default)
76+
epsilon_high = experimental_config.get("epsilon_high", epsilon_high_default)
7077
if epsilon_high is None:
7178
epsilon_high = epsilon
7279
if max_negative_advantage_importance_sampling_weight := experimental_config.get(
@@ -83,7 +90,7 @@ def loss_fn(
8390
)
8491
if tau := experimental_config.get("kimi_k2_tau", None):
8592
advantages -= tau * logprob_diff.detach()
86-
if experimental_config.get("ppo", True):
93+
if ppo:
8794
policy_loss = -torch.min(
8895
prob_ratio * advantages,
8996
torch.clip(prob_ratio, 1 - epsilon, 1 + epsilon_high) * advantages,

0 commit comments

Comments
 (0)