Skip to content

Commit 86d347b

Browse files
authored
feat: upgrade vLLM to 0.15.1 (#561)
1 parent 0a3752e commit 86d347b

7 files changed

Lines changed: 909 additions & 802 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ backend = [
2424
"bitsandbytes>=0.45.2",
2525
"unsloth==2025.12.9",
2626
"unsloth-zoo==2025.12.7",
27-
"vllm==0.13.0",
2827
"torch>=2.8.0",
2928
"torchao==0.14.1",
3029
"accelerate==1.7.0",
@@ -39,6 +38,7 @@ backend = [
3938
"pytest>=8.4.1",
4039
"nbmake>=1.5.5",
4140
"gql<4",
41+
"vllm==0.15.1 ; sys_platform == 'linux'",
4242
]
4343

4444
langgraph = [

src/art/vllm/patches.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,16 @@ def subclass_chat_completion_request() -> None:
77
"""
88
Subclass ChatCompletionRequest so that logprobs are always returned.
99
"""
10-
import vllm.entrypoints.openai.protocol
10+
from vllm.entrypoints.openai.chat_completion import protocol
1111

12-
class ChatCompletionRequest(vllm.entrypoints.openai.protocol.ChatCompletionRequest):
12+
class ChatCompletionRequest(protocol.ChatCompletionRequest):
1313
def __init__(self, *args: object, **kwargs: object) -> None:
1414
super().__init__(*args, **kwargs) # ty:ignore[invalid-argument-type]
1515
self.logprobs = True
1616
if self.top_logprobs is None:
1717
self.top_logprobs = 0
1818

19-
vllm.entrypoints.openai.protocol.ChatCompletionRequest = ChatCompletionRequest # ty:ignore[invalid-assignment]
19+
protocol.ChatCompletionRequest = ChatCompletionRequest # ty:ignore[invalid-assignment]
2020

2121

2222
def patch_listen_for_disconnect() -> None:
@@ -39,7 +39,7 @@ def patch_tool_parser_manager() -> None:
3939
"""
4040
Patch ToolParserManager to support streaming tool call logprobs.
4141
"""
42-
from vllm.entrypoints.openai.protocol import DeltaMessage
42+
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
4343
from vllm.tool_parsers.abstract_tool_parser import ToolParserManager
4444

4545
get_tool_parser = ToolParserManager.get_tool_parser

src/art/vllm/server.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ async def openai_server_task(
4646
subclass_chat_completion_request()
4747
# Capture the OpenAIServingModels instance so dynamically added LoRAs
4848
# are reflected in the model list.
49-
from vllm.entrypoints.openai import api_server, serving_models
49+
from vllm.entrypoints.openai import api_server
50+
from vllm.entrypoints.openai.models import serving as serving_models
5051

5152
serving_models_any = cast(Any, serving_models)
5253
if not getattr(serving_models_any, "_art_openai_serving_models_patched", False):
@@ -64,22 +65,19 @@ def _init(self, *args: Any, **kwargs: Any) -> None:
6465
patch_tool_parser_manager()
6566
set_vllm_log_file(config.get("log_file", "vllm.log"))
6667

67-
# Patch engine.add_lora to ensure lora_tensors attribute exists
68-
# This is needed for compatibility with Unsloth
68+
# Patch engine.add_lora to normalize requests across vLLM schema changes.
6969
add_lora = engine.add_lora
7070

7171
async def _add_lora(lora_request) -> bool:
72-
# Ensure lora_tensors attribute exists on the request
73-
if not hasattr(lora_request, "lora_tensors"):
74-
# For msgspec.Struct, we need to create a new instance with the attribute
75-
from vllm.lora.request import LoRARequest
72+
from vllm.lora.request import LoRARequest
7673

74+
if not isinstance(lora_request, LoRARequest):
7775
lora_request = LoRARequest(
7876
lora_name=lora_request.lora_name,
7977
lora_int_id=lora_request.lora_int_id,
8078
lora_path=lora_request.lora_path,
81-
long_lora_max_len=getattr(lora_request, "long_lora_max_len", None),
8279
base_model_name=getattr(lora_request, "base_model_name", None),
80+
load_inplace=getattr(lora_request, "load_inplace", False),
8381
)
8482
added = await add_lora(lora_request)
8583
if added and _openai_serving_models is not None:
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
"""End-to-end vLLM contract tests for ART LocalBackend."""
2+
3+
import os
4+
import tempfile
5+
import uuid
6+
7+
import openai
8+
import pytest
9+
10+
torch = pytest.importorskip("torch")
11+
pytest.importorskip("vllm")
12+
13+
import art
14+
from art.local import LocalBackend
15+
from art.types import LocalTrainResult
16+
17+
DEFAULT_BASE_MODEL = "Qwen/Qwen3-0.6B"
18+
DEFAULT_GPU_MEMORY_UTILIZATION = 0.2
19+
DEFAULT_MAX_MODEL_LEN = 2048
20+
DEFAULT_MAX_SEQ_LENGTH = 2048
21+
22+
23+
def get_base_model() -> str:
24+
return os.environ.get("BASE_MODEL", DEFAULT_BASE_MODEL)
25+
26+
27+
def get_safe_gpu_memory_utilization() -> float:
28+
requested = float(
29+
os.environ.get(
30+
"ART_TEST_GPU_MEMORY_UTILIZATION",
31+
str(DEFAULT_GPU_MEMORY_UTILIZATION),
32+
)
33+
)
34+
min_free_gib = float(os.environ.get("ART_TEST_MIN_FREE_GPU_GIB", "8"))
35+
free_bytes, total_bytes = torch.cuda.mem_get_info()
36+
free_gib = free_bytes / (1024**3)
37+
if free_gib < min_free_gib:
38+
pytest.skip(
39+
f"Insufficient free GPU memory for vLLM contract test: {free_gib:.1f} GiB free < {min_free_gib:.1f} GiB required."
40+
)
41+
# Keep requested utilization below currently free memory with headroom.
42+
return max(0.02, min(requested, (free_bytes / total_bytes) * 0.8))
43+
44+
45+
def get_vllm_test_config() -> art.dev.InternalModelConfig:
46+
return {
47+
"engine_args": {
48+
"gpu_memory_utilization": get_safe_gpu_memory_utilization(),
49+
"max_model_len": int(
50+
os.environ.get("ART_TEST_MAX_MODEL_LEN", str(DEFAULT_MAX_MODEL_LEN))
51+
),
52+
"max_num_seqs": 8,
53+
"enforce_eager": True,
54+
},
55+
"init_args": {
56+
"max_seq_length": int(
57+
os.environ.get("ART_TEST_MAX_SEQ_LENGTH", str(DEFAULT_MAX_SEQ_LENGTH))
58+
),
59+
},
60+
}
61+
62+
63+
async def simple_rollout(
64+
client: openai.AsyncOpenAI, model_name: str, prompt: str
65+
) -> art.Trajectory:
66+
messages: art.Messages = [{"role": "user", "content": prompt}]
67+
completion = await client.chat.completions.create(
68+
messages=messages,
69+
model=model_name,
70+
max_tokens=10,
71+
timeout=60,
72+
temperature=1,
73+
logprobs=True,
74+
top_logprobs=0,
75+
)
76+
choice = completion.choices[0]
77+
content = (choice.message.content or "").lower()
78+
if "yes" in content:
79+
reward = 1.0
80+
elif "no" in content:
81+
reward = 0.5
82+
elif "maybe" in content:
83+
reward = 0.25
84+
else:
85+
reward = 0.0
86+
return art.Trajectory(messages_and_choices=[*messages, choice], reward=reward)
87+
88+
89+
async def assert_chat_logprobs(
90+
client: openai.AsyncOpenAI,
91+
model_name: str,
92+
) -> None:
93+
completion = await client.chat.completions.create(
94+
messages=[{"role": "user", "content": "Say hello."}],
95+
model=model_name,
96+
max_tokens=8,
97+
timeout=60,
98+
logprobs=True,
99+
top_logprobs=0,
100+
)
101+
assert completion.choices[0].logprobs is not None
102+
103+
104+
@pytest.mark.skipif(
105+
not torch.cuda.is_available(),
106+
reason="No CUDA available in this environment",
107+
)
108+
async def test_local_backend_vllm_contract() -> None:
109+
model_name = f"test-vllm-contract-{uuid.uuid4().hex[:8]}"
110+
with tempfile.TemporaryDirectory() as tmpdir:
111+
backend = LocalBackend(path=tmpdir)
112+
model = art.TrainableModel(
113+
name=model_name,
114+
project="integration-tests",
115+
base_model=get_base_model(),
116+
)
117+
object.__setattr__(model, "_internal_config", get_vllm_test_config())
118+
try:
119+
await model.register(backend)
120+
client = model.openai_client()
121+
122+
step0_name = model.get_inference_name(step=0)
123+
await assert_chat_logprobs(client, step0_name)
124+
125+
model_ids = [m.id async for m in client.models.list()]
126+
assert f"{model.name}@0" in model_ids
127+
128+
train_groups = await art.gather_trajectory_groups(
129+
[
130+
art.TrajectoryGroup(
131+
[simple_rollout(client, step0_name, prompt) for _ in range(2)]
132+
)
133+
for prompt in ("Say yes", "Say no")
134+
] # ty:ignore[invalid-argument-type]
135+
)
136+
result = await backend.train(model, train_groups, learning_rate=1e-5)
137+
assert isinstance(result, LocalTrainResult)
138+
assert result.step > 0
139+
140+
latest_name = model.get_inference_name(step=result.step)
141+
await assert_chat_logprobs(client, latest_name)
142+
await assert_chat_logprobs(client, step0_name)
143+
144+
model_ids_after = [m.id async for m in client.models.list()]
145+
assert f"{model.name}@0" in model_ids_after
146+
assert f"{model.name}@{result.step}" in model_ids_after
147+
finally:
148+
await backend.close()

tests/test_backend_train_api.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,59 @@
88
"""
99

1010
import asyncio
11+
import os
1112
import tempfile
1213

1314
import art
1415
from art.local import LocalBackend
1516
from art.types import LocalTrainResult
1617

18+
DEFAULT_GPU_MEMORY_UTILIZATION = 0.2
19+
DEFAULT_MAX_MODEL_LEN = 2048
20+
DEFAULT_MAX_SEQ_LENGTH = 2048
21+
22+
23+
def get_vllm_test_config() -> tuple[art.dev.InternalModelConfig, str | None]:
24+
requested = float(
25+
os.environ.get(
26+
"ART_TEST_GPU_MEMORY_UTILIZATION",
27+
str(DEFAULT_GPU_MEMORY_UTILIZATION),
28+
)
29+
)
30+
min_free_gib = float(os.environ.get("ART_TEST_MIN_FREE_GPU_GIB", "8"))
31+
safe_utilization = requested
32+
skip_reason: str | None = None
33+
try:
34+
import torch
35+
36+
if torch.cuda.is_available():
37+
free_bytes, total_bytes = torch.cuda.mem_get_info()
38+
free_gib = free_bytes / (1024**3)
39+
if free_gib < min_free_gib:
40+
skip_reason = (
41+
f"Skipping backend.train API test: free GPU memory is too low "
42+
f"({free_gib:.2f} GiB < {min_free_gib:.2f} GiB)."
43+
)
44+
safe_utilization = min(requested, (free_bytes / total_bytes) * 0.8)
45+
except Exception:
46+
pass
47+
48+
return {
49+
"engine_args": {
50+
"gpu_memory_utilization": safe_utilization,
51+
"max_model_len": int(
52+
os.environ.get("ART_TEST_MAX_MODEL_LEN", str(DEFAULT_MAX_MODEL_LEN))
53+
),
54+
"max_num_seqs": 8,
55+
"enforce_eager": True,
56+
},
57+
"init_args": {
58+
"max_seq_length": int(
59+
os.environ.get("ART_TEST_MAX_SEQ_LENGTH", str(DEFAULT_MAX_SEQ_LENGTH))
60+
),
61+
},
62+
}, skip_reason
63+
1764

1865
async def simple_rollout(client, model_name: str, prompt: str) -> art.Trajectory:
1966
"""A simple rollout function for testing."""
@@ -53,6 +100,11 @@ async def main():
53100
project="api-test",
54101
base_model="Qwen/Qwen3-0.6B",
55102
)
103+
test_config, skip_reason = get_vllm_test_config()
104+
if skip_reason is not None:
105+
print(f"\n{skip_reason}")
106+
return
107+
object.__setattr__(model, "_internal_config", test_config)
56108

57109
try:
58110
print("\n1. Registering model with backend...")
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
"""Unit tests for ART's vLLM patch contract."""
2+
3+
import importlib
4+
5+
import pytest
6+
7+
pytest.importorskip("cloudpickle")
8+
pytest.importorskip("vllm")
9+
10+
from art.vllm.patches import patch_tool_parser_manager, subclass_chat_completion_request
11+
12+
13+
def test_subclass_chat_completion_request_forces_logprobs() -> None:
14+
protocol = importlib.import_module(
15+
"vllm.entrypoints.openai.chat_completion.protocol"
16+
)
17+
original = getattr(protocol, "ChatCompletionRequest")
18+
19+
try:
20+
subclass_chat_completion_request()
21+
request_cls = getattr(protocol, "ChatCompletionRequest")
22+
request = request_cls(
23+
messages=[{"role": "user", "content": "hello"}],
24+
model="dummy-model",
25+
)
26+
assert request.logprobs is True
27+
assert request.top_logprobs == 0
28+
finally:
29+
setattr(protocol, "ChatCompletionRequest", original)
30+
31+
32+
def test_patch_tool_parser_manager_falls_back_to_empty_delta_message() -> None:
33+
protocol = importlib.import_module("vllm.entrypoints.openai.engine.protocol")
34+
DeltaMessage = protocol.DeltaMessage
35+
36+
from vllm.tool_parsers.abstract_tool_parser import ToolParserManager
37+
38+
class DummyToolParser:
39+
@staticmethod
40+
def extract_tool_calls_streaming(*_args, **_kwargs):
41+
return None
42+
43+
original_get_tool_parser = ToolParserManager.get_tool_parser
44+
45+
try:
46+
setattr(
47+
ToolParserManager,
48+
"get_tool_parser",
49+
classmethod(lambda _cls, _name: DummyToolParser),
50+
)
51+
patch_tool_parser_manager()
52+
53+
parser_cls = ToolParserManager.get_tool_parser("dummy")
54+
result = parser_cls.extract_tool_calls_streaming("", "", "", [], [], [], None) # ty:ignore[missing-argument,invalid-argument-type]
55+
56+
assert isinstance(result, DeltaMessage)
57+
finally:
58+
setattr(ToolParserManager, "get_tool_parser", original_get_tool_parser)

0 commit comments

Comments
 (0)