Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
982d36c
added cond_mat_gen task
Meng0o May 13, 2025
b84e908
fixed slurm and model_paths
Meng0o May 13, 2025
7cd3a60
changed launch_CSCS.slurm
Meng0o May 13, 2025
8f2297b
Fixed config
Meng0o May 13, 2025
4bca3cd
added deps
Meng0o May 13, 2025
909f4e3
fixed recipe filename
Meng0o May 13, 2025
27384d9
fixed data filename
Meng0o May 13, 2025
8baeea0
removed mkdir in condmatgen
Meng0o May 13, 2025
12e46ed
fixing dataloader
Meng0o May 13, 2025
5d92141
added debugging code for reading data
Meng0o May 13, 2025
3a9d100
fixed data loading error
Meng0o May 13, 2025
16a6ca9
added placeholder solutions
Meng0o May 13, 2025
bb7aa25
moved the randomisation into the read file function
Meng0o May 13, 2025
568bfa6
moved the randomisation into the read file function
Meng0o May 13, 2025
3848abc
fixed random import
Meng0o May 13, 2025
3903f47
fixed random seed
Meng0o May 13, 2025
f586bf4
added debugging code
Meng0o May 13, 2025
55f317b
fixed try statement
Meng0o May 13, 2025
011e47a
added debugging code
Meng0o May 13, 2025
5828d3e
see if its coz of the double quotes
Meng0o May 13, 2025
acfce96
see if its coz of the get
Meng0o May 13, 2025
7d1ad6a
try this keyword
Meng0o May 13, 2025
7802ef6
try this dataset conversion
Meng0o May 13, 2025
ea2dfb8
try this override
Meng0o May 13, 2025
68c3f34
added debugging code
Meng0o May 13, 2025
b808b56
added debugging code
Meng0o May 13, 2025
45e5cfd
added debugging code
Meng0o May 13, 2025
b4fcfd4
added debugging code
Meng0o May 13, 2025
d3c935f
Maybe its this
Meng0o May 13, 2025
d807a76
made problems a list of dicts
Meng0o May 13, 2025
1767bf1
added generate prompt
Meng0o May 13, 2025
fc03efc
changed to 12 hrs
Meng0o May 13, 2025
0f6f4b0
found bug in reward
Meng0o May 14, 2025
e41b663
Added debugging code
Meng0o May 14, 2025
43b5815
added condmatgen.yaml
Meng0o Sep 10, 2025
6fedda6
cleaned up cgm
Meng0o Sep 10, 2025
bbd2473
changed gitignore
Sep 16, 2025
033d4a0
Stop tracking wandb_api_key.txt
Sep 16, 2025
014f8c2
bug fixing cgm
Sep 16, 2025
66d6bd7
improving cgm training
Sep 23, 2025
5f11827
continued
Sep 30, 2025
9d6efd2
Run black and isort for CI compliance
doncamilom Apr 6, 2026
ccf87e6
Make pymatgen/smact imports lazy for CI compatibility
doncamilom Apr 6, 2026
4449832
Fix lint, add missing doc stub, use portable kinetic test
doncamilom Apr 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ celerybeat.pid

