Skip to content

Commit 79bf8d5

Browse files
committed
test: Add PipelineTrainer KL smoke coverage
1 parent e113b3b commit 79bf8d5

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

tests/integration/test_pipeline_localbackend_dedicated.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
"""Dedicated LocalBackend smoke test for PipelineTrainer."""
22

33
import asyncio
4+
import json
5+
import math
46
import os
7+
from pathlib import Path
58
import tempfile
69
import uuid
710

@@ -163,6 +166,8 @@ async def rollout_fn(
163166
min_batch_size=1,
164167
max_batch_size=1,
165168
max_steps=2,
169+
kl_penalty_coef=0.25,
170+
kl_penalty_reference_step=0,
166171
loss_fn="cispo",
167172
eval_fn=None,
168173
)
@@ -180,5 +185,23 @@ async def rollout_fn(
180185
model_ids = [m.id async for m in client.models.list()]
181186
assert f"{model.name}@0" in model_ids
182187
assert f"{model.name}@{latest_step}" in model_ids
188+
189+
history_path = (
190+
Path(tmpdir)
191+
/ model.project
192+
/ "models"
193+
/ model.name
194+
/ "history.jsonl"
195+
)
196+
history_rows = [
197+
json.loads(line) for line in history_path.read_text().splitlines()
198+
]
199+
kl_values = [
200+
row["loss/kl_policy_ref"]
201+
for row in history_rows
202+
if "loss/kl_policy_ref" in row
203+
]
204+
assert kl_values
205+
assert all(math.isfinite(value) for value in kl_values)
183206
finally:
184207
await client.close()

0 commit comments

Comments
 (0)