Skip to content

Add FCN3 checkpoint state#923

Open
NickGeneva wants to merge 11 commits into
codex/checkpoint-gaussianfrom
codex/checkpoint-fcn3
Open

Add FCN3 checkpoint state#923
NickGeneva wants to merge 11 commits into
codex/checkpoint-gaussianfrom
codex/checkpoint-fcn3

Conversation

@NickGeneva

@NickGeneva NickGeneva commented Jun 15, 2026

Copy link
Copy Markdown
Collaborator

Earth2Studio Pull Request

Description

Third PR in the checkpointing stack.

Adds checkpoint opt-in state to FCN3 so a selected full checkpoint can resume a prognostic rollout from the model's internal continuation state rather than assuming user-facing IO output is sufficient to reinitialize the model. The iterator consumes the restored checkpoint boundary internally and yields the next forecast state.

Stack

  1. Add checkpoint utilities and workflow support #912: checkpoint utilities and built-in workflow support
  2. Add Gaussian perturbation checkpoint state #922: Gaussian perturbation checkpoint state
  3. This PR: FCN3 checkpoint state
  4. Add model checkpoint update developer skill #924: developer skill for adding model checkpoint support

Validation

  • uv run ruff check earth2studio/models/px/fcn3.py test/models/px/test_fcn3.py
  • uv run pytest test/models/px/test_fcn3.py::test_fcn3_checkpoint_state_round_trip_with_phoo_model -q
  • git diff --check

Full test/models/px/test_fcn3.py was not clean in this local environment because the optional fcn3 dependency group is not installed.

Checklist

Dependencies

Depends on #922.

@greptile-apps

greptile-apps Bot commented Jun 15, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR adds opt-in "full" checkpoint state to FCN3, allowing a prognostic rollout to resume from the model's internal noise states rather than reinitialising from user-facing IO. A new _FCN3CheckpointState dataclass is introduced, and _save_checkpoint_state / _restore_checkpoint_state helpers are wired into both __call__ and _default_generator.

  • _default_generator conditionally skips the initial IC yield when a full checkpoint is loaded, resuming the while loop directly from the saved tensor; a new round-trip test in test_fcn3.py validates this path using a PhooRestartFCN3ModelWrapper that makes the output state-dependent.
  • __call__ also honours the restored checkpoint state, but the user-provided x and coords are silently replaced by the checkpoint's saved values with no documentation or test coverage for this code path.
  • The create_iterator docstring still promises "Will return the initial condition first (0th step)" unconditionally, which is no longer true for a resumed run.

Confidence Score: 3/5

Safe to merge for the iterator use-case that the PR targets, but the direct __call__ path with a loaded checkpoint silently drops user input and is untested, and the create_iterator docstring now mis-describes behavior on a resumed run.

The iterator round-trip is well-tested and the core checkpoint save/restore logic is sound. However, __call__ silently ignores caller-supplied tensors when a full checkpoint is loaded — a behavior change with no documentation and no test — and the stale create_iterator docstring could mislead users who call next(iterator) to skip the IC before looping, causing them to advance one step too far on a resumed run.

earth2studio/models/px/fcn3.py — the __call__ restore path and the create_iterator docstring both need attention before broader adoption of the checkpoint API.

Important Files Changed

Filename Overview
earth2studio/models/px/fcn3.py Adds checkpoint save/restore state to FCN3: new _FCN3CheckpointState dataclass, _save_checkpoint_state and _restore_checkpoint_state helpers, and modified __call__ / _default_generator to integrate them. __call__ silently discards user input when a full checkpoint is loaded (undocumented, untested), create_iterator docstring is stale for the restored path, and a narrow CUDA device-comparison issue in checkpoint_tensor could skip a needed clone().
test/models/px/test_fcn3.py Adds PhooRestartFCN3ModelWrapper, _phoo_fcn3 helper, and test_fcn3_checkpoint_state_round_trip_with_phoo_model which validates iterator-based checkpoint round-trip. Missing test coverage for the __call__ path with a loaded checkpoint.

Comments Outside Diff (1)

  1. earth2studio/models/px/fcn3.py, line 523-542 (link)

    P1 create_iterator docstring no longer accurate when checkpoint is restored

    The docstring says "Will return the initial condition first (0th step)" unconditionally, but when a "full" checkpoint is loaded _default_generator skips the initial yield entirely — the first next() call returns step N+1 (the step after the checkpoint boundary). Users who write next(iterator) to discard the IC before looping will advance one step too far on a resumed run. The docstring (and ideally a runtime note or log line) should document the divergent behavior.

Reviews (1): Last reviewed commit: "Add checkpoint state to FCN3" | Re-trigger Greptile

Comment on lines +368 to +372
def checkpoint_tensor(tensor: torch.Tensor) -> torch.Tensor:
tensor = tensor.detach()
if tensor.device == self.checkpoint.device:
return tensor.clone()
return tensor.to(self.checkpoint.device)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 checkpoint_tensor can skip clone() on same-device CUDA tensors

When Checkpoint(device="cuda") (without an explicit index) is used and the model tensor lives on cuda:0, torch.device("cuda") != torch.device("cuda:0") evaluates to True, so the branch falls through to .to(self.checkpoint.device). Calling .to() on a tensor that is already on the target device returns self (shared storage), not a copy. Any subsequent in-place operation on the live x would then silently corrupt the saved checkpoint field. Normalising both sides to an indexed device before comparing eliminates the ambiguity.

Suggested change
def checkpoint_tensor(tensor: torch.Tensor) -> torch.Tensor:
tensor = tensor.detach()
if tensor.device == self.checkpoint.device:
return tensor.clone()
return tensor.to(self.checkpoint.device)
def checkpoint_tensor(tensor: torch.Tensor) -> torch.Tensor:
tensor = tensor.detach()
target = self.checkpoint.device
# Normalise both sides so cuda == cuda:0 comparisons work correctly.
src = torch.device(tensor.device.type, tensor.device.index or 0)
tgt = torch.device(target.type, target.index or 0)
if src == tgt:
return tensor.clone()
return tensor.to(target)

Comment on lines 473 to +498
@@ -413,11 +490,12 @@ def __call__(
tuple[torch.Tensor, CoordSystem]
Output tensor and coordinate system
"""
# Initialize the internal noise states
# for each batch index, we will have a list of noise states for each separate time
self._reset_internal_state(len(coords["batch"]), len(coords["time"]))
output, coords = self._forward(x, coords)
return output, coords
x, coords, restored = self._restore_checkpoint_state(x, coords)
if not restored:
# Initialize the internal noise states for each batch and time.
self._reset_internal_state(len(coords["batch"]), len(coords["time"]))

return self._forward(x, coords)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 __call__ silently ignores user-provided x and coords when a full checkpoint is loaded

When checkpoint_state_loaded=True with policy="full", _restore_checkpoint_state replaces the caller's x and coords with the checkpoint's saved state. The user's inputs are completely discarded with no warning, and the returned tensor/coords come from a different starting point than what was passed in. Direct __call__ invocations inside a with checkpoint.select(-1): block will produce silently wrong outputs. There is no test covering this path, and the method's docstring gives no hint that inputs may be overridden.

@NickGeneva NickGeneva force-pushed the codex/checkpoint-gaussian branch from 689c708 to 3c95e3f Compare June 15, 2026 17:31
@NickGeneva NickGeneva force-pushed the codex/checkpoint-fcn3 branch from ada5c50 to 29da9e2 Compare June 15, 2026 17:31
@NickGeneva NickGeneva force-pushed the codex/checkpoint-gaussian branch from 3c95e3f to ffbc639 Compare June 15, 2026 17:44
@NickGeneva NickGeneva force-pushed the codex/checkpoint-fcn3 branch from 29da9e2 to 2cd1502 Compare June 15, 2026 17:45
@NickGeneva NickGeneva force-pushed the codex/checkpoint-gaussian branch from ffbc639 to abd36dc Compare June 15, 2026 17:48
@NickGeneva NickGeneva force-pushed the codex/checkpoint-fcn3 branch from 2cd1502 to b610945 Compare June 15, 2026 17:48
@NickGeneva NickGeneva force-pushed the codex/checkpoint-gaussian branch from abd36dc to c4b148f Compare June 15, 2026 17:51
@NickGeneva NickGeneva force-pushed the codex/checkpoint-fcn3 branch from b610945 to ae5b4b3 Compare June 15, 2026 17:51
@NickGeneva NickGeneva force-pushed the codex/checkpoint-gaussian branch from c4b148f to d82244e Compare June 15, 2026 17:54
@NickGeneva NickGeneva force-pushed the codex/checkpoint-fcn3 branch from ae5b4b3 to a745cdc Compare June 15, 2026 17:54
@NickGeneva NickGeneva force-pushed the codex/checkpoint-gaussian branch from d82244e to 53c611c Compare June 15, 2026 17:56
@NickGeneva NickGeneva force-pushed the codex/checkpoint-fcn3 branch from a745cdc to 494846c Compare June 15, 2026 17:56
@NickGeneva NickGeneva force-pushed the codex/checkpoint-gaussian branch from 53c611c to 99c112a Compare June 15, 2026 22:29
@NickGeneva NickGeneva force-pushed the codex/checkpoint-fcn3 branch from 494846c to 6cd8881 Compare June 15, 2026 22:30
@NickGeneva NickGeneva force-pushed the codex/checkpoint-gaussian branch from 99c112a to 8eaa079 Compare June 26, 2026 01:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant