Skip to content

Commit 90ec50c

Browse files
corbtCursor Bot
andauthored
Vendor tinker_cookbook under art package (#535)
* feat: add TinkerNativeBackend for native training Separate native Tinker training/inference from LocalBackend to keep the API clear while enabling explicit loss/checkpoint behavior and config. * feat(pipeline): add PipelineTrainer for async 3-stage training Add a new PipelineTrainer module that implements an asynchronous 3-stage pipeline (rollout, training, eval) for efficient RL training: - PipelineTrainer: Main trainer class with configurable workers, batch sizes, and off-policy limits - StatusReporter: Live progress reporting with tqdm and periodic logging - PipelineState: Shared state dataclass for stage coordination - Type definitions for RolloutFn, SingleRolloutFn, EvalFn Key features: - Async rollout workers with policy version tracking - Stale sample detection and automatic discard - Zero-variance group handling with collapse detection - Graceful signal handling (SIGINT/SIGTERM) - State persistence for training resumption - Eval scheduling with configurable intervals Also includes: - yes_no_maybe_pipeline.py: Simple example showing basic usage - binary_prefix_tool_pipeline.py: Complex example with tool calls Updates to tinker_native backend: - Add debug logging via ART_TINKER_TRAIN_LOG/ART_TINKER_SAMPLE_LOG - Add fallback for create_conversation_prefix_with_tools - Fix tool_call id handling in OpenAI server responses * feat: vendor tinker_cookbook for tool calls Vendor renderer utilities to keep tool-call parsing and OpenAI message conversion consistent without relying on a git dependency, and wire server/backend pipelines through renderer conversions. * fix: use tinker_cookbook_v wrapper for vendored imports Add a wrapper package that forces the vendored tinker_cookbook on sys.path and switch ART imports to the new name to avoid picking up the installed package. * refactor: move vendored tinker_cookbook under art package Relocate vendored cookbook code into art.tinker_cookbook_v and remove the old top-level vendored directory to avoid sys.path manipulation. * chore: fix ruff import order in tinker native test Normalize import spacing to satisfy ruff's sorting rules. * refactor: relocate vendored cookbook under art.tinker.cookbook_v Move the vendored tinker-cookbook files into the art.tinker.cookbook_v subpackage and update import paths to match. --------- Co-authored-by: Cursor Bot <bot@cursor.com>
1 parent 7d8dc6d commit 90ec50c

25 files changed

Lines changed: 3980 additions & 899 deletions

pyproject.toml

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,17 @@ langgraph = [
5151
"langgraph>=0.6.2",
5252
"langchain-openai>=0.3.27",
5353
]
54+
tinker = [
55+
"fastapi>=0.128.0",
56+
"huggingface_hub",
57+
"numpy",
58+
"pillow",
59+
"pydantic>=2.12.5",
60+
"tinker>=0.8.1",
61+
"torch>=2.8.0",
62+
"transformers>=4.55.2,<=4.57.3",
63+
"uvicorn>=0.35.0",
64+
]
5465

5566
[project.scripts]
5667
art = "art.cli:app"
@@ -115,7 +126,6 @@ unused-ignore-comment = "ignore"
115126
allowed-unresolved-imports = [
116127
# tinker deps
117128
"tinker.**",
118-
"tinker_cookbook.**",
119129
# backend deps
120130
"accelerate.**",
121131
"awscli.**",
@@ -166,12 +176,6 @@ dev = [
166176
"pyarrow>=15.0.0",
167177
"prek>=0.2.29",
168178
]
169-
tinker = [
170-
"fastapi>=0.128.0",
171-
"tinker>=0.8.1",
172-
"tinker-cookbook>=0.1.0",
173-
"uvicorn>=0.35.0",
174-
]
175179

176180
[tool.uv.sources]
177181
panza = { git = "https://github.com/corbt/panza.git" }

src/art/__init__.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,6 @@ def __init__(self, **kwargs):
5757
from .local import LocalBackend
5858
from .model import Model, TrainableModel
5959
from .serverless import ServerlessBackend
60-
61-
try:
62-
from .tinker import TinkerBackend
63-
from .tinker_native import TinkerNativeBackend
64-
except ModuleNotFoundError:
65-
TinkerBackend = None # type: ignore[assignment]
66-
TinkerNativeBackend = None # type: ignore[assignment]
6760
from .trajectories import Trajectory, TrajectoryGroup
6861
from .types import (
6962
LocalTrainResult,
@@ -102,5 +95,3 @@ def __init__(self, **kwargs):
10295
"capture_yielded_trajectory",
10396
"yield_trajectory",
10497
]
105-
if TinkerBackend is not None:
106-
__all__.extend(["TinkerBackend", "TinkerNativeBackend"])

src/art/pipeline_trainer/binary_prefix_tool_pipeline.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pathlib import Path
88
import re
99
from typing import Any, cast
10+
import uuid
1011

1112
from dotenv import load_dotenv
1213
from openai.types.chat.chat_completion_tool_choice_option_param import (
@@ -16,6 +17,7 @@
1617
import polars as pl
1718

1819
import art
20+
from art.tinker_native import TinkerNativeBackend
1921

2022
from . import PipelineTrainer, make_group_rollout_fn
2123

@@ -178,6 +180,8 @@ async def main() -> None:
178180
"BASE_MODEL", "Qwen/Qwen3-4B-Instruct-2507"
179181
) # Qwen/Qwen3-30B-A3B-Instruct-2507
180182
model_name = os.environ.get("MODEL_NAME", "pipeline-binary-prefix-tool")
183+
run_suffix = os.environ.get("RUN_SUFFIX") or uuid.uuid4().hex[:8]
184+
model_name = f"{model_name}-{run_suffix}"
181185
project = os.environ.get("PROJECT", "binary-prefix-tool-pipeline")
182186
art_path = os.environ.get("ART_PATH")
183187

@@ -213,7 +217,7 @@ async def main() -> None:
213217
}
214218
}
215219

