Skip to content

Commit 5af1d38

Browse files
corbtCursor Bot
andauthored
feat: Backend-First Training API (Phase 1) (#521)
* feat: Backend-First Training API (Phase 1) BREAKING CHANGE: Training is now initiated through backend.train() instead of model.train() ## Migration Guide ### Before (deprecated): ```python await model.train(trajectory_groups, config=art.TrainConfig(learning_rate=5e-6)) ``` ### After (recommended): ```python await model.log(trajectory_groups, split='train') # Log trajectories result = await backend.train(model, trajectory_groups, learning_rate=5e-6) await model.log(metrics=result.metrics, step=result.step, split='train') # Log training metrics ``` ## Key Changes - **New API**: `backend.train(model, trajectory_groups, **kwargs)` with explicit, type-safe parameters - **Explicit logging**: `backend.train()` does NOT automatically log trajectories or metrics - **Extended model.log()**: Now accepts `metrics` and `step` kwargs for logging training metrics directly - **Structured returns**: `LocalTrainResult` and `ServerlessTrainResult` with step, metrics, and backend-specific fields - **Fixed get_inference_name()**: Now correctly returns `model.name@step` for LocalBackend - **Deprecation warning**: `model.train()` emits a warning with migration instructions ## Phase 2 (Future) In a future release, we will: - Remove `model.train()` method entirely - Remove `art.TrainConfig` and `art.dev.TrainConfig` classes Closes #519 * refactor: simplify examples to single model.log() after training Combined trajectory and metrics logging into a single call: result = await backend.train(model, groups, ...) await model.log(groups, metrics=result.metrics, step=result.step, split='train') Removed redundant comments and pre-training log calls. --------- Co-authored-by: Cursor Bot <bot@cursor.com>
1 parent 8636830 commit 5af1d38

23 files changed

Lines changed: 704 additions & 105 deletions

examples/2048/train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,9 @@ async def train():
7171
model,
7272
)
7373

74-
await model.train(
75-
train_groups,
76-
config=art.TrainConfig(learning_rate=1e-5),
74+
result = await backend.train(model, train_groups, learning_rate=1e-5)
75+
await model.log(
76+
train_groups, metrics=result.metrics, step=result.step, split="train"
7777
)
7878

7979

examples/benchmarking_comparison_models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,8 @@ async def train_model(model: art.TrainableModel):
127127
)
128128
for scenario in batch.items
129129
)
130-
await model.train(groups)
130+
result = await backend.train(model, groups)
131+
await model.log(groups, metrics=result.metrics, step=result.step, split="train")
131132

132133
if batch.step % 20 == 0:
133134
# Every 20 steps let's benchmark our model under training so we can

examples/hn_title_generator/train.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -325,9 +325,11 @@ async def main():
325325
)
326326
continue
327327

328-
await model.train(
329-
valid_train_groups,
330-
config=art.TrainConfig(learning_rate=LEARNING_RATE),
328+
result = await backend.train(
329+
model, valid_train_groups, learning_rate=LEARNING_RATE
330+
)
331+
await model.log(
332+
valid_train_groups, metrics=result.metrics, step=result.step, split="train"
331333
)
332334

333335
if batch.step > 0 and batch.step % EVAL_STEPS == 0:

examples/just-the-facts/just_the_facts/train.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,13 @@ async def train(
8181
),
8282
)
8383

84-
await model.train(
84+
result = await backend.train(
85+
model,
8586
groups,
86-
config=art.TrainConfig(learning_rate=model.config.learning_rate),
87-
_config=art.dev.TrainConfig(
88-
scale_rewards=model.config.scale_rewards,
89-
),
87+
learning_rate=model.config.learning_rate,
88+
scale_rewards=model.config.scale_rewards,
9089
)
90+
await model.log(groups, metrics=result.metrics, step=result.step, split="train")
9191

9292
await backend._experimental_push_to_s3(model)
9393

examples/mcp-rl/mcp_rl/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,8 @@ async def train_mcp_agent(model: art.TrainableModel, use_skypilot: bool = False)
168168
await model.log(val_groups, split="val")
169169

170170
print("starting train")
171-
await model.train(groups, config=art.TrainConfig(learning_rate=learning_rate))
171+
result = await backend.train(model, groups, learning_rate=learning_rate)
172+
await model.log(groups, metrics=result.metrics, step=result.step, split="train")
172173

173174
await backend._experimental_push_to_s3(
174175
model,

examples/openenv_echo.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ async def main() -> None:
8686
[art.TrajectoryGroup(rollout(model, env_client) for env_client in env_pool)]
8787
)
8888

89-
await model.train(groups)
89+
result = await backend.train(model, groups)
90+
await model.log(groups, metrics=result.metrics, step=result.step, split="train")
9091

9192

9293
asyncio.run(main())

examples/prisoners-dilemma.ipynb

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -136,17 +136,18 @@
136136
" )\n",
137137
" await model.log([ts[0] for ts in base_play_trajectories], split=\"versus-base\")\n",
138138
" await model.log([ts[1] for ts in base_play_trajectories], split=\"base-model\")\n",
139-
" # Train the model on self-play and base-play trajectories.\n",
140-
" await model.train(\n",
141-
" trajectory_groups=[\n",
142-
" # Since all self-play games have the same starting state and are symmetric, we can gather\n",
143-
" # trajectories from all self-play games into a single trajectory group.\n",
144-
" art.TrajectoryGroup(t for ts in self_play_trajectories for t in ts),\n",
145-
" # We can also gather all base-play _trained model_ trajectories into a single trajectory group.\n",
146-
" # We don't want to train on base model trajectories, because they are sampled from a different distribution.\n",
147-
" art.TrajectoryGroup(ts[0] for ts in base_play_trajectories),\n",
148-
" ],\n",
149-
" config=art.TrainConfig(learning_rate=5e-5),\n",
139+
" # Train the model on self-play and base-play trajectories using the backend-first API.\n",
140+
" trajectory_groups = [\n",
141+
" # Since all self-play games have the same starting state and are symmetric, we can gather\n",
142+
" # trajectories from all self-play games into a single trajectory group.\n",
143+
" art.TrajectoryGroup(t for ts in self_play_trajectories for t in ts),\n",
144+
" # We can also gather all base-play _trained model_ trajectories into a single trajectory group.\n",
145+
" # We don't want to train on base model trajectories, because they are sampled from a different distribution.\n",
146+
" art.TrajectoryGroup(ts[0] for ts in base_play_trajectories),\n",
147+
" ]\n",
148+
" result = await backend.train(model, trajectory_groups, learning_rate=5e-5)\n",
149+
" await model.log(\n",
150+
" trajectory_groups, metrics=result.metrics, step=result.step, split=\"train\"\n",
150151
" )"
151152
]
152153
}
@@ -172,4 +173,4 @@
172173
},
173174
"nbformat": 4,
174175
"nbformat_minor": 2
175-
}
176+
}

