Skip to content

Commit e164f73

Browse files
committed
feat: Implement TinkerBackend and integrate with existing architecture
- Added TinkerBackend class to support Tinker functionality. - Updated __init__.py files to include TinkerBackend in the module exports. - Enhanced TinkerService to require tinker_args and improved renderer name handling. - Introduced backend.py for TinkerBackend implementation, including service management and renderer configuration.
1 parent e279274 commit e164f73

9 files changed

Lines changed: 130 additions & 25 deletions

File tree

dev/yes-no-maybe.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -40,27 +40,27 @@ def with_quotes(w: str) -> str:
4040
async def main():
4141
load_dotenv()
4242

43-
backend = LocalBackend(in_process=True)
43+
backend = art.TinkerBackend()
4444
global model
4545
base_model = os.environ.get("BASE_MODEL", "Qwen/Qwen3-30B-A3B-Instruct-2507")
4646
model = art.TrainableModel(
4747
name=os.environ.get("MODEL_NAME", "012"),
4848
project="yes-no-maybe",
4949
base_model=base_model,
50-
_internal_config=art.dev.InternalModelConfig(
51-
# engine_args=art.dev.EngineArgs(
52-
# max_lora_rank=1,
53-
# ),
54-
# peft_args=art.dev.PeftArgs(
55-
# r=1,
56-
# ),
57-
tinker_args=art.dev.TinkerArgs(
58-
renderer_name="qwen3_instruct",
59-
training_client_args=art.dev.TinkerTrainingClientArgs(
60-
rank=1,
61-
),
62-
),
63-
),
50+
# _internal_config=art.dev.InternalModelConfig(
51+
# # engine_args=art.dev.EngineArgs(
52+
# # max_lora_rank=1,
53+
# # ),
54+
# # peft_args=art.dev.PeftArgs(
55+
# # r=1,
56+
# # ),
57+
# tinker_args=art.dev.TinkerArgs(
58+
# renderer_name="qwen3_instruct",
59+
# training_client_args=art.dev.TinkerTrainingClientArgs(
60+
# rank=1,
61+
# ),
62+
# ),
63+
# ),
6464
)
6565
await model.register(backend)
6666

src/art/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __init__(self, **kwargs):
4545
from .gather import gather_trajectories, gather_trajectory_groups
4646
from .model import Model, TrainableModel
4747
from .serverless import ServerlessBackend
48+
from .tinker import TinkerBackend
4849
from .trajectories import Trajectory, TrajectoryGroup
4950
from .types import Messages, MessagesAndChoices, Tools, TrainConfig
5051
from .utils import retry
@@ -66,6 +67,7 @@ def __init__(self, **kwargs):
6667
"TrainableModel",
6768
"retry",
6869
"TrainConfig",
70+
"TinkerBackend",
6971
"Trajectory",
7072
"TrajectoryGroup",
7173
"capture_yielded_trajectory",

src/art/loss.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
from pydantic import BaseModel, ConfigDict
44
import torch
55

6-
from art import dev
76
from art.utils.group_aggregate import group_aggregate
87

8+
from . import dev
9+
910
if TYPE_CHECKING:
1011
from art.unsloth.service import TrainInputs
1112

src/art/tinker/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .backend import TinkerBackend
2+
3+
__all__ = ["TinkerBackend"]

src/art/tinker/backend.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import os
2+
3+
from mp_actors import move_to_child_process
4+
5+
from ..local.backend import LocalBackend
6+
from ..local.service import ModelService
7+
from ..model import TrainableModel
8+
from ..utils.output_dirs import get_model_dir
9+
10+
11+
class TinkerBackend(LocalBackend):
12+
def __init__(
13+
self,
14+
*,
15+
tinker_api_key: str | None = None,
16+
in_process: bool = False,
17+
path: str | None = None,
18+
) -> None:
19+
if not "TINKER_API_KEY" in os.environ or tinker_api_key is not None:
20+
assert tinker_api_key is not None, (
21+
"TINKER_API_KEY is not set and no tinker_api_key was provided"
22+
)
23+
print("Setting TINKER_API_KEY to", tinker_api_key, "in environment")
24+
os.environ["TINKER_API_KEY"] = tinker_api_key
25+
super().__init__(in_process=in_process, path=path)
26+
27+
async def _get_service(self, model: TrainableModel) -> ModelService:
28+
from ..dev.get_model_config import get_model_config
29+
from ..dev.model import TinkerArgs
30+
from .service import TinkerService
31+
32+
if model.name not in self._services:
33+
config = get_model_config(
34+
base_model=model.base_model,
35+
output_dir=get_model_dir(model=model, art_path=self._path),
36+
config=model._internal_config,
37+
)
38+
config["tinker_args"] = config.get("tinker_args") or TinkerArgs(
39+
renderer_name=get_renderer_name(model.base_model)
40+
)
41+
self._services[model.name] = TinkerService(
42+
model_name=model.name,
43+
base_model=model.base_model,
44+
config=config,
45+
output_dir=get_model_dir(model=model, art_path=self._path),
46+
)
47+
if not self._in_process:
48+
self._services[model.name] = move_to_child_process(
49+
self._services[model.name],
50+
process_name="tinker-service",
51+
)
52+
return self._services[model.name]
53+
54+
55+
renderer_name_message = """
56+
To manually specify a renderer (and silence this message), you can set the "renderer_name" field like so:
57+
58+
model = art.TrainableModel(
59+
name="my-model",
60+
project="my-project",
61+
base_model="Qwen/Qwen3-8B",
62+
_internal_config=art.dev.InternalModelConfig(
63+
tinker_args=art.dev.TinkerArgs(renderer_name="qwen3_disable_thinking"),
64+
),
65+
)
66+
67+
Valid renderer names are:
68+
69+
- llama3
70+
- qwen3
71+
- qwen3_disable_thinking
72+
- qwen3_instruct
73+
- deepseekv3
74+
- deepseekv3_disable_thinking
75+
- gpt_oss_no_sysprompt
76+
- gpt_oss_low_reasoning
77+
- gpt_oss_medium_reasoning
78+
- gpt_oss_high_reasoning
79+
""".strip()
80+
81+
82+
def get_renderer_name(base_model: str) -> str:
83+
if base_model.startswith("meta-llama/"):
84+
return "llama3"
85+
elif base_model.startswith("Qwen/Qwen3-"):
86+
if "Instruct" in base_model:
87+
return "qwen3_instruct"
88+
else:
89+
print("Defaulting to Qwen3 renderer without thinking for", base_model)
90+
print(renderer_name_message)
91+
return "qwen3_disable_thinking"
92+
elif base_model.startswith("deepseek-ai/DeepSeek-V3"):
93+
print("Defaulting to DeepSeekV3 renderer without thinking for", base_model)
94+
print(renderer_name_message)
95+
return "deepseekv3_disable_thinking"
96+
elif base_model.startswith("openai/gpt-oss"):
97+
print("Defaulting to GPT-OSS renderer without system prompt for", base_model)
98+
print(renderer_name_message)
99+
return "gpt_oss_no_sysprompt"
100+
else:
101+
raise ValueError(f"Unknown base model: {base_model}")

src/art/tinker/service.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,8 @@ def _state_task(self) -> asyncio.Task["TinkerState"]:
195195
return asyncio.create_task(self._get_state())
196196

197197
async def _get_state(self) -> "TinkerState":
198-
config = self.config.get("tinker_args") or {"renderer_name": "qwen3_instruct"}
198+
config = self.config.get("tinker_args")
199+
assert config is not None, "Tinker args are required"
199200
service_client = tinker.ServiceClient()
200201
rest_client = service_client.create_rest_client()
201202
checkpoint_dir = self._get_last_checkpoint_dir()
@@ -204,9 +205,7 @@ async def _get_state(self) -> "TinkerState":
204205
with log_timing("Creating Tinker training client from checkpoint"):
205206
training_client = await service_client.create_training_client_from_state_with_optimizer_async(
206207
path=info["state_with_optimizer_path"],
207-
user_metadata=(self.config.get("tinker_args") or {}).get(
208-
"user_metadata", None
209-
),
208+
user_metadata=config.get("user_metadata", None),
210209
)
211210
with log_timing("Creating Tinker sampling client from checkpoint"):
212211
sampler_client = await training_client.create_sampling_client_async(
@@ -229,7 +228,7 @@ async def _get_state(self) -> "TinkerState":
229228
training_client=training_client,
230229
sampler_client=sampler_client,
231230
renderer=renderers.get_renderer(
232-
name=config.get("renderer_name"),
231+
name=config["renderer_name"],
233232
tokenizer=tokenizer_utils.get_tokenizer(self.base_model),
234233
),
235234
)

src/art/utils/trajectory_logging.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from openai.types.chat.chat_completion import Choice
1616
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
1717

18-
from art import Trajectory, TrajectoryGroup
18+
from art.trajectories import Trajectory, TrajectoryGroup
1919

2020

2121
def _flatten_message(msg: dict) -> dict:

src/art/utils/trajectory_migration.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717

1818
import yaml
1919

20-
from art import Trajectory, TrajectoryGroup
21-
from art.trajectories import History
20+
from art.trajectories import History, Trajectory, TrajectoryGroup
2221
from art.types import Choice, Message, MessageOrChoice
2322

2423
# ============================================================================

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)