@@ -129,19 +129,29 @@ async def register(
129129
130130 async def _get_service (self , model : TrainableModel ) -> ModelService :
131131 from ..dev .get_model_config import get_model_config
132- from ..torchtune .service import TorchtuneService
133- from ..unsloth .service import UnslothService
134132
135133 if model .name not in self ._services :
136134 config = get_model_config (
137135 base_model = model .base_model ,
138136 output_dir = get_model_dir (model = model , art_path = self ._path ),
139137 config = model ._internal_config ,
140138 )
141- if config .get ("torchtune_args" ) is not None :
139+ is_tinker = config .get ("tinker_args" ) is not None
140+ if is_tinker :
141+ from ..tinker .service import TinkerService
142+
143+ service_class = TinkerService
144+ elif config .get ("torchtune_args" ) is not None :
145+ from ..torchtune .service import TorchtuneService
146+
142147 service_class = TorchtuneService
143148 else :
149+ from ..unsloth .service import UnslothService
150+
144151 service_class = UnslothService
152+ # When moving the service to a child process, import unsloth
153+ # early to maximize optimizations
154+ os .environ ["IMPORT_UNSLOTH" ] = "1"
145155 self ._services [model .name ] = service_class (
146156 model_name = model .name ,
147157 base_model = model .base_model ,
@@ -151,12 +161,9 @@ async def _get_service(self, model: TrainableModel) -> ModelService:
151161 if not self ._in_process :
152162 # Kill all "model-service" processes to free up GPU memory
153163 subprocess .run (["pkill" , "-9" , "model-service" ])
154- # When moving the service to a child process, import unsloth
155- # early to maximize optimizations
156- os .environ ["IMPORT_UNSLOTH" ] = "1"
157164 self ._services [model .name ] = move_to_child_process (
158165 self ._services [model .name ],
159- process_name = "model-service" ,
166+ process_name = "tinker-service" if is_tinker else " model-service" ,
160167 )
161168 return self ._services [model .name ]
162169
@@ -242,6 +249,8 @@ async def _delete_checkpoints(
242249 benchmark : str ,
243250 benchmark_smoothing : float ,
244251 ) -> None :
252+ from ..tinker .service import TinkerService
253+
245254 output_dir = get_model_dir (model = model , art_path = self ._path )
246255 # Keep the latest step
247256 steps_to_keep = [get_model_step (model , self ._path )]
@@ -261,7 +270,11 @@ async def _delete_checkpoints(
261270 print (f'"{ output_dir } /history.jsonl" not found' )
262271 except pl .exceptions .ColumnNotFoundError :
263272 print (f'No "{ benchmark } " metric found in history' )
264- delete_checkpoints (output_dir , steps_to_keep )
273+ service = await self ._get_service (model )
274+ if isinstance (service , TinkerService ):
275+ await service .delete_checkpoints (steps_to_keep )
276+ else :
277+ delete_checkpoints (output_dir , steps_to_keep )
265278
266279 async def _prepare_backend_for_training (
267280 self ,
0 commit comments