@@ -658,12 +658,7 @@ def _log_metrics(
658658
659659 # If we have a W&B run, log the data there
660660 if run := self ._get_wandb_run (model ):
661- # Mark the step metric itself as hidden so W&B doesn't create an automatic chart for it
662- wandb .define_metric ("training_step" , hidden = True )
663-
664- # Enabling the following line will cause W&B to use the training_step metric as the x-axis for all metrics
665- # wandb.define_metric(f"{split}/*", step_metric="training_step")
666- run .log ({"training_step" : step , ** metrics }, step = step )
661+ run .log ({"training_step" : step , ** metrics })
667662
668663 def _get_wandb_run (self , model : Model ) -> Run | None :
669664 if "WANDB_API_KEY" not in os .environ :
@@ -688,6 +683,12 @@ def _get_wandb_run(self, model: Model) -> Run | None:
688683 ),
689684 )
690685 self ._wandb_runs [model .name ] = run
686+
687+ # Define training_step as the x-axis for all metrics.
688+ # This allows out-of-order logging (e.g., async validation for previous steps).
689+ wandb .define_metric ("training_step" )
690+ wandb .define_metric ("train/*" , step_metric = "training_step" )
691+ wandb .define_metric ("val/*" , step_metric = "training_step" )
691692 os .environ ["WEAVE_PRINT_CALL_LINK" ] = os .getenv (
692693 "WEAVE_PRINT_CALL_LINK" , "False"
693694 )
0 commit comments