11import os
2+ from typing import Any , cast
23
34from mp_actors import move_to_child_process
45
6+ from .. import dev
57from ..local .backend import LocalBackend
68from ..local .service import ModelService
79from ..model import TrainableModel
@@ -24,9 +26,41 @@ def __init__(
2426 os .environ ["TINKER_API_KEY" ] = tinker_api_key
2527 super ().__init__ (in_process = in_process , path = path )
2628
29+ async def _prepare_backend_for_training (
30+ self ,
31+ model : TrainableModel ,
32+ config : dev .OpenAIServerConfig | None = None ,
33+ ) -> tuple [str , str ]:
34+ """Start the local OpenAI server and return its base URL + API key."""
35+ service = await self ._get_service (model )
36+ raw_config : dict [str , Any ] = cast (dict [str , Any ], config ) if config else {}
37+
38+ server_args = cast (dict [str , Any ], raw_config .get ("server_args" , {}))
39+ host = server_args .get ("host" , raw_config .get ("host" , "0.0.0.0" ))
40+ port = server_args .get ("port" , raw_config .get ("port" ))
41+ if port is None :
42+ from .service import get_free_port
43+
44+ port = get_free_port ()
45+ api_key = server_args .get ("api_key" , raw_config .get ("api_key" )) or "default"
46+
47+ # Ensure the Tinker server binds to the same host/port we return.
48+ tinker_config = cast (
49+ dev .OpenAIServerConfig ,
50+ {
51+ ** raw_config ,
52+ "host" : host ,
53+ "port" : port ,
54+ },
55+ )
56+ await service .start_openai_server (config = tinker_config )
57+
58+ base_url = f"http://{ host } :{ port } /v1"
59+ return base_url , api_key
60+
2761 async def _get_service (self , model : TrainableModel ) -> ModelService :
2862 from ..dev .get_model_config import get_model_config
29- from ..dev .model import TinkerArgs
63+ from ..dev .model import TinkerArgs , TinkerTrainingClientArgs
3064 from .service import TinkerService
3165
3266 if model .name not in self ._services :
@@ -38,8 +72,9 @@ async def _get_service(self, model: TrainableModel) -> ModelService:
3872 config ["tinker_args" ] = config .get ("tinker_args" ) or TinkerArgs (
3973 renderer_name = get_renderer_name (model .base_model )
4074 )
41- config ["tinker_args" ]["training_client_args" ] = (
42- config ["tinker_args" ].get ("training_client_args" ) or {}
75+ config ["tinker_args" ]["training_client_args" ] = cast (
76+ TinkerTrainingClientArgs ,
77+ config ["tinker_args" ].get ("training_client_args" ) or {},
4378 )
4479 self ._services [model .name ] = TinkerService (
4580 model_name = model .name ,
0 commit comments