Skip to content

Commit 3a7574c

Browse files
arcticflyclaude
andauthored
feat: add provenance tracking to training backends (#551)
* feat: add provenance tracking to LocalBackend and ServerlessBackend training Records training technique ("local-rl" or "serverless-rl") in W&B run config at the start of each train() call, enabling downstream consumers to understand how a model was trained and detect technique changes over its lifecycle. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * chore: move test_provenance to tests/integration/ Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: resolve type errors and formatting in test_provenance Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 520aba6 commit 3a7574c

File tree

6 files changed

+129
-14
lines changed

6 files changed

+129
-14
lines changed

src/art/local/backend.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
get_output_dir_from_model_properties,
2525
get_step_checkpoint_dir,
2626
)
27+
from art.utils.record_provenance import record_provenance
2728
from art.utils.s3 import (
2829
ExcludableOption,
2930
pull_model_from_s3,
@@ -459,6 +460,11 @@ async def train( # type: ignore[override]
459460
"""
460461
groups_list = list(trajectory_groups)
461462

463+
# Record provenance in W&B
464+
wandb_run = model._get_wandb_run()
465+
if wandb_run is not None:
466+
record_provenance(wandb_run, "local-rl")
467+
462468
# Build config objects from explicit kwargs
463469
config = TrainConfig(learning_rate=learning_rate, beta=beta)
464470
dev_config: dev.TrainConfig = {

src/art/preprocessing/tokenize.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
import math
2-
import random
31
from dataclasses import dataclass
42
from itertools import takewhile
3+
import math
4+
import random
55
from typing import Any, Generator, cast
66

7-
import torch
87
from PIL import Image
8+
import torch
99
from transformers.image_processing_utils import BaseImageProcessor
1010
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
1111

@@ -167,8 +167,7 @@ def tokenize_trajectory(
167167
):
168168
last_assistant_index = i
169169
elif not isinstance(message, dict) and (
170-
message.logprobs
171-
or allow_training_without_logprobs # ty:ignore[possibly-missing-attribute]
170+
message.logprobs or allow_training_without_logprobs # ty:ignore[possibly-missing-attribute]
172171
):
173172
last_assistant_index = i
174173
# If there are no trainable assistant messages, return None
@@ -241,9 +240,7 @@ def tokenize_trajectory(
241240
continue
242241
if not allow_training_without_logprobs:
243242
continue
244-
elif (
245-
message.logprobs is None and not allow_training_without_logprobs
246-
): # ty:ignore[possibly-missing-attribute]
243+
elif message.logprobs is None and not allow_training_without_logprobs: # ty:ignore[possibly-missing-attribute]
247244
continue
248245
start = token_ids.index(sentinal_token_id)
249246
end = start + 1
@@ -268,16 +265,12 @@ def tokenize_trajectory(
268265
assistant_mask[start:end] = [1] * len(content_token_ids)
269266
else:
270267
choice = message
271-
assert (
272-
choice.logprobs or allow_training_without_logprobs
273-
), ( # ty:ignore[possibly-missing-attribute]
268+
assert choice.logprobs or allow_training_without_logprobs, ( # ty:ignore[possibly-missing-attribute]
274269
"Chat completion choices must have logprobs"
275270
)
276271
if not choice.logprobs: # ty:ignore[possibly-missing-attribute]
277272
continue
278-
token_logprobs = (
279-
choice.logprobs.content or choice.logprobs.refusal or []
280-
) # ty:ignore[possibly-missing-attribute]
273+
token_logprobs = choice.logprobs.content or choice.logprobs.refusal or [] # ty:ignore[possibly-missing-attribute]
281274
if (
282275
bytes(token_logprobs[0].bytes or []).decode("utf-8")
283276
== "<think>"

src/art/serverless/backend.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from ..backend import AnyTrainableModel, Backend
1212
from ..trajectories import TrajectoryGroup
1313
from ..types import ServerlessTrainResult, TrainConfig
14+
from ..utils.record_provenance import record_provenance
1415

1516
if TYPE_CHECKING:
1617
import wandb
@@ -209,6 +210,11 @@ async def train( # type: ignore[override]
209210
"""
210211
groups_list = list(trajectory_groups)
211212

213+
# Record provenance in W&B
214+
wandb_run = model._get_wandb_run()
215+
if wandb_run is not None:
216+
record_provenance(wandb_run, "serverless-rl")
217+
212218
# Build config objects from explicit kwargs
213219
config = TrainConfig(learning_rate=learning_rate, beta=beta)
214220
dev_config: dev.TrainConfig = {

src/art/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
from .iterate_dataset import iterate_dataset
55
from .limit_concurrency import limit_concurrency
66
from .log_http_errors import log_http_errors
7+
from .record_provenance import record_provenance
78
from .retry import retry
89

910
__all__ = [
1011
"format_message",
12+
"record_provenance",
1113
"retry",
1214
"iterate_dataset",
1315
"limit_concurrency",

src/art/utils/record_provenance.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
if TYPE_CHECKING:
6+
import wandb
7+
8+
9+
def record_provenance(run: wandb.Run, provenance: str) -> None:
10+
"""Record provenance in run metadata, ensuring it's the last value in the array."""
11+
if "provenance" in run.config:
12+
existing = list(run.config["provenance"])
13+
if existing[-1] != provenance:
14+
existing.append(provenance)
15+
run.config.update({"provenance": existing})
16+
else:
17+
run.config.update({"provenance": [provenance]})
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
"""Integration test: verify provenance tracking in W&B run config via ServerlessBackend."""
2+
3+
import asyncio
4+
from datetime import datetime
5+
6+
from dotenv import load_dotenv
7+
8+
import art
9+
from art.serverless.backend import ServerlessBackend
10+
11+
load_dotenv()
12+
13+
14+
async def simple_rollout(model: art.TrainableModel) -> art.Trajectory:
15+
"""Minimal rollout that produces a single turn with a reward."""
16+
traj = art.Trajectory(
17+
messages_and_choices=[
18+
{"role": "system", "content": "Reply with exactly 'hello'."},
19+
],
20+
reward=0.0,
21+
)
22+
23+
choice = (
24+
await model.openai_client().chat.completions.create(
25+
model=model.get_inference_name(),
26+
messages=traj.messages(),
27+
max_completion_tokens=16,
28+
timeout=30,
29+
)
30+
).choices[0]
31+
32+
traj.messages_and_choices.append(choice)
33+
traj.reward = (
34+
1.0 if (choice.message.content or "").strip().lower() == "hello" else 0.0
35+
)
36+
return traj
37+
38+
39+
async def make_group(model: art.TrainableModel) -> art.TrajectoryGroup:
40+
return art.TrajectoryGroup(simple_rollout(model) for _ in range(4))
41+
42+
43+
async def main() -> None:
44+
backend = ServerlessBackend()
45+
46+
model = art.TrainableModel(
47+
name=f"provenance-test-{datetime.now().strftime('%Y%m%d-%H%M%S')}",
48+
project="provenance-test",
49+
base_model="OpenPipe/Qwen3-14B-Instruct",
50+
)
51+
await model.register(backend)
52+
53+
# --- Step 1: first training call ---
54+
groups = await art.gather_trajectory_groups(make_group(model) for _ in range(1))
55+
result = await backend.train(model, groups)
56+
await model.log(groups, metrics=result.metrics, step=result.step, split="train")
57+
58+
# Check provenance after first train call
59+
run = model._get_wandb_run()
60+
assert run is not None, "W&B run should exist"
61+
provenance = run.config.get("provenance")
62+
print(f"After step 1: provenance = {provenance}")
63+
assert provenance == ["serverless-rl"], (
64+
f"Expected ['serverless-rl'], got {provenance}"
65+
)
66+
67+
# --- Step 2: second training call (same technique, should NOT duplicate) ---
68+
# Provenance is recorded at the start of train(), before the remote call,
69+
# so we can verify deduplication even if the server-side training fails.
70+
groups2 = await art.gather_trajectory_groups(make_group(model) for _ in range(1))
71+
try:
72+
result2 = await backend.train(model, groups2)
73+
await model.log(
74+
groups2, metrics=result2.metrics, step=result2.step, split="train"
75+
)
76+
except RuntimeError as e:
77+
print(f"Step 2 training failed (transient server error, OK for this test): {e}")
78+
79+
provenance = run.config.get("provenance")
80+
print(f"After step 2: provenance = {provenance}")
81+
assert provenance == ["serverless-rl"], (
82+
f"Expected ['serverless-rl'] (no duplicate), got {provenance}"
83+
)
84+
85+
print("\nAll provenance checks passed!")
86+
87+
await backend.close()
88+
89+
90+
if __name__ == "__main__":
91+
asyncio.run(main())

0 commit comments

Comments
 (0)