We expand the framework of test time reinforcement (TTRL) to a single sample adaptation paradigm and perform a comprehensive study on various datasets and models while compairing with other single sample adaptation methods.
STTRL operates at test time, treating the generation of a solution as a sequential decision-making problem. For each test sample, we use a reward model to guide a policy (the language model) towards a better solution through reinforcement learning. This allows the model to refine its internal reasoning process on-the-fly, customized for the specific complexities of the question at hand.
Key Features:
- Sample-Wise Adaptation: The model's generation strategy is optimized for each unique test instance.
- Model-Agnostic: Can be applied to a wide range of autoregressive language models.
- Improved Reasoning: Boosts performance on tasks requiring complex, multi-step thought processes.
Follow these steps to set up the environment and reproduce our results.
All experiments were conducted on a machine with 4 x NVIDIA H100 80GB GPUs.
- Clone the repository:
git clone https://github.com/your-username/sttrl.git cd sttrl - Create and activate the Conda environment:
conda env create -f environment.yaml conda activate sttrl
All scripts should be run from the verl/ directory.
cd verlThe main script is main.py. You can specify the dataset, model, and other parameters.
Base Command Template:
python main.py --dataset <DATASET_NAME> --model <MODEL_NAME> --voting_function majorityExample:
To reproduce our results on the AIME-TTT dataset with the Qwen/Qwen2.5-Math-1.5B model:
python main.py --dataset AIME-TTT --model Qwen/Qwen2.5-Math-1.5B --voting_function majoritySupported Datasets (--dataset):
AIME-TTTAMC-TTTMATH-TTTGSM8K-TTTGPQA-TTT
Supported Models (--model):
Qwen/Qwen2.5-Math-1.5BQwen/Qwen2.5-Math-7BQwen/Qwen2.5-7Bmeta-llama/Llama-3.1-8B-Instruct
Note for Non-Math Models: When using general-purpose models like Qwen/Qwen2.5-7B or Llama-3.1-8B-Instruct, we found the following flags improve performance:
python main.py \
--dataset AIME-TTT \
--model meta-llama/Llama-3.1-8B-Instruct \
--voting_function majority \
--val_temp 0.6 \
--separate_ref_gpuSupported Voting Functions (--voting_function):
majorityenergyconfidence
Simply run with the following settings, for single sample and continuous adaptation respectively.
python main.py \
--dataset AIME-TTT \
--model Qwen/Qwen2.5-Math-1.5B \
--voting_function majority \
--adaptation_method ttft \
--lr 0.0001 \python main.py \
--dataset AIME-TTT \
--model Qwen/Qwen2.5-Math-1.5B \
--voting_function majority \
--adaptation_method ttft \
--continuous \
--lr 5e-5 \To reproduce the baseline results, set the --adaptation_method flag.
Example (SLOT):
python main.py \
--dataset AIME-TTT \
--model Qwen/Qwen2.5-Math-1.5B \
--voting_function majority \
--adaptation_method slotExample (MEMO):
python main.py \
--dataset AIME-TTT \
--model Qwen/Qwen2.5-Math-1.5B \
--voting_function majority \
--adaptation_method memoTo run the original TTRL method, use the following command:
cd examples/ttrl/<MODEL_NAME>
bash aime.shHere, the <MODEL_NAME> can be replaced with the desired model, such as Math-1.5B or Math-7B. The folders contain scripts for different datasets as well.