examples/rock-paper-tool-use.ipynb

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@
5252
"model = art.TrainableModel(\n",
5353
" name=MODEL_NAME, project=\"rock-paper-tool-use\", base_model=BASE_MODEL\n",
5454
")\n",
55-
"await model.register(LocalBackend())\n",
55+
"backend = LocalBackend()\n",
56+
"await model.register(backend)\n",
5657
"client = model.openai_client()\n",
5758
"\n",
5859
"\n",
@@ -180,10 +181,10 @@
180181
" trajectories = await art.gather_trajectories(\n",
181182
" (rollout() for _ in range(64)), max_exceptions=64\n",
182183
" )\n",
183-
" await model.train(\n",
184-
" [art.TrajectoryGroup(trajectories)],\n",
185-
" config=art.TrainConfig(learning_rate=5e-5),\n",
186-
" )"
184+
" # Log trajectories and train using the backend-first API\n",
185+
" groups = [art.TrajectoryGroup(trajectories)]\n",
186+
" result = await backend.train(model, groups, learning_rate=5e-5)\n",
187+
" await model.log(groups, metrics=result.metrics, step=result.step, split=\"train\")"
187188
]
188189
}
189190
],
@@ -208,4 +209,4 @@
208209
},
209210
"nbformat": 4,
210211
"nbformat_minor": 2
211-
}
212+
}

examples/temporal_clue/temporal-clue-7b-async.ipynb

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -156,10 +156,11 @@
156156
" for trajectory in group:\n",
157157
" trajectory.metrics[\"max_reward\"] = max_reward\n",
158158
" await model.delete_checkpoints()\n",
159-
" await model.train(\n",
160-
" train_groups,\n",
161-
" config=art.TrainConfig(learning_rate=5e-6),\n",
162-
" _config=art.dev.TrainConfig(precalculate_logprobs=True),\n",
159+
" result = await backend.train(\n",
160+
" model, train_groups, learning_rate=5e-6, precalculate_logprobs=True\n",
161+
" )\n",
162+
" await model.log(\n",
163+
" train_groups, metrics=result.metrics, step=result.step, split=\"train\"\n",
163164
" )"
164165
]
165166
}
@@ -185,4 +186,4 @@
185186
},
186187
"nbformat": 4,
187188
"nbformat_minor": 2
188-
}
189+
}

examples/temporal_clue/temporal-clue-7b.ipynb

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,15 @@
118118
" trajectory.metrics[\"max_reward\"] = max_reward\n",
119119
" await model.log(val_groups)\n",
120120
" await model.delete_checkpoints()\n",
121-
" await model.train(\n",
121+
" result = await backend.train(\n",
122+
" model,\n",
122123
" train_groups,\n",
123-
" config=art.TrainConfig(learning_rate=5e-6),\n",
124-
" _config=art.dev.TrainConfig(precalculate_logprobs=True, scale_rewards=False),\n",
124+
" learning_rate=5e-6,\n",
125+
" precalculate_logprobs=True,\n",
126+
" scale_rewards=False,\n",
127+
" )\n",
128+
" await model.log(\n",
129+
" train_groups, metrics=result.metrics, step=result.step, split=\"train\"\n",
125130
" )"
126131
]
127132
}
@@ -147,4 +152,4 @@
147152
},
148153
"nbformat": 4,
149154
"nbformat_minor": 2
150-
}
155+
}

0 commit comments

Comments
 (0)