Skip to content

Commit 5ef9eab

Browse files
committed
fix: Preserve sampled KL metric in TinkerNativeBackend
1 parent 79bf8d5 commit 5ef9eab

File tree

2 files changed

+106
-13
lines changed

2 files changed

+106
-13
lines changed

src/art/tinker_native/backend.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -350,23 +350,23 @@ async def train( # type: ignore[override]
350350
train_tokens, pricing
351351
)
352352
trainer_started = time.monotonic()
353+
sampled_kl_policy_ref: float | None = None
353354

354355
if kl_penalty_coef > 0:
355-
reference_sampling_client = await self._get_kl_reference_sampling_client(
356-
state,
357-
model.base_model,
358-
kl_penalty_reference_step,
359-
)
360-
metrics.update(
361-
await self._tinker_sample_call(
362-
"apply_kl_penalty",
363-
_apply_kl_penalty(
364-
datums,
365-
reference_sampling_client,
366-
kl_penalty_coef,
356+
kl_metrics = await self._tinker_sample_call(
357+
"apply_kl_penalty",
358+
_apply_kl_penalty(
359+
datums,
360+
await self._get_kl_reference_sampling_client(
361+
state,
362+
model.base_model,
363+
kl_penalty_reference_step,
367364
),
368-
)
365+
kl_penalty_coef,
366+
),
369367
)
368+
sampled_kl_policy_ref = kl_metrics["loss/kl_policy_ref"]
369+
metrics.update(kl_metrics)
370370

371371
if adam_params is None:
372372
adam_params = tinker.AdamParams(
@@ -405,13 +405,23 @@ def remove_mask(datum: tinker.Datum) -> tinker.Datum:
405405
if value is None:
406406
continue
407407
canonical_key = _canonicalize_upstream_metric_key(key)
408+
if (
409+
sampled_kl_policy_ref is not None
410+
and canonical_key == "loss/kl_policy_ref"
411+
):
412+
continue
408413
if canonical_key:
409414
metrics[canonical_key] = float(value)
410415
if optim_output.metrics:
411416
for key, value in optim_output.metrics.items():
412417
if value is None:
413418
continue
414419
canonical_key = _canonicalize_upstream_metric_key(key)
420+
if (
421+
sampled_kl_policy_ref is not None
422+
and canonical_key == "loss/kl_policy_ref"
423+
):
424+
continue
415425
if canonical_key:
416426
metrics[canonical_key] = float(value)
417427

tests/integration/test_tinker_native_backend.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
import art
1111
from art.tinker_native import TinkerNativeBackend
12+
from art.tinker_native.backend import _apply_kl_penalty
13+
from art.tinker_native.data import trajectory_groups_to_datums
1214

1315
DEFAULT_BASE_MODEL = "Qwen/Qwen3-30B-A3B-Instruct-2507"
1416

@@ -37,6 +39,8 @@ async def simple_rollout(
3739
max_tokens=10,
3840
timeout=60,
3941
temperature=1,
42+
logprobs=True,
43+
top_logprobs=0,
4044
)
4145
choice = chat_completion.choices[0]
4246
content = (choice.message.content or "").lower()
@@ -115,6 +119,85 @@ async def make_group(prompt: str) -> art.TrajectoryGroup:
115119
await backend.close()
116120

117121

122+
@pytest.mark.skipif(
123+
"TINKER_API_KEY" not in os.environ,
124+
reason="TINKER_API_KEY not set - skipping TinkerNativeBackend KL test",
125+
)
126+
async def test_tinker_native_backend_kl_identity_metric():
127+
model_name = f"test-tinker-native-kl-{uuid.uuid4().hex[:8]}"
128+
with tempfile.TemporaryDirectory() as tmpdir:
129+
backend = TinkerNativeBackend(path=tmpdir)
130+
model = art.TrainableModel(
131+
name=model_name,
132+
project="integration-tests",
133+
base_model=get_base_model(),
134+
)
135+
try:
136+
await model.register(backend)
137+
138+
openai_client = model.openai_client()
139+
current_step = await model.get_step()
140+
model_name_step = model.get_inference_name(step=current_step)
141+
prompts = ["Say yes", "Say no", "Say maybe"]
142+
143+
async def make_group(prompt: str) -> art.TrajectoryGroup:
144+
import asyncio
145+
146+
trajectories = await asyncio.gather(
147+
*[
148+
simple_rollout(openai_client, model_name_step, prompt)
149+
for _ in range(2)
150+
]
151+
)
152+
return art.TrajectoryGroup(trajectories) # type: ignore[attr-defined]
153+
154+
train_groups = await art.gather_trajectory_groups( # type: ignore[attr-defined]
155+
[make_group(prompt) for prompt in prompts]
156+
)
157+
ensure_reward_variance(train_groups)
158+
159+
state = backend._model_state[model.name]
160+
datums = trajectory_groups_to_datums(
161+
train_groups,
162+
state.renderer,
163+
state.tokenizer,
164+
)
165+
assert datums
166+
167+
reference_sampling_client = await backend._get_kl_reference_sampling_client(
168+
state,
169+
model.base_model,
170+
current_step,
171+
)
172+
expected_kl = (
173+
await _apply_kl_penalty(
174+
trajectory_groups_to_datums(
175+
train_groups,
176+
state.renderer,
177+
state.tokenizer,
178+
),
179+
reference_sampling_client,
180+
kl_penalty_coef=0.25,
181+
)
182+
)["loss/kl_policy_ref"]
183+
184+
result = await backend.train(
185+
model,
186+
train_groups,
187+
learning_rate=1e-5,
188+
kl_penalty_coef=0.25,
189+
kl_penalty_reference_step=current_step,
190+
)
191+
192+
assert result.metrics["loss/kl_policy_ref"] == pytest.approx(
193+
expected_kl,
194+
abs=0.05,
195+
)
196+
assert result.metrics["loss/kl_policy_ref"] == pytest.approx(0.0, abs=0.05)
197+
finally:
198+
await backend.close()
199+
200+
118201
@pytest.mark.skipif(
119202
"TINKER_API_KEY" not in os.environ,
120203
reason="TINKER_API_KEY not set - skipping TinkerNativeBackend fork test",

0 commit comments

Comments
 (0)