# Environments
.env
.env.local
.venv
env/
venv/
Expand Down Expand Up @@ -184,12 +183,10 @@ docs/build/
# WANDB API KEY (PERSONAL & SECRET) and setup
wandb_api_key.txt
kuma.env
cluster/*.local.env
cluster/*.local.env.sh

completion_samples/
output/
core*
*.ipynb
slurm_logs/*.out
slurm_logs/*.err
src/open_r1/dataset/
checkpoints/
26 changes: 26 additions & 0 deletions docs/source/tasks/condmatgen.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
Conditional Material Generation (CMG)
======================================

.. currentmodule:: open_r1.tasks.condmatgen.condmatgen

ConditionalMaterialGeneration
-----------------------------

.. autoclass:: ConditionalMaterialGeneration
:members:
:show-inheritance:

Task Description
----------------

Given a set of chemical elements, the model proposes a novel crystalline
compound (element list and space group number). The model wraps its reasoning
in ``<think>...</think>`` tags and its answer in ``<answer>...</answer>`` tags.

Reward Functions
----------------

- **accuracy**: multi-component scoring including SMACT validity, element
precision, space group validity, and novelty bonus.
- **format**: checks presence and ordering of think/answer tags, penalizes
short reasoning.
40 changes: 18 additions & 22 deletions launch_CSCS.slurm
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
#!/bin/bash
#SBATCH --job-name=grpo-chem
#SBATCH --ntasks-per-node=1
#SBATCH --time=00:30:00
#SBATCH --nodes=1
#SBATCH --time=12:00:00
#SBATCH --nodes=4
#SBATCH --gres=gpu:4
#SBATCH --output=slurm_logs/%x-%j.out
#SBATCH --err=slurm_logs/%x-%j.err
#SBATCH --output=../sink_logs/%x-%j.out
#SBATCH --err=../sink_logs/%x-%j.err
#SBATCH --environment=vllm071
#SBATCH -A a-a05
#SBATCH -A a131

# Run like:
# Example 1: sbatch launch_CSCS.slurm Qwen2.5-0.5B canonmc
# Example 2: sbatch launch_CSCS.slurm Qwen2.5-0.5B canonmc base
# Example 1: sbatch launch.slurm Qwen2.5-0.5B canonmc
# Example 2: sbatch launch.slurm Qwen2.5-0.5B canonmc base
# Argument 1: model id (from model_paths.txt)
# Argument 2: task name (in folder recipes)
# Argument 3 (optional): job_id to continue training from (if none, the default is 0)
Expand All @@ -22,12 +22,8 @@

set -x -e

REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
# shellcheck disable=SC1091
source "${REPO_ROOT}/cluster/common.sh"

source ~/.bashrc
cd "${MIST_REPO_ROOT}"
cd /capstor/store/cscs/swissai/a131/jmeng/sink
echo "START TIME: $(date)"

NUM_NODES=$SLURM_NNODES
Expand All @@ -38,7 +34,7 @@ NUM_GPUS_FOR_TRAINING=$(($WORLD_SIZE - 1))

# so processes know who to talk to
MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
MASTER_PORT=6000
MASTER_PORT=23456

MODEL_ID=$1
TASK=$2
Expand All @@ -65,13 +61,13 @@ echo "TASK_MODE: ${TASK_MODE}"
echo "SAMPLING_PARAMS_CONFIG_NAME: ${SAMPLING_PARAMS_CONFIG_NAME}"


MODEL=$(mist_resolve_model_path "${MODEL_ID}")
MODEL=$(grep "^${MODEL_ID}:" model_paths.txt | cut -d':' -f2- | sed 's/^[[:space:]]*//; s/[[:space:]]*$//')
MODEL_NAME=$(echo $MODEL | sed -E 's/.*models--(.*)--(.*)\/snapshots.*/\1\/\2/' | sed 's/--/\//')
CONFIG_FILE=recipes/$TASK.yaml
CONFIG_FILE=/capstor/store/cscs/swissai/a131/jmeng/sink/recipes/$TASK.yaml

export HF_HOME="${MIST_HF_HOME}"
export HF_HOME=/cache/huggingface
export WANDB_PROJECT="r1-${TASK%%.*}"
export WANDB_API_KEY=$(cat "${MIST_WANDB_API_KEY_FILE}")
export WANDB_API_KEY=$(cat wandb_api_key.txt)

export NCCL_TIMEOUT=3600
export TORCH_DISTRIBUTED_TIMEOUT=3600
Expand All @@ -89,7 +85,7 @@ export CUDA_LAUNCH_BLOCKING=1
export CMD=" \
src/open_r1/run_r1_grpo.py --config $CONFIG_FILE \
--model_name_or_path=$MODEL \
--output_dir=${MIST_CHECKPOINT_DIR}/${MODEL_ID}/${RESUME_JOB_ID} \
--output_dir=/capstor/store/cscs/swissai/a131/jmeng/sink/checkpoints/${MODEL_ID}/${RESUME_JOB_ID} \
--run_name grpo-${MODEL_ID}${TASK_SUFFIX}${TASK_MODE_SUFFIX}-${SLURM_JOB_ID}_from_${RESUME_JOB_ID} \
--base_model_name=$MODEL_NAME \
--base_model_id=$MODEL_ID \
Expand All @@ -102,8 +98,8 @@ export CMD=" \
"

export LAUNCHER="conda deactivate;
cd ${MIST_REPO_ROOT};
pip install hf_transfer gdown levenshtein;
cd /capstor/store/cscs/swissai/a131/jmeng/sink;
pip install hf_transfer gdown levenshtein smact pymatgen datasets;
pip install -e . --no-deps;
HF_HUB_ENABLE_HF_TRANSFER=1 ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch \
--config_file configs/deepspeed_zero3.yaml \
Expand All @@ -118,7 +114,7 @@ HF_HUB_ENABLE_HF_TRANSFER=1 ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=inf
--tee 3 \
"

export WANDBLOG="wandb sync \$(cat ${MIST_LOG_DIR}/grpo-chem-${SLURM_JOB_ID}.out | grep 'saved locally' | awk '{print \$8}')"
export WANDBLOG="wandb sync \$(cat ../sink_logs/grpo-chem-${SLURM_JOB_ID}.out | grep 'saved locally' | awk '{print \$8}')"

export NCCL_ASYNC_ERROR_HANDLING=1

Expand All @@ -133,4 +129,4 @@ SRUN_ARGS=" \

clear; srun $SRUN_ARGS bash -c "$LAUNCHER --role \$SLURMD_NODENAME: $CMD; $WANDBLOG" 2>&1

echo "END TIME: $(date)"
echo "END TIME: $(date)"
35 changes: 18 additions & 17 deletions model_paths.txt
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
# Set MIST_MODELS_DIR in .env.local or your shell before launching jobs.
Qwen2.5-0.5B: ${MIST_MODELS_DIR}/models--Qwen--Qwen2.5-0.5B/snapshots/060db6499f32faf8b98477b0a26969ef7d8b9987
Qwen2.5-3B: ${MIST_MODELS_DIR}/models--Qwen--Qwen2.5-3B/snapshots/3aab1f1954e9cc14eb9509a215f9e5ca08227a9b
Qwen2.5-7B: ${MIST_MODELS_DIR}/models--Qwen--Qwen2.5-7B/snapshots/d149729398750b98c0af14eb82c78cfe92750796
Qwen2.5-3B-Instruct: ${MIST_MODELS_DIR}/models--Qwen--Qwen2.5-3B-Instruct/snapshots/aa8e72537993ba99e69dfaafa59ed015b17504d1
Llama-3.1-8B: ${MIST_MODELS_DIR}/models--meta-llama--Llama-3.1-8B/snapshots/d04e592bb4f6aa9cfee91e2e20afa771667e1d4b
Llama-3-8B: ${MIST_MODELS_DIR}/models--meta-llama--Meta-Llama-3-8B/snapshots/8cde5ca8380496c9a6cc7ef3a8b46a0372a1d920
DeepSeek-R1-Distill-Qwen-1.5B: ${MIST_MODELS_DIR}/models--deepseek-ai--DeepSeek-R1-Distill-Qwen-1.5B/snapshots/530ca3e1ad39d440e182c2e4317aa40f012512fa
DeepSeek-R1-Distill-Qwen-7B: ${MIST_MODELS_DIR}/models--deepseek-ai--DeepSeek-R1-Distill-Qwen-7B/snapshots/916b56a44061fd5cd7d6a8fb632557ed4f724f60
DeepSeek-R1-Distill-Qwen-14B: ${MIST_MODELS_DIR}/models--deepseek-ai--DeepSeek-R1-Distill-Qwen-14B/snapshots/1df8507178afcc1bef68cd8c393f61a886323761
DeepSeek-R1-Distill-Qwen-32B: ${MIST_MODELS_DIR}/models--deepseek-ai--DeepSeek-R1-Distill-Qwen-32B/snapshots/711ad2ea6aa40cfca18895e8aca02ab92df1a746
DeepSeek-R1-Distill-Llama-70B: ${MIST_MODELS_DIR}/models--deepseek-ai--DeepSeek-R1-Distill-Llama-70B/snapshots/b1c0b44b4369b597ad119a196caf79a9c40e141e
Qwen2.5-3B_pretrained-v1: ${MIST_MODELS_DIR}/Qwen2.5-3B_pretrained-v1
Qwen2.5-3B_pretrained-v1_cot-v1: ${MIST_MODELS_DIR}/Qwen2.5-3B_pretrained-v1_cot-v1
Qwen2.5-3B_pretrained-v2: ${MIST_MODELS_DIR}/Qwen2.5-3B_pretrained-v2
Qwen2.5-3B_pretrained-v3: ${MIST_MODELS_DIR}/Qwen2.5-3B_pretrained-v3
Qwen2.5-3B_pretrained-v4-cot: ${MIST_MODELS_DIR}/Qwen2.5-3B_pretrained-v4-cot
Qwen2.5-0.5B: /LLM_models/models--Qwen--Qwen2.5-0.5B/snapshots/060db6499f32faf8b98477b0a26969ef7d8b9987
Qwen2.5-3B: /LLM_models/models--Qwen--Qwen2.5-3B/snapshots/3aab1f1954e9cc14eb9509a215f9e5ca08227a9b
Qwen2.5-7B: /LLM_models/models--Qwen--Qwen2.5-7B/snapshots/d149729398750b98c0af14eb82c78cfe92750796
Qwen2.5-3B-Instruct: /LLM_models/models--Qwen--Qwen2.5-3B-Instruct/snapshots/aa8e72537993ba99e69dfaafa59ed015b17504d1
Llama-3.1-8B: /LLM_models/models--meta-llama--Llama-3.1-8B/snapshots/d04e592bb4f6aa9cfee91e2e20afa771667e1d4b
Llama-3-8B: /LLM_models/models--meta-llama--Meta-Llama-3-8B/snapshots/8cde5ca8380496c9a6cc7ef3a8b46a0372a1d920
DeepSeek-R1-Distill-Qwen-1.5B: /LLM_models/models--deepseek-ai--DeepSeek-R1-Distill-Qwen-1.5B/snapshots/530ca3e1ad39d440e182c2e4317aa40f012512fa
DeepSeek-R1-Distill-Qwen-7B: /LLM_models/models--deepseek-ai--DeepSeek-R1-Distill-Qwen-7B/snapshots/916b56a44061fd5cd7d6a8fb632557ed4f724f60
DeepSeek-R1-Distill-Qwen-14B: /LLM_models/models--deepseek-ai--DeepSeek-R1-Distill-Qwen-14B/snapshots/1df8507178afcc1bef68cd8c393f61a886323761
DeepSeek-R1-Distill-Qwen-32B: /LLM_models/models--deepseek-ai--DeepSeek-R1-Distill-Qwen-32B/snapshots/711ad2ea6aa40cfca18895e8aca02ab92df1a746
DeepSeek-R1-Distill-Llama-70B: /LLM_models/models--deepseek-ai--DeepSeek-R1-Distill-Llama-70B/snapshots/b1c0b44b4369b597ad119a196caf79a9c40e141e
Qwen2.5-3B_pretrained-v1: /LLM_models/Qwen2.5-3B_pretrained-v1
Qwen2.5-3B_pretrained-v1_cot-v1: /LLM_models/Qwen2.5-3B_pretrained-v1_cot-v1
Qwen2.5-3B_pretrained-v2: /LLM_models/Qwen2.5-3B_pretrained-v2
Qwen2.5-3B_pretrained-v3: /LLM_models/Qwen2.5-3B_pretrained-v3
Qwen2.5-3B_pretrained-v4-cot: /LLM_models/Qwen2.5-3B_pretrained-v4-cot
Qwen2.5-3B_pretrained-v6-1: /LLM_models/Qwen2.5-3B_pretrained-v6-1
Qwen2.5-3B_pretrained_sft_epoch8: /capstor/store/cscs/swissai/a131/jmeng/megatron/qwen-ckpts/pretrained_sft_qwen2.5_3B/pretrained_sft_qwen2.5_3B_epoch8
49 changes: 49 additions & 0 deletions recipes/condmatgen.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Model arguments
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
bf16: true
tf32: true

# Chemical Task arguments
chem_task: condmatgen
dataset_id_or_path: /capstor/store/cscs/swissai/a131/jmeng/sink/src/open_r1/dataset/
rewards:
- accuracy

# Lora Arguments
# No LoRA is used here

# Training arguments
max_steps: 1450
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
learning_rate: 2.0e-6 # 1.0e-6 as in the deepseek math paper 5-e7 from https://hijkzzz.notion.site/unraveling-rlhf-and-its-variants-engineering-insights#147d9a33ecc9806090f3d5c749d31f05
lr_scheduler_type: cosine
warmup_ratio: 0.03
# GRPO specific parameters
beta: 0.04 # 0.04 as in the deepseek math paper 0.001 from https://hijkzzz.notion.site/unraveling-rlhf-and-its-variants-engineering-insights#147d9a33ecc9806090f3d5c749d31f05
max_prompt_length: 600
max_completion_length: 2048
num_generations: 4
use_vllm: true
vllm_device: "cuda:3"
vllm_gpu_memory_utilization: 0.8
vllm_max_model_len: 2048

# Logging arguments
logging_strategy: steps
logging_steps: 1
report_to:
- wandb

save_strategy: "steps"
save_steps: 25
seed: 42

# Hugging Face Hub
push_to_hub: false
# hub_model_id: llama-3-1-8b-math-orca-qlora-10k-ep1 # if not defined same as output_dir
2 changes: 2 additions & 0 deletions run_rl.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#!/bin/bash
sbatch launch_CSCS.slurm Qwen2.5-3B_pretrained_sft_epoch8 condmatgen
2 changes: 0 additions & 2 deletions src/open_r1/run_r1_grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
from transformers import AutoTokenizer

from paths import expand_path
from tasks import CHEMTASKS
from trl import GRPOConfig, ModelConfig, TrlParser, get_peft_config
from utils import (
Expand Down Expand Up @@ -104,7 +103,6 @@ def grpo_function(model_args: ModelConfig, training_args: GRPOConfig):
def main():
parser = TrlParser((ModelConfig, ExtendedGRPOConfig))
model_args, training_args = parser.parse_args_and_config()
training_args.dataset_id_or_path = expand_path(training_args.dataset_id_or_path)
training_args = load_sampling_params_config(training_args) # Load sampling parameters
grpo_function(model_args, training_args)

Expand Down
7 changes: 7 additions & 0 deletions src/open_r1/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,10 @@
"smhydrogen": SmilesHydrogen,
"kinetic": KineticDataClassification,
}

try:
from .condmatgen.condmatgen import ConditionalMaterialGeneration

CHEMTASKS["condmatgen"] = ConditionalMaterialGeneration
except ImportError:
pass
2 changes: 1 addition & 1 deletion src/open_r1/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def get_metrics(self) -> dict:
"""
return dict()

def random_print(self, print_data: dict, out_rate=0.01):
def random_print(self, print_data: dict, out_rate=0.1):
if random.random() < out_rate: # 1% chance to print a completion
out = "\n\n=======<RANDOM_RESPONSE>=======\n"
for k, v in print_data.items():
Expand Down
1 change: 1 addition & 0 deletions src/open_r1/tasks/condmatgen/comps_used_in_sft.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[]
Loading
Loading