Skip to content

Commit b220140

Browse files
corbtCursor Bot
andauthored
fix: preserve tool-call context in tokenization (#527)
* fix: preserve tool-call context in tokenization Only splice trainable assistant spans and keep tool_calls in the template; error if tool_calls would be dropped. * Fix type errors in tool-call tokenization - Use .get() instead of direct [] access for tool_calls to handle message types that don't have this key - Cast message to dict[str, Any] when appending to token_template_messages --------- Co-authored-by: Cursor Bot <bot@cursor.com>
1 parent 79d27f2 commit b220140

1 file changed

Lines changed: 35 additions & 16 deletions

File tree

src/art/preprocessing/tokenize.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -201,32 +201,46 @@ def tokenize_trajectory(
201201
set(range(cast(int, tokenizer.vocab_size))) - set(original_token_ids)
202202
)
203203
sentinal_token = tokenizer.decode(sentinal_token_id)
204+
token_template_messages: list[dict[str, Any]] = []
205+
for original, message in zip(messages_and_choices, messages):
206+
trainable_assistant = (
207+
not isinstance(original, dict) and original.logprobs is not None
208+
) or (
209+
allow_training_without_logprobs
210+
and isinstance(original, dict)
211+
and original.get("role") == "assistant"
212+
)
213+
if trainable_assistant:
214+
token_template_messages.append(
215+
{
216+
"role": "assistant",
217+
"content": sentinal_token,
218+
**(
219+
{"tool_calls": message.get("tool_calls")} # type: ignore[call-overload]
220+
if message.get("tool_calls") # type: ignore[call-overload]
221+
else {}
222+
),
223+
}
224+
)
225+
else:
226+
token_template_messages.append(cast(dict[str, Any], message))
204227
token_ids = cast(
205228
list[int],
206229
tokenizer.apply_chat_template(
207-
cast(
208-
list[dict],
209-
[
210-
(
211-
message_or_choice
212-
if isinstance(message_or_choice, dict)
213-
and not message_or_choice["role"] == "assistant"
214-
else {
215-
"role": "assistant",
216-
"content": sentinal_token,
217-
}
218-
)
219-
for message_or_choice in messages_and_choices
220-
],
221-
),
230+
cast(list[dict], token_template_messages),
222231
tools=tools, # type: ignore
223232
continue_final_message=True,
224233
),
225234
)
226235
assistant_mask: list[int] = [0] * len(token_ids)
227236
logprobs = [float("nan")] * len(token_ids)
228237
for message in messages_and_choices:
229-
if isinstance(message, dict) and not message["role"] == "assistant":
238+
if isinstance(message, dict):
239+
if message["role"] != "assistant":
240+
continue
241+
if not allow_training_without_logprobs:
242+
continue
243+
elif message.logprobs is None and not allow_training_without_logprobs:
230244
continue
231245
start = token_ids.index(sentinal_token_id)
232246
end = start + 1
@@ -235,6 +249,11 @@ def tokenize_trajectory(
235249
except IndexError:
236250
end_token_id = None
237251
if isinstance(message, dict):
252+
if message.get("tool_calls"):
253+
raise ValueError(
254+
"Assistant message has tool_calls but is being tokenized "
255+
"via tokenizer.encode(content). This path ignores tool calls."
256+
)
238257
content = message.get("content")
239258
assert isinstance(content, str)
240259
content_token_ids = tokenizer.encode(

0 commit comments

Comments
 (0)