33import os
44import socket
55import time
6+ from typing import cast
67import uuid
78
89from fastapi import FastAPI , HTTPException , Request
2021import uvicorn
2122
2223from art .tinker .cookbook_v import renderers
24+ from art .tinker .prefix_cache import LRUTrieCache
2325
2426
2527@dataclass
@@ -29,6 +31,9 @@ class OpenAICompatibleTinkerServer:
2931 sampling_clients_and_renderers : dict [
3032 str , tuple [tinker .SamplingClient , renderers .Renderer ]
3133 ] = field (default_factory = dict )
34+ prompt_prefix_cache : LRUTrieCache = field (
35+ default_factory = lambda : LRUTrieCache (max_entries = 1000 )
36+ )
3237 _task : asyncio .Task [None ] | None = None
3338
3439 async def start (self ) -> tuple [str , int ]:
@@ -84,17 +89,26 @@ async def chat_completions(
8489 raise HTTPException (
8590 status_code = 404 , detail = f"Model { body ['model' ]} not found"
8691 )
87-
88- prompt = tinker . ModelInput . from_ints (
89- tokens = renderer .tokenizer .apply_chat_template (
92+ rendered_prompt_tokens = cast (
93+ list [ int ],
94+ renderer .tokenizer .apply_chat_template (
9095 list (body ["messages" ]), # type: ignore
9196 tools = body .get ("tools" ), # type: ignore
9297 add_generation_prompt = True ,
93- ) # ty:ignore[invalid-argument-type]
98+ ),
9499 )
100+ prompt_tokens = rendered_prompt_tokens
101+ prefix_entry = self .prompt_prefix_cache .lookup (rendered_prompt_tokens )
102+ if prefix_entry is not None and prefix_entry .rendered_len <= len (
103+ rendered_prompt_tokens
104+ ):
105+ prompt_tokens = (
106+ list (prefix_entry .raw_prefix )
107+ + rendered_prompt_tokens [prefix_entry .rendered_len :]
108+ )
95109 try :
96110 sample_response = await sampler_client .sample_async (
97- prompt = prompt ,
111+ prompt = tinker . ModelInput . from_ints ( tokens = prompt_tokens ) ,
98112 num_samples = body .get ("n" ) or 1 ,
99113 sampling_params = tinker .SamplingParams (
100114 max_tokens = body .get ("max_completion_tokens" )
@@ -119,9 +133,17 @@ async def chat_completions(
119133 choices : list [Choice ] = []
120134 for i , sequence in enumerate (sample_response .sequences ):
121135 assert sequence .logprobs is not None , "Logprobs are required"
122- assert len (sequence .tokens ) == len (sequence .logprobs ), (
123- "Tokens and logprobs must have the same length"
136+ assert len (sequence .tokens ) == len (
137+ sequence .logprobs
138+ ), "Tokens and logprobs must have the same length"
139+ rendered_response_tokens = renderer .tokenizer .encode (
140+ renderer .tokenizer .decode (sequence .tokens )
124141 )
142+ if rendered_response_tokens != sequence .tokens :
143+ self .prompt_prefix_cache .insert (
144+ rendered_prompt_tokens + rendered_response_tokens ,
145+ prompt_tokens + sequence .tokens ,
146+ )
125147 message , _ = renderer .parse_response (sequence .tokens )
126148 openai_message = renderer .to_openai_message (message )
127149 tool_calls = (
@@ -176,8 +198,8 @@ async def chat_completions(
176198 object = "chat.completion" ,
177199 usage = CompletionUsage (
178200 completion_tokens = completion_tokens ,
179- prompt_tokens = prompt . length ,
180- total_tokens = completion_tokens + prompt . length ,
201+ prompt_tokens = len ( prompt_tokens ) ,
202+ total_tokens = completion_tokens + len ( prompt_tokens ) ,
181203 ),
182204 )
183205
0 commit comments