File tree Expand file tree Collapse file tree 1 file changed +23
-0
lines changed
Expand file tree Collapse file tree 1 file changed +23
-0
lines changed Original file line number Diff line number Diff line change 11"""Dedicated LocalBackend smoke test for PipelineTrainer."""
22
33import asyncio
4+ import json
5+ import math
46import os
7+ from pathlib import Path
58import tempfile
69import uuid
710
@@ -163,6 +166,8 @@ async def rollout_fn(
163166 min_batch_size = 1 ,
164167 max_batch_size = 1 ,
165168 max_steps = 2 ,
169+ kl_penalty_coef = 0.25 ,
170+ kl_penalty_reference_step = 0 ,
166171 loss_fn = "cispo" ,
167172 eval_fn = None ,
168173 )
@@ -180,5 +185,23 @@ async def rollout_fn(
180185 model_ids = [m .id async for m in client .models .list ()]
181186 assert f"{ model .name } @0" in model_ids
182187 assert f"{ model .name } @{ latest_step } " in model_ids
188+
189+ history_path = (
190+ Path (tmpdir )
191+ / model .project
192+ / "models"
193+ / model .name
194+ / "history.jsonl"
195+ )
196+ history_rows = [
197+ json .loads (line ) for line in history_path .read_text ().splitlines ()
198+ ]
199+ kl_values = [
200+ row ["loss/kl_policy_ref" ]
201+ for row in history_rows
202+ if "loss/kl_policy_ref" in row
203+ ]
204+ assert kl_values
205+ assert all (math .isfinite (value ) for value in kl_values )
183206 finally :
184207 await client .close ()
You can’t perform that action at this time.
0 commit comments