Skip to content

Commit 5207db0

Browse files
Kovboclaude
andauthored
Remove batch-level padding from tokenize_sft_batch (#582)
* Remove batch-level padding from tokenize_sft_batch tokenize_sft_batch was padding all trajectories to the longest sequence in the batch, but every consumer (unsloth, megatron) processes them individually. This wasted CPU memory and GPU compute on padding tokens. Now each trajectory tensor keeps its natural length. The unsloth training loop strips any residual padding before .to(device) for robustness. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Use microbatch pattern in unsloth training loop Match the serverless-training microbatch approach: process trajectories in configurable microbatch groups with padding trimmed to the longest in each group. Changing microbatch_size is a one-line change. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Fix ruff formatting Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Simplify training loop back to plain per-trajectory iteration Padding is now removed at the source in tokenize_sft_batch, so the training loop doesn't need microbatch trimming logic. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent f79fa5e commit 5207db0

1 file changed

Lines changed: 8 additions & 43 deletions

File tree

src/art/preprocessing/tokenize.py

Lines changed: 8 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -387,13 +387,9 @@ def tokenize_sft_batch(
387387
tokenizer=tokenizer,
388388
return_function=True,
389389
)
390-
# Handle missing pad_token_id (common for LLaMA and similar models)
391-
pad_token_id = tokenizer.pad_token_id
392-
if pad_token_id is None:
393-
pad_token_id = tokenizer.eos_token_id
394-
395-
# First pass: tokenize all trajectories
396-
tokenized_trajectories = []
390+
# Tokenize all trajectories (no padding — each keeps its natural length)
391+
trajectory_tensors = []
392+
num_trainable_tokens = 0
397393
for trajectory in trajectory_batch:
398394
messages = trajectory.messages_and_choices
399395
tools = trajectory.tools
@@ -409,49 +405,18 @@ def tokenize_sft_batch(
409405
),
410406
)
411407

412-
# Create attention mask (all 1s - no padding yet)
413408
attention_mask = [1] * len(input_ids)
414409

415410
labels = train_on_responses_only_fn({"input_ids": [input_ids]})["labels"][0]
416411

417-
tokenized_trajectories.append(
412+
trajectory_tensors.append(
418413
{
419-
"input_ids": input_ids,
420-
"attention_mask": attention_mask,
421-
"labels": labels,
414+
"input_ids": torch.tensor([input_ids], dtype=torch.long),
415+
"attention_mask": torch.tensor([attention_mask], dtype=torch.long),
416+
"labels": torch.tensor([labels], dtype=torch.long),
422417
}
423418
)
424-
425-
# Find max length in this batch for padding
426-
max_seq_length = max(len(t["input_ids"]) for t in tokenized_trajectories)
427-
428-
# Second pass: pad all trajectories to max_seq_length
429-
trajectory_tensors = []
430-
for tokenized in tokenized_trajectories:
431-
input_ids = tokenized["input_ids"]
432-
attention_mask = tokenized["attention_mask"]
433-
labels = tokenized["labels"]
434-
435-
# Pad to max_seq_length
436-
padding_length = max_seq_length - len(input_ids)
437-
if padding_length > 0:
438-
input_ids = input_ids + [pad_token_id] * padding_length
439-
attention_mask = attention_mask + [0] * padding_length
440-
labels = labels + [-100] * padding_length
441-
442-
trajectory_tensor = {
443-
"input_ids": torch.tensor([input_ids], dtype=torch.long),
444-
"attention_mask": torch.tensor([attention_mask], dtype=torch.long),
445-
"labels": torch.tensor([labels], dtype=torch.long),
446-
}
447-
448-
trajectory_tensors.append(trajectory_tensor)
449-
450-
# Calculate total trainable tokens (labels != -100)
451-
num_trainable_tokens = sum(
452-
(tensor_dict["labels"] != -100).sum().item()
453-
for tensor_dict in trajectory_tensors
454-
)
419+
num_trainable_tokens += sum(1 for l in labels if l != -100)
455420

456421
return SFTBatch(
457422
trajectory_tensors=trajectory_tensors,

0 commit comments

Comments
 (0)