Skip to content

Commit 8d18989

Browse files
corbtCursor Bot
andauthored
fix: align Tinker server port with client base_url (#531)
Pick a port once in TinkerBackend and reuse it for the server bind and returned base_url to avoid pointing clients at stale endpoints. Co-authored-by: Cursor Bot <bot@cursor.com>
1 parent 6b6d85e commit 8d18989

2 files changed

Lines changed: 39 additions & 4 deletions

File tree

src/art/tinker/backend.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import os
2+
from typing import Any, cast
23

34
from mp_actors import move_to_child_process
45

6+
from .. import dev
57
from ..local.backend import LocalBackend
68
from ..local.service import ModelService
79
from ..model import TrainableModel
@@ -24,9 +26,41 @@ def __init__(
2426
os.environ["TINKER_API_KEY"] = tinker_api_key
2527
super().__init__(in_process=in_process, path=path)
2628

29+
async def _prepare_backend_for_training(
30+
self,
31+
model: TrainableModel,
32+
config: dev.OpenAIServerConfig | None = None,
33+
) -> tuple[str, str]:
34+
"""Start the local OpenAI server and return its base URL + API key."""
35+
service = await self._get_service(model)
36+
raw_config: dict[str, Any] = cast(dict[str, Any], config) if config else {}
37+
38+
server_args = cast(dict[str, Any], raw_config.get("server_args", {}))
39+
host = server_args.get("host", raw_config.get("host", "0.0.0.0"))
40+
port = server_args.get("port", raw_config.get("port"))
41+
if port is None:
42+
from .service import get_free_port
43+
44+
port = get_free_port()
45+
api_key = server_args.get("api_key", raw_config.get("api_key")) or "default"
46+
47+
# Ensure the Tinker server binds to the same host/port we return.
48+
tinker_config = cast(
49+
dev.OpenAIServerConfig,
50+
{
51+
**raw_config,
52+
"host": host,
53+
"port": port,
54+
},
55+
)
56+
await service.start_openai_server(config=tinker_config)
57+
58+
base_url = f"http://{host}:{port}/v1"
59+
return base_url, api_key
60+
2761
async def _get_service(self, model: TrainableModel) -> ModelService:
2862
from ..dev.get_model_config import get_model_config
29-
from ..dev.model import TinkerArgs
63+
from ..dev.model import TinkerArgs, TinkerTrainingClientArgs
3064
from .service import TinkerService
3165

3266
if model.name not in self._services:
@@ -38,8 +72,9 @@ async def _get_service(self, model: TrainableModel) -> ModelService:
3872
config["tinker_args"] = config.get("tinker_args") or TinkerArgs(
3973
renderer_name=get_renderer_name(model.base_model)
4074
)
41-
config["tinker_args"]["training_client_args"] = (
42-
config["tinker_args"].get("training_client_args") or {}
75+
config["tinker_args"]["training_client_args"] = cast(
76+
TinkerTrainingClientArgs,
77+
config["tinker_args"].get("training_client_args") or {},
4378
)
4479
self._services[model.name] = TinkerService(
4580
model_name=model.name,

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)