A modular, configuration-driven framework for SFT (Supervised Fine-Tuning) and DPO (Direct Preference Optimization). Built on TRL, DeepSpeed, and Accelerate with multi-node SLURM support.
This project uses uv for dependency management. To create the Python environment, run:
uv syncTo run training locally, use accelerate launch. You must specify the distributed flags explicitly.
accelerate launch \
--num_machines 1 \
--num_processes 4 \
--dynamo_backend=inductor \
--use_deepspeed \
--same_network \
--rdzv_backend static \
--mixed_precision bf16 \
scripts/train.py \
--config configs/sft.yaml \
training.max_steps=100 \
offline=trueaccelerate launch \
--num_machines 1 \
--num_processes 4 \
--dynamo_backend=inductor \
--use_deepspeed \
--same_network \
--rdzv_backend static \
--mixed_precision bf16 \
scripts/train.py \
--config configs/dpo.yaml \
training.max_steps=100 \
offline=trueNote
The --mixed_precision flag passed to accelerate launch must match model.dtype in your config.
For cluster environments, use the submission script. It auto-generates a SLURM batch script based on your YAML configuration and submits it.
- SLURM job template:
src/post_training/slurm/job.sh.jinja
python scripts/submit.py --config configs/sft.yamlpost-training/
βββ configs/
β βββ sft.yaml # SFT example config
β βββ dpo.yaml # DPO example config
β βββ deepspeed/
β βββ zero2.yaml # DeepSpeed ZeRO Stage 2 config
β βββ zero3.yaml # DeepSpeed ZeRO Stage 3 config
βββ src/post_training/
β βββ config.py # OmegaConf dataclass schema + validation
β βββ methods/ # Trainer builders (SFT/DPO)
β βββ data/ # Dataset loading, transforms, mixing
β βββ chat_templates/ # Chat template registry + Jinja templates
β βββ callbacks/ # Custom callbacks (e.g., inference checkpoints)
β βββ slurm/ # SLURM script rendering + submission
β βββ utils/ # Logging + run directory utilities
βββ scripts/
β βββ train.py # Training entrypoint (supports CLI overrides)
β βββ submit.py # SLURM submission entrypoint
β βββ data.py # Data pipeline debugger + token-stats
β βββ wb.py # Weights & Biases utilities
βββ pyproject.toml
All run configuration lives in a single YAML file.
You do not need to edit Python scripts to change hyperparameters, models, or data mixtures.
- Override any YAML value via the CLI using dot-notation
- Or create a new YAML config specific to your run
scripts/train.py \
--config configs/sft.yaml \
model.name_or_path="meta-llama/Llama-3.1-8B" \
training.learning_rate=5e-6 \
sft.packing=falseSelect your training strategy using method.
-
SFT (Supervised Fine-Tuning)
- Key:
method: "sft" - Packing: set
sft.packing: trueto pack multiple short examples into a single sequence (recommended for efficiency) - Sequence length: controlled by
sft.max_seq_length
- Key:
-
DPO (Direct Preference Optimization)
- Key:
method: "dpo" - Loss type: set
dpo.loss_type(e.g.,sigmoid,hinge,ipo) - Reference model: set
dpo.ref_model_name_or_path- If
null, TRL creates an implicit copy of the active model - If using ZeRO Stage 3, consider specifying the reference model explicitly (implicit copy creation can be unstable with Stage 3)
- If
- Key:
The data pipeline is modularized into four distinct stages.
Define multiple datasets in data.datasets. The loader automatically interleaves them based on the weight parameter (normalized automatically).
data:
datasets:
- name: "my_dataset"
path: "org/dataset"
split: "train"
weight: 1.0 # Mixing weight (normalized automatically)Raw datasets often come in varying formats. Transforms normalize them into a standard messages list format before templating.
- Config:
transform: "transform_name"(in the dataset entry) - Registry:
src/post_training/data/transforms.py - Customization: decorate a function with
@register_transform("name")to add your own logic
Example (normalize raw fields into messages):
from post_training.data.transforms import register_transform
@register_transform("my_transform")
def my_transform(example: dict) -> dict:
return {
"messages": [
{"role": "user", "content": example["prompt"]},
{"role": "assistant", "content": example["answer"]},
]
}Templates convert the list of messages into a single string for the model.
- Config:
data.chat_template: "name" - Source: Jinja files located in
src/post_training/chat_templates/templates/
Use the data script to debug the pipeline stages (Raw β Transformed β Formatted β Tokenized) and to compute token statistics.
python scripts/data.py --config configs/sft.yaml --show-formatted --num-samples 3
python scripts/data.py --config configs/sft.yaml token-statsYou must specify exactly one determining factor for training duration in the training section:
- Step-based:
training.max_steps(fixed number of optimizer steps) - Sample-based:
training.num_training_samples(steps =ceil(samples / global_batch_size)) - Token-based:
training.num_training_tokens(steps based on total token count)- Only valid when
method: "sft"andsft.packing: true
- Only valid when
- DeepSpeed: configured via
deepspeed.config_path(e.g.,configs/deepspeed/zero3.yaml) - Accelerate flags: the
acceleratesection in the YAML mirrors the CLI flags required for multi-node setups (mixed_precision,dynamo_backend,rdzv_backend, etc.). These are used by the SLURM launcher to generate the correct job script. - Self-healing: the SLURM launcher (
src/post_training/slurm/) supports auto-requeueing.slurm.signal_time_secondsensures the job saves a checkpoint and requeues itself before the wall time expires
- What: full training state (optimizer + model)
- Location:
checkpoints/checkpoint-* - Logic: training automatically resumes from the latest checkpoint found here
- What: model + tokenizer only
- Location:
inference_checkpoints/step-* - Config:
checkpointing.inference_checkpoint_steps(set tonullto disable)
- Offline:
offline: true
Disables Hugging Face Hub / Weights & Biases network calls (essential for air-gapped nodes). - Debug:
debug.enabled: true
Forcesreport_to: none, uses a separate output directory, and allows overwriting existing runs.
The framework supports multiple logging backends and handles offline environments (e.g., air-gapped clusters).
For multi-node runs, SLURM output and error logs are stored within each run's specific directory:
<run_directory>/slurm/slurm-<job_id>.out: Standard output (including console logs and progress bars)<run_directory>/slurm/slurm-<job_id>.err: Standard error (including stack traces and warnings)
- Online: Logs are streamed directly to the WandB cloud. The project name is controlled by
logging.wandb_project. - Offline: When
offline: trueis set, WandB logs are saved locally to thewandb/directory in the project root.
To upload offline runs to the cloud (e.g., from a login node with internet access), use the utility script:
# Interactive mode - view and select runs to sync
python scripts/wb.py sync --interactive
# Sync a specific run by its training run name
python scripts/wb.py sync --run-name <run_name>Each run generates a unique directory based on paths.output_base (or paths.debug_base) and a run name auto-generated from the model, method, and dataset mix.
<output_base>/<run_name>/
βββ config.yaml # Frozen configuration for reproducibility
βββ checkpoints/ # Full TRL training state (resumable)
β βββ checkpoint-500/
βββ inference_checkpoints/ # Lightweight model + tokenizer only
β βββ step-500/
βββ logs/ # TensorBoard / Weights & Biases logs
βββ slurm/ # SLURM artifacts
βββ job.sh # The generated submission script
βββ slurm-<id>.out # Standard output
βββ slurm-<id>.err # Standard error
βββ failure_count # Tracks retries for self-healing
Full reference configuration for the default SFT setup:
# ============================================================================
# SFT (Supervised Fine-Tuning) Configuration
# ============================================================================
# Override any value via CLI dot-notation:
# accelerate launch \
# --num_machines 1 \
# --num_processes 4 \
# --dynamo_backend=inductor \
# --use_deepspeed \
# --same_network \
# --rdzv_backend static \
# --mixed_precision bf16 \
# scripts/train.py \
# --config configs/sft.yaml \
# training.max_steps=100 \
# offline=true
# ============================================================================
method: sft
run_name: null # auto-generated from model + datasets if null
offline: false # set true to disable all HuggingFace / wandb network calls
# -- Model -------------------------------------------------------------------
model:
name_or_path: "allenai/Olmo-3-1025-7B"
attn_implementation: "flash_attention_3"
dtype: "bfloat16"
# -- Training hyper-parameters -----------------------------------------------
training:
max_steps: null # Set explicitly, OR use num_training_samples below
num_training_samples: null # If set: max_steps = ceil(num_samples / effective_batch_size)
# num_training_tokens: null # Only valid when sft.packing=true (max_steps = ceil(tokens / (effective_batch_size * sft.max_seq_length)))
learning_rate: 2.0e-5
effective_batch_size: 32 # per_device * grad_accum * world_size
per_device_train_batch_size: 8
warmup_ratio: 0.03
lr_scheduler_type: "cosine_with_min_lr"
lr_scheduler_kwargs:
min_lr_rate: 0.1
gradient_checkpointing: true
bf16: true
seed: 42
use_liger_kernel: true
# -- SFT method parameters ---------------------------------------------------
sft:
max_seq_length: 4096
packing: true
# -- Checkpointing -----------------------------------------------------------
checkpointing:
save_steps: 200
save_total_limit: 2 # Full checkpoints to keep
inference_checkpoint_steps: 157 # Minimal inference model interval (set to null to disable)
inference_checkpoint_path: "inference_checkpoints" # Relative to run dir
# -- Data mix ----------------------------------------------------------------
data:
chat_template: "olmo3" # Name from chat template registry
num_proc: null # null = auto-detect, capped at 32
datasets:
- name: "nemotron_pt_v2"
path: "nvidia/Nemotron-Post-Training-Dataset-v2"
split: "stem"
weight: 1.0
transform: null # null = already conversational
# -- DeepSpeed ---------------------------------------------------------------
deepspeed:
config_path: "configs/deepspeed/zero2.yaml"
# -- Accelerate launch flags (explicit multi-node control) -------------------
accelerate:
mixed_precision: "bf16"
use_deepspeed: true
deepspeed_multinode_launcher: "standard" # "standard" | "pdsh" | etc.
same_network: true # All nodes on same network
rdzv_backend: "static" # "static" | "c10d" | "etcd"
dynamo_backend: "inductor" # "inductor" | "no" | etc.
# -- Logging & tracking ------------------------------------------------------
logging:
report_to:
- "wandb"
- "tensorboard"
wandb_project: "sft-training"
logging_steps: 1
include_num_input_tokens_seen: "non_padding"
# -- SLURM -------------------------------------------------------------------
slurm:
partition: "booster"
num_nodes: 1
gpus_per_node: 4
cpus_per_gpu: 32
wall_time: "02:00:00"
job_name: "sft-training"
signal_time_seconds: 300 # SIGUSR1 sent this many seconds before timeout to trigger self-healing
max_failures: 3 # Self-healing retry limit
# -- Debug mode --------------------------------------------------------------
debug:
enabled: false
override_existing: false
# -- Output paths -------------------------------------------------------------
paths:
output_base: "outputs"
debug_base: "outputs/debug"