Skip to content

Commit 05f39ef

Browse files
committed
test: Add max batch size regression coverage
1 parent dbefea6 commit 05f39ef

File tree

1 file changed

+67
-0
lines changed

1 file changed

+67
-0
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import asyncio
2+
from pathlib import Path
3+
from unittest.mock import MagicMock
4+
5+
import pytest
6+
7+
from art import TrainableModel, Trajectory, TrajectoryGroup
8+
from art.pipeline_trainer.trainer import PipelineTrainer
9+
10+
11+
def _make_group() -> TrajectoryGroup:
12+
return TrajectoryGroup(
13+
[
14+
Trajectory(
15+
reward=reward,
16+
initial_policy_version=0,
17+
messages_and_choices=[
18+
{"role": "user", "content": f"prompt-{idx}"},
19+
{"role": "assistant", "content": f"answer-{idx}"},
20+
],
21+
)
22+
for idx, reward in enumerate([0.0, 1.0])
23+
]
24+
)
25+
26+
27+
@pytest.mark.asyncio
28+
async def test_collect_batch_respects_max_batch_size(tmp_path: Path) -> None:
29+
model = TrainableModel(
30+
name="pipeline-max-batch-size-test",
31+
project="pipeline-tests",
32+
base_model="test-model",
33+
base_path=str(tmp_path),
34+
)
35+
trainer = PipelineTrainer(
36+
model=model,
37+
backend=MagicMock(), # type: ignore[arg-type]
38+
rollout_fn=lambda *_args, **_kwargs: asyncio.sleep(0),
39+
scenarios=[],
40+
config={},
41+
num_rollout_workers=1,
42+
min_batch_size=1,
43+
max_batch_size=2,
44+
max_steps=1,
45+
eval_fn=None,
46+
)
47+
trainer._output_queue = asyncio.Queue()
48+
49+
first = _make_group()
50+
second = _make_group()
51+
third = _make_group()
52+
await trainer._output_queue.put(first)
53+
await trainer._output_queue.put(second)
54+
await trainer._output_queue.put(third)
55+
await trainer._output_queue.put(None)
56+
57+
batch, discarded, saw_sentinel = await trainer._collect_batch(current_step=0)
58+
59+
assert batch == [first, second]
60+
assert discarded == 0
61+
assert not saw_sentinel
62+
63+
batch, discarded, saw_sentinel = await trainer._collect_batch(current_step=0)
64+
65+
assert batch == [third]
66+
assert discarded == 0
67+
assert saw_sentinel

0 commit comments

Comments
 (0)