Skip to content

Commit 0e567fb

Browse files
committed
Revert "use validation instead"
1 parent fd939af commit 0e567fb

3 files changed

Lines changed: 15 additions & 42 deletions

File tree

src/art/preprocessing/tokenize.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -139,14 +139,16 @@ def tokenize_trajectory(
139139
# Find the index of the last assistant message
140140
last_assistant_index = -1
141141
for i, message in enumerate(history.messages_and_choices):
142-
if isinstance(message, dict):
143-
# Message dict
144-
if message["role"] == "assistant" and allow_training_without_logprobs:
145-
last_assistant_index = i
146-
else:
147-
# Choice object
148-
if message.logprobs is not None or allow_training_without_logprobs:
149-
last_assistant_index = i
142+
if (
143+
isinstance(message, dict)
144+
and message["role"] == "assistant"
145+
and allow_training_without_logprobs
146+
):
147+
last_assistant_index = i
148+
elif not isinstance(message, dict) and (
149+
message.logprobs or allow_training_without_logprobs
150+
):
151+
last_assistant_index = i
150152
# If there are no trainable assistant messages, return None
151153
if last_assistant_index == -1:
152154
return None
@@ -187,7 +189,7 @@ def tokenize_trajectory(
187189
(
188190
message_or_choice
189191
if isinstance(message_or_choice, dict)
190-
and message_or_choice["role"] != "assistant"
192+
and not message_or_choice["role"] == "assistant"
191193
else {
192194
"role": "assistant",
193195
"content": sentinal_token,
@@ -203,7 +205,7 @@ def tokenize_trajectory(
203205
assistant_mask: list[int] = [0] * len(token_ids)
204206
logprobs = [float("nan")] * len(token_ids)
205207
for message in messages_and_choices:
206-
if isinstance(message, dict) and message["role"] != "assistant":
208+
if isinstance(message, dict) and not message["role"] == "assistant":
207209
continue
208210
start = token_ids.index(sentinal_token_id)
209211
end = start + 1
@@ -212,7 +214,6 @@ def tokenize_trajectory(
212214
except IndexError:
213215
end_token_id = None
214216
if isinstance(message, dict):
215-
# Message dict
216217
content = message.get("content")
217218
assert isinstance(content, str)
218219
content_token_ids = tokenizer.encode(
@@ -223,7 +224,6 @@ def tokenize_trajectory(
223224
logprobs[start:end] = [float("nan")] * len(content_token_ids)
224225
assistant_mask[start:end] = [1] * len(content_token_ids)
225226
else:
226-
# Choice object
227227
choice = message
228228
assert choice.logprobs or allow_training_without_logprobs, (
229229
"Chat completion choices must have logprobs"

src/art/trajectories.py

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -31,19 +31,6 @@ class History(pydantic.BaseModel):
3131
messages_and_choices: MessagesAndChoices
3232
tools: Tools | None = None
3333

34-
@pydantic.field_validator("messages_and_choices", mode="before")
35-
@classmethod
36-
def deserialize_choices(cls, v: list[Any]) -> list[Any]:
37-
"""Convert serialized Choice dicts back to Choice objects."""
38-
result = []
39-
for item in v:
40-
if isinstance(item, dict) and "message" in item and "index" in item:
41-
# This is a serialized Choice dict - convert back to Choice object
42-
result.append(Choice.model_validate(item))
43-
else:
44-
result.append(item)
45-
return result
46-
4734
def messages(self) -> Messages:
4835
return get_messages(self.messages_and_choices)
4936

@@ -59,19 +46,6 @@ class Trajectory(pydantic.BaseModel):
5946
logs: list[str] = []
6047
start_time: datetime = pydantic.Field(default_factory=datetime.now, exclude=True)
6148

62-
@pydantic.field_validator("messages_and_choices", mode="before")
63-
@classmethod
64-
def deserialize_choices(cls, v: list[Any]) -> list[Any]:
65-
"""Convert serialized Choice dicts back to Choice objects."""
66-
result = []
67-
for item in v:
68-
if isinstance(item, dict) and "message" in item and "index" in item:
69-
# This is a serialized Choice dict - convert back to Choice object
70-
result.append(Choice.model_validate(item))
71-
else:
72-
result.append(item)
73-
return result
74-
7549
def __init__(self, **data: Any):
7650
super().__init__(**data)
7751
self.start_time = datetime.now()
@@ -123,7 +97,6 @@ def get_messages(messages_and_choices: MessagesAndChoices) -> Messages:
12397
messages: Messages = []
12498
for message_or_choice in messages_and_choices:
12599
if isinstance(message_or_choice, Choice):
126-
# Choice object (always a Choice after Pydantic validation)
127100
content = message_or_choice.message.content or ""
128101
tool_calls = message_or_choice.message.tool_calls or []
129102
messages.append(
@@ -143,7 +116,7 @@ def get_messages(messages_and_choices: MessagesAndChoices) -> Messages:
143116
}
144117
)
145118
else:
146-
# Regular Message dict
119+
# Ensure content is always a string for tokenizer chat templates
147120
msg = dict(message_or_choice)
148121
if msg.get("content") is None:
149122
msg["content"] = ""

src/art/unsloth/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,11 @@ def compute_loss(
7272
if inputs.get("pixel_values") and inputs["pixel_values"][0] is not None:
7373
inputs["pixel_values"] = inputs["pixel_values"][0] # type: ignore
7474
else:
75-
inputs.pop("pixel_values", None)
75+
del inputs["pixel_values"] # type: ignore
7676
if inputs.get("image_grid_thw") and inputs["image_grid_thw"][0] is not None:
7777
inputs["image_grid_thw"] = inputs["image_grid_thw"][0] # type: ignore
7878
else:
79-
inputs.pop("image_grid_thw", None)
79+
del inputs["image_grid_thw"] # type: ignore
8080

8181
# Move tensors to the correct device
8282
inputs = {

0 commit comments

Comments
 (0)