@@ -175,20 +175,37 @@ def custom_loss_fn(
175175 }
176176 last_checkpoint_dir = self ._get_last_checkpoint_dir ()
177177 assert last_checkpoint_dir is not None , "No checkpoint found"
178- state .sampler_client = await self ._save_checkpoint (
179- last_checkpoint_dir .with_name (f"{ int (last_checkpoint_dir .name ) + 1 :04d} " ),
178+ next_step = int (last_checkpoint_dir .name ) + 1
179+ new_sampler_client = await self ._save_checkpoint (
180+ last_checkpoint_dir .with_name (f"{ next_step :04d} " ),
180181 state .training_client ,
181182 )
183+ # Add new sampler client to the dict and update latest step
184+ state .sampler_clients [next_step ] = new_sampler_client
185+ state .latest_step = next_step
182186
183187 async def delete_checkpoints (self , steps_to_keep : list [int ]) -> None :
184188 state = await self ._state_task
189+ # Find steps to delete
190+ steps_to_delete = [
191+ int (checkpoint_dir .name )
192+ for checkpoint_dir in self ._checkpoints_path .iterdir ()
193+ if int (checkpoint_dir .name ) not in steps_to_keep
194+ ]
195+ # Delete checkpoints from disk and Tinker
185196 await asyncio .gather (
186197 * [
187- delete_checkpoint (checkpoint_dir , state .rest_client )
188- for checkpoint_dir in self ._checkpoints_path .iterdir ()
189- if int (checkpoint_dir .name ) not in steps_to_keep
198+ delete_checkpoint (
199+ self ._checkpoints_path / f"{ step :04d} " , state .rest_client
200+ )
201+ for step in steps_to_delete
190202 ]
191203 )
204+ # Also remove corresponding sampler clients from state
205+ for step in steps_to_delete :
206+ if step in state .sampler_clients :
207+ del state .sampler_clients [step ]
208+ print (f"Removed sampler client for step { step } " )
192209
193210 @cached_property
194211 def _state_task (self ) -> asyncio .Task ["TinkerState" ]:
@@ -201,6 +218,7 @@ async def _get_state(self) -> "TinkerState":
201218 rest_client = service_client .create_rest_client ()
202219 checkpoint_dir = self ._get_last_checkpoint_dir ()
203220 if checkpoint_dir :
221+ current_step = int (checkpoint_dir .name )
204222 info = yaml .safe_load (open (checkpoint_dir / "info.yaml" , "r" ))
205223 with log_timing ("Creating Tinker training client from checkpoint" ):
206224 training_client = await service_client .create_training_client_from_state_with_optimizer_async (
@@ -212,6 +230,7 @@ async def _get_state(self) -> "TinkerState":
212230 model_path = info ["sampler_weights_path" ],
213231 )
214232 else :
233+ current_step = 0
215234 with log_timing ("Creating Tinker training client" ):
216235 training_client_args = config .get ("training_client_args" , {})
217236 if "rank" not in training_client_args :
@@ -231,7 +250,8 @@ async def _get_state(self) -> "TinkerState":
231250 service_client = service_client ,
232251 rest_client = rest_client ,
233252 training_client = training_client ,
234- sampler_client = sampler_client ,
253+ sampler_clients = {current_step : sampler_client },
254+ latest_step = current_step ,
235255 renderer = renderers .get_renderer (
236256 name = config ["renderer_name" ],
237257 tokenizer = tokenizer_utils .get_tokenizer (self .base_model ),
@@ -296,14 +316,23 @@ async def completions() -> dict:
296316 async def chat_completions (
297317 request : Request , body : CompletionCreateParams
298318 ) -> ChatCompletion :
319+ # Parse model name to extract optional @step suffix
320+ model_name = body .get ("model" , self .model_name )
321+ step : int | None = None
322+ if "@" in str (model_name ):
323+ base_name , step_str = str (model_name ).rsplit ("@" , 1 )
324+ step = int (step_str )
325+
326+ sampler_client = state .get_sampler_client (step )
327+
299328 prompt = tinker .ModelInput .from_ints (
300329 tokens = state .renderer .tokenizer .apply_chat_template (
301330 list (body ["messages" ]), # type: ignore
302331 tools = body .get ("tools" ), # type: ignore
303332 add_generation_prompt = True ,
304333 )
305334 )
306- sample_response = await state . sampler_client .sample_async (
335+ sample_response = await sampler_client .sample_async (
307336 prompt = prompt ,
308337 num_samples = body .get ("n" ) or 1 ,
309338 sampling_params = tinker .SamplingParams (
@@ -417,5 +446,16 @@ class TinkerState:
417446 service_client : tinker .ServiceClient
418447 rest_client : TinkerRestClient
419448 training_client : tinker .TrainingClient
420- sampler_client : tinker .SamplingClient
449+ sampler_clients : dict [int , tinker .SamplingClient ]
450+ latest_step : int
421451 renderer : renderers .Renderer
452+
453+ def get_sampler_client (self , step : int | None = None ) -> tinker .SamplingClient :
454+ if step is None :
455+ step = self .latest_step
456+ if step not in self .sampler_clients :
457+ available = sorted (self .sampler_clients .keys ())
458+ raise ValueError (
459+ f"No sampler client for step { step } . Available steps: { available } "
460+ )
461+ return self .sampler_clients [step ]
0 commit comments