Skip to content

Commit e02c60c

Browse files
arcticflyclaude
andauthored
Add provenance to artifacts, not runs (#553)
* feat: copy provenance from source model when forking checkpoints Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * feat: add required provenance field to WandbDeploymentConfig Includes provenance array in W&B artifact metadata when deploying. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * feat: write provenance to W&B artifact metadata instead of run config Moves record_provenance to update the latest artifact version's metadata and calls it after training completes (so the new artifact exists). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: fix test_provenance to work with artifact-based provenance Use list syntax for gather_trajectory_groups and add retry for transient server timeouts. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * refactor: rename provenance metadata key to wandb.provenance Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 1539181 commit e02c60c

File tree

6 files changed

+81
-33
lines changed

6 files changed

+81
-33
lines changed

src/art/local/backend.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -460,11 +460,6 @@ async def train( # type: ignore[override]
460460
"""
461461
groups_list = list(trajectory_groups)
462462

463-
# Record provenance in W&B
464-
wandb_run = model._get_wandb_run()
465-
if wandb_run is not None:
466-
record_provenance(wandb_run, "local-rl")
467-
468463
# Build config objects from explicit kwargs
469464
config = TrainConfig(learning_rate=learning_rate, beta=beta)
470465
dev_config: dev.TrainConfig = {
@@ -521,6 +516,11 @@ async def train( # type: ignore[override]
521516
if not os.path.exists(checkpoint_path):
522517
checkpoint_path = None
523518

519+
# Record provenance on the latest W&B artifact
520+
wandb_run = model._get_wandb_run()
521+
if wandb_run is not None:
522+
record_provenance(wandb_run, "local-rl")
523+
524524
return LocalTrainResult(
525525
step=step,
526526
metrics=avg_metrics,

src/art/serverless/backend.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -210,11 +210,6 @@ async def train( # type: ignore[override]
210210
"""
211211
groups_list = list(trajectory_groups)
212212

213-
# Record provenance in W&B
214-
wandb_run = model._get_wandb_run()
215-
if wandb_run is not None:
216-
record_provenance(wandb_run, "serverless-rl")
217-
218213
# Build config objects from explicit kwargs
219214
config = TrainConfig(learning_rate=learning_rate, beta=beta)
220215
dev_config: dev.TrainConfig = {
@@ -260,6 +255,11 @@ async def train( # type: ignore[override]
260255
if model.entity is not None:
261256
artifact_name = f"{model.entity}/{model.project}/{model.name}:step{step}"
262257

258+
# Record provenance on the latest W&B artifact
259+
wandb_run = model._get_wandb_run()
260+
if wandb_run is not None:
261+
record_provenance(wandb_run, "serverless-rl")
262+
263263
return ServerlessTrainResult(
264264
step=step,
265265
metrics=avg_metrics,
@@ -645,6 +645,20 @@ async def _experimental_fork_checkpoint(
645645
run.log_artifact(dest_artifact, aliases=aliases)
646646
run.finish()
647647

648+
# Copy provenance from the source model's W&B run to the destination model
649+
api = wandb.Api(api_key=self._client.api_key) # ty:ignore[possibly-missing-attribute]
650+
try:
651+
source_run = api.run(f"{model.entity}/{from_project}/{from_model}")
652+
source_provenance = source_run.config.get("wandb.provenance")
653+
if source_provenance is not None:
654+
dest_run = model._get_wandb_run()
655+
if dest_run is not None:
656+
dest_run.config.update(
657+
{"wandb.provenance": list(source_provenance)}
658+
)
659+
except Exception:
660+
pass # Source run may not exist (e.g., S3-only models)
661+
648662
if verbose:
649663
print(
650664
f"Successfully forked checkpoint from {from_model} "

src/art/utils/deployment/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ async def deploy_model(
9090
model=model,
9191
checkpoint_path=checkpoint_path,
9292
step=step,
93+
config=config,
9394
verbose=verbose,
9495
)
9596
return DeploymentResult(inference_model_name=inference_name)

src/art/utils/deployment/wandb.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ class WandbDeploymentConfig(DeploymentConfig):
2121
- Qwen/Qwen2.5-14B-Instruct
2222
"""
2323

24-
pass
24+
provenance: list[str]
25+
"""The training provenance history for this model (e.g. ["local-rl", "serverless-rl"])."""
2526

2627

2728
WANDB_SUPPORTED_BASE_MODELS = [
@@ -36,6 +37,7 @@ def deploy_wandb(
3637
model: "TrainableModel",
3738
checkpoint_path: str,
3839
step: int,
40+
config: "WandbDeploymentConfig | None" = None,
3941
verbose: bool = False,
4042
) -> str:
4143
"""Deploy a model to W&B by uploading a LoRA artifact.
@@ -44,6 +46,7 @@ def deploy_wandb(
4446
model: The TrainableModel to deploy.
4547
checkpoint_path: Local path to the checkpoint directory.
4648
step: The step number of the checkpoint.
49+
config: Optional WandbDeploymentConfig with provenance metadata.
4750
verbose: Whether to print verbose output.
4851
4952
Returns:
@@ -74,10 +77,13 @@ def deploy_wandb(
7477
settings=wandb.Settings(api_key=os.environ["WANDB_API_KEY"]),
7578
)
7679
try:
80+
metadata: dict[str, object] = {"wandb.base_model": model.base_model}
81+
if config is not None:
82+
metadata["wandb.provenance"] = config.provenance
7783
artifact = wandb.Artifact(
7884
model.name,
7985
type="lora",
80-
metadata={"wandb.base_model": model.base_model},
86+
metadata=metadata,
8187
storage_region="coreweave-us",
8288
)
8389
artifact.add_dir(checkpoint_path)

src/art/utils/record_provenance.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,22 @@
77

88

99
def record_provenance(run: wandb.Run, provenance: str) -> None:
10-
"""Record provenance in run metadata, ensuring it's the last value in the array."""
11-
if "provenance" in run.config:
12-
existing = list(run.config["provenance"])
10+
"""Record provenance on the latest artifact version's metadata."""
11+
import wandb as wandb_module
12+
13+
api = wandb_module.Api()
14+
artifact_path = f"{run.entity}/{run.project}/{run.name}:latest"
15+
try:
16+
artifact = api.artifact(artifact_path, type="lora")
17+
except wandb_module.errors.CommError:
18+
return # No artifact exists yet
19+
20+
existing = artifact.metadata.get("wandb.provenance")
21+
if existing is not None:
22+
existing = list(existing)
1323
if existing[-1] != provenance:
1424
existing.append(provenance)
15-
run.config.update({"provenance": existing})
25+
artifact.metadata["wandb.provenance"] = existing
1626
else:
17-
run.config.update({"provenance": [provenance]})
27+
artifact.metadata["wandb.provenance"] = [provenance]
28+
artifact.save()

tests/integration/test_provenance.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
"""Integration test: verify provenance tracking in W&B run config via ServerlessBackend."""
1+
"""Integration test: verify provenance tracking on W&B artifact metadata via ServerlessBackend."""
22

33
import asyncio
44
from datetime import datetime
55

66
from dotenv import load_dotenv
7+
import wandb
78

89
import art
910
from art.serverless.backend import ServerlessBackend
@@ -36,8 +37,13 @@ async def simple_rollout(model: art.TrainableModel) -> art.Trajectory:
3637
return traj
3738

3839

39-
async def make_group(model: art.TrainableModel) -> art.TrajectoryGroup:
40-
return art.TrajectoryGroup(simple_rollout(model) for _ in range(4))
40+
def get_latest_artifact_provenance(
41+
entity: str, project: str, name: str
42+
) -> list[str] | None:
43+
"""Fetch provenance from the latest W&B artifact's metadata."""
44+
api = wandb.Api()
45+
artifact = api.artifact(f"{entity}/{project}/{name}:latest", type="lora")
46+
return artifact.metadata.get("wandb.provenance")
4147

4248

4349
async def main() -> None:
@@ -49,25 +55,35 @@ async def main() -> None:
4955
base_model="OpenPipe/Qwen3-14B-Instruct",
5056
)
5157
await model.register(backend)
58+
assert model.entity is not None
5259

53-
# --- Step 1: first training call ---
54-
groups = await art.gather_trajectory_groups(make_group(model) for _ in range(1))
55-
result = await backend.train(model, groups)
56-
await model.log(groups, metrics=result.metrics, step=result.step, split="train")
57-
58-
# Check provenance after first train call
59-
run = model._get_wandb_run()
60-
assert run is not None, "W&B run should exist"
61-
provenance = run.config.get("provenance")
60+
# --- Step 1: first training call (retry on transient server errors) ---
61+
for attempt in range(3):
62+
groups = await art.gather_trajectory_groups(
63+
[art.TrajectoryGroup(simple_rollout(model) for _ in range(4))] # ty: ignore[invalid-argument-type]
64+
)
65+
try:
66+
result = await backend.train(model, groups)
67+
await model.log(
68+
groups, metrics=result.metrics, step=result.step, split="train"
69+
)
70+
break
71+
except RuntimeError as e:
72+
print(f"Step 1 attempt {attempt + 1} failed: {e}")
73+
if attempt == 2:
74+
raise
75+
76+
# Check provenance on the latest artifact after first train call
77+
provenance = get_latest_artifact_provenance(model.entity, model.project, model.name)
6278
print(f"After step 1: provenance = {provenance}")
6379
assert provenance == ["serverless-rl"], (
6480
f"Expected ['serverless-rl'], got {provenance}"
6581
)
6682

6783
# --- Step 2: second training call (same technique, should NOT duplicate) ---
68-
# Provenance is recorded at the start of train(), before the remote call,
69-
# so we can verify deduplication even if the server-side training fails.
70-
groups2 = await art.gather_trajectory_groups(make_group(model) for _ in range(1))
84+
groups2 = await art.gather_trajectory_groups(
85+
[art.TrajectoryGroup(simple_rollout(model) for _ in range(4))] # ty: ignore[invalid-argument-type]
86+
)
7187
try:
7288
result2 = await backend.train(model, groups2)
7389
await model.log(
@@ -76,7 +92,7 @@ async def main() -> None:
7692
except RuntimeError as e:
7793
print(f"Step 2 training failed (transient server error, OK for this test): {e}")
7894

79-
provenance = run.config.get("provenance")
95+
provenance = get_latest_artifact_provenance(model.entity, model.project, model.name)
8096
print(f"After step 2: provenance = {provenance}")
8197
assert provenance == ["serverless-rl"], (
8298
f"Expected ['serverless-rl'] (no duplicate), got {provenance}"

0 commit comments

Comments
 (0)