Skip to content

Separate prognostic and boundary tensors end-to-end#701

Merged
alxmrs merged 8 commits into
mainfrom
u/alxmrs/kr2/1-split-tensors
Apr 18, 2026
Merged

Separate prognostic and boundary tensors end-to-end#701
alxmrs merged 8 commits into
mainfrom
u/alxmrs/kr2/1-split-tensors

Conversation

@alxmrs

@alxmrs alxmrs commented Apr 15, 2026

Copy link
Copy Markdown
Member

Thread prognostic and boundary inputs as separate tensors through the data pipeline, stepper, base model, and all model implementations. Previously they were concatenated along the channel dim before reaching the model; now each model's forward_once receives (prog, boundary, ctx).

FOMO concatenates before its single-stream encoder (the dual-perceiver encoder that enables cross-resolution fusion lands in a follow-up). Samudra and FOMini concatenate inside their forward_once as before.

This is #681 broken up into a stack of three small PRs. 1/3

Thread prognostic and boundary inputs as separate tensors through the
data pipeline, stepper, base model, and all model implementations.
Previously they were concatenated along the channel dim before reaching
the model; now each model's forward_once receives (prog, boundary, ctx).

FOMO concatenates before its single-stream encoder (the dual-perceiver
encoder that enables cross-resolution fusion lands in a follow-up).
Samudra and FOMini concatenate inside their forward_once as before.

