Skip to content

Commit 4683fd0

Browse files
committed
feat: add datrie dependency and implement LRUTrieCache for prefix caching
- Added datrie as a dependency in pyproject.toml. - Introduced LRUTrieCache class for efficient caching of token sequence prefixes. - Implemented methods for inserting and looking up cached prefixes. - Added unit tests for LRUTrieCache to ensure functionality and eviction behavior. - Updated server code to handle tool calls with type safety.
1 parent 01c2070 commit 4683fd0

5 files changed

Lines changed: 108 additions & 1 deletion

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ tinker = [
6161
"torch>=2.8.0",
6262
"transformers>=4.55.2,<=4.57.3",
6363
"uvicorn>=0.35.0",
64+
"datrie>=0.8.3",
6465
]
6566

6667
[project.scripts]

src/art/tinker/prefix_cache.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from __future__ import annotations
2+
3+
from collections import OrderedDict
4+
from dataclasses import dataclass
5+
import struct
6+
from typing import Sequence
7+
8+
import datrie
9+
10+
_TOKEN_BYTES = 4
11+
_TRIE_ALPHABET = "".join(chr(i) for i in range(256))
12+
13+
14+
@dataclass(frozen=True)
15+
class PrefixEntry:
16+
rendered_len: int
17+
raw_prefix: tuple[int, ...]
18+
19+
20+
class LRUTrieCache:
21+
"""LRU-bounded prefix trie for token sequence rewrites."""
22+
23+
def __init__(self, max_entries: int = 1024) -> None:
24+
if max_entries <= 0:
25+
raise ValueError("max_entries must be positive")
26+
self._trie: datrie.Trie = datrie.Trie(_TRIE_ALPHABET)
27+
self._lru: OrderedDict[str, None] = OrderedDict()
28+
self._max_entries = max_entries
29+
30+
@staticmethod
31+
def _encode_tokens(tokens: Sequence[int]) -> str:
32+
packed = bytearray()
33+
for token in tokens:
34+
packed.extend(struct.pack(">I", token))
35+
return packed.decode("latin-1")
36+
37+
def lookup(self, rendered_tokens: Sequence[int]) -> PrefixEntry | None:
38+
key = self._encode_tokens(rendered_tokens)
39+
match: tuple[str, PrefixEntry] | None = None
40+
for item in self._trie.prefix_items(key):
41+
match = item
42+
if match is None:
43+
return None
44+
match_key, entry = match
45+
self._lru.move_to_end(match_key)
46+
return entry
47+
48+
def insert(self, rendered_prefix: Sequence[int], raw_prefix: Sequence[int]) -> None:
49+
key = self._encode_tokens(rendered_prefix)
50+
entry = PrefixEntry(
51+
rendered_len=len(rendered_prefix), raw_prefix=tuple(raw_prefix)
52+
)
53+
self._trie[key] = entry
54+
self._lru[key] = None
55+
self._lru.move_to_end(key)
56+
self._evict()
57+
58+
def _evict(self) -> None:
59+
while len(self._lru) > self._max_entries:
60+
old_key, _ = self._lru.popitem(last=False)
61+
if old_key in self._trie:
62+
del self._trie[old_key]

src/art/tinker/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ async def chat_completions(
168168
message=ChatCompletionMessage(
169169
content=openai_message.get("content") or None,
170170
role="assistant",
171-
tool_calls=tool_calls,
171+
tool_calls=tool_calls, # type: ignore
172172
),
173173
logprobs=ChoiceLogprobs(
174174
content=[

tests/unit/test_prefix_cache.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
"""Tests for the LRUTrieCache prefix rewrite helper."""
2+
3+
import pytest
4+
5+
pytest.importorskip("datrie")
6+
7+
from art.tinker.prefix_cache import LRUTrieCache
8+
9+
10+
class TestLRUTrieCache:
11+
def test_longest_prefix_match(self) -> None:
12+
cache = LRUTrieCache(max_entries=10)
13+
cache.insert([1, 2], [10, 11])
14+
cache.insert([1, 2, 3], [20, 21, 22])
15+
16+
entry = cache.lookup([1, 2, 3, 4])
17+
18+
assert entry is not None
19+
assert entry.rendered_len == 3
20+
assert entry.raw_prefix == (20, 21, 22)
21+
22+
def test_lru_eviction(self) -> None:
23+
cache = LRUTrieCache(max_entries=2)
24+
cache.insert([1], [10])
25+
cache.insert([2], [20])
26+
27+
assert cache.lookup([1, 99]) is not None
28+
29+
cache.insert([3], [30])
30+
31+
assert cache.lookup([2, 0]) is None
32+
assert cache.lookup([1, 0]) is not None
33+
34+
def test_invalid_size(self) -> None:
35+
with pytest.raises(ValueError):
36+
LRUTrieCache(max_entries=0)

uv.lock

Lines changed: 8 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)