Skip to content

Commit ca97bff

Browse files
committed
feat: Enhance OpenAICompatibleTinkerServer with model management improvements
- Introduced max_concurrent_sampling_clients to manage concurrent sampling. - Refactored model access to use tenant-scoped properties for better encapsulation. - Updated sampling client handling to use an async context manager for improved resource management. - Enhanced error handling and response structure in model-related API endpoints.
1 parent ca77e97 commit ca97bff

File tree

1 file changed

+77
-35
lines changed

1 file changed

+77
-35
lines changed

src/art/tinker/server.py

Lines changed: 77 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import asyncio
2+
from collections import defaultdict
3+
from contextlib import asynccontextmanager
24
from dataclasses import dataclass, field
35
from itertools import cycle
46
import json
57
import os
68
import socket
79
import time
8-
from typing import Annotated, Literal
10+
from typing import Annotated, AsyncGenerator, Literal
911
import uuid
1012

1113
from fastapi import FastAPI, HTTPException, Request
@@ -50,6 +52,7 @@ class OpenAICompatibleTinkerServer:
5052
host: str | None = None
5153
port: int | None = None
5254
num_workers: int | None = None
55+
max_concurrent_sampling_clients: int | None = None
5356
_prefix_cache: LRUTrieCache = field(default_factory=LRUTrieCache)
5457
_task: asyncio.Task[None] | None = None
5558
_tenants: dict[str, "OpenAICompatibleTinkerServerTenant"] = field(
@@ -61,13 +64,13 @@ class OpenAICompatibleTinkerServer:
6164
def models(self) -> dict[str, str]:
6265
if "TINKER_API_KEY" not in os.environ:
6366
raise ValueError("TINKER_API_KEY is not set")
64-
return self._get_tenant(os.environ["TINKER_API_KEY"])._models
67+
return self._get_tenant(os.environ["TINKER_API_KEY"]).models
6568

6669
@models.setter
6770
def models(self, models: dict[str, str]) -> None:
6871
if "TINKER_API_KEY" not in os.environ:
6972
raise ValueError("TINKER_API_KEY is not set")
70-
self._get_tenant(os.environ["TINKER_API_KEY"])._models = models
73+
self._get_tenant(os.environ["TINKER_API_KEY"]).models = models
7174

7275
async def start(self) -> tuple[str, int]:
7376
host = self.host or "0.0.0.0"
@@ -139,25 +142,25 @@ async def list_models(request: Request) -> ModelList:
139142
data=[
140143
Model(
141144
id=model,
142-
created=tenant._model_timestamps.get(model, 0),
145+
created=tenant.model_timestamps.get(model, 0),
143146
object="model",
144147
owned_by="tinker",
145148
)
146-
for model in tenant._models
149+
for model in tenant.models
147150
],
148151
)
149152

150153
@app.get("/v1/models/{model}")
151154
async def get_model(request: Request, model: str) -> Model:
152155
tenant = self._get_request_tenant(request)
153-
if model not in tenant._models:
156+
if model not in tenant.models:
154157
raise HTTPException(
155158
status_code=404,
156159
detail=f"Model not found: {model}",
157160
)
158161
return Model(
159162
id=model,
160-
created=tenant._model_timestamps.get(model, 0),
163+
created=tenant.model_timestamps.get(model, 0),
161164
object="model",
162165
owned_by="tinker",
163166
)
@@ -169,25 +172,25 @@ async def put_model(
169172
body: ModelUpsert,
170173
) -> Model:
171174
tenant = self._get_request_tenant(request)
172-
tenant._models[model] = body.target
173-
tenant._model_timestamps.setdefault(model, int(time.time()))
175+
tenant.models[model] = body.target
176+
tenant.model_timestamps.setdefault(model, int(time.time()))
174177
return Model(
175178
id=model,
176-
created=tenant._model_timestamps[model],
179+
created=tenant.model_timestamps[model],
177180
object="model",
178181
owned_by="tinker",
179182
)
180183

181184
@app.delete("/v1/models/{model}")
182185
async def delete_model(request: Request, model: str) -> ModelDeleted:
183186
tenant = self._get_request_tenant(request)
184-
if model not in tenant._models:
187+
if model not in tenant.models:
185188
raise HTTPException(
186189
status_code=404,
187190
detail=f"Model not found: {model}",
188191
)
189-
tenant._models.pop(model)
190-
tenant._model_timestamps.pop(model, None)
192+
tenant.models.pop(model)
193+
tenant.model_timestamps.pop(model, None)
191194
return ModelDeleted(
192195
id=model,
193196
deleted=True,
@@ -200,7 +203,7 @@ async def chat_completions(
200203
) -> ChatCompletion:
201204
worker = next(workers)
202205
tenant = self._get_request_tenant(request)
203-
samplable_model = await tenant._get_samplable_model(body["model"])
206+
samplable_model = await tenant.get_samplable_model(body["model"])
204207
rendered_prompt_tokens = await worker.prompt_tokens(
205208
base_model=samplable_model.base_model,
206209
messages=list(body["messages"]),
@@ -216,20 +219,21 @@ async def chat_completions(
216219
+ rendered_prompt_tokens[prefix_entry.rendered_len :]
217220
)
218221
try:
219-
sample_response = await samplable_model.sampling_client.sample_async(
220-
prompt=tinker.ModelInput.from_ints(tokens=prompt_tokens),
221-
num_samples=body.get("n") or 1,
222-
sampling_params=tinker.SamplingParams(
223-
max_tokens=body.get("max_completion_tokens")
224-
or body.get("max_tokens"),
225-
seed=body.get("seed"),
226-
temperature=(
227-
t if (t := body.get("temperature")) is not None else 1.0
222+
async with samplable_model.sampling_client() as sampling_client:
223+
sample_response = await sampling_client.sample_async(
224+
prompt=tinker.ModelInput.from_ints(tokens=prompt_tokens),
225+
num_samples=body.get("n") or 1,
226+
sampling_params=tinker.SamplingParams(
227+
max_tokens=body.get("max_completion_tokens")
228+
or body.get("max_tokens"),
229+
seed=body.get("seed"),
230+
temperature=(
231+
t if (t := body.get("temperature")) is not None else 1.0
232+
),
233+
top_k=body.get("top_k") or -1,
234+
top_p=body.get("top_p") or 1.0,
228235
),
229-
top_k=body.get("top_k") or -1,
230-
top_p=body.get("top_p") or 1.0,
231-
),
232-
)
236+
)
233237
except tinker.APIStatusError as e:
234238
error_body = e.body
235239
if isinstance(error_body, dict) and "detail" in error_body:
@@ -272,30 +276,52 @@ def _default_num_workers(self) -> int:
272276

273277
def _get_tenant(self, api_key: str) -> "OpenAICompatibleTinkerServerTenant":
274278
if api_key not in self._tenants:
275-
self._tenants[api_key] = OpenAICompatibleTinkerServerTenant(api_key)
279+
self._tenants[api_key] = OpenAICompatibleTinkerServerTenant(
280+
api_key, self.max_concurrent_sampling_clients or 32
281+
)
276282
return self._tenants[api_key]
277283

278284

279285
@dataclass
280286
class OpenAICompatibleTinkerServerSamplableModel:
281-
sampling_client: tinker.SamplingClient
282287
base_model: str
288+
_sampling_client: tinker.SamplingClient
289+
_concurrent_sampling_client_semaphore: asyncio.Semaphore
290+
_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
291+
_yields: int = 0
292+
293+
@asynccontextmanager
294+
async def sampling_client(self) -> AsyncGenerator[tinker.SamplingClient, None]:
295+
async with self._lock:
296+
if self._yields == 0:
297+
await self._concurrent_sampling_client_semaphore.acquire()
298+
self._yields += 1
299+
try:
300+
yield self._sampling_client
301+
finally:
302+
async with self._lock:
303+
self._yields -= 1
304+
if self._yields == 0:
305+
self._concurrent_sampling_client_semaphore.release()
283306

284307

285308
class OpenAICompatibleTinkerServerTenant:
286-
def __init__(self, api_key: str) -> None:
287-
self._models: dict[str, str] = {}
288-
self._model_timestamps: dict[str, int] = {}
309+
def __init__(self, api_key: str, max_concurrent_sampling_clients: int) -> None:
310+
self.models: dict[str, str] = {}
311+
self.model_timestamps: dict[str, int] = {}
289312
self._service_client = tinker.ServiceClient(api_key=api_key)
290313
self._rest_client = self._service_client.create_rest_client()
291314
self._samplable_models: dict[
292315
str, asyncio.Task[OpenAICompatibleTinkerServerSamplableModel]
293316
] = dict()
317+
self._concurrent_sampling_client_semaphores: defaultdict[
318+
str, asyncio.Semaphore
319+
] = defaultdict(lambda: asyncio.Semaphore(max_concurrent_sampling_clients))
294320

295-
async def _get_samplable_model(
321+
async def get_samplable_model(
296322
self, model: str
297323
) -> OpenAICompatibleTinkerServerSamplableModel:
298-
model_path_or_base_model = self._models.get(model, model)
324+
model_path_or_base_model = self.models.get(model, model)
299325
if not model_path_or_base_model.startswith("tinker://"):
300326
try:
301327
get_renderer_name(model_path_or_base_model)
@@ -331,9 +357,25 @@ async def _load_samplable_model(
331357
base_model = sampler_response.base_model
332358
else:
333359
base_model = model_path_or_base_model
360+
# on_queue_state_change = sampling_client.on_queue_state_change
361+
362+
# def patched_on_queue_state_change(
363+
# queue_state: TinkerQueueState, queue_state_reason: str | None
364+
# ) -> None:
365+
# on_queue_state_change(queue_state, queue_state_reason)
366+
# if queue_state == TinkerQueueState.PAUSED_RATE_LIMIT:
367+
# # implicit upper-bound on the number of concurrent sampling clients found
368+
# # do not allow this number of concurrent sampling clients again
369+
# semaphore = self._concurrent_sampling_client_semaphores[base_model]
370+
# semaphore._value = max(semaphore._value - 1, -4)
371+
372+
# sampling_client.on_queue_state_change = patched_on_queue_state_change
334373
return OpenAICompatibleTinkerServerSamplableModel(
335-
sampling_client=sampling_client,
336374
base_model=base_model,
375+
_sampling_client=sampling_client,
376+
_concurrent_sampling_client_semaphore=self._concurrent_sampling_client_semaphores[
377+
base_model
378+
],
337379
)
338380

339381

0 commit comments

Comments
 (0)