Skip to content

Parallel eval#772

Open
amogh-gulati wants to merge 7 commits into
mainfrom
parallel_eval
Open

Parallel eval#772
amogh-gulati wants to merge 7 commits into
mainfrom
parallel_eval

Conversation

@amogh-gulati

@amogh-gulati amogh-gulati commented Jun 22, 2026

Copy link
Copy Markdown
Collaborator

Post-train checkpoint eval/viz sweep

Adds an optional post-training sweep that evaluates selected checkpoints after training completes, always including the final EMA checkpoint, with optional visualization generation.

  • Add post_train_eval.py to discover checkpoints in saved_nets/, shard eval jobs across available GPUs, run standalone inference per checkpoint, and write summary.json.
  • Support either last_n_checkpoints: N or an explicit checkpoints: [...] epoch list, with validation for mutually exclusive settings and missing checkpoints.
  • Edit Trainer.finish() to launch the sweep on the main process via a new post_train_eval config block.
  • Add a DDP barrier and destroy_process_group() before the sweep so all training ranks finish cleanly and release GPUs. (still seems a little flacky)
  • Refactor viz to prepare groundtruth once and reuse it across checkpoint visualizations.
  • Support standalone usage via python -m samudra.post_train_eval for models already trained.

Notes:

  • Viz sweep requires eval.save_zarr=true.
  • In the current config, checkpoints: [50, 55, 60, 65, 70] evaluates those epochs plus the final EMA checkpoint.

#Config

Add a post_train_eval block to the train config to enable the post-training checkpoint sweep:

post_train_eval:
enabled: true
eval_config_path: configs//eval.yaml # required when enabled
viz_config_path: configs//viz.yaml # optional; omit to skip viz

Pick exactly one checkpoint selection mode:
checkpoints: [50, 55, 60, 65, 70] # explicit checkpoint epochs
last_n_checkpoints: 10 # or evaluate the last N checkpoints

The final EMA checkpoint is always included automatically.
wandb is disabled for this, will a follow up PR I will add all the evals in one run (ideally with timeline scrub) which is not tested yet!

@amogh-gulati amogh-gulati requested a review from jder June 25, 2026 21:29
@jder

jder commented Jun 26, 2026

Copy link
Copy Markdown
Member

@codex review

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 0d7bf27bd7

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread src/samudra/train.py
Comment on lines +1148 to +1149
if main_process:
run_post_train_checkpoint_sweep(self.cfg, self.ckpt_paths)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Free rank-0 GPU before spawning eval sweep

When post_train_eval is enabled after CUDA/DDP training, this call starts the sweep while the Trainer object still owns the rank-0 model, optimizer, EMA state, and dataloader tensors on GPU 0. Destroying the process group releases the other ranks but does not free rank-0 CUDA allocations, and _resolve_worker_count() assigns an eval worker to GPU 0, so large runs can OOM when that worker loads a second model/rollout on the same occupied GPU; delete/move the CUDA state and empty the cache, or run the sweep after the training process exits.

Useful? React with 👍 / 👎.

Comment on lines +380 to +382
prepared_groundtruth = template_cfg.prepare_groundtruth(
LocalLocation(path=Path.cwd())
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Resolve viz ground truth against the training data root

When the sweep is launched from training and the viz config leaves data_root unset with a relative groundtruth_location (the shipped viz configs use OM4.zarr this way), eval gets train_cfg.experiment.data_root but viz prepares ground truth relative to Path.cwd() instead. That makes post-train viz fail by looking under the repo/launch directory, or worse use a different local dataset; pass the training data root into the viz config/default root before preparing ground truth.

Useful? React with 👍 / 👎.

raise ValueError(
f"last_n_checkpoints must be >= 1, got {last_n_checkpoints}"
)
targets = targets[-last_n_checkpoints:]

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Slice periodic checkpoints before appending EMA

When last_n_checkpoints is set, the final EMA entry has already been appended before this slice runs, so last_n_checkpoints: 10 evaluates only 9 periodic checkpoints plus final_ema (and 1 evaluates only EMA). The config/commit describes the EMA checkpoint as always added in addition to the selected checkpoints, so users silently get fewer saved epochs than requested; apply the limit to periodic targets before appending EMA.

Useful? React with 👍 / 👎.

@jder jder left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very excited about this. I didn't read everything in detail but a few high-level comments first.

I also have one question about the goal: why are we causing the existing per-GPU torch processes to exit and then re-spawning them? I think will cause us to drop down to a single host for running evals/viz when doing multi-host training. Can we use the existing per-GPU processes instead (and run post_train_eval via slurm/torchrun as we do training when we want it to be standalone?). Basically just have the existing non-main processes go into a worker loop and distribute work via torch.distributed?

Comment on lines +406 to +408
eval_config_path = (
Path(train_cfg.post_train_eval.eval_config_path).expanduser().resolve()
)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why we need expanduser()/resolve() here given it's in Config.from_yaml_and_cli

logger.warning("No checkpoints selected for post-train eval sweep")
return []

eval_cfg = EvalConfig.from_yaml_and_cli(list(eval_config_args))

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably nicer to have a path here instead and expose a Config.from_yaml()

cfg = VizConfig.model_validate(updated)

start = time.perf_counter()
run_with_prepared_groundtruth(cfg, prepared_groundtruth)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this config surgery is a bit of a smell and somewhat fragile. How about:

  • Create a new top-level config VizTemplateConfig with base_output_dir, dataset_name, variables, data_root, etc. But not runs or name. It builds a VizTemplate which is basically the same content as PreparedVizGroundtruth today.
  • In run_checkpoint_sweep we load a VizTemplateConfig and call its build method to get a VizTemplate.
  • In this code we call viz_template.instantiate(output_path, runs) or something to get a Viz.
  • VizConfig now extends VizTemplateConfig and adds runs + name. (VizConfig.build now calls super.build and then calls instantiate on that)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Alternatively, if you want to just land this without the viz stuff that works for me too and we can deal with this as a follow-up.)

Comment thread src/samudra/train.py
cfg.prepare_output_dirs()
cfg.save_yaml(cfg.experiment.output_dir / "config.yaml")

self.cfg = cfg

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pattern we're trying to move towards (though I see we have failed to actually write this down anywhere) is to avoid passing around cfg values like this (and below), and instead turn the cfg types into "inflated" or "ready to run" types with build(...) as early as possible, passing in needed dependencies. Trainer is the main exception to this pattern since it predates the config system and is a beast to refactor in that way. Would it be possible to avoid saving the cfg here and to apply this pattern to the PostTrainCheckpointSweepConfig type, passing in the needed extra data from TrainConfig to the build method (e.g. the experiment output directory) here? (Perhaps producing a CheckpointSweep instance which has all the needed information to run such a thing?)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants