diff --git a/.gitignore b/.gitignore index 173cf8c3..2511b116 100644 --- a/.gitignore +++ b/.gitignore @@ -193,3 +193,6 @@ core* *.ipynb slurm_logs/*.out slurm_logs/*.err + +# BINARY COMPOUND EXPERIMENT LOG FILES +src/open_r1/tasks/crystal_structure/reward_server/reward_logs/ diff --git a/demo/crystalrelax_tiny/src-test.txt b/demo/crystalrelax_tiny/src-test.txt new file mode 100644 index 00000000..9a3812c2 --- /dev/null +++ b/demo/crystalrelax_tiny/src-test.txt @@ -0,0 +1,60 @@ +serialized_cif formula Cs 1_int I 1_int +space_group_symbol Pm-3m_sg +lattice_parameters a 4.57000000 b 4.57000000 c 4.57000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Cs 1_int 0.00000000 0.00000000 0.00000000 +I 1_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Rb 1_int Cl 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 6.58000000 b 6.58000000 c 6.58000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Rb 4_int 0.00000000 0.00000000 0.00000000 +Cl 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Cs 1_int I 1_int +space_group_symbol Pm-3m_sg +lattice_parameters a 4.57000000 b 4.57000000 c 4.57000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Cs 1_int 0.00000000 0.00000000 0.00000000 +I 1_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Rb 1_int Cl 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 6.58000000 b 6.58000000 c 6.58000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Rb 4_int 0.00000000 0.00000000 0.00000000 +Cl 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Cs 1_int I 1_int +space_group_symbol Pm-3m_sg +lattice_parameters a 4.57000000 b 4.57000000 c 4.57000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Cs 1_int 0.00000000 0.00000000 0.00000000 +I 1_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Rb 1_int Cl 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 6.58000000 b 6.58000000 c 6.58000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Rb 4_int 0.00000000 0.00000000 0.00000000 +Cl 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Cs 1_int I 1_int +space_group_symbol Pm-3m_sg +lattice_parameters a 4.57000000 b 4.57000000 c 4.57000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Cs 1_int 0.00000000 0.00000000 0.00000000 +I 1_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Rb 1_int Cl 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 6.58000000 b 6.58000000 c 6.58000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Rb 4_int 0.00000000 0.00000000 0.00000000 +Cl 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Cs 1_int I 1_int +space_group_symbol Pm-3m_sg +lattice_parameters a 4.57000000 b 4.57000000 c 4.57000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Cs 1_int 0.00000000 0.00000000 0.00000000 +I 1_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Rb 1_int Cl 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 6.58000000 b 6.58000000 c 6.58000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Rb 4_int 0.00000000 0.00000000 0.00000000 +Cl 4_int 0.50000000 0.50000000 0.50000000 + diff --git a/demo/crystalrelax_tiny/src-train.txt b/demo/crystalrelax_tiny/src-train.txt new file mode 100644 index 00000000..0f147494 --- /dev/null +++ b/demo/crystalrelax_tiny/src-train.txt @@ -0,0 +1,240 @@ +serialized_cif formula Na 1_int Cl 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 5.69000000 b 5.69000000 c 5.69000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Na 4_int 0.00000000 0.00000000 0.00000000 +Cl 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Mg 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 4.21000000 b 4.21000000 c 4.21000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Mg 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Ca 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 4.81000000 b 4.81000000 c 4.81000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Ca 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Zn 1_int S 1_int +space_group_symbol F-43m_sg +lattice_parameters a 5.41000000 b 5.41000000 c 5.41000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Zn 4_int 0.00000000 0.00000000 0.00000000 +S 4_int 0.25000000 0.25000000 0.25000000 + +serialized_cif formula Li 1_int F 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 4.03000000 b 4.03000000 c 4.03000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Li 4_int 0.00000000 0.00000000 0.00000000 +F 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula K 1_int Br 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 6.60000000 b 6.60000000 c 6.60000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +K 4_int 0.00000000 0.00000000 0.00000000 +Br 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Ba 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 5.52000000 b 5.52000000 c 5.52000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Ba 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Sr 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 5.16000000 b 5.16000000 c 5.16000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Sr 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Na 1_int Cl 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 5.69000000 b 5.69000000 c 5.69000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Na 4_int 0.00000000 0.00000000 0.00000000 +Cl 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Mg 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 4.21000000 b 4.21000000 c 4.21000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Mg 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Ca 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 4.81000000 b 4.81000000 c 4.81000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Ca 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Zn 1_int S 1_int +space_group_symbol F-43m_sg +lattice_parameters a 5.41000000 b 5.41000000 c 5.41000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Zn 4_int 0.00000000 0.00000000 0.00000000 +S 4_int 0.25000000 0.25000000 0.25000000 + +serialized_cif formula Li 1_int F 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 4.03000000 b 4.03000000 c 4.03000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Li 4_int 0.00000000 0.00000000 0.00000000 +F 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula K 1_int Br 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 6.60000000 b 6.60000000 c 6.60000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +K 4_int 0.00000000 0.00000000 0.00000000 +Br 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Ba 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 5.52000000 b 5.52000000 c 5.52000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Ba 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Sr 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 5.16000000 b 5.16000000 c 5.16000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Sr 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Na 1_int Cl 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 5.69000000 b 5.69000000 c 5.69000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Na 4_int 0.00000000 0.00000000 0.00000000 +Cl 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Mg 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 4.21000000 b 4.21000000 c 4.21000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Mg 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Ca 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 4.81000000 b 4.81000000 c 4.81000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Ca 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Zn 1_int S 1_int +space_group_symbol F-43m_sg +lattice_parameters a 5.41000000 b 5.41000000 c 5.41000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Zn 4_int 0.00000000 0.00000000 0.00000000 +S 4_int 0.25000000 0.25000000 0.25000000 + +serialized_cif formula Li 1_int F 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 4.03000000 b 4.03000000 c 4.03000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Li 4_int 0.00000000 0.00000000 0.00000000 +F 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula K 1_int Br 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 6.60000000 b 6.60000000 c 6.60000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +K 4_int 0.00000000 0.00000000 0.00000000 +Br 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Ba 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 5.52000000 b 5.52000000 c 5.52000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Ba 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Sr 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 5.16000000 b 5.16000000 c 5.16000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Sr 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Na 1_int Cl 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 5.69000000 b 5.69000000 c 5.69000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Na 4_int 0.00000000 0.00000000 0.00000000 +Cl 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Mg 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 4.21000000 b 4.21000000 c 4.21000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Mg 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Ca 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 4.81000000 b 4.81000000 c 4.81000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Ca 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Zn 1_int S 1_int +space_group_symbol F-43m_sg +lattice_parameters a 5.41000000 b 5.41000000 c 5.41000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Zn 4_int 0.00000000 0.00000000 0.00000000 +S 4_int 0.25000000 0.25000000 0.25000000 + +serialized_cif formula Li 1_int F 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 4.03000000 b 4.03000000 c 4.03000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Li 4_int 0.00000000 0.00000000 0.00000000 +F 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula K 1_int Br 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 6.60000000 b 6.60000000 c 6.60000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +K 4_int 0.00000000 0.00000000 0.00000000 +Br 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Ba 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 5.52000000 b 5.52000000 c 5.52000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Ba 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Sr 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 5.16000000 b 5.16000000 c 5.16000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Sr 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Na 1_int Cl 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 5.69000000 b 5.69000000 c 5.69000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Na 4_int 0.00000000 0.00000000 0.00000000 +Cl 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Mg 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 4.21000000 b 4.21000000 c 4.21000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Mg 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Ca 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 4.81000000 b 4.81000000 c 4.81000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Ca 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Zn 1_int S 1_int +space_group_symbol F-43m_sg +lattice_parameters a 5.41000000 b 5.41000000 c 5.41000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Zn 4_int 0.00000000 0.00000000 0.00000000 +S 4_int 0.25000000 0.25000000 0.25000000 + +serialized_cif formula Li 1_int F 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 4.03000000 b 4.03000000 c 4.03000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Li 4_int 0.00000000 0.00000000 0.00000000 +F 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula K 1_int Br 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 6.60000000 b 6.60000000 c 6.60000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +K 4_int 0.00000000 0.00000000 0.00000000 +Br 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Ba 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 5.52000000 b 5.52000000 c 5.52000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Ba 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Sr 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 5.16000000 b 5.16000000 c 5.16000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Sr 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + diff --git a/demo/crystalrelax_tiny/tgt-test.txt b/demo/crystalrelax_tiny/tgt-test.txt new file mode 100644 index 00000000..098c1b0f --- /dev/null +++ b/demo/crystalrelax_tiny/tgt-test.txt @@ -0,0 +1,60 @@ +serialized_cif formula Cs 1_int I 1_int +space_group_symbol Pm-3m_sg +lattice_parameters a 4.56700000 b 4.56700000 c 4.56700000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Cs 1_int 0.00000000 0.00000000 0.00000000 +I 1_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Rb 1_int Cl 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 6.57800000 b 6.57800000 c 6.57800000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Rb 4_int 0.00000000 0.00000000 0.00000000 +Cl 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Cs 1_int I 1_int +space_group_symbol Pm-3m_sg +lattice_parameters a 4.56700000 b 4.56700000 c 4.56700000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Cs 1_int 0.00000000 0.00000000 0.00000000 +I 1_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Rb 1_int Cl 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 6.57800000 b 6.57800000 c 6.57800000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Rb 4_int 0.00000000 0.00000000 0.00000000 +Cl 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Cs 1_int I 1_int +space_group_symbol Pm-3m_sg +lattice_parameters a 4.56700000 b 4.56700000 c 4.56700000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Cs 1_int 0.00000000 0.00000000 0.00000000 +I 1_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Rb 1_int Cl 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 6.57800000 b 6.57800000 c 6.57800000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Rb 4_int 0.00000000 0.00000000 0.00000000 +Cl 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Cs 1_int I 1_int +space_group_symbol Pm-3m_sg +lattice_parameters a 4.56700000 b 4.56700000 c 4.56700000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Cs 1_int 0.00000000 0.00000000 0.00000000 +I 1_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Rb 1_int Cl 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 6.57800000 b 6.57800000 c 6.57800000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Rb 4_int 0.00000000 0.00000000 0.00000000 +Cl 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Cs 1_int I 1_int +space_group_symbol Pm-3m_sg +lattice_parameters a 4.56700000 b 4.56700000 c 4.56700000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Cs 1_int 0.00000000 0.00000000 0.00000000 +I 1_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Rb 1_int Cl 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 6.57800000 b 6.57800000 c 6.57800000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Rb 4_int 0.00000000 0.00000000 0.00000000 +Cl 4_int 0.50000000 0.50000000 0.50000000 + diff --git a/demo/crystalrelax_tiny/tgt-train.txt b/demo/crystalrelax_tiny/tgt-train.txt new file mode 100644 index 00000000..91d4669c --- /dev/null +++ b/demo/crystalrelax_tiny/tgt-train.txt @@ -0,0 +1,240 @@ +serialized_cif formula Na 1_int Cl 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 5.64000000 b 5.64000000 c 5.64000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Na 4_int 0.00000000 0.00000000 0.00000000 +Cl 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Mg 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 4.21200000 b 4.21200000 c 4.21200000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Mg 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Ca 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 4.81100000 b 4.81100000 c 4.81100000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Ca 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Zn 1_int S 1_int +space_group_symbol F-43m_sg +lattice_parameters a 5.40900000 b 5.40900000 c 5.40900000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Zn 4_int 0.00000000 0.00000000 0.00000000 +S 4_int 0.25000000 0.25000000 0.25000000 + +serialized_cif formula Li 1_int F 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 4.02800000 b 4.02800000 c 4.02800000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Li 4_int 0.00000000 0.00000000 0.00000000 +F 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula K 1_int Br 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 6.59500000 b 6.59500000 c 6.59500000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +K 4_int 0.00000000 0.00000000 0.00000000 +Br 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Ba 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 5.51800000 b 5.51800000 c 5.51800000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Ba 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Sr 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 5.15900000 b 5.15900000 c 5.15900000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Sr 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Na 1_int Cl 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 5.64000000 b 5.64000000 c 5.64000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Na 4_int 0.00000000 0.00000000 0.00000000 +Cl 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Mg 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 4.21200000 b 4.21200000 c 4.21200000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Mg 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Ca 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 4.81100000 b 4.81100000 c 4.81100000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Ca 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Zn 1_int S 1_int +space_group_symbol F-43m_sg +lattice_parameters a 5.40900000 b 5.40900000 c 5.40900000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Zn 4_int 0.00000000 0.00000000 0.00000000 +S 4_int 0.25000000 0.25000000 0.25000000 + +serialized_cif formula Li 1_int F 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 4.02800000 b 4.02800000 c 4.02800000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Li 4_int 0.00000000 0.00000000 0.00000000 +F 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula K 1_int Br 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 6.59500000 b 6.59500000 c 6.59500000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +K 4_int 0.00000000 0.00000000 0.00000000 +Br 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Ba 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 5.51800000 b 5.51800000 c 5.51800000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Ba 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Sr 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 5.15900000 b 5.15900000 c 5.15900000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Sr 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Na 1_int Cl 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 5.64000000 b 5.64000000 c 5.64000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Na 4_int 0.00000000 0.00000000 0.00000000 +Cl 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Mg 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 4.21200000 b 4.21200000 c 4.21200000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Mg 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Ca 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 4.81100000 b 4.81100000 c 4.81100000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Ca 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Zn 1_int S 1_int +space_group_symbol F-43m_sg +lattice_parameters a 5.40900000 b 5.40900000 c 5.40900000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Zn 4_int 0.00000000 0.00000000 0.00000000 +S 4_int 0.25000000 0.25000000 0.25000000 + +serialized_cif formula Li 1_int F 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 4.02800000 b 4.02800000 c 4.02800000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Li 4_int 0.00000000 0.00000000 0.00000000 +F 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula K 1_int Br 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 6.59500000 b 6.59500000 c 6.59500000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +K 4_int 0.00000000 0.00000000 0.00000000 +Br 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Ba 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 5.51800000 b 5.51800000 c 5.51800000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Ba 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Sr 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 5.15900000 b 5.15900000 c 5.15900000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Sr 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Na 1_int Cl 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 5.64000000 b 5.64000000 c 5.64000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Na 4_int 0.00000000 0.00000000 0.00000000 +Cl 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Mg 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 4.21200000 b 4.21200000 c 4.21200000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Mg 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Ca 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 4.81100000 b 4.81100000 c 4.81100000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Ca 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Zn 1_int S 1_int +space_group_symbol F-43m_sg +lattice_parameters a 5.40900000 b 5.40900000 c 5.40900000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Zn 4_int 0.00000000 0.00000000 0.00000000 +S 4_int 0.25000000 0.25000000 0.25000000 + +serialized_cif formula Li 1_int F 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 4.02800000 b 4.02800000 c 4.02800000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Li 4_int 0.00000000 0.00000000 0.00000000 +F 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula K 1_int Br 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 6.59500000 b 6.59500000 c 6.59500000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +K 4_int 0.00000000 0.00000000 0.00000000 +Br 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Ba 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 5.51800000 b 5.51800000 c 5.51800000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Ba 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Sr 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 5.15900000 b 5.15900000 c 5.15900000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Sr 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Na 1_int Cl 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 5.64000000 b 5.64000000 c 5.64000000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Na 4_int 0.00000000 0.00000000 0.00000000 +Cl 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Mg 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 4.21200000 b 4.21200000 c 4.21200000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Mg 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Ca 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 4.81100000 b 4.81100000 c 4.81100000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Ca 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Zn 1_int S 1_int +space_group_symbol F-43m_sg +lattice_parameters a 5.40900000 b 5.40900000 c 5.40900000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Zn 4_int 0.00000000 0.00000000 0.00000000 +S 4_int 0.25000000 0.25000000 0.25000000 + +serialized_cif formula Li 1_int F 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 4.02800000 b 4.02800000 c 4.02800000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Li 4_int 0.00000000 0.00000000 0.00000000 +F 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula K 1_int Br 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 6.59500000 b 6.59500000 c 6.59500000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +K 4_int 0.00000000 0.00000000 0.00000000 +Br 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Ba 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 5.51800000 b 5.51800000 c 5.51800000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Ba 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + +serialized_cif formula Sr 1_int O 1_int +space_group_symbol Fm-3m_sg +lattice_parameters a 5.15900000 b 5.15900000 c 5.15900000 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Sr 4_int 0.00000000 0.00000000 0.00000000 +O 4_int 0.50000000 0.50000000 0.50000000 + diff --git a/demo/fixture_manifest.csv b/demo/fixture_manifest.csv index 1996d9a7..f2067b6e 100644 --- a/demo/fixture_manifest.csv +++ b/demo/fixture_manifest.csv @@ -4,7 +4,7 @@ rxnpred_with_tags,demo/rxnpred_tiny,50,"40 train / 10 test",repo-local fixture,r iupacsm,demo/datasets/CRLLM-PubChem-compounds1M.sample.csv,50,"loader split -> 45 train / 5 test",Figshare datasets.zip slice,ready,"50-row CSV extracted from the released PubChem bundle" iupacsm_with_tags,demo/datasets/CRLLM-PubChem-compounds1M-simple.sample.csv,50,"loader split -> 45 train / 5 test",Figshare datasets.zip slice,ready,"Same CSV as iupacsm" canonic,demo/datasets/CRLLM-PubChem-compounds1M.sample.csv,50,"loader split -> 45 train / 5 test",Figshare datasets.zip slice,ready,"Uses SMILES and SMILES_variant1 columns" -canonmc,demo/datasets/CRLLM-PubChem-compounds1M.sample.csv,50,"loader split -> 45 train / 5 test",Figshare datasets.zip slice,ready,"Uses canonical plus variant columns from the main 50-row PubChem slice" +canonmc,demo/datasets/CRLLM-PubChem-compounds1M.sample.csv,50,"~45 train / ~5 test",Figshare datasets.zip slice,ready,"Uses canonical plus variant columns from the main 50-row PubChem slice" smi_permute,demo/datasets/CRLLM-PubChem-compounds1M-very_very_simple.sample.csv,50,"loader split -> 45 train / 5 test",Figshare datasets.zip slice,ready,"50-row CSV extracted from the released PubChem bundle" smhydrogen,demo/datasets/CRLLM-PubChem-compounds1M_hydrogen.sample.csv,50,"loader expansion -> 90 train / 10 test",Figshare datasets.zip slice,ready,"Each input row yields add-H and remove-H examples" kinetic,demo/kinetic_tiny,50,"40 train / 10 validation",repo-local synthetic fixture,ready,"Matches the pickle layout expected by the kinetic loader" @@ -12,3 +12,4 @@ rxn_inversion,demo/datasets/rxn_inversion_sample.csv,50,"~45 train / ~5 test",re rxn_replacement,demo/datasets/rxn_inversion_sample.csv,50,"~45 train / ~5 test",repo-local fixture,ready,"Same schema as rxn_inversion (MCQ with 4 options)" rxn_naming,demo/datasets/rxn_naming_sample.csv,50,"~45 train / ~5 test",repo-local fixture,ready,"Reaction classification into 10 named categories" rxn_truefalse,demo/datasets/rxn_truefalse_sample.csv,50,"~45 train / ~5 test",repo-local fixture,ready,"Binary true/false reaction validity" +crystalrelax,demo/crystalrelax_tiny,50,"40 train / 10 test",repo-local fixture,ready,"M2S format binary compound structures (perturbed src / relaxed tgt)" diff --git a/demo/run_fixture_smoke.py b/demo/run_fixture_smoke.py index 4316f664..7d2b63c5 100644 --- a/demo/run_fixture_smoke.py +++ b/demo/run_fixture_smoke.py @@ -43,6 +43,10 @@ def main(): "rxn_truefalse": datasets_dir / "rxn_truefalse_sample.csv", } + # Tasks with heavy optional dependencies — only test if registered + if "crystalrelax" in CHEMTASKS: + task_configs["crystalrelax"] = demo_dir / "crystalrelax_tiny" + summary = {} for task_name, dataset_path in task_configs.items(): task_class = CHEMTASKS[task_name] diff --git a/docs/source/modules.rst b/docs/source/modules.rst index e9a2c264..10eb39b8 100644 --- a/docs/source/modules.rst +++ b/docs/source/modules.rst @@ -16,5 +16,6 @@ Modules Reference tasks/rxn_replacement tasks/rxn_naming tasks/rxn_truefalse + tasks/crystalrelax tasks/template tasks/smi_permute \ No newline at end of file diff --git a/docs/source/tasks/_static/structure_relaxing_result.png b/docs/source/tasks/_static/structure_relaxing_result.png new file mode 100644 index 00000000..b6b94591 Binary files /dev/null and b/docs/source/tasks/_static/structure_relaxing_result.png differ diff --git a/docs/source/tasks/_static/structure_relaxing_success_rate.png b/docs/source/tasks/_static/structure_relaxing_success_rate.png new file mode 100644 index 00000000..e12c3d02 Binary files /dev/null and b/docs/source/tasks/_static/structure_relaxing_success_rate.png differ diff --git a/docs/source/tasks/crystalrelax.rst b/docs/source/tasks/crystalrelax.rst new file mode 100644 index 00000000..9e882b81 --- /dev/null +++ b/docs/source/tasks/crystalrelax.rst @@ -0,0 +1,112 @@ +Crystal Relaxing +=================== + +.. currentmodule:: open_r1.tasks.relaxing + +BinaryCompoundRelaxing +------------------ + +.. autoclass:: BinaryCompoundRelaxing + :members: + :show-inheritance: + +Task Description +---------------- + +The `BinaryCompoundRelaxing` task guides a language model through multiple steps of structural relaxation on perturbed binary compounds. Given a serialized CIF description of a compound, the model must iteratively propose adjustments to reduce the internal energy, documenting its reasoning within tags and outputting a final relaxed structure within tags. + +Features +-------- + +- Uses an m2s-style serialized CIF representation as the task input and expected output format +- Prompts the model to provide crystallographic reasoning inside ```` tags and the final relaxed structure inside ```` tags +- Evaluates predictions with structure deserialization, CIF validity checks, and an energy-based reward computed from the generated and reference structures + +Usage Example +------------- + +.. code-block:: python + + from open_r1.tasks.crystal_structure.relaxing import BinaryCompoundRelaxing + + # Initialize the task, pointing to a local dataset directory + task = BinaryCompoundRelaxing(dataset_id_or_path="/path/to/cif_data") + + # Load datasets + dataset = task.load() + train_ds = dataset["train"] + test_ds = dataset["test"] + + # Compute accuracy rewards for an example prediction + completions = ["M2S serialized_cif …"] + solutions = ["M2S serialized_cif …"] + rewards = task.accuracy_reward(completions, solutions) + + +Data Format +----------- + +The task reads paired text files with multi-line CIF records separated by blank lines: + +- `src-train.txt / src-test.txt`: Each record is a serialized CIF string of a perturbed binary structure. +- `tgt-train.txt / tgt-test.txt`: Each record is the ground‑truth CIF string after DFT relaxation. + +Dataset +------- + +The CrystalRelax dataset is published on figshare: + +- DOI: https://doi.org/10.6084/m9.figshare.31948860 +- Download URL: https://figshare.com/ndownloader/articles/31948860/versions/1 + +Task +---------------- +Generate a structure with lower internal energy. + +Base model +---------------- +`Qwen/Qwen2.5-3B-Instruct`, fine-tuned on the MPtraj dataset via supervised fine-tuning (SFT). + +Reward Functions +---------------- + +1. **Accuracy Reward (accuracy_reward)** + - The completion is first reduced to the last ``...`` block. If no answer block is found, the reward is ``0``. + - The answer is deserialized with ``CIFTokenizer.deserialize(answer, solution)`` and then checked as a CIF structure with ``gemmi`` and ``pymatgen``. Any deserialization, CIF validation, or structure parsing failure returns ``0``. + - Valid structures are scored by comparing per-atom MACE potential energy between the reference solution CIF and the generated answer CIF: + - ``1``: the generated answer has lower per-atom potential energy than the reference solution. + - ``0.5``: the reference solution has lower per-atom potential energy than the generated answer. + - ``0``: both structures have equal per-atom potential energy, or the answer exactly matches the solution text before energy scoring. + +2. **Format Reward (format_reward)** + - The completion is expected to contain ``...`` followed by ``...``. If this tag pattern cannot be found, the reward is ``0.0``. + - If the completion does not start with ````, the reward function prepends that opening tag before matching. + - Once the tags are present, the reward is a keyword-based bonus score up to ``1.0``: + - ``+0.2`` for math-like reasoning patterns. + - ``+0.1`` each for CIF/crystallographic terms, position or coordinate terms, recognized space groups, crystallographic concepts, energy or force terms, dynamical stability terms, and structure-classification terms. + - ``+0.05`` each for lattice-angle terms and chemistry terms. + +Task Example +------------ + +.. code-block:: text + + Input: unstable Crystal structure [M2S format] + Output: relaxed Crystal structure [M2S format] + +.. image:: _static/structure_relaxing_result.png + :width: 400 + :align: center + :alt: accuracy and response length matrices + +Around step 50, the model experiences a pivotal shift (the “aha moment”) where it transitions from moderate performance gains to a pronounced acceleration in accuracy. By the final stages of training, the model achieves a 91% success rate in generating lower-energy structures, underscoring the effectiveness of the learning process after redesigning the experiments. +We selected the checkpoint at step 130 for evaluation on a larger, external test set. Among 471 binary crystal structures, the model achieved a success rate of 81% in generating structures with lower internal energy. + +.. image:: _static/structure_relaxing_success_rate.png + :width: 400 + :align: center + :alt: success rate of the model in generating structures with lower internal energy + +Motivation +---------------- +Crystal Structure Relaxation usually serves as the fundamental initial step in material science, establishing the foundation for the development of real-world applications. Traditional Crystal Structure Relaxation Algorithms are computationally expensive and scale poorly with system size, due to the intensive computational demands of their iterative procedures. To accelerate this process, we propose to introduce a large language model (LLM) trained via Group Relative Policy Optimisation (GRPO), which is designed to rapidly reduce the crystal’s internal energy and converge toward a DFT-relaxed configuration. By pre-relaxing structures, the DFT method can complete the residual relaxation steps in fewer iterations, significantly reducing the time required for the relaxation process, especially for complex structures. Our model is initially supervised fine-tuned (SFT) on the Materials Project Trajectory dataset, comprising multiple intermediate relaxation frames, and subsequently optimised with GRPO reinforcement learning on perturbed structures. Evaluation on a binary compound validation set shows a 91% success rate, defined as generating structures with lower internal energy than their inputs. These results validate the effectiveness of our current approach and suggest feasibility for further scaling to more complex material systems. diff --git a/recipes/crystalrelax.yaml b/recipes/crystalrelax.yaml new file mode 100644 index 00000000..2721d4ff --- /dev/null +++ b/recipes/crystalrelax.yaml @@ -0,0 +1,49 @@ +# Model arguments +model_revision: main +torch_dtype: bfloat16 +attn_implementation: flash_attention_2 +bf16: true +tf32: true + +# Chemical Task arguments +chem_task: crystalrelax +dataset_id_or_path: ${MIST_DATA_DIR}/binary_compound_relaxing +rewards: +- accuracy + +# Lora Arguments +# No LoRA is used here + +# Training arguments +max_steps: 1450 +per_device_train_batch_size: 2 +gradient_accumulation_steps: 8 +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +learning_rate: 5.0e-7 # 1.0e-6 as in the deepseek math paper 5-e7 from https://hijkzzz.notion.site/unraveling-rlhf-and-its-variants-engineering-insights#147d9a33ecc9806090f3d5c749d31f05 +lr_scheduler_type: cosine +warmup_ratio: 0.03 +# GRPO specific parameters +beta: 0.001 # 0.04 as in the deepseek math paper 0.001 from https://hijkzzz.notion.site/unraveling-rlhf-and-its-variants-engineering-insights#147d9a33ecc9806090f3d5c749d31f05 +max_prompt_length: 600 +max_completion_length: 2200 +num_generations: 16 +use_vllm: true +vllm_device: "cuda:3" +vllm_gpu_memory_utilization: 0.7 +vllm_max_model_len: 3000 + +# Logging arguments +logging_strategy: steps +logging_steps: 2 +report_to: +- wandb + +save_strategy: "steps" +save_steps: 25 +seed: 42 + +# Hugging Face Hub +push_to_hub: false + # hub_model_id: llama-3-1-8b-math-orca-qlora-10k-ep1 # if not defined same as output_dir diff --git a/setup.py b/setup.py index 5c34488c..c64b71db 100644 --- a/setup.py +++ b/setup.py @@ -41,6 +41,7 @@ # * If a dependency is fast-moving (e.g. transformers), pin to the exact version _deps = [ "accelerate>=1.2.1", + "ase", "bitsandbytes>=0.43.0", "black>=24.4.2", "datasets>=3.2.0", @@ -48,17 +49,21 @@ "distilabel[vllm,ray,openai]>=1.5.2", "einops>=0.8.0", "flake8>=6.0.0", + "gemmi", "hf_transfer>=0.1.4", "huggingface-hub[cli]>=0.19.2,<1.0", "isort>=5.12.0", "liger_kernel==0.5.2", "lighteval @ git+https://github.com/huggingface/lighteval.git@4f381b352c0e467b5870a97d41cb66b487a2c503#egg=lighteval[math]", + "mace-torch", "math-verify>=0.3.2", # Used for math verification in grpo "packaging>=23.0", "parameterized>=0.9.0", + "pymatgen", "pytest", "safetensors>=0.3.3", "sentencepiece>=0.1.99", + "smact", "torch>=2.5.1", "transformers @ git+https://github.com/huggingface/transformers.git@main", "trl @ git+https://github.com/huggingface/trl.git@main", @@ -89,16 +94,21 @@ def deps_list(*pkgs): # core dependencies shared across the whole project - keep this to a bare minimum :) install_requires = [ deps["accelerate"], + deps["ase"], deps["bitsandbytes"], deps["einops"], deps["datasets"], deps["deepspeed"], + deps["gemmi"], deps["hf_transfer"], deps["huggingface-hub"], deps["liger_kernel"], + deps["mace-torch"], + deps["pymatgen"], deps["packaging"], # utilities from PyPA to e.g., compare versions deps["safetensors"], deps["sentencepiece"], + deps["smact"], deps["transformers"], deps["trl"], ] @@ -115,6 +125,11 @@ def deps_list(*pkgs): url="https://github.com/schwallergroup/mist", package_dir={"": "src"}, packages=find_packages("src"), + package_data={ + "open_r1.dataset": ["spacegroups.txt"], + "open_r1.tasks.smiles_understanding": ["smiles_hydrogen_prompt_guiding.json"], + "open_r1.tasks.crystal_structure.AIRS_preprocess": ["spacegroups.txt"], + }, zip_safe=False, extras_require=extras, python_requires=">=3.10.9", diff --git a/src/open_r1/dataset/__init__.py b/src/open_r1/dataset/__init__.py new file mode 100644 index 00000000..798a05e5 --- /dev/null +++ b/src/open_r1/dataset/__init__.py @@ -0,0 +1 @@ +"""Dataset helpers for Open R1.""" diff --git a/src/open_r1/dataset/crystal_structure_relaxing.py b/src/open_r1/dataset/crystal_structure_relaxing.py new file mode 100644 index 00000000..ebce41c1 --- /dev/null +++ b/src/open_r1/dataset/crystal_structure_relaxing.py @@ -0,0 +1,444 @@ +import argparse +import os +import random +import re + +import pandas as pd +from torch.utils.data import Dataset +from tqdm import tqdm + +from verl.utils.hdfs_io import copy, makedirs + +# This file contains code adapted from the AIRS project: +# https://github.com/divelab/AIRS/blob/main/OpenMat/Mat2Seq/mat2seq/_tokenizer.py +# +# Copyright (c) 2023 Luis M. Antunes +# Licensed under the MIT License. +# +# The CIFTokenizer class and related utilities are reused and modified here +# for downstream crystallographic tasks. +# +# Modifications by Ruizhi Xu, 2025 + + +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) + + +with open(os.path.join(THIS_DIR, "spacegroups.txt"), "rt") as f: + SPACE_GROUPS = [sg.strip() for sg in f.readlines()] + + +ATOMS = [ + "Si", + "C", + "Pb", + "I", + "Br", + "Cl", + "Eu", + "O", + "Fe", + "Sb", + "In", + "S", + "N", + "U", + "Mn", + "Lu", + "Se", + "Tl", + "Hf", + "Ir", + "Ca", + "Ta", + "Cr", + "K", + "Pm", + "Mg", + "Zn", + "Cu", + "Sn", + "Ti", + "B", + "W", + "P", + "H", + "Pd", + "As", + "Co", + "Np", + "Tc", + "Hg", + "Pu", + "Al", + "Tm", + "Tb", + "Ho", + "Nb", + "Ge", + "Zr", + "Cd", + "V", + "Sr", + "Ni", + "Rh", + "Th", + "Na", + "Ru", + "La", + "Re", + "Y", + "Er", + "Ce", + "Pt", + "Ga", + "Li", + "Cs", + "F", + "Ba", + "Te", + "Mo", + "Gd", + "Pr", + "Bi", + "Sc", + "Ag", + "Rb", + "Dy", + "Yb", + "Nd", + "Au", + "Os", + "Pa", + "Sm", + "Be", + "Ac", + "Xe", + "Kr", + "He", + "Ne", + "Ar", +] + +DIGITS = [str(d) for d in list(range(10))] + +INTS = [str(d) for d in list(range(300))] + +KEYWORDS = ["space_group_symbol", "formula", "atoms", "lattice_parameters", "a", "b", "c", "alpha", "beta", "gamma"] + +UNK_TOKEN = "" + + +def get_spacegroup_number(sg_symbol): + try: + from pymatgen.symmetry.groups import SpaceGroup + + sg = SpaceGroup(sg_symbol) + return sg + except Exception as e: + print("Err:", e) + return None + + +def parse_formula(formula): + formula = formula.replace("'", "").replace('"', "").strip() + pattern = r"([A-Z][a-z]*)(\d*)" + counts = {} + for element, count in re.findall(pattern, formula): + counts[element] = counts.get(element, 0) + (int(count) if count else 1) + return counts + + +def compute_cell_formula_units_Z(formula_sum, formula_structural): + counts_sum = parse_formula(formula_sum) + counts_struct = parse_formula(formula_structural) + + ratios = [] + for element, count_struct in counts_struct.items(): + if element not in counts_sum: + raise ValueError(f"{element}") + ratio = counts_sum[element] / count_struct + if ratio != int(ratio): + raise ValueError(f"{element}, {ratio} not int") + ratios.append(int(ratio)) + + if len(set(ratios)) != 1: + raise ValueError(f"{ratios} != 1") + return ratios[0] + + +class CIFTokenizer: + def __init__(self): + self._tokens = [""] + self._tokens.extend(self.atoms()) + self._tokens.extend(self.digits()) + self._tokens.extend(self.keywords()) + self._tokens.extend(self.symbols()) + + space_groups = list(self.space_groups()) + # Replace 'Pm' space group with 'Pm_sg' to disambiguate from atom 'Pm', + # or 'P1' with 'P1_sg' to disambiguate from atom 'P' and number '1' + space_groups_sg = [sg + "_sg" for sg in space_groups] + self._tokens.extend(space_groups_sg) + + digits_int = [v + "_int" for v in INTS] + self._tokens.extend(digits_int) + + self._escaped_tokens = [re.escape(token) for token in self._tokens] + self._escaped_tokens.sort(key=len, reverse=True) + + # a mapping from characters to integers + self._token_to_id = {ch: i for i, ch in enumerate(self._tokens)} + self._id_to_token = {i: ch for i, ch in enumerate(self._tokens)} + # map the id of 'Pm_sg' back to 'Pm', or 'P1_sg' to 'P1', + # for decoding convenience + for sg in space_groups_sg: + self._id_to_token[self.token_to_id[sg]] = sg.replace("_sg", "") + + for v_int in digits_int: + self._id_to_token[self.token_to_id[v_int]] = v_int.replace("_int", "") + + @staticmethod + def atoms(): + return ATOMS + + @staticmethod + def digits(): + return DIGITS + + @staticmethod + def keywords(): + kws = list(KEYWORDS) + return kws + + @staticmethod + def symbols(): + # return ["x", "y", "z", ".", "(", ")", "+", "-", "/", "'", ",", " ", "\n"] + return [",", " ", ":", ".", "\n"] + + @staticmethod + def space_groups(): + return SPACE_GROUPS + + @property + def token_to_id(self): + return dict(self._token_to_id) + + @property + def id_to_token(self): + return dict(self._id_to_token) + + def encode(self, tokens): + # encoder: take a list of tokens, output a list of integers + return [self._token_to_id[t] for t in tokens] + + def decode(self, ids): + # decoder: take a list of integers (i.e. encoded tokens), output a string + return "".join([self._id_to_token[i] for i in ids]) + + def serialize(self, cif_string): + spacegroups = "|".join(SPACE_GROUPS) + cif_string = re.sub(rf"(_symmetry_space_group_name_H-M *\b({spacegroups}))\n", r"\1_sg\n", cif_string) + extracted_data = self.tokenize_cif_preprocess(cif_string) + + seq_res = "" + # formula + seq_res += "formula " + formula = extracted_data["formula"] + elements_counts = re.findall(r"([A-Z][a-z]*)(\d*)", formula) + for element, count in elements_counts: + if not element: + break + if not count: + count = "1" + seq_res += element + " " + count + "_int " + seq_res += "\n" + # space group name + seq_res += "space_group_symbol " + extracted_data["space_group_symbol"] + "\n" + # lattice + seq_res += "lattice_parameters " + "a " + extracted_data["lattice_parameters"]["a"] + " " + seq_res += "b " + extracted_data["lattice_parameters"]["b"] + " " + seq_res += "c " + extracted_data["lattice_parameters"]["c"] + " " + seq_res += "alpha " + extracted_data["lattice_parameters"]["alpha"] + " " + seq_res += "beta " + extracted_data["lattice_parameters"]["beta"] + " " + seq_res += "gamma " + extracted_data["lattice_parameters"]["gamma"] + " " + seq_res += "\n" + # atoms + for idx in range(len(extracted_data["atoms"])): + tmp = extracted_data["atoms"][idx] + seq_res += ( + tmp["type"] + + " " + + tmp["num"] + + "_int " + + tmp["coordinates"][0] + + " " + + tmp["coordinates"][1] + + " " + + tmp["coordinates"][2] + + "\n" + ) + seq_res += "\n" + # Create a regex pattern by joining the escaped tokens with '|' + token_pattern = "|".join(self._escaped_tokens) + # Add a regex pattern to match any sequence of characters separated by whitespace or punctuation + full_pattern = f"({token_pattern}|\\w+|[\\.,;!?])" + # Tokenize the input string using the regex pattern + seq_res = re.sub(r"[ \t]+", " ", seq_res) + return seq_res + + def tokenize_cif_preprocess(self, cif_string): + # Re-initialize the dictionary to hold the extracted data + extracted_data = {"space_group_symbol": "", "formula": "", "atoms": [], "lattice_parameters": {}} + + # Split the text into lines for processing + lines = cif_string.split("\n") + + # Iterate through each line to extract the required information + atom_line_idx = -1 + for line_idx in range(len(lines)): + line = lines[line_idx] + # Extract space group symbol + if "_symmetry_space_group_name_H-M" in line: + spacegroup_match = re.search(r"_symmetry_space_group_name_H-M\s+([^\n]+)", line) + spacegroup = spacegroup_match.group(1).strip() + extracted_data["space_group_symbol"] = spacegroup + # Extract formula + elif line.startswith("data_"): + extracted_data["formula"] = line.split("_")[1] + # Extract lattice parameters + elif line.startswith("_cell_length_a"): + extracted_data["lattice_parameters"]["a"] = line.split()[-1] + elif line.startswith("_cell_length_b"): + extracted_data["lattice_parameters"]["b"] = line.split()[-1] + elif line.startswith("_cell_length_c"): + extracted_data["lattice_parameters"]["c"] = line.split()[-1] + elif line.startswith("_cell_angle_alpha"): + extracted_data["lattice_parameters"]["alpha"] = line.split()[-1] + elif line.startswith("_cell_angle_beta"): + extracted_data["lattice_parameters"]["beta"] = line.split()[-1] + elif line.startswith("_cell_angle_gamma"): + extracted_data["lattice_parameters"]["gamma"] = line.split()[-1] + elif "_atom_site_occupancy" in line: + atom_line_idx = line_idx + 1 + break + + for line_idx in range(atom_line_idx, len(lines)): + line = lines[line_idx] + if len(line) < 2: + continue + atom_info = line.split() + atom_type = atom_info[0] + num_atoms = atom_info[2] + x, y, z = atom_info[3], atom_info[4], atom_info[5] + extracted_data["atoms"].append({"type": atom_type, "num": num_atoms, "coordinates": (x, y, z)}) + + return extracted_data + + +# Initialize the tokenizer +cif_tokenizer = CIFTokenizer() + + +def load_cif_dataset(binary_dir: str, perturbed_dir: str, size: int, local_dir: str) -> list: + """ + Load the dataset: + - Read a parquet dataframe from local_dir (assumed to be "perturbed_df_cif.parquet"). + - Load the ground truth CIF file and the perturbed CIF file based on material_id. + - Serialize the loaded content using cif_tokenizer.serialize. + """ + parquet_path = os.path.join(local_dir, "perturbed_df_cif.parquet") + df = pd.read_parquet(parquet_path) + df = df.head(size) + samples = [] + + for idx, row in tqdm(df.iterrows(), total=len(df), desc="Loading CIF dataset"): + material_id = row["material_id"] + # Extract the ground truth file name from material_id + ground_truth_material_id = material_id.split("_random_")[0] + gt_file = os.path.join(binary_dir, f"{ground_truth_material_id}.cif") + try: + with open(gt_file, "r", encoding="utf-8") as f: + gt_content = f.read() + except Exception as e: + print(f"Error reading ground truth file {gt_file}: {e}") + continue + + perturbed_file = os.path.join(perturbed_dir, f"{material_id}.cif") + try: + with open(perturbed_file, "r", encoding="utf-8") as f: + perturbed_content = f.read() + except Exception as e: + print(f"Error reading perturbed file {perturbed_file}: {e}") + continue + + # Note: Both ground truth and perturbed content are serialized here. + sample = { + "compound_id": material_id, + "ground_truth": cif_tokenizer.serialize(gt_content), + "perturbed": cif_tokenizer.serialize(perturbed_content), + } + print("material_id:", material_id) + print("ground_truth:", cif_tokenizer.serialize(gt_content)) + print("perturbed:", cif_tokenizer.serialize(perturbed_content)) + samples.append(sample) + + print(f"{len(samples)} samples loaded") + return samples + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--local_dir", default="./binary_compounds_dataset", help="Directory where the dataset is saved locally" + ) + parser.add_argument("--hdfs_dir", default=None, help="HDFS directory (optional)") + parser.add_argument("--train_size", type=int, default=500, help="Number of training set samples") + parser.add_argument("--test_size", type=int, default=100, help="Number of test set samples") + args = parser.parse_args() + + # Construct the paths for CIF files + binary_dir = os.path.join(args.local_dir, "binary_compounds_cifs") + perturbed_dir = os.path.join(args.local_dir, "perturbed_binary_compounds_cifs") + + samples = load_cif_dataset(binary_dir, perturbed_dir, args.train_size + args.test_size, args.local_dir) + random.shuffle(samples) + total_samples = len(samples) + print(f"Total number of samples loaded: {total_samples}") + + # If the number of samples is insufficient, use all samples as the training set. + if total_samples < (args.train_size + args.test_size): + print("Warning: Insufficient samples, will use all samples as training set.") + train_samples = samples + test_samples = [] + else: + train_samples = samples[: args.train_size] + test_samples = samples[args.train_size : args.train_size + args.test_size] + + # Construct the file paths for generating the dataset required by BinaryCompoundRelaxing. + src_train_path = os.path.join(args.local_dir, "src-train.txt") + tgt_train_path = os.path.join(args.local_dir, "tgt-train.txt") + src_test_path = os.path.join(args.local_dir, "src-test.txt") + tgt_test_path = os.path.join(args.local_dir, "tgt-test.txt") + + # Write the training set text: each line of the question uses the 'perturbed' field, and the corresponding answer uses the 'ground_truth' field. + with open(src_train_path, "w", encoding="utf-8") as f_src, open(tgt_train_path, "w", encoding="utf-8") as f_tgt: + for sample in train_samples: + f_src.write(sample["perturbed"] + "\n") + f_tgt.write(sample["ground_truth"] + "\n") + + # Write the test set text files. + if test_samples: + with open(src_test_path, "w", encoding="utf-8") as f_src, open(tgt_test_path, "w", encoding="utf-8") as f_tgt: + for sample in test_samples: + f_src.write(sample["perturbed"] + "\n") + f_tgt.write(sample["ground_truth"] + "\n") + else: + # If the test set is empty, create empty files. + open(src_test_path, "w", encoding="utf-8").close() + open(tgt_test_path, "w", encoding="utf-8").close() + + # If an HDFS directory is specified, copy the local_dir to HDFS. + if args.hdfs_dir is not None: + makedirs(args.hdfs_dir) + copy(src=args.local_dir, dst=args.hdfs_dir) diff --git a/src/open_r1/dataset/spacegroups.txt b/src/open_r1/dataset/spacegroups.txt new file mode 100644 index 00000000..767fcd16 --- /dev/null +++ b/src/open_r1/dataset/spacegroups.txt @@ -0,0 +1,227 @@ +P6/mmm +Imma +P4_32_12 +P4_2/mnm +Fd-3m +P3m1 +P-3 +P4mm +P4_332 +P4/nnc +P2_12_12 +Pnn2 +Pbcn +P4_2/n +Cm +R3m +Cmce +Aea2 +P-42_1m +P-42m +P2_13 +R-3 +Fm-3 +Cmm2 +Pn-3n +P6/mcc +P-6m2 +P3_2 +P-3m1 +P3_212 +I23 +P-62m +P4_2nm +Pma2 +Pmma +I-42m +P-31c +Pa-3 +Pmmn +Pmmm +P4_2/ncm +I4/mcm +I-4m2 +P3_1 +Pcc2 +Cmcm +I222 +Fddd +P312 +Cccm +P6_1 +F-43c +P6_322 +Pm-3 +P3_121 +P6_4 +Ia-3d +Pm-3m +P2_1/c +C222_1 +Pc +P4/n +Pba2 +Ama2 +Pbcm +P31m +Pcca +P222 +P-43n +Pccm +P6_422 +F23 +P42_12 +C222 +Pnnn +P6_3cm +P4_12_12 +P6/m +Fmm2 +I4_1/a +P4/mbm +Pmn2_1 +P4_2bc +P4_22_12 +I-43d +I4/m +P4bm +Fdd2 +P3 +P6_122 +Pnc2 +P4_2/mcm +P4_122 +Cmc2_1 +P-6c2 +R32 +P4_1 +P4_232 +Pnna +P422 +Pban +Cc +I4_122 +P6_3/m +P6_3mc +I4_1/amd +P4_2 +P4/nmm +Pmna +P4/m +Fm-3m +P4/mmm +Imm2 +P4/ncc +P-62c +Ima2 +P6_5 +P2/c +P4/nbm +Ibam +P6_522 +P6_3/mmc +I4/mmm +Fmmm +P2/m +P-4b2 +I-4 +C2/m +P4_2/mmc +P4 +Fd-3c +P4_3 +P2_1/m +I-43m +P-42c +F4_132 +Pm +Pccn +P-4n2 +P4_132 +P23 +I4cm +R3c +Amm2 +Immm +Iba2 +I4 +Fd-3 +P1 +Pbam +P4_2/nbc +Im-3 +P4_2/nnm +Pmc2_1 +P-31m +R-3m +Ia-3 +P622 +F222 +P2 +P-1 +Pmm2 +P-4 +Aem2 +P6_222 +P-3c1 +P4_322 +I422 +Pnma +P6_3 +P3c1 +Pn-3 +P4nc +P-6 +P4/mcc +I2_12_12_1 +P4_2/mbc +P31c +Ccc2 +P4_2/nmc +P6_3/mcm +C2 +Pbca +P-4c2 +I4_1cd +P2_1 +P3_112 +P4_2mc +Pn-3m +C2/c +R3 +P-43m +I432 +P222_1 +I-42d +I-4c2 +P6cc +P6_2 +P3_221 +P321 +Pca2_1 +I4_1/acd +I4_132 +F432 +Pna2_1 +Ccce +Ibca +P4/mnc +I4_1md +P2_12_12_1 +R-3c +I2_13 +P-4m2 +Pm-3n +I4mm +F-43m +Pnnm +P-42_1c +Cmmm +P6mm +P4_2cm +P4_2/m +Im-3m +Fm-3c +I4_1 +P4cc +Cmme diff --git a/src/open_r1/tasks/__init__.py b/src/open_r1/tasks/__init__.py index 8812ce63..a4b9d6ec 100644 --- a/src/open_r1/tasks/__init__.py +++ b/src/open_r1/tasks/__init__.py @@ -26,3 +26,10 @@ "rxn_naming": Smiles2Name, "rxn_truefalse": ReactionTrueFalse, } + +try: + from .crystal_structure.relaxing import BinaryCompoundRelaxing + + CHEMTASKS["crystalrelax"] = BinaryCompoundRelaxing +except ImportError: + pass diff --git a/src/open_r1/tasks/crystal_structure/AIRS_preprocess/LICENSE b/src/open_r1/tasks/crystal_structure/AIRS_preprocess/LICENSE new file mode 100644 index 00000000..b552f185 --- /dev/null +++ b/src/open_r1/tasks/crystal_structure/AIRS_preprocess/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Luis M. Antunes + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/src/open_r1/tasks/crystal_structure/AIRS_preprocess/__init__.py b/src/open_r1/tasks/crystal_structure/AIRS_preprocess/__init__.py new file mode 100644 index 00000000..2eaaa4c3 --- /dev/null +++ b/src/open_r1/tasks/crystal_structure/AIRS_preprocess/__init__.py @@ -0,0 +1 @@ +"""AIRS preprocessing utilities for crystal structure tasks.""" diff --git a/src/open_r1/tasks/crystal_structure/AIRS_preprocess/_tokenizer.py b/src/open_r1/tasks/crystal_structure/AIRS_preprocess/_tokenizer.py new file mode 100644 index 00000000..d0c6899f --- /dev/null +++ b/src/open_r1/tasks/crystal_structure/AIRS_preprocess/_tokenizer.py @@ -0,0 +1,594 @@ +import math +import os +import re + +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) + + +with open(os.path.join(THIS_DIR, "spacegroups.txt"), "rt") as f: + SPACE_GROUPS = [sg.strip() for sg in f.readlines()] + + +ATOMS = [ + "Si", + "C", + "Pb", + "I", + "Br", + "Cl", + "Eu", + "O", + "Fe", + "Sb", + "In", + "S", + "N", + "U", + "Mn", + "Lu", + "Se", + "Tl", + "Hf", + "Ir", + "Ca", + "Ta", + "Cr", + "K", + "Pm", + "Mg", + "Zn", + "Cu", + "Sn", + "Ti", + "B", + "W", + "P", + "H", + "Pd", + "As", + "Co", + "Np", + "Tc", + "Hg", + "Pu", + "Al", + "Tm", + "Tb", + "Ho", + "Nb", + "Ge", + "Zr", + "Cd", + "V", + "Sr", + "Ni", + "Rh", + "Th", + "Na", + "Ru", + "La", + "Re", + "Y", + "Er", + "Ce", + "Pt", + "Ga", + "Li", + "Cs", + "F", + "Ba", + "Te", + "Mo", + "Gd", + "Pr", + "Bi", + "Sc", + "Ag", + "Rb", + "Dy", + "Yb", + "Nd", + "Au", + "Os", + "Pa", + "Sm", + "Be", + "Ac", + "Xe", + "Kr", + "He", + "Ne", + "Ar", +] + +DIGITS = [str(d) for d in list(range(10))] + +INTS = [str(d) for d in list(range(300))] + +KEYWORDS = ["space_group_symbol", "formula", "atoms", "lattice_parameters", "a", "b", "c", "alpha", "beta", "gamma"] + +UNK_TOKEN = "" + + +def get_spacegroup_number(sg_symbol): + try: + from pymatgen.symmetry.groups import SpaceGroup + + sg = SpaceGroup(sg_symbol) + return sg + except Exception as e: + print("Err:", e) + return None + + +def parse_formula(formula): + formula = formula.replace("'", "").replace('"', "").strip() + pattern = r"([A-Z][a-z]*)(\d*)" + counts = {} + for element, count in re.findall(pattern, formula): + counts[element] = counts.get(element, 0) + (int(count) if count else 1) + return counts + + +def compute_cell_formula_units_Z(formula_sum, formula_structural): + counts_sum = parse_formula(formula_sum) + counts_struct = parse_formula(formula_structural) + + ratios = [] + for element, count_struct in counts_struct.items(): + if element not in counts_sum: + raise ValueError(f"{element}") + ratio = counts_sum[element] / count_struct + if ratio != int(ratio): + raise ValueError(f"{element}, {ratio} not int") + ratios.append(int(ratio)) + + if len(set(ratios)) != 1: + raise ValueError(f"{ratios} != 1") + return ratios[0] + + +class CIFTokenizer: + def __init__(self): + self._tokens = [""] + self._tokens.extend(self.atoms()) + self._tokens.extend(self.digits()) + self._tokens.extend(self.keywords()) + self._tokens.extend(self.symbols()) + + space_groups = list(self.space_groups()) + # Replace 'Pm' space group with 'Pm_sg' to disambiguate from atom 'Pm', + # or 'P1' with 'P1_sg' to disambiguate from atom 'P' and number '1' + space_groups_sg = [sg + "_sg" for sg in space_groups] + self._tokens.extend(space_groups_sg) + + digits_int = [v + "_int" for v in INTS] + self._tokens.extend(digits_int) + + self._escaped_tokens = [re.escape(token) for token in self._tokens] + self._escaped_tokens.sort(key=len, reverse=True) + + # a mapping from characters to integers + self._token_to_id = {ch: i for i, ch in enumerate(self._tokens)} + self._id_to_token = {i: ch for i, ch in enumerate(self._tokens)} + # map the id of 'Pm_sg' back to 'Pm', or 'P1_sg' to 'P1', + # for decoding convenience + for sg in space_groups_sg: + self._id_to_token[self.token_to_id[sg]] = sg.replace("_sg", "") + + for v_int in digits_int: + self._id_to_token[self.token_to_id[v_int]] = v_int.replace("_int", "") + + @staticmethod + def atoms(): + return ATOMS + + @staticmethod + def digits(): + return DIGITS + + @staticmethod + def keywords(): + kws = list(KEYWORDS) + return kws + + @staticmethod + def symbols(): + # return ["x", "y", "z", ".", "(", ")", "+", "-", "/", "'", ",", " ", "\n"] + return [",", " ", ":", ".", "\n"] + + @staticmethod + def space_groups(): + return SPACE_GROUPS + + @property + def token_to_id(self): + return dict(self._token_to_id) + + @property + def id_to_token(self): + return dict(self._id_to_token) + + def prompt_tokenize(self, cif): + token_pattern = "|".join(self._escaped_tokens) + # Add a regex pattern to match any sequence of characters separated by whitespace or punctuation + full_pattern = f"({token_pattern}|\\w+|[\\.,;!?])" + # Tokenize the input string using the regex pattern + cif = re.sub(r"[ \t]+", " ", cif) + tokens = re.findall(full_pattern, cif) + return tokens + + def encode(self, tokens): + # encoder: take a list of tokens, output a list of integers + return [self._token_to_id[t] for t in tokens] + + def decode(self, ids): + # decoder: take a list of integers (i.e. encoded tokens), output a string + return "".join([self._id_to_token[i] for i in ids]) + + def serialize(self, cif_string): + spacegroups = "|".join(SPACE_GROUPS) + cif_string = re.sub(rf"(_symmetry_space_group_name_H-M *\b({spacegroups}))\n", r"\1_sg\n", cif_string) + extracted_data = self.tokenize_cif_preprocess(cif_string) + + seq_res = "" + # formula + seq_res += "formula " + formula = extracted_data["formula"] + elements_counts = re.findall(r"([A-Z][a-z]*)(\d*)", formula) + for element, count in elements_counts: + if not element: + break + if not count: + count = "1" + seq_res += element + " " + count + "_int " + seq_res += "\n" + # space group name + seq_res += "space_group_symbol " + extracted_data["space_group_symbol"] + "\n" + # lattice + seq_res += "lattice_parameters " + "a " + extracted_data["lattice_parameters"]["a"] + " " + seq_res += "b " + extracted_data["lattice_parameters"]["b"] + " " + seq_res += "c " + extracted_data["lattice_parameters"]["c"] + " " + seq_res += "alpha " + extracted_data["lattice_parameters"]["alpha"] + " " + seq_res += "beta " + extracted_data["lattice_parameters"]["beta"] + " " + seq_res += "gamma " + extracted_data["lattice_parameters"]["gamma"] + " " + seq_res += "\n" + # atoms + for idx in range(len(extracted_data["atoms"])): + tmp = extracted_data["atoms"][idx] + seq_res += ( + tmp["type"] + + " " + + tmp["num"] + + "_int " + + tmp["coordinates"][0] + + " " + + tmp["coordinates"][1] + + " " + + tmp["coordinates"][2] + + "\n" + ) + seq_res += "\n" + # Create a regex pattern by joining the escaped tokens with '|' + token_pattern = "|".join(self._escaped_tokens) + # Add a regex pattern to match any sequence of characters separated by whitespace or punctuation + full_pattern = f"({token_pattern}|\\w+|[\\.,;!?])" + # Tokenize the input string using the regex pattern + seq_res = re.sub(r"[ \t]+", " ", seq_res) + return seq_res + + def deserialize(self, custom_str, ground_truth=None): + print("self", self) + print("custom_str", custom_str) + print("ground_truth", ground_truth) + pattern_structural = re.compile(r"_chemical_formula_structural\s+['\"]?([^\n'\"]+)['\"]?") + pattern_sum = re.compile(r"_chemical_formula_sum\s+['\"]?([^'\"]+)['\"]?") + pattern_units = re.compile(r"_cell_formula_units_Z\s+(\d+)") + + structural_match = pattern_structural.search(ground_truth) + sum_match = pattern_sum.search(ground_truth) + units_match = pattern_units.search(ground_truth) + + symmetry_equiv_pos_pattern = re.compile( + r"loop_\s*\n\s*_symmetry_equiv_pos_site_id\s*\n\s*_symmetry_equiv_pos_as_xyz\s*\n(.*?)(?:\nloop_|\Z)", + re.DOTALL, + ) + symmetry_equiv_pos_match = symmetry_equiv_pos_pattern.search(ground_truth) + if symmetry_equiv_pos_match: + sym_ops_block = symmetry_equiv_pos_match.group(1).strip() + + formula_structural = structural_match.group(1) if structural_match else None + formula_sum = sum_match.group(1) if sum_match else None + units_Z = int(units_match.group(1)) if units_match else None + print("formula_structural", formula_structural) + lines = custom_str.strip().splitlines() + data = {} + + if lines: + tokens = lines[0].split() + if tokens[0] != "formula": + raise ValueError("'formula' missing") + formula = "" + for i in range(1, len(tokens), 2): + element = tokens[i] + count_token = tokens[i + 1] if i + 1 < len(tokens) else "" + if count_token.endswith("_int"): + count = count_token[:-4] + else: + count = count_token + formula += f"{element}{count}" + data["formula"] = formula + + if len(lines) >= 2: + tokens = lines[1].split() + if tokens[0] != "space_group_symbol": + raise ValueError("'space_group_symbol' missing") + data["space_group_symbol"] = " ".join(tokens[1:]) + + if len(lines) >= 3: + tokens = lines[2].split() + if tokens[0] != "lattice_parameters": + raise ValueError("'lattice_parameters' missing") + lattice = {} + for i in range(1, len(tokens), 2): + key = tokens[i] + value = tokens[i + 1] if i + 1 < len(tokens) else "" + lattice[key] = value + data["lattice_parameters"] = lattice + + atoms = [] + for line in lines[3:]: + if not line.strip(): + break + tokens = line.split() + if len(tokens) < 5: + continue + atom_type = tokens[0] + num_token = tokens[1] + if num_token.endswith("_int"): + num = num_token[:-4] + else: + num = num_token + coords = tokens[2:5] + atoms.append({"type": atom_type, "num": num, "coordinates": coords}) + data["atoms"] = atoms + + cif_lines = [] + cif_lines.append(f"data_{formula_structural}") + cif_lines.append(f"_symmetry_space_group_name_H-M {data['space_group_symbol'].split('_sg')[0]}") + lattice = data["lattice_parameters"] + cif_lines.append(f"_cell_length_a {lattice.get('a', '')}") + cif_lines.append(f"_cell_length_b {lattice.get('b', '')}") + cif_lines.append(f"_cell_length_c {lattice.get('c', '')}") + cif_lines.append(f"_cell_angle_alpha {lattice.get('alpha', '')}") + cif_lines.append(f"_cell_angle_beta {lattice.get('beta', '')}") + cif_lines.append(f"_cell_angle_gamma {lattice.get('gamma', '')}") + space_group_symbol = str(get_spacegroup_number(data["space_group_symbol"].split("_sg")[0].strip("'"))) + space_group_symbol = re.search(r"number\s+(\d+)", space_group_symbol).group(1) + cif_lines.append(f"_symmetry_Int_Tables_number {space_group_symbol}") + cif_lines.append(f"_chemical_formula_structural {formula_structural}") + cif_lines.append(f"_chemical_formula_sum '{formula_sum}'") + + a = float(lattice.get("a", 0)) + b = float(lattice.get("b", 0)) + c = float(lattice.get("c", 0)) + alpha = float(lattice.get("alpha", 90)) + beta = float(lattice.get("beta", 90)) + gamma = float(lattice.get("gamma", 90)) + alpha_rad = math.radians(alpha) + beta_rad = math.radians(beta) + gamma_rad = math.radians(gamma) + + cos_alpha = math.cos(alpha_rad) + cos_beta = math.cos(beta_rad) + cos_gamma = math.cos(gamma_rad) + cell_volume = ( + a * b * c * math.sqrt(1 - cos_alpha**2 - cos_beta**2 - cos_gamma**2 + 2 * cos_alpha * cos_beta * cos_gamma) + ) + cif_lines.append(f"_cell_volume {cell_volume:.8f}") + cif_lines.append(f"_cell_formula_units_Z '{units_Z}'") + cif_lines.append("loop_") + cif_lines.append(" _symmetry_equiv_pos_site_id") + cif_lines.append(" _symmetry_equiv_pos_as_xyz") + cif_lines.append(f" {sym_ops_block}") + cif_lines.append("loop_") + cif_lines.append("_atom_site_type_symbol") + cif_lines.append("_atom_site_label") + cif_lines.append("_atom_site_symmetry_multiplicity") + cif_lines.append("_atom_site_fract_x") + cif_lines.append("_atom_site_fract_y") + cif_lines.append("_atom_site_fract_z") + cif_lines.append("_atom_site_occupancy") + unique_counts = {} + for atom in data["atoms"]: + label = f"{atom['type']}" + if label not in unique_counts: + unique_counts[label] = len(unique_counts) + label = label + str(unique_counts[label]) + else: + label = label + str(unique_counts[label]) + cif_lines.append( + f"{ atom['type']} {label} {atom['num']} {atom['coordinates'][0]} {atom['coordinates'][1]} {atom['coordinates'][2]} 1" + ) + cif_string_reconstructed = "\n".join(cif_lines) + return cif_string_reconstructed + + def tokenize_cif(self, cif_string, max_length=1385): + # Preprocessing step to replace '_symmetry_space_group_name_H-M Pm' + # with '_symmetry_space_group_name_H-M Pm_sg',to disambiguate from atom 'Pm', + # or any space group symbol to avoid problematic cases, like 'P1' + spacegroups = "|".join(SPACE_GROUPS) + cif_string = re.sub(rf"(_symmetry_space_group_name_H-M *\b({spacegroups}))\n", r"\1_sg\n", cif_string) + + extracted_data = self.tokenize_cif_preprocess(cif_string) + + seq_res = "" + # formula + seq_res += "formula " + formula = extracted_data["formula"] + elements_counts = re.findall(r"([A-Z][a-z]*)(\d*)", formula) + for element, count in elements_counts: + if not element: + break + if not count: + count = "1" + seq_res += element + " " + count + "_int " + seq_res += "\n" + # space group name + seq_res += "space_group_symbol " + extracted_data["space_group_symbol"] + "\n" + # lattice + seq_res += "lattice_parameters " + "a " + extracted_data["lattice_parameters"]["a"] + " " + seq_res += "b " + extracted_data["lattice_parameters"]["b"] + " " + seq_res += "c " + extracted_data["lattice_parameters"]["c"] + " " + seq_res += "alpha " + extracted_data["lattice_parameters"]["alpha"] + " " + seq_res += "beta " + extracted_data["lattice_parameters"]["beta"] + " " + seq_res += "gamma " + extracted_data["lattice_parameters"]["gamma"] + " " + seq_res += "\n" + # atoms + for idx in range(len(extracted_data["atoms"])): + tmp = extracted_data["atoms"][idx] + seq_res += ( + tmp["type"] + + " " + + tmp["num"] + + "_int " + + tmp["coordinates"][0] + + " " + + tmp["coordinates"][1] + + " " + + tmp["coordinates"][2] + + "\n" + ) + seq_res += "\n" + # Create a regex pattern by joining the escaped tokens with '|' + token_pattern = "|".join(self._escaped_tokens) + # Add a regex pattern to match any sequence of characters separated by whitespace or punctuation + full_pattern = f"({token_pattern}|\\w+|[\\.,;!?])" + # Tokenize the input string using the regex pattern + seq_res = re.sub(r"[ \t]+", " ", seq_res) + # print(seq_res) + tokens = re.findall(full_pattern, seq_res) + # print(tokens) + padding_length = max_length - len(tokens) + if padding_length > 0: + tokens.extend([""] * padding_length) + + return tokens + + def tokenize_cif_preprocess(self, cif_string): + # Re-initialize the dictionary to hold the extracted data + extracted_data = {"space_group_symbol": "", "formula": "", "atoms": [], "lattice_parameters": {}} + + # Split the text into lines for processing + lines = cif_string.split("\n") + + # Iterate through each line to extract the required information + atom_line_idx = -1 + for line_idx in range(len(lines)): + line = lines[line_idx] + # Extract space group symbol + if "_symmetry_space_group_name_H-M" in line: + spacegroup_match = re.search(r"_symmetry_space_group_name_H-M\s+([^\n]+)", line) + spacegroup = spacegroup_match.group(1).strip() + extracted_data["space_group_symbol"] = spacegroup + # Extract formula + elif line.startswith("data_"): + extracted_data["formula"] = line.split("_")[1] + # Extract lattice parameters + elif line.startswith("_cell_length_a"): + extracted_data["lattice_parameters"]["a"] = line.split()[-1] + elif line.startswith("_cell_length_b"): + extracted_data["lattice_parameters"]["b"] = line.split()[-1] + elif line.startswith("_cell_length_c"): + extracted_data["lattice_parameters"]["c"] = line.split()[-1] + elif line.startswith("_cell_angle_alpha"): + extracted_data["lattice_parameters"]["alpha"] = line.split()[-1] + elif line.startswith("_cell_angle_beta"): + extracted_data["lattice_parameters"]["beta"] = line.split()[-1] + elif line.startswith("_cell_angle_gamma"): + extracted_data["lattice_parameters"]["gamma"] = line.split()[-1] + elif "_atom_site_occupancy" in line: + atom_line_idx = line_idx + 1 + break + + for line_idx in range(atom_line_idx, len(lines)): + line = lines[line_idx] + if len(line) < 2: + continue + atom_info = line.split() + atom_type = atom_info[0] + num_atoms = atom_info[2] + x, y, z = atom_info[3], atom_info[4], atom_info[5] + extracted_data["atoms"].append({"type": atom_type, "num": num_atoms, "coordinates": (x, y, z)}) + + return extracted_data + + +if __name__ == "__main__": + llm_output = """ +formula Si 1_int C 1_int +space_group_symbol P 1 +lattice_parameters a 3.07486950 b 3.07486950 c 5.04587300 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 +Si 1_int 0.66666700 0.33333300 0.46331700 +Si 1_int 0.33333300 0.66666700 0.96331700 +C 1_int 0.66666700 0.33333300 0.96331700 +C 1_int 0.33333300 0.66666700 0.46331700 +""" + + ground_truth = """data_MoS2 +_symmetry_space_group_name_H-M P6_3/mmc +_cell_length_a 3.19223791 +_cell_length_b 3.19223791 +_cell_length_c 13.37829400 +_cell_angle_alpha 90.00000000 +_cell_angle_beta 90.00000000 +_cell_angle_gamma 120.00000000 +_symmetry_Int_Tables_number 194 +_chemical_formula_structural MoS2 +_chemical_formula_sum 'Mo2 S4' +_cell_volume 118.06518982 +_cell_formula_units_Z 2 +loop_ + _symmetry_equiv_pos_site_id + _symmetry_equiv_pos_as_xyz + 1 'x, y, z' + 2 '-x, -y, -z' + 3 'x-y, x, z+1/2' + 4 '-x+y, -x, -z+1/2' + 5 '-y, x-y, z' + 6 'y, -x+y, -z' + 7 '-x, -y, z+1/2' + 8 'x, y, -z+1/2' + 9 '-x+y, -x, z' + 10 'x-y, x, -z' + 11 'y, -x+y, z+1/2' + 12 '-y, x-y, -z+1/2' + 13 '-y, -x, -z+1/2' + 14 'y, x, z+1/2' + 15 '-x, -x+y, -z' + 16 'x, x-y, z' + 17 '-x+y, y, -z+1/2' + 18 'x-y, -y, z+1/2' + 19 'y, x, -z' + 20 '-y, -x, z' + 21 'x, x-y, -z+1/2' + 22 '-x, -x+y, z+1/2' + 23 'x-y, -y, -z' + 24 '-x+y, y, z' +loop_ + _atom_site_type_symbol + _atom_site_label + _atom_site_symmetry_multiplicity + _atom_site_fract_x + _atom_site_fract_y + _atom_site_fract_z + _atom_site_occupancy + Mo Mo0 2 0.33333333 0.66666667 0.75000000 1 + S S1 4 0.33333333 0.66666667 0.13308200 1 +""" + + cif_tokenizer = CIFTokenizer() + result = cif_tokenizer.deserialize(llm_output, ground_truth) + + print("\n=== Deserialized Output ===") + print(result) diff --git a/src/open_r1/tasks/crystal_structure/AIRS_preprocess/spacegroups.txt b/src/open_r1/tasks/crystal_structure/AIRS_preprocess/spacegroups.txt new file mode 100644 index 00000000..767fcd16 --- /dev/null +++ b/src/open_r1/tasks/crystal_structure/AIRS_preprocess/spacegroups.txt @@ -0,0 +1,227 @@ +P6/mmm +Imma +P4_32_12 +P4_2/mnm +Fd-3m +P3m1 +P-3 +P4mm +P4_332 +P4/nnc +P2_12_12 +Pnn2 +Pbcn +P4_2/n +Cm +R3m +Cmce +Aea2 +P-42_1m +P-42m +P2_13 +R-3 +Fm-3 +Cmm2 +Pn-3n +P6/mcc +P-6m2 +P3_2 +P-3m1 +P3_212 +I23 +P-62m +P4_2nm +Pma2 +Pmma +I-42m +P-31c +Pa-3 +Pmmn +Pmmm +P4_2/ncm +I4/mcm +I-4m2 +P3_1 +Pcc2 +Cmcm +I222 +Fddd +P312 +Cccm +P6_1 +F-43c +P6_322 +Pm-3 +P3_121 +P6_4 +Ia-3d +Pm-3m +P2_1/c +C222_1 +Pc +P4/n +Pba2 +Ama2 +Pbcm +P31m +Pcca +P222 +P-43n +Pccm +P6_422 +F23 +P42_12 +C222 +Pnnn +P6_3cm +P4_12_12 +P6/m +Fmm2 +I4_1/a +P4/mbm +Pmn2_1 +P4_2bc +P4_22_12 +I-43d +I4/m +P4bm +Fdd2 +P3 +P6_122 +Pnc2 +P4_2/mcm +P4_122 +Cmc2_1 +P-6c2 +R32 +P4_1 +P4_232 +Pnna +P422 +Pban +Cc +I4_122 +P6_3/m +P6_3mc +I4_1/amd +P4_2 +P4/nmm +Pmna +P4/m +Fm-3m +P4/mmm +Imm2 +P4/ncc +P-62c +Ima2 +P6_5 +P2/c +P4/nbm +Ibam +P6_522 +P6_3/mmc +I4/mmm +Fmmm +P2/m +P-4b2 +I-4 +C2/m +P4_2/mmc +P4 +Fd-3c +P4_3 +P2_1/m +I-43m +P-42c +F4_132 +Pm +Pccn +P-4n2 +P4_132 +P23 +I4cm +R3c +Amm2 +Immm +Iba2 +I4 +Fd-3 +P1 +Pbam +P4_2/nbc +Im-3 +P4_2/nnm +Pmc2_1 +P-31m +R-3m +Ia-3 +P622 +F222 +P2 +P-1 +Pmm2 +P-4 +Aem2 +P6_222 +P-3c1 +P4_322 +I422 +Pnma +P6_3 +P3c1 +Pn-3 +P4nc +P-6 +P4/mcc +I2_12_12_1 +P4_2/mbc +P31c +Ccc2 +P4_2/nmc +P6_3/mcm +C2 +Pbca +P-4c2 +I4_1cd +P2_1 +P3_112 +P4_2mc +Pn-3m +C2/c +R3 +P-43m +I432 +P222_1 +I-42d +I-4c2 +P6cc +P6_2 +P3_221 +P321 +Pca2_1 +I4_1/acd +I4_132 +F432 +Pna2_1 +Ccce +Ibca +P4/mnc +I4_1md +P2_12_12_1 +R-3c +I2_13 +P-4m2 +Pm-3n +I4mm +F-43m +Pnnm +P-42_1c +Cmmm +P6mm +P4_2cm +P4_2/m +Im-3m +Fm-3c +I4_1 +P4cc +Cmme diff --git a/src/open_r1/tasks/crystal_structure/__init__.py b/src/open_r1/tasks/crystal_structure/__init__.py new file mode 100644 index 00000000..b4345109 --- /dev/null +++ b/src/open_r1/tasks/crystal_structure/__init__.py @@ -0,0 +1 @@ +"""Crystal structure task package.""" diff --git a/src/open_r1/tasks/crystal_structure/relaxing.py b/src/open_r1/tasks/crystal_structure/relaxing.py new file mode 100644 index 00000000..50cb987b --- /dev/null +++ b/src/open_r1/tasks/crystal_structure/relaxing.py @@ -0,0 +1,670 @@ +import os +import re +from dataclasses import field +from io import StringIO +from random import random +from typing import Dict, Optional + +import pandas as pd +from datasets import Dataset, DatasetDict + +from open_r1.tasks.base import RLTask + +from .AIRS_preprocess._tokenizer import CIFTokenizer + +cif_tokenizer = CIFTokenizer() + +CRYSTALRELAX_DATA_URL = "https://drive.google.com/drive/folders/1gt3E8OIGHbs2B8cKJOAnXaoM3K_loa9k?usp=sharing" +CRYSTALRELAX_REQUIRED_FILES = ( + "src-train.txt", + "tgt-train.txt", + "src-test.txt", + "tgt-test.txt", +) + + +def _expand_local_path(path: str) -> str: + raw_path = os.fspath(path) + if "MIST_DATA_DIR" in raw_path and "MIST_DATA_DIR" not in os.environ: + repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../..")) + data_root = os.path.join(repo_root, "data") + raw_path = raw_path.replace("${MIST_DATA_DIR}", data_root).replace("$MIST_DATA_DIR", data_root) + return os.path.expandvars(os.path.expanduser(raw_path)) + + +def _missing_crystalrelax_files(data_path: str) -> list[str]: + return [name for name in CRYSTALRELAX_REQUIRED_FILES if not os.path.exists(os.path.join(data_path, name))] + + +def download_crystalrelax_data(data_path: str) -> str: + """Download the released crystal relaxation dataset if it is not already present.""" + data_path = _expand_local_path(data_path) + if not _missing_crystalrelax_files(data_path): + return data_path + + os.makedirs(data_path, exist_ok=True) + + import gdown + + try: + gdown.download_folder( + url=CRYSTALRELAX_DATA_URL, + output=data_path, + quiet=False, + use_cookies=False, + remaining_ok=True, + ) + except TypeError: + gdown.download_folder( + url=CRYSTALRELAX_DATA_URL, + output=data_path, + quiet=False, + use_cookies=False, + ) + + missing_files = _missing_crystalrelax_files(data_path) + if missing_files: + raise FileNotFoundError( + f"Downloaded crystalrelax data to {data_path}, but missing required files: {', '.join(missing_files)}" + ) + return data_path + + +class BinaryCompoundRelaxing(RLTask): + _mace_calculator = None + _mace_device = None + + src_train_file: str = "" + tgt_train_file: str = "" + src_test_file: str = "" + tgt_test_file: str = "" + question_template: str = "" + log_custom_metrics: bool = True + custom_metrics: dict = field(default_factory=dict) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.dataset_id_or_path = download_crystalrelax_data(self.dataset_id_or_path) + + self.src_train_file = os.path.join(self.dataset_id_or_path, "src-train.txt") + self.tgt_train_file = os.path.join(self.dataset_id_or_path, "tgt-train.txt") + src_test_path = os.path.join(self.dataset_id_or_path, "src-test.txt") + tgt_test_path = os.path.join(self.dataset_id_or_path, "tgt-test.txt") + self.src_test_file = src_test_path if os.path.exists(src_test_path) else None + self.tgt_test_file = tgt_test_path if os.path.exists(tgt_test_path) else None + self.question_template = ( + "system You are a seasoned crystallographic structure analysis expert. " + "Your task is to relax a binary compound to a stable state.\n" + "user Given a perturbed binary compound:\n" + "{}\n, perform multiple steps of Structural Relaxation on the given perturbed binary compound " + "and reduce the internal energy. Please document your thought process within tags, and provide " + "the final corrected structure in tags using the proper m2s format as given in the example:\n" + "serialized_cif formula Cd 1_int As 2_int \n" + "space_group_symbol I4_122_sg\n" + "lattice_parameters a 8.03811770 b 8.03811770 c 4.72563470 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 \n" + "Cd 4_int 0.00000000 0.00000000 0.00000000\n" + "As 8_int 0.06170692 0.25000000 0.62500000\n" + ) + self.log_custom_metrics = True + self.custom_metrics = { + "val/rewards": [], + } + + # Dataset here: /iopsstor/store/cscs/swissai/a05/chem/binary_compound_relaxing + + @classmethod + def _get_mace_device(cls) -> str: + if cls._mace_device is not None: + return cls._mace_device + + import torch + + cls._mace_device = "cuda" if torch.cuda.is_available() else "cpu" + + return cls._mace_device + + @classmethod + def _get_mace_calculator(cls): + if cls._mace_calculator is None: + from mace.calculators import mace_mp + + cls._mace_calculator = mace_mp(model="large", device=cls._get_mace_device()) + return cls._mace_calculator + + def read_files(self, src_file: str, tgt_file: str) -> Dict: + """Read source and target files and create dataset dictionary.""" + + def read_records(file_path: str) -> list: + """Helper function to read multi-line records separated by blank lines.""" + with open(file_path, "r", encoding="utf-8") as f: + lines = f.readlines() + records = [] + current_record = [] + for line in lines: + if line.strip() == "": # Blank line indicates end of a record + if current_record: + records.append("\n".join(current_record)) + current_record = [] + else: + current_record.append(line.strip()) + if current_record: # Append the last record if file doesn't end with blank line + records.append("\n".join(current_record)) + return records + + # Read records from source and target files + src_records = read_records(src_file) + tgt_records = read_records(tgt_file) + + # Generate problems using the question template + problems = [self.question_template.format(record) for record in src_records] + # Solutions are the raw target records (assuming no further processing needed) + solutions = tgt_records + + return { + "problem": problems, + "solution": solutions, + } + + def load(self) -> DatasetDict: + """Load and return the complete dataset.""" + # Load training data + train_dict = self.read_files(self.src_train_file, self.tgt_train_file) + train_dataset = Dataset.from_dict(train_dict) + + # Load or create test data + if self.src_test_file and self.tgt_test_file: + test_dict = self.read_files(self.src_test_file, self.tgt_test_file) + test_dataset = Dataset.from_dict(test_dict) + else: + # Create test split from training data + train_test_split = train_dataset.train_test_split(test_size=0.1) + train_dataset = train_test_split["train"].unique(column="solution") + test_dataset = train_test_split["test"] + + # Combine into DatasetDict + self.dataset = DatasetDict({"train": train_dataset, "test": test_dataset}) + + return self.dataset + + def accuracy_reward(self, completions, solution, **kwargs): + """Reward function - check that completion is same as ground truth.""" + + import gemmi + from ase.io import read as ase_read + from pymatgen.core import Structure + + def compute_internal_score(answer_cif, ground_truth_dict, alpha=5.0): + def sanitize_cif(cif_str): + lines = cif_str.splitlines() + in_symmetry_loop = False + new_lines = [] + for line in lines: + stripped = line.strip() + if stripped.startswith("loop_"): + in_symmetry_loop = False + new_lines.append(line) + continue + if not in_symmetry_loop and "_symmetry_equiv_pos_as_xyz" in line: + in_symmetry_loop = True + new_lines.append(line) + continue + if in_symmetry_loop: + if stripped == "" or stripped.startswith("_") or stripped.startswith("loop_"): + in_symmetry_loop = False + new_lines.append(line) + else: + line = re.sub(r'"([^"]+)"', r"'\1'", line) + new_lines.append(line) + else: + new_lines.append(line) + return "\n".join(new_lines) + + def parse_llm_structure(cif_content): + sanitized = sanitize_cif(cif_content) + try: + return Structure.from_str(sanitized, fmt="cif") + except Exception: + return None + + def compare_internal_energy(cif1, cif2): + # uses ASE + MACE to get per‐atom potential energies + atoms1 = ase_read(StringIO(cif1), format="cif") + atoms2 = ase_read(StringIO(cif2), format="cif") + calc = self._get_mace_calculator() + atoms1.calc = calc + atoms2.calc = calc + e1 = atoms1.get_potential_energy() / len(atoms1) + e2 = atoms2.get_potential_energy() / len(atoms2) + if e1 < e2: + return 0.5 + elif e1 > e2: + return 1 + else: + return 0 + + gt_cif = ground_truth_dict + if not gt_cif: + return 0 + # first, reformat / deserialize via tokenizer + try: + answer_cif = cif_tokenizer.deserialize(answer_cif, gt_cif) + except Exception: + return 0 + + # quick gemmi checks + try: + for s in (gt_cif, answer_cif): + doc = gemmi.cif.read_string(s) + doc.check_for_missing_values() + doc.check_for_duplicates() + except Exception: + return 0 + + # parse Pymatgen structures + try: + Structure.from_str(gt_cif, fmt="cif") + except Exception: + return 0 + + llm_struct = parse_llm_structure(answer_cif) + if llm_struct is None: + return 0 + + # energy‐based reward + energy_reward = compare_internal_energy(gt_cif, answer_cif) + + # choose which to return (here using energy check as original) + return energy_reward + + rewards = [] + # Here task is simple: check that the smiles is the same as the target s + for content, sol in zip(completions, solution): + content = self.preprocess_response(content) + if content == "NONE": + rewards.append(0) + continue + + # server_url = os.environ.get("SERVER_URL", "http://10.197.48.175:9001/compute_score") + if content == sol: + rewards.append(0) + continue + + try: + reward = compute_internal_score(content, sol) + rewards.append(reward) + except Exception: + rewards.append(0) + if self.log_custom_metrics: + self.custom_metrics["val/rewards"].extend(rewards) + return rewards + + def format_reward(self, completions, **kwargs): + """ + Format: ...... + Args: + completions (list[str]): Generated outputs + + Returns: + list[float]: Reward scores + """ + rewards = [] + + # detect malformed or missing tags + tag_regex = re.compile(r"(.*?)\s*(.*?)", re.DOTALL) + space_groups = [ + "P6/mmm", + "Imma", + "P4_32_12", + "P4_2/mnm", + "Fd-3m", + "P3m1", + "P-3", + "P4mm", + "P4_332", + "P4/nnc", + "P2_12_12", + "Pnn2", + "Pbcn", + "P4_2/n", + "Cm", + "R3m", + "Cmce", + "Aea2", + "P-42_1m", + "P-42m", + "P2_13", + "R-3", + "Fm-3", + "Cmm2", + "Pn-3n", + "P6/mcc", + "P-6m2", + "P3_2", + "P-3m1", + "P3_212", + "I23", + "P-62m", + "P4_2nm", + "Pma2", + "Pmma", + "I-42m", + "P-31c", + "Pa-3", + "Pmmn", + "Pmmm", + "P4_2/ncm", + "I4/mcm", + "I-4m2", + "P3_1", + "Pcc2", + "Cmcm", + "I222", + "Fddd", + "P312", + "Cccm", + "P6_1", + "F-43c", + "P6_322", + "Pm-3", + "P3_121", + "P6_4", + "Ia-3d", + "Pm-3m", + "P2_1/c", + "C222_1", + "Pc", + "P4/n", + "Pba2", + "Ama2", + "Pbcm", + "P31m", + "Pcca", + "P222", + "P-43n", + "Pccm", + "P6_422", + "F23", + "P42_12", + "C222", + "Pnnn", + "P6_3cm", + "P4_12_12", + "P6/m", + "Fmm2", + "I4_1/a", + "P4/mbm", + "Pmn2_1", + "P4_2bc", + "P4_22_12", + "I-43d", + "I4/m", + "P4bm", + "Fdd2", + "P3", + "P6_122", + "Pnc2", + "P4_2/mcm", + "P4_122", + "Cmc2_1", + "P-6c2", + "R32", + "P4_1", + "P4_232", + "Pnna", + "P422", + "Pban", + "Cc", + "I4_122", + "P6_3/m", + "P6_3mc", + "I4_1/amd", + "P4_2", + "P4/nmm", + "Pmna", + "P4/m", + "Fm-3m", + "P4/mmm", + "Imm2", + "P4/ncc", + "P-62c", + "Ima2", + "P6_5", + "P2/c", + "P4/nbm", + "Ibam", + "P6_522", + "P6_3/mmc", + "I4/mmm", + "Fmmm", + "P2/m", + "P-4b2", + "I-4", + "C2/m", + "P4_2/mmc", + "P4", + "Fd-3c", + "P4_3", + "P2_1/m", + "I-43m", + "P-42c", + "F4_132", + "Pm", + "Pccn", + "P-4n2", + "P4_132", + "P23", + "I4cm", + "R3c", + "Amm2", + "Immm", + "Iba2", + "I4", + "Fd-3", + "P1", + "Pbam", + "P4_2/nbc", + "Im-3", + "P4_2/nnm", + "Pmc2_1", + "P-31m", + "R-3m", + "Ia-3", + "P622", + "F222", + "P2", + "P-1", + "Pmm2", + "P-4", + "Aem2", + "P6_222", + "P-3c1", + "P4_322", + "I422", + "Pnma", + "P6_3", + "P3c1", + "Pn-3", + "P4nc", + "P-6", + "P4/mcc", + "I2_12_12_1", + "P4_2/mbc", + "P31c", + "Ccc2", + "P4_2/nmc", + "P6_3/mcm", + "C2", + "Pbca", + "P-4c2", + "I4_1cd", + "P2_1", + "P3_112", + "P4_2mc", + "Pn-3m", + "C2/c", + "R3", + "P-43m", + "I432", + "P222_1", + "I-42d", + "I-4c2", + "P6cc", + "P6_2", + "P3_221", + "P321", + "Pca2_1", + "I4_1/acd", + "I4_132", + "F432", + "Pna2_1", + "Ccce", + "Ibca", + "P4/mnc", + "I4_1md", + "P2_12_12_1", + "R-3c", + "I2_13", + "P-4m2", + "Pm-3n", + "I4mm", + "F-43m", + "Pnnm", + "P-42_1c", + "Cmmm", + "P6mm", + "P4_2cm", + "P4_2/m", + "Im-3m", + "Fm-3c", + "I4_1", + "P4cc", + "Cmme", + ] + escaped = [re.escape(sg) for sg in space_groups] + pattern = r"\b(?:" + "|".join(escaped) + r")\b" + + symmetry_pattern = re.compile(pattern, re.IGNORECASE) + # bonus patterns + cif_pattern = re.compile( + r"\b(" + r"cif|space\s+group|unit\s+cell|lattice|" + r"symmetry|fractional\s+coordinates|cell\s+parameters|" + r"bond\s+length|bond\s+angle|volume|Wyckoff|" + r"atomic\s+positions|occupancy|site\s+multiplicity" + r")\b", + re.IGNORECASE, + ) + math_pattern = re.compile( + # looks for sqrt(…), (…)^2, (…)=(…), or simple a±b/c etc. + r"(?:sqrt\s*\(|\([^)]*\)\s*\^\s*2|[0-9\.\)]+\s*[\+\-\*/=]\s*[0-9\.\(])", + re.IGNORECASE, + ) + position_pattern = re.compile( + r"\b(" + r"position|pos\.?|coordinate|coord\.?|site|" + r"atomic\s+position|fractional\s+coord(?:inate)?s?|" + r"xyz|uvw" + r")\b", + re.IGNORECASE, + ) + lattice_angle_pattern = re.compile(r"\b(a=|b=|c=|c/a|gamma\s*=?\s*\d+(\.\d+)?°?)\b", re.IGNORECASE) + crystallographic_pattern = re.compile( + r"\b(" + r"Wyckoff|multiplicity|asymmetric unit|mirror plane|inversion center|" + r"Bravais lattice|primitive cell|supercell" + r")\b", + re.IGNORECASE, + ) + energy_force_pattern = re.compile( + r"\b(" + r"formation energy|total energy|enthalpy|residual force|stress|" + r"converged energy|converged stress|force\s*<\s*0\.01\s*eV/Å" + r")\b", + re.IGNORECASE, + ) + dynamical_pattern = re.compile( + r"\b(" + r"phonon dispersion|imaginary mode|soft mode|dynamical stability|" + r"elastic constant|Born criteria" + r")\b", + re.IGNORECASE, + ) + classification_pattern = re.compile( + r"\b(" + r"perovskite|spinel|rocksalt|layered oxide|phase transition|" + r"Jahn[- ]Teller distortion|olivine|rutile" + r")\b", + re.IGNORECASE, + ) + chemical_pattern = re.compile( + r"\b(" + r"bond length|bond angle|electronegativity|Bader charge|" + r"electron localization|coordination number|coordination environment|ionic radius" + r")\b", + re.IGNORECASE, + ) + + for completion in completions: + try: + # ensure it at least starts in the right place + if not completion.startswith(""): + completion = "" + completion + + m = tag_regex.search(completion) + if not m: + rewards.append(0.0) + continue + + bonus = 0.0 + if cif_pattern.search(completion): + bonus += 0.1 + if math_pattern.search(completion): + bonus += 0.2 + if position_pattern.search(completion): + bonus += 0.1 + if symmetry_pattern.search(completion): + bonus += 0.1 + if lattice_angle_pattern.search(completion): + bonus += 0.05 + if crystallographic_pattern.search(completion): + bonus += 0.1 + if energy_force_pattern.search(completion): + bonus += 0.1 + if dynamical_pattern.search(completion): + bonus += 0.1 + if classification_pattern.search(completion): + bonus += 0.1 + if chemical_pattern.search(completion): + bonus += 0.05 + + rewards.append(bonus) + + except Exception: + rewards.append(0.0) + + return rewards + + def preprocess_response(self, response): + """Preprocess the response before checking for accuracy.""" + pattern = r"(.*)<\/answer>" + m = re.findall(pattern, response, re.DOTALL) + if m: + return m[-1].strip() + else: + return "NONE" + + def get_metrics(self) -> Dict: + """ + Get task metrics to log in WANDB. + This function takes no arguments and returns a dictionary of metrics {key[str]: value[float]}. + """ + metrics = dict() + if self.log_custom_metrics: + rewards = self.custom_metrics["val/rewards"] + if rewards: + correct_count = sum(1 for r in rewards if r == 1) + total_count = len(rewards) + accuracy = correct_count / total_count if total_count > 0 else 0.0 + metrics["val/accuracy"] = accuracy + self.custom_metrics["val/rewards"] = [] + return metrics