Skip to content

Commit 1ff0a9a

Browse files
authored
feat: Add Tinker support to LocalBackend
1 parent 2daa845 commit 1ff0a9a

14 files changed

Lines changed: 1192 additions & 1172 deletions

File tree

.python-version

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
3.10
1+
3.11

dev/yes-no-maybe.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@
128128
"name": "python",
129129
"nbconvert_exporter": "python",
130130
"pygments_lexer": "ipython3",
131-
"version": "3.10.13"
131+
"version": "3.11.13"
132132
}
133133
},
134134
"nbformat": 4,

dev/yes-no-maybe.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,19 +40,25 @@ def with_quotes(w: str) -> str:
4040
async def main():
4141
load_dotenv()
4242

43-
backend = LocalBackend()
43+
backend = LocalBackend(in_process=True)
4444
global model
4545
base_model = os.environ.get("BASE_MODEL", "Qwen/Qwen3-30B-A3B-Instruct-2507")
4646
model = art.TrainableModel(
47-
name=os.environ.get("MODEL_NAME", "011"),
47+
name=os.environ.get("MODEL_NAME", "012"),
4848
project="yes-no-maybe",
4949
base_model=base_model,
5050
_internal_config=art.dev.InternalModelConfig(
51-
engine_args=art.dev.EngineArgs(
52-
max_lora_rank=1,
53-
),
54-
peft_args=art.dev.PeftArgs(
55-
r=1,
51+
# engine_args=art.dev.EngineArgs(
52+
# max_lora_rank=1,
53+
# ),
54+
# peft_args=art.dev.PeftArgs(
55+
# r=1,
56+
# ),
57+
tinker_args=art.dev.TinkerArgs(
58+
renderer_name="qwen3_instruct",
59+
training_client_args=art.dev.TinkerTrainingClientArgs(
60+
rank=1,
61+
),
5662
),
5763
),
5864
)
@@ -68,7 +74,7 @@ async def main():
6874
]
6975

7076
openai_client = model.openai_client()
71-
max_steps = int(os.environ.get("NUM_STEPS", "4"))
77+
max_steps = int(os.environ.get("NUM_STEPS", "20"))
7278
start_step = await model.get_step()
7379
for _ in range(start_step, start_step + max_steps):
7480
train_groups = await art.gather_trajectory_groups(

pyproject.toml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,16 @@ name = "openpipe-art"
33
version = "0.5.4"
44
description = "The OpenPipe Agent Reinforcement Training (ART) library"
55
readme = "README.md"
6-
requires-python = ">=3.10"
6+
requires-python = ">=3.11"
77
dependencies = [
88
"openai>=1.65.5",
99
"typer>=0.15.2",
1010
"litellm==1.74.1",
1111
"weave>=0.51.51",
12+
"tinker>=0.7.0",
13+
"tinker-cookbook>=0.1.0",
14+
"polars>=1.26.0",
15+
"tblib>=3.0.0",
1216
]
1317

1418
[project.optional-dependencies]
@@ -27,10 +31,9 @@ backend = [
2731
"accelerate==1.7.0",
2832
"awscli>=1.38.1",
2933
"setproctitle>=1.3.6",
30-
"tblib>=3.0.0",
34+
3135
"setuptools>=78.1.0",
3236
"wandb==0.22.1",
33-
"polars>=1.26.0",
3437
"transformers>=4.55.2,<=4.57.3",
3538
"duckdb>=1.0.0",
3639
"pyarrow>=15.0.0",
@@ -91,7 +94,7 @@ select = ["I"]
9194
[tool.ruff.lint.isort]
9295
case-sensitive = false
9396
known-first-party = ["art"]
94-
known-third-party = ["wandb"]
97+
known-third-party = ["tinker", "wandb"]
9598
force-sort-within-sections = true
9699

97100
[tool.pytest.ini_options]

src/art/dev/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
InitArgs,
44
InternalModelConfig,
55
PeftArgs,
6+
TinkerArgs,
7+
TinkerTrainingClientArgs,
68
TrainerArgs,
79
)
810
from .openai_server import OpenAIServerConfig, ServerArgs, get_openai_server_config
@@ -14,6 +16,8 @@
1416
"InternalModelConfig",
1517
"InitArgs",
1618
"PeftArgs",
19+
"TinkerArgs",
20+
"TinkerTrainingClientArgs",
1721
"TrainerArgs",
1822
"get_openai_server_config",
1923
"OpenAIServerConfig",

src/art/dev/get_model_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def get_model_config(
7777
init_args=init_args,
7878
engine_args=engine_args,
7979
peft_args=peft_args,
80+
tinker_args=config.get("tinker_args"),
8081
trainer_args=trainer_args,
8182
torchtune_args=torchtune_args,
8283
)

src/art/dev/model.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from enum import Enum
22

3-
from typing_extensions import TypedDict
3+
from typing_extensions import Required, TypedDict
44

55
from .engine import EngineArgs
66
from .torchtune import TorchtuneArgs
@@ -112,17 +112,35 @@ class InternalModelConfig(TypedDict, total=False):
112112
113113
Args:
114114
init: Arguments for initializing an Unsloth FastLanguageModel.
115+
engine: Arguments for the vLLM engine.
115116
peft: Arguments for creating an Unsloth PEFT model wrapper.
116-
train: Arguments for the GRPO trainer.
117+
tinker: Arguments for the Tinker training client.
118+
trainer: Arguments for the GRPO trainer.
119+
torchtune: Arguments for TorchTune.
117120
"""
118121

119122
init_args: "InitArgs"
120123
engine_args: "EngineArgs"
121124
peft_args: "PeftArgs"
125+
tinker_args: "TinkerArgs | None"
122126
trainer_args: "TrainerArgs"
123127
torchtune_args: TorchtuneArgs | None
124128

125129

130+
class TinkerArgs(TypedDict, total=False):
131+
renderer_name: Required[str]
132+
training_client_args: "TinkerTrainingClientArgs"
133+
134+
135+
class TinkerTrainingClientArgs(TypedDict, total=False):
136+
rank: int
137+
seed: int | None
138+
train_mlp: bool
139+
train_attn: bool
140+
train_unembed: bool
141+
user_metadata: dict[str, str] | None
142+
143+
126144
class InitArgs(TypedDict, total=False):
127145
model_name: str
128146
max_seq_length: int

src/art/local/backend.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -129,19 +129,29 @@ async def register(
129129

130130
async def _get_service(self, model: TrainableModel) -> ModelService:
131131
from ..dev.get_model_config import get_model_config
132-
from ..torchtune.service import TorchtuneService
133-
from ..unsloth.service import UnslothService
134132

135133
if model.name not in self._services:
136134
config = get_model_config(
137135
base_model=model.base_model,
138136
output_dir=get_model_dir(model=model, art_path=self._path),
139137
config=model._internal_config,
140138
)
141-
if config.get("torchtune_args") is not None:
139+
is_tinker = config.get("tinker_args") is not None
140+
if is_tinker:
141+
from ..tinker.service import TinkerService
142+
143+
service_class = TinkerService
144+
elif config.get("torchtune_args") is not None:
145+
from ..torchtune.service import TorchtuneService
146+
142147
service_class = TorchtuneService
143148
else:
149+
from ..unsloth.service import UnslothService
150+
144151
service_class = UnslothService
152+
# When moving the service to a child process, import unsloth
153+
# early to maximize optimizations
154+
os.environ["IMPORT_UNSLOTH"] = "1"
145155
self._services[model.name] = service_class(
146156
model_name=model.name,
147157
base_model=model.base_model,
@@ -151,12 +161,9 @@ async def _get_service(self, model: TrainableModel) -> ModelService:
151161
if not self._in_process:
152162
# Kill all "model-service" processes to free up GPU memory
153163
subprocess.run(["pkill", "-9", "model-service"])
154-
# When moving the service to a child process, import unsloth
155-
# early to maximize optimizations
156-
os.environ["IMPORT_UNSLOTH"] = "1"
157164
self._services[model.name] = move_to_child_process(
158165
self._services[model.name],
159-
process_name="model-service",
166+
process_name="tinker-service" if is_tinker else "model-service",
160167
)
161168
return self._services[model.name]
162169

@@ -242,6 +249,8 @@ async def _delete_checkpoints(
242249
benchmark: str,
243250
benchmark_smoothing: float,
244251
) -> None:
252+
from ..tinker.service import TinkerService
253+
245254
output_dir = get_model_dir(model=model, art_path=self._path)
246255
# Keep the latest step
247256
steps_to_keep = [get_model_step(model, self._path)]
@@ -261,7 +270,11 @@ async def _delete_checkpoints(
261270
print(f'"{output_dir}/history.jsonl" not found')
262271
except pl.exceptions.ColumnNotFoundError:
263272
print(f'No "{benchmark}" metric found in history')
264-
delete_checkpoints(output_dir, steps_to_keep)
273+
service = await self._get_service(model)
274+
if isinstance(service, TinkerService):
275+
await service.delete_checkpoints(steps_to_keep)
276+
else:
277+
delete_checkpoints(output_dir, steps_to_keep)
265278

266279
async def _prepare_backend_for_training(
267280
self,

src/art/loss.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class Loss(BaseModel):
1515
mean_policy_loss: torch.Tensor
1616
mean_kl: torch.Tensor
1717
mean_entropy: torch.Tensor | None
18+
policy_loss_sum: torch.Tensor
1819
probs_corr: torch.Tensor
1920

2021

@@ -135,6 +136,7 @@ def loss_fn(
135136
mean_policy_loss=mean_policy_loss,
136137
mean_kl=mean_kl,
137138
mean_entropy=mean_entropy,
139+
policy_loss_sum=policy_loss.sum(),
138140
probs_corr=probs_corr,
139141
)
140142

src/art/preprocessing/inputs.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from typing import TYPE_CHECKING
2+
3+
import torch
4+
5+
from .pack import PackedTensors
6+
7+
if TYPE_CHECKING:
8+
from .. import dev, types
9+
10+
11+
class TrainInputs(PackedTensors):
12+
"""Training inputs with config attached."""
13+
14+
config: "types.TrainConfig"
15+
_config: "dev.TrainConfig"
16+
return_new_logprobs: bool
17+
18+
19+
def create_train_inputs(
20+
packed_tensors: PackedTensors,
21+
offset: int,
22+
config: "types.TrainConfig",
23+
_config: "dev.TrainConfig",
24+
warmup: bool,
25+
) -> TrainInputs:
26+
"""Create TrainInputs for a single batch offset."""
27+
return TrainInputs(
28+
**{
29+
k: (
30+
v[offset : offset + 1, :1024]
31+
if warmup and v.dim() > 1
32+
else v[offset : offset + 1]
33+
)
34+
for k, v in packed_tensors.items()
35+
if isinstance(v, torch.Tensor)
36+
},
37+
pixel_values=(
38+
[None] if warmup else packed_tensors["pixel_values"][offset : offset + 1]
39+
),
40+
image_grid_thw=(
41+
[None] if warmup else packed_tensors["image_grid_thw"][offset : offset + 1]
42+
),
43+
config=(
44+
config.model_copy(update={"lr": 1e-9, "beta": 0.0, "kl_coef": 0.0})
45+
if warmup
46+
else config
47+
),
48+
_config=_config,
49+
return_new_logprobs=False,
50+
)

0 commit comments

Comments
 (0)