Skip to content

schwallergroup/mist

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

358 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MiST

Code for the MiST chemical reasoning experiments, including task definitions, GRPO training recipes, evaluation utilities, and cluster launchers used in the paper.

The reinforcement learning stack in this repository builds on Open-R1, with MiST-specific task implementations, recipes, and reproducibility tooling layered on top.

System Requirements

  • Python 3.10 or newer
  • Linux is the primary target environment for training and evaluation
  • Shell utilities compatible with bash
  • Python dependencies listed in setup.py and dev-requirements.txt

The repository targets Python >=3.10.9 in packaging metadata. Cluster launchers are intended for Linux HPC environments such as SwissAI / CSCS and Kuma. For the lightweight demo included with this repository, a standard CPU workstation is sufficient.

The SCS diagnostic requires vllm plus a model that can be loaded by vLLM. For the full 10k-row benchmark run, use at least one CUDA-capable GPU and preferably a multi-GPU Linux environment.

Tested Environments

  • Full training workflows were developed for Linux HPC environments on SwissAI / CSCS and Kuma.
  • The lightweight demo is designed to run from a standard local Python environment with the dependencies listed in dev-requirements.txt.

Installation

Create a Python environment and install the dependencies required for the lightweight demo:

python3 -m venv .venv
source .venv/bin/activate
pip install --upgrade pip
pip install -r dev-requirements.txt
pip install -e . --no-deps

For full GRPO training and large-scale evaluation, use the cluster-specific setup described below and install the complete training dependencies defined in setup.py.

For the SCS diagnostic or any workflow that relies on vllm, install the full project dependencies instead of the lightweight demo-only environment:

pip install -e .

Quick Start — Download Data and Run

The datasets and model checkpoints are released on Figshare (v3, ~7.3 GB total).

To download and set up everything automatically:

# Download datasets only (~2.2 GB)
python scripts/setup_data.py --data-dir ./data --skip-models

# Or full setup including model checkpoints (~7.3 GB)
python scripts/setup_data.py --data-dir ./data

This downloads from Figshare, verifies MD5 checksums, extracts into the directory layout expected by the training recipes, and writes a .env.local file. Then:

source .env.local && export MIST_DATA_DIR

Setup

Important

Cluster-specific paths should live in .env.local, not in committed files. See cluster/README.md and the example env files for CSCS and Kuma.

Run one of the setup scripts if you want a starter .env.local for a supported cluster:

python3 CSCS_setup.py
# or
python3 kuma_setup.py

For a manual setup, define at least:

export MIST_MODELS_DIR=/path/to/models
export MIST_DATA_DIR=/path/to/data
export MIST_CACHE_DIR=/path/to/cache

The recipes and model registry expand these variables automatically.

Supported cluster launchers:

  • launch_CSCS.slurm for SwissAI / CSCS
  • launch_kuma.slurm for Kuma

[MODEL] is any model specified in model_paths.txt (for example Qwen2.5-3B) and [TASK] is the recipe short name under recipes/ (without the suffix .yaml).

sbatch launch_CSCS.slurm [MODEL] [TASK]

# Example: launch a job for training Qwen2.5-3B as specified in recipes/rxnpred.yaml
sbatch launch_CSCS.slurm Qwen2.5-3B rxnpred

A third optional parameter is [RESUME_JOB_ID]. It is used if you would like to continue the training of a previous job. [RESUME_JOB_ID] should contain the job ID of the previous job you want to continue from. If you want to start a run from scratch (without using a previous run checkpoint), then you can set this parameter to 0 (however it is not necessary since the default value is 0 if the parameter is omitted).

sbatch launch_CSCS.slurm [MODEL] [TASK] [RESUME_JOB_ID]

# Example: launch a job for training Qwen2.5-3B as specified in recipes/rxnpred.yaml, continuing from job ID 123456
sbatch launch_CSCS.slurm Qwen2.5-3B rxnpred 123456

# Example: launch from scratch
sbatch launch_CSCS.slurm Qwen2.5-3B rxnpred 0

A fourth optional parameter is [TASK_MODE], it can be used if you would like to use a specific mode in a task directly from the launch SLURM script. The goal of [TASK_MODE] is to allow to run the same recipe file with the same Task class with some small differences (without rewriting multiple subclasses of the same class), for example:

  • If you would like to process the dataset in a different manner.
  • If you would like to apply different chat templates / prompt templates.
  • If you would like to compute the rewards in a different manner.
  • Etc... (You can do whatever you want)

If you use it, the parameter [TASK_MODE] will be given to the task in self.task_mode. It is useful if you would like to run the same recipe files with multiple task modes without rewriting dozens of individual recipes and task classes.

