Separate prognostic and boundary tensors end-to-end#701
Conversation
Thread prognostic and boundary inputs as separate tensors through the data pipeline, stepper, base model, and all model implementations. Previously they were concatenated along the channel dim before reaching the model; now each model's forward_once receives (prog, boundary, ctx). FOMO concatenates before its single-stream encoder (the dual-perceiver encoder that enables cross-resolution fusion lands in a follow-up). Samudra and FOMini concatenate inside their forward_once as before. Also includes unrelated changes from previously-merged PRs that were on this branch: DataLoadingConfig (#668), Nsight profiling (#674), samudra highres configs (#679), torch profiling improvements (#677). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> # Conflicts: # src/ocean_emulators/stepper.py # tests/test_stepper.py
3392629 to
93bca31
Compare
|
@codex may I have a review of this PR? |
jder
left a comment
There was a problem hiding this comment.
A few suggestions but generally LGTM! Love the getting rid of Stepper. I kinda wish Claude did not leave so many vague comments about matching "existing behavior" but 🤷
| epoch=None, | ||
| ) -> ModelInferenceOutput: | ||
| out_shape = (num_steps, *dataset[0][1].shape[1:]) | ||
| # `dataset[idx]` returns `(prog, boundary, label)`. |
| ) | ||
| + decodings | ||
| ) | ||
| pred = prog_tensor.to(device=get_device()) + decodings |
There was a problem hiding this comment.
We seem to have lost the [0, ... part from the old code. IIUC the prog_tensor has a batch dimension which we're trying to get rid of here before adding decodings. I guess we don't have inference tests that would catch this? Or am I missing something?
There was a problem hiding this comment.
I know this is covered by tests because when I removed the unexpand(0) a test failed. I tried two variants (the status quo and a adhering to the original version) and both versions worked. Merging this version.
There was a problem hiding this comment.
@fomo-bot can you help me understand why this version works given my above comments?
There was a problem hiding this comment.
@fomo-bot ok, can you please try again to help me understand why this version works given my above comments?
There was a problem hiding this comment.
Short version: your intuition about the batch dim was right. The version without [0] works because this inference path is batch-1 and PyTorch is forgiving about leading singleton dims here, not because the batch dim truly disappeared.
The shapes are:
prog_tensor/boundary_tensor:[1, C, H, W]during rollout (InferenceDatasetreturns a singletonwindow_dim).decodings = forward_once(...):[1, C, H, W].pred_tensor[step]:[C, H, W], becausepred_tensoris allocated fromdataset[0][-1].shape[1:].
The old code did input_tensor[0, :out_channels] + decodings. input_tensor[0, ...] is [C, H, W], but adding batched decodings broadcasts the result back to [1, C, H, W]. The version you flagged, prog_tensor + decodings, is also [1, C, H, W] and has the same values because the batch size is 1. Then pred_tensor[step] = pred accepts/stores that leading singleton dimension into the [C, H, W] slot.
So both variants pass for batch-1 inference. The current branch’s prog_tensor[0] + decodings is closer to the old residual-base selection, but it still produces a batched temporary because decodings is batched. If we want the shape intent to be explicit, I would make it prog_tensor[0] + decodings[0] (or squeeze once before assignment) and add an inference residual test; the existing rollout tests mostly cover pred_residuals=False, while residual behavior is covered more directly in validation.
There was a problem hiding this comment.
Posted the reply on the PR thread: https://github.com/Open-Athena/Ocean_Emulator/pull/701#discussion_r3118214309
Summary: jder’s batch-dimension concern was right. The version works because inference is batch-1, the old code also broadcast the unbatched residual against batched decodings, and PyTorch accepts the leading singleton dim when assigning into pred_tensor[step].
I also ran uv run pytest tests/test_stepper.py -q: 72 passed. No code changes made.
Codex work
- Repository:
Open-Athena/Ocean_Emulator - Trigger: https://github.com/Open-Athena/Ocean_Emulator/pull/701#discussion_r3118190421
- Request ID:
99478a69-2ede-44bf-812c-dd8efc1d40de - Accepted:
true
Codex transcript
- I’ll inspect the PR discussion and the exact diff around that review comment so I can answer the “why this works” question from the code, not from guesses.
- I found the thread: the key hunk is the residual addition in `BaseModel.inference`, and the branch has later commits that changed that line again. I’m going to inspect the current file plus the rollout tests to separate “why no crash” from “whether the shape is actually t...[truncated]
- The current branch now has `prog_tensor[0] + decodings`, while the older review hunk had `prog_tensor + decodings`. I’m checking the intermediate commits because the explanation differs: one shape is intentional, the other only survives because of PyTorch broadcasting and assig...[truncated]
- The plain `python` executable is not installed here, so I’m using `python3` for a small shape check.
- `python3` exists but this shell environment does not have `torch` importable. I’ll rely on the code paths and, if there is a project environment configured, run the targeted tests through that rather than the bare interpreter.
- One subtlety here is that the old code also did not fully remove the batch dimension from the residual result, because `decodings` is still batched. I’m reading the model outputs now; that determines whether the explanation is “same values, different temporary rank” or whet...[truncated]
- The shape check confirms the surprising part: PyTorch will accept a source like `[1, C, H, W]` when assigning into `pred_tensor[step]` shaped `[C, H, W]` because the extra leading dimension is size 1. I’m running the local stepper tests now to see what this branch actually cove...[truncated]
- I have enough to answer the thread. I’m going to post a concise explanation: your batch-dimension concern is valid, and the reason the version passes is singleton batch broadcasting plus PyTorch accepting that singleton source when writing into the unbatched rollout buffer.
- The first GitHub API endpoint form returned 404, so I’m retrying with the pull-number-scoped review-comment reply endpoint.
- Posted the reply on the PR thread: https://github.com/Open-Athena/Ocean_Emulator/pull/701#discussion_r3118214309
Summary: jder’s batch-dimension concern was right. The version works because inference is batch-1, the old code also broadcast the unbatched residual against batche...[truncated]
| return prognostic_steps | ||
| return rearrange( | ||
| steps, "batch time variable lat lon -> batch (time variable) lat lon" | ||
| ) |
There was a problem hiding this comment.
Thanks, I was happy about this too.
Co-authored-by: Jesse Rusak <jesse@openathena.ai>
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
14 similar comments
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
26 similar comments
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
|
Sorry, I hit an internal failure while handling this mention and couldn't complete the request. Please mention |
Resolves conflicts with main following PRs #700 (Samudra 2 paper site), #701 (separate prognostic/boundary tensors), and #539 (ARM PyTorch 2.9). - pages.yml: keep our Zensical build pipeline; relocate docs/samudra2/ under docs/static/samudra2/ so the existing static-copy step publishes it to /samudra2/. Delete docs/{index.html,404.html,.nojekyll} from #700 — superseded by docs/static/index.html (strict superset adding the API Docs card). - stepper.py: take main's module-level train_batch / validate_batch / run_rollout functions; lift docstring to module level so mkdocstrings still renders it. - uv.lock: regenerated to include docs deps on top of main's PyTorch 2.9 / CUDA 13 ARM index. Verified locally: zensical build succeeds; /, /docs/, /samudra2/ all return 200. Finished by Claude Code Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Thread prognostic and boundary inputs as separate tensors through the data pipeline, stepper, base model, and all model implementations. Previously they were concatenated along the channel dim before reaching the model; now each model's forward_once receives (prog, boundary, ctx).
FOMO concatenates before its single-stream encoder (the dual-perceiver encoder that enables cross-resolution fusion lands in a follow-up). Samudra and FOMini concatenate inside their forward_once as before.
This is #681 broken up into a stack of three small PRs. 1/3