Also includes unrelated changes from previously-merged PRs that were
on this branch: DataLoadingConfig (#668), Nsight profiling (#674),
samudra highres configs (#679), torch profiling improvements (#677).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

# Conflicts:
#	src/ocean_emulators/stepper.py
#	tests/test_stepper.py
@alxmrs alxmrs force-pushed the u/alxmrs/kr2/1-split-tensors branch from 3392629 to 93bca31 Compare April 15, 2026 02:40
@alxmrs

alxmrs commented Apr 15, 2026

Copy link
Copy Markdown
Member Author

@codex may I have a review of this PR?

@alxmrs alxmrs marked this pull request as ready for review April 15, 2026 02:46
@alxmrs alxmrs requested a review from jder April 15, 2026 02:46

@jder jder left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few suggestions but generally LGTM! Love the getting rid of Stepper. I kinda wish Claude did not leave so many vague comments about matching "existing behavior" but 🤷

epoch=None,
) -> ModelInferenceOutput:
out_shape = (num_steps, *dataset[0][1].shape[1:])
# `dataset[idx]` returns `(prog, boundary, label)`.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

<3

Comment thread src/ocean_emulators/datasets.py Outdated
Comment thread src/ocean_emulators/models/base.py Outdated
)
+ decodings
)
pred = prog_tensor.to(device=get_device()) + decodings

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We seem to have lost the [0, ... part from the old code. IIUC the prog_tensor has a batch dimension which we're trying to get rid of here before adding decodings. I guess we don't have inference tests that would catch this? Or am I missing something?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is covered by tests because when I removed the unexpand(0) a test failed. I tried two variants (the status quo and a adhering to the original version) and both versions worked. Merging this version.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fomo-bot can you help me understand why this version works given my above comments?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fomo-bot ok, can you please try again to help me understand why this version works given my above comments?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Short version: your intuition about the batch dim was right. The version without [0] works because this inference path is batch-1 and PyTorch is forgiving about leading singleton dims here, not because the batch dim truly disappeared.

The shapes are:

  • prog_tensor / boundary_tensor: [1, C, H, W] during rollout (InferenceDataset returns a singleton window_dim).
  • decodings = forward_once(...): [1, C, H, W].
  • pred_tensor[step]: [C, H, W], because pred_tensor is allocated from dataset[0][-1].shape[1:].

The old code did input_tensor[0, :out_channels] + decodings. input_tensor[0, ...] is [C, H, W], but adding batched decodings broadcasts the result back to [1, C, H, W]. The version you flagged, prog_tensor + decodings, is also [1, C, H, W] and has the same values because the batch size is 1. Then pred_tensor[step] = pred accepts/stores that leading singleton dimension into the [C, H, W] slot.

So both variants pass for batch-1 inference. The current branch’s prog_tensor[0] + decodings is closer to the old residual-base selection, but it still produces a batched temporary because decodings is batched. If we want the shape intent to be explicit, I would make it prog_tensor[0] + decodings[0] (or squeeze once before assignment) and add an inference residual test; the existing rollout tests mostly cover pred_residuals=False, while residual behavior is covered more directly in validation.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Posted the reply on the PR thread: https://github.com/Open-Athena/Ocean_Emulator/pull/701#discussion_r3118214309



Summary: jder’s batch-dimension concern was right. The version works because inference is batch-1, the old code also broadcast the unbatched residual against batched decodings, and PyTorch accepts the leading singleton dim when assigning into pred_tensor[step].



I also ran uv run pytest tests/test_stepper.py -q: 72 passed. No code changes made.

Codex work

Codex transcript

- I’ll inspect the PR discussion and the exact diff around that review comment so I can answer the “why this works” question from the code, not from guesses.
- I found the thread: the key hunk is the residual addition in `BaseModel.inference`, and the branch has later commits that changed that line again. I’m going to inspect the current file plus the rollout tests to separate “why no crash” from “whether the shape is actually t...[truncated]
- The current branch now has `prog_tensor[0] + decodings`, while the older review hunk had `prog_tensor + decodings`. I’m checking the intermediate commits because the explanation differs: one shape is intentional, the other only survives because of PyTorch broadcasting and assig...[truncated]
- The plain `python` executable is not installed here, so I’m using `python3` for a small shape check.
- `python3` exists but this shell environment does not have `torch` importable. I’ll rely on the code paths and, if there is a project environment configured, run the targeted tests through that rather than the bare interpreter.
- One subtlety here is that the old code also did not fully remove the batch dimension from the residual result, because `decodings` is still batched. I’m reading the model outputs now; that determines whether the explanation is “same values, different temporary rank” or whet...[truncated]
- The shape check confirms the surprising part: PyTorch will accept a source like `[1, C, H, W]` when assigning into `pred_tensor[step]` shaped `[C, H, W]` because the extra leading dimension is size 1. I’m running the local stepper tests now to see what this branch actually cove...[truncated]
- I have enough to answer the thread. I’m going to post a concise explanation: your batch-dimension concern is valid, and the reason the version passes is singleton batch broadcasting plus PyTorch accepting that singleton source when writing into the unbatched rollout buffer.
- The first GitHub API endpoint form returned 404, so I’m retrying with the pull-number-scoped review-comment reply endpoint.
- Posted the reply on the PR thread: https://github.com/Open-Athena/Ocean_Emulator/pull/701#discussion_r3118214309

Summary: jder’s batch-dimension concern was right. The version works because inference is batch-1, the old code also broadcast the unbatched residual against batche...[truncated]

Comment thread src/ocean_emulators/datasets.py Outdated
return prognostic_steps
return rearrange(
steps, "batch time variable lat lon -> batch (time variable) lat lon"
)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice cleanup!

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I was happy about this too.

Comment thread src/ocean_emulators/stepper.py Outdated
Comment thread src/ocean_emulators/stepper.py Outdated
@alxmrs alxmrs merged commit 027c144 into main Apr 18, 2026
9 checks passed
@alxmrs alxmrs deleted the u/alxmrs/kr2/1-split-tensors branch April 18, 2026 00:07
@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

14 similar comments
@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

26 similar comments
@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

@fomo-bot

Copy link
Copy Markdown
Collaborator

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @fomo-bot again to retry.

YuanYuan98 added a commit that referenced this pull request Apr 27, 2026
Resolves conflicts with main following PRs #700 (Samudra 2 paper site),
#701 (separate prognostic/boundary tensors), and #539 (ARM PyTorch 2.9).

- pages.yml: keep our Zensical build pipeline; relocate docs/samudra2/
  under docs/static/samudra2/ so the existing static-copy step publishes
  it to /samudra2/. Delete docs/{index.html,404.html,.nojekyll} from
  #700 — superseded by docs/static/index.html (strict superset adding
  the API Docs card).
- stepper.py: take main's module-level train_batch / validate_batch /
  run_rollout functions; lift docstring to module level so
  mkdocstrings still renders it.
- uv.lock: regenerated to include docs deps on top of main's PyTorch
  2.9 / CUDA 13 ARM index.

Verified locally: zensical build succeeds; /, /docs/, /samudra2/ all
return 200.

Finished by Claude Code

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants