Skip to content

Commit 9775490

Browse files
committed
feat: Add sampled KL support to pipeline backends
1 parent e748071 commit 9775490

File tree

9 files changed

+364
-3
lines changed

9 files changed

+364
-3
lines changed

src/art/dev/train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class TrainConfig(TypedDict, total=False):
1818
]
1919
kimi_k2_tau: float | None
2020
kl_penalty_coef: float
21+
kl_penalty_source: Literal["current_learner", "sample"]
2122
kl_ref_adapter_path: str | None
2223
logprob_calculation_chunk_size: int
2324
mask_prob_ratio: bool

src/art/local/backend.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,7 @@ async def train( # type: ignore[override]
531531
kl_penalty_coef: float = 0.0,
532532
kl_penalty_reference_step: int | None = None,
533533
kl_ref_adapter_path: str | None = None,
534+
kl_penalty_source: Literal["current_learner", "sample"] = "current_learner",
534535
epsilon: float | None = None,
535536
epsilon_high: float | None = None,
536537
# Advantage computation
@@ -584,6 +585,11 @@ async def train( # type: ignore[override]
584585
kl_ref_adapter_path: Direct filesystem path to a LoRA adapter
585586
checkpoint to use as the KL reference. Alternative to
586587
kl_penalty_reference_step.
588+
kl_penalty_source: Which policy's logprobs to compare against the
589+
reference when building the centered KL penalty. Use
590+
"current_learner" to match the original ART implementation, or
591+
"sample" to shape from the rollout policy logprobs, which is
592+
usually better for async/off-policy workloads.
587593
epsilon: Clip epsilon for importance sampling. Defaults based on loss_fn.
588594
epsilon_high: Asymmetric upper clip bound. Defaults to epsilon.
589595
advantage_balance: Balance between negative and positive advantages
@@ -635,16 +641,20 @@ async def train( # type: ignore[override]
635641
raise ValueError("LocalBackend requires normalize_advantages=True.")
636642
if adam_params is not None:
637643
raise ValueError("LocalBackend requires adam_params=None.")
644+
assert kl_penalty_source in {"current_learner", "sample"}
638645

639646
# Build config objects from explicit kwargs
640647
config = TrainConfig(
641-
learning_rate=learning_rate, kl_penalty_coef=kl_penalty_coef
648+
learning_rate=learning_rate,
649+
kl_penalty_coef=kl_penalty_coef,
650+
kl_penalty_source=kl_penalty_source,
642651
)
643652
dev_config: dev.TrainConfig = {
644653
"advantage_balance": advantage_balance,
645654
"allow_training_without_logprobs": allow_training_without_logprobs,
646655
"importance_sampling_level": importance_sampling_level,
647656
"kl_penalty_coef": kl_penalty_coef,
657+
"kl_penalty_source": kl_penalty_source,
648658
"mask_prob_ratio": mask_prob_ratio,
649659
"plot_tensors": plot_tensors,
650660
"ppo": loss_fn == "ppo",

src/art/loss.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,14 @@ def loss_fn(
9595
kl_policy_ref: torch.Tensor | None = None
9696
kl_penalty_coef = experimental_config.get("kl_penalty_coef", 0.0)
9797
if kl_penalty_coef > 0 and ref_logprobs is not None:
98-
kl_per_token = (new_logprobs - ref_logprobs).detach() * assistant_mask
98+
match experimental_config.get("kl_penalty_source", "current_learner"):
99+
case "sample":
100+
kl_source_logprobs = old_logprobs.detach()
101+
case "current_learner":
102+
kl_source_logprobs = new_logprobs.detach()
103+
case other:
104+
raise AssertionError(other)
105+
kl_per_token = (kl_source_logprobs - ref_logprobs).detach() * assistant_mask
99106
avg_kl = kl_per_token.sum() / (assistant_mask.sum() + 1e-6)
100107
kl_penalty = kl_penalty_coef * (avg_kl - kl_per_token) * assistant_mask
101108
advantages = advantages + kl_penalty

src/art/pipeline_trainer/trainer.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ def __init__(
7878
loss_fn_config: dict | None = None,
7979
normalize_advantages: bool = True,
8080
adam_params: object | None = None,
81+
kl_penalty_coef: float = 0.0,
82+
kl_penalty_reference_step: int | None = None,
8183
max_steps: int | None = None,
8284
# Discard handling
8385
discard_queue_multiplier: int = 100,
@@ -129,6 +131,8 @@ def __init__(
129131
self.loss_fn_config = loss_fn_config
130132
self.normalize_advantages = normalize_advantages
131133
self.adam_params = adam_params
134+
self.kl_penalty_coef = kl_penalty_coef
135+
self.kl_penalty_reference_step = kl_penalty_reference_step
132136
self.max_steps = max_steps
133137
self._status_log_interval_seconds = log_interval_seconds
134138
self.eval_every_n_steps = eval_every_n_steps
@@ -452,6 +456,14 @@ async def _training_stage(self) -> None:
452456
if os.getenv("ART_TRAIN_STEP_LOG"):
453457
print(f"[train] step {expected_step} starting (batch={len(batch)})")
454458
try:
459+
kl_train_kwargs: dict[str, object] = {}
460+
if self.kl_penalty_coef > 0.0:
461+
kl_train_kwargs["kl_penalty_coef"] = self.kl_penalty_coef
462+
kl_train_kwargs["kl_penalty_source"] = "sample"
463+
if self.kl_penalty_reference_step is not None:
464+
kl_train_kwargs["kl_penalty_reference_step"] = (
465+
self.kl_penalty_reference_step
466+
)
455467
result = await self.backend.train(
456468
self.model,
457469
batch,
@@ -461,6 +473,7 @@ async def _training_stage(self) -> None:
461473
normalize_advantages=self.normalize_advantages,
462474
save_checkpoint=should_checkpoint,
463475
adam_params=self.adam_params,
476+
**kl_train_kwargs,
464477
)
465478
except Exception:
466479
self._status.note_training_end()

src/art/test/test_kl_advantage.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44

5-
from art.loss import Loss, loss_fn
5+
from art.loss import loss_fn, shift_tensor
66

77

88
def _make_inputs(
@@ -114,3 +114,50 @@ def test_kl_advantage_does_not_affect_when_no_ref():
114114

115115
loss = loss_fn(inputs, new_logprobs, None, None, {"kl_penalty_coef": 0.5})
116116
assert loss.kl_policy_ref is None
117+
118+
119+
def test_kl_advantage_can_use_sample_logprobs() -> None:
120+
"""Sample-source KL should use stored rollout logprobs rather than learner logprobs."""
121+
inputs = _make_inputs(seq_len=8)
122+
inputs["logprobs"] = torch.tensor(
123+
[[0.0, -0.2, -0.4, -0.6, -0.8, -1.0, -1.2, -1.4]], dtype=torch.float32
124+
)
125+
new_logprobs = torch.tensor(
126+
[[0.0, -1.2, -1.1, -1.0, -0.9, -0.8, -0.7, -0.6]], dtype=torch.float32
127+
)
128+
ref_logprobs = torch.full((1, 8), -0.5)
129+
assistant_mask = shift_tensor(inputs["assistant_mask"], False).to(
130+
new_logprobs.dtype
131+
)
132+
sampled_logprobs = torch.where(
133+
torch.isnan(shift_tensor(inputs["logprobs"], float("nan"))),
134+
new_logprobs.detach(),
135+
shift_tensor(inputs["logprobs"], float("nan")),
136+
)
137+
expected_sample_kl = ((sampled_logprobs - ref_logprobs) * assistant_mask).sum() / (
138+
assistant_mask.sum() + 1e-6
139+
)
140+
expected_current_kl = ((new_logprobs - ref_logprobs) * assistant_mask).sum() / (
141+
assistant_mask.sum() + 1e-6
142+
)
143+
144+
sample_loss = loss_fn(
145+
inputs,
146+
new_logprobs,
147+
ref_logprobs,
148+
None,
149+
{"kl_penalty_coef": 0.5, "kl_penalty_source": "sample"},
150+
)
151+
learner_loss = loss_fn(
152+
inputs,
153+
new_logprobs,
154+
ref_logprobs,
155+
None,
156+
{"kl_penalty_coef": 0.5, "kl_penalty_source": "current_learner"},
157+
)
158+
159+
assert sample_loss.kl_policy_ref is not None
160+
assert learner_loss.kl_policy_ref is not None
161+
assert torch.isclose(sample_loss.kl_policy_ref, expected_sample_kl)
162+
assert torch.isclose(learner_loss.kl_policy_ref, expected_current_kl)
163+
assert not torch.isclose(sample_loss.kl_policy_ref, learner_loss.kl_policy_ref)

src/art/tinker_native/backend.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from openai.types.chat.completion_create_params import CompletionCreateParams
2525
from openai.types.completion_usage import CompletionUsage
2626
import tinker
27+
import torch
2728
import uvicorn
2829

2930
from art.tinker.cookbook_v import renderers, tokenizer_utils
@@ -82,6 +83,76 @@ def _canonicalize_upstream_metric_key(metric: str) -> str:
8283
return _UPSTREAM_TRAIN_METRIC_KEYS.get(metric, metric)
8384

8485

86+
async def _apply_kl_penalty(
87+
datums: list[tinker.Datum],
88+
reference_sampling_client: tinker.SamplingClient,
89+
kl_penalty_coef: float,
90+
) -> dict[str, float]:
91+
assert datums
92+
assert kl_penalty_coef > 0.0
93+
94+
full_sequences: list[tinker.ModelInput] = []
95+
sampled_logprobs_by_datum: list[torch.Tensor] = []
96+
masks_by_datum: list[torch.Tensor] = []
97+
advantages_by_datum: list[torch.Tensor] = []
98+
for datum in datums:
99+
target_tokens = datum.loss_fn_inputs["target_tokens"].to_torch()
100+
assert target_tokens.numel() > 0
101+
full_sequences.append(
102+
datum.model_input.append_int(int(target_tokens[-1].item()))
103+
)
104+
sampled_logprobs_by_datum.append(datum.loss_fn_inputs["logprobs"].to_torch())
105+
masks_by_datum.append(datum.loss_fn_inputs["mask"].to_torch().float())
106+
advantages_by_datum.append(datum.loss_fn_inputs["advantages"].to_torch())
107+
108+
reference_logprobs_by_datum = await asyncio.gather(
109+
*[
110+
reference_sampling_client.compute_logprobs_async(full_sequence)
111+
for full_sequence in full_sequences
112+
]
113+
)
114+
115+
logprob_diffs_by_datum: list[torch.Tensor] = []
116+
for reference_logprobs, sampled_logprobs, mask in zip(
117+
reference_logprobs_by_datum,
118+
sampled_logprobs_by_datum,
119+
masks_by_datum,
120+
strict=True,
121+
):
122+
reference_values = reference_logprobs[1:]
123+
assert len(reference_values) == sampled_logprobs.numel()
124+
assert all(value is not None for value in reference_values)
125+
reference_logprobs_tensor = torch.tensor(
126+
reference_values,
127+
dtype=sampled_logprobs.dtype,
128+
)
129+
logprob_diffs_by_datum.append(
130+
(sampled_logprobs - reference_logprobs_tensor) * mask
131+
)
132+
133+
total_tokens = torch.stack([mask.sum() for mask in masks_by_datum]).sum()
134+
assert total_tokens.item() > 0
135+
avg_logprob_diff = (
136+
torch.stack(
137+
[logprob_diff.sum() for logprob_diff in logprob_diffs_by_datum]
138+
).sum()
139+
/ total_tokens
140+
)
141+
142+
for datum, advantages, mask, logprob_diff in zip(
143+
datums,
144+
advantages_by_datum,
145+
masks_by_datum,
146+
logprob_diffs_by_datum,
147+
strict=True,
148+
):
149+
datum.loss_fn_inputs["advantages"] = tinker.TensorData.from_torch(
150+
advantages + kl_penalty_coef * (avg_logprob_diff - logprob_diff) * mask
151+
)
152+
153+
return {"loss/kl_policy_ref": float(avg_logprob_diff)}
154+
155+
85156
@dataclass
86157
class ModelState:
87158
service_client: tinker.ServiceClient
@@ -239,6 +310,9 @@ async def train( # type: ignore[override]
239310
save_checkpoint: bool = False,
240311
loss_fn_config: dict | None = None,
241312
adam_params: tinker.AdamParams | None = None,
313+
kl_penalty_coef: float = 0.0,
314+
kl_penalty_reference_step: int | None = None,
315+
kl_penalty_source: Literal["sample"] = "sample",
242316
) -> TrainResult:
243317
state = self._model_state[model.name]
244318
groups_list = list(trajectory_groups)
@@ -259,6 +333,10 @@ async def train( # type: ignore[override]
259333
"data/step_num_datums": float(len(datums)),
260334
}
261335

336+
assert kl_penalty_source == "sample", (
337+
"TinkerNativeBackend only supports kl_penalty_source='sample'."
338+
)
339+
262340
if not datums:
263341
return TrainResult(step=state.current_step, metrics=metrics)
264342

@@ -273,6 +351,23 @@ async def train( # type: ignore[override]
273351
)
274352
trainer_started = time.monotonic()
275353

354+
if kl_penalty_coef > 0:
355+
reference_sampling_client = await self._get_kl_reference_sampling_client(
356+
state,
357+
model.base_model,
358+
kl_penalty_reference_step,
359+
)
360+
metrics.update(
361+
await self._tinker_sample_call(
362+
"apply_kl_penalty",
363+
_apply_kl_penalty(
364+
datums,
365+
reference_sampling_client,
366+
kl_penalty_coef,
367+
),
368+
)
369+
)
370+
276371
if adam_params is None:
277372
adam_params = tinker.AdamParams(
278373
learning_rate=learning_rate,
@@ -697,6 +792,19 @@ async def _get_sampler_client(
697792
state.sampler_clients[actual_step] = sampler_client
698793
return sampler_client
699794

795+
async def _get_kl_reference_sampling_client(
796+
self,
797+
state: ModelState,
798+
base_model: str,
799+
step: int | None,
800+
) -> tinker.SamplingClient:
801+
if step is not None:
802+
return await self._get_sampler_client(state, step)
803+
return await self._tinker_sample_call(
804+
"create_sampling_client_async",
805+
state.service_client.create_sampling_client_async(base_model=base_model),
806+
)
807+
700808
def _normalize_messages(self, messages: Iterable[Any]) -> list[dict[str, Any]]:
701809
normalized: list[dict[str, Any]] = []
702810
for message in messages:

src/art/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
class TrainConfig(pydantic.BaseModel):
1818
learning_rate: float = 5e-6
1919
kl_penalty_coef: float = 0.0
20+
kl_penalty_source: Literal["current_learner", "sample"] = "current_learner"
2021

2122

2223
class TrainSFTConfig(pydantic.BaseModel):

0 commit comments

Comments
 (0)