216-
backend = art.TinkerNativeBackend(path=art_path)
220+
backend = TinkerNativeBackend(path=art_path)
217221
model = art.TrainableModel(
218222
name=model_name,
219223
project=project,
@@ -239,6 +243,7 @@ async def do_rollout(scenario: Scenario, temp: float) -> art.Trajectory:
239243
)
240244
choice = response.choices[0]
241245
raw_guess, source = extract_guess(choice)
246+
sampled_content = choice.message.content or ""
242247
guess = raw_guess or ""
243248
valid_guess = is_valid_guess(guess)
244249
prefix_len = shared_prefix_len(guess, SECRET_BITS) if valid_guess else 0
@@ -258,6 +263,7 @@ async def do_rollout(scenario: Scenario, temp: float) -> art.Trajectory:
258263
messages_and_choices=[*messages, choice],
259264
tools=TOOLS,
260265
reward=reward,
266+
logs=[f"sampled_content:\n{sampled_content}"],
261267
metrics=metrics,
262268
)
263269

src/art/pipeline_trainer/yes_no_maybe_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from dotenv import load_dotenv
1313

1414
import art
15+
from art.tinker_native import TinkerNativeBackend
1516

1617
from . import PipelineTrainer
1718

@@ -106,7 +107,7 @@ async def main() -> None:
106107
model_name = f"{MODEL_NAME}-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
107108

108109
print("Initializing TinkerNativeBackend")
109-
backend = art.TinkerNativeBackend()
110+
backend = TinkerNativeBackend()
110111

111112
print(f"Initializing TrainableModel: {model_name}")
112113
model = art.TrainableModel(name=model_name, project=PROJECT, base_model=BASE_MODEL)

src/art/tinker/cookbook_v/__init__.py

