Skip to content

atomicarchitects/equiformer_v3

EquiformerV3: Scaling Efficient, Expressive, and General SE(3)-Equivariant Graph Attention Transformers

Paper (will be on ArXiv soon) | Checkpoint

This repository contains the official PyTorch implementation of the work "EquiformerV3: Scaling Efficient, Expressive, and General SE(3)-Equivariant Graph Attention Transformers". We provide the code for training on the OC20 S2EF-2M, MPtrj, OMat24, and sAlex datasets and for evaluation on Matbench Discovery.

Additional training configs for more datasets will be added in the future.

This repository is based on this version of fairchem. We include the original codebase for ease of reproducibility and place the code relevant to our work under experimental.

photo not available

photo not available

Content

  1. Environment Setup
  2. File Structure
  3. Training
  4. Checkpoint
  5. Evaluation
  6. Acknowledgement

Environment Setup

Environment

See here for setting up the environment.

OC20

  1. The OC20 S2EF dataset can be downloaded by following instructions in their GitHub repository.

  2. For example, we can download the OC20 S2EF-2M dataset by running:

        cd ocp
        python scripts/download_data.py --task s2ef --split "2M" --num-workers 8
  3. We note that we remove --ref-energy since we now train on total energy labels instead of adsorption energy labels.

MPtrj

  1. Download the dataset (MPtrj_2022.9_full.json) here.

  2. Update the path to the .json dataset and the path to save the processed dataset in this file.

  3. Run the following command to convert .json dataset into .aselmdb dataset:

        python experimental/datasets/mptrj_convert_json_to_aselmdb.py
  4. We remove structures in which any atom has no neighbor within 6Å.

  5. We create metadata.npz, which records the number of edges for each structure for better load balancing:

        # Path to the directory containing .aselmdb files
        ASELMDB_DATASET=""
    
        python experimental/datasets/create_metadata_num_edges.py --input_dir $ASELMDB_DATASET
        
        # Rename to `metadata.npz`
        cd $ASELMDB_DATASET
        cp metadata_num-edges.npz metadata.npz

OMat24 and sAlex

  1. The datasets can be found here.

  2. Same as MPtrj above, we create metadata.npz, which records the number of edges for each structure for better load balancing:

        # Path to the directory containing .aselmdb files
        ASELMDB_DATASET=""
    
        python experimental/datasets/create_metadata_num_edges.py --input_dir $ASELMDB_DATASET
        
        cd $ASELMDB_DATASET
        
        # Deprecate the original `metadata.npz`
        mv metadata.npz metadata_num-nodes.npz
        # Instead, use the one recording the number of edges
        cp metadata_num-edges.npz metadata.npz
  3. For OMat24, we repeat 2. for all directories under train and val. That is, we need to do that for aimd-from-PBE-1000-npt, aimd-from-PBE-1000-nvt, aimd-from-PBE-3000-npt, aimd-from-PBE-3000-nvt, rattled-1000, rattled-1000-subsampled, rattled-300, rattled-300-subsampled, rattled-500, rattled-500-subsampled, rattled-relax.

File Structure

We place all the files relevant to our work under experimental.

  1. configs contains config files for training and evaluation.
  2. datasets contains utility functions to preprocess datasets.
  3. models contains EquiformerV3 (+ DeNS) models.
  4. scripts contains the scripts for training and evaluation.
  5. tasks contains the code of running simulations in Matbench Discovery, testing equivariance, and conducting body-order experiments.
  6. trainers contains the code for training and evaluation.

Training

OC20

  1. OC20 S2EF-2M dataset (Index 7 in Table 1).

    a. Modify the path to the training set and the path to the validation set in the config file.

    b. Update the path to save results in the training script.

    c. Run:

        sh experimental/scripts/train/oc20/s2ef/equiformer_v3/equiformer_v3_splits@2M_g@8.sh

MPtrj

We provide the config and script for training EquiformerV3 with $L_{max} = 4$ here. The preprocessing of MPtrj data is here.

  1. Direct pre-training

    a. Modify the path to the training set (the full MPtrj dataset) and the path to the validation set (we used a subset of sAlex validation set as the final evaluation is on Matbench Discovery) in the config file.

    b. Modify the path to save results in the training script.

    c. The training script requires two nodes with 8 GPUs on each node. We note that the training script provides an example of launching distributed training on 2 nodes and that training can be launched in different manners.

    d. Run:

        bash experimental/scripts/train/omat24/equiformer_v3/equiformer_v3_mptrj.sh
  2. Remove energy head from pre-trained checkpoint

    a. After direct pre-training, we remove the energy head from the checkpoints by running:

        # Path to the checkpoint of direct pre-training
        CHECKPOINT=""
    
        python experimental/tasks/remove_key_from_checkpoint.py --input-path $CHECKPOINT --remove-key energy_block

    b. This creates a new checkpoint (.../checkpoint_no-energy_block.pt), which is used to initialize model weights during gradient fine-tuning.

  3. Gradient fine-tuning

    a. Modify the path to the training set (the full MPtrj dataset) and the path to the validation set (we used a subset of sAlex validation set as the final evaluation is on Matbench Discovery) in the config file.

    b. Modify the path to pre-trained checkpoint. The path should be something like .../checkpoint_no-energy_block.pt obtained by running 2. above.

    c. Modify the path to save results in the training script.

    d. The training script requires two nodes with 8 GPUs on each node. We note that the training script provides an example of launching distributed training on 2 nodes and that training can be launched in different manners.

    e. Run:

        bash experimental/scripts/train/omat24/equiformer_v3/equiformer_v3_grad_mptrj.sh

OMat24 → MPtrj and sAlex

We provide the config and script for training EquiformerV3 with $L_{max} = 4$ and $L_{max} = 6$ here. The preprocessing of MPtrj data is here.

  1. Direct pre-training on OMat24

    a. Modify the path to the training set and the path to the validation set in the config file.

    b. Modify the path to save results in the training script.

    c. The training script requires four nodes with 8 GPUs on each node. We note that the training script provides an example of launching distributed training on four nodes and that training can be launched in different manners.

    d. Run:

        bash experimental/scripts/train/omat24/equiformer_v3/equiformer_v3_omat24.sh

    e. Repeat the above steps for $L_{max} = 6$.

  2. Gradient fine-tuning on OMat24

    a. Modify the path to the training set and the path to the validation set in the config file.

    b. Modify the path to pre-trained checkpoint. The path should be something like .../checkpoint.pt.

    c. Modify the path to save results in the training script.

    d. The training script requires four nodes with 8 GPUs on each node. We note that the training script provides an example of launching distributed training on four nodes and that training can be launched in different manners.

    e. Run:

        bash experimental/scripts/train/omat24/equiformer_v3/equiformer_v3_grad_omat24.sh

    f. Repeat the above steps for $L_{max} = 6$.

  3. Gradient fine-tuning on MPtrj and sAlex

    a. Modify the path to the training set and the path to the validation set in the config file.

    b. Modify the path to pre-trained checkpoint. The path should be something like .../checkpoint.pt.

    c. Modify the path to save results in the training script.

    d. The training script requires four nodes with 8 GPUs on each node. We note that the training script provides an example of launching distributed training on four nodes and that training can be launched in different manners.

    e. Run:

        bash experimental/scripts/train/omat24/equiformer_v3/equiformer_v3_grad_salex-mptrj.sh

    f. Repeat the above steps for $L_{max} = 6$.

Checkpoint

Trained checkpoints can be found in the HuggingFace page.

Evaluation

OC20

  1. OC20 S2EF-2M dataset (Index 7 in Table 1)

    a. Follow the 1.a. here to update the config file.

    b. Download the OC20 checkpoint here.

    c. Modify the path to save results, the path to the validation set, and the path to checkpoint in the evaluation script.

    d. Run the script:

        sh experimental/scripts/eval/oc20/s2ef/equiformer_v3/equiformer_v3_splits@2M_g@8.sh

OMat24

  1. Evaluate on OMat24 validation set

    a. Follow here to update the config files (direct and gradient).

    b. Download the OMat24 checkpoint here.

    c. Modify the path to save results, the path to config file, and the path to checkpoint in the evaluation script.

    d. Run the script:

        sh experimental/scripts/eval/omat24/s2ef/equiformer_v3/equiformer_v3_g@8.sh

Matbench Discovery

  1. Evaluate on discovery metrics

    a. Modify the path to save preprocessed data and then run the command:

        cd experimental/datasets
        python matbench_discovery_create_aselmdb.py

    b. Modify the path to checkpoint, the path to save results, and the path to the dataset in the evaluation script.

    c. Run the calculation script:

        sh experimental/scripts/eval/matbench_discovery/discovery.sh

    d. Postprocess the calculation:

        # Path to save results in b. and c.
        INPUT_DIR=""
    
        python experimental/tasks/matbench_discovery/join_preds.py --input-dir $INPUT_DIR

    e. Evaluate the calculations and print results like F1 score and RMSD:

        # Path to save results in b. and c.
        INPUT_DIR=""
    
        python experimental/tasks/matbench_discovery/evaluate_discovery.py --input-dir $INPUT_DIR

    This would take about 30 minutes to get the RMSD results.

  2. Evaluate on thermal conductivity task ($\kappa_{\text{SRME}}$)

    a. Modify the path to save results, and the path to checkpoint in the evaluation script.

    b. Run the script:

        sh experimental/scripts/eval/matbench_discovery/kappa.sh

    c. The results will be printed out after running the above command. If there is 2.0, it is possibly because a certain structure hits out-of-memory error on GPUs. We can run the command on CPU for those structures.

Acknowledgement

Our implementation is based on PyTorch, PyG, e3nn, timm, fairchem, Equiformer, EquiformerV2, DeNS, Matbench Discovery, and Nequix.