|
1 | | -from dataclasses import dataclass |
2 | | -from itertools import takewhile |
| 1 | +# ruff: noqa: I001 |
| 2 | +# Import order is intentional - unsloth MUST be imported before transformers |
3 | 3 | import math |
4 | 4 | import random |
| 5 | +from dataclasses import dataclass |
| 6 | +from itertools import takewhile |
5 | 7 | from typing import Any, Generator, cast |
6 | 8 |
|
7 | | -from PIL import Image |
| 9 | +import unsloth # noqa: F401 # Must import first to set UNSLOTH_IS_PRESENT env var |
| 10 | + |
8 | 11 | import torch |
| 12 | +from PIL import Image |
9 | 13 | from transformers.image_processing_utils import BaseImageProcessor |
10 | 14 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase |
11 | 15 |
|
12 | | -from ..trajectories import History, TrajectoryGroup, get_messages |
| 16 | +from ..trajectories import History, Trajectory, TrajectoryGroup, get_messages |
13 | 17 |
|
14 | 18 |
|
15 | 19 | @dataclass |
@@ -44,6 +48,23 @@ def without_prompt(self) -> "TokenizedResult": |
44 | 48 | ) |
45 | 49 |
|
46 | 50 |
|
| 51 | +@dataclass |
| 52 | +class SFTBatch: |
| 53 | + """A batch of tokenized trajectories for supervised fine-tuning. |
| 54 | + Attributes: |
| 55 | + trajectory_tensors: List of tensor dictionaries, one per trajectory. |
| 56 | + Each dict contains 'input_ids', 'attention_mask', and 'labels'. |
| 57 | + learning_rate: Learning rate to use for this batch. |
| 58 | + num_trajectories: Number of trajectories in this batch. |
| 59 | + num_trainable_tokens: Total number of tokens being trained on (labels != -100). |
| 60 | + """ |
| 61 | + |
| 62 | + trajectory_tensors: list[dict[str, torch.Tensor]] |
| 63 | + learning_rate: float |
| 64 | + num_trajectories: int |
| 65 | + num_trainable_tokens: int |
| 66 | + |
| 67 | + |
47 | 68 | def tokenize_trajectory_groups( |
48 | 69 | tokenizer: "PreTrainedTokenizerBase", |
49 | 70 | trajectory_groups: list[TrajectoryGroup], |
@@ -312,3 +333,133 @@ def tokenize_trajectory( |
312 | 333 | pixel_values=pixel_values, |
313 | 334 | image_grid_thw=image_grid_thw, |
314 | 335 | ) |
| 336 | + |
| 337 | + |
| 338 | +def tokenize_sft_batches( |
| 339 | + trajectories: list[Trajectory], |
| 340 | + batch_size: int, |
| 341 | + learning_rates: list[float], |
| 342 | + tokenizer: PreTrainedTokenizerBase, |
| 343 | + instruction_part: str, |
| 344 | + response_part: str, |
| 345 | +) -> Generator[SFTBatch, None, None]: |
| 346 | + """ |
| 347 | + Tokenize trajectories into batches for supervised fine-tuning. |
| 348 | + Args: |
| 349 | + trajectories: Flat list of trajectories |
| 350 | + batch_size: Number of trajectories per batch |
| 351 | + learning_rates: Learning rate for each batch |
| 352 | + tokenizer: Tokenizer to use for encoding |
| 353 | + instruction_part: Instruction template part (e.g., "User:") |
| 354 | + response_part: Response template part (e.g., "Assistant:") |
| 355 | + Yields: |
| 356 | + SFTBatch object containing: |
| 357 | + - trajectory_tensors: List of tensors for each trajectory |
| 358 | + - learning_rate: Learning rate for this batch |
| 359 | + - num_trajectories: Number of trajectories in this batch |
| 360 | + - num_trainable_tokens: Total number of trainable tokens |
| 361 | + """ |
| 362 | + # Import Unsloth Zoo utility for training on responses only |
| 363 | + # Source: https://github.com/unslothai/unsloth-zoo/blob/main/unsloth_zoo/dataset_utils.py |
| 364 | + # This function handles edge cases with tokenization (newlines, spaces, etc.) |
| 365 | + from unsloth_zoo.dataset_utils import train_on_responses_only |
| 366 | + |
| 367 | + # Validate inputs |
| 368 | + num_trajectories = len(trajectories) |
| 369 | + num_learning_rates = len(learning_rates) |
| 370 | + expected_num_batches = math.ceil(num_trajectories / batch_size) |
| 371 | + |
| 372 | + if num_learning_rates != expected_num_batches: |
| 373 | + raise ValueError( |
| 374 | + f"Mismatch between trajectories and learning_rates: " |
| 375 | + f"{num_trajectories} trajectories with batch_size={batch_size} " |
| 376 | + f"yields {expected_num_batches} batches, but got {num_learning_rates} learning_rates" |
| 377 | + ) |
| 378 | + |
| 379 | + # Handle missing pad_token_id (common for LLaMA and similar models) |
| 380 | + pad_token_id = tokenizer.pad_token_id |
| 381 | + if pad_token_id is None: |
| 382 | + pad_token_id = tokenizer.eos_token_id |
| 383 | + |
| 384 | + _train_on_responses_only = train_on_responses_only( |
| 385 | + trainer=None, |
| 386 | + instruction_part=instruction_part, |
| 387 | + response_part=response_part, |
| 388 | + force_match=False, |
| 389 | + tokenizer=tokenizer, |
| 390 | + return_function=True, |
| 391 | + ) |
| 392 | + |
| 393 | + # TODO Process input_ids in batch for better efficiency |
| 394 | + for batch_idx, lr in enumerate(learning_rates): |
| 395 | + start_idx = batch_idx * batch_size |
| 396 | + end_idx = start_idx + batch_size |
| 397 | + trajectory_batch = trajectories[start_idx:end_idx] |
| 398 | + |
| 399 | + # First pass: tokenize all trajectories |
| 400 | + tokenized_trajectories = [] |
| 401 | + for trajectory in trajectory_batch: |
| 402 | + messages = trajectory.messages_and_choices |
| 403 | + tools = trajectory.tools |
| 404 | + |
| 405 | + # Single-step tokenization: apply_chat_template with tokenize=True |
| 406 | + input_ids = cast( |
| 407 | + list[int], |
| 408 | + tokenizer.apply_chat_template( |
| 409 | + cast(Any, messages), |
| 410 | + tools=cast(Any, tools), |
| 411 | + tokenize=True, |
| 412 | + add_generation_prompt=False, |
| 413 | + ), |
| 414 | + ) |
| 415 | + |
| 416 | + # Create attention mask (all 1s - no padding yet) |
| 417 | + attention_mask = [1] * len(input_ids) |
| 418 | + |
| 419 | + labels = _train_on_responses_only({"input_ids": [input_ids]})["labels"][0] |
| 420 | + |
| 421 | + tokenized_trajectories.append( |
| 422 | + { |
| 423 | + "input_ids": input_ids, |
| 424 | + "attention_mask": attention_mask, |
| 425 | + "labels": labels, |
| 426 | + } |
| 427 | + ) |
| 428 | + |
| 429 | + # Find max length in this batch for padding |
| 430 | + max_seq_length = max(len(t["input_ids"]) for t in tokenized_trajectories) |
| 431 | + |
| 432 | + # Second pass: pad all trajectories to max_seq_length |
| 433 | + trajectory_tensors = [] |
| 434 | + for tokenized in tokenized_trajectories: |
| 435 | + input_ids = tokenized["input_ids"] |
| 436 | + attention_mask = tokenized["attention_mask"] |
| 437 | + labels = tokenized["labels"] |
| 438 | + |
| 439 | + # Pad to max_seq_length |
| 440 | + padding_length = max_seq_length - len(input_ids) |
| 441 | + if padding_length > 0: |
| 442 | + input_ids = input_ids + [pad_token_id] * padding_length |
| 443 | + attention_mask = attention_mask + [0] * padding_length |
| 444 | + labels = labels + [-100] * padding_length |
| 445 | + |
| 446 | + trajectory_tensor = { |
| 447 | + "input_ids": torch.tensor([input_ids], dtype=torch.long), |
| 448 | + "attention_mask": torch.tensor([attention_mask], dtype=torch.long), |
| 449 | + "labels": torch.tensor([labels], dtype=torch.long), |
| 450 | + } |
| 451 | + |
| 452 | + trajectory_tensors.append(trajectory_tensor) |
| 453 | + |
| 454 | + # Calculate total trainable tokens (labels != -100) |
| 455 | + num_trainable_tokens = sum( |
| 456 | + (tensor_dict["labels"] != -100).sum().item() |
| 457 | + for tensor_dict in trajectory_tensors |
| 458 | + ) |
| 459 | + |
| 460 | + yield SFTBatch( |
| 461 | + trajectory_tensors=trajectory_tensors, |
| 462 | + learning_rate=lr, |
| 463 | + num_trajectories=len(trajectory_tensors), |
| 464 | + num_trainable_tokens=num_trainable_tokens, |
| 465 | + ) |
0 commit comments