Add FCN3 checkpoint state#923
Conversation
Greptile SummaryThis PR adds opt-in "full" checkpoint state to FCN3, allowing a prognostic rollout to resume from the model's internal noise states rather than reinitialising from user-facing IO. A new
|
| Filename | Overview |
|---|---|
| earth2studio/models/px/fcn3.py | Adds checkpoint save/restore state to FCN3: new _FCN3CheckpointState dataclass, _save_checkpoint_state and _restore_checkpoint_state helpers, and modified __call__ / _default_generator to integrate them. __call__ silently discards user input when a full checkpoint is loaded (undocumented, untested), create_iterator docstring is stale for the restored path, and a narrow CUDA device-comparison issue in checkpoint_tensor could skip a needed clone(). |
| test/models/px/test_fcn3.py | Adds PhooRestartFCN3ModelWrapper, _phoo_fcn3 helper, and test_fcn3_checkpoint_state_round_trip_with_phoo_model which validates iterator-based checkpoint round-trip. Missing test coverage for the __call__ path with a loaded checkpoint. |
Comments Outside Diff (1)
-
earth2studio/models/px/fcn3.py, line 523-542 (link)create_iteratordocstring no longer accurate when checkpoint is restoredThe docstring says "Will return the initial condition first (0th step)" unconditionally, but when a "full" checkpoint is loaded
_default_generatorskips the initial yield entirely — the firstnext()call returns step N+1 (the step after the checkpoint boundary). Users who writenext(iterator)to discard the IC before looping will advance one step too far on a resumed run. The docstring (and ideally a runtime note or log line) should document the divergent behavior.
Reviews (1): Last reviewed commit: "Add checkpoint state to FCN3" | Re-trigger Greptile
| def checkpoint_tensor(tensor: torch.Tensor) -> torch.Tensor: | ||
| tensor = tensor.detach() | ||
| if tensor.device == self.checkpoint.device: | ||
| return tensor.clone() | ||
| return tensor.to(self.checkpoint.device) |
There was a problem hiding this comment.
checkpoint_tensor can skip clone() on same-device CUDA tensors
When Checkpoint(device="cuda") (without an explicit index) is used and the model tensor lives on cuda:0, torch.device("cuda") != torch.device("cuda:0") evaluates to True, so the branch falls through to .to(self.checkpoint.device). Calling .to() on a tensor that is already on the target device returns self (shared storage), not a copy. Any subsequent in-place operation on the live x would then silently corrupt the saved checkpoint field. Normalising both sides to an indexed device before comparing eliminates the ambiguity.
| def checkpoint_tensor(tensor: torch.Tensor) -> torch.Tensor: | |
| tensor = tensor.detach() | |
| if tensor.device == self.checkpoint.device: | |
| return tensor.clone() | |
| return tensor.to(self.checkpoint.device) | |
| def checkpoint_tensor(tensor: torch.Tensor) -> torch.Tensor: | |
| tensor = tensor.detach() | |
| target = self.checkpoint.device | |
| # Normalise both sides so cuda == cuda:0 comparisons work correctly. | |
| src = torch.device(tensor.device.type, tensor.device.index or 0) | |
| tgt = torch.device(target.type, target.index or 0) | |
| if src == tgt: | |
| return tensor.clone() | |
| return tensor.to(target) |
| @@ -413,11 +490,12 @@ def __call__( | |||
| tuple[torch.Tensor, CoordSystem] | |||
| Output tensor and coordinate system | |||
| """ | |||
| # Initialize the internal noise states | |||
| # for each batch index, we will have a list of noise states for each separate time | |||
| self._reset_internal_state(len(coords["batch"]), len(coords["time"])) | |||
| output, coords = self._forward(x, coords) | |||
| return output, coords | |||
| x, coords, restored = self._restore_checkpoint_state(x, coords) | |||
| if not restored: | |||
| # Initialize the internal noise states for each batch and time. | |||
| self._reset_internal_state(len(coords["batch"]), len(coords["time"])) | |||
|
|
|||
| return self._forward(x, coords) | |||
There was a problem hiding this comment.
__call__ silently ignores user-provided x and coords when a full checkpoint is loaded
When checkpoint_state_loaded=True with policy="full", _restore_checkpoint_state replaces the caller's x and coords with the checkpoint's saved state. The user's inputs are completely discarded with no warning, and the returned tensor/coords come from a different starting point than what was passed in. Direct __call__ invocations inside a with checkpoint.select(-1): block will produce silently wrong outputs. There is no test covering this path, and the method's docstring gives no hint that inputs may be overridden.
689c708 to
3c95e3f
Compare
ada5c50 to
29da9e2
Compare
…talog # Conflicts: # CHANGELOG.md
3c95e3f to
ffbc639
Compare
29da9e2 to
2cd1502
Compare
ffbc639 to
abd36dc
Compare
2cd1502 to
b610945
Compare
abd36dc to
c4b148f
Compare
b610945 to
ae5b4b3
Compare
c4b148f to
d82244e
Compare
ae5b4b3 to
a745cdc
Compare
d82244e to
53c611c
Compare
a745cdc to
494846c
Compare
53c611c to
99c112a
Compare
494846c to
6cd8881
Compare
99c112a to
8eaa079
Compare
Earth2Studio Pull Request
Description
Third PR in the checkpointing stack.
Adds checkpoint opt-in state to FCN3 so a selected full checkpoint can resume a prognostic rollout from the model's internal continuation state rather than assuming user-facing IO output is sufficient to reinitialize the model. The iterator consumes the restored checkpoint boundary internally and yields the next forecast state.
Stack
Validation
uv run ruff check earth2studio/models/px/fcn3.py test/models/px/test_fcn3.pyuv run pytest test/models/px/test_fcn3.py::test_fcn3_checkpoint_state_round_trip_with_phoo_model -qgit diff --checkFull
test/models/px/test_fcn3.pywas not clean in this local environment because the optionalfcn3dependency group is not installed.Checklist
Dependencies
Depends on #922.