Whitespace-only changes.
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
"""
2+
Utilities for guessing good hyperparameters for fine-tuning.
3+
"""
4+
5+
import json
6+
import math
7+
import struct
8+
from typing import Dict, Tuple
9+
10+
import huggingface_hub
11+
import numpy as np
12+
from transformers import AutoConfig
13+
14+
from .utils.misc_utils import not_none
15+
16+
17+
def _list_param_shapes_from_safetensors_remote(
18+
repo_id: str,
19+
revision: str = "main",
20+
token: str | None = None,
21+
) -> Dict[str, Tuple[int, ...]]:
22+
"""
23+
Returns {param_name: shape_tuple} by reading ONLY the safetensors header(s)
24+
over HTTP (ranged requests). No full file download.
25+
"""
26+
fs = huggingface_hub.HfFileSystem(token=token)
27+
info = huggingface_hub.model_info(repo_id, revision=revision, token=token)
28+
29+
# find all .safetensors files (handles sharded checkpoints)
30+
st_files = [
31+
s.rfilename
32+
for s in not_none(info.siblings)
33+
if s.rfilename.endswith(".safetensors")
34+
]
35+
if not st_files:
36+
raise FileNotFoundError("No .safetensors files found in this repo.")
37+
38+
shapes: Dict[str, Tuple[int, ...]] = {}
39+
40+
for fname in st_files:
41+
# Open remote file via fsspec; this performs HTTP range reads under the hood
42+
path = f"{repo_id}@{revision}/{fname}" # HfFileSystem path format
43+
with fs.open(path, "rb") as f:
44+
# safetensors spec:
45+
# [0:8] = little-endian u64 header_len
46+
# [8:8+header_len] = UTF-8 JSON header
47+
header_len_bytes = f.read(8)
48+
assert isinstance(header_len_bytes, bytes)
49+
if len(header_len_bytes) < 8:
50+
raise IOError(f"File too small or not safetensors: {fname}")
51+
(header_len,) = struct.unpack("<Q", header_len_bytes)
52+
53+
header_bytes = f.read(header_len)
54+
assert isinstance(header_bytes, bytes)
55+
if len(header_bytes) < header_len:
56+
raise IOError(f"Incomplete header read for {fname}")
57+
58+
header = json.loads(header_bytes.decode("utf-8"))
59+
# header maps tensor_name -> { "dtype": "...", "shape": [...], "data_offsets": [start, end] }
60+
for name, meta in header.items():
61+
if name == "__metadata__": # optional global metadata block
62+
continue
63+
shapes[name] = tuple(meta["shape"])
64+
65+
return shapes
66+
67+
68+
def get_lora_lr_over_full_finetune_lr(model_name: str, lora_alpha: int = 32) -> float:
69+
"""
70+
Return the factor that you should scale the full fine-tuning learning rate by to get the equivalent LoRA learning rate.
71+
Previously we had a more complicated formula, but the factor of 10 was more accurate empirically.
72+
See Lora Without Regret (https://thinkingmachines.ai/blog/lora/) for more details.
73+
"""
74+
return 10.0
75+
76+
77+
def _get_hidden_size(model_name: str) -> int:
78+
if "meta-llama/Llama-3" in model_name:
79+
# Bypass HF_TOKEN requirement for Llama-3 models
80+
return {
81+
"meta-llama/Llama-3.2-1B": 2048,
82+
"meta-llama/Llama-3.2-1B-Instruct": 2048,
83+
"meta-llama/Llama-3.2-3B": 3072,
84+
"meta-llama/Llama-3.2-3B-Instruct": 3072,
85+
"meta-llama/Llama-3.1-8B": 4096,
86+
"meta-llama/Llama-3.1-8B-Instruct": 4096,
87+
"meta-llama/Llama-3.1-70B": 8192,
88+
"meta-llama/Llama-3.3-70B-Instruct": 8192,
89+
}[model_name]
90+
91+
if model_name in (
92+
"deepseek-ai/DeepSeek-V3.1",
93+
"deepseek-ai/DeepSeek-V3.1-Base",
94+
"moonshotai/Kimi-K2-Thinking",
95+
):
96+
return 7168
97+
98+
config = AutoConfig.from_pretrained(model_name)
99+
return config.hidden_size
100+
101+
102+
def get_lora_param_count(
103+
model_name: str,
104+
lora_rank: int = 32,
105+
detailed: bool = False,
106+
include_experts: bool = True,
107+
shared_expert_outer_loras: bool = True,
108+
) -> int | dict[str, int]:
109+
"""
110+
Get the number of parameters in the LoRA adapter.
111+
"""
112+
113+
dim_sum = 0
114+
dim_sum_experts = 0
115+
ignore = ["gate", "embed_tokens", "q_b_proj", "kv_b_proj"]
116+
if not include_experts:
117+
ignore.append("experts")
118+
119+
for name, shape in _list_param_shapes_from_safetensors_remote(model_name).items():
120+
if (
121+
len(shape) == 2
122+
and name.endswith(".weight")
123+
and not any([v in name.split(".") for v in ignore])
124+
):
125+
parts = name.split(".")
126+
if "experts" not in parts or not shared_expert_outer_loras:
127+
dim_sum += shape[0] + shape[1]
128+
else:
129+
# For expert shared outer_loras, we only count the outer dims once, since they are shared across experts
130+
expert_idx = int(parts[parts.index("experts") + 1])
131+
weight_name = parts[parts.index("experts") + 2]
132+
assert weight_name in ["gate_proj", "down_proj", "up_proj"], (
133+
f"Unexpected expert weight name: {weight_name}"
134+
)
135+
intermediate_dim = shape[1] if weight_name == "down_proj" else shape[0]
136+
outer_dim = shape[0] if weight_name == "down_proj" else shape[1]
137+
138+
dim_sum_experts += intermediate_dim
139+
if expert_idx == 0:
140+
dim_sum_experts += outer_dim
141+
142+
non_expert_params = lora_rank * dim_sum
143+
expert_params = lora_rank * dim_sum_experts
144+
145+
return (
146+
(expert_params + non_expert_params)
147+
if not detailed
148+
else {
149+
"expert_params": expert_params,
150+
"non_expert_params": non_expert_params,
151+
"total_params": expert_params + non_expert_params,
152+
}
153+
)
154+
155+
156+
def get_lr(model_name: str, is_lora: bool = True) -> float:
157+
base_lr = 5e-05
158+
lora_multiplier = 10.0
159+
160+
lr = base_lr * lora_multiplier if is_lora else base_lr
161+
if "llama" in model_name.lower():
162+
exponent_model = 0.781
163+
elif "qwen" in model_name.lower():
164+
exponent_model = 0.0775
165+
else:
166+
raise ValueError(f"Unknown model: {model_name}")
167+
# TODO: sweep to determine LR multipliers for other models
168+
lr = lr * (2000 / _get_hidden_size(model_name)) ** exponent_model
169+
return lr
170+
171+
172+
def get_full_finetune_param_count(model_name: str) -> float:
173+
count = 0
174+
for name, shape in _list_param_shapes_from_safetensors_remote(model_name).items():
175+
count += np.prod(shape)
176+
return float(count)
177+
178+
179+
def get_full_finetune_lr_multiplier(model_name: str):
180+
return 1.0 / math.sqrt(get_full_finetune_param_count(model_name))
181+
182+
183+
def get_lora_lr_multiplier(model_name: str):
184+
"""
185+
Get a model-specific mutliplier for the LR, when training with LoRA.
186+
Given two models A and B, and learning rate LR_A that's known to be optimal for A,
187+
we can guess an optimal learning rate for B as
188+
LR_B = LR_A * get_lora_lr_multiplier(B) / get_lora_lr_multiplier(A)
189+
"""
190+
return get_full_finetune_lr_multiplier(
191+
model_name
192+
) * get_lora_lr_over_full_finetune_lr(model_name)
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
"""
2+
Utilities for working with image processors. Create new types to avoid needing to import AutoImageProcessor and BaseImageProcessor.
3+
4+
5+
Avoid importing AutoImageProcessor and BaseImageProcessor until runtime, because they're slow imports.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
from functools import cache
11+
from typing import TYPE_CHECKING, Any, TypeAlias
12+
13+
from PIL import Image
14+
15+
if TYPE_CHECKING:
16+
# this import takes a few seconds, so avoid it on the module import when possible
17+
from transformers.image_processing_utils import BaseImageProcessor
18+
19+
ImageProcessor: TypeAlias = BaseImageProcessor
20+
else:
21+
# make it importable from other files as a type in runtime
22+
ImageProcessor: TypeAlias = Any
23+
24+
25+
@cache
26+
def get_image_processor(model_name: str) -> ImageProcessor:
27+
model_name = model_name.split(":")[0]
28+
29+
from transformers.models.auto.image_processing_auto import AutoImageProcessor
30+
31+
processor = AutoImageProcessor.from_pretrained(model_name, use_fast=True)
32+
assert processor.is_fast, f"Could not load fast image processor for {model_name}"
33+
return processor
34+
35+
36+
def resize_image(image: Image.Image, max_size: int) -> Image.Image:
37+
"""
38+
Resize an image so that its longest side is at most max_size pixels.
39+
40+
Preserves aspect ratio and uses LANCZOS resampling for quality.
41+
Returns the original image if it's already smaller than max_size.
42+
"""
43+
44+
width, height = image.size
45+
if max(width, height) <= max_size:
46+
return image
47+
48+
if width > height:
49+
new_width = max_size
50+
new_height = int(height * max_size / width)
51+
else:
52+
new_height = max_size
53+
new_width = int(width * max_size / height)
54+
55+
return image.resize((new_width, new_height), Image.Resampling.LANCZOS)

0 commit comments

Comments
 (0)