Notes:

  • If you would like to use the parameter [TASK_MODE], you need to pass it as the fourth parameter. Therefore, you need to specify the third parameter [RESUME_JOB_ID] as well (even if you don't use it). If you do not want to use [RESUME_JOB_ID], you can set it to 0.
  • You can completely omit this fourth parameter, and it won't affect anything if you don't use it. The default value for [TASK_MODE] is "base".
  • The task_mode parameter should never be specified in a recipe file.
sbatch launch_CSCS.slurm [MODEL] [TASK] [RESUME_JOB_ID] [TASK_MODE]

# Example: launch a job for training Qwen2.5-3B as specified in recipes/rxnpred.yaml, running from scratch with task mode "base"
sbatch launch_CSCS.slurm Qwen2.5-3B rxnpred 0 base

A fifth optional parameter is [SAMPLING_PARAMS_CONFIG_NAME], it can be used if you would like to use a specific sampling parameters configuration file during an experiment (it will overwrite the default sampling parameters).

  • You can check the documentation about the "Sampling parameters" in the documentation for more information.
  • This parameter should be the name of the sampling parameters configuration file (without the suffix .json). These files are found in the folder sampling_params/.
sbatch launch_CSCS.slurm [MODEL] [TASK] [RESUME_JOB_ID] [TASK_MODE] [SAMPLING_PARAMS_CONFIG_NAME]

# Example: launch a job for training Qwen2.5-3B_pretrained-v4-cot as specified in recipes/rxnpred.yaml, running from scratch with task mode "base" and uusing the sampling parameters specified in sampling_params/pretrained_models_v1.json
sbatch launch_CSCS.slurm Qwen2.5-3B_pretrained-v4-cot rxnpred 0 base pretrained_models_v1

Since the default values are:

  • [RESUME_JOB_ID] = 0 (start from scratch)
  • [TASK_MODE] = "base" (base task mode)
  • [SAMPLING_PARAMS_CONFIG_NAME] = "default" (default sampling parameters)

The 4 following commands are equivalent:

sbatch launch_CSCS.slurm Qwen2.5-3B rxnpred
sbatch launch_CSCS.slurm Qwen2.5-3B rxnpred 0
sbatch launch_CSCS.slurm Qwen2.5-3B rxnpred 0 base
sbatch launch_CSCS.slurm Qwen2.5-3B rxnpred 0 base default

As a final note, remind that the parameter order is important. If you want to use the last parameters, it is necessary to specify the previous parameters as well (you can use the default values written above).

📖 Documentation

The documentation is built using Sphinx. To build and view the documentation locally:

cd docs
make html
python -m http.server -d build/html

Then open http://localhost:8000 in your browser.

Demo

This repository includes lightweight 50-row fixtures under demo/ for all implemented tasks, enabling smoke testing without downloading full datasets.

  • demo/rxnpred_tiny/ — reaction prediction (40 train + 10 test)
  • demo/datasets/*.csv — 50-row CSV fixtures for I2S, canonicalization, permutation, hydrogen, reaction naming, inversion, replacement, and true/false tasks
  • demo/kinetic_tiny/ — synthetic kinetic data (40 train + 10 val)
  • demo/crystalrelax_tiny/ — M2S crystal structures (40 train + 10 test)
  • demo/condmatgen_tiny/ — element lists for material generation (50 entries)

Run the demo from the repository root:

PYTHONPATH=src python demo/run_demo.py

To smoke-test dataset loaders across all tasks:

PYTHONPATH=src python demo/run_fixture_smoke.py

This exercises fixtures for all tasks: rxnpred, iupacsm, iupacsm_with_tags, canonic, canonmc, smi_permute, smhydrogen, kinetic, rxn_inversion, rxn_replacement, rxn_naming, rxn_truefalse, and conditionally crystalrelax and condmatgen (when heavy dependencies are installed).

See demo/fixture_manifest.csv for the complete task-to-fixture mapping.

Submission and Release Inventory

For the Nature software checklist and the public release, use the release tracking files under release/:

  • release/submission_package_checklist.md tracks which checklist items are already covered by the repository and which assets must still be exported.
  • release/figshare_upload_manifest.csv lists the planned GitHub and Figshare artifacts for code, datasets, manifests, and model snapshots.
  • release/dataset_components.csv is the detailed working index for the datasets and derived components referenced in the manuscript.

Reproducing Manuscript Results

The repository contains the task code, recipes, and launcher examples used for the MiST GRPO experiments. Full reproduction of the training results reported in the manuscript requires a Linux multi-GPU environment and the datasets described in the manuscript appendix.

SCS Diagnostic

The key SMILES Competence Score (SCS) workflow is implemented in src/open_r1/diagnostic/smiles_competence.py.

Before running SCS, install the full project dependencies so that vllm is available:

pip install -e .

For a small reviewer smoke test on the bundled 50-row fixture:

PYTHONPATH=src python src/open_r1/diagnostic/smiles_competence.py \
  --model /path/to/model \
  --data-path demo/datasets/CRLLM-PubChem-compounds1M.sample.csv \
  --output-dir output/scs-smoke \
  --num-rows 50 \
  --tensor-parallel-size 1

This writes:

  • output/scs-smoke/lps_canonical.csv
  • output/scs-smoke/lps_random.csv
  • output/scs-smoke/lps_corrupt.csv
  • output/scs-smoke/summary.json

For the full 10k-example diagnostic run described in the manuscript:

PYTHONPATH=src python src/open_r1/diagnostic/smiles_competence.py \
  --model /path/to/model \
  --data-path "${MIST_DATA_DIR}/CRLLM-PubChem-compounds1M.csv" \
  --output-dir "${MIST_OUTPUT_DIR}/scs/full" \
  --num-rows 10000 \
  --tensor-parallel-size 1

For a multi-GPU cluster launch on CSCS-style infrastructure:

sbatch launch_diagnostics.slurm Qwen2.5-3B

The diagnostics launcher also accepts optional overrides:

sbatch launch_diagnostics.slurm Qwen2.5-3B \
  "${MIST_DATA_DIR}/CRLLM-PubChem-compounds1M.csv" \
  10000 \
  "${MIST_OUTPUT_DIR}/scs/qwen25-3b" \
  4

At a minimum, full reproduction requires:

  • the released MiST code in this repository
  • the released datasets and derived task splits listed in release/
  • the model paths or MiST checkpoints referenced in model_paths.txt
  • a cluster environment compatible with the provided CSCS or Kuma launchers

For release tracking, see:

  • release/release_scope.md
  • release/dataset_components.csv
  • release/reviewer_runs.md

Training Scope

GRPO training is covered by this repository and the cluster launchers above. The mid-training / pretraining pipeline is only partially represented here at the moment: the final public release still needs the preprocessing scripts, manifests, and split definitions tracked under release/.

Single-GPU RL Smoke Run

For a minimal end-to-end GRPO smoke run on a single GPU, use the bundled 50-example reaction-prediction fixture and the smoke accelerate config:

accelerate launch --config_file configs/smoke_single_gpu.yaml \
  src/open_r1/run_r1_grpo.py \
  --config recipes/rxnpred.smoke.yaml \
  --model_name_or_path Qwen/Qwen2.5-3B \
  --output_dir output/rxnpred-smoke \
  --run_name rxnpred-smoke-qwen25-3b \
  --base_model_name Qwen/Qwen2.5-3B \
  --base_model_id Qwen/Qwen2.5-3B

This smoke path does not require a MiST checkpoint. It is intended to validate that the GRPO loop, task loading, reward wiring, and checkpoint output all work with any compatible base model that fits on the available GPU.

Contributing New Tasks

MiST is designed to be easily extensible with new chemistry tasks suitable for reasoning. Each task inherits from the base RLTask class and implements specific logic for data handling and reward calculation.

Creating a New Task

  1. Create a new file in src/open_r1/tasks/ for your task, e.g. sampletask.py
  2. Inherit from the base RLTask class and implement required methods, e.g. SampleTask(RLTask)
    • During the GRPO training script, the methods load, dataset_preprocess and the different reward functions *_reward will be called.
  3. Add class to CHEMTASKS in src/open_r1/tasks/__init__.py, e.g. 'sampletask': SampleTask
  4. Write a recipe with the same name as the task recipes/sampletask.yaml
    • The run will be logged on wandb under the project named r1-[TASK] (e.g. r1-sampletask). Therefore, runs using different recipe files will be logged under different wandb projects.
    • If you add a dot in your recipe filename (e.g. sampletask.variant1.yaml), the run will also be logged under r1-sampletask (everything after the dot will be ignored). This is useful if you want to run multiple experiments with different recipe files but keep them under the same wandb project for analysis.
  5. Add documentation:
    • Create an entry under docs/source/tasks/sampletask.rst (use the template.rst)
    • Add it to the modules index: docs/source/modules.rst as tasks/sampletask.rst

Here's a template for creating a new task:

from open_r1.tasks.base import RLTask
from datasets import DatasetDict

class NewTask(RLTask):
    """
    Description of your new task.
    
    This task should [describe what the task does and its purpose].
    """
    
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.question_template = "Your task-specific question format: {}"
        
    def load(self) -> DatasetDict:
        """
        Load and prepare the dataset for the task.
        
        Returns:
            DatasetDict: Dataset with 'train' and 'test' splits
        """
        # Implement dataset loading logic
        pass
        
    def accuracy_reward(self, completions, solution, **kwargs):
        """
        Calculate rewards for model completions.
        
        Args:
            completions (List[str]): Model generated responses
            solution (List[str]): Ground truth solutions
            
        Returns:
            List[float]: Rewards for each completion
        """
        # Implement reward calculation
        pass
Additional information to create a task class

Here is a list of the methods that you can implement in your task class:

  1. load() (mandatory): This method is responsible for loading the dataset (it is called during the GRPO training). It usually uses the dataset_id_or_path defined in the recipe file (automatically parsed from the recipe file and set in self.dataset_id_or_path in your task) and create the class variable self.dataset.
    • Input: nothing
    • Output: nothing
  2. dataset_preprocess() (optional): This method is called after the load() method in the GRPO training and is used to preprocess the dataset. The method is defined by default in the base class so you can omit it if you don't need to preprocess the dataset in a custom manner. If you would like to add custom preprocessing, you can override this method and takes inspiration from the implementation in the base class (RLTask).
    • Input: tokenizer (huggingface Tokenizer)
      • The tokenizer can be used in the method to apply a chat template for example (tokenizer.apply_chat_template()).
    • Output: dataset (huggingface Dataset)
      • This dataset should contain two splits: dataset["train"] and dataset["test"].
      • Each of these splits should at least contain a column named "prompt", you can also add as many other columns as you need (then the other columns can be used during the computations of the rewards).
  3. *_reward() (optional): You can implement as many reward functions as you want. The list of reward functions used during an experiment should be specified in the recipe file.
    • Input: completions, **kwargs
      • During the GRPO training, the reward functions will be called automatically with the following parameters:
        • completions: list of strings containing the generated text completions (without the prompts). The completions usually contains the thinking and the answer if you follow the standard format.
        • **kwargs: any additional column found in self.dataset will be passed as keyword arguments (in a list in the same way as completions). For example, if you have a column named "prompt", the parameter prompt=... will be taken as input. It is useful if you would like to compute rewards and checking if the solutions/expected answer is found in the completions.
    • Output: rewards (list of reward float values with the same length as completions)
  4. accuracy_reward() (mandatory): This reward function is used to evaluate the accuracy of the reward (if the answer equals the expected solution).
  5. format_reward() (optional): Predefined reward used to reward the correct formatting of the completions ( and tags correctly formatted in the completions).
  6. reasoning_steps_reward() (optional): Predefined reward used to reward a step-by-step thinking in the completions.
  7. get_metrics() (optional): Optional function that can be used to log additional metrics in the wandb run. This function is called automatically during the GRPO training.
    • Input: nothing
    • Output: dictionary of metrics with the format {key[str]: value[float]}
      • Each metrics will be logged in wandb in custom/[key] with the value [value].
      • These metrics can be computed during the reward functions for example and saved in a class variable of your choice. The function get_metrics() just need to output these values for logging.

Example: Forward Reaction Task

The Forward Reaction task demonstrates how to implement a chemical reaction prediction task:

from open_r1.tasks.base import RLTask
from rdkit import Chem

class ForwardReaction(RLTask):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.question_template = (
            f"What is the product of the following reaction? "
            f"Reactants: {self.begin_smiles_tag} {{}} {self.end_smiles_tag}"
        )
    
    def accuracy_reward(self, completions, solution, **kwargs):
        rewards = []
        for content, sol in zip(completions, solution):
            ans = self.preprocess_response(content)
            try:
                if Chem.MolToSmiles(Chem.MolFromSmiles(ans)) == \
                   Chem.MolToSmiles(Chem.MolFromSmiles(sol)):
                    rewards.append(1)
                else:
                    rewards.append(-0.5)
            except:
                rewards.append(-1)
        return rewards

Recipes

To specify a recipe for your task, copy from recipes/template.yaml and modify the section # Chemical Task Arguments:

# recipes/my_task.yaml
chem_task: my_task
dataset_id_or_path: ${MIST_DATA_DIR}/my_dataset.csv  # can also be a HF dataset id
rewards:
- accuracy
- format
task_kwargs:
  my_kwarg1: my_value1
  my_kwarg2: my_value2
  1. chem_task: Name of the task class. The task names are defined in the file src/open_r1/tasks/__init__.py in the keys of the variable CHEMTASKS.
  2. dataset_id_or_path: Path to the dataset. This argument can be used anywhere in the task class, but it is usually used in the load method.
  3. rewards: List of reward functions to be used. The available reward functions depend on the task (you can implement as many as you want). However, each reward function in the task class should ends by _reward and the suffix is omitted in the recipe file. For example, if you want to use your function accuracy_reward(), you need to specify accuracy in the recipe file.
  4. task_kwargs: Special argument (dict-like) that can contain any additional keyworded argument that you would like to pass to your task.
    • This argument is optional and can be omitted in the recipe file if you don't need it.
  5. There are many other training parameters in the recipe file, you can keep the default values (as in recipes/template.yaml) but feel free to modify them if you need it (however it could lead to unexpected results or crashes). If you just built your task, it's recommended to keep the default parameters, ensure that your task is working and then modify these parameters to your needs.

Task Requirements

When creating a new task, ensure:

  1. Base Class Inheritance: Inherit from RLTask
  2. Required Methods: Implement at minimum:
    • load(): Dataset loading
    • accuracy_reward(): Reward calculation
  3. Documentation:
    • Class docstring explaining the task
    • Method docstrings
    • Example usage
  4. Testing: Add tests for your task in tests/

Adding Documentation

  1. Create a new RST file in docs/source/tasks/ for your task. Use the template under tasks/template.rst.
  2. Add your task to docs/source/modules.rst
  3. Include examples and usage instructions
  4. Build and verify the documentation: cd docs; make clean; make html; python -m http.server 7000

Current Tasks

  • Forward Reaction (rxnpred): Chemical reaction product prediction
  • IUPAC to SMILES (iupacsm): Convert IUPAC names to SMILES notation
  • Canonicalize SMILES (canonic): SMILES canonicalization
  • Canonicalize SMILES MCQA (canonmc): Multiple-choice SMILES canonicalization
  • SMILES Permutation (smi_permute): Generate alternative SMILES for same molecule
  • SMILES Hydrogen (smhydrogen): Add/remove implicit hydrogens
  • Kinetic Classification (kinetic): Kinetic reaction mechanism classification
  • Reaction Inversion (rxn_inversion): MCQ — identify correct reaction among inverted fakes
  • Reaction Replacement (rxn_replacement): MCQ — identify correct reaction among modified fakes
  • Reaction Naming (rxn_naming): Classify reactions into 10 named categories
  • Reaction True/False (rxn_truefalse): Binary reaction validity classification
  • Conditional Material Generation (condmatgen): Generate novel crystal compositions from element sets
  • Crystal Structure Relaxation (crystalrelax): Relax perturbed binary crystal structures

For detailed examples and API reference, see the documentation.

Models

The list of models can be found in the file model_paths.txt.

  • The models are stored in the folder LLM_models/.
  • If you want to add a new model, you can add it in the file model_paths.txt (in a new line with the format [model_id]: [path]) and add the model in the appropriate folder (LLM_models).

Multiple custom models were pretrained:

  • Qwen2.5-3B_pretrained-v1
    • Original name: qwen_3b_pretrained
  • Qwen2.5-3B_pretrained-v1_cot-v1
    • Original name: qwen-cot
  • Qwen2.5-3B_pretrained-v2
    • Original name: qwen_3b_sft
  • Qwen2.5-3B_pretrained-v3
    • Original name: qwen_sft_full
  • Qwen2.5-3B_pretrained-v4-cot
    • Original name: qwen_cot_v3

Sampling parameters

It is possible to modify the sampling parameters used during the training by writing configurations in the folder sampling_params/:

  • sampling_params/model_default_sampling_params.txt: This file contains the default sampling parameters configurations used for the models. The line format is the following: [model_id]: [sampling_params_config_name]. It is optional, if a model is not specified, it means that the default sampling parameters will be used.
    • /!\ The default sampling parameters should only be modified when adding a new model and never modified afterwards (for experiment tracking purposes).
    • The default sampling parameters have been modified for custom pretrained models that specifically need a different configuration (these models can't be used with the default configuration).
  • sampling_params/*.json: These files contain the sampling parameters configurations.
    • /!\ The 'default' configuration is a reserved keyword. You can't create a file called sampling_params/default.json. Please use a different naming, the default configuration should never be modified.
    • You can create a new configuration file if you want to experiment with different sampling parameters. In that case, do not modify the file sampling_params/model_default_sampling_params.txt which contain the default configurations only. You can give the name of your sampling parameters configuration (without the suffix '.json') during the launch of the SLURM job (see the launcher examples at the start of the documentation).

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors