Skip to content

Commit fd939af

Browse files
committed
use validation instead
1 parent 4fa368e commit fd939af

3 files changed

Lines changed: 42 additions & 15 deletions

File tree

src/art/preprocessing/tokenize.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -139,16 +139,14 @@ def tokenize_trajectory(
139139
# Find the index of the last assistant message
140140
last_assistant_index = -1
141141
for i, message in enumerate(history.messages_and_choices):
142-
if (
143-
isinstance(message, dict)
144-
and message["role"] == "assistant"
145-
and allow_training_without_logprobs
146-
):
147-
last_assistant_index = i
148-
elif not isinstance(message, dict) and (
149-
message.logprobs or allow_training_without_logprobs
150-
):
151-
last_assistant_index = i
142+
if isinstance(message, dict):
143+
# Message dict
144+
if message["role"] == "assistant" and allow_training_without_logprobs:
145+
last_assistant_index = i
146+
else:
147+
# Choice object
148+
if message.logprobs is not None or allow_training_without_logprobs:
149+
last_assistant_index = i
152150
# If there are no trainable assistant messages, return None
153151
if last_assistant_index == -1:
154152
return None
@@ -189,7 +187,7 @@ def tokenize_trajectory(
189187
(
190188
message_or_choice
191189
if isinstance(message_or_choice, dict)
192-
and not message_or_choice["role"] == "assistant"
190+
and message_or_choice["role"] != "assistant"
193191
else {
194192
"role": "assistant",
195193
"content": sentinal_token,
@@ -205,7 +203,7 @@ def tokenize_trajectory(
205203
assistant_mask: list[int] = [0] * len(token_ids)
206204
logprobs = [float("nan")] * len(token_ids)
207205
for message in messages_and_choices:
208-
if isinstance(message, dict) and not message["role"] == "assistant":
206+
if isinstance(message, dict) and message["role"] != "assistant":
209207
continue
210208
start = token_ids.index(sentinal_token_id)
211209
end = start + 1
@@ -214,6 +212,7 @@ def tokenize_trajectory(
214212
except IndexError:
215213
end_token_id = None
216214
if isinstance(message, dict):
215+
# Message dict
217216
content = message.get("content")
218217
assert isinstance(content, str)
219218
content_token_ids = tokenizer.encode(
@@ -224,6 +223,7 @@ def tokenize_trajectory(
224223
logprobs[start:end] = [float("nan")] * len(content_token_ids)
225224
assistant_mask[start:end] = [1] * len(content_token_ids)
226225
else:
226+
# Choice object
227227
choice = message
228228
assert choice.logprobs or allow_training_without_logprobs, (
229229
"Chat completion choices must have logprobs"

src/art/trajectories.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,19 @@ class History(pydantic.BaseModel):
3131
messages_and_choices: MessagesAndChoices
3232
tools: Tools | None = None
3333

34+
@pydantic.field_validator("messages_and_choices", mode="before")
35+
@classmethod
36+
def deserialize_choices(cls, v: list[Any]) -> list[Any]:
37+
"""Convert serialized Choice dicts back to Choice objects."""
38+
result = []
39+
for item in v:
40+
if isinstance(item, dict) and "message" in item and "index" in item:
41+
# This is a serialized Choice dict - convert back to Choice object
42+
result.append(Choice.model_validate(item))
43+
else:
44+
result.append(item)
45+
return result
46+
3447
def messages(self) -> Messages:
3548
return get_messages(self.messages_and_choices)
3649

@@ -46,6 +59,19 @@ class Trajectory(pydantic.BaseModel):
4659
logs: list[str] = []
4760
start_time: datetime = pydantic.Field(default_factory=datetime.now, exclude=True)
4861

62+
@pydantic.field_validator("messages_and_choices", mode="before")
63+
@classmethod
64+
def deserialize_choices(cls, v: list[Any]) -> list[Any]:
65+
"""Convert serialized Choice dicts back to Choice objects."""
66+
result = []
67+
for item in v:
68+
if isinstance(item, dict) and "message" in item and "index" in item:
69+
# This is a serialized Choice dict - convert back to Choice object
70+
result.append(Choice.model_validate(item))
71+
else:
72+
result.append(item)
73+
return result
74+
4975
def __init__(self, **data: Any):
5076
super().__init__(**data)
5177
self.start_time = datetime.now()
@@ -97,6 +123,7 @@ def get_messages(messages_and_choices: MessagesAndChoices) -> Messages:
97123
messages: Messages = []
98124
for message_or_choice in messages_and_choices:
99125
if isinstance(message_or_choice, Choice):
126+
# Choice object (always a Choice after Pydantic validation)
100127
content = message_or_choice.message.content or ""
101128
tool_calls = message_or_choice.message.tool_calls or []
102129
messages.append(
@@ -116,7 +143,7 @@ def get_messages(messages_and_choices: MessagesAndChoices) -> Messages:
116143
}
117144
)
118145
else:
119-
# Ensure content is always a string for tokenizer chat templates
146+
# Regular Message dict
120147
msg = dict(message_or_choice)
121148
if msg.get("content") is None:
122149
msg["content"] = ""

src/art/unsloth/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,11 @@ def compute_loss(
7272
if inputs.get("pixel_values") and inputs["pixel_values"][0] is not None:
7373
inputs["pixel_values"] = inputs["pixel_values"][0] # type: ignore
7474
else:
75-
del inputs["pixel_values"] # type: ignore
75+
inputs.pop("pixel_values", None)
7676
if inputs.get("image_grid_thw") and inputs["image_grid_thw"][0] is not None:
7777
inputs["image_grid_thw"] = inputs["image_grid_thw"][0] # type: ignore
7878
else:
79-
del inputs["image_grid_thw"] # type: ignore
79+
inputs.pop("image_grid_thw", None)
8080

8181
# Move tensors to the correct device
8282
inputs = {

0 commit comments

Comments
 (0)