Skip to content

Commit 121e6ea

Browse files
authored
SFT Preprocessing (#525)
* SFT preprocessing * Use unsloth-zoo _train_on_response_only * Fix ruff check * reorder import * Add comment for unsloth import * Import unsloth-zoo inside of tokenize_sft_batches
1 parent 1775016 commit 121e6ea

1 file changed

Lines changed: 155 additions & 4 deletions

File tree

src/art/preprocessing/tokenize.py

Lines changed: 155 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
1-
from dataclasses import dataclass
2-
from itertools import takewhile
1+
# ruff: noqa: I001
2+
# Import order is intentional - unsloth MUST be imported before transformers
33
import math
44
import random
5+
from dataclasses import dataclass
6+
from itertools import takewhile
57
from typing import Any, Generator, cast
68

7-
from PIL import Image
9+
import unsloth # noqa: F401 # Must import first to set UNSLOTH_IS_PRESENT env var
10+
811
import torch
12+
from PIL import Image
913
from transformers.image_processing_utils import BaseImageProcessor
1014
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
1115

12-
from ..trajectories import History, TrajectoryGroup, get_messages
16+
from ..trajectories import History, Trajectory, TrajectoryGroup, get_messages
1317

1418

1519
@dataclass
@@ -44,6 +48,23 @@ def without_prompt(self) -> "TokenizedResult":
4448
)
4549

4650

51+
@dataclass
52+
class SFTBatch:
53+
"""A batch of tokenized trajectories for supervised fine-tuning.
54+
Attributes:
55+
trajectory_tensors: List of tensor dictionaries, one per trajectory.
56+
Each dict contains 'input_ids', 'attention_mask', and 'labels'.
57+
learning_rate: Learning rate to use for this batch.
58+
num_trajectories: Number of trajectories in this batch.
59+
num_trainable_tokens: Total number of tokens being trained on (labels != -100).
60+
"""
61+
62+
trajectory_tensors: list[dict[str, torch.Tensor]]
63+
learning_rate: float
64+
num_trajectories: int
65+
num_trainable_tokens: int
66+
67+
4768
def tokenize_trajectory_groups(
4869
tokenizer: "PreTrainedTokenizerBase",
4970
trajectory_groups: list[TrajectoryGroup],
@@ -312,3 +333,133 @@ def tokenize_trajectory(
312333
pixel_values=pixel_values,
313334
image_grid_thw=image_grid_thw,
314335
)
336+
337+
338+
def tokenize_sft_batches(
339+
trajectories: list[Trajectory],
340+
batch_size: int,
341+
learning_rates: list[float],
342+
tokenizer: PreTrainedTokenizerBase,
343+
instruction_part: str,
344+
response_part: str,
345+
) -> Generator[SFTBatch, None, None]:
346+
"""
347+
Tokenize trajectories into batches for supervised fine-tuning.
348+
Args:
349+
trajectories: Flat list of trajectories
350+
batch_size: Number of trajectories per batch
351+
learning_rates: Learning rate for each batch
352+
tokenizer: Tokenizer to use for encoding
353+
instruction_part: Instruction template part (e.g., "User:")
354+
response_part: Response template part (e.g., "Assistant:")
355+
Yields:
356+
SFTBatch object containing:
357+
- trajectory_tensors: List of tensors for each trajectory
358+
- learning_rate: Learning rate for this batch
359+
- num_trajectories: Number of trajectories in this batch
360+
- num_trainable_tokens: Total number of trainable tokens
361+
"""
362+
# Import Unsloth Zoo utility for training on responses only
363+
# Source: https://github.com/unslothai/unsloth-zoo/blob/main/unsloth_zoo/dataset_utils.py
364+
# This function handles edge cases with tokenization (newlines, spaces, etc.)
365+
from unsloth_zoo.dataset_utils import train_on_responses_only
366+
367+
# Validate inputs
368+
num_trajectories = len(trajectories)
369+
num_learning_rates = len(learning_rates)
370+
expected_num_batches = math.ceil(num_trajectories / batch_size)
371+
372+
if num_learning_rates != expected_num_batches:
373+
raise ValueError(
374+
f"Mismatch between trajectories and learning_rates: "
375+
f"{num_trajectories} trajectories with batch_size={batch_size} "
376+
f"yields {expected_num_batches} batches, but got {num_learning_rates} learning_rates"
377+
)
378+
379+
# Handle missing pad_token_id (common for LLaMA and similar models)
380+
pad_token_id = tokenizer.pad_token_id
381+
if pad_token_id is None:
382+
pad_token_id = tokenizer.eos_token_id
383+
384+
_train_on_responses_only = train_on_responses_only(
385+
trainer=None,
386+
instruction_part=instruction_part,
387+
response_part=response_part,
388+
force_match=False,
389+
tokenizer=tokenizer,
390+
return_function=True,
391+
)
392+
393+
# TODO Process input_ids in batch for better efficiency
394+
for batch_idx, lr in enumerate(learning_rates):
395+
start_idx = batch_idx * batch_size
396+
end_idx = start_idx + batch_size
397+
trajectory_batch = trajectories[start_idx:end_idx]
398+
399+
# First pass: tokenize all trajectories
400+
tokenized_trajectories = []
401+
for trajectory in trajectory_batch:
402+
messages = trajectory.messages_and_choices
403+
tools = trajectory.tools
404+
405+
# Single-step tokenization: apply_chat_template with tokenize=True
406+
input_ids = cast(
407+
list[int],
408+
tokenizer.apply_chat_template(
409+
cast(Any, messages),
410+
tools=cast(Any, tools),
411+
tokenize=True,
412+
add_generation_prompt=False,
413+
),
414+
)
415+
416+
# Create attention mask (all 1s - no padding yet)
417+
attention_mask = [1] * len(input_ids)
418+
419+
labels = _train_on_responses_only({"input_ids": [input_ids]})["labels"][0]
420+
421+
tokenized_trajectories.append(
422+
{
423+
"input_ids": input_ids,
424+
"attention_mask": attention_mask,
425+
"labels": labels,
426+
}
427+
)
428+
429+
# Find max length in this batch for padding
430+
max_seq_length = max(len(t["input_ids"]) for t in tokenized_trajectories)
431+
432+
# Second pass: pad all trajectories to max_seq_length
433+
trajectory_tensors = []
434+
for tokenized in tokenized_trajectories:
435+
input_ids = tokenized["input_ids"]
436+
attention_mask = tokenized["attention_mask"]
437+
labels = tokenized["labels"]
438+
439+
# Pad to max_seq_length
440+
padding_length = max_seq_length - len(input_ids)
441+
if padding_length > 0:
442+
input_ids = input_ids + [pad_token_id] * padding_length
443+
attention_mask = attention_mask + [0] * padding_length
444+
labels = labels + [-100] * padding_length
445+
446+
trajectory_tensor = {
447+
"input_ids": torch.tensor([input_ids], dtype=torch.long),
448+
"attention_mask": torch.tensor([attention_mask], dtype=torch.long),
449+
"labels": torch.tensor([labels], dtype=torch.long),
450+
}
451+
452+
trajectory_tensors.append(trajectory_tensor)
453+
454+
# Calculate total trainable tokens (labels != -100)
455+
num_trainable_tokens = sum(
456+
(tensor_dict["labels"] != -100).sum().item()
457+
for tensor_dict in trajectory_tensors
458+
)
459+
460+
yield SFTBatch(
461+
trajectory_tensors=trajectory_tensors,
462+
learning_rate=lr,
463+
num_trajectories=len(trajectory_tensors),
464+
num_trainable_tokens=num_trainable_tokens,
465+
)

0 commit comments

Comments
 (0)