Skip to content

Commit 7d8dc6d

Browse files
corbtCursor Bot
andauthored
feat(tinker): add TinkerNativeBackend (#532)
* feat: add TinkerNativeBackend for native training Separate native Tinker training/inference from LocalBackend to keep the API clear while enabling explicit loss/checkpoint behavior and config. * fix: address pre-commit type and format issues Align tinker native types with OpenAI tooling and update tests to avoid invalid type expressions under pyright. * feat: add safer state merge and policy tracking Use merge_state for backend persistence to avoid clobbering model state, and fail fast on trajectories without Choice objects to prevent no-op training. Expose policy version fields on trajectories for off-policy tracking. * feat(pipeline): add PipelineTrainer for async 3-stage training Add a new PipelineTrainer module that implements an asynchronous 3-stage pipeline (rollout, training, eval) for efficient RL training: - PipelineTrainer: Main trainer class with configurable workers, batch sizes, and off-policy limits - StatusReporter: Live progress reporting with tqdm and periodic logging - PipelineState: Shared state dataclass for stage coordination - Type definitions for RolloutFn, SingleRolloutFn, EvalFn Key features: - Async rollout workers with policy version tracking - Stale sample detection and automatic discard - Zero-variance group handling with collapse detection - Graceful signal handling (SIGINT/SIGTERM) - State persistence for training resumption - Eval scheduling with configurable intervals Also includes: - yes_no_maybe_pipeline.py: Simple example showing basic usage - binary_prefix_tool_pipeline.py: Complex example with tool calls Updates to tinker_native backend: - Add debug logging via ART_TINKER_TRAIN_LOG/ART_TINKER_SAMPLE_LOG - Add fallback for create_conversation_prefix_with_tools - Fix tool_call id handling in OpenAI server responses * fix: resolve type errors after rebasing on main with ty - Fix import path for get_free_port (moved from service to server) - Add cast for merge_state return type - Fix test to use async function for TrajectoryGroup creation - Move tinker deps to separate dependency group - Add tinker to allowed-unresolved-imports for ty --------- Co-authored-by: Cursor Bot <bot@cursor.com>
1 parent cac23d1 commit 7d8dc6d

20 files changed

Lines changed: 3074 additions & 35 deletions

pyproject.toml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ dependencies = [
99
"typer>=0.15.2",
1010
"litellm>=1.71.1",
1111
"weave>=0.52.23",
12-
"tinker>=0.8.1",
13-
"tinker-cookbook>=0.1.0",
1412
"polars>=1.26.0",
1513
"tblib>=3.0.0",
1614
"nest-asyncio>=1.6.0",
@@ -115,6 +113,9 @@ unused-ignore-comment = "ignore"
115113
# Allow unresolved imports for optional dependencies that may not be installed locally.
116114
# In CI, we install all optional deps so these will be resolved and type-checked.
117115
allowed-unresolved-imports = [
116+
# tinker deps
117+
"tinker.**",
118+
"tinker_cookbook.**",
118119
# backend deps
119120
"accelerate.**",
120121
"awscli.**",
@@ -165,6 +166,12 @@ dev = [
165166
"pyarrow>=15.0.0",
166167
"prek>=0.2.29",
167168
]
169+
tinker = [
170+
"fastapi>=0.128.0",
171+
"tinker>=0.8.1",
172+
"tinker-cookbook>=0.1.0",
173+
"uvicorn>=0.35.0",
174+
]
168175

169176
[tool.uv.sources]
170177
panza = { git = "https://github.com/corbt/panza.git" }

src/art/__init__.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,13 @@ def __init__(self, **kwargs):
5757
from .local import LocalBackend
5858
from .model import Model, TrainableModel
5959
from .serverless import ServerlessBackend
60-
from .tinker import TinkerBackend
60+
61+
try:
62+
from .tinker import TinkerBackend
63+
from .tinker_native import TinkerNativeBackend
64+
except ModuleNotFoundError:
65+
TinkerBackend = None # type: ignore[assignment]
66+
TinkerNativeBackend = None # type: ignore[assignment]
6167
from .trajectories import Trajectory, TrajectoryGroup
6268
from .types import (
6369
LocalTrainResult,
@@ -91,9 +97,10 @@ def __init__(self, **kwargs):
9197
"retry",
9298
"TrainConfig",
9399
"TrainResult",
94-
"TinkerBackend",
95100
"Trajectory",
96101
"TrajectoryGroup",
97102
"capture_yielded_trajectory",
98103
"yield_trajectory",
99104
]
105+
if TinkerBackend is not None:
106+
__all__.extend(["TinkerBackend", "TinkerNativeBackend"])

src/art/dev/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
InternalModelConfig,
55
PeftArgs,
66
TinkerArgs,
7+
TinkerNativeArgs,
78
TinkerTrainingClientArgs,
89
TrainerArgs,
910
)
@@ -16,6 +17,7 @@
1617
"InitArgs",
1718
"PeftArgs",
1819
"TinkerArgs",
20+
"TinkerNativeArgs",
1921
"TinkerTrainingClientArgs",
2022
"TrainerArgs",
2123
"get_openai_server_config",

src/art/dev/model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ class InternalModelConfig(TypedDict, total=False):
121121
engine_args: "EngineArgs"
122122
peft_args: "PeftArgs"
123123
tinker_args: "TinkerArgs | None"
124+
tinker_native_args: "TinkerNativeArgs | None"
124125
trainer_args: "TrainerArgs"
125126

126127

@@ -129,6 +130,11 @@ class TinkerArgs(TypedDict, total=False):
129130
training_client_args: "TinkerTrainingClientArgs"
130131

131132

133+
class TinkerNativeArgs(TypedDict, total=False):
134+
renderer_name: Required[str]
135+
training_client_args: "TinkerTrainingClientArgs"
136+
137+
132138
class TinkerTrainingClientArgs(TypedDict, total=False):
133139
rank: int
134140
seed: int | None

src/art/local/backend.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ async def register(
102102
Args:
103103
model: An art.Model instance.
104104
"""
105+
# Ensure model state/logging uses the backend path
106+
model.base_path = self._path
105107
output_dir = get_model_dir(model=model, art_path=self._path)
106108
os.makedirs(output_dir, exist_ok=True)
107109
with open(f"{output_dir}/model.json", "w") as f:

src/art/model.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -261,18 +261,22 @@ def _get_output_dir(self) -> str:
261261
"""Get the output directory for this model."""
262262
return f"{self.base_path}/{self.project}/models/{self.name}"
263263

264-
def write_state(self, state: StateType) -> None:
265-
"""Write persistent state to the model directory as JSON.
264+
def overwrite_state(self, state: StateType) -> None:
265+
"""Overwrite persistent state in the model directory as JSON.
266266
267267
This state is stored in `state.json` within the model's output directory
268268
and can be used to track training progress, dataset position, or any
269269
other information that should persist across runs.
270270
271+
Warning:
272+
This overwrites the entire state file. Prefer `merge_state()` unless
273+
you intentionally want to replace all existing keys.
274+
271275
Args:
272276
state: A dictionary of JSON-serializable values to persist.
273277
274278
Example:
275-
model.write_state({
279+
model.overwrite_state({
276280
"step": 5,
277281
"dataset_offset": 100,
278282
"last_checkpoint_time": "2024-01-15T10:30:00",
@@ -283,6 +287,45 @@ def write_state(self, state: StateType) -> None:
283287
with open(f"{output_dir}/state.json", "w") as f:
284288
json.dump(state, f, indent=2)
285289

290+
def write_state(self, state: StateType) -> None:
291+
"""Deprecated: use `overwrite_state()` or `merge_state()` instead."""
292+
warnings.warn(
293+
"write_state() is deprecated. Use overwrite_state() or merge_state() instead.",
294+
DeprecationWarning,
295+
stacklevel=2,
296+
)
297+
self.overwrite_state(state)
298+
299+
def merge_state(self, state: StateType) -> StateType:
300+
"""Deep-merge state into the existing state and persist it.
301+
302+
Args:
303+
state: A dictionary of JSON-serializable values to merge.
304+
305+
Returns:
306+
The merged state dictionary that was persisted.
307+
"""
308+
existing = self.read_state() or {}
309+
merged = self._deep_merge_dicts(existing, state)
310+
self.overwrite_state(merged)
311+
return cast(StateType, merged)
312+
313+
@staticmethod
314+
def _deep_merge_dicts(
315+
base: dict[str, Any], updates: dict[str, Any]
316+
) -> dict[str, Any]:
317+
merged = dict(base)
318+
for key, value in updates.items():
319+
if (
320+
key in merged
321+
and isinstance(merged[key], dict)
322+
and isinstance(value, dict)
323+
):
324+
merged[key] = Model._deep_merge_dicts(merged[key], value)
325+
else:
326+
merged[key] = value
327+
return merged
328+
286329
def read_state(self) -> StateType | None:
287330
"""Read persistent state from the model directory.
288331
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from .status import StatusReporter
2+
from .trainer import PipelineTrainer, make_group_rollout_fn
3+
from .types import EvalFn, RolloutFn, ScenarioT, SingleRolloutFn
4+
5+
__all__ = [
6+
"PipelineTrainer",
7+
"make_group_rollout_fn",
8+
"StatusReporter",
9+
"RolloutFn",
10+
"SingleRolloutFn",
11+
"EvalFn",
12+
"ScenarioT",
13+
]

0 commit comments

Comments
 (0)