Skip to content

Commit d69345e

Browse files
arcticflyclaude
andauthored
Remove beta KL divergence from training loss (#607)
Remove the Schulman KL estimator (beta * KL) that was added directly to the training loss. The kl_penalty_coef mechanism (advantage adjustment) remains as the preferred approach for KL regularization. Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent ec1c174 commit d69345e

7 files changed

Lines changed: 6 additions & 25 deletions

File tree

src/art/local/backend.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,6 @@ async def train( # type: ignore[override]
427427
*,
428428
# Core training parameters
429429
learning_rate: float = 5e-6,
430-
beta: float = 0.0,
431430
# KL-penalized advantage adjustment
432431
kl_penalty_coef: float = 0.0,
433432
kl_penalty_reference_step: int | None = None,
@@ -470,7 +469,6 @@ async def train( # type: ignore[override]
470469
model: The trainable model to train.
471470
trajectory_groups: Batches of trajectories to train on.
472471
learning_rate: Learning rate for training. Defaults to 5e-6.
473-
beta: KL penalty coefficient added to the loss. Defaults to 0.0.
474472
kl_penalty_coef: Coefficient for KL-penalized advantage adjustment.
475473
Tokens diverging more from the reference get reduced advantages.
476474
Defaults to 0.0 (disabled).
@@ -527,7 +525,7 @@ async def train( # type: ignore[override]
527525

528526
# Build config objects from explicit kwargs
529527
config = TrainConfig(
530-
learning_rate=learning_rate, beta=beta, kl_penalty_coef=kl_penalty_coef
528+
learning_rate=learning_rate, kl_penalty_coef=kl_penalty_coef
531529
)
532530
dev_config: dev.TrainConfig = {
533531
"advantage_balance": advantage_balance,

src/art/loss.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
class Loss(BaseModel):
1515
model_config = ConfigDict(arbitrary_types_allowed=True)
1616
mean_policy_loss: torch.Tensor
17-
mean_kl: torch.Tensor
1817
mean_entropy: torch.Tensor | None
1918
policy_loss_sum: torch.Tensor
2019
probs_corr: torch.Tensor
@@ -124,16 +123,8 @@ def loss_fn(
124123
logprob_diff = old_logprobs - original_logprobs
125124
prob_ratio = torch.exp(logprob_diff)
126125
policy_loss *= torch.clamp(prob_ratio, max=upper_bound).detach()
127-
if ref_logprobs is not None:
128-
kl_div = (
129-
torch.exp(ref_logprobs - new_logprobs) - (ref_logprobs - new_logprobs) - 1.0
130-
)
131-
else:
132-
kl_div = torch.zeros_like(policy_loss)
133126
policy_loss = policy_loss * weights * assistant_mask
134-
kl_div = kl_div * weights * assistant_mask
135127
mean_policy_loss = policy_loss.sum() / (assistant_mask.sum() + 1e-6)
136-
mean_kl = kl_div.sum() / (assistant_mask.sum() + 1e-6)
137128
# Compute mean entropy for the current step
138129
if entropies is not None:
139130
shifted_entropies = shift_tensor(entropies, 0.0)
@@ -144,7 +135,6 @@ def loss_fn(
144135
mean_entropy = None
145136
return Loss(
146137
mean_policy_loss=mean_policy_loss,
147-
mean_kl=mean_kl,
148138
mean_entropy=mean_entropy,
149139
policy_loss_sum=policy_loss.sum(),
150140
probs_corr=probs_corr,

src/art/megatron/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def print0(*values: Any) -> None:
250250
)
251251
probs_corr = loss.probs_corr.item()
252252
print0("Correlation between old and new probabilities:", probs_corr)
253-
loss = loss.mean_policy_loss + config.beta * loss.mean_kl
253+
loss = loss.mean_policy_loss
254254
loss.backward()
255255
# Reduce LoRA grads
256256
start = time.perf_counter()

src/art/preprocessing/inputs.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,7 @@ def create_train_inputs(
4141
[None] if warmup else packed_tensors["image_grid_thw"][offset : offset + 1]
4242
),
4343
config=(
44-
config.model_copy(
45-
update={"learning_rate": 1e-9, "beta": 0.0, "kl_penalty_coef": 0.0}
46-
)
44+
config.model_copy(update={"learning_rate": 1e-9, "kl_penalty_coef": 0.0})
4745
if warmup
4846
else config
4947
),

src/art/serverless/backend.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,6 @@ async def train( # type: ignore[override]
149149
*,
150150
# Core training parameters
151151
learning_rate: float = 5e-6,
152-
beta: float = 0.0,
153152
# RL algorithm settings
154153
ppo: bool = False,
155154
epsilon: float | None = None,
@@ -179,7 +178,6 @@ async def train( # type: ignore[override]
179178
model: The trainable model to train.
180179
trajectory_groups: Batches of trajectories to train on.
181180
learning_rate: Learning rate for training. Defaults to 5e-6.
182-
beta: KL penalty coefficient. Defaults to 0.0.
183181
ppo: Whether to use PPO clipping. Defaults to False.
184182
epsilon: Clip epsilon for importance sampling. Defaults based on ppo.
185183
epsilon_high: Asymmetric upper clip bound. Defaults to epsilon.
@@ -212,7 +210,7 @@ async def train( # type: ignore[override]
212210
groups_list = list(trajectory_groups)
213211

214212
# Build config objects from explicit kwargs
215-
config = TrainConfig(learning_rate=learning_rate, beta=beta)
213+
config = TrainConfig(learning_rate=learning_rate)
216214
dev_config: dev.TrainConfig = {
217215
"advantage_balance": advantage_balance,
218216
"importance_sampling_level": importance_sampling_level,

src/art/types.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
class TrainConfig(pydantic.BaseModel):
1818
learning_rate: float = 5e-6
19-
beta: float = 0.0
2019
kl_penalty_coef: float = 0.0
2120

2221

src/art/unsloth/train.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def compute_loss(
138138
)
139139
if return_new_logprobs:
140140
return torch.nn.functional.pad(new_logprobs[:, :-1], (1, 0), value=0.0)
141-
if config.beta > 0.0 or config.kl_penalty_coef > 0.0:
141+
if config.kl_penalty_coef > 0.0:
142142
ref_adapter = _config.get("kl_ref_adapter_path")
143143
ref_logprobs, _ = calculate_logprobs(
144144
dtype_for_autocasting,
@@ -173,11 +173,9 @@ def compute_loss(
173173
trainer._metrics["train"]["policy_loss"].append(loss.mean_policy_loss.item())
174174
if loss.mean_entropy is not None:
175175
trainer._metrics["train"]["entropy"].append(loss.mean_entropy.item())
176-
if config.beta > 0.0:
177-
trainer._metrics["train"]["kl_div"].append(loss.mean_kl.item())
178176
if loss.kl_policy_ref is not None:
179177
trainer._metrics["train"]["kl_policy_ref"].append(loss.kl_policy_ref.item())
180-
return loss.mean_policy_loss + config.beta * loss.mean_kl
178+
return loss.mean_policy_loss
181179

182180
return compute_loss
183181

0 commit comments

Comments
 (0)