Skip to content

Commit abecd93

Browse files
authored
Fix W&B run reuse across ART models (#618)
* Fix wandb run reuse across models * Remove ART wandb regression test
1 parent 4cf171d commit abecd93

File tree

1 file changed

+17
-14
lines changed

1 file changed

+17
-14
lines changed

src/art/model.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,7 @@ def _get_wandb_run(self) -> Optional["Run"]:
470470
id=self.name,
471471
config=self._wandb_config or None,
472472
resume="allow",
473+
reinit="create_new",
473474
settings=wandb.Settings(
474475
x_stats_open_metrics_endpoints={
475476
"vllm": "http://localhost:8000/metrics",
@@ -492,18 +493,18 @@ def _get_wandb_run(self) -> Optional["Run"]:
492493

493494
# Define training_step as the x-axis for all metrics.
494495
# This allows out-of-order logging (e.g., async validation for previous steps).
495-
wandb.define_metric("training_step")
496-
wandb.define_metric("time/wall_clock_sec")
497-
wandb.define_metric("reward/*", step_metric="training_step")
498-
wandb.define_metric("loss/*", step_metric="training_step")
499-
wandb.define_metric("throughput/*", step_metric="training_step")
500-
wandb.define_metric("costs/*", step_metric="training_step")
501-
wandb.define_metric("time/*", step_metric="training_step")
502-
wandb.define_metric("data/*", step_metric="training_step")
503-
wandb.define_metric("train/*", step_metric="training_step")
504-
wandb.define_metric("val/*", step_metric="training_step")
505-
wandb.define_metric("test/*", step_metric="training_step")
506-
wandb.define_metric("discarded/*", step_metric="training_step")
496+
run.define_metric("training_step")
497+
run.define_metric("time/wall_clock_sec")
498+
run.define_metric("reward/*", step_metric="training_step")
499+
run.define_metric("loss/*", step_metric="training_step")
500+
run.define_metric("throughput/*", step_metric="training_step")
501+
run.define_metric("costs/*", step_metric="training_step")
502+
run.define_metric("time/*", step_metric="training_step")
503+
run.define_metric("data/*", step_metric="training_step")
504+
run.define_metric("train/*", step_metric="training_step")
505+
run.define_metric("val/*", step_metric="training_step")
506+
run.define_metric("test/*", step_metric="training_step")
507+
run.define_metric("discarded/*", step_metric="training_step")
507508
self._sync_wandb_config(run)
508509
return self._wandb_run
509510

@@ -562,14 +563,16 @@ def _log_metrics(
562563
run.log(prefixed)
563564

564565
def _define_wandb_step_metrics(self, keys: Iterable[str]) -> None:
565-
import wandb
566+
run = self._wandb_run
567+
if run is None or run._is_finished:
568+
return
566569

567570
for key in keys:
568571
if not key.startswith("costs/"):
569572
continue
570573
if key in self._wandb_defined_metrics:
571574
continue
572-
wandb.define_metric(key, step_metric="training_step")
575+
run.define_metric(key, step_metric="training_step")
573576
self._wandb_defined_metrics.add(key)
574577

575578
def _route_metrics_and_collect_non_costs(

0 commit comments

Comments
 (0)