Skip to content

Commit 913b35d

Browse files
corbtCursor Bot
andauthored
Multi-checkpoint inference for pipelined training (RFC #513) (#515)
* Use training_step for W&B x-axis to allow out-of-order logging * Implement multi-checkpoint inference for pipelined training * Fix formatting and typing issues --------- Co-authored-by: Cursor Bot <bot@cursor.com>
1 parent 619ed3e commit 913b35d

10 files changed

Lines changed: 809 additions & 33 deletions

File tree

src/art/dev/openai_server.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,23 @@ def get_openai_server_config(
1212
lora_path: str | None = None,
1313
config: "OpenAIServerConfig | None" = None,
1414
) -> "OpenAIServerConfig":
15+
import os
16+
1517
if config is None:
1618
config = OpenAIServerConfig()
1719
log_file = config.get("log_file", log_file)
20+
21+
# Extract step from lora_path for multi-checkpoint support
22+
# lora_path format is: {output_dir}/checkpoints/{step:04d}
23+
lora_name = model_name
24+
if lora_path:
25+
step = int(os.path.basename(lora_path))
26+
lora_name = f"{model_name}@{step}"
27+
1828
server_args = ServerArgs(
1929
api_key="default",
2030
lora_modules=(
21-
[f'{{"name": "{model_name}", "path": "{lora_path}"}}']
22-
if lora_path
23-
else None
31+
[f'{{"name": "{lora_name}", "path": "{lora_path}"}}'] if lora_path else None
2432
),
2533
return_tokens_as_token_ids=True,
2634
enable_auto_tool_choice=True,

src/art/local/backend.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,13 @@ async def _train_model(
484484
f"Advanced step from {current_step} to {next_step} (no training occurred)"
485485
)
486486

487+
# Register the renamed checkpoint as a new LoRA adapter
488+
# so it's available for inference at the new step
489+
from ..unsloth.service import UnslothService
490+
491+
if isinstance(service, UnslothService):
492+
await service.register_lora_for_step(next_step, next_checkpoint_dir)
493+
487494
# Log metrics showing no groups were trainable
488495
self._log_metrics(
489496
model,

src/art/model.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,15 @@ def openai_client(
187187
)
188188
return self._openai_client
189189

190-
def litellm_completion_params(self) -> dict:
191-
"""Return the parameters that should be sent to litellm.completion."""
192-
model_name = self.inference_model_name
190+
def litellm_completion_params(self, step: int | None = None) -> dict:
191+
"""Return the parameters that should be sent to litellm.completion.
192+
193+
Args:
194+
step: If provided, returns params for specific checkpoint using
195+
the `name@step` convention. If None, returns params for
196+
latest checkpoint (default, backwards compatible).
197+
"""
198+
model_name = self.get_inference_name(step)
193199
if self.trainable:
194200
model_name = f"hosted_vllm/{model_name}"
195201
return {
@@ -203,13 +209,21 @@ def litellm_completion_params(self) -> dict:
203209
# Inference name helpers
204210
# ------------------------------------------------------------------
205211

206-
def get_inference_name(self) -> str:
212+
def get_inference_name(self, step: int | None = None) -> str:
207213
"""Return the name that should be sent to the inference endpoint.
208214
209215
If `inference_model_name` is provided we use that, otherwise we fall
210216
back to the model's own `name`.
217+
218+
Args:
219+
step: If provided, returns name for specific checkpoint using
220+
the `name@step` convention. If None, returns name for
221+
latest checkpoint (default, backwards compatible).
211222
"""
212-
return self.inference_model_name or self.name
223+
base_name = self.inference_model_name or self.name
224+
if step is not None:
225+
return f"{base_name}@{step}"
226+
return base_name
213227

214228
async def log(
215229
self,

src/art/serverless/backend.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,22 @@ async def delete(
7474
assert model.id is not None, "Model ID is required"
7575
await self._client.models.delete(model_id=model.id)
7676

77-
def _model_inference_name(self, model: "TrainableModel") -> str:
77+
def _model_inference_name(
78+
self, model: "TrainableModel", step: int | None = None
79+
) -> str:
80+
"""Return the inference name for a model checkpoint.
81+
82+
Args:
83+
model: The trainable model.
84+
step: If provided, returns name for specific checkpoint using
85+
W&B artifact versioning (e.g., :step5). If None, returns
86+
name for latest checkpoint (default, backwards compatible).
87+
"""
7888
assert model.entity is not None, "Model entity is required"
79-
return f"wandb-artifact:///{model.entity}/{model.project}/{model.name}"
89+
base_name = f"wandb-artifact:///{model.entity}/{model.project}/{model.name}"
90+
if step is not None:
91+
return f"{base_name}:step{step}"
92+
return base_name
8093

8194
async def _get_step(self, model: "Model") -> int:
8295
if model.trainable:

src/art/tinker/service.py

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -175,20 +175,37 @@ def custom_loss_fn(
175175
}
176176
last_checkpoint_dir = self._get_last_checkpoint_dir()
177177
assert last_checkpoint_dir is not None, "No checkpoint found"
178-
state.sampler_client = await self._save_checkpoint(
179-
last_checkpoint_dir.with_name(f"{int(last_checkpoint_dir.name) + 1:04d}"),
178+
next_step = int(last_checkpoint_dir.name) + 1
179+
new_sampler_client = await self._save_checkpoint(
180+
last_checkpoint_dir.with_name(f"{next_step:04d}"),
180181
state.training_client,
181182
)
183+
# Add new sampler client to the dict and update latest step
184+
state.sampler_clients[next_step] = new_sampler_client
185+
state.latest_step = next_step
182186

183187
async def delete_checkpoints(self, steps_to_keep: list[int]) -> None:
184188
state = await self._state_task
189+
# Find steps to delete
190+
steps_to_delete = [
191+
int(checkpoint_dir.name)
192+
for checkpoint_dir in self._checkpoints_path.iterdir()
193+
if int(checkpoint_dir.name) not in steps_to_keep
194+
]
195+
# Delete checkpoints from disk and Tinker
185196
await asyncio.gather(
186197
*[
187-
delete_checkpoint(checkpoint_dir, state.rest_client)
188-
for checkpoint_dir in self._checkpoints_path.iterdir()
189-
if int(checkpoint_dir.name) not in steps_to_keep
198+
delete_checkpoint(
199+
self._checkpoints_path / f"{step:04d}", state.rest_client
200+
)
201+
for step in steps_to_delete
190202
]
191203
)
204+
# Also remove corresponding sampler clients from state
205+
for step in steps_to_delete:
206+
if step in state.sampler_clients:
207+
del state.sampler_clients[step]
208+
print(f"Removed sampler client for step {step}")
192209

193210
@cached_property
194211
def _state_task(self) -> asyncio.Task["TinkerState"]:
@@ -201,6 +218,7 @@ async def _get_state(self) -> "TinkerState":
201218
rest_client = service_client.create_rest_client()
202219
checkpoint_dir = self._get_last_checkpoint_dir()
203220
if checkpoint_dir:
221+
current_step = int(checkpoint_dir.name)
204222
info = yaml.safe_load(open(checkpoint_dir / "info.yaml", "r"))
205223
with log_timing("Creating Tinker training client from checkpoint"):
206224
training_client = await service_client.create_training_client_from_state_with_optimizer_async(
@@ -212,6 +230,7 @@ async def _get_state(self) -> "TinkerState":
212230
model_path=info["sampler_weights_path"],
213231
)
214232
else:
233+
current_step = 0
215234
with log_timing("Creating Tinker training client"):
216235
training_client_args = config.get("training_client_args", {})
217236
if "rank" not in training_client_args:
@@ -231,7 +250,8 @@ async def _get_state(self) -> "TinkerState":
231250
service_client=service_client,
232251
rest_client=rest_client,
233252
training_client=training_client,
234-
sampler_client=sampler_client,
253+
sampler_clients={current_step: sampler_client},
254+
latest_step=current_step,
235255
renderer=renderers.get_renderer(
236256
name=config["renderer_name"],
237257
tokenizer=tokenizer_utils.get_tokenizer(self.base_model),
@@ -296,14 +316,23 @@ async def completions() -> dict:
296316
async def chat_completions(
297317
request: Request, body: CompletionCreateParams
298318
) -> ChatCompletion:
319+
# Parse model name to extract optional @step suffix
320+
model_name = body.get("model", self.model_name)
321+
step: int | None = None
322+
if "@" in str(model_name):
323+
base_name, step_str = str(model_name).rsplit("@", 1)
324+
step = int(step_str)
325+
326+
sampler_client = state.get_sampler_client(step)
327+
299328
prompt = tinker.ModelInput.from_ints(
300329
tokens=state.renderer.tokenizer.apply_chat_template(
301330
list(body["messages"]), # type: ignore
302331
tools=body.get("tools"), # type: ignore
303332
add_generation_prompt=True,
304333
)
305334
)
306-
sample_response = await state.sampler_client.sample_async(
335+
sample_response = await sampler_client.sample_async(
307336
prompt=prompt,
308337
num_samples=body.get("n") or 1,
309338
sampling_params=tinker.SamplingParams(
@@ -417,5 +446,16 @@ class TinkerState:
417446
service_client: tinker.ServiceClient
418447
rest_client: TinkerRestClient
419448
training_client: tinker.TrainingClient
420-
sampler_client: tinker.SamplingClient
449+
sampler_clients: dict[int, tinker.SamplingClient]
450+
latest_step: int
421451
renderer: renderers.Renderer
452+
453+
def get_sampler_client(self, step: int | None = None) -> tinker.SamplingClient:
454+
if step is None:
455+
step = self.latest_step
456+
if step not in self.sampler_clients:
457+
available = sorted(self.sampler_clients.keys())
458+
raise ValueError(
459+
f"No sampler client for step {step}. Available steps: {available}"
460+
)
461+
return self.sampler_clients[step]

src/art/unsloth/service.py

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,13 @@ class UnslothService:
261261
config: dev.InternalModelConfig
262262
output_dir: str
263263
_is_sleeping: bool = False
264+
_latest_step: int = 0
265+
_lora_id_counter: int = 1 # Start from 1 since 0 is reserved
266+
267+
def _next_lora_id(self) -> int:
268+
"""Return a new unique LoRA ID to avoid collisions in vLLM."""
269+
self._lora_id_counter += 1
270+
return self._lora_id_counter
264271

265272
async def start_openai_server(self, config: dev.OpenAIServerConfig | None) -> None:
266273
lora_path = get_last_checkpoint_dir(self.output_dir)
@@ -269,24 +276,50 @@ async def start_openai_server(self, config: dev.OpenAIServerConfig | None) -> No
269276
lora_path = get_step_checkpoint_dir(self.output_dir, 0)
270277
os.makedirs(os.path.dirname(lora_path), exist_ok=True)
271278
self._state.trainer.save_model(lora_path)
279+
self._latest_step = 0
280+
else:
281+
# Extract step from checkpoint path
282+
self._latest_step = get_step_from_dir(self.output_dir)
272283

273284
# Offload training model to CPU before vLLM starts to free GPU memory
274285
self._state.offload_to_cpu()
275286

287+
server_config = dev.get_openai_server_config(
288+
model_name=self.model_name,
289+
base_model=self.base_model,
290+
log_file=f"{self.output_dir}/logs/vllm.log",
291+
lora_path=lora_path,
292+
config=config,
293+
)
276294
await openai_server_task(
277295
engine=await self.llm,
278-
config=dev.get_openai_server_config(
279-
model_name=self.model_name,
280-
base_model=self.base_model,
281-
log_file=f"{self.output_dir}/logs/vllm.log",
282-
lora_path=lora_path,
283-
config=config,
284-
),
296+
config=server_config,
285297
)
286298

287299
async def vllm_engine_is_sleeping(self) -> bool:
288300
return self._is_sleeping
289301

302+
async def register_lora_for_step(self, step: int, checkpoint_dir: str) -> None:
303+
"""Register a LoRA adapter for a specific checkpoint step.
304+
305+
This is called when training is skipped but the checkpoint is renamed.
306+
"""
307+
llm = await self.llm
308+
await llm.pause_generation()
309+
added = await llm.add_lora(
310+
LoRARequest(
311+
lora_name=f"{self.model_name}@{step}",
312+
lora_int_id=self._next_lora_id(),
313+
lora_path=checkpoint_dir,
314+
)
315+
)
316+
if not added:
317+
raise RuntimeError(
318+
f"Failed to add LoRA adapter for step {step} at {checkpoint_dir}"
319+
)
320+
self._latest_step = step
321+
await llm.resume_generation()
322+
290323
async def train(
291324
self,
292325
disk_packed_tensors: DiskPackedTensors,
@@ -371,17 +404,26 @@ async def train(
371404
await run_on_workers(llm, do_wake_up)
372405
self._is_sleeping = False
373406

374-
# Swap out the LoRA adapter with the newly trained checkpoint
375-
await llm.remove_lora(1)
376-
await llm.add_lora(
407+
# Determine the new step from the checkpoint directory
408+
# checkpoint_dir format is: {output_dir}/checkpoints/{step:04d}
409+
new_step = int(os.path.basename(checkpoint_dir))
410+
411+
# Add the new LoRA adapter
412+
# We keep old LoRAs loaded - vLLM will page them out as needed
413+
added = await llm.add_lora(
377414
LoRARequest(
378-
lora_name=self.model_name,
379-
lora_int_id=1,
415+
lora_name=f"{self.model_name}@{new_step}",
416+
lora_int_id=self._next_lora_id(),
380417
lora_path=checkpoint_dir,
381418
)
382419
)
420+
if not added:
421+
raise RuntimeError(
422+
f"Failed to add LoRA adapter for step {new_step} at {checkpoint_dir}"
423+
)
424+
self._latest_step = new_step
383425

384-
# Resume generation after LoRA swap is complete
426+
# Resume generation after LoRA add is complete
385427
await llm.resume_generation()
386428

387429
if verbose:
@@ -461,6 +503,7 @@ def llm(self) -> asyncio.Task[AsyncLLM]:
461503
engine_args = {
462504
**self.config.get("engine_args", {}),
463505
"enable_lora": True,
506+
"max_loras": self.config.get("engine_args", {}).get("max_loras", 2),
464507
}
465508
# Remove boolean flags that vLLM's argparse doesn't accept as =False
466509
for key in ["enable_log_requests", "disable_log_requests"]:

src/art/vllm/server.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from contextlib import asynccontextmanager
55
import logging
66
import os
7-
from typing import Any, AsyncIterator, Coroutine
7+
from typing import Any, AsyncIterator, Coroutine, cast
88

99
from openai import AsyncOpenAI
1010
from uvicorn.config import LOGGING_CONFIG
@@ -16,6 +16,8 @@
1616

1717
from ..dev.openai_server import OpenAIServerConfig
1818

19+
_openai_serving_models: Any | None = None
20+
1921

2022
async def openai_server_task(
2123
engine: EngineClient,
@@ -44,6 +46,22 @@ async def openai_server_task(
4446
subclass_chat_completion_request()
4547
from vllm.entrypoints.openai import api_server
4648

49+
# Capture the OpenAIServingModels instance so dynamically added LoRAs
50+
# are reflected in the model list.
51+
from vllm.entrypoints.openai import serving_models
52+
53+
serving_models_any = cast(Any, serving_models)
54+
if not getattr(serving_models_any, "_art_openai_serving_models_patched", False):
55+
serving_models_any._art_openai_serving_models_patched = True
56+
original_init = serving_models.OpenAIServingModels.__init__
57+
58+
def _init(self, *args: Any, **kwargs: Any) -> None:
59+
original_init(self, *args, **kwargs)
60+
global _openai_serving_models
61+
_openai_serving_models = self
62+
63+
serving_models.OpenAIServingModels.__init__ = _init
64+
4765
patch_listen_for_disconnect()
4866
patch_tool_parser_manager()
4967
set_vllm_log_file(config.get("log_file", "vllm.log"))
@@ -65,7 +83,10 @@ async def _add_lora(lora_request) -> bool:
6583
long_lora_max_len=getattr(lora_request, "long_lora_max_len", None),
6684
base_model_name=getattr(lora_request, "base_model_name", None),
6785
)
68-
return await add_lora(lora_request)
86+
added = await add_lora(lora_request)
87+
if added and _openai_serving_models is not None:
88+
_openai_serving_models.lora_requests[lora_request.lora_name] = lora_request
89+
return added
6990

7091
engine.add_lora = _add_lora
7192

tests/integration/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Integration tests for ART multi-checkpoint inference."""

0 commit comments

Comments
 (0)