Skip to content

Commit 078990e

Browse files
arcticflyclaude
andauthored
Add KL-penalized advantage adjustment (#562)
* Add KL-penalized advantage adjustment Introduces a new mechanism that adjusts per-token advantages based on KL divergence from a reference model. Tokens where the policy has drifted more get reduced advantages, while tokens that drifted less get increased advantages. The adjustment is zero-mean (centered) across tokens. New parameters on LocalBackend.train(): - kl_penalty_coef: coefficient for the adjustment (0.0 = disabled) - kl_penalty_reference_step: use a specific checkpoint step as reference - kl_ref_adapter_path: use an arbitrary LoRA adapter path as reference Also fixes a pre-existing bug in preprocessing/inputs.py where warmup config used incorrect field names (lr → learning_rate, kl_coef → kl_penalty_coef). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Fix import sorting and formatting Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 524feb0 commit 078990e

9 files changed

Lines changed: 423 additions & 14 deletions

File tree

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
"""Launch yes-no-maybe-kl-advantage training on SkyPilot (Kubernetes).
2+
3+
Usage:
4+
uv run dev/run_yes_no_maybe_kl_advantage.py
5+
uv run dev/run_yes_no_maybe_kl_advantage.py --fast
6+
uv run dev/run_yes_no_maybe_kl_advantage.py --base-model Qwen/Qwen2.5-7B-Instruct
7+
"""
8+
9+
import argparse
10+
import os
11+
import textwrap
12+
13+
from dotenv import load_dotenv
14+
import sky
15+
from sky import ClusterStatus
16+
17+
load_dotenv()
18+
19+
parser = argparse.ArgumentParser(
20+
description="Launch yes-no-maybe KL advantage training on SkyPilot."
21+
)
22+
parser.add_argument(
23+
"--fast", action="store_true", help="Skip setup (for re-runs on existing cluster)."
24+
)
25+
parser.add_argument(
26+
"--base-model", type=str, default="meta-llama/Meta-Llama-3.1-8B-Instruct"
27+
)
28+
parser.add_argument("--num-steps", type=int, default=20)
29+
parser.add_argument("--kl-penalty-coef", type=float, default=0.1)
30+
parser.add_argument("--accelerator", type=str, default="H200:1")
31+
parser.add_argument("--cluster-name", type=str, default=None)
32+
parser.add_argument(
33+
"--kl-ref-step",
34+
type=int,
35+
default=None,
36+
help="Checkpoint step of training model to use as KL reference",
37+
)
38+
parser.add_argument(
39+
"--kl-ref-adapter-path",
40+
type=str,
41+
default=None,
42+
help="Path to LoRA adapter checkpoint to use as KL reference",
43+
)
44+
args = parser.parse_args()
45+
46+
cluster_name = args.cluster_name or f"ynm-kl-{args.kl_penalty_coef}"
47+
cluster_prefix = os.environ.get("CLUSTER_PREFIX")
48+
if cluster_prefix:
49+
cluster_name = f"{cluster_prefix}-{cluster_name}"
50+
51+
setup_script = textwrap.dedent("""\
52+
echo 'Setting up environment...'
53+
apt install -y nvtop
54+
curl -LsSf https://astral.sh/uv/install.sh | sh
55+
source $HOME/.local/bin/env
56+
""")
57+
58+
kl_ref_env = ""
59+
if args.kl_ref_step is not None:
60+
kl_ref_env = f"KL_REF_STEP={args.kl_ref_step} "
61+
elif args.kl_ref_adapter_path is not None:
62+
kl_ref_env = f"KL_REF_ADAPTER_PATH={args.kl_ref_adapter_path} "
63+
64+
run_script = textwrap.dedent(f"""\
65+
source $HOME/.local/bin/env
66+
cd ~/sky_workdir
67+
{kl_ref_env}BASE_MODEL={args.base_model} NUM_STEPS={args.num_steps} KL_PENALTY_COEF={args.kl_penalty_coef} uv run --python 3.11 --extra backend dev/yes-no-maybe-kl-advantage.py
68+
""")
69+
70+
task = sky.Task(
71+
name="yes-no-maybe-kl-advantage",
72+
setup=setup_script,
73+
run=run_script,
74+
workdir=".",
75+
)
76+
task.set_resources(
77+
sky.Resources(accelerators=args.accelerator, cloud=sky.clouds.Kubernetes())
78+
)
79+
task.set_file_mounts(
80+
{
81+
"~/sky_workdir/.env": ".env",
82+
}
83+
)
84+
85+
print(f"Launching on cluster: {cluster_name}")
86+
print(f" base_model: {args.base_model}")
87+
print(f" accelerator: {args.accelerator}")
88+
print(f" num_steps: {args.num_steps}")
89+
print(f" kl_penalty_coef: {args.kl_penalty_coef}")
90+
if args.kl_ref_step is not None:
91+
print(f" kl_ref_step: {args.kl_ref_step}")
92+
if args.kl_ref_adapter_path is not None:
93+
print(f" kl_ref_adapter_path: {args.kl_ref_adapter_path}")
94+
95+
# Cancel any existing jobs on this cluster
96+
cluster_status = sky.stream_and_get(sky.status(cluster_names=[cluster_name]))
97+
if len(cluster_status) > 0 and cluster_status[0]["status"] == ClusterStatus.UP:
98+
print(f"Cluster {cluster_name} is UP. Canceling any active jobs...")
99+
sky.stream_and_get(sky.cancel(cluster_name, all=True))
100+
101+
job_id, _ = sky.stream_and_get(
102+
sky.launch(
103+
task,
104+
cluster_name=cluster_name,
105+
retry_until_up=True,
106+
idle_minutes_to_autostop=60,
107+
down=True,
108+
fast=args.fast,
109+
)
110+
)
111+
112+
print(f"Job submitted (ID: {job_id}). Streaming logs...")
113+
exit_code = sky.tail_logs(cluster_name=cluster_name, job_id=job_id, follow=True)
114+
print(f"Job {job_id} finished with exit code {exit_code}.")

dev/yes-no-maybe-kl-advantage.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
"""Yes-no-maybe training with KL-penalized advantage adjustment.
2+
3+
Demonstrates the kl_penalty_coef feature: tokens where the policy has drifted
4+
more from the reference model get reduced advantages, while tokens that have
5+
drifted less get increased advantages.
6+
7+
Uses meta-llama/Meta-Llama-3.1-8B-Instruct as the base model (trained locally).
8+
"""
9+
10+
import asyncio
11+
from itertools import permutations
12+
import os
13+
14+
from dotenv import load_dotenv
15+
import openai
16+
17+
import art
18+
from art.local import LocalBackend
19+
20+
21+
async def rollout(
22+
client: openai.AsyncOpenAI, model: art.TrainableModel, prompt: str
23+
) -> art.Trajectory:
24+
messages: art.Messages = [
25+
{
26+
"role": "user",
27+
"content": prompt,
28+
}
29+
]
30+
chat_completion = await client.chat.completions.create(
31+
messages=messages, model=model.get_inference_name(), max_tokens=100, timeout=100
32+
)
33+
choice = chat_completion.choices[0]
34+
content = choice.message.content
35+
assert isinstance(content, str)
36+
if content == "yes":
37+
reward = 0.5
38+
elif content == "no":
39+
reward = 0.75
40+
elif content == "maybe":
41+
reward = 1.0
42+
else:
43+
reward = 0.0
44+
return art.Trajectory(messages_and_choices=[*messages, choice], reward=reward)
45+
46+
47+
def with_quotes(w: str) -> str:
48+
return f"'{w}'"
49+
50+
51+
async def main():
52+
load_dotenv()
53+
54+
backend = LocalBackend()
55+
base_model = os.environ.get("BASE_MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
56+
kl_penalty_coef = float(os.environ.get("KL_PENALTY_COEF", "0.1"))
57+
model = art.TrainableModel(
58+
name=os.environ.get("MODEL_NAME", f"kl-{kl_penalty_coef}"),
59+
project="yes-no-maybe",
60+
base_model=base_model,
61+
)
62+
await model.register(backend)
63+
64+
kl_penalty_reference_step: int | None = (
65+
int(os.environ["KL_REF_STEP"])
66+
if os.environ.get("KL_REF_STEP") is not None
67+
else None
68+
)
69+
kl_ref_adapter_path: str | None = os.environ.get("KL_REF_ADAPTER_PATH") or None
70+
71+
prompts = [
72+
f"{prefix} with {', '.join([with_quotes(w) if use_quotes else w for w in words]) if len(words) == 3 else f'{words[0]}' + (f' or {words[1]}' if len(words) > 1 else '')}"
73+
for prefix in ["respond", "just respond"]
74+
for use_quotes in [True, False]
75+
for words in (
76+
list(p) for n in [3, 2] for p in permutations(["yes", "no", "maybe"], n)
77+
)
78+
]
79+
80+
openai_client = model.openai_client()
81+
max_steps = int(os.environ.get("NUM_STEPS", "20"))
82+
start_step = await model.get_step()
83+
for step in range(start_step, start_step + max_steps):
84+
train_groups = await art.gather_trajectory_groups(
85+
(
86+
art.TrajectoryGroup(
87+
rollout(openai_client, model, prompt) for _ in range(32)
88+
)
89+
for prompt in prompts
90+
)
91+
)
92+
result = await backend.train(
93+
model,
94+
train_groups,
95+
learning_rate=1e-4,
96+
kl_penalty_coef=kl_penalty_coef,
97+
kl_penalty_reference_step=kl_penalty_reference_step,
98+
kl_ref_adapter_path=kl_ref_adapter_path,
99+
)
100+
await model.log(
101+
train_groups,
102+
metrics=result.metrics,
103+
step=result.step,
104+
split="train",
105+
)
106+
print(f"step {result.step}: {result.metrics}")
107+
108+
109+
if __name__ == "__main__":
110+
asyncio.run(main())

src/art/dev/train.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ class TrainConfig(TypedDict, total=False):
1717
"token", "sequence", "average", "geometric_average"
1818
]
1919
kimi_k2_tau: float | None
20+
kl_penalty_coef: float
21+
kl_ref_adapter_path: str | None
2022
logprob_calculation_chunk_size: int
2123
mask_prob_ratio: bool
2224
max_negative_advantage_importance_sampling_weight: float

src/art/local/backend.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,10 @@ async def train( # type: ignore[override]
391391
# Core training parameters
392392
learning_rate: float = 5e-6,
393393
beta: float = 0.0,
394+
# KL-penalized advantage adjustment
395+
kl_penalty_coef: float = 0.0,
396+
kl_penalty_reference_step: int | None = None,
397+
kl_ref_adapter_path: str | None = None,
394398
# RL algorithm settings
395399
ppo: bool = False,
396400
epsilon: float | None = None,
@@ -429,7 +433,16 @@ async def train( # type: ignore[override]
429433
model: The trainable model to train.
430434
trajectory_groups: Batches of trajectories to train on.
431435
learning_rate: Learning rate for training. Defaults to 5e-6.
432-
beta: KL penalty coefficient. Defaults to 0.0.
436+
beta: KL penalty coefficient added to the loss. Defaults to 0.0.
437+
kl_penalty_coef: Coefficient for KL-penalized advantage adjustment.
438+
Tokens diverging more from the reference get reduced advantages.
439+
Defaults to 0.0 (disabled).
440+
kl_penalty_reference_step: Checkpoint step of the training model to
441+
use as the KL reference. If None, uses the base model (LoRA
442+
disabled) as reference.
443+
kl_ref_adapter_path: Direct filesystem path to a LoRA adapter
444+
checkpoint to use as the KL reference. Alternative to
445+
kl_penalty_reference_step.
433446
ppo: Whether to use PPO clipping. Defaults to False.
434447
epsilon: Clip epsilon for importance sampling. Defaults based on ppo.
435448
epsilon_high: Asymmetric upper clip bound. Defaults to epsilon.
@@ -476,11 +489,14 @@ async def train( # type: ignore[override]
476489
groups_list = list(trajectory_groups)
477490

478491
# Build config objects from explicit kwargs
479-
config = TrainConfig(learning_rate=learning_rate, beta=beta)
492+
config = TrainConfig(
493+
learning_rate=learning_rate, beta=beta, kl_penalty_coef=kl_penalty_coef
494+
)
480495
dev_config: dev.TrainConfig = {
481496
"advantage_balance": advantage_balance,
482497
"allow_training_without_logprobs": allow_training_without_logprobs,
483498
"importance_sampling_level": importance_sampling_level,
499+
"kl_penalty_coef": kl_penalty_coef,
484500
"mask_prob_ratio": mask_prob_ratio,
485501
"plot_tensors": plot_tensors,
486502
"ppo": ppo,
@@ -503,6 +519,14 @@ async def train( # type: ignore[override]
503519
dev_config["kimi_k2_tau"] = kimi_k2_tau
504520
if truncated_importance_sampling is not None:
505521
dev_config["truncated_importance_sampling"] = truncated_importance_sampling
522+
if kl_ref_adapter_path is not None:
523+
dev_config["kl_ref_adapter_path"] = kl_ref_adapter_path
524+
elif kl_penalty_reference_step is not None:
525+
ref_checkpoint_dir = get_step_checkpoint_dir(
526+
get_model_dir(model=model, art_path=self._path),
527+
kl_penalty_reference_step,
528+
)
529+
dev_config["kl_ref_adapter_path"] = ref_checkpoint_dir
506530

507531
# Collect metrics from training
508532
training_metrics: list[dict[str, float]] = []

src/art/loss.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class Loss(BaseModel):
1818
mean_entropy: torch.Tensor | None
1919
policy_loss_sum: torch.Tensor
2020
probs_corr: torch.Tensor
21+
kl_policy_ref: torch.Tensor | None = None
2122

2223

2324
def loss_fn(
@@ -92,6 +93,14 @@ def loss_fn(
9293
)
9394
if tau := experimental_config.get("kimi_k2_tau", None):
9495
advantages -= tau * logprob_diff.detach()
96+
kl_policy_ref: torch.Tensor | None = None
97+
kl_penalty_coef = experimental_config.get("kl_penalty_coef", 0.0)
98+
if kl_penalty_coef > 0 and ref_logprobs is not None:
99+
kl_per_token = (new_logprobs - ref_logprobs).detach() * assistant_mask
100+
avg_kl = kl_per_token.sum() / (assistant_mask.sum() + 1e-6)
101+
kl_penalty = kl_penalty_coef * (avg_kl - kl_per_token) * assistant_mask
102+
advantages = advantages + kl_penalty
103+
kl_policy_ref = avg_kl
95104
if ppo:
96105
policy_loss = -torch.min(
97106
prob_ratio * advantages,
@@ -139,6 +148,7 @@ def loss_fn(
139148
mean_entropy=mean_entropy,
140149
policy_loss_sum=policy_loss.sum(),
141150
probs_corr=probs_corr,
151+
kl_policy_ref=kl_policy_ref,
142152
)
143153

144154

src/art/preprocessing/inputs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ def create_train_inputs(
4141
[None] if warmup else packed_tensors["image_grid_thw"][offset : offset + 1]
4242
),
4343
config=(
44-
config.model_copy(update={"lr": 1e-9, "beta": 0.0, "kl_coef": 0.0})
44+
config.model_copy(
45+
update={"learning_rate": 1e-9, "beta": 0.0, "kl_penalty_coef": 0.0}
46+
)
4547
if warmup
4648
else config
4749
),

0 commit comments

Comments
 (0)