Code for the MiST chemical reasoning experiments, including task definitions, GRPO training recipes, evaluation utilities, and cluster launchers used in the paper.
The reinforcement learning stack in this repository builds on Open-R1, with MiST-specific task implementations, recipes, and reproducibility tooling layered on top.
- Python 3.10 or newer
- Linux is the primary target environment for training and evaluation
- Shell utilities compatible with
bash - Python dependencies listed in
setup.pyanddev-requirements.txt
The repository targets Python >=3.10.9 in packaging metadata. Cluster
launchers are intended for Linux HPC environments such as SwissAI / CSCS and
Kuma. For the lightweight demo included with this repository, a standard CPU
workstation is sufficient.
The SCS diagnostic requires vllm plus a model that can be loaded by vLLM. For
the full 10k-row benchmark run, use at least one CUDA-capable GPU and
preferably a multi-GPU Linux environment.
- Full training workflows were developed for Linux HPC environments on SwissAI / CSCS and Kuma.
- The lightweight demo is designed to run from a standard local Python
environment with the dependencies listed in
dev-requirements.txt.
Create a Python environment and install the dependencies required for the lightweight demo:
python3 -m venv .venv
source .venv/bin/activate
pip install --upgrade pip
pip install -r dev-requirements.txt
pip install -e . --no-depsFor full GRPO training and large-scale evaluation, use the cluster-specific
setup described below and install the complete training dependencies defined in
setup.py.
For the SCS diagnostic or any workflow that relies on vllm, install the full
project dependencies instead of the lightweight demo-only environment:
pip install -e .The datasets and model checkpoints are released on Figshare (v3, ~7.3 GB total).
To download and set up everything automatically:
# Download datasets only (~2.2 GB)
python scripts/setup_data.py --data-dir ./data --skip-models
# Or full setup including model checkpoints (~7.3 GB)
python scripts/setup_data.py --data-dir ./dataThis downloads from Figshare, verifies MD5 checksums, extracts into the
directory layout expected by the training recipes, and writes a .env.local
file. Then:
source .env.local && export MIST_DATA_DIRImportant
Cluster-specific paths should live in .env.local, not in committed files.
See cluster/README.md and the example env files for
CSCS and Kuma.
Run one of the setup scripts if you want a starter .env.local for a supported
cluster:
python3 CSCS_setup.py
# or
python3 kuma_setup.pyFor a manual setup, define at least:
export MIST_MODELS_DIR=/path/to/models
export MIST_DATA_DIR=/path/to/data
export MIST_CACHE_DIR=/path/to/cacheThe recipes and model registry expand these variables automatically.
Supported cluster launchers:
launch_CSCS.slurmfor SwissAI / CSCSlaunch_kuma.slurmfor Kuma
[MODEL] is any model specified in model_paths.txt (for example
Qwen2.5-3B) and [TASK] is the recipe short name under recipes/ (without
the suffix .yaml).
sbatch launch_CSCS.slurm [MODEL] [TASK]
# Example: launch a job for training Qwen2.5-3B as specified in recipes/rxnpred.yaml
sbatch launch_CSCS.slurm Qwen2.5-3B rxnpredA third optional parameter is [RESUME_JOB_ID]. It is used if you would like to continue the training of a previous job. [RESUME_JOB_ID] should contain the job ID of the previous job you want to continue from. If you want to start a run from scratch (without using a previous run checkpoint), then you can set this parameter to 0 (however it is not necessary since the default value is 0 if the parameter is omitted).
sbatch launch_CSCS.slurm [MODEL] [TASK] [RESUME_JOB_ID]
# Example: launch a job for training Qwen2.5-3B as specified in recipes/rxnpred.yaml, continuing from job ID 123456
sbatch launch_CSCS.slurm Qwen2.5-3B rxnpred 123456
# Example: launch from scratch
sbatch launch_CSCS.slurm Qwen2.5-3B rxnpred 0A fourth optional parameter is [TASK_MODE], it can be used if you would like to use a specific mode in a task directly from the launch SLURM script. The goal of [TASK_MODE] is to allow to run the same recipe file with the same Task class with some small differences (without rewriting multiple subclasses of the same class), for example:
- If you would like to process the dataset in a different manner.
- If you would like to apply different chat templates / prompt templates.
- If you would like to compute the rewards in a different manner.
- Etc... (You can do whatever you want)
If you use it, the parameter [TASK_MODE] will be given to the task in self.task_mode. It is useful if you would like to run the same recipe files with multiple task modes without rewriting dozens of individual recipes and task classes.
Notes:
- If you would like to use the parameter
[TASK_MODE], you need to pass it as the fourth parameter. Therefore, you need to specify the third parameter[RESUME_JOB_ID]as well (even if you don't use it). If you do not want to use[RESUME_JOB_ID], you can set it to 0. - You can completely omit this fourth parameter, and it won't affect anything if you don't use it. The default value for
[TASK_MODE]is"base". - The
task_modeparameter should never be specified in a recipe file.
sbatch launch_CSCS.slurm [MODEL] [TASK] [RESUME_JOB_ID] [TASK_MODE]
# Example: launch a job for training Qwen2.5-3B as specified in recipes/rxnpred.yaml, running from scratch with task mode "base"
sbatch launch_CSCS.slurm Qwen2.5-3B rxnpred 0 baseA fifth optional parameter is [SAMPLING_PARAMS_CONFIG_NAME], it can be used if you would like to use a specific sampling parameters configuration file during an experiment (it will overwrite the default sampling parameters).
- You can check the documentation about the "Sampling parameters" in the documentation for more information.
- This parameter should be the name of the sampling parameters configuration file (without the suffix
.json). These files are found in the foldersampling_params/.
sbatch launch_CSCS.slurm [MODEL] [TASK] [RESUME_JOB_ID] [TASK_MODE] [SAMPLING_PARAMS_CONFIG_NAME]
# Example: launch a job for training Qwen2.5-3B_pretrained-v4-cot as specified in recipes/rxnpred.yaml, running from scratch with task mode "base" and uusing the sampling parameters specified in sampling_params/pretrained_models_v1.json
sbatch launch_CSCS.slurm Qwen2.5-3B_pretrained-v4-cot rxnpred 0 base pretrained_models_v1Since the default values are:
[RESUME_JOB_ID]= 0 (start from scratch)[TASK_MODE]= "base" (base task mode)[SAMPLING_PARAMS_CONFIG_NAME]= "default" (default sampling parameters)
The 4 following commands are equivalent:
sbatch launch_CSCS.slurm Qwen2.5-3B rxnpred
sbatch launch_CSCS.slurm Qwen2.5-3B rxnpred 0
sbatch launch_CSCS.slurm Qwen2.5-3B rxnpred 0 base
sbatch launch_CSCS.slurm Qwen2.5-3B rxnpred 0 base defaultAs a final note, remind that the parameter order is important. If you want to use the last parameters, it is necessary to specify the previous parameters as well (you can use the default values written above).
The documentation is built using Sphinx. To build and view the documentation locally:
cd docs
make html
python -m http.server -d build/htmlThen open http://localhost:8000 in your browser.
This repository includes lightweight 50-row fixtures under demo/ for all
implemented tasks, enabling smoke testing without downloading full datasets.
demo/rxnpred_tiny/— reaction prediction (40 train + 10 test)demo/datasets/*.csv— 50-row CSV fixtures for I2S, canonicalization, permutation, hydrogen, reaction naming, inversion, replacement, and true/false tasksdemo/kinetic_tiny/— synthetic kinetic data (40 train + 10 val)demo/crystalrelax_tiny/— M2S crystal structures (40 train + 10 test)demo/condmatgen_tiny/— element lists for material generation (50 entries)
Run the demo from the repository root:
PYTHONPATH=src python demo/run_demo.pyTo smoke-test dataset loaders across all tasks:
PYTHONPATH=src python demo/run_fixture_smoke.pyThis exercises fixtures for all tasks: rxnpred, iupacsm,
iupacsm_with_tags, canonic, canonmc, smi_permute, smhydrogen,
kinetic, rxn_inversion, rxn_replacement, rxn_naming, rxn_truefalse,
and conditionally crystalrelax and condmatgen (when heavy dependencies are
installed).
See demo/fixture_manifest.csv for the complete task-to-fixture mapping.
For the Nature software checklist and the public release, use the release
tracking files under release/:
release/submission_package_checklist.mdtracks which checklist items are already covered by the repository and which assets must still be exported.release/figshare_upload_manifest.csvlists the planned GitHub and Figshare artifacts for code, datasets, manifests, and model snapshots.release/dataset_components.csvis the detailed working index for the datasets and derived components referenced in the manuscript.
The repository contains the task code, recipes, and launcher examples used for the MiST GRPO experiments. Full reproduction of the training results reported in the manuscript requires a Linux multi-GPU environment and the datasets described in the manuscript appendix.
The key SMILES Competence Score (SCS) workflow is implemented in
src/open_r1/diagnostic/smiles_competence.py.
Before running SCS, install the full project dependencies so that vllm is
available:
pip install -e .For a small reviewer smoke test on the bundled 50-row fixture:
PYTHONPATH=src python src/open_r1/diagnostic/smiles_competence.py \
--model /path/to/model \
--data-path demo/datasets/CRLLM-PubChem-compounds1M.sample.csv \
--output-dir output/scs-smoke \
--num-rows 50 \
--tensor-parallel-size 1This writes:
output/scs-smoke/lps_canonical.csvoutput/scs-smoke/lps_random.csvoutput/scs-smoke/lps_corrupt.csvoutput/scs-smoke/summary.json
For the full 10k-example diagnostic run described in the manuscript:
PYTHONPATH=src python src/open_r1/diagnostic/smiles_competence.py \
--model /path/to/model \
--data-path "${MIST_DATA_DIR}/CRLLM-PubChem-compounds1M.csv" \
--output-dir "${MIST_OUTPUT_DIR}/scs/full" \
--num-rows 10000 \
--tensor-parallel-size 1For a multi-GPU cluster launch on CSCS-style infrastructure:
sbatch launch_diagnostics.slurm Qwen2.5-3BThe diagnostics launcher also accepts optional overrides:
sbatch launch_diagnostics.slurm Qwen2.5-3B \
"${MIST_DATA_DIR}/CRLLM-PubChem-compounds1M.csv" \
10000 \
"${MIST_OUTPUT_DIR}/scs/qwen25-3b" \
4At a minimum, full reproduction requires:
- the released MiST code in this repository
- the released datasets and derived task splits listed in
release/ - the model paths or MiST checkpoints referenced in
model_paths.txt - a cluster environment compatible with the provided CSCS or Kuma launchers
For release tracking, see:
release/release_scope.mdrelease/dataset_components.csvrelease/reviewer_runs.md
GRPO training is covered by this repository and the cluster launchers above.
The mid-training / pretraining pipeline is only partially represented here at
the moment: the final public release still needs the preprocessing scripts,
manifests, and split definitions tracked under release/.
For a minimal end-to-end GRPO smoke run on a single GPU, use the bundled 50-example reaction-prediction fixture and the smoke accelerate config:
accelerate launch --config_file configs/smoke_single_gpu.yaml \
src/open_r1/run_r1_grpo.py \
--config recipes/rxnpred.smoke.yaml \
--model_name_or_path Qwen/Qwen2.5-3B \
--output_dir output/rxnpred-smoke \
--run_name rxnpred-smoke-qwen25-3b \
--base_model_name Qwen/Qwen2.5-3B \
--base_model_id Qwen/Qwen2.5-3BThis smoke path does not require a MiST checkpoint. It is intended to validate that the GRPO loop, task loading, reward wiring, and checkpoint output all work with any compatible base model that fits on the available GPU.
MiST is designed to be easily extensible with new chemistry tasks suitable
for reasoning. Each task inherits from the base RLTask class and implements
specific logic for data handling and reward calculation.
- Create a new file in
src/open_r1/tasks/for your task, e.g.sampletask.py - Inherit from the base
RLTaskclass and implement required methods, e.g.SampleTask(RLTask)- During the GRPO training script, the methods
load,dataset_preprocessand the different reward functions*_rewardwill be called.
- During the GRPO training script, the methods
- Add class to
CHEMTASKSinsrc/open_r1/tasks/__init__.py, e.g.'sampletask': SampleTask - Write a recipe with the same name as the task
recipes/sampletask.yaml- The run will be logged on wandb under the project named r1-
[TASK](e.g. r1-sampletask). Therefore, runs using different recipe files will be logged under different wandb projects. - If you add a dot in your recipe filename (e.g.
sampletask.variant1.yaml), the run will also be logged under r1-sampletask (everything after the dot will be ignored). This is useful if you want to run multiple experiments with different recipe files but keep them under the same wandb project for analysis.
- The run will be logged on wandb under the project named r1-
- Add documentation:
- Create an entry under
docs/source/tasks/sampletask.rst(use the template.rst) - Add it to the modules index:
docs/source/modules.rstastasks/sampletask.rst
- Create an entry under
Here's a template for creating a new task:
from open_r1.tasks.base import RLTask
from datasets import DatasetDict
class NewTask(RLTask):
"""
Description of your new task.
This task should [describe what the task does and its purpose].
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.question_template = "Your task-specific question format: {}"
def load(self) -> DatasetDict:
"""
Load and prepare the dataset for the task.
Returns:
DatasetDict: Dataset with 'train' and 'test' splits
"""
# Implement dataset loading logic
pass
def accuracy_reward(self, completions, solution, **kwargs):
"""
Calculate rewards for model completions.
Args:
completions (List[str]): Model generated responses
solution (List[str]): Ground truth solutions
Returns:
List[float]: Rewards for each completion
"""
# Implement reward calculation
passHere is a list of the methods that you can implement in your task class:
load()(mandatory): This method is responsible for loading the dataset (it is called during the GRPO training). It usually uses thedataset_id_or_pathdefined in the recipe file (automatically parsed from the recipe file and set inself.dataset_id_or_pathin your task) and create the class variableself.dataset.- Input: nothing
- Output: nothing
dataset_preprocess()(optional): This method is called after theload()method in the GRPO training and is used to preprocess the dataset. The method is defined by default in the base class so you can omit it if you don't need to preprocess the dataset in a custom manner. If you would like to add custom preprocessing, you can override this method and takes inspiration from the implementation in the base class (RLTask).- Input: tokenizer (huggingface Tokenizer)
- The tokenizer can be used in the method to apply a chat template for example (
tokenizer.apply_chat_template()).
- The tokenizer can be used in the method to apply a chat template for example (
- Output: dataset (huggingface Dataset)
- This dataset should contain two splits:
dataset["train"]anddataset["test"]. - Each of these splits should at least contain a column named
"prompt", you can also add as many other columns as you need (then the other columns can be used during the computations of the rewards).
- This dataset should contain two splits:
- Input: tokenizer (huggingface Tokenizer)
*_reward()(optional): You can implement as many reward functions as you want. The list of reward functions used during an experiment should be specified in the recipe file.- Input: completions, **kwargs
- During the GRPO training, the reward functions will be called automatically with the following parameters:
completions: list of strings containing the generated text completions (without the prompts). The completions usually contains the thinking and the answer if you follow the standard format.**kwargs: any additional column found inself.datasetwill be passed as keyword arguments (in a list in the same way ascompletions). For example, if you have a column named"prompt", the parameterprompt=...will be taken as input. It is useful if you would like to compute rewards and checking if the solutions/expected answer is found in the completions.
- During the GRPO training, the reward functions will be called automatically with the following parameters:
- Output: rewards (list of reward float values with the same length as completions)
- Input: completions, **kwargs
accuracy_reward()(mandatory): This reward function is used to evaluate the accuracy of the reward (if the answer equals the expected solution).format_reward()(optional): Predefined reward used to reward the correct formatting of the completions ( and tags correctly formatted in the completions).reasoning_steps_reward()(optional): Predefined reward used to reward a step-by-step thinking in the completions.get_metrics()(optional): Optional function that can be used to log additional metrics in the wandb run. This function is called automatically during the GRPO training.- Input: nothing
- Output: dictionary of metrics with the format
{key[str]: value[float]}- Each metrics will be logged in wandb in
custom/[key]with the value[value]. - These metrics can be computed during the reward functions for example and saved in a class variable of your choice. The function
get_metrics()just need to output these values for logging.
- Each metrics will be logged in wandb in
The Forward Reaction task demonstrates how to implement a chemical reaction prediction task:
from open_r1.tasks.base import RLTask
from rdkit import Chem
class ForwardReaction(RLTask):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.question_template = (
f"What is the product of the following reaction? "
f"Reactants: {self.begin_smiles_tag} {{}} {self.end_smiles_tag}"
)
def accuracy_reward(self, completions, solution, **kwargs):
rewards = []
for content, sol in zip(completions, solution):
ans = self.preprocess_response(content)
try:
if Chem.MolToSmiles(Chem.MolFromSmiles(ans)) == \
Chem.MolToSmiles(Chem.MolFromSmiles(sol)):
rewards.append(1)
else:
rewards.append(-0.5)
except:
rewards.append(-1)
return rewardsTo specify a recipe for your task, copy from recipes/template.yaml and modify the section # Chemical Task Arguments:
# recipes/my_task.yaml
chem_task: my_task
dataset_id_or_path: ${MIST_DATA_DIR}/my_dataset.csv # can also be a HF dataset id
rewards:
- accuracy
- format
task_kwargs:
my_kwarg1: my_value1
my_kwarg2: my_value2- chem_task: Name of the task class. The task names are defined in the file
src/open_r1/tasks/__init__.pyin the keys of the variableCHEMTASKS. - dataset_id_or_path: Path to the dataset. This argument can be used anywhere in the task class, but it is usually used in the
loadmethod. - rewards: List of reward functions to be used. The available reward functions depend on the task (you can implement as many as you want). However, each reward function in the task class should ends by
_rewardand the suffix is omitted in the recipe file. For example, if you want to use your functionaccuracy_reward(), you need to specifyaccuracyin the recipe file. - task_kwargs: Special argument (dict-like) that can contain any additional keyworded argument that you would like to pass to your task.
- This argument is optional and can be omitted in the recipe file if you don't need it.
- There are many other training parameters in the recipe file, you can keep the default values (as in
recipes/template.yaml) but feel free to modify them if you need it (however it could lead to unexpected results or crashes). If you just built your task, it's recommended to keep the default parameters, ensure that your task is working and then modify these parameters to your needs.
When creating a new task, ensure:
- Base Class Inheritance: Inherit from
RLTask - Required Methods: Implement at minimum:
load(): Dataset loadingaccuracy_reward(): Reward calculation
- Documentation:
- Class docstring explaining the task
- Method docstrings
- Example usage
- Testing: Add tests for your task in
tests/
- Create a new RST file in
docs/source/tasks/for your task. Use the template undertasks/template.rst. - Add your task to
docs/source/modules.rst - Include examples and usage instructions
- Build and verify the documentation:
cd docs; make clean; make html; python -m http.server 7000
- Forward Reaction (
rxnpred): Chemical reaction product prediction - IUPAC to SMILES (
iupacsm): Convert IUPAC names to SMILES notation - Canonicalize SMILES (
canonic): SMILES canonicalization - Canonicalize SMILES MCQA (
canonmc): Multiple-choice SMILES canonicalization - SMILES Permutation (
smi_permute): Generate alternative SMILES for same molecule - SMILES Hydrogen (
smhydrogen): Add/remove implicit hydrogens - Kinetic Classification (
kinetic): Kinetic reaction mechanism classification - Reaction Inversion (
rxn_inversion): MCQ — identify correct reaction among inverted fakes - Reaction Replacement (
rxn_replacement): MCQ — identify correct reaction among modified fakes - Reaction Naming (
rxn_naming): Classify reactions into 10 named categories - Reaction True/False (
rxn_truefalse): Binary reaction validity classification - Conditional Material Generation (
condmatgen): Generate novel crystal compositions from element sets - Crystal Structure Relaxation (
crystalrelax): Relax perturbed binary crystal structures
For detailed examples and API reference, see the documentation.
The list of models can be found in the file model_paths.txt.
- The models are stored in the folder
LLM_models/. - If you want to add a new model, you can add it in the file
model_paths.txt(in a new line with the format[model_id]: [path]) and add the model in the appropriate folder (LLM_models).
Multiple custom models were pretrained:
- Qwen2.5-3B_pretrained-v1
- Original name: qwen_3b_pretrained
- Qwen2.5-3B_pretrained-v1_cot-v1
- Original name: qwen-cot
- Qwen2.5-3B_pretrained-v2
- Original name: qwen_3b_sft
- Qwen2.5-3B_pretrained-v3
- Original name: qwen_sft_full
- Qwen2.5-3B_pretrained-v4-cot
- Original name: qwen_cot_v3
It is possible to modify the sampling parameters used during the training by writing configurations in the folder sampling_params/:
sampling_params/model_default_sampling_params.txt: This file contains the default sampling parameters configurations used for the models. The line format is the following:[model_id]: [sampling_params_config_name]. It is optional, if a model is not specified, it means that the default sampling parameters will be used.- /!\ The default sampling parameters should only be modified when adding a new model and never modified afterwards (for experiment tracking purposes).
- The default sampling parameters have been modified for custom pretrained models that specifically need a different configuration (these models can't be used with the default configuration).
sampling_params/*.json: These files contain the sampling parameters configurations.- /!\ The 'default' configuration is a reserved keyword. You can't create a file called
sampling_params/default.json. Please use a different naming, the default configuration should never be modified. - You can create a new configuration file if you want to experiment with different sampling parameters. In that case, do not modify the file
sampling_params/model_default_sampling_params.txtwhich contain the default configurations only. You can give the name of your sampling parameters configuration (without the suffix '.json') during the launch of the SLURM job (see the launcher examples at the start of the documentation).
- /!\ The 'default' configuration is a reserved keyword. You can't create a file called