A modular, configuration-driven framework for SFT (Supervised Fine-Tuning) and DPO (Direct Preference Optimization). Built on TRL, DeepSpeed, and Accelerate with multi-node SLURM support.
This repo supports two training backends:
- TRL -- SFT and DPO via
accelerate launch - LlamaFactory -- SFT, DPO, long-context tuning via Singularity containers
This project uses uv for dependency management. To create the Python environment, run:
uv syncTo include dev dependencies (required for pre-commit):
uv sync --extra devThis project uses pre-commit with ruff (lint + format) and black (format). To install the hooks:
uv run pre-commit installHooks will now run automatically on every git commit. To run them manually against all files:
uv run pre-commit run --all-filesTo run training locally, use accelerate launch. You must specify the distributed flags explicitly.
accelerate launch \
--num_machines 1 \
--num_processes 4 \
--dynamo_backend=inductor \
--use_deepspeed \
--same_network \
--rdzv_backend static \
--mixed_precision bf16 \
scripts/train.py \
--config configs/trl/sft.yaml \
training.max_steps=100 \
offline=trueaccelerate launch \
--num_machines 1 \
--num_processes 4 \
--dynamo_backend=inductor \
--use_deepspeed \
--same_network \
--rdzv_backend static \
--mixed_precision bf16 \
scripts/train.py \
--config configs/trl/dpo.yaml \
training.max_steps=100 \
offline=trueNote
The --mixed_precision flag passed to accelerate launch must match model.dtype in your config.
For cluster environments, use the submission script. It auto-generates a SLURM batch script based on your YAML configuration and submits it.
- SLURM job template:
src/post_training/slurm/job.sh.jinja
python scripts/submit.py --config configs/trl/sft.yamlpost-training/
βββ configs/
β βββ trl/
β β βββ sft.yaml # TRL SFT example config
β βββ llamafactory/
β β βββ long-context.yaml # LlamaFactory long-context SFT config
β βββ deepspeed/
β βββ zero2.yaml # DeepSpeed ZeRO Stage 2 config
β βββ zero3.yaml # DeepSpeed ZeRO Stage 3 config
β βββ z3_partial_offload.json # ZeRO Stage 3 with CPU offloading
βββ src/post_training/
β βββ config.py # OmegaConf dataclass schema + validation
β βββ methods/ # Trainer builders (SFT/DPO)
β βββ data/ # Dataset loading, transforms, mixing
β βββ chat_templates/ # Chat template registry + Jinja templates
β βββ callbacks/ # Custom callbacks (e.g., inference checkpoints)
β βββ slurm/ # SLURM script rendering + submission
β βββ utils/ # Logging + run directory utilities
βββ scripts/
β βββ train.py # Training entrypoint (supports CLI overrides)
β βββ submit.py # SLURM submission entrypoint
β βββ data.py # Data pipeline debugger + token-stats
β βββ wb.py # Weights & Biases utilities
βββ pyproject.toml
This is the golden rule: all run configuration lives in a single YAML file.
This YAML file specifies:
- The hyper-parameters of the target script
- The SLURM configuration, which might be cluster-specific
- The Singularity container to use, if applicable
You do not need to edit Python scripts to change these settings. Either:
- Override any YAML value via the CLI using dot-notation
- Or create a new YAML config specific to your run
scripts/train.py \
--config configs/trl/sft.yaml \
model.name_or_path="meta-llama/Llama-3.1-8B" \
training.learning_rate=5e-6 \
sft.packing=falseA job submission on the HPC should be a single line as follows:
python scripts/submit.py --config /path/to/config.yamlGiven the heterogeneity of cluster environments, training jobs should, where possible, run inside a Singularity (or Apptainer) container that bundles all required dependenciesβsuch as PyTorch, CUDA, Flash Attention, and any cluster-specific backend librariesβinto a single, portable environment, simplifying both setup and reproducibility across systems.
Container images are specified in the config under container.image. Set container: null for TRL bare-metal runs. When a container image is configured, the SLURM launcher passes it to singularity exec and bind-mounts the repository into the container at runtime, so no rebuild is needed when the code changes.
container:
image: /path/to/image.sif
bind_mounts:
- /data:/data
env_file: env/cluster.env # required when image is set; sourced before launchBoth the LlamaFactory and containerized TRL backends use this mechanism. Building containers for different HPCs is a work in progress, so if your cluster specific container is not available yet, please use the uv environment instead (or raise a pull request with a recipe for your cluster-specific container!). For TRL, use the uv environment by setting container: null or container.image: null; LlamaFactory requires a container image.
Select your training strategy using method.
-
SFT (Supervised Fine-Tuning)
- Key:
method: "sft" - Packing: set
sft.packing: trueto pack multiple short examples into a single sequence (recommended for efficiency) - Sequence length: controlled by
sft.max_seq_length
- Key:
-
DPO (Direct Preference Optimization)
- Key:
method: "dpo" - Loss type: set
dpo.loss_type(e.g.,sigmoid,hinge,ipo) - Reference model: set
dpo.ref_model_name_or_path- If
null, TRL creates an implicit copy of the active model - If using ZeRO Stage 3, consider specifying the reference model explicitly (implicit copy creation can be unstable with Stage 3)
- If
- Key:
The data pipeline is modularized into four distinct stages.
Define multiple datasets in data.datasets. The loader samples each dataset independently according to its weight, concatenates the sampled datasets, and shuffles the final mix with data.seed. A weight of 1.0 means the full dataset after transforms and filters, values below 1.0 undersample, values above 1.0 oversample, and 0.0 omits that dataset.
data:
seed: 42
datasets:
- name: "my_dataset"
path: "org/dataset"
split: "train"
weight: 1.0 # 1 = full dataset, <1 undersamples, >1 oversamplesRaw datasets often come in varying formats. Transforms normalize them into a standard messages list format before templating. SFT loading keeps the messages column and enforces its feature schema during mapping to avoid wrongly inferring the schema; conceretly, during the data transformation, some samples might be mapped to an empty list [of messages]. If the first sample falls into that case, it makes the automatically inferred schema of the mapped column wrong, raising exceptions during the mapping of the subsequent samples where each is mapped to a list of dicts (i.e., messages).
- Config:
transform: "transform_name"(in the dataset entry) - Registry:
src/post_training/data/transforms.py - Customization: decorate a function with
@register_transform("name")to add your own logic
Example (normalize raw fields into messages):
from post_training.data.transforms import register_transform
@register_transform("my_transform")
def my_transform(example: dict) -> dict:
return {
"messages": [
{"role": "user", "content": example["prompt"]},
{"role": "assistant", "content": example["answer"]},
]
}Templates convert the list of messages into a single string for the model.
- Config:
data.chat_template: "name" - Source: Jinja files located in
src/post_training/chat_templates/templates/
SFT in this framework uses TRL's assistant_only_loss=True, which masks the
cross-entropy loss on every non-assistant token (system + user). This depends on
the chat template wrapping the assistant content emission in
{% generation %}β¦{% endgeneration %} markers β transformers'
apply_chat_template(..., return_assistant_tokens_mask=True) uses them to build
the per-token loss mask.
If your template lacks those markers, build_sft_trainer() will refuse to start
with a ValueError that names the template and points to the fix. This is
deliberate: silently training without the mask measurably degrades downstream
performance (the framework previously had this bug; SFT computed CE loss on
every token in the packed sequence).
Templates that are safe for SFT today:
| Name | Source | Notes |
|---|---|---|
olmo3-instruct-sft |
allenai/OLMo-3-7B-Instruct-SFT (HF Hub) |
Use to reproduce the Instruct-SFT recipe. |
olmo3-think-sft |
allenai/Olmo-3-7B-Think-SFT (HF Hub) |
Use to reproduce the Think-SFT recipe. |
Templates that are not safe for SFT (kept for inference / DPO compatibility):
| Name | Notes |
|---|---|
olmo3 |
Legacy alias for the markerless Think-SFT template; preserved for inference parity only. |
chatml, tulu3, apertus |
Markerless; need {% generation %} markers added before they can be used for SFT. |
To use a custom template for SFT, wrap exactly the tokens that should contribute
to loss β typically content + function_calls/tool_calls + the closing
<|im_end|> or eos_token. Do not wrap the leading role-tag prefix
(<|im_start|>assistant\n); it's a deterministic control sequence the model
shouldn't have to predict.
Use the data script to debug the pipeline stages (Raw β Transformed β Formatted β Tokenized) and to compute token statistics.
python scripts/data.py inspect --config configs/trl/sft.yaml --show-formatted --num-samples 3
python scripts/data.py token-stats --config configs/trl/sft.yamlYou must specify exactly one determining factor for training duration in the training section:
- Step-based:
training.max_steps(fixed number of optimizer steps) - Sample-based:
training.num_training_samples(steps =ceil(samples / global_batch_size)) - Token-based:
training.num_training_tokens(steps based on total token count)- Only valid when
method: "sft"andsft.packing: true
- Only valid when
- DeepSpeed: configured via
deepspeed.config_path(e.g.,configs/deepspeed/zero3.yaml) - Accelerate flags: the
acceleratesection in the YAML mirrors the CLI flags required for multi-node setups (mixed_precision,dynamo_backend,rdzv_backend, etc.). These are used by the SLURM launcher to generate the correct job script. - Self-healing: the SLURM launcher (
src/post_training/slurm/) supports auto-requeueing.slurm.signal_time_secondsensures the job saves a checkpoint and requeues itself before the wall time expires
- Account: set
slurm.accountwhen your cluster requires an explicit#SBATCH --accountdirective; leave itnullto omit the directive.
- What: full training state (optimizer + model)
- Location:
checkpoints/checkpoint-* - Logic: training automatically resumes from the latest checkpoint found here
- What: model + tokenizer only
- Location:
inference_checkpoints/step-* - Config:
checkpointing.inference_checkpoint_steps(set tonullto disable)
-
Offline:
offline: trueDisables Hugging Face Hub / Weights & Biases network calls (essential for air-gapped nodes). -
Debug:
debug.enabled: trueForcesreport_to: none, uses a separate output directory, and allows overwriting existing runs. -
Tokenize only:
--tokenize-only(CLI flag ontrain.py/submit.py) Exits immediately after the trainer is initialized β dataset loading, tokenization, and packing all run, but the training loop is never entered. Useful for pretokenizing the dataset before committing to a full run. When passed tosubmit.py, the job is automatically constrained to 1 node and 1 GPU.python scripts/submit.py --config configs/trl/sft.yaml --tokenize-only
The framework supports multiple logging backends and handles offline environments (e.g., air-gapped clusters).
For multi-node runs, SLURM output and error logs are stored within each run's specific directory:
<run_directory>/slurm/slurm-<job_id>.out: Standard output (including console logs and progress bars)<run_directory>/slurm/slurm-<job_id>.err: Standard error (including stack traces and warnings)
- Online: Logs are streamed directly to the WandB cloud. The project name is controlled by
logging.wandb_project. - Offline: When
offline: trueis set, WandB logs are saved locally to thewandb/directory in the project root.
To upload offline runs to the cloud (e.g., from a login node with internet access), use the utility script:
# Interactive mode - view and select runs to sync
python scripts/wb.py sync --interactive
# Sync a specific run by its training run name
python scripts/wb.py sync --run-name <run_name>Each run generates a unique directory based on paths.output_base (or paths.debug_base) and a run name auto-generated from the model, method, and dataset mix.
<output_base>/<run_name>/
βββ config.yaml # Frozen configuration for reproducibility
βββ checkpoints/ # Full TRL training state (resumable)
β βββ checkpoint-500/
βββ inference_checkpoints/ # Lightweight model + tokenizer only
β βββ step-500/
βββ logs/ # TensorBoard / Weights & Biases logs
βββ slurm/ # SLURM artifacts
βββ job.sh # The generated submission script
βββ slurm-<id>.out # Standard output
βββ slurm-<id>.err # Standard error
βββ failure_count # Tracks retries for self-healing
An alternative backend using LlamaFactory for training, running inside a Singularity container.
- Build the Singularity container:
singularity build --fakeroot llamafactory.sif containers/llamafactory_jupiter.def
- Set the container path in
env/jupiter.env:export CONTAINER=/path/to/llamafactory.sif
python scripts/submit.py --config configs/llamafactory/long-context.yaml- Config:
configs/llamafactory/long-context.yaml - DeepSpeed:
configs/deepspeed/z3_partial_offload.json - Dataset registry:
data/llamafactory/dataset_info.json
Full reference configuration for the default SFT setup:
# ============================================================================
# SFT (Supervised Fine-Tuning) Configuration
# ============================================================================
# Override any value via CLI dot-notation:
# accelerate launch \
# --num_machines 1 \
# --num_processes 4 \
# --dynamo_backend=inductor \
# --use_deepspeed \
# --same_network \
# --rdzv_backend static \
# --mixed_precision bf16 \
# scripts/train.py \
# --config configs/trl/sft.yaml \
# training.max_steps=100 \
# offline=true
# ============================================================================
method: sft
backend: trl
run_name: null # auto-generated from model + datasets if null
offline: false # set true to disable all HuggingFace / wandb network calls
# -- Container ---------------------------------------------------------------
container: null # null = bare-metal; set image/binds/env_file for Singularity
# -- Model -------------------------------------------------------------------
model:
name_or_path: "allenai/Olmo-3-1025-7B"
attn_implementation: "flash_attention_3"
dtype: "bfloat16"
# -- Training hyper-parameters -----------------------------------------------
training:
max_steps: null # Set explicitly, OR use num_training_samples below
num_training_samples: null # If set: max_steps = ceil(num_samples / effective_batch_size)
# num_training_tokens: null # Only valid when sft.packing=true (max_steps = ceil(tokens / (effective_batch_size * sft.max_seq_length)))
learning_rate: 2.0e-5
effective_batch_size: 32 # per_device * grad_accum * world_size
per_device_train_batch_size: 8
warmup_steps: 0.03
lr_scheduler_type: "cosine_with_min_lr"
lr_scheduler_kwargs: # set to null for schedulers with no kwargs
min_lr_rate: 0.1
gradient_checkpointing: true
gradient_checkpointing_kwargs: null # null = use TRL/Transformers defaults
bf16: true
seed: 42
use_liger_kernel: true
# -- SFT method parameters ---------------------------------------------------
sft:
max_seq_length: 4096
packing: true
# -- Checkpointing -----------------------------------------------------------
checkpointing:
save_steps: 200
save_total_limit: 2 # Full checkpoints to keep
inference_checkpoint_steps: 157 # Minimal inference model interval (set to null to disable)
inference_checkpoint_path: "inference_checkpoints" # Relative to run dir
# -- Data mix ----------------------------------------------------------------
data:
chat_template: "olmo3-instruct-sft" # Name from chat template registry
num_proc: null # null = auto-detect, capped at 32
seed: 42 # RNG seed for dataset resampling and final shuffle
datasets:
- name: "nemotron_pt_v2"
path: "nvidia/Nemotron-Post-Training-Dataset-v2"
split: "stem"
weight: 1.0 # 1 = full dataset, <1 undersamples, >1 oversamples
transform: null # null = already conversational
# -- DeepSpeed ---------------------------------------------------------------
deepspeed:
config_path: "configs/deepspeed/zero2.yaml"
# -- Accelerate launch flags (explicit multi-node control) -------------------
accelerate:
mixed_precision: "bf16"
use_deepspeed: true
deepspeed_multinode_launcher: "standard" # "standard" | "pdsh" | etc.
same_network: true # All nodes on same network
rdzv_backend: "static" # "static" | "c10d" | "etcd"
dynamo_backend: "inductor" # "inductor" | "no" | etc.
# -- Logging & tracking ------------------------------------------------------
logging:
report_to:
- "wandb"
- "tensorboard"
wandb_project: "sft-training"
logging_steps: 1
include_num_input_tokens_seen: "non_padding"
# -- SLURM -------------------------------------------------------------------
slurm:
account: null # set when your cluster requires #SBATCH --account
partition: "booster"
num_nodes: 1
gpus_per_node: 4
cpus_per_task: 32
wall_time: "02:00:00"
job_name: "sft-training"
signal_time_seconds: 300 # SIGUSR1 sent this many seconds before timeout to trigger self-healing
max_failures: 3 # Self-healing retry limit
# -- Debug mode --------------------------------------------------------------
debug:
enabled: false
override_existing: false
# -- Output paths -------------------------------------------------------------
paths:
output_base: "outputs"
debug_base: "outputs/debug"