11import asyncio
2+ from collections import defaultdict
3+ from contextlib import asynccontextmanager
24from dataclasses import dataclass , field
35from itertools import cycle
46import json
57import os
68import socket
79import time
8- from typing import Annotated , Literal
10+ from typing import Annotated , AsyncGenerator , Literal
911import uuid
1012
1113from fastapi import FastAPI , HTTPException , Request
@@ -50,6 +52,7 @@ class OpenAICompatibleTinkerServer:
5052 host : str | None = None
5153 port : int | None = None
5254 num_workers : int | None = None
55+ max_concurrent_sampling_clients : int | None = None
5356 _prefix_cache : LRUTrieCache = field (default_factory = LRUTrieCache )
5457 _task : asyncio .Task [None ] | None = None
5558 _tenants : dict [str , "OpenAICompatibleTinkerServerTenant" ] = field (
@@ -61,13 +64,13 @@ class OpenAICompatibleTinkerServer:
6164 def models (self ) -> dict [str , str ]:
6265 if "TINKER_API_KEY" not in os .environ :
6366 raise ValueError ("TINKER_API_KEY is not set" )
64- return self ._get_tenant (os .environ ["TINKER_API_KEY" ])._models
67+ return self ._get_tenant (os .environ ["TINKER_API_KEY" ]).models
6568
6669 @models .setter
6770 def models (self , models : dict [str , str ]) -> None :
6871 if "TINKER_API_KEY" not in os .environ :
6972 raise ValueError ("TINKER_API_KEY is not set" )
70- self ._get_tenant (os .environ ["TINKER_API_KEY" ])._models = models
73+ self ._get_tenant (os .environ ["TINKER_API_KEY" ]).models = models
7174
7275 async def start (self ) -> tuple [str , int ]:
7376 host = self .host or "0.0.0.0"
@@ -139,25 +142,25 @@ async def list_models(request: Request) -> ModelList:
139142 data = [
140143 Model (
141144 id = model ,
142- created = tenant ._model_timestamps .get (model , 0 ),
145+ created = tenant .model_timestamps .get (model , 0 ),
143146 object = "model" ,
144147 owned_by = "tinker" ,
145148 )
146- for model in tenant ._models
149+ for model in tenant .models
147150 ],
148151 )
149152
150153 @app .get ("/v1/models/{model}" )
151154 async def get_model (request : Request , model : str ) -> Model :
152155 tenant = self ._get_request_tenant (request )
153- if model not in tenant ._models :
156+ if model not in tenant .models :
154157 raise HTTPException (
155158 status_code = 404 ,
156159 detail = f"Model not found: { model } " ,
157160 )
158161 return Model (
159162 id = model ,
160- created = tenant ._model_timestamps .get (model , 0 ),
163+ created = tenant .model_timestamps .get (model , 0 ),
161164 object = "model" ,
162165 owned_by = "tinker" ,
163166 )
@@ -169,25 +172,25 @@ async def put_model(
169172 body : ModelUpsert ,
170173 ) -> Model :
171174 tenant = self ._get_request_tenant (request )
172- tenant ._models [model ] = body .target
173- tenant ._model_timestamps .setdefault (model , int (time .time ()))
175+ tenant .models [model ] = body .target
176+ tenant .model_timestamps .setdefault (model , int (time .time ()))
174177 return Model (
175178 id = model ,
176- created = tenant ._model_timestamps [model ],
179+ created = tenant .model_timestamps [model ],
177180 object = "model" ,
178181 owned_by = "tinker" ,
179182 )
180183
181184 @app .delete ("/v1/models/{model}" )
182185 async def delete_model (request : Request , model : str ) -> ModelDeleted :
183186 tenant = self ._get_request_tenant (request )
184- if model not in tenant ._models :
187+ if model not in tenant .models :
185188 raise HTTPException (
186189 status_code = 404 ,
187190 detail = f"Model not found: { model } " ,
188191 )
189- tenant ._models .pop (model )
190- tenant ._model_timestamps .pop (model , None )
192+ tenant .models .pop (model )
193+ tenant .model_timestamps .pop (model , None )
191194 return ModelDeleted (
192195 id = model ,
193196 deleted = True ,
@@ -200,7 +203,7 @@ async def chat_completions(
200203 ) -> ChatCompletion :
201204 worker = next (workers )
202205 tenant = self ._get_request_tenant (request )
203- samplable_model = await tenant ._get_samplable_model (body ["model" ])
206+ samplable_model = await tenant .get_samplable_model (body ["model" ])
204207 rendered_prompt_tokens = await worker .prompt_tokens (
205208 base_model = samplable_model .base_model ,
206209 messages = list (body ["messages" ]),
@@ -216,20 +219,21 @@ async def chat_completions(
216219 + rendered_prompt_tokens [prefix_entry .rendered_len :]
217220 )
218221 try :
219- sample_response = await samplable_model .sampling_client .sample_async (
220- prompt = tinker .ModelInput .from_ints (tokens = prompt_tokens ),
221- num_samples = body .get ("n" ) or 1 ,
222- sampling_params = tinker .SamplingParams (
223- max_tokens = body .get ("max_completion_tokens" )
224- or body .get ("max_tokens" ),
225- seed = body .get ("seed" ),
226- temperature = (
227- t if (t := body .get ("temperature" )) is not None else 1.0
222+ async with samplable_model .sampling_client () as sampling_client :
223+ sample_response = await sampling_client .sample_async (
224+ prompt = tinker .ModelInput .from_ints (tokens = prompt_tokens ),
225+ num_samples = body .get ("n" ) or 1 ,
226+ sampling_params = tinker .SamplingParams (
227+ max_tokens = body .get ("max_completion_tokens" )
228+ or body .get ("max_tokens" ),
229+ seed = body .get ("seed" ),
230+ temperature = (
231+ t if (t := body .get ("temperature" )) is not None else 1.0
232+ ),
233+ top_k = body .get ("top_k" ) or - 1 ,
234+ top_p = body .get ("top_p" ) or 1.0 ,
228235 ),
229- top_k = body .get ("top_k" ) or - 1 ,
230- top_p = body .get ("top_p" ) or 1.0 ,
231- ),
232- )
236+ )
233237 except tinker .APIStatusError as e :
234238 error_body = e .body
235239 if isinstance (error_body , dict ) and "detail" in error_body :
@@ -272,30 +276,52 @@ def _default_num_workers(self) -> int:
272276
273277 def _get_tenant (self , api_key : str ) -> "OpenAICompatibleTinkerServerTenant" :
274278 if api_key not in self ._tenants :
275- self ._tenants [api_key ] = OpenAICompatibleTinkerServerTenant (api_key )
279+ self ._tenants [api_key ] = OpenAICompatibleTinkerServerTenant (
280+ api_key , self .max_concurrent_sampling_clients or 32
281+ )
276282 return self ._tenants [api_key ]
277283
278284
279285@dataclass
280286class OpenAICompatibleTinkerServerSamplableModel :
281- sampling_client : tinker .SamplingClient
282287 base_model : str
288+ _sampling_client : tinker .SamplingClient
289+ _concurrent_sampling_client_semaphore : asyncio .Semaphore
290+ _lock : asyncio .Lock = field (default_factory = asyncio .Lock )
291+ _yields : int = 0
292+
293+ @asynccontextmanager
294+ async def sampling_client (self ) -> AsyncGenerator [tinker .SamplingClient , None ]:
295+ async with self ._lock :
296+ if self ._yields == 0 :
297+ await self ._concurrent_sampling_client_semaphore .acquire ()
298+ self ._yields += 1
299+ try :
300+ yield self ._sampling_client
301+ finally :
302+ async with self ._lock :
303+ self ._yields -= 1
304+ if self ._yields == 0 :
305+ self ._concurrent_sampling_client_semaphore .release ()
283306
284307
285308class OpenAICompatibleTinkerServerTenant :
286- def __init__ (self , api_key : str ) -> None :
287- self ._models : dict [str , str ] = {}
288- self ._model_timestamps : dict [str , int ] = {}
309+ def __init__ (self , api_key : str , max_concurrent_sampling_clients : int ) -> None :
310+ self .models : dict [str , str ] = {}
311+ self .model_timestamps : dict [str , int ] = {}
289312 self ._service_client = tinker .ServiceClient (api_key = api_key )
290313 self ._rest_client = self ._service_client .create_rest_client ()
291314 self ._samplable_models : dict [
292315 str , asyncio .Task [OpenAICompatibleTinkerServerSamplableModel ]
293316 ] = dict ()
317+ self ._concurrent_sampling_client_semaphores : defaultdict [
318+ str , asyncio .Semaphore
319+ ] = defaultdict (lambda : asyncio .Semaphore (max_concurrent_sampling_clients ))
294320
295- async def _get_samplable_model (
321+ async def get_samplable_model (
296322 self , model : str
297323 ) -> OpenAICompatibleTinkerServerSamplableModel :
298- model_path_or_base_model = self ._models .get (model , model )
324+ model_path_or_base_model = self .models .get (model , model )
299325 if not model_path_or_base_model .startswith ("tinker://" ):
300326 try :
301327 get_renderer_name (model_path_or_base_model )
@@ -331,9 +357,25 @@ async def _load_samplable_model(
331357 base_model = sampler_response .base_model
332358 else :
333359 base_model = model_path_or_base_model
360+ # on_queue_state_change = sampling_client.on_queue_state_change
361+
362+ # def patched_on_queue_state_change(
363+ # queue_state: TinkerQueueState, queue_state_reason: str | None
364+ # ) -> None:
365+ # on_queue_state_change(queue_state, queue_state_reason)
366+ # if queue_state == TinkerQueueState.PAUSED_RATE_LIMIT:
367+ # # implicit upper-bound on the number of concurrent sampling clients found
368+ # # do not allow this number of concurrent sampling clients again
369+ # semaphore = self._concurrent_sampling_client_semaphores[base_model]
370+ # semaphore._value = max(semaphore._value - 1, -4)
371+
372+ # sampling_client.on_queue_state_change = patched_on_queue_state_change
334373 return OpenAICompatibleTinkerServerSamplableModel (
335- sampling_client = sampling_client ,
336374 base_model = base_model ,
375+ _sampling_client = sampling_client ,
376+ _concurrent_sampling_client_semaphore = self ._concurrent_sampling_client_semaphores [
377+ base_model
378+ ],
337379 )
338380
339381
0 commit comments