diff --git a/docs/source/user_guide/command_line.rst b/docs/source/user_guide/command_line.rst index 154cac46..88855693 100644 --- a/docs/source/user_guide/command_line.rst +++ b/docs/source/user_guide/command_line.rst @@ -659,7 +659,7 @@ Training and fine-tuning MLIPs ------------------------------ .. note:: - Currently only MACE and Nequip models are supported. + Currently MACE, Nequip, and SevenNet models are supported. Models can be trained by passing an archictecture and an archictecture specific configuration file as options to the ``janus train`` command. The configuration file will be passed to the corresponding MLIPs command line interface. For example to train a MACE MLIP: @@ -699,6 +699,37 @@ Configuration of Nequip training is outlined in the `Nequip user guide `_ and the `MACE run_train CLI `_. + + +Training Nequip MLIPS ++++++++++++++++++++++ + +Configuration of Nequip training is outlined in the `Nequip user guide `_. In particular note that the configuration file must have a ``.yaml`` extension. + +The results directory contents depends on the options selected in the configuration file, but may typically contain model checkpoint, ``.ckpt``, files and a metrics directory. + + +Training SevenNet MLIPS ++++++++++++++++++++++++ + +The `SevenNet documentation `_ contains information on training SevenNet MLIPs. The SevenNet `tutorial repository `_ also contains some example ```.yaml``` configuration files for training and fine-tuning. Preprocessing training data ---------------------------- diff --git a/janus_core/cli/train.py b/janus_core/cli/train.py index b81b146c..67d6dfa0 100644 --- a/janus_core/cli/train.py +++ b/janus_core/cli/train.py @@ -122,6 +122,45 @@ def train( """Fine-tuning requested but there is no checkpoint or package specified in your config.""" ) + case "sevennet": + continue_section = config["train"].get("continue") + if continue_section is None and fine_tune: + raise ValueError( + """Fine-tuning requested but there is no continue + section in your config.""" + ) + model = continue_section.get("checkpoint") + if model is None: + raise ValueError( + """No model specified as a checkpoint for + fine-tuning. + """ + ) + if not fine_tune and continue_section is not None: + raise ValueError( + """Fine-tuning not requested but a continue + section is in your config. Please use + --fine-tune""" + ) + + case "grace": + if "potential" not in config: + raise ValueError("No potential is specified in you config.") + + if fine_tune: + model = config["potential"].get("finetune_foundation_model") + if model is None: + raise ValueError( + """Fine-tuning was requested but your conifg + does not contains a finetune_foundation_model""" + ) + elif "finetune_foundation_model" in config["potential"]: + raise ValueError( + """Fine-tuning not requested but finetune_foundation_model + is in your config. Please use --fine-tune. + """ + ) + case _: raise ValueError(f"Unsupported Architecture ({arch})") diff --git a/janus_core/training/train.py b/janus_core/training/train.py index f8ffe4a7..65d33f6e 100644 --- a/janus_core/training/train.py +++ b/janus_core/training/train.py @@ -2,6 +2,8 @@ from __future__ import annotations +from argparse import ArgumentParser +from pathlib import Path from typing import Any import yaml @@ -93,6 +95,20 @@ def train( ) foundation_model = model["checkpoint_path"] + case "sevennet": + from sevenn.main.sevenn import cmd_parser_train, run + + parser = ArgumentParser() + cmd_parser_train(parser) + mlip_args = parser.parse_args( + [str(mlip_config), "--working_dir", str(file_prefix), "-s"] + ) + + case "grace": + from tensorpotential.cli.gracemaker import main as run + + mlip_args = [str(mlip_config)] + case _: raise ValueError(f"{arch} is currently unsupported in train.") @@ -120,6 +136,11 @@ def train( run(mlip_args) + if arch == "grace" and (Path.cwd() / "seed").exists(): + # Gracemaker always works in ./seed. + file_prefix.mkdir(parents=True, exist_ok=True) + (Path.cwd() / "seed").rename(file_prefix.resolve() / "seed") + if logger: logger.info("Training complete") if tracker: diff --git a/tests/data/grace_fine_tune.yml b/tests/data/grace_fine_tune.yml new file mode 100644 index 00000000..0831936a --- /dev/null +++ b/tests/data/grace_fine_tune.yml @@ -0,0 +1,35 @@ +seed: 42 +cutoff: 6 + +data: + filename: "tests/data/mlip_train.pkl.gz" + reference_energy: 0 + +potential: + finetune_foundation_model: "GRACE-1L-OAM" + +fit: + loss: + energy: + type: huber + weight: 17 + delta: 0.01 + forces: + type: huber + weight: 32. + delta: 0.01 + + maxiter: 1 # Max number of optimization epochs + optimizer: Adam + opt_params: { learning_rate: 0.008, use_ema: True, ema_momentum: 0.99, weight_decay: 1.e-20, clipnorm: 1.0} + scheduler: cosine_decay # scheduler for learning-rate reduction during training + scheduler_params: {"minimal_learning_rate": 0.0001} + + batch_size: 32 # Important hyperparameter for Adam and irrelevant (but must be) for L-BFGS-B/BFGS + test_batch_size: 200 # test batch size (optional) + + jit_compile: True # for XLA compilation, must be used in almost all cases + train_max_n_buckets: 10 ## max number of buckets in train set + test_max_n_buckets: 3 ## same for test + + checkpoint_freq: 10 # frequency for **REGULAR** checkpoints. diff --git a/tests/data/mlip_train.pkl.gz b/tests/data/mlip_train.pkl.gz new file mode 100644 index 00000000..281e4a98 Binary files /dev/null and b/tests/data/mlip_train.pkl.gz differ diff --git a/tests/data/sevennet_fine_tune.yml b/tests/data/sevennet_fine_tune.yml new file mode 100644 index 00000000..965a9438 --- /dev/null +++ b/tests/data/sevennet_fine_tune.yml @@ -0,0 +1,92 @@ +model: + chemical_species: auto + + cutoff: 2.0 + irreps_manual: + - 128x0e + - 128x0e+64x1e+32x2e+32x3e + - 128x0e+64x1e+32x2e+32x3e + - 128x0e+64x1e+32x2e+32x3e + - 128x0e+64x1e+32x2e+32x3e + - 128x0e + channel: 128 + lmax: 3 + num_convolution_layer: 5 + is_parity: false + radial_basis: + radial_basis_name: bessel + bessel_basis_num: 8 + cutoff_function: + cutoff_function_name: poly_cut + poly_cut_p_value: 6 + + act_radial: silu + weight_nn_hidden_neurons: + - 64 + - 64 + act_scalar: + e: silu + o: tanh + act_gate: + e: silu + o: tanh + + train_denominator: false + train_shift_scale: false + use_bias_in_linear: false + + readout_as_fcn: false + self_connection_type: linear + interaction_type: nequip + +train: + random_seed: 1 + is_train_stress: True + epoch: 1 + + + + optimizer: 'adam' + optim_param: + lr: 0.005 + scheduler: 'exponentiallr' + scheduler_param: + gamma: 0.99 + + force_loss_weight: 0.1 + stress_loss_weight: 1e-06 + + per_epoch: 1 + + + + error_record: + - ['Energy', 'RMSE'] + - ['Force', 'RMSE'] + - ['Stress', 'RMSE'] + - ['TotalLoss', 'None'] + + continue: + reset_optimizer: True + reset_scheduler: True + reset_epoch: True + checkpoint: 'tests/models/extra/SevenNet_l3i5.pth' + + use_statistic_values_of_checkpoint: True + +data: + batch_size: 4 + data_divide_ratio: 0.1 + + shift: 'per_atom_energy_mean' + scale: 'force_rms' + + + + data_format: 'ase' + data_format_args: + index: ':' + + + + load_dataset_path: ['tests/data/mlip_train.xyz'] diff --git a/tests/data/sevennet_train.yml b/tests/data/sevennet_train.yml new file mode 100644 index 00000000..09eba5fa --- /dev/null +++ b/tests/data/sevennet_train.yml @@ -0,0 +1,55 @@ +model: + chemical_species: 'Auto' + cutoff: 2.0 + channel: 4 + lmax: 1 + num_convolution_layer: 1 + + weight_nn_hidden_neurons: [4, 4] + radial_basis: + radial_basis_name: 'bessel' + bessel_basis_num: 8 + cutoff_function: + cutoff_function_name: 'poly_cut' + poly_cut_p_value: 6 + + act_gate: {'e': 'silu', 'o': 'tanh'} + act_scalar: {'e': 'silu', 'o': 'tanh'} + + is_parity: False + + self_connection_type: 'nequip' + + conv_denominator: "avg_num_neigh" + train_denominator: False + train_shift_scale: False + +train: + random_seed: 1 + is_train_stress: True + epoch: 2 + optimizer: 'adam' + optim_param: + lr: 0.005 + scheduler: 'exponentiallr' + scheduler_param: + gamma: 0.99 + force_loss_weight: 0.1 + stress_loss_weight: 1e-06 + per_epoch: 1 + error_record: + - ['Energy', 'RMSE'] + - ['Force', 'RMSE'] + - ['Stress', 'RMSE'] + - ['TotalLoss', 'None'] + +data: + batch_size: 4 + data_divide_ratio: 0.1 + + shift: 'per_atom_energy_mean' + scale: 'force_rms' + data_format: 'ase' + data_format_args: + index: ':' + load_dataset_path: ['tests/data/mlip_train.xyz'] diff --git a/tests/models/extra_models.py b/tests/models/extra_models.py index d05948c1..c93d9777 100644 --- a/tests/models/extra_models.py +++ b/tests/models/extra_models.py @@ -14,7 +14,13 @@ args = parser.parse_args() args.path.mkdir(parents=True, exist_ok=True) + urlretrieve( "https://zenodo.org/records/16980200/files/NequIP-MP-L-0.1.nequip.zip", filename=args.path / "NequIP-MP-L-0.1.nequip.zip", ) + + urlretrieve( + "https://github.com/MDIL-SNU/SevenNet/raw/dff008ac9c53d368b5bee30a27fa4bdfd73f19b2/sevenn/pretrained_potentials/SevenNet_l3i5/checkpoint_l3i5.pth", + filename=args.path / "SevenNet_l3i5.pth", + ) diff --git a/tests/models/extra_models.sh b/tests/models/extra_models.sh new file mode 100644 index 00000000..b6836315 --- /dev/null +++ b/tests/models/extra_models.sh @@ -0,0 +1,6 @@ +if [ ! -d tests/models/extra ] +then + mkdir tests/models/extra +fi + +(cd tests/models/extra; curl --output NequIP-MP-L-0.1.nequip.zip https://zenodo.org/records/16980200/files/NequIP-MP-L-0.1.nequip.zip) diff --git a/tests/test_train_cli.py b/tests/test_train_cli.py index e06d032e..b565ab54 100644 --- a/tests/test_train_cli.py +++ b/tests/test_train_cli.py @@ -4,6 +4,7 @@ from pathlib import Path +from ase.io import read, write import pytest from typer.testing import CliRunner import yaml @@ -14,15 +15,16 @@ chdir, check_output_files, clear_log_handlers, + rename_atoms_attributes, skip_extras, strip_ansi_codes, ) DATA_PATH = Path(__file__).parent / "data" MODEL_PATH = Path(__file__).parent / "models" -NEQUIP_EXTRA_MODEL_PATH = ( - Path(__file__).parent / "models" / "extra" / "NequIP-MP-L-0.1.nequip.zip" -) +EXTRA_MODEL_PATH = MODEL_PATH / "extra" +NEQUIP_EXTRA_MODEL_PATH = EXTRA_MODEL_PATH / "NequIP-MP-L-0.1.nequip.zip" +SEVENNET_EXTRA_MODEL_PATH = EXTRA_MODEL_PATH / "SevenNet_l3i5.pth" runner = CliRunner() @@ -80,8 +82,10 @@ def write_tmp_config_nequip( Path to yaml config file to be fixed. tmp_path Temporary path from pytest in which to write corrected config. - model_path - Path to a saved model. + fine_tune + Whether fine tuning. Default is False. + model_type + If using a package or checkpoint file. Returns ------- @@ -109,6 +113,13 @@ def write_tmp_config_nequip( if (MODEL_PATH / pth).is_file(): model_dict[f"{model_type}_path"] = str(MODEL_PATH / pth) + if fine_tune: + model_dict = config["training_module"]["model"] + model = Path(model_dict[f"{model_type}_path"]).name + for pth in (model, f"extra/{model}"): + if (MODEL_PATH / pth).is_file(): + model_dict[f"{model_type}_path"] = str(MODEL_PATH / pth) + # Write out temporary config with corrected paths tmp_config = tmp_path / "config.yaml" with open(tmp_config, "w", encoding="utf8") as file: @@ -117,6 +128,92 @@ def write_tmp_config_nequip( return tmp_config +def write_tmp_data_sevennet( + config_path: Path, tmp_path: Path, fine_tune: bool = False +) -> Path: + """ + Fix paths and data columns, write config and data to tmp_path. + + Parameters + ---------- + config_path + Path to yaml config file to be fixed. + tmp_path + Temporary path from pytest in which to write corrected config. + + Returns + ------- + Path. + Temporary path to corrected config file. + """ + # Load config from tests/data + with open(config_path, encoding="utf8") as file: + config = yaml.safe_load(file) + + # Use DATA_PATH to set paths relative to this test file + for dataset in config["data"].keys() & {"load_dataset_path", "load_validset_path"}: + files = config["data"][dataset] + for i, file in enumerate(files): + name = Path(file).name + path = DATA_PATH / name + if path.exists(): + frames = read(path, index=":") + # There is currenlty no option to rename these. + rename_info = {"dft_energy": "energy", "dft_stress": "stress"} + rename_arrays = {"dft_forces": "forces"} + for frame in frames: + rename_atoms_attributes(frame, rename_info, rename_arrays) + write(tmp_path / name, frames) + files[i] = str(tmp_path / name) + + if fine_tune: + model = Path(config["train"]["continue"]["checkpoint"]).name + if (MODEL_PATH / "extra" / model).exists(): + config["train"]["continue"]["checkpoint"] = str( + MODEL_PATH / "extra" / model + ) + + # Write out temporary config with corrected paths + tmp_config = tmp_path / "config.yml" + with open(tmp_config, "w", encoding="utf8") as file: + yaml.dump(config, file) + + return tmp_config + + +def write_tmp_config_grace(config_path: Path, tmp_path: Path) -> Path: + """ + Fix paths in config files and write corrected config to tmp_path for grace. + + Parameters + ---------- + config_path + Path to yaml config file to be fixed. + tmp_path + Temporary path from pytest in which to write corrected config. + + Returns + ------- + Path + Temporary path to corrected config file. + """ + # Load config from tests/data + with open(config_path, encoding="utf8") as file: + config = yaml.safe_load(file) + + # Use DATA_PATH to set paths relative to this test file + for file in config["data"].keys() & {"filename", "test_filename"}: + if (DATA_PATH / Path(config["data"][file]).name).exists(): + config["data"][file] = str(DATA_PATH / Path(config["data"][file]).name) + + # Write out temporary config with corrected paths + tmp_config = tmp_path / "config.yml" + with open(tmp_config, "w", encoding="utf8") as file: + yaml.dump(config, file) + + return tmp_config + + def test_help(): """Test calling `janus train --help`.""" result = runner.invoke(app, ["train", "--help"]) @@ -227,7 +324,7 @@ def test_fine_tune(tmp_path): summary_path = tmp_path / "summary.yml" logs_path = results_dir / "logs" - config = write_tmp_config_mace(DATA_PATH / "mlip_fine_tune.yml", Path()) + config = write_tmp_config_mace(DATA_PATH / "mlip_fine_tune.yml", Path.cwd()) result = runner.invoke( app, @@ -469,3 +566,154 @@ def test_nequip_fine_tune_foundation(tmp_path): header = metrics.readline().split(",") assert header[:3] == ["epoch", "lr-Adam", "step"] + + +def test_sevennet_train(tmp_path): + """Test training with sevennet.""" + skip_extras("sevennet") + + with chdir(tmp_path): + log_path = tmp_path / "test.log" + summary_path = tmp_path / "summary.yml" + + results_dir = Path("janus_results") + + checkpoints_paths = [ + results_dir / f"checkpoint_{ver}.pth" for ver in ("0", "1", "best") + ] + sevennet_log_path = results_dir / "log.sevenn" + sevenn_data_path = results_dir / "sevenn_data" + metrics_path = results_dir / "lc.csv" + + config_path = DATA_PATH / "sevennet_train.yml" + config_path = write_tmp_data_sevennet(config_path, tmp_path) + + result = runner.invoke( + app, + [ + "train", + "sevennet", + "--mlip-config", + config_path, + "--log", + log_path, + "--summary", + summary_path, + ], + ) + assert result.exit_code == 0 + + assert results_dir.exists() + assert log_path.exists() + assert summary_path.exists() + assert sevennet_log_path.exists() + assert sevenn_data_path.is_dir() + assert metrics_path.exists() + + for checkpoint in checkpoints_paths: + assert checkpoint.exists() + + with open(metrics_path) as metrics: + lines = metrics.readlines() + assert len(lines) == 3 + assert lines[0].split(",")[0] == "epoch" + + +@pytest.mark.skipif( + not SEVENNET_EXTRA_MODEL_PATH.exists(), + reason=f"Extra model: {SEVENNET_EXTRA_MODEL_PATH} not downloaded.", +) +def test_sevennet_fine_tune_foundation(tmp_path): + """Test training with sevennet.""" + skip_extras("sevennet") + + with chdir(tmp_path): + log_path = tmp_path / "test.log" + summary_path = tmp_path / "summary.yml" + + results_dir = Path("janus_results") + + checkpoints_paths = [ + results_dir / f"checkpoint_{ver}.pth" for ver in ("0", "1", "best") + ] + sevennet_log_path = results_dir / "log.sevenn" + sevenn_data_path = results_dir / "sevenn_data" + metrics_path = results_dir / "lc.csv" + + config_path = DATA_PATH / "sevennet_fine_tune.yml" + config_path = write_tmp_data_sevennet(config_path, tmp_path, True) + + result = runner.invoke( + app, + [ + "train", + "sevennet", + "--mlip-config", + config_path, + "--fine-tune", + "--log", + log_path, + "--summary", + summary_path, + ], + ) + assert result.exit_code == 0 + + assert results_dir.exists() + assert log_path.exists() + assert summary_path.exists() + assert sevennet_log_path.exists() + assert sevenn_data_path.is_dir() + assert metrics_path.exists() + + for checkpoint in checkpoints_paths: + assert checkpoint.exists() + + with open(metrics_path) as metrics: + lines = metrics.readlines() + assert len(lines) == 2 + assert lines[0].split(",")[0] == "epoch" + + +def test_grace_fine_tune_foundation(tmp_path): + """Test fine tuning grace.""" + skip_extras("grace") + + with chdir(tmp_path): + log_path = tmp_path / "test.log" + summary_path = tmp_path / "summary.yml" + + results_dir = Path("janus_results") + grace_dir = results_dir / "seed" / "42" + + metrics_path = grace_dir / "train_metrics.yaml" + + config_path = DATA_PATH / "grace_fine_tune.yml" + config_path = write_tmp_config_grace(config_path, tmp_path) + result = runner.invoke( + app, + [ + "train", + "grace", + "--mlip-config", + config_path, + "--fine-tune", + "--log", + log_path, + "--summary", + summary_path, + ], + ) + assert result.exit_code == 0 + + assert results_dir.exists() + assert log_path.exists() + assert summary_path.exists() + assert grace_dir.is_dir() + assert metrics_path.exists() + + with open(metrics_path, encoding="utf8") as file: + metrics = yaml.safe_load(file) + assert len(metrics) == 1 + for epoch in metrics: + assert "total_loss/train" in epoch diff --git a/tests/utils.py b/tests/utils.py index 21599dc2..64e159d3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -199,3 +199,15 @@ def chdir(path): yield finally: os.chdir(prev_cwd) + + +def rename_atoms_attributes( + atoms: Atoms, rename_info: dict[str, str], rename_arrays: dict[str, str] +) -> None: + """Rename an Atoms objects info and arrays entries in place.""" + for name_map, store in zip( + (rename_info, rename_arrays), (atoms.info, atoms.arrays), strict=True + ): + for old_name, new_name in name_map.items(): + if old_name in store: + store[new_name] = store.pop(old_name)