Skip to content

Commit a437c9d

Browse files
committed
refactor: Enhance LRUTrieCache with radix tree structure and statistics tracking
- Replaced the existing trie implementation with a radix tree for improved performance in token sequence rewrites. - Introduced a new PrefixCacheStats dataclass to track cache statistics such as lookups, hits, misses, and evictions. - Updated the LRUTrieCache initialization to allow for a configurable maximum number of entries. - Refactored the insert and lookup methods to utilize the new radix tree structure and maintain cache statistics. These changes optimize the cache's efficiency and provide better insights into its usage.
1 parent 6c3be3f commit a437c9d

2 files changed

Lines changed: 177 additions & 36 deletions

File tree

src/art/tinker/prefix_cache.py

Lines changed: 176 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,64 +2,207 @@
22

33
from collections import OrderedDict
44
from dataclasses import dataclass
5-
import struct
65
from typing import Sequence
76

8-
import datrie
9-
10-
_TRIE_ALPHABET = "0123456789abcdef"
11-
127

138
@dataclass(frozen=True)
149
class PrefixEntry:
1510
rendered_len: int
1611
raw_prefix: tuple[int, ...]
1712

1813

14+
@dataclass
15+
class PrefixCacheStats:
16+
max_entries: int
17+
lookups: int = 0
18+
hits: int = 0
19+
misses: int = 0
20+
inserts: int = 0
21+
replaced_entries: int = 0
22+
evictions: int = 0
23+
splits: int = 0
24+
pruned_nodes: int = 0
25+
merged_nodes: int = 0
26+
lru_repairs: int = 0
27+
28+
29+
class _RadixEdge:
30+
__slots__ = ("label", "child")
31+
32+
def __init__(self, label: tuple[int, ...], child: _RadixNode) -> None:
33+
self.label = label
34+
self.child = child
35+
36+
37+
class _RadixNode:
38+
__slots__ = ("entry", "children", "parent", "parent_token")
39+
40+
def __init__(
41+
self, parent: _RadixNode | None = None, parent_token: int | None = None
42+
) -> None:
43+
self.entry: PrefixEntry | None = None
44+
self.children: dict[int, _RadixEdge] = {}
45+
self.parent = parent
46+
self.parent_token = parent_token
47+
48+
49+
def _common_prefix_len(
50+
tokens: Sequence[int], start: int, label: tuple[int, ...]
51+
) -> int:
52+
max_len = min(len(tokens) - start, len(label))
53+
i = 0
54+
while i < max_len and tokens[start + i] == label[i]:
55+
i += 1
56+
return i
57+
58+
1959
class LRUTrieCache:
20-
"""LRU-bounded prefix trie for token sequence rewrites."""
60+
"""LRU-bounded radix trie for token sequence rewrites."""
2161

22-
def __init__(self, max_entries: int = 1024) -> None:
62+
def __init__(self, max_entries: int = 16_384) -> None:
2363
if max_entries <= 0:
2464
raise ValueError("max_entries must be positive")
25-
self._trie: datrie.Trie = datrie.Trie(_TRIE_ALPHABET)
26-
self._lru: OrderedDict[str, None] = OrderedDict()
65+
self._root = _RadixNode()
66+
self._lru: OrderedDict[_RadixNode, None] = OrderedDict()
2767
self._max_entries = max_entries
28-
29-
@staticmethod
30-
def _encode_tokens(tokens: Sequence[int]) -> str:
31-
if not tokens:
32-
return ""
33-
return struct.pack(f">{len(tokens)}I", *tokens).hex()
68+
self.stats = PrefixCacheStats(max_entries=max_entries)
3469

3570
def lookup(self, rendered_tokens: Sequence[int]) -> PrefixEntry | None:
36-
key = self._encode_tokens(rendered_tokens)
37-
match: tuple[str, PrefixEntry] | None = None
38-
for item in self._trie.prefix_items(key):
39-
match = item
40-
if match is None:
71+
self.stats.lookups += 1
72+
node = self._root
73+
idx = 0
74+
best_node = None
75+
while idx < len(rendered_tokens):
76+
edge = node.children.get(rendered_tokens[idx])
77+
if edge is None:
78+
break
79+
matched = _common_prefix_len(rendered_tokens, idx, edge.label)
80+
if matched != len(edge.label):
81+
break
82+
idx += matched
83+
node = edge.child
84+
if node.entry is not None:
85+
best_node = node
86+
if best_node is None:
87+
self.stats.misses += 1
4188
return None
42-
match_key, entry = match
89+
self.stats.hits += 1
4390
try:
44-
self._lru.move_to_end(match_key)
91+
self._lru.move_to_end(best_node)
4592
except KeyError:
46-
self._lru[match_key] = None
47-
self._lru.move_to_end(match_key)
93+
self.stats.lru_repairs += 1
94+
self._lru[best_node] = None
95+
self._lru.move_to_end(best_node)
4896
self._evict()
49-
return entry
97+
return best_node.entry
5098

