Skip to content

Commit 01c2070

Browse files
committed
feat: implement prompt prefix caching in OpenAICompatibleTinkerServer
- Introduced LRUTrieCache for caching prompt prefixes to optimize token processing. - Enhanced token handling by integrating cached prefixes into the sampling process. - Improved error handling and assertions for token and logprob consistency. - Updated response handling to cache rendered prompts for future use.
1 parent 656c0e4 commit 01c2070

1 file changed

Lines changed: 31 additions & 9 deletions

File tree

src/art/tinker/server.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import socket
55
import time
6+
from typing import cast
67
import uuid
78

89
from fastapi import FastAPI, HTTPException, Request
@@ -20,6 +21,7 @@
2021
import uvicorn
2122

2223
from art.tinker.cookbook_v import renderers
24+
from art.tinker.prefix_cache import LRUTrieCache
2325

2426

2527
@dataclass
@@ -29,6 +31,9 @@ class OpenAICompatibleTinkerServer:
2931
sampling_clients_and_renderers: dict[
3032
str, tuple[tinker.SamplingClient, renderers.Renderer]
3133
] = field(default_factory=dict)
34+
prompt_prefix_cache: LRUTrieCache = field(
35+
default_factory=lambda: LRUTrieCache(max_entries=1000)
36+
)
3237
_task: asyncio.Task[None] | None = None
3338

3439
async def start(self) -> tuple[str, int]:
@@ -84,17 +89,26 @@ async def chat_completions(
8489
raise HTTPException(
8590
status_code=404, detail=f"Model {body['model']} not found"
8691
)
87-
88-
prompt = tinker.ModelInput.from_ints(
89-
tokens=renderer.tokenizer.apply_chat_template(
92+
rendered_prompt_tokens = cast(
93+
list[int],
94+
renderer.tokenizer.apply_chat_template(
9095
list(body["messages"]), # type: ignore
9196
tools=body.get("tools"), # type: ignore
9297
add_generation_prompt=True,
93-
) # ty:ignore[invalid-argument-type]
98+
),
9499
)
100+
prompt_tokens = rendered_prompt_tokens
101+
prefix_entry = self.prompt_prefix_cache.lookup(rendered_prompt_tokens)
102+
if prefix_entry is not None and prefix_entry.rendered_len <= len(
103+
rendered_prompt_tokens
104+
):
105+
prompt_tokens = (
106+
list(prefix_entry.raw_prefix)
107+
+ rendered_prompt_tokens[prefix_entry.rendered_len :]
108+
)
95109
try:
96110
sample_response = await sampler_client.sample_async(
97-
prompt=prompt,
111+
prompt=tinker.ModelInput.from_ints(tokens=prompt_tokens),
98112
num_samples=body.get("n") or 1,
99113
sampling_params=tinker.SamplingParams(
100114
max_tokens=body.get("max_completion_tokens")
@@ -119,9 +133,17 @@ async def chat_completions(
119133
choices: list[Choice] = []
120134
for i, sequence in enumerate(sample_response.sequences):
121135
assert sequence.logprobs is not None, "Logprobs are required"
122-
assert len(sequence.tokens) == len(sequence.logprobs), (
123-
"Tokens and logprobs must have the same length"
136+
assert len(sequence.tokens) == len(
137+
sequence.logprobs
138+
), "Tokens and logprobs must have the same length"
139+
rendered_response_tokens = renderer.tokenizer.encode(
140+
renderer.tokenizer.decode(sequence.tokens)
124141
)
142+
if rendered_response_tokens != sequence.tokens:
143+
self.prompt_prefix_cache.insert(
144+
rendered_prompt_tokens + rendered_response_tokens,
145+
prompt_tokens + sequence.tokens,
146+
)
125147
message, _ = renderer.parse_response(sequence.tokens)
126148
openai_message = renderer.to_openai_message(message)
127149
tool_calls = (
@@ -176,8 +198,8 @@ async def chat_completions(
176198
object="chat.completion",
177199
usage=CompletionUsage(
178200
completion_tokens=completion_tokens,
179-
prompt_tokens=prompt.length,
180-
total_tokens=completion_tokens + prompt.length,
201+
prompt_tokens=len(prompt_tokens),
202+
total_tokens=completion_tokens + len(prompt_tokens),
181203
),
182204
)
183205

0 commit comments

Comments
 (0)