Skip to content

Commit 6559367

Browse files
Kovboclaude
andauthored
Update transformers to v5.x, unsloth, and add MoE LoRA conversion (#576)
* feat: update transformers to v5.x, unsloth, and add MoE LoRA conversion Update core dependencies for transformers v5 ecosystem: - transformers: >=4.55.2,<=4.57.3 → >=5.1.0 - unsloth: 2025.12.9 → 2026.2.1 - unsloth-zoo: 2025.12.7 → 2026.2.1 (+ updated VCS pin) - trl: 0.20.0 → >=0.28.0 - peft: >=0.14.0 → >=0.18.0 (required by transformers v5) Fix transformers v5 breaking changes: - Replace removed dummy_pt_objects import with direct transformers import - Update masking_utils patch return type (now returns 5 values) - Remove deprecated TrainerArgs fields (overwrite_output_dir, jit_mode_eval, mp_parameters, logging_dir, fp16_backend, push_to_hub_token/model_id/organization) Add MoE LoRA adapter conversion utility for vLLM compatibility: - Unsloth + transformers v5 saves MoE LoRA as fused 2D tensors - vLLM expects per-expert format - Auto-detect and convert after checkpoint save Closes #575 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: pin trl<=0.24.0 for unsloth 2026.2.1 compatibility Unsloth 2026.2.1 requires trl>0.18.2,!=0.19.0,<=0.24.0. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: override unsloth dep constraints for transformers v5 + trl compat Unsloth 2026.2.1's pyproject.toml has overly strict constraints (transformers<=4.57.6, trl<=0.24.0) but the February-2026 release notes confirm v5.1.0 + trl 0.27.1 work well. Use uv override-dependencies to allow the upgrade. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: add warnings_issued attr for transformers v5 + unsloth compat Transformers v5 removed `warnings_issued` from PreTrainedModel, but Unsloth's GRPOTrainer still accesses it during initialization. Add it as an empty dict on the PEFT model before creating the trainer. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: add return_dict=False to apply_chat_template calls for transformers v5 Transformers v5 changed apply_chat_template to return BatchEncoding by default when tokenize=True. Add return_dict=False to all calls that expect list[int] return type. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: remove unsloth-zoo VCS pin, use PyPI 2026.2.1 instead The bradhilton/unsloth-zoo fork is at version 2025.8.4 which is missing modules needed by unsloth 2026.2.1 (e.g. unsloth_zoo.device_type). Switch to the official PyPI release which matches unsloth 2026.2.1. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * revert: remove unnecessary changes to backend.vcs.txt and model.py These changes were not needed for the transformers v5 upgrade: - backend.vcs.txt: not used for installation (pyproject.toml handles deps) - model.py TrainerArgs: TypedDict fields don't cause runtime errors Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: remove TrainerArgs fields removed in transformers v5 Remove fields that transformers v5 dropped from TrainingArguments: overwrite_output_dir, logging_dir, jit_mode_eval, half_precision_backend, tpu_num_cores, past_index, fp16_backend, push_to_hub_model_id, push_to_hub_organization, push_to_hub_token, mp_parameters, torchdynamo, ray_scope. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: pin transformers==5.1.0 to avoid breakage from future releases Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: restore trl==0.20.0 pin and remove unnecessary trl override trl was originally pinned to 0.20.0. No reason to loosen it — 0.20.0 already satisfies unsloth's trl<=0.24.0 constraint. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * refactor: centralize apply_chat_template return_dict=False patch Instead of adding return_dict=False to every call site, patch PreTrainedTokenizerBase.apply_chat_template once in patches.py to default return_dict=False. This restores transformers v4 behavior (returning list[int]) globally. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: correct comment about warnings_issued workaround The attribute wasn't removed in transformers v5 — Unsloth's model patching can leave the PEFT model without it. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: remove invalid exclude-dependencies and add apex dependency-metadata exclude-dependencies is not a valid [tool.uv] field in uv 0.8.x, which caused the entire settings section to be silently ignored. This meant dependency-metadata, no-build-isolation-package, and extra-build-dependencies were all skipped, forcing uv to build apex from source during resolution — which fails on non-GPU machines missing torch. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * add exclude-dependencies = ["pynvml"] back * clean pyproject * update uv lock * update transformers to v5.2.0 * cleaner types * lint fix * add extra fix to lora conversion * fix build * ruff fix * revert pyproject change --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 6560dfd commit 6559367

8 files changed

Lines changed: 260 additions & 41 deletions

File tree

pyproject.toml

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ dependencies = [
1919
plotting = ["matplotlib>=3.10.1", "seaborn>=0.13.2"]
2020

2121
backend = [
22-
"peft>=0.14.0",
22+
"peft>=0.18.0",
2323
"hf-xet>=1.1.0",
2424
"bitsandbytes>=0.45.2",
2525
"unsloth==2026.2.1",
@@ -30,7 +30,7 @@ backend = [
3030
"awscli>=1.38.1",
3131
"setuptools>=78.1.0",
3232
"wandb==0.25.0",
33-
"transformers>=4.55.2,<=4.57.3",
33+
"transformers==5.2.0",
3434
"duckdb>=1.0.0",
3535
"pyarrow>=15.0.0",
3636
"trl==0.20.0",
@@ -65,7 +65,7 @@ tinker = [
6565
"pydantic>=2.12.5",
6666
"tinker>=0.8.1",
6767
"torch>=2.8.0",
68-
"transformers>=4.55.2,<=4.57.3",
68+
"transformers==5.2.0",
6969
"uvicorn>=0.35.0",
7070
"datrie>=0.8.3",
7171
]
@@ -122,7 +122,13 @@ required-version = ">=0.6.15"
122122
# Override numpy to <2.0 for compatibility with megatron-core in the training
123123
# environment. vLLM 0.15.1 pulls opencv-python-headless>=4.13 which wants
124124
# numpy>=2 on Python 3.9+, but megatron-core requires numpy<2.
125-
override-dependencies = ["transformer-engine>=2.11.0", "numpy<2"]
125+
override-dependencies = [
126+
"transformer-engine>=2.11.0",
127+
"numpy<2",
128+
# Override unsloth's overly strict constraint on transformers — v5.x
129+
# is confirmed working per unsloth February-2026 release notes
130+
"transformers==5.2.0",
131+
]
126132
exclude-dependencies = ["pynvml"]
127133
no-build-isolation-package = ["apex", "transformer-engine", "transformer-engine-cu12", "transformer-engine-torch", "megatron-core", "megatron-bridge", "nv-grouped-gemm", "mamba-ssm", "causal-conv1d"]
128134

src/art/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,13 @@ def __init__(self, **kwargs):
4040
import transformers
4141

4242
try:
43-
from .transformers.patches import patch_preprocess_mask_arguments
43+
from .transformers.patches import (
44+
patch_apply_chat_template,
45+
patch_preprocess_mask_arguments,
46+
)
4447

4548
patch_preprocess_mask_arguments()
49+
patch_apply_chat_template()
4650
except Exception:
4751
pass
4852
except ImportError:

src/art/dev/model.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,6 @@ class PeftArgs(TypedDict, total=False):
197197

198198
class TrainerArgs(TypedDict, total=False):
199199
output_dir: str | None
200-
overwrite_output_dir: bool
201200
do_train: bool
202201
do_eval: bool
203202
do_predict: bool
@@ -226,7 +225,6 @@ class TrainerArgs(TypedDict, total=False):
226225
log_level: str
227226
log_level_replica: str
228227
log_on_each_node: bool
229-
logging_dir: str | None
230228
logging_strategy: "IntervalStrategy | str"
231229
logging_first_step: bool
232230
logging_steps: float
@@ -243,25 +241,21 @@ class TrainerArgs(TypedDict, total=False):
243241
use_mps_device: bool
244242
seed: int
245243
data_seed: int | None
246-
jit_mode_eval: bool
247244
use_ipex: bool
248245
bf16: bool
249246
fp16: bool
250247
fp16_opt_level: str
251-
half_precision_backend: str
252248
bf16_full_eval: bool
253249
fp16_full_eval: bool
254250
tf32: bool | None
255251
local_rank: int
256252
ddp_backend: str | None
257-
tpu_num_cores: int | None
258253
tpu_metrics_debug: bool
259254
debug: str | list[DebugOption]
260255
dataloader_drop_last: bool
261256
eval_steps: float | None
262257
dataloader_num_workers: int
263258
dataloader_prefetch_factor: int | None
264-
past_index: int
265259
run_name: str | None
266260
disable_tqdm: bool | None
267261
remove_unused_columns: bool | None
@@ -302,15 +296,8 @@ class TrainerArgs(TypedDict, total=False):
302296
include_inputs_for_metrics: bool
303297
include_for_metrics: list[str]
304298
eval_do_concat_batches: bool
305-
fp16_backend: str
306-
push_to_hub_model_id: str | None
307-
push_to_hub_organization: str | None
308-
push_to_hub_token: str | None
309-
mp_parameters: str
310299
auto_find_batch_size: bool
311300
full_determinism: bool
312-
torchdynamo: str | None
313-
ray_scope: str | None
314301
ddp_timeout: int
315302
torch_compile: bool
316303
torch_compile_backend: str | None

src/art/preprocessing/tokenize.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,7 @@ def tokenize_trajectory(
197197
continue_final_message=True,
198198
),
199199
)
200-
sentinal_token_id = max(
201-
set(range(cast(int, tokenizer.vocab_size))) - set(original_token_ids)
202-
)
200+
sentinal_token_id = max(set(range(tokenizer.vocab_size)) - set(original_token_ids))
203201
sentinal_token = tokenizer.decode(sentinal_token_id)
204202
token_template_messages: list[dict[str, Any]] = []
205203
for original, message in zip(messages_and_choices, messages):
@@ -287,11 +285,14 @@ def tokenize_trajectory(
287285
except (IndexError, ValueError):
288286
token_ids[start:end] = [
289287
token_id if token_id is not None else tokenizer.eos_token_id
290-
for token_id in tokenizer.convert_tokens_to_ids(
291-
[
292-
token_logprob.token or tokenizer.eos_token
293-
for token_logprob in token_logprobs
294-
]
288+
for token_id in cast(
289+
list[int],
290+
tokenizer.convert_tokens_to_ids(
291+
[
292+
token_logprob.token or tokenizer.eos_token
293+
for token_logprob in token_logprobs
294+
]
295+
),
295296
)
296297
]
297298
logprobs[start:end] = (
@@ -346,7 +347,7 @@ def tokenize_trajectory(
346347
return TokenizedResult(
347348
advantage=advantage,
348349
chat=chat,
349-
tokens=[tokenizer.decode(token_id) for token_id in token_ids],
350+
tokens=[cast(str, tokenizer.decode(token_id)) for token_id in token_ids],
350351
token_ids=token_ids,
351352
input_pos=list(range(len(token_ids))),
352353
assistant_mask=assistant_mask,

src/art/transformers/patches.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import functools
12
from typing import TYPE_CHECKING, Optional, Union
23

34
import torch
45
from transformers import masking_utils
56
from transformers.cache_utils import Cache
67
from transformers.configuration_utils import PretrainedConfig
8+
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
79

810
if TYPE_CHECKING:
911
from torch.nn.attention.flex_attention import BlockMask
@@ -35,3 +37,19 @@ def _patched_preprocess_mask_arguments(
3537

3638
def patch_preprocess_mask_arguments() -> None:
3739
masking_utils._preprocess_mask_arguments = _patched_preprocess_mask_arguments # ty:ignore[invalid-assignment]
40+
41+
42+
def patch_apply_chat_template() -> None:
43+
"""Default return_dict=False in apply_chat_template for transformers v5.
44+
45+
Transformers v5 changed the default from list[int] to BatchEncoding.
46+
This restores the v4 behavior so all call sites get list[int] back.
47+
"""
48+
original = PreTrainedTokenizerBase.apply_chat_template
49+
50+
@functools.wraps(original)
51+
def _patched(self, *args, **kwargs): # type: ignore
52+
kwargs.setdefault("return_dict", False)
53+
return original(self, *args, **kwargs)
54+
55+
PreTrainedTokenizerBase.apply_chat_template = _patched # type: ignore

src/art/unsloth/service.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from datasets import Dataset
1414
import peft
1515
import torch
16+
from transformers import GenerationMixin, PreTrainedModel
1617
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
17-
from transformers.utils.dummy_pt_objects import GenerationMixin, PreTrainedModel
1818
from trl import GRPOConfig, GRPOTrainer
1919
from vllm import AsyncEngineArgs
2020
from vllm.lora.request import LoRARequest
@@ -30,6 +30,7 @@
3030
packed_tensors_from_dir,
3131
)
3232
from ..preprocessing.tokenize import SFTBatch
33+
from ..utils.convert_moe_lora import convert_checkpoint_if_needed
3334
from ..utils.get_model_step import get_step_from_dir
3435
from ..utils.output_dirs import get_step_checkpoint_dir
3536
from ..vllm import get_llm, get_worker, openai_server_task, run_on_workers
@@ -156,6 +157,7 @@ def save_checkpoint(
156157
checkpoint_dir = get_step_checkpoint_dir(output_dir, next_step)
157158
os.makedirs(checkpoint_dir, exist_ok=True)
158159
trainer.save_model(checkpoint_dir)
160+
convert_checkpoint_if_needed(checkpoint_dir)
159161
return checkpoint_dir
160162

161163

@@ -436,6 +438,7 @@ async def start_openai_server(
436438
lora_path = get_step_checkpoint_dir(self.output_dir, 0)
437439
os.makedirs(os.path.dirname(lora_path), exist_ok=True)
438440
self._state.trainer.save_model(lora_path)
441+
convert_checkpoint_if_needed(lora_path)
439442
self._latest_step = 0
440443
else:
441444
self._latest_step = get_step_from_dir(self.output_dir)
@@ -921,6 +924,11 @@ def _state(self) -> UnslothState:
921924
),
922925
)
923926

927+
# Unsloth's model patching can leave the PEFT model without
928+
# `warnings_issued`, which GRPOTrainer expects during init.
929+
if not hasattr(peft_model, "warnings_issued"):
930+
peft_model.warnings_issued = {} # type: ignore[attr-defined]
931+
924932
# Initialize trainer with dummy dataset
925933
data = {"prompt": ""}
926934
trainer = GRPOTrainer(

0 commit comments

Comments
 (0)