5199
def insert(self, rendered_prefix: Sequence[int], raw_prefix: Sequence[int]) -> None:
52-
key = self._encode_tokens(rendered_prefix)
53-
entry = PrefixEntry(
100+
self.stats.inserts += 1
101+
node = self._root
102+
idx = 0
103+
while idx < len(rendered_prefix):
104+
token = rendered_prefix[idx]
105+
edge = node.children.get(token)
106+
if edge is None:
107+
child = _RadixNode(parent=node, parent_token=token)
108+
node.children[token] = _RadixEdge(tuple(rendered_prefix[idx:]), child)
109+
node = child
110+
idx = len(rendered_prefix)
111+
break
112+
113+
matched = _common_prefix_len(rendered_prefix, idx, edge.label)
114+
if matched == len(edge.label):
115+
idx += matched
116+
node = edge.child
117+
continue
118+
119+
mid = _RadixNode(parent=node, parent_token=token)
120+
self.stats.splits += 1
121+
old_suffix = edge.label[matched:]
122+
old_child = edge.child
123+
old_child.parent = mid
124+
old_child.parent_token = old_suffix[0]
125+
mid.children[old_suffix[0]] = _RadixEdge(old_suffix, old_child)
126+
edge.label = edge.label[:matched]
127+
edge.child = mid
128+
node = mid
129+
idx += matched
130+
if idx < len(rendered_prefix):
131+
new_token = rendered_prefix[idx]
132+
child = _RadixNode(parent=node, parent_token=new_token)
133+
node.children[new_token] = _RadixEdge(
134+
tuple(rendered_prefix[idx:]), child
135+
)
136+
node = child
137+
break
138+
139+
if node.entry is not None:
140+
self.stats.replaced_entries += 1
141+
node.entry = PrefixEntry(
54142
rendered_len=len(rendered_prefix), raw_prefix=tuple(raw_prefix)
55143
)
56-
self._trie[key] = entry
57-
self._lru[key] = None
58-
self._lru.move_to_end(key)
144+
self._lru[node] = None
145+
self._lru.move_to_end(node)
59146
self._evict()
60147

61148
def _evict(self) -> None:
62149
while len(self._lru) > self._max_entries:
63-
old_key, _ = self._lru.popitem(last=False)
64-
if old_key in self._trie:
65-
del self._trie[old_key]
150+
old_node, _ = self._lru.popitem(last=False)
151+
self.stats.evictions += 1
152+
old_node.entry = None
153+
self._prune(old_node)
154+
155+
def _prune(self, node: _RadixNode) -> None:
156+
# Collapse empty branches after eviction so the bounded cache stays bounded.
157+
while node.parent is not None:
158+
parent = node.parent
159+
parent_token = node.parent_token
160+
assert parent_token is not None
161+
162+
if node.entry is None and not node.children:
163+
del parent.children[parent_token]
164+
self.stats.pruned_nodes += 1
165+
node = parent
166+
continue
167+
168+
if node.entry is None and len(node.children) == 1:
169+
_, child_edge = next(iter(node.children.items()))
170+
parent_edge = parent.children[parent_token]
171+
parent_edge.label = parent_edge.label + child_edge.label
172+
parent_edge.child = child_edge.child
173+
child_edge.child.parent = parent
174+
child_edge.child.parent_token = parent_token
175+
self.stats.merged_nodes += 1
176+
node = parent
177+
continue
178+
179+
break
180+
181+
def snapshot_stats(self) -> dict[str, int | float]:
182+
hit_rate = self.stats.hits / self.stats.lookups if self.stats.lookups else 0.0
183+
return {
184+
"enabled": True,
185+
"max_entries": self.stats.max_entries,
186+
"current_entries": len(self._lru),
187+
"node_count": self._node_count(),
188+
"lookups": self.stats.lookups,
189+
"hits": self.stats.hits,
190+
"misses": self.stats.misses,
191+
"hit_rate": hit_rate,
192+
"inserts": self.stats.inserts,
193+
"replaced_entries": self.stats.replaced_entries,
194+
"evictions": self.stats.evictions,
195+
"splits": self.stats.splits,
196+
"pruned_nodes": self.stats.pruned_nodes,
197+
"merged_nodes": self.stats.merged_nodes,
198+
"lru_repairs": self.stats.lru_repairs,
199+
}
200+
201+
def _node_count(self) -> int:
202+
count = 0
203+
stack = [self._root]
204+
while stack:
205+
node = stack.pop()
206+
count += 1
207+
stack.extend(edge.child for edge in node.children.values())
208+
return count

src/art/tinker/server.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,7 @@ class OpenAICompatibleTinkerServer:
4040
port: int | None = None
4141
num_workers: int | None = None
4242
models: dict[str, str] = field(default_factory=dict)
43-
_prefix_cache: LRUTrieCache = field(
44-
default_factory=lambda: LRUTrieCache(max_entries=1000)
45-
)
43+
_prefix_cache: LRUTrieCache = field(default_factory=LRUTrieCache)
4644
_workers: list["Worker"] = field(default_factory=list)
4745
_task: asyncio.Task[None] | None = None
4846
_tenant_clients: dict[str, tuple[tinker.ServiceClient, TinkerRestClient]] = field(

0 commit comments

Comments
 (0)