Skip to content

Commit 1775016

Browse files
corbtCursor Bot
andauthored
feat: add persistent state methods to Model (#522)
* feat: add persistent state methods to Model Add write_state() and read_state() methods to art.Model that persist arbitrary JSON state to the model's output directory (state.json). Features: - StateType type parameter with default dict[str, Any] for type safety - write_state(state: StateType) - persists state as JSON - read_state() -> StateType | None - reads state, returns None if not found - Full backward compatibility: existing Model[MyConfig] syntax still works - Optional type safety: Model[MyConfig, MyState] enforces state type This enables filesystem-based state tracking for training resumption, dataset position, and other metadata that was previously tracked via wandb. * fix: resolve pyright type checking errors - Add pyright: ignore[reportInconsistentOverload] to __new__ overloads - Update Backend method signatures to use AnyModel/AnyTrainableModel type aliases - Add type: ignore comment for return value in Model.__new__ --------- Co-authored-by: Cursor Bot <bot@cursor.com>
1 parent 003bf38 commit 1775016

2 files changed

Lines changed: 70 additions & 24 deletions

File tree

src/art/backend.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterable, Literal
2+
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterable, Literal, TypeAlias
33
import warnings
44

55
import httpx
@@ -20,6 +20,10 @@
2020
if TYPE_CHECKING:
2121
from .model import Model, TrainableModel
2222

23+
# Type aliases for models with any config/state type (for backend method signatures)
24+
AnyModel: TypeAlias = "Model[Any, Any]"
25+
AnyTrainableModel: TypeAlias = "TrainableModel[Any, Any]"
26+
2327

2428
class Backend:
2529
def __init__(
@@ -39,7 +43,7 @@ async def close(self) -> None:
3943

4044
async def register(
4145
self,
42-
model: "Model",
46+
model: AnyModel,
4347
) -> None:
4448
"""
4549
Registers a model with the Backend for logging and/or training.
@@ -50,14 +54,14 @@ async def register(
5054
response = await self._client.post("/register", json=model.safe_model_dump())
5155
response.raise_for_status()
5256

53-
async def _get_step(self, model: "TrainableModel") -> int:
57+
async def _get_step(self, model: AnyTrainableModel) -> int:
5458
response = await self._client.post("/_get_step", json=model.safe_model_dump())
5559
response.raise_for_status()
5660
return response.json()
5761

5862
async def _delete_checkpoint_files(
5963
self,
60-
model: "TrainableModel",
64+
model: AnyTrainableModel,
6165
steps_to_keep: list[int],
6266
) -> None:
6367
response = await self._client.post(
@@ -68,7 +72,7 @@ async def _delete_checkpoint_files(
6872

6973
async def _prepare_backend_for_training(
7074
self,
71-
model: "TrainableModel",
75+
model: AnyTrainableModel,
7276
config: dev.OpenAIServerConfig | None,
7377
) -> tuple[str, str]:
7478
response = await self._client.post(
@@ -80,7 +84,7 @@ async def _prepare_backend_for_training(
8084
base_url, api_key = tuple(response.json())
8185
return base_url, api_key
8286

83-
def _model_inference_name(self, model: "Model", step: int | None = None) -> str:
87+
def _model_inference_name(self, model: AnyModel, step: int | None = None) -> str:
8488
"""Return the inference name for a model checkpoint.
8589
8690
Override in subclasses to provide backend-specific naming.
@@ -93,7 +97,7 @@ def _model_inference_name(self, model: "Model", step: int | None = None) -> str:
9397

9498
async def train(
9599
self,
96-
model: "TrainableModel",
100+
model: AnyTrainableModel,
97101
trajectory_groups: Iterable[TrajectoryGroup],
98102
**kwargs: Any,
99103
) -> TrainResult:
@@ -114,7 +118,7 @@ async def train(
114118

115119
async def _train_model(
116120
self,
117-
model: "TrainableModel",
121+
model: AnyTrainableModel,
118122
trajectory_groups: list[TrajectoryGroup],
119123
config: TrainConfig,
120124
dev_config: dev.TrainConfig,
@@ -152,7 +156,7 @@ async def _train_model(
152156
@log_http_errors
153157
async def _experimental_pull_from_s3(
154158
self,
155-
model: "Model",
159+
model: AnyModel,
156160
*,
157161
s3_bucket: str | None = None,
158162
prefix: str | None = None,
@@ -191,7 +195,7 @@ async def _experimental_pull_from_s3(
191195
@log_http_errors
192196
async def _experimental_push_to_s3(
193197
self,
194-
model: "Model",
198+
model: AnyModel,
195199
*,
196200
s3_bucket: str | None = None,
197201
prefix: str | None = None,
@@ -215,7 +219,7 @@ async def _experimental_push_to_s3(
215219
@log_http_errors
216220
async def _experimental_fork_checkpoint(
217221
self,
218-
model: "Model",
222+
model: AnyModel,
219223
from_model: str,
220224
from_project: str | None = None,
221225
from_s3_bucket: str | None = None,

src/art/model.py

Lines changed: 55 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from datetime import datetime
22
import json
33
import os
4-
from typing import TYPE_CHECKING, Generic, Iterable, Optional, TypeVar, cast, overload
4+
from typing import TYPE_CHECKING, Any, Generic, Iterable, Optional, cast, overload
55
import warnings
66

77
import httpx
88
from openai import AsyncOpenAI, DefaultAsyncHttpxClient
99
import polars as pl
1010
from pydantic import BaseModel
11-
from typing_extensions import Never
11+
from typing_extensions import Never, TypeVar
1212

1313
from . import dev
1414
from .trajectories import Trajectory, TrajectoryGroup
@@ -23,11 +23,12 @@
2323

2424

2525
ModelConfig = TypeVar("ModelConfig", bound=BaseModel | None)
26+
StateType = TypeVar("StateType", bound=dict[str, Any], default=dict[str, Any])
2627

2728

2829
class Model(
2930
BaseModel,
30-
Generic[ModelConfig],
31+
Generic[ModelConfig, StateType],
3132
):
3233
"""
3334
A model is an object that can be passed to your `rollout` function, and used
@@ -129,7 +130,7 @@ def __new__(
129130
inference_model_name: str | None = None,
130131
base_path: str = ".art",
131132
report_metrics: list[str] | None = None,
132-
) -> "Model[None]": ...
133+
) -> "Model[None, dict[str, Any]]": ...
133134

134135
@overload
135136
def __new__(
@@ -145,14 +146,14 @@ def __new__(
145146
inference_model_name: str | None = None,
146147
base_path: str = ".art",
147148
report_metrics: list[str] | None = None,
148-
) -> "Model[ModelConfig]": ...
149+
) -> "Model[ModelConfig, dict[str, Any]]": ...
149150

150-
def __new__(
151+
def __new__( # pyright: ignore[reportInconsistentOverload]
151152
cls,
152153
*args,
153154
**kwargs,
154-
) -> "Model[ModelConfig] | Model[None]":
155-
return super().__new__(cls)
155+
) -> "Model[ModelConfig, StateType]":
156+
return super().__new__(cls) # type: ignore[return-value]
156157

157158
def safe_model_dump(self, *args, **kwargs) -> dict:
158159
"""
@@ -260,6 +261,47 @@ def _get_output_dir(self) -> str:
260261
"""Get the output directory for this model."""
261262
return f"{self.base_path}/{self.project}/models/{self.name}"
262263

264+
def write_state(self, state: StateType) -> None:
265+
"""Write persistent state to the model directory as JSON.
266+
267+
This state is stored in `state.json` within the model's output directory
268+
and can be used to track training progress, dataset position, or any
269+
other information that should persist across runs.
270+
271+
Args:
272+
state: A dictionary of JSON-serializable values to persist.
273+
274+
Example:
275+
model.write_state({
276+
"step": 5,
277+
"dataset_offset": 100,
278+
"last_checkpoint_time": "2024-01-15T10:30:00",
279+
})
280+
"""
281+
output_dir = self._get_output_dir()
282+
os.makedirs(output_dir, exist_ok=True)
283+
with open(f"{output_dir}/state.json", "w") as f:
284+
json.dump(state, f, indent=2)
285+
286+
def read_state(self) -> StateType | None:
287+
"""Read persistent state from the model directory.
288+
289+
Returns:
290+
The state dictionary if it exists, or None if no state has been saved.
291+
292+
Example:
293+
state = model.read_state()
294+
if state:
295+
start_step = state["step"]
296+
dataset_offset = state["dataset_offset"]
297+
"""
298+
output_dir = self._get_output_dir()
299+
state_path = f"{output_dir}/state.json"
300+
if not os.path.exists(state_path):
301+
return None
302+
with open(state_path, "r") as f:
303+
return json.load(f)
304+
263305
def _get_wandb_run(self) -> Optional["Run"]:
264306
"""Get or create the wandb run for this model."""
265307
import wandb
@@ -429,7 +471,7 @@ async def get_step(self) -> int:
429471
# ---------------------------------------------------------------------------
430472

431473

432-
class TrainableModel(Model[ModelConfig], Generic[ModelConfig]):
474+
class TrainableModel(Model[ModelConfig, StateType], Generic[ModelConfig, StateType]):
433475
base_model: str
434476
# Override discriminator field for FastAPI serialization
435477
trainable: bool = True
@@ -480,7 +522,7 @@ def __new__(
480522
base_path: str = ".art",
481523
report_metrics: list[str] | None = None,
482524
_internal_config: dev.InternalModelConfig | None = None,
483-
) -> "TrainableModel[None]": ...
525+
) -> "TrainableModel[None, dict[str, Any]]": ...
484526

485527
@overload
486528
def __new__(
@@ -495,13 +537,13 @@ def __new__(
495537
base_path: str = ".art",
496538
report_metrics: list[str] | None = None,
497539
_internal_config: dev.InternalModelConfig | None = None,
498-
) -> "TrainableModel[ModelConfig]": ...
540+
) -> "TrainableModel[ModelConfig, dict[str, Any]]": ...
499541

500-
def __new__(
542+
def __new__( # pyright: ignore[reportInconsistentOverload]
501543
cls,
502544
*args,
503545
**kwargs,
504-
) -> "TrainableModel[ModelConfig] | TrainableModel[None]":
546+
) -> "TrainableModel[ModelConfig, StateType]":
505547
return super().__new__(cls) # type: ignore
506548

507549
def model_dump(self, *args, **kwargs) -> dict:

0 commit comments

Comments
 (0)