From 4b331bddd50ec98722d95608fb3bd02299e5eb29 Mon Sep 17 00:00:00 2001 From: Ruizhi Xu Date: Sat, 12 Apr 2025 17:48:49 +1000 Subject: [PATCH 01/37] init --- .../dataset/crystal_structure_relaxing.py | 130 ++ src/open_r1/tasks/__init__.py | 2 + .../AIRS_preporcess/_tokenizer.py | 441 +++++ .../AIRS_preporcess/mycif.py | 1668 +++++++++++++++++ .../AIRS_preporcess/spacegroups.txt | 227 +++ .../tasks/crystal_structure/relaxing.py | 134 ++ .../tasks/crystal_structure/reward_server.py | 139 ++ 7 files changed, 2741 insertions(+) create mode 100644 src/open_r1/dataset/crystal_structure_relaxing.py create mode 100644 src/open_r1/tasks/crystal_structure/AIRS_preporcess/_tokenizer.py create mode 100644 src/open_r1/tasks/crystal_structure/AIRS_preporcess/mycif.py create mode 100644 src/open_r1/tasks/crystal_structure/AIRS_preporcess/spacegroups.txt create mode 100644 src/open_r1/tasks/crystal_structure/relaxing.py create mode 100644 src/open_r1/tasks/crystal_structure/reward_server.py diff --git a/src/open_r1/dataset/crystal_structure_relaxing.py b/src/open_r1/dataset/crystal_structure_relaxing.py new file mode 100644 index 00000000..5b775511 --- /dev/null +++ b/src/open_r1/dataset/crystal_structure_relaxing.py @@ -0,0 +1,130 @@ +import os +from tqdm import tqdm +import argparse +import pandas as pd +import random +from verl.utils.hdfs_io import copy, makedirs +from AIRS_preporcess._tokenizer import CIFTokenizer + +# Initialize the tokenizer +cif_tokenizer = CIFTokenizer() + +def load_cif_dataset(binary_dir: str, perturbed_dir: str, size: int, local_dir: str) -> list: + """ + Load the dataset: + - Read a parquet dataframe from local_dir (assumed to be "perturbed_df_cif.parquet"). + - Load the ground truth CIF file and the perturbed CIF file based on material_id. + - Serialize the loaded content using cif_tokenizer.serialize. + """ + parquet_path = os.path.join(local_dir, "perturbed_df_cif.parquet") + df = pd.read_parquet(parquet_path) + df = df.head(size) + samples = [] + + for idx, row in tqdm(df.iterrows(), total=len(df), desc="Loading CIF dataset"): + material_id = row['material_id'] + # Extract the ground truth file name from material_id + ground_truth_material_id = material_id.split("_random_")[0] + gt_file = os.path.join(binary_dir, f"{ground_truth_material_id}.cif") + try: + with open(gt_file, 'r', encoding='utf-8') as f: + gt_content = f.read() + except Exception as e: + print(f"Error reading ground truth file {gt_file}: {e}") + continue + + perturbed_file = os.path.join(perturbed_dir, f"{material_id}.cif") + try: + with open(perturbed_file, 'r', encoding='utf-8') as f: + perturbed_content = f.read() + except Exception as e: + print(f"Error reading perturbed file {perturbed_file}: {e}") + continue + + # Note: Both ground truth and perturbed content are serialized here. + sample = { + "compound_id": material_id, + "ground_truth": cif_tokenizer.serialize(gt_content), + "perturbed": cif_tokenizer.serialize(perturbed_content) + } + print("material_id:", material_id) + print("ground_truth:", cif_tokenizer.serialize(gt_content)) + print("perturbed:", cif_tokenizer.serialize(perturbed_content)) + samples.append(sample) + + print(f"{len(samples)} samples loaded") + return samples + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--local_dir', + default='./binary_compounds_dataset', + help="Directory where the dataset is saved locally" + ) + parser.add_argument( + '--hdfs_dir', + default=None, + help="HDFS directory (optional)" + ) + parser.add_argument( + '--train_size', + type=int, + default=500, + help="Number of training set samples" + ) + parser.add_argument( + '--test_size', + type=int, + default=100, + help="Number of test set samples" + ) + args = parser.parse_args() + + # Construct the paths for CIF files + binary_dir = os.path.join(args.local_dir, 'binary_compounds_cifs') + perturbed_dir = os.path.join(args.local_dir, 'perturbed_binary_compounds_cifs') + + samples = load_cif_dataset(binary_dir, perturbed_dir, args.train_size + args.test_size, args.local_dir) + random.shuffle(samples) + total_samples = len(samples) + print(f"Total number of samples loaded: {total_samples}") + + # If the number of samples is insufficient, use all samples as the training set. + if total_samples < (args.train_size + args.test_size): + print("Warning: Insufficient samples, will use all samples as training set.") + train_samples = samples + test_samples = [] + else: + train_samples = samples[:args.train_size] + test_samples = samples[args.train_size:args.train_size + args.test_size] + + # Construct the file paths for generating the dataset required by BinaryCompoundRelaxing. + src_train_path = os.path.join(args.local_dir, 'src-train.txt') + tgt_train_path = os.path.join(args.local_dir, 'tgt-train.txt') + src_test_path = os.path.join(args.local_dir, 'src-test.txt') + tgt_test_path = os.path.join(args.local_dir, 'tgt-test.txt') + + # Write the training set text: each line of the question uses the 'perturbed' field, and the corresponding answer uses the 'ground_truth' field. + with open(src_train_path, 'w', encoding='utf-8') as f_src, \ + open(tgt_train_path, 'w', encoding='utf-8') as f_tgt: + for sample in train_samples: + f_src.write(sample['perturbed'] + "\n") + f_tgt.write(sample['ground_truth'] + "\n") + + # Write the test set text files. + if test_samples: + with open(src_test_path, 'w', encoding='utf-8') as f_src, \ + open(tgt_test_path, 'w', encoding='utf-8') as f_tgt: + for sample in test_samples: + f_src.write(sample['perturbed'] + "\n") + f_tgt.write(sample['ground_truth'] + "\n") + else: + # If the test set is empty, create empty files. + open(src_test_path, 'w', encoding='utf-8').close() + open(tgt_test_path, 'w', encoding='utf-8').close() + + # If an HDFS directory is specified, copy the local_dir to HDFS. + if args.hdfs_dir is not None: + makedirs(args.hdfs_dir) + copy(src=args.local_dir, dst=args.hdfs_dir) diff --git a/src/open_r1/tasks/__init__.py b/src/open_r1/tasks/__init__.py index 8812ce63..c19f4a82 100644 --- a/src/open_r1/tasks/__init__.py +++ b/src/open_r1/tasks/__init__.py @@ -9,6 +9,7 @@ from .reactions.reaction_truefalse import ReactionTrueFalse from .reactions.smi_permute import PermuteSmiles from .smiles_understanding.smiles_hydrogen import SmilesHydrogen +from .crystal_structure.relaxing import BinaryCompoundRelaxing # Task keys as specified in the task recipes and documentation CHEMTASKS = { @@ -25,4 +26,5 @@ "rxn_replacement": SmilesReplacement, "rxn_naming": Smiles2Name, "rxn_truefalse": ReactionTrueFalse, + "crystalrelax": BinaryCompoundRelaxing, } diff --git a/src/open_r1/tasks/crystal_structure/AIRS_preporcess/_tokenizer.py b/src/open_r1/tasks/crystal_structure/AIRS_preporcess/_tokenizer.py new file mode 100644 index 00000000..1c043f76 --- /dev/null +++ b/src/open_r1/tasks/crystal_structure/AIRS_preporcess/_tokenizer.py @@ -0,0 +1,441 @@ +import os +import re +from torch.utils.data import Dataset +import math + +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) + + +with open(os.path.join(THIS_DIR, "spacegroups.txt"), "rt") as f: + SPACE_GROUPS = [sg.strip() for sg in f.readlines()] + + +ATOMS = ["Si", "C", "Pb", "I", "Br", "Cl", "Eu", "O", "Fe", "Sb", "In", "S", "N", "U", "Mn", "Lu", "Se", "Tl", "Hf", + "Ir", "Ca", "Ta", "Cr", "K", "Pm", "Mg", "Zn", "Cu", "Sn", "Ti", "B", "W", "P", "H", "Pd", "As", "Co", "Np", + "Tc", "Hg", "Pu", "Al", "Tm", "Tb", "Ho", "Nb", "Ge", "Zr", "Cd", "V", "Sr", "Ni", "Rh", "Th", "Na", "Ru", + "La", "Re", "Y", "Er", "Ce", "Pt", "Ga", "Li", "Cs", "F", "Ba", "Te", "Mo", "Gd", "Pr", "Bi", "Sc", "Ag", "Rb", + "Dy", "Yb", "Nd", "Au", "Os", "Pa", "Sm", "Be", "Ac", "Xe", "Kr", "He", "Ne", "Ar"] + +DIGITS = [str(d) for d in list(range(10))] + +INTS = [str(d) for d in list(range(300))] + +KEYWORDS = [ + "space_group_symbol", + "formula", + "atoms", + "lattice_parameters", + "a", + "b", + "c", + "alpha", + "beta", + "gamma" +] + +UNK_TOKEN = "" + +def get_spacegroup_number(sg_symbol): + try: + from pymatgen.symmetry.groups import SpaceGroup + sg = SpaceGroup(sg_symbol) + return sg + except Exception as e: + print("Err:", e) + return None + +def parse_formula(formula): + formula = formula.replace("'", "").replace('"', '').strip() + pattern = r"([A-Z][a-z]*)(\d*)" + counts = {} + for element, count in re.findall(pattern, formula): + counts[element] = counts.get(element, 0) + (int(count) if count else 1) + return counts + +def compute_cell_formula_units_Z(formula_sum, formula_structural): + counts_sum = parse_formula(formula_sum) + counts_struct = parse_formula(formula_structural) + + ratios = [] + for element, count_struct in counts_struct.items(): + if element not in counts_sum: + raise ValueError(f"{element}") + ratio = counts_sum[element] / count_struct + if ratio != int(ratio): + raise ValueError(f"{element}, {ratio} not int") + ratios.append(int(ratio)) + + if len(set(ratios)) != 1: + raise ValueError(f"{ratios} != 1") + return ratios[0] + +class CIFTokenizer: + def __init__(self): + self._tokens = [""] + self._tokens.extend(self.atoms()) + self._tokens.extend(self.digits()) + self._tokens.extend(self.keywords()) + self._tokens.extend(self.symbols()) + + space_groups = list(self.space_groups()) + # Replace 'Pm' space group with 'Pm_sg' to disambiguate from atom 'Pm', + # or 'P1' with 'P1_sg' to disambiguate from atom 'P' and number '1' + space_groups_sg = [sg+"_sg" for sg in space_groups] + self._tokens.extend(space_groups_sg) + + digits_int = [v+"_int" for v in INTS] + self._tokens.extend(digits_int) + + self._escaped_tokens = [re.escape(token) for token in self._tokens] + self._escaped_tokens.sort(key=len, reverse=True) + + # a mapping from characters to integers + self._token_to_id = {ch: i for i, ch in enumerate(self._tokens)} + self._id_to_token = {i: ch for i, ch in enumerate(self._tokens)} + # map the id of 'Pm_sg' back to 'Pm', or 'P1_sg' to 'P1', + # for decoding convenience + for sg in space_groups_sg: + self._id_to_token[self.token_to_id[sg]] = sg.replace("_sg", "") + + for v_int in digits_int: + self._id_to_token[self.token_to_id[v_int]] = v_int.replace("_int", "") + + @staticmethod + def atoms(): + return ATOMS + + @staticmethod + def digits(): + return DIGITS + + @staticmethod + def keywords(): + kws = list(KEYWORDS) + return kws + + @staticmethod + def symbols(): + # return ["x", "y", "z", ".", "(", ")", "+", "-", "/", "'", ",", " ", "\n"] + return [",", " ", ":", ".", "\n"] + + @staticmethod + def space_groups(): + return SPACE_GROUPS + + @property + def token_to_id(self): + return dict(self._token_to_id) + + @property + def id_to_token(self): + return dict(self._id_to_token) + + def prompt_tokenize(self, cif): + token_pattern = '|'.join(self._escaped_tokens) + # Add a regex pattern to match any sequence of characters separated by whitespace or punctuation + full_pattern = f'({token_pattern}|\\w+|[\\.,;!?])' + # Tokenize the input string using the regex pattern + cif = re.sub(r'[ \t]+', ' ', cif) + tokens = re.findall(full_pattern, cif) + return tokens + + def encode(self, tokens): + # encoder: take a list of tokens, output a list of integers + return [self._token_to_id[t] for t in tokens] + + def decode(self, ids): + # decoder: take a list of integers (i.e. encoded tokens), output a string + return ''.join([self._id_to_token[i] for i in ids]) + + def serialize(self, cif_string): + spacegroups = "|".join(SPACE_GROUPS) + cif_string = re.sub(fr'(_symmetry_space_group_name_H-M *\b({spacegroups}))\n', r'\1_sg\n', cif_string) + extracted_data = self.tokenize_cif_preprocess(cif_string) + + seq_res = '' + # formula + seq_res += "formula " + formula = extracted_data["formula"] + elements_counts = re.findall(r'([A-Z][a-z]*)(\d*)', formula) + for element, count in elements_counts: + if not element: break + if not count: count ="1" + seq_res += element + " " + count + "_int " + seq_res += "\n" + # space group name + seq_res += "space_group_symbol " + extracted_data["space_group_symbol"] + "\n" + # lattice + seq_res += "lattice_parameters " + "a " + extracted_data["lattice_parameters"]["a"] + " " + seq_res += "b " + extracted_data["lattice_parameters"]["b"] + " " + seq_res += "c " + extracted_data["lattice_parameters"]["c"] + " " + seq_res += "alpha " + extracted_data["lattice_parameters"]["alpha"] + " " + seq_res += "beta " + extracted_data["lattice_parameters"]["beta"] + " " + seq_res += "gamma " + extracted_data["lattice_parameters"]["gamma"] + " " + seq_res += "\n" + # atoms + for idx in range(len(extracted_data["atoms"])): + tmp = extracted_data["atoms"][idx] + seq_res += tmp["type"] + " " + tmp["num"] + "_int " + tmp["coordinates"][0] + " " + tmp["coordinates"][1] + " " + tmp["coordinates"][2] + "\n" + seq_res += "\n" + # Create a regex pattern by joining the escaped tokens with '|' + token_pattern = '|'.join(self._escaped_tokens) + # Add a regex pattern to match any sequence of characters separated by whitespace or punctuation + full_pattern = f'({token_pattern}|\\w+|[\\.,;!?])' + # Tokenize the input string using the regex pattern + seq_res = re.sub(r'[ \t]+', ' ', seq_res) + return seq_res + + def deserialize(self, custom_str, ground_truth=None): + print("self", self) + print("custom_str", custom_str) + print("ground_truth", ground_truth) + pattern_structural = re.compile(r"_chemical_formula_structural\s+['\"]?([^\n'\"]+)['\"]?") + pattern_sum = re.compile(r"_chemical_formula_sum\s+['\"]?([^'\"]+)['\"]?") + pattern_units = re.compile(r"_cell_formula_units_Z\s+(\d+)") + + structural_match = pattern_structural.search(ground_truth) + sum_match = pattern_sum.search(ground_truth) + units_match = pattern_units.search(ground_truth) + + symmetry_equiv_pos_pattern = re.compile( + r"loop_\s*\n\s*_symmetry_equiv_pos_site_id\s*\n\s*_symmetry_equiv_pos_as_xyz\s*\n(.*?)(?:\nloop_|\Z)", + re.DOTALL + ) + symmetry_equiv_pos_match = symmetry_equiv_pos_pattern.search(ground_truth) + if symmetry_equiv_pos_match: + sym_ops_block = symmetry_equiv_pos_match.group(1).strip() + + formula_structural = structural_match.group(1) if structural_match else None + formula_sum = sum_match.group(1) if sum_match else None + units_Z = int(units_match.group(1)) if units_match else None + print("formula_structural", formula_structural) + lines = custom_str.strip().splitlines() + data = {} + + if lines: + tokens = lines[0].split() + if tokens[0] != "formula": + raise ValueError("'formula' missing") + formula = "" + for i in range(1, len(tokens), 2): + element = tokens[i] + count_token = tokens[i+1] if i+1 < len(tokens) else "" + if count_token.endswith("_int"): + count = count_token[:-4] + else: + count = count_token + formula += f"{element}{count}" + data["formula"] = formula + + if len(lines) >= 2: + tokens = lines[1].split() + if tokens[0] != "space_group_symbol": + raise ValueError("'space_group_symbol' missing") + data["space_group_symbol"] = " ".join(tokens[1:]) + + if len(lines) >= 3: + tokens = lines[2].split() + if tokens[0] != "lattice_parameters": + raise ValueError("'lattice_parameters' missing") + lattice = {} + for i in range(1, len(tokens), 2): + key = tokens[i] + value = tokens[i+1] if i+1 < len(tokens) else "" + lattice[key] = value + data["lattice_parameters"] = lattice + + atoms = [] + for line in lines[3:]: + if not line.strip(): + break + tokens = line.split() + if len(tokens) < 5: + continue + atom_type = tokens[0] + num_token = tokens[1] + if num_token.endswith("_int"): + num = num_token[:-4] + else: + num = num_token + coords = tokens[2:5] + atoms.append({"type": atom_type, "num": num, "coordinates": coords}) + data["atoms"] = atoms + + cif_lines = [] + cif_lines.append(f"data_{formula_structural}") + cif_lines.append(f"_symmetry_space_group_name_H-M {data['space_group_symbol'].split('_sg')[0]}") + lattice = data["lattice_parameters"] + cif_lines.append(f"_cell_length_a {lattice.get('a', '')}") + cif_lines.append(f"_cell_length_b {lattice.get('b', '')}") + cif_lines.append(f"_cell_length_c {lattice.get('c', '')}") + cif_lines.append(f"_cell_angle_alpha {lattice.get('alpha', '')}") + cif_lines.append(f"_cell_angle_beta {lattice.get('beta', '')}") + cif_lines.append(f"_cell_angle_gamma {lattice.get('gamma', '')}") + space_group_symbol = str(get_spacegroup_number(data['space_group_symbol'].split("_sg")[0].strip("'"))) + space_group_symbol = re.search(r'number\s+(\d+)', space_group_symbol).group(1) + cif_lines.append(f"_symmetry_Int_Tables_number {space_group_symbol}") + cif_lines.append(f"_chemical_formula_structural {formula_structural}") + cif_lines.append(f"_chemical_formula_sum '{formula_sum}'") + + a = float(lattice.get("a", 0)) + b = float(lattice.get("b", 0)) + c = float(lattice.get("c", 0)) + alpha = float(lattice.get("alpha", 90)) + beta = float(lattice.get("beta", 90)) + gamma = float(lattice.get("gamma", 90)) + alpha_rad = math.radians(alpha) + beta_rad = math.radians(beta) + gamma_rad = math.radians(gamma) + + cos_alpha = math.cos(alpha_rad) + cos_beta = math.cos(beta_rad) + cos_gamma = math.cos(gamma_rad) + cell_volume = a * b * c * math.sqrt(1 - cos_alpha**2 - cos_beta**2 - cos_gamma**2 + + 2 * cos_alpha * cos_beta * cos_gamma) + cif_lines.append(f"_cell_volume {cell_volume:.8f}") + cif_lines.append(f"_cell_formula_units_Z '{units_Z}'") + cif_lines.append("loop_") + cif_lines.append(" _symmetry_equiv_pos_site_id") + cif_lines.append(" _symmetry_equiv_pos_as_xyz") + cif_lines.append(f" {sym_ops_block}") + cif_lines.append("loop_") + cif_lines.append("_atom_site_type_symbol") + cif_lines.append("_atom_site_label") + cif_lines.append("_atom_site_symmetry_multiplicity") + cif_lines.append("_atom_site_fract_x") + cif_lines.append("_atom_site_fract_y") + cif_lines.append("_atom_site_fract_z") + cif_lines.append("_atom_site_occupancy") + unique_counts = {} + for atom in data["atoms"]: + label = f"{atom['type']}" + if label not in unique_counts: + unique_counts[label] = len(unique_counts) + label = label + str(unique_counts[label]) + else: + label = label + str(unique_counts[label]) + cif_lines.append(f"{ atom['type']} {label} {atom['num']} {atom['coordinates'][0]} {atom['coordinates'][1]} {atom['coordinates'][2]} 1") + cif_string_reconstructed = "\n".join(cif_lines) + return cif_string_reconstructed + + def tokenize_cif(self, cif_string, max_length=1385): + # Preprocessing step to replace '_symmetry_space_group_name_H-M Pm' + # with '_symmetry_space_group_name_H-M Pm_sg',to disambiguate from atom 'Pm', + # or any space group symbol to avoid problematic cases, like 'P1' + spacegroups = "|".join(SPACE_GROUPS) + cif_string = re.sub(fr'(_symmetry_space_group_name_H-M *\b({spacegroups}))\n', r'\1_sg\n', cif_string) + + extracted_data = self.tokenize_cif_preprocess(cif_string) + + seq_res = '' + # formula + seq_res += "formula " + formula = extracted_data["formula"] + elements_counts = re.findall(r'([A-Z][a-z]*)(\d*)', formula) + for element, count in elements_counts: + if not element: break + if not count: count ="1" + seq_res += element + " " + count + "_int " + seq_res += "\n" + # space group name + seq_res += "space_group_symbol " + extracted_data["space_group_symbol"] + "\n" + # lattice + seq_res += "lattice_parameters " + "a " + extracted_data["lattice_parameters"]["a"] + " " + seq_res += "b " + extracted_data["lattice_parameters"]["b"] + " " + seq_res += "c " + extracted_data["lattice_parameters"]["c"] + " " + seq_res += "alpha " + extracted_data["lattice_parameters"]["alpha"] + " " + seq_res += "beta " + extracted_data["lattice_parameters"]["beta"] + " " + seq_res += "gamma " + extracted_data["lattice_parameters"]["gamma"] + " " + seq_res += "\n" + # atoms + for idx in range(len(extracted_data["atoms"])): + tmp = extracted_data["atoms"][idx] + seq_res += tmp["type"] + " " + tmp["num"] + "_int " + tmp["coordinates"][0] + " " + tmp["coordinates"][1] + " " + tmp["coordinates"][2] + "\n" + seq_res += "\n" + # Create a regex pattern by joining the escaped tokens with '|' + token_pattern = '|'.join(self._escaped_tokens) + # Add a regex pattern to match any sequence of characters separated by whitespace or punctuation + full_pattern = f'({token_pattern}|\\w+|[\\.,;!?])' + # Tokenize the input string using the regex pattern + seq_res = re.sub(r'[ \t]+', ' ', seq_res) + # print(seq_res) + tokens = re.findall(full_pattern, seq_res) + # print(tokens) + padding_length = max_length - len(tokens) + if padding_length > 0: + tokens.extend([""] * padding_length) + + return tokens + + def tokenize_cif_preprocess(self, cif_string): + # Re-initialize the dictionary to hold the extracted data + extracted_data = { + "space_group_symbol": "", + "formula": "", + "atoms": [], + "lattice_parameters": {} + } + + # Split the text into lines for processing + lines = cif_string.split('\n') + + # Iterate through each line to extract the required information + atom_line_idx = -1 + for line_idx in range(len(lines)): + line = lines[line_idx] + # Extract space group symbol + if "_symmetry_space_group_name_H-M" in line: + spacegroup_match = re.search(r'_symmetry_space_group_name_H-M\s+([^\n]+)', line) + spacegroup = spacegroup_match.group(1).strip() + extracted_data["space_group_symbol"] = spacegroup + # Extract formula + elif line.startswith("data_"): + extracted_data["formula"] = line.split("_")[1] + # Extract lattice parameters + elif line.startswith("_cell_length_a"): + extracted_data["lattice_parameters"]["a"] = line.split()[-1] + elif line.startswith("_cell_length_b"): + extracted_data["lattice_parameters"]["b"] = line.split()[-1] + elif line.startswith("_cell_length_c"): + extracted_data["lattice_parameters"]["c"] = line.split()[-1] + elif line.startswith("_cell_angle_alpha"): + extracted_data["lattice_parameters"]["alpha"] = line.split()[-1] + elif line.startswith("_cell_angle_beta"): + extracted_data["lattice_parameters"]["beta"] = line.split()[-1] + elif line.startswith("_cell_angle_gamma"): + extracted_data["lattice_parameters"]["gamma"] = line.split()[-1] + elif "_atom_site_occupancy" in line: + atom_line_idx = line_idx + 1 + break + + for line_idx in range(atom_line_idx, len(lines)): + line = lines[line_idx] + if len(line) < 2: + continue + atom_info = line.split() + atom_type = atom_info[0] + num_atoms = atom_info[2] + x, y, z = atom_info[3], atom_info[4], atom_info[5] + extracted_data["atoms"].append({ + "type": atom_type, + "num": num_atoms, + "coordinates": (x, y, z) + }) + + return extracted_data + + +class CinDataset(Dataset): + def __init__(self, texts): + self.texts = texts + + def __len__(self): + return len(self.texts) + + def __getitem__(self, idx): + text = self.texts[idx][:1500] + # if self.conditions is not None: + # raw_input_ids = raw_input_ids[1:] # Remove the first token () + input_ids = text[:-1] + targets = text[1:] + return input_ids, targets diff --git a/src/open_r1/tasks/crystal_structure/AIRS_preporcess/mycif.py b/src/open_r1/tasks/crystal_structure/AIRS_preporcess/mycif.py new file mode 100644 index 00000000..2f84a691 --- /dev/null +++ b/src/open_r1/tasks/crystal_structure/AIRS_preporcess/mycif.py @@ -0,0 +1,1668 @@ +"""Wrapper classes for Cif input and output from Structures.""" + +from __future__ import annotations + +import math +import os +import re +import textwrap +import warnings +from collections import defaultdict, deque +from datetime import datetime +from functools import partial +from inspect import getfullargspec as getargspec +from io import StringIO +from itertools import groupby +from pathlib import Path +from typing import TYPE_CHECKING, Any, Literal + +import numpy as np +from monty.io import zopen +from monty.serialization import loadfn + +from pymatgen.core import Composition, DummySpecies, Element, Lattice, PeriodicSite, Species, Structure, get_el_sp +from pymatgen.core.operations import MagSymmOp, SymmOp +from pymatgen.electronic_structure.core import Magmom +from pymatgen.symmetry.analyzer import SpacegroupAnalyzer, SpacegroupOperations +from pymatgen.symmetry.groups import SYMM_DATA, SpaceGroup +from pymatgen.symmetry.maggroups import MagneticSpaceGroup +from pymatgen.symmetry.structure import SymmetrizedStructure +from pymatgen.util.coord import find_in_coord_list_pbc, in_coord_list_pbc + +if TYPE_CHECKING: + from pymatgen.core.trajectory import Vector3D + +__author__ = "Shyue Ping Ong, Will Richards, Matthew Horton" + +sub_spgrp = partial(re.sub, r"[\s_]", "") + +space_groups = {sub_spgrp(key): key for key in SYMM_DATA["space_group_encoding"]} # type: ignore + +space_groups.update({sub_spgrp(key): key for key in SYMM_DATA["space_group_encoding"]}) # type: ignore + + +class CifBlock: + """ + Object for storing cif data. All data is stored in a single dictionary. + Data inside loops are stored in lists in the data dictionary, and + information on which keys are grouped together are stored in the loops + attribute. + """ + + max_len = 70 # not quite 80 so we can deal with semicolons and things + + def __init__(self, data, loops, header): + """ + Args: + data: dict of data to go into the cif. Values should be convertible to string, + or lists of these if the key is in a loop + loops: list of lists of keys, grouped by which loop they should appear in + header: name of the block (appears after the data_ on the first line). + """ + self.loops = loops + self.data = data + # AJ (@computron) says: CIF Block names cannot be more than 75 characters or you get an Exception + self.header = header[:74] + + def __eq__(self, other: object) -> bool: + if not isinstance(other, CifBlock): + return NotImplemented + return self.loops == other.loops and self.data == other.data and self.header == other.header + + def __getitem__(self, key): + return self.data[key] + + def __str__(self) -> str: + """Returns the cif string for the data block.""" + out = [f"data_{self.header}"] + keys = list(self.data) + written = [] + for key in keys: + if key in written: + continue + for loop in self.loops: + # search for a corresponding loop + if key in loop: + out.append(self._loop_to_str(loop)) + written.extend(loop) + break + if key not in written: + # k didn't belong to a loop + v = self._format_field(self.data[key]) + if len(key) + len(v) + 3 < self.max_len: + out.append(f"{key} {v}") + else: + out.extend([key, v]) + return "\n".join(out) + + def _loop_to_str(self, loop): + out = "loop_" + for line in loop: + out += "\n " + line + for fields in zip(*(self.data[k] for k in loop)): + line = "\n" + for val in map(self._format_field, fields): + if val[0] == ";": + out += line + "\n" + val + line = "\n" + elif len(line) + len(val) + 2 < self.max_len: + line += " " + val + else: + out += line + line = "\n " + val + out += line + return out + + def _format_field(self, val) -> str: + val = str(val).strip() + if len(val) > self.max_len: + return f";\n{textwrap.fill(val, self.max_len)}\n;" + # add quotes if necessary + if val == "": + return '""' + if ( + (" " in val or val[0] == "_") + and not (val[0] == "'" and val[-1] == "'") + and not (val[0] == '"' and val[-1] == '"') + ): + quote = '"' if "'" in val else "'" + val = quote + val + quote + return val + + @classmethod + def _process_string(cls, string): + # remove comments + string = re.sub(r"(\s|^)#.*$", "", string, flags=re.MULTILINE) + # remove empty lines + string = re.sub(r"^\s*\n", "", string, flags=re.MULTILINE) + # remove non_ascii + string = string.encode("ascii", "ignore").decode("ascii") + # since line breaks in .cif files are mostly meaningless, + # break up into a stream of tokens to parse, rejoining multiline + # strings (between semicolons) + deq = deque() + multiline = False + ml = [] + # this regex splits on spaces, except when in quotes. starting quotes must not be + # preceded by non-whitespace (these get eaten by the first expression). ending + # quotes must not be followed by non-whitespace + pattern = re.compile(r"""([^'"\s][\S]*)|'(.*?)'(?!\S)|"(.*?)"(?!\S)""") + for line in string.splitlines(): + if multiline: + if line.startswith(";"): + multiline = False + deq.append(("", "", "", " ".join(ml))) + ml = [] + line = line[1:].strip() + else: + ml.append(line) + continue + if line.startswith(";"): + multiline = True + ml.append(line[1:].strip()) + else: + for string in pattern.findall(line): + # location of the data in string depends on whether it was quoted in the input + deq.append(tuple(string)) + return deq + + @classmethod + def from_str(cls, string): + """ + Reads CifBlock from string. + + :param string: String representation. + + Returns: + CifBlock + """ + q = cls._process_string(string) + header = q.popleft()[0][5:] + data = {} + loops = [] + while q: + s = q.popleft() + # cif keys aren't in quotes, so show up in s[0] + if s[0] == "_eof": + break + if s[0].startswith("_"): + try: + data[s[0]] = "".join(q.popleft()) + except IndexError: + data[s[0]] = "" + elif s[0].startswith("loop_"): + columns = [] + items = [] + while q: + s = q[0] + if s[0].startswith("loop_") or not s[0].startswith("_"): + break + columns.append("".join(q.popleft())) + data[columns[-1]] = [] + while q: + s = q[0] + if s[0].startswith(("loop_", "_")): + break + items.append("".join(q.popleft())) + n = len(items) // len(columns) + assert len(items) % n == 0 + loops.append(columns) + for k, v in zip(columns * n, items): + data[k].append(v.strip()) + elif issue := "".join(s).strip(): + warnings.warn(f"Possible issue in CIF file at line: {issue}") + return cls(data, loops, header) + + +class CifFile: + """Reads and parses CifBlocks from a .cif file or string.""" + + def __init__(self, data: dict, orig_string: str | None = None, comment: str | None = None) -> None: + """ + Args: + data (dict): Of CifBlock objects. + orig_string (str): The original cif string. + comment (str): Comment string. + """ + self.data = data + self.orig_string = orig_string + self.comment = comment or "# generated using pymatgen" + + def __str__(self): + out = "\n".join(map(str, self.data.values())) + return f"{self.comment}\n{out}\n" + + @classmethod + def from_str(cls, string) -> CifFile: + """Reads CifFile from a string. + + :param string: String representation. + + Returns: + CifFile + """ + dct = {} + + for block_str in re.split(r"^\s*data_", f"x\n{string}", flags=re.MULTILINE | re.DOTALL)[1:]: + # Skip over Cif block that contains powder diffraction data. + # Some elements in this block were missing from CIF files in + # Springer materials/Pauling file DBs. + # This block does not contain any structure information anyway, and + # CifParser was also not parsing it. + if "powder_pattern" in re.split(r"\n", block_str, maxsplit=1)[0]: + continue + block = CifBlock.from_str(f"data_{block_str}") + # TODO (@janosh, 2023-10-11) multiple CIF blocks with equal header will overwrite each other, + # latest taking precedence. maybe something to fix and test e.g. in test_cif_writer_write_file + dct[block.header] = block + + return cls(dct, string) + + @classmethod + def from_file(cls, filename: str | Path) -> CifFile: + """ + Reads CifFile from a filename. + + :param filename: Filename + + Returns: + CifFile + """ + with zopen(str(filename), mode="rt", errors="replace") as file: + return cls.from_str(file.read()) + + +class CifParser: + """ + Parses a CIF file. Attempts to fix CIFs that are out-of-spec, but will issue warnings + if corrections applied. These are also stored in the CifParser's errors attribute. + """ + + def __init__( + self, + filename: str | StringIO, + occupancy_tolerance: float = 1.0, + site_tolerance: float = 1e-4, + frac_tolerance: float = 1e-4, + check_cif: bool = True, + comp_tol: float = 0.01, + ) -> None: + """ + Args: + filename (str): CIF filename, gzipped or bzipped CIF files are fine too. + occupancy_tolerance (float): If total occupancy of a site is between 1 and occupancy_tolerance, the + occupancies will be scaled down to 1. + site_tolerance (float): This tolerance is used to determine if two sites are sitting in the same position, + in which case they will be combined to a single disordered site. Defaults to 1e-4. + frac_tolerance (float): This tolerance is used to determine is a coordinate should be rounded to an ideal + value. E.g., 0.6667 is rounded to 2/3. This is desired if symmetry operations are going to be applied. + However, for very large CIF files, this may need to be set to 0. + check_cif (bool): Whether to check that stoichiometry reported in CIF matches + that of resulting Structure, and whether elements are missing. Defaults to True. + comp_tol (float): Tolerance for how closely stoichiometries of CIF file and pymatgen should match. + Defaults to 0.01. Context: Experimental CIF files often don't report hydrogens positions due to being + hard-to-locate with X-rays. pymatgen warns if the stoichiometry of the CIF file and the Structure + don't match to within comp_tol. + """ + self._occupancy_tolerance = occupancy_tolerance + self._site_tolerance = site_tolerance + self._frac_tolerance = frac_tolerance + if isinstance(filename, (str, Path)): + self._cif = CifFile.from_file(filename) + else: + self._cif = CifFile.from_str(filename.read()) + + # options related to checking CIFs for missing elements + # or incorrect stoichiometries + self.check_cif = check_cif + self.comp_tol = comp_tol + + # store if CIF contains features from non-core CIF dictionaries + # e.g. magCIF + self.feature_flags = {} + self.warnings: list[str] = [] + + def is_magcif() -> bool: + """Checks to see if file appears to be a magCIF file (heuristic).""" + # Doesn't seem to be a canonical way to test if file is magCIF or + # not, so instead check for magnetic symmetry datanames + prefixes = ["_space_group_magn", "_atom_site_moment", "_space_group_symop_magn"] + for d in self._cif.data.values(): + for k in d.data: + for prefix in prefixes: + if prefix in k: + return True + return False + + self.feature_flags["magcif"] = is_magcif() + + def is_magcif_incommensurate() -> bool: + """ + Checks to see if file contains an incommensurate magnetic + structure (heuristic). + """ + # Doesn't seem to be a canonical way to test if magCIF file + # describes incommensurate structure or not, so instead check + # for common datanames + if not self.feature_flags["magcif"]: + return False + prefixes = ["_cell_modulation_dimension", "_cell_wave_vector"] + for d in self._cif.data.values(): + for k in d.data: + for prefix in prefixes: + if prefix in k: + return True + return False + + self.feature_flags["magcif_incommensurate"] = is_magcif_incommensurate() + + for key in self._cif.data: + # pass individual CifBlocks to _sanitize_data + self._cif.data[key] = self._sanitize_data(self._cif.data[key]) + + @classmethod + def from_str(cls, cif_string: str, **kwargs) -> CifParser: + """ + Creates a CifParser from a string. + + Args: + cif_string (str): String representation of a CIF. + **kwargs: Passthrough of all kwargs supported by CifParser. + + Returns: + CifParser + """ + stream = StringIO(cif_string) + return cls(stream, **kwargs) + + def _sanitize_data(self, data): + """ + Some CIF files do not conform to spec. This function corrects + known issues, particular in regards to Springer materials/ + Pauling files. + + This function is here so that CifParser can assume its + input conforms to spec, simplifying its implementation. + :param data: CifBlock + + Returns: + data CifBlock + """ + """ + This part of the code deals with handling formats of data as found in + CIF files extracted from the Springer Materials/Pauling File + databases, and that are different from standard ICSD formats. + """ + # check for implicit hydrogens, warn if any present + if "_atom_site_attached_hydrogens" in data.data: + attached_hydrogens = [str2float(x) for x in data.data["_atom_site_attached_hydrogens"] if str2float(x) != 0] + if len(attached_hydrogens) > 0: + self.warnings.append( + "Structure has implicit hydrogens defined, parsed structure unlikely to be " + "suitable for use in calculations unless hydrogens added." + ) + + # Check to see if "_atom_site_type_symbol" exists, as some test CIFs do + # not contain this key. + if "_atom_site_type_symbol" in data.data: + # Keep a track of which data row needs to be removed. + # Example of a row: Nb,Zr '0.8Nb + 0.2Zr' .2a .m-3m 0 0 0 1 14 + # 'rhombic dodecahedron, Nb14' + # Without this code, the above row in a structure would be parsed + # as an ordered site with only Nb (since + # CifParser would try to parse the first two characters of the + # label "Nb,Zr") and occupancy=1. + # However, this site is meant to be a disordered site with 0.8 of + # Nb and 0.2 of Zr. + idxs_to_remove = [] + + new_atom_site_label = [] + new_atom_site_type_symbol = [] + new_atom_site_occupancy = [] + new_fract_x = [] + new_fract_y = [] + new_fract_z = [] + + for idx, el_row in enumerate(data["_atom_site_label"]): + # CIF files from the Springer Materials/Pauling File have + # switched the label and symbol. Thus, in the + # above shown example row, '0.8Nb + 0.2Zr' is the symbol. + # Below, we split the strings on ' + ' to + # check if the length (or number of elements) in the label and + # symbol are equal. + if len(data["_atom_site_type_symbol"][idx].split(" + ")) > len(el_row.split(" + ")): + # Dictionary to hold extracted elements and occupancies + els_occu = {} + + # parse symbol to get element names and occupancy and store + # in "els_occu" + symbol_str = data["_atom_site_type_symbol"][idx] + symbol_str_lst = symbol_str.split(" + ") + for elocc_idx, sym in enumerate(symbol_str_lst): + # Remove any bracketed items in the string + symbol_str_lst[elocc_idx] = re.sub(r"\([0-9]*\)", "", sym.strip()) + + # Extract element name and its occupancy from the + # string, and store it as a + # key-value pair in "els_occ". + els_occu[ + str(re.findall(r"\D+", symbol_str_lst[elocc_idx].strip())[1]).replace("", "") + ] = float("0" + re.findall(r"\.?\d+", symbol_str_lst[elocc_idx].strip())[1]) + + x = str2float(data["_atom_site_fract_x"][idx]) + y = str2float(data["_atom_site_fract_y"][idx]) + z = str2float(data["_atom_site_fract_z"][idx]) + + for et, occu in els_occu.items(): + # new atom site labels have 'fix' appended + new_atom_site_label.append(f"{et}_fix{len(new_atom_site_label)}") + new_atom_site_type_symbol.append(et) + new_atom_site_occupancy.append(str(occu)) + new_fract_x.append(str(x)) + new_fract_y.append(str(y)) + new_fract_z.append(str(z)) + + idxs_to_remove.append(idx) + + # Remove the original row by iterating over all keys in the CIF + # data looking for lists, which indicates + # multiple data items, one for each row, and remove items from the + # list that corresponds to the removed row, + # so that it's not processed by the rest of this function (which + # would result in an error). + for original_key in data.data: + if isinstance(data.data[original_key], list): + for idx in sorted(idxs_to_remove, reverse=True): + del data.data[original_key][idx] + + if len(idxs_to_remove) > 0: + self.warnings.append("Pauling file corrections applied.") + + data.data["_atom_site_label"] += new_atom_site_label + data.data["_atom_site_type_symbol"] += new_atom_site_type_symbol + data.data["_atom_site_occupancy"] += new_atom_site_occupancy + data.data["_atom_site_fract_x"] += new_fract_x + data.data["_atom_site_fract_y"] += new_fract_y + data.data["_atom_site_fract_z"] += new_fract_z + # This fixes inconsistencies in naming of several magCIF tags as a result of magCIF + # being in widespread use prior to specification being finalized (on advice of Branton Campbell). + if self.feature_flags["magcif"]: + # CIF-1 style has all underscores, interim standard + # had period before magn instead of before the final + # component (e.g. xyz) + # we want to standardize on a specific key, to simplify + # parsing code + correct_keys = [ + "_space_group_symop_magn_operation.xyz", + "_space_group_symop_magn_centering.xyz", + "_space_group_magn.name_BNS", + "_space_group_magn.number_BNS", + "_atom_site_moment_crystalaxis_x", + "_atom_site_moment_crystalaxis_y", + "_atom_site_moment_crystalaxis_z", + "_atom_site_moment_label", + ] + + # cannot mutate dict during enumeration, so store changes we want to make + changes_to_make = {} + + for original_key in data.data: + for correct_key in correct_keys: + # convert to all underscore + trial_key = "_".join(correct_key.split(".")) + test_key = "_".join(original_key.split(".")) + if trial_key == test_key: + changes_to_make[correct_key] = original_key + + # make changes + for correct_key, original_key in changes_to_make.items(): + data.data[correct_key] = data.data[original_key] + + # renamed_keys maps interim_keys to final_keys + renamed_keys = { + "_magnetic_space_group.transform_to_standard_Pp_abc": "_space_group_magn.transform_BNS_Pp_abc" + } + changes_to_make = {} + + for interim_key, final_key in renamed_keys.items(): + if data.data.get(interim_key): + changes_to_make[final_key] = interim_key + + if len(changes_to_make) > 0: + self.warnings.append("Keys changed to match new magCIF specification.") + + for final_key, interim_key in changes_to_make.items(): + data.data[final_key] = data.data[interim_key] + + # check for finite precision frac coordinates (e.g. 0.6667 instead of 0.6666666...7) + # this can sometimes cause serious issues when applying symmetry operations + important_fracs = (1 / 3, 2 / 3) + fracs_to_change = {} + for label in ("_atom_site_fract_x", "_atom_site_fract_y", "_atom_site_fract_z"): + if label in data.data: + for idx, frac in enumerate(data.data[label]): + try: + frac = str2float(frac) + except Exception: + # coordinate might not be defined e.g. '?' + continue + for comparison_frac in important_fracs: + if abs(1 - frac / comparison_frac) < self._frac_tolerance: + fracs_to_change[(label, idx)] = str(comparison_frac) + if fracs_to_change: + self.warnings.append( + f"{len(fracs_to_change)} fractional coordinates rounded to ideal values to avoid issues with " + "finite precision." + ) + for (label, idx), val in fracs_to_change.items(): + data.data[label][idx] = val + + return data + + def _unique_coords( + self, + coords: list[Vector3D], + magmoms: list[Magmom] | None = None, + lattice: Lattice | None = None, + labels: dict[Vector3D, str] | None = None, + ): + """ + Generate unique coordinates using coord and symmetry positions + and also their corresponding magnetic moments, if supplied. + """ + coords_out: list[np.ndarray] = [] + labels_out = [] + labels = labels or {} + + if magmoms: + magmoms_out = [] + if len(magmoms) != len(coords): + raise ValueError + for tmp_coord, tmp_magmom in zip(coords, magmoms): + for op in self.symmetry_operations: + coord = op.operate(tmp_coord) + coord = np.array([i - math.floor(i) for i in coord]) + if isinstance(op, MagSymmOp): + # Up to this point, magmoms have been defined relative + # to crystal axis. Now convert to Cartesian and into + # a Magmom object. + magmom = Magmom.from_moment_relative_to_crystal_axes( + op.operate_magmom(tmp_magmom), lattice=lattice + ) + else: + magmom = Magmom(tmp_magmom) + if not in_coord_list_pbc(coords_out, coord, atol=self._site_tolerance): + coords_out.append(coord) + magmoms_out.append(magmom) + labels_out.append(labels.get(tmp_coord)) + return coords_out, magmoms_out, labels_out + + for tmp_coord in coords: + for op in self.symmetry_operations: + coord = op.operate(tmp_coord) + coord = np.array([i - math.floor(i) for i in coord]) + if not in_coord_list_pbc(coords_out, coord, atol=self._site_tolerance): + coords_out.append(coord) + labels_out.append(labels.get(tmp_coord)) + + dummy_magmoms = [Magmom(0)] * len(coords_out) + return coords_out, dummy_magmoms, labels_out + + def get_lattice( + self, + data, + length_strings=("a", "b", "c"), + angle_strings=("alpha", "beta", "gamma"), + lattice_type=None, + ): + """ + Generate the lattice from the provided lattice parameters. In + the absence of all six lattice parameters, the crystal system + and necessary parameters are parsed. + """ + try: + return self.get_lattice_no_exception( + data=data, angle_strings=angle_strings, lattice_type=lattice_type, length_strings=length_strings + ) + + except KeyError: + # Missing Key search for cell setting + for lattice_label in ["_symmetry_cell_setting", "_space_group_crystal_system"]: + if data.data.get(lattice_label): + lattice_type = data.data.get(lattice_label).lower() + try: + required_args = getargspec(getattr(Lattice, lattice_type)).args + + lengths = (length for length in length_strings if length in required_args) + angles = (a for a in angle_strings if a in required_args) + return self.get_lattice(data, lengths, angles, lattice_type=lattice_type) + except AttributeError as exc: + self.warnings.append(str(exc)) + warnings.warn(exc) + + else: + return None + return None + + @staticmethod + def get_lattice_no_exception( + data, length_strings=("a", "b", "c"), angle_strings=("alpha", "beta", "gamma"), lattice_type=None + ): + """ + Take a dictionary of CIF data and returns a pymatgen Lattice object. + + Args: + data: a dictionary of the CIF file + length_strings: The strings that are used to identify the length parameters in the CIF file. + angle_strings: The strings that are used to identify the angles in the CIF file. + lattice_type: The type of lattice. This is a string, and can be any of the following: + + Returns: + Lattice object + """ + lengths = [str2float(data["_cell_length_" + i]) for i in length_strings] + angles = [str2float(data["_cell_angle_" + i]) for i in angle_strings] + if not lattice_type: + return Lattice.from_parameters(*lengths, *angles) + return getattr(Lattice, lattice_type)(*(lengths + angles)) + + def get_symops(self, data): + """ + In order to generate symmetry equivalent positions, the symmetry + operations are parsed. If the symops are not present, the space + group symbol is parsed, and symops are generated. + """ + sym_ops = [] + for symmetry_label in [ + "_symmetry_equiv_pos_as_xyz", + "_symmetry_equiv_pos_as_xyz_", + "_space_group_symop_operation_xyz", + "_space_group_symop_operation_xyz_", + ]: + if data.data.get(symmetry_label): + xyz = data.data.get(symmetry_label) + if isinstance(xyz, str): + msg = "A 1-line symmetry op P1 CIF is detected!" + warnings.warn(msg) + self.warnings.append(msg) + xyz = [xyz] + try: + sym_ops = [SymmOp.from_xyz_str(s) for s in xyz] + break + except ValueError: + continue + if not sym_ops: + # Try to parse symbol + for symmetry_label in [ + "_symmetry_space_group_name_H-M", + "_symmetry_space_group_name_H_M", + "_symmetry_space_group_name_H-M_", + "_symmetry_space_group_name_H_M_", + "_space_group_name_Hall", + "_space_group_name_Hall_", + "_space_group_name_H-M_alt", + "_space_group_name_H-M_alt_", + "_symmetry_space_group_name_hall", + "_symmetry_space_group_name_hall_", + "_symmetry_space_group_name_h-m", + "_symmetry_space_group_name_h-m_", + ]: + sg = data.data.get(symmetry_label) + msg_template = "No _symmetry_equiv_pos_as_xyz type key found. Spacegroup from {} used." + + if sg: + sg = sub_spgrp(sg) + try: + spg = space_groups.get(sg) + if spg: + sym_ops = SpaceGroup(spg).symmetry_ops + msg = msg_template.format(symmetry_label) + warnings.warn(msg) + self.warnings.append(msg) + break + except ValueError: + # Ignore any errors + pass + + try: + cod_data = loadfn( + os.path.join(os.path.dirname(os.path.dirname(__file__)), "symmetry", "symm_ops.json") + ) + for d in cod_data: + if sg == re.sub(r"\s+", "", d["hermann_mauguin"]): + xyz = d["symops"] + sym_ops = [SymmOp.from_xyz_str(s) for s in xyz] + msg = msg_template.format(symmetry_label) + warnings.warn(msg) + self.warnings.append(msg) + break + except Exception: + continue + + if sym_ops: + break + if not sym_ops: + # Try to parse International number + for symmetry_label in [ + "_space_group_IT_number", + "_space_group_IT_number_", + "_symmetry_Int_Tables_number", + "_symmetry_Int_Tables_number_", + ]: + if data.data.get(symmetry_label): + try: + i = int(str2float(data.data.get(symmetry_label))) + sym_ops = SpaceGroup.from_int_number(i).symmetry_ops + break + except ValueError: + continue + + if not sym_ops: + msg = "No _symmetry_equiv_pos_as_xyz type key found. Defaulting to P1." + warnings.warn(msg) + self.warnings.append(msg) + sym_ops = [SymmOp.from_xyz_str(s) for s in ["x", "y", "z"]] + + return sym_ops + + def get_magsymops(self, data): + """ + Equivalent to get_symops except for magnetic symmetry groups. + Separate function since additional operation for time reversal symmetry + (which changes magnetic moments on sites) needs to be returned. + """ + mag_symm_ops = [] + bns_name = data.data.get("_space_group_magn.name_BNS") # get BNS label for MagneticSpaceGroup() + bns_num = data.data.get("_space_group_magn.number_BNS") # get BNS number for MagneticSpaceGroup() + + # check to see if magCIF file explicitly contains magnetic symmetry operations + if xyzt := data.data.get("_space_group_symop_magn_operation.xyz"): + if isinstance(xyzt, str): + xyzt = [xyzt] + mag_symm_ops = [MagSymmOp.from_xyzt_str(s) for s in xyzt] + + if data.data.get("_space_group_symop_magn_centering.xyz"): + xyzt = data.data.get("_space_group_symop_magn_centering.xyz") + if isinstance(xyzt, str): + xyzt = [xyzt] + centering_symops = [MagSymmOp.from_xyzt_str(s) for s in xyzt] + + all_ops = [] + for op in mag_symm_ops: + for centering_op in centering_symops: + new_translation = [ + i - np.floor(i) for i in op.translation_vector + centering_op.translation_vector + ] + new_time_reversal = op.time_reversal * centering_op.time_reversal + all_ops.append( + MagSymmOp.from_rotation_and_translation_and_time_reversal( + rotation_matrix=op.rotation_matrix, + translation_vec=new_translation, + time_reversal=new_time_reversal, + ) + ) + mag_symm_ops = all_ops + + # else check to see if it specifies a magnetic space group + elif bns_name or bns_num: + label = bns_name if bns_name else list(map(int, (bns_num.split(".")))) + + if data.data.get("_space_group_magn.transform_BNS_Pp_abc") != "a,b,c;0,0,0": + jonas_faithful = data.data.get("_space_group_magn.transform_BNS_Pp_abc") + msg = MagneticSpaceGroup(label, jonas_faithful) + + elif data.data.get("_space_group_magn.transform_BNS_Pp"): + return NotImplementedError("Incomplete specification to implement.") + else: + msg = MagneticSpaceGroup(label) + + mag_symm_ops = msg.symmetry_ops + + if not mag_symm_ops: + msg = "No magnetic symmetry detected, using primitive symmetry." + warnings.warn(msg) + self.warnings.append(msg) + mag_symm_ops = [MagSymmOp.from_xyzt_str("x, y, z, 1")] + + return mag_symm_ops + + @staticmethod + def parse_oxi_states(data): + """Parse oxidation states from data dictionary.""" + try: + oxi_states = { + data["_atom_type_symbol"][i]: str2float(data["_atom_type_oxidation_number"][i]) + for i in range(len(data["_atom_type_symbol"])) + } + # attempt to strip oxidation state from _atom_type_symbol + # in case the label does not contain an oxidation state + for i, symbol in enumerate(data["_atom_type_symbol"]): + oxi_states[re.sub(r"\d?[\+,\-]?$", "", symbol)] = str2float(data["_atom_type_oxidation_number"][i]) + + except (ValueError, KeyError): + oxi_states = None + return oxi_states + + @staticmethod + def parse_magmoms(data, lattice=None): + """Parse atomic magnetic moments from data dictionary.""" + if lattice is None: + raise Exception("Magmoms given in terms of crystal axes in magCIF spec.") + try: + magmoms = { + data["_atom_site_moment_label"][i]: np.array( + [ + str2float(data["_atom_site_moment_crystalaxis_x"][i]), + str2float(data["_atom_site_moment_crystalaxis_y"][i]), + str2float(data["_atom_site_moment_crystalaxis_z"][i]), + ] + ) + for i in range(len(data["_atom_site_moment_label"])) + } + except (ValueError, KeyError): + return None + return magmoms + + def _parse_symbol(self, sym): + """ + Parse a string with a symbol to extract a string representing an element. + + Args: + sym (str): A symbol to be parsed. + + Returns: + A string with the parsed symbol. None if no parsing was possible. + """ + # Common representations for elements/water in cif files + # TODO: fix inconsistent handling of water + special = { + "Hw": "H", + "Ow": "O", + "Wat": "O", + "wat": "O", + "OH": "", + "OH2": "", + "NO3": "N", + } + + parsed_sym = None + # try with special symbols, otherwise check the first two letters, + # then the first letter alone. If everything fails try extracting the + # first letters. + m_sp = re.match("|".join(special), sym) + if m_sp: + parsed_sym = special[m_sp.group()] + elif Element.is_valid_symbol(sym[:2].title()): + parsed_sym = sym[:2].title() + elif Element.is_valid_symbol(sym[0].upper()): + parsed_sym = sym[0].upper() + else: + m = re.match(r"w?[A-Z][a-z]*", sym) + if m: + parsed_sym = m.group() + + if parsed_sym is not None and (m_sp or not re.match(rf"{parsed_sym}\d*", sym)): + msg = f"{sym} parsed as {parsed_sym}" + warnings.warn(msg) + self.warnings.append(msg) + + return parsed_sym + + def _get_structure( + self, data: dict[str, Any], primitive: bool, symmetrized: bool, check_occu: bool = False + ) -> Structure | None: + """Generate structure from part of the cif.""" + + def get_num_implicit_hydrogens(sym): + num_h = {"Wat": 2, "wat": 2, "O-H": 1} + return num_h.get(sym[:3], 0) + + lattice = self.get_lattice(data) + + # if magCIF, get magnetic symmetry moments and magmoms + # else standard CIF, and use empty magmom dict + if self.feature_flags["magcif_incommensurate"]: + raise NotImplementedError("Incommensurate structures not currently supported.") + if self.feature_flags["magcif"]: + self.symmetry_operations = self.get_magsymops(data) + magmoms = self.parse_magmoms(data, lattice=lattice) + else: + self.symmetry_operations = self.get_symops(data) + magmoms = {} + + oxi_states = self.parse_oxi_states(data) + + coord_to_species = {} # type: ignore + coord_to_magmoms = {} + labels = {} + + def get_matching_coord(coord): + keys = list(coord_to_species) + coords = np.array(keys) + for op in self.symmetry_operations: + frac_coord = op.operate(coord) + indices = find_in_coord_list_pbc(coords, frac_coord, atol=self._site_tolerance) + if len(indices) > 0: + return keys[indices[0]] + return False + + for idx, label in enumerate(data["_atom_site_label"]): + try: + # If site type symbol exists, use it. Otherwise, we use the label. + symbol = self._parse_symbol(data["_atom_site_type_symbol"][idx]) + num_h = get_num_implicit_hydrogens(data["_atom_site_type_symbol"][idx]) + except KeyError: + symbol = self._parse_symbol(label) + num_h = get_num_implicit_hydrogens(label) + if not symbol: + continue + + if oxi_states is not None: + o_s = oxi_states.get(symbol, 0) + # use _atom_site_type_symbol if possible for oxidation state + if "_atom_site_type_symbol" in data.data: # type: ignore[attr-defined] + oxi_symbol = data["_atom_site_type_symbol"][idx] + o_s = oxi_states.get(oxi_symbol, o_s) + try: + el = Species(symbol, o_s) + except Exception: + el = DummySpecies(symbol, o_s) + else: + el = get_el_sp(symbol) # type: ignore + + x = str2float(data["_atom_site_fract_x"][idx]) + y = str2float(data["_atom_site_fract_y"][idx]) + z = str2float(data["_atom_site_fract_z"][idx]) + magmom = magmoms.get(label, np.array([0, 0, 0])) + + try: + occu = str2float(data["_atom_site_occupancy"][idx]) + except (KeyError, ValueError): + occu = 1 + # If check_occu is True or the occupancy is greater than 0, create comp_d + if not check_occu or occu > 0: + coord = (x, y, z) + match = get_matching_coord(coord) + comp_dict = {el: max(occu, 1e-8)} + + if num_h > 0: + comp_dict["H"] = num_h # type: ignore + self.warnings.append( + "Structure has implicit hydrogens defined, parsed structure unlikely to be " + "suitable for use in calculations unless hydrogens added." + ) + comp = Composition(comp_dict) + + if not match: + coord_to_species[coord] = comp + coord_to_magmoms[coord] = magmom + labels[coord] = label + else: + coord_to_species[match] += comp + # disordered magnetic not currently supported + coord_to_magmoms[match] = None + labels[match] = label + sum_occu = [ + sum(c.values()) for c in coord_to_species.values() if set(c.elements) != {Element("O"), Element("H")} + ] + if any(occu > 1 for occu in sum_occu): + msg = ( + f"Some occupancies ({sum_occu}) sum to > 1! If they are within " + "the occupancy_tolerance, they will be rescaled. " + f"The current occupancy_tolerance is set to: {self._occupancy_tolerance}" + ) + warnings.warn(msg) + self.warnings.append(msg) + + all_species = [] + all_coords = [] + all_magmoms = [] + all_hydrogens = [] + equivalent_indices = [] + all_labels = [] + + # check to see if magCIF file is disordered + if self.feature_flags["magcif"]: + for v in coord_to_magmoms.values(): + if v is None: + # Proposed solution to this is to instead store magnetic + # moments as Species 'spin' property, instead of site + # property, but this introduces ambiguities for end user + # (such as unintended use of `spin` and Species will have + # fictitious oxidation state). + raise NotImplementedError("Disordered magnetic structures not currently supported.") + + if coord_to_species.items(): + for idx, (comp, group) in enumerate( + groupby( + sorted(coord_to_species.items(), key=lambda x: x[1]), + key=lambda x: x[1], + ) + ): + tmp_coords = [site[0] for site in group] + tmp_magmom = [coord_to_magmoms[tmp_coord] for tmp_coord in tmp_coords] + + if self.feature_flags["magcif"]: + coords, magmoms, new_labels = self._unique_coords( + tmp_coords, magmoms=tmp_magmom, labels=labels, lattice=lattice + ) + else: + coords, magmoms, new_labels = self._unique_coords(tmp_coords, labels=labels) + + if set(comp.elements) == {Element("O"), Element("H")}: + # O with implicit hydrogens + im_h = comp["H"] + species = Composition({"O": comp["O"]}) + else: + im_h = 0 + species = comp + + # The following might be a more natural representation of equivalent indices, + # but is not in the format expect by SymmetrizedStructure: + # equivalent_indices.append(list(range(len(all_coords), len(coords)+len(all_coords)))) + # The above gives a list like: + # [[0, 1, 2, 3], [4, 5, 6, 7, 8, 9, 10, 11]] where the + # integers are site indices, whereas the version used below will give a version like: + # [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1] + # which is a list in the same order as the sites, but where if a site has the same integer + # it is equivalent. + equivalent_indices += len(coords) * [idx] + + all_hydrogens.extend(len(coords) * [im_h]) + all_coords.extend(coords) + all_species.extend(len(coords) * [species]) + all_magmoms.extend(magmoms) + all_labels.extend(new_labels) + + # rescale occupancies if necessary + all_species_noedit = all_species.copy() # save copy before scaling in case of check_occu=False, used below + for idx, species in enumerate(all_species): + total_occu = sum(species.values()) + if 1 < total_occu <= self._occupancy_tolerance: + all_species[idx] = species / total_occu + + if all_species and len(all_species) == len(all_coords) and len(all_species) == len(all_magmoms): + site_properties = {} + if any(all_hydrogens): + assert len(all_hydrogens) == len(all_coords) + site_properties["implicit_hydrogens"] = all_hydrogens + + if self.feature_flags["magcif"]: + site_properties["magmom"] = all_magmoms + + if len(site_properties) == 0: + site_properties = None # type: ignore + + if any(all_labels): + assert len(all_labels) == len(all_species) + else: + all_labels = None # type: ignore + + struct = Structure(lattice, all_species, all_coords, site_properties=site_properties, labels=all_labels) + + if symmetrized: + # Wyckoff labels not currently parsed, note that not all CIFs will contain Wyckoff labels + # TODO: extract Wyckoff labels (or other CIF attributes) and include as site_properties + wyckoffs = ["Not Parsed"] * len(struct) + + # space groups names are likewise not parsed (again, not all CIFs will contain this information) + # What is stored are the lists of symmetry operations used to generate the structure + # TODO: ensure space group labels are stored if present + sg = SpacegroupOperations("Not Parsed", -1, self.symmetry_operations) + struct = SymmetrizedStructure(struct, sg, equivalent_indices, wyckoffs) + + if not check_occu: + for idx in range(len(struct)): + struct[idx] = PeriodicSite( + all_species_noedit[idx], all_coords[idx], lattice, properties=site_properties, skip_checks=True + ) + + if symmetrized or not check_occu: + return struct + + struct = struct.get_sorted_structure() + + if primitive and self.feature_flags["magcif"]: + struct = struct.get_primitive_structure(use_site_props=True) + elif primitive: + struct = struct.get_primitive_structure() + struct = struct.get_reduced_structure() + + if self.check_cif: + cif_failure_reason = self.check(struct) + if cif_failure_reason is not None: + warnings.warn(cif_failure_reason) + + return struct + return None + + @np.deprecate( + message="get_structures is deprecated and will be removed in 2024. Use parse_structures instead." + "The only difference is that primitive defaults to False in the new parse_structures method." + "So parse_structures(primitive=True) is equivalent to the old behavior of get_structures().", + ) + def get_structures(self, *args, **kwargs) -> list[Structure]: + """ + Deprecated. Use parse_structures instead. Only difference between the two methods is the + default primitive=False in parse_structures. + So parse_structures(primitive=True) is equivalent to the old behavior of get_structures(). + """ + if len(args) > 0: # extract primitive if passed as arg + kwargs["primitive"] = args[0] + args = args[1:] + kwargs.setdefault("primitive", True) + return self.parse_structures(*args, **kwargs) + + def parse_structures( + self, + primitive: bool | None = None, + symmetrized: bool = False, + check_occu: bool = True, + on_error: Literal["ignore", "warn", "raise"] = "warn", + ) -> list[Structure]: + """Return list of structures in CIF file. + + Args: + primitive (bool): Set to True to return primitive unit cells. + Defaults to False. With magnetic CIF files, True will return primitive + magnetic cell which may be larger than nuclear primitive cell. + symmetrized (bool): If True, return a SymmetrizedStructure which will + include the equivalent indices and symmetry operations used to + create the Structure as provided by the CIF (if explicit symmetry + operations are included in the CIF) or generated from information + in the CIF (if only space group labels are provided). Note that + currently Wyckoff labels and space group labels or numbers are + not included in the generated SymmetrizedStructure, these will be + notated as "Not Parsed" or -1 respectively. + check_occu (bool): If False, site occupancy will not be checked, allowing unphysical + occupancy != 1. Useful for experimental results in which occupancy was allowed + to refine to unphysical values. Warning: unphysical site occupancies are incompatible + with many pymatgen features. Defaults to True. + on_error ('ignore' | 'warn' | 'raise'): What to do in case of KeyError or ValueError + while parsing CIF file. Defaults to 'warn'. + + Returns: + list[Structure]: All structures in CIF file. + """ + if os.getenv("CI") and datetime.now() > datetime(2024, 3, 1): # March 2024 seems long enough # pragma: no cover + raise RuntimeError("remove the change of default primitive=True to False made on 2023-10-24") + if primitive is None: + primitive = False + warnings.warn( + "The default value of primitive was changed from True to False in " + "https://github.com/materialsproject/pymatgen/pull/3419. CifParser now returns the cell " + "in the CIF file as is. If you want the primitive cell, please set primitive=True explicitly.", + UserWarning, + ) + if not check_occu: # added in https://github.com/materialsproject/pymatgen/pull/2836 + warnings.warn("Structures with unphysical site occupancies are not compatible with many pymatgen features.") + if primitive and symmetrized: + raise ValueError( + "Using both 'primitive' and 'symmetrized' arguments is not currently supported " + "since unexpected behavior might result." + ) + + structures = [] + for idx, dct in enumerate(self._cif.data.values()): + try: + struct = self._get_structure(dct, primitive, symmetrized, check_occu=check_occu) + if struct: + structures.append(struct) + except (KeyError, ValueError) as exc: + # A user reported a problem with cif files produced by Avogadro + # in which the atomic coordinates are in Cartesian coords. + msg = f"No structure parsed for section {idx + 1} in CIF.\n{exc}" + if on_error == "raise": + raise ValueError(msg) from exc + if on_error == "warn": + warnings.warn(msg) + self.warnings.append(msg) + # continue silently if on_error == "ignore" + + # if on_error == "raise" we don't get to here so no need to check + if self.warnings and on_error == "warn": + warnings.warn("Issues encountered while parsing CIF: " + "\n".join(self.warnings)) + + if len(structures) == 0: + raise ValueError("Invalid CIF file with no structures!") + return structures + + def get_bibtex_string(self): + """ + Get BibTeX reference from CIF file. + :param data: + + Returns: + BibTeX string. + """ + try: + from pybtex.database import BibliographyData, Entry + except ImportError: + raise RuntimeError("Bibliographic data extraction requires pybtex.") + + bibtex_keys = { + "author": ("_publ_author_name", "_citation_author_name"), + "title": ("_publ_section_title", "_citation_title"), + "journal": ( + "_journal_name_full", + "_journal_name_abbrev", + "_citation_journal_full", + "_citation_journal_abbrev", + ), + "volume": ("_journal_volume", "_citation_journal_volume"), + "year": ("_journal_year", "_citation_year"), + "number": ("_journal_number", "_citation_number"), + "page_first": ("_journal_page_first", "_citation_page_first"), + "page_last": ("_journal_page_last", "_citation_page_last"), + "doi": ("_journal_DOI", "_citation_DOI"), + } + + entries = {} + + # TODO: parse '_publ_section_references' when it exists? + # TODO: CIF specification supports multiple citations. + + for idx, data in enumerate(self._cif.data.values()): + # convert to lower-case keys, some cif files inconsistent + data = {k.lower(): v for k, v in data.data.items()} + + bibtex_entry = {} + + for field, tags in bibtex_keys.items(): + for tag in tags: + if tag in data: + if isinstance(data[tag], list): + bibtex_entry[field] = data[tag][0] + else: + bibtex_entry[field] = data[tag] + + # convert to bibtex author format ('and' delimited) + if "author" in bibtex_entry: + # separate out semicolon authors + if isinstance(bibtex_entry["author"], str) and ";" in bibtex_entry["author"]: + bibtex_entry["author"] = bibtex_entry["author"].split(";") + + if isinstance(bibtex_entry["author"], list): + bibtex_entry["author"] = " and ".join(bibtex_entry["author"]) + + # convert to bibtex page range format, use empty string if not specified + if ("page_first" in bibtex_entry) or ("page_last" in bibtex_entry): + bibtex_entry["pages"] = bibtex_entry.get("page_first", "") + "--" + bibtex_entry.get("page_last", "") + bibtex_entry.pop("page_first", None) # and remove page_first, page_list if present + bibtex_entry.pop("page_last", None) + + # cite keys are given as cif-reference-idx in order they are found + entries[f"cifref{idx}"] = Entry("article", list(bibtex_entry.items())) + + return BibliographyData(entries).to_string(bib_format="bibtex") + + def as_dict(self): + """MSONable dict""" + dct = {} + for k, v in self._cif.data.items(): + dct[k] = {} + for k2, v2 in v.data.items(): + dct[k][k2] = v2 + return dct + + @property + def has_errors(self): + """Whether there are errors/warnings detected in CIF parsing.""" + return len(self.warnings) > 0 + + def check(self, structure: Structure) -> str | None: + """Check whether a structure constructed from CIF passes sanity checks. + + Args: + structure (Structure) : structure created from CIF + + Returns: + str | None: If any check fails, on output, returns a human-readable str for the + reason why (e.g., which elements are missing). Returns None if all checks pass. + + Checks: + - Composition from CIF is valid + - CIF composition contains only valid elements + - CIF and structure contain the same elements (often hydrogens + are omitted from CIFs, as their positions cannot be determined from + X-ray diffraction, needs more difficult neutron diffraction) + - CIF and structure have same relative stoichiometry. Thus + if CIF reports stoichiometry LiFeO, and the structure has + composition (LiFeO)4, this check passes. + """ + failure_reason = None + + cif_as_dict = self.as_dict() + head_key = next(iter(cif_as_dict)) + + cif_formula = None + for key in ("_chemical_formula_sum", "_chemical_formula_structural"): + if cif_as_dict[head_key].get(key): + cif_formula = cif_as_dict[head_key][key] + break + + if cif_formula is None and cif_as_dict[head_key].get("_atom_site_type_symbol"): + cif_formula = " ".join(cif_as_dict[head_key]["_atom_site_type_symbol"]) + + try: + cif_composition = Composition(cif_formula) + except Exception as exc: + return f"Cannot determine chemical composition from CIF! {exc}" + + try: + orig_comp = cif_composition.remove_charges().as_dict() + struct_comp = structure.composition.remove_charges().as_dict() + except Exception as exc: + return str(exc) + + orig_comp_elts = {str(elt) for elt in orig_comp} + struct_comp_elts = {str(elt) for elt in struct_comp} + + if orig_comp_elts != struct_comp_elts: + # hard failure - missing elements + + missing = set(orig_comp_elts).difference(set(struct_comp_elts)) + addendum = "from PMG structure composition" + if len(missing) == 0: + addendum = "from CIF-reported composition" + missing = set(struct_comp_elts).difference(set(orig_comp_elts)) + missing_str = ", ".join([str(x) for x in missing]) + failure_reason = f"Missing elements {missing_str} {addendum}" + + elif not all(struct_comp[elt] - orig_comp[elt] == 0 for elt in orig_comp): + # Check that stoichiometry is same, i.e., same relative ratios of elements + ratios = {elt: struct_comp[elt] / orig_comp[elt] for elt in orig_comp_elts} + + same_stoich = all( + abs(ratios[elt_a] - ratios[elt_b]) < self.comp_tol + for elt_a in orig_comp_elts + for elt_b in orig_comp_elts + ) + + if not same_stoich: + failure_reason = f"Incorrect stoichiometry:\n CIF={orig_comp}\n PMG={struct_comp}\n {ratios=}" + + return failure_reason + + +class CifWriter: + """A wrapper around CifFile to write CIF files from pymatgen structures.""" + + def __init__( + self, + struct: Structure, + symprec: float | None = None, + write_magmoms: bool = False, + significant_figures: int = 8, + angle_tolerance: float = 5, + refine_struct: bool = True, + write_site_properties: bool = False, + printout: bool = False, + pos_order: bool = False, + full_order: bool = False, + ) -> None: + """ + Args: + struct (Structure): structure to write + symprec (float): If not none, finds the symmetry of the structure + and writes the cif with symmetry information. Passes symprec + to the SpacegroupAnalyzer. See also refine_struct. + write_magmoms (bool): If True, will write magCIF file. Incompatible + with symprec + significant_figures (int): Specifies precision for formatting of floats. + Defaults to 8. + angle_tolerance (float): Angle tolerance for symmetry finding. Passes + angle_tolerance to the SpacegroupAnalyzer. Used only if symprec + is not None. + refine_struct: Used only if symprec is not None. If True, get_refined_structure + is invoked to convert input structure from primitive to conventional. + write_site_properties (bool): Whether to write the Structure.site_properties + to the CIF as _atom_site_{property name}. Defaults to False. + """ + if write_magmoms and symprec: + warnings.warn("Magnetic symmetry cannot currently be detected by pymatgen,disabling symmetry detection.") + symprec = None + + format_str = f"{{:.{significant_figures}f}}" + + block: dict[str, Any] = {} + loops = [] + spacegroup = ("P 1", 1) + if symprec is not None: + spg_analyzer = SpacegroupAnalyzer(struct, symprec, angle_tolerance=angle_tolerance) + spacegroup = (spg_analyzer.get_space_group_symbol(), spg_analyzer.get_space_group_number()) + + if refine_struct: + # Needs the refined structure when using symprec. This converts + # primitive to conventional structures, the standard for CIF. + struct = spg_analyzer.get_refined_structure() + + lattice = struct.lattice + comp = struct.composition + no_oxi_comp = comp.element_composition + block["_symmetry_space_group_name_H-M"] = spacegroup[0] + for cell_attr in ["a", "b", "c"]: + block["_cell_length_" + cell_attr] = format_str.format(getattr(lattice, cell_attr)) + for cell_attr in ["alpha", "beta", "gamma"]: + block["_cell_angle_" + cell_attr] = format_str.format(getattr(lattice, cell_attr)) + block["_symmetry_Int_Tables_number"] = spacegroup[1] + block["_chemical_formula_structural"] = no_oxi_comp.reduced_formula + block["_chemical_formula_sum"] = no_oxi_comp.formula + block["_cell_volume"] = format_str.format(lattice.volume) + + _, fu = no_oxi_comp.get_reduced_composition_and_factor() + block["_cell_formula_units_Z"] = str(int(fu)) + + if symprec is None: + block["_symmetry_equiv_pos_site_id"] = ["1"] + block["_symmetry_equiv_pos_as_xyz"] = ["x, y, z"] + else: + spg_analyzer = SpacegroupAnalyzer(struct, symprec) + + symm_ops: list[SymmOp] = [] + for op in spg_analyzer.get_symmetry_operations(): + v = op.translation_vector + symm_ops.append(SymmOp.from_rotation_and_translation(op.rotation_matrix, v)) + + ops = [op.as_xyz_str() for op in symm_ops] + block["_symmetry_equiv_pos_site_id"] = [f"{i}" for i in range(1, len(ops) + 1)] + block["_symmetry_equiv_pos_as_xyz"] = ops + + loops.append(["_symmetry_equiv_pos_site_id", "_symmetry_equiv_pos_as_xyz"]) + + try: + symbol_to_oxi_num = {str(el): float(el.oxi_state or 0) for el in sorted(comp.elements)} + block["_atom_type_symbol"] = list(symbol_to_oxi_num) + block["_atom_type_oxidation_number"] = symbol_to_oxi_num.values() + loops.append(["_atom_type_symbol", "_atom_type_oxidation_number"]) + except (TypeError, AttributeError): + symbol_to_oxi_num = {el.symbol: 0 for el in sorted(comp.elements)} + + atom_site_type_symbol = [] + atom_site_symmetry_multiplicity = [] + atom_site_fract_x = [] + atom_site_fract_y = [] + atom_site_fract_z = [] + atom_site_label = [] + atom_site_occupancy = [] + atom_site_moment_label = [] + atom_site_moment_crystalaxis_x = [] + atom_site_moment_crystalaxis_y = [] + atom_site_moment_crystalaxis_z = [] + atom_site_properties: dict[str, list] = defaultdict(list) + count = 0 + if symprec is None: + for site in struct: + for sp, occu in sorted(site.species.items()): + atom_site_type_symbol.append(str(sp)) + atom_site_symmetry_multiplicity.append("1") + atom_site_fract_x.append(format_str.format(site.a)) + atom_site_fract_y.append(format_str.format(site.b)) + atom_site_fract_z.append(format_str.format(site.c)) + atom_site_occupancy.append(str(occu)) + site_label = f"{sp.symbol}{count}" + + if "magmom" in site.properties: + mag = site.properties["magmom"] + elif getattr(sp, "spin", None) is not None: + mag = sp.spin + else: + # Use site label if available for regular sites + site_label = site.label if site.label != site.species_string else site_label + mag = 0 + + atom_site_label.append(site_label) + + magmom = Magmom(mag) + if write_magmoms and abs(magmom) > 0: + moment = Magmom.get_moment_relative_to_crystal_axes(magmom, lattice) + atom_site_moment_label.append(f"{sp.symbol}{count}") + atom_site_moment_crystalaxis_x.append(format_str.format(moment[0])) + atom_site_moment_crystalaxis_y.append(format_str.format(moment[1])) + atom_site_moment_crystalaxis_z.append(format_str.format(moment[2])) + + if write_site_properties: + for key, val in site.properties.items(): + atom_site_properties[key].append(format_str.format(val)) + + count += 1 + else: + # The following just presents a deterministic ordering. + if full_order: + unique_sites = [site for site in struct] + for site in sorted( + unique_sites, + key=lambda t: ( + t.species.average_electroneg, + t.a, + t.b, + t.c, + ), + ): + for sp, occu in site.species.items(): + atom_site_type_symbol.append(str(sp)) + atom_site_symmetry_multiplicity.append("1") + atom_site_fract_x.append(format_str.format(site.a)) + atom_site_fract_y.append(format_str.format(site.b)) + atom_site_fract_z.append(format_str.format(site.c)) + site_label = site.label if site.label != site.species_string else f"{sp.symbol}{count}" + atom_site_label.append(site_label) + atom_site_occupancy.append(str(occu)) + count += 1 + else: + unique_sites = [ + ( + sorted(sites, key=lambda s: tuple(round(x % 1., 7) % 1. for x in s.frac_coords))[0], + len(sites), + ) + for sites in spg_analyzer.get_symmetrized_structure().equivalent_sites + ] + if not pos_order: + for site, mult in sorted( + unique_sites, + key=lambda t: ( + t[0].species.average_electroneg, + -t[1], + t[0].a, + t[0].b, + t[0].c, + ), + ): + for sp, occu in site.species.items(): + atom_site_type_symbol.append(str(sp)) + atom_site_symmetry_multiplicity.append(f"{mult}") + atom_site_fract_x.append(format_str.format(site.a)) + atom_site_fract_y.append(format_str.format(site.b)) + atom_site_fract_z.append(format_str.format(site.c)) + site_label = site.label if site.label != site.species_string else f"{sp.symbol}{count}" + atom_site_label.append(site_label) + atom_site_occupancy.append(str(occu)) + count += 1 + else: + for site, mult in sorted( + unique_sites, + key=lambda t: ( + t[0].a, + t[0].b, + t[0].c, + ), + ): + for sp, occu in site.species.items(): + atom_site_type_symbol.append(str(sp)) + atom_site_symmetry_multiplicity.append(f"{mult}") + atom_site_fract_x.append(format_str.format(site.a)) + atom_site_fract_y.append(format_str.format(site.b)) + atom_site_fract_z.append(format_str.format(site.c)) + site_label = site.label if site.label != site.species_string else f"{sp.symbol}{count}" + atom_site_label.append(site_label) + atom_site_occupancy.append(str(occu)) + count += 1 + + block["_atom_site_type_symbol"] = atom_site_type_symbol + block["_atom_site_label"] = atom_site_label + block["_atom_site_symmetry_multiplicity"] = atom_site_symmetry_multiplicity + block["_atom_site_fract_x"] = atom_site_fract_x + block["_atom_site_fract_y"] = atom_site_fract_y + block["_atom_site_fract_z"] = atom_site_fract_z + block["_atom_site_occupancy"] = atom_site_occupancy + loop_labels = [ + "_atom_site_type_symbol", + "_atom_site_label", + "_atom_site_symmetry_multiplicity", + "_atom_site_fract_x", + "_atom_site_fract_y", + "_atom_site_fract_z", + "_atom_site_occupancy", + ] + if write_site_properties: + for key, vals in atom_site_properties.items(): + block[f"_atom_site_{key}"] = vals + loop_labels += [f"_atom_site_{key}"] + loops.append(loop_labels) + + if write_magmoms: + block["_atom_site_moment_label"] = atom_site_moment_label + block["_atom_site_moment_crystalaxis_x"] = atom_site_moment_crystalaxis_x + block["_atom_site_moment_crystalaxis_y"] = atom_site_moment_crystalaxis_y + block["_atom_site_moment_crystalaxis_z"] = atom_site_moment_crystalaxis_z + loops.append( + [ + "_atom_site_moment_label", + "_atom_site_moment_crystalaxis_x", + "_atom_site_moment_crystalaxis_y", + "_atom_site_moment_crystalaxis_z", + ] + ) + dct = {} + dct[comp.reduced_formula] = CifBlock(block, loops, comp.reduced_formula) + self._cf = CifFile(dct) + + @property + def cif_file(self): + """Returns: CifFile associated with the CifWriter.""" + return self._cf + + def __str__(self): + """Returns the CIF as a string.""" + return str(self._cf) + + def write_file(self, filename: str | Path, mode: Literal["w", "a", "wt", "at"] = "w") -> None: + """Write the CIF file.""" + with zopen(filename, mode=mode) as file: + file.write(str(self)) + + +def str2float(text): + """Remove uncertainty brackets from strings and return the float.""" + try: + # Note that the ending ) is sometimes missing. That is why the code has + # been modified to treat it as optional. Same logic applies to lists. + return float(re.sub(r"\(.+\)*", "", text)) + except TypeError: + if isinstance(text, list) and len(text) == 1: + return float(re.sub(r"\(.+\)*", "", text[0])) + except ValueError as exc: + if text.strip() == ".": + return 0 + raise exc + raise ValueError(f"{text} cannot be converted to float") diff --git a/src/open_r1/tasks/crystal_structure/AIRS_preporcess/spacegroups.txt b/src/open_r1/tasks/crystal_structure/AIRS_preporcess/spacegroups.txt new file mode 100644 index 00000000..767fcd16 --- /dev/null +++ b/src/open_r1/tasks/crystal_structure/AIRS_preporcess/spacegroups.txt @@ -0,0 +1,227 @@ +P6/mmm +Imma +P4_32_12 +P4_2/mnm +Fd-3m +P3m1 +P-3 +P4mm +P4_332 +P4/nnc +P2_12_12 +Pnn2 +Pbcn +P4_2/n +Cm +R3m +Cmce +Aea2 +P-42_1m +P-42m +P2_13 +R-3 +Fm-3 +Cmm2 +Pn-3n +P6/mcc +P-6m2 +P3_2 +P-3m1 +P3_212 +I23 +P-62m +P4_2nm +Pma2 +Pmma +I-42m +P-31c +Pa-3 +Pmmn +Pmmm +P4_2/ncm +I4/mcm +I-4m2 +P3_1 +Pcc2 +Cmcm +I222 +Fddd +P312 +Cccm +P6_1 +F-43c +P6_322 +Pm-3 +P3_121 +P6_4 +Ia-3d +Pm-3m +P2_1/c +C222_1 +Pc +P4/n +Pba2 +Ama2 +Pbcm +P31m +Pcca +P222 +P-43n +Pccm +P6_422 +F23 +P42_12 +C222 +Pnnn +P6_3cm +P4_12_12 +P6/m +Fmm2 +I4_1/a +P4/mbm +Pmn2_1 +P4_2bc +P4_22_12 +I-43d +I4/m +P4bm +Fdd2 +P3 +P6_122 +Pnc2 +P4_2/mcm +P4_122 +Cmc2_1 +P-6c2 +R32 +P4_1 +P4_232 +Pnna +P422 +Pban +Cc +I4_122 +P6_3/m +P6_3mc +I4_1/amd +P4_2 +P4/nmm +Pmna +P4/m +Fm-3m +P4/mmm +Imm2 +P4/ncc +P-62c +Ima2 +P6_5 +P2/c +P4/nbm +Ibam +P6_522 +P6_3/mmc +I4/mmm +Fmmm +P2/m +P-4b2 +I-4 +C2/m +P4_2/mmc +P4 +Fd-3c +P4_3 +P2_1/m +I-43m +P-42c +F4_132 +Pm +Pccn +P-4n2 +P4_132 +P23 +I4cm +R3c +Amm2 +Immm +Iba2 +I4 +Fd-3 +P1 +Pbam +P4_2/nbc +Im-3 +P4_2/nnm +Pmc2_1 +P-31m +R-3m +Ia-3 +P622 +F222 +P2 +P-1 +Pmm2 +P-4 +Aem2 +P6_222 +P-3c1 +P4_322 +I422 +Pnma +P6_3 +P3c1 +Pn-3 +P4nc +P-6 +P4/mcc +I2_12_12_1 +P4_2/mbc +P31c +Ccc2 +P4_2/nmc +P6_3/mcm +C2 +Pbca +P-4c2 +I4_1cd +P2_1 +P3_112 +P4_2mc +Pn-3m +C2/c +R3 +P-43m +I432 +P222_1 +I-42d +I-4c2 +P6cc +P6_2 +P3_221 +P321 +Pca2_1 +I4_1/acd +I4_132 +F432 +Pna2_1 +Ccce +Ibca +P4/mnc +I4_1md +P2_12_12_1 +R-3c +I2_13 +P-4m2 +Pm-3n +I4mm +F-43m +Pnnm +P-42_1c +Cmmm +P6mm +P4_2cm +P4_2/m +Im-3m +Fm-3c +I4_1 +P4cc +Cmme diff --git a/src/open_r1/tasks/crystal_structure/relaxing.py b/src/open_r1/tasks/crystal_structure/relaxing.py new file mode 100644 index 00000000..19db9ff7 --- /dev/null +++ b/src/open_r1/tasks/crystal_structure/relaxing.py @@ -0,0 +1,134 @@ +import os +import re +from random import random +from typing import Dict, Optional +from open_r1.download_data import download_data +import pandas as pd +from datasets import Dataset, DatasetDict +from rdkit import Chem +from ..base import RLTask +import requests + + +class BinaryCompoundRelaxing(RLTask): + question_template: str = "" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if not os.path.exists(self.dataset_id_or_path): + os.makedirs(self.dataset_id_or_path) + download_data(self.dataset_id_or_path) + + self.src_train_file = os.path.join( + self.dataset_id_or_path, "src-train.txt" + ) + self.tgt_train_file = os.path.join( + self.dataset_id_or_path, "tgt-train.txt" + ) + self.src_test_file = ( + os.path.join(self.dataset_id_or_path, "src-test.txt") + if "src-test.txt" + else None + ) + self.tgt_test_file = ( + os.path.join(self.dataset_id_or_path, "tgt-test.txt") + if "tgt-test.txt" + else None + ) + self.question_template = ( + "<|im_start|>system You are a seasoned crystallographic structure analysis expert. " + "Your task is to relax a binary compound to a stable state. <|im_end|>\n" + "<|im_start|>user Given a perturbed binary compound:\n" + "{}\n, perform multiple steps of Structural Relaxation on the given perturbed binary compound " + "and reduce the internal energy. Please document your thought process within tags, and provide " + "the final corrected structure in tags using the proper format as given in the example:\n" + "serialized_cif formula Cd 1_int As 2_int \n" + "space_group_symbol I4_122_sg\n" + "lattice_parameters a 8.03811770 b 8.03811770 c 4.72563470 alpha 90.00000000 beta 90.00000000 gamma 90.00000000 \n" + "Cd 4_int 0.00000000 0.00000000 0.00000000\n" + "As 8_int 0.06170692 0.25000000 0.62500000\n" + "<|im_end|>\n" + ) + + # Dataset here: /iopsstor/store/cscs/swissai/a05/chem/CRLLM-PubChem-compounds1M.csv + + def read_files(self, src_file: str, tgt_file: str) -> Dict: + """Read source and target files and create dataset dictionary.""" + with open(src_file, "r", encoding="utf-8") as f: + problems = [ + self.question_template.format(self.process_line(line)) + for line in f.readlines() + ] + + with open(tgt_file, "r", encoding="utf-8") as f: + solutions = [self.process_line(line) for line in f.readlines()] + + return { + "problem": problems, + "solution": solutions, + } + + def load(self) -> DatasetDict: + """Load and return the complete dataset.""" + # Load training data + train_dict = self.read_files(self.src_train_file, self.tgt_train_file) + train_dataset = Dataset.from_dict(train_dict) + + # Load or create test data + if self.src_test_file and self.tgt_test_file: + test_dict = self.read_files(self.src_test_file, self.tgt_test_file) + test_dataset = Dataset.from_dict(test_dict) + else: + # Create test split from training data + train_test_split = train_dataset.train_test_split(test_size=0.1) + train_dataset = train_test_split["train"].unique(column="solution") + test_dataset = train_test_split["test"] + + # Combine into DatasetDict + self.dataset = DatasetDict( + {"train": train_dataset, "test": test_dataset} + ) + + return self.dataset + + def accuracy_reward(self, completions, solution, **kwargs): + """Reward function - check that completion is same as ground truth.""" + + answers = [self.preprocess_response(c) for c in completions] + + rewards = [] + + # Here task is simple: check that the smiles is the same as the target smiles + for content, sol in zip(answers, solution): + if content == "NONE": + rewards.append(-10) + continue + + server_url = os.environ.get("SERVER_URL", "http://10.197.48.175:9001/compute_score") + if content == sol: + rewards.append(-10) + continue + + payload = { + "answer_text": content, + "ground_truth": sol + } + + try: + response = requests.post(server_url, json=payload, timeout=20) + response.raise_for_status() + data = response.json() + reward = data.get("reward", -10) + rewards.append(reward) + except Exception as e: + rewards.append(-10) + return rewards + + def preprocess_response(self, response): + """Preprocess the response before checking for accuracy.""" + pattern = r"(.*)<\/answer>" + m = re.findall(pattern, response, re.DOTALL) + if m: + return m[-1].strip() + else: + return "NONE" diff --git a/src/open_r1/tasks/crystal_structure/reward_server.py b/src/open_r1/tasks/crystal_structure/reward_server.py new file mode 100644 index 00000000..2ac1e969 --- /dev/null +++ b/src/open_r1/tasks/crystal_structure/reward_server.py @@ -0,0 +1,139 @@ +from flask import Flask, request, jsonify +import gc +import random +from io import StringIO +from pymatgen.core import Structure +from pymatgen.analysis.structure_matcher import StructureMatcher +import gemmi +from pymatgen.io.cif import CifWriter +from AIRS_preporcess._tokenizer import CIFTokenizer +from mace.calculators import mace_mp +from ase.io import read + +def compare_internal_energy(cif1, cif2): + atoms1 = read(StringIO(cif1),format='cif') + atoms2 = read(StringIO(cif2),format='cif') + calc = mace_mp(model="large", device='cuda') + atoms1.calc = calc + atoms2.calc = calc + energy1_total = atoms1.get_potential_energy() + energy2_total = atoms2.get_potential_energy() + + energy1_per_atom = energy1_total / len(atoms1) + energy2_per_atom = energy2_total / len(atoms2) + print(";;;;;;;;;;;;;;;;;;;;;") + print("Orginal Internal Energy:", energy1_per_atom) + print("LLM Energy:", energy2_per_atom) + print(";;;;;;;;;;;;;;;;;;;;;") + if energy1_per_atom < energy2_per_atom: + return -4 + elif energy1_per_atom > energy2_per_atom: + return 1 + else: + return -10 + +app = Flask(__name__) +cif_tokenizer = CIFTokenizer() + +def parse_llm_structure(answer_text): + """ + """ + try: + return Structure.from_str(answer_text, fmt="cif") + except Exception as e: + print("Error in parse_llm_structure:", e) + return None + +def compute_score(answer_text, ground_truth, alpha=5.0): + """ + Calculate the score based on the structure match: + Logic description: + + """ + try: + answer_text = cif_tokenizer.deserialize(answer_text, ground_truth.get("ground_truth", "")) + except Exception as e: + print("format error 1", e) + return -10 + do_print = random.randint(1, 1) == 1 + if do_print: + print("-------------- START ------------------") + print("answer_text:", answer_text) + print("ground_cif:", ground_truth.get("ground_truth", "")) + + dft_cif = ground_truth.get("ground_truth", "") + if not dft_cif: + print("No ground truth CIF content provided.") + return -10 + + try: + doc = gemmi.cif.read_string(dft_cif) + doc.check_for_missing_values() + doc.check_for_duplicates() + + doc = gemmi.cif.read_string(answer_text) + doc.check_for_missing_values() + doc.check_for_duplicates() + except Exception as e: + print("CIF error:", e) + return -10 + + try: + dft_structure = Structure.from_str(dft_cif, fmt="cif") + if do_print: + print("dft_structure OK") + except Exception as e: + if do_print: + print("Error parsing DFT structure:", e) + return -10 + + try: + llm_structure = parse_llm_structure(answer_text) + if llm_structure is None: + return -10 + if do_print: + print("llm_structure OK") + except Exception as e: + print("Error parsing LLM-generated structure:", e) + return -10 + reward = -5 + try: + reward = compare_internal_energy(dft_cif, answer_text) + except Exception as e: + print("**************************") + print('CALC ERROR:', e) + print("**************************") + print("-------------- END ------------------") + return reward + + if do_print: + print(f"Reward: {reward}") + print("-------------- END ------------------") + return reward + +@app.route('/compute_score', methods=['POST']) +def compute_score_endpoint(): + """ + The interface /compute_score receives POST requests, and the JSON format content needs to include: + - answer_text: CIF content string generated by LLM + - ground_truth: dictionary containing the key "ground_truth", and the value is the CIF content after DFT optimization + + Return the calculation result in JSON format, for example: + { "reward": -0.123 } + """ + data = request.get_json() + if not data: + return jsonify({"error": "No JSON data provided"}), 400 + + answer_text = data.get("answer_text", "") + ground_truth = data.get("ground_truth", {}) + + if not answer_text or not ground_truth: + return jsonify({"error": "Missing required fields: answer_text and ground_truth"}), 400 + + reward = compute_score(answer_text, ground_truth) + gc.collect() + return jsonify({"reward": reward}) + +if __name__ == '__main__': + app.run(host='0.0.0.0', port=9001, debug=True) \ No newline at end of file From 6f2e93eca76cb757fd9f2e885fc450f6415a9c0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=90=E7=9D=BF=E6=99=BA?= Date: Sat, 12 Apr 2025 22:03:27 +1000 Subject: [PATCH 02/37] ADD MIT LICENSE --- .../dataset/crystal_structure_relaxing.py | 245 ++- .../crystal_structure/AIRS_preporcess/LICENSE | 21 + .../AIRS_preporcess/mycif.py | 1668 ----------------- 3 files changed, 265 insertions(+), 1669 deletions(-) create mode 100644 src/open_r1/tasks/crystal_structure/AIRS_preporcess/LICENSE delete mode 100644 src/open_r1/tasks/crystal_structure/AIRS_preporcess/mycif.py diff --git a/src/open_r1/dataset/crystal_structure_relaxing.py b/src/open_r1/dataset/crystal_structure_relaxing.py index 5b775511..0c351349 100644 --- a/src/open_r1/dataset/crystal_structure_relaxing.py +++ b/src/open_r1/dataset/crystal_structure_relaxing.py @@ -4,7 +4,250 @@ import pandas as pd import random from verl.utils.hdfs_io import copy, makedirs -from AIRS_preporcess._tokenizer import CIFTokenizer + +# This file contains code adapted from the AIRS project: +# https://github.com/divelab/AIRS/blob/main/OpenMat/Mat2Seq/mat2seq/_tokenizer.py +# +# Copyright (c) 2023 Luis M. Antunes +# Licensed under the MIT License. +# +# The CIFTokenizer class and related utilities are reused and modified here +# for downstream crystallographic tasks. +# +# Modifications by Ruizhi Xu, 2025 + +import re +from torch.utils.data import Dataset + +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) + + +with open(os.path.join(THIS_DIR, "spacegroups.txt"), "rt") as f: + SPACE_GROUPS = [sg.strip() for sg in f.readlines()] + + +ATOMS = ["Si", "C", "Pb", "I", "Br", "Cl", "Eu", "O", "Fe", "Sb", "In", "S", "N", "U", "Mn", "Lu", "Se", "Tl", "Hf", + "Ir", "Ca", "Ta", "Cr", "K", "Pm", "Mg", "Zn", "Cu", "Sn", "Ti", "B", "W", "P", "H", "Pd", "As", "Co", "Np", + "Tc", "Hg", "Pu", "Al", "Tm", "Tb", "Ho", "Nb", "Ge", "Zr", "Cd", "V", "Sr", "Ni", "Rh", "Th", "Na", "Ru", + "La", "Re", "Y", "Er", "Ce", "Pt", "Ga", "Li", "Cs", "F", "Ba", "Te", "Mo", "Gd", "Pr", "Bi", "Sc", "Ag", "Rb", + "Dy", "Yb", "Nd", "Au", "Os", "Pa", "Sm", "Be", "Ac", "Xe", "Kr", "He", "Ne", "Ar"] + +DIGITS = [str(d) for d in list(range(10))] + +INTS = [str(d) for d in list(range(300))] + +KEYWORDS = [ + "space_group_symbol", + "formula", + "atoms", + "lattice_parameters", + "a", + "b", + "c", + "alpha", + "beta", + "gamma" +] + +UNK_TOKEN = "" + +def get_spacegroup_number(sg_symbol): + try: + from pymatgen.symmetry.groups import SpaceGroup + sg = SpaceGroup(sg_symbol) + return sg + except Exception as e: + print("Err:", e) + return None + +def parse_formula(formula): + formula = formula.replace("'", "").replace('"', '').strip() + pattern = r"([A-Z][a-z]*)(\d*)" + counts = {} + for element, count in re.findall(pattern, formula): + counts[element] = counts.get(element, 0) + (int(count) if count else 1) + return counts + +def compute_cell_formula_units_Z(formula_sum, formula_structural): + counts_sum = parse_formula(formula_sum) + counts_struct = parse_formula(formula_structural) + + ratios = [] + for element, count_struct in counts_struct.items(): + if element not in counts_sum: + raise ValueError(f"{element}") + ratio = counts_sum[element] / count_struct + if ratio != int(ratio): + raise ValueError(f"{element}, {ratio} not int") + ratios.append(int(ratio)) + + if len(set(ratios)) != 1: + raise ValueError(f"{ratios} != 1") + return ratios[0] + +class CIFTokenizer: + def __init__(self): + self._tokens = [""] + self._tokens.extend(self.atoms()) + self._tokens.extend(self.digits()) + self._tokens.extend(self.keywords()) + self._tokens.extend(self.symbols()) + + space_groups = list(self.space_groups()) + # Replace 'Pm' space group with 'Pm_sg' to disambiguate from atom 'Pm', + # or 'P1' with 'P1_sg' to disambiguate from atom 'P' and number '1' + space_groups_sg = [sg+"_sg" for sg in space_groups] + self._tokens.extend(space_groups_sg) + + digits_int = [v+"_int" for v in INTS] + self._tokens.extend(digits_int) + + self._escaped_tokens = [re.escape(token) for token in self._tokens] + self._escaped_tokens.sort(key=len, reverse=True) + + # a mapping from characters to integers + self._token_to_id = {ch: i for i, ch in enumerate(self._tokens)} + self._id_to_token = {i: ch for i, ch in enumerate(self._tokens)} + # map the id of 'Pm_sg' back to 'Pm', or 'P1_sg' to 'P1', + # for decoding convenience + for sg in space_groups_sg: + self._id_to_token[self.token_to_id[sg]] = sg.replace("_sg", "") + + for v_int in digits_int: + self._id_to_token[self.token_to_id[v_int]] = v_int.replace("_int", "") + + @staticmethod + def atoms(): + return ATOMS + + @staticmethod + def digits(): + return DIGITS + + @staticmethod + def keywords(): + kws = list(KEYWORDS) + return kws + + @staticmethod + def symbols(): + # return ["x", "y", "z", ".", "(", ")", "+", "-", "/", "'", ",", " ", "\n"] + return [",", " ", ":", ".", "\n"] + + @staticmethod + def space_groups(): + return SPACE_GROUPS + + @property + def token_to_id(self): + return dict(self._token_to_id) + + @property + def id_to_token(self): + return dict(self._id_to_token) + + def encode(self, tokens): + # encoder: take a list of tokens, output a list of integers + return [self._token_to_id[t] for t in tokens] + + def decode(self, ids): + # decoder: take a list of integers (i.e. encoded tokens), output a string + return ''.join([self._id_to_token[i] for i in ids]) + + def serialize(self, cif_string): + spacegroups = "|".join(SPACE_GROUPS) + cif_string = re.sub(fr'(_symmetry_space_group_name_H-M *\b({spacegroups}))\n', r'\1_sg\n', cif_string) + extracted_data = self.tokenize_cif_preprocess(cif_string) + + seq_res = '' + # formula + seq_res += "formula " + formula = extracted_data["formula"] + elements_counts = re.findall(r'([A-Z][a-z]*)(\d*)', formula) + for element, count in elements_counts: + if not element: break + if not count: count ="1" + seq_res += element + " " + count + "_int " + seq_res += "\n" + # space group name + seq_res += "space_group_symbol " + extracted_data["space_group_symbol"] + "\n" + # lattice + seq_res += "lattice_parameters " + "a " + extracted_data["lattice_parameters"]["a"] + " " + seq_res += "b " + extracted_data["lattice_parameters"]["b"] + " " + seq_res += "c " + extracted_data["lattice_parameters"]["c"] + " " + seq_res += "alpha " + extracted_data["lattice_parameters"]["alpha"] + " " + seq_res += "beta " + extracted_data["lattice_parameters"]["beta"] + " " + seq_res += "gamma " + extracted_data["lattice_parameters"]["gamma"] + " " + seq_res += "\n" + # atoms + for idx in range(len(extracted_data["atoms"])): + tmp = extracted_data["atoms"][idx] + seq_res += tmp["type"] + " " + tmp["num"] + "_int " + tmp["coordinates"][0] + " " + tmp["coordinates"][1] + " " + tmp["coordinates"][2] + "\n" + seq_res += "\n" + # Create a regex pattern by joining the escaped tokens with '|' + token_pattern = '|'.join(self._escaped_tokens) + # Add a regex pattern to match any sequence of characters separated by whitespace or punctuation + full_pattern = f'({token_pattern}|\\w+|[\\.,;!?])' + # Tokenize the input string using the regex pattern + seq_res = re.sub(r'[ \t]+', ' ', seq_res) + return seq_res + + def tokenize_cif_preprocess(self, cif_string): + # Re-initialize the dictionary to hold the extracted data + extracted_data = { + "space_group_symbol": "", + "formula": "", + "atoms": [], + "lattice_parameters": {} + } + + # Split the text into lines for processing + lines = cif_string.split('\n') + + # Iterate through each line to extract the required information + atom_line_idx = -1 + for line_idx in range(len(lines)): + line = lines[line_idx] + # Extract space group symbol + if "_symmetry_space_group_name_H-M" in line: + spacegroup_match = re.search(r'_symmetry_space_group_name_H-M\s+([^\n]+)', line) + spacegroup = spacegroup_match.group(1).strip() + extracted_data["space_group_symbol"] = spacegroup + # Extract formula + elif line.startswith("data_"): + extracted_data["formula"] = line.split("_")[1] + # Extract lattice parameters + elif line.startswith("_cell_length_a"): + extracted_data["lattice_parameters"]["a"] = line.split()[-1] + elif line.startswith("_cell_length_b"): + extracted_data["lattice_parameters"]["b"] = line.split()[-1] + elif line.startswith("_cell_length_c"): + extracted_data["lattice_parameters"]["c"] = line.split()[-1] + elif line.startswith("_cell_angle_alpha"): + extracted_data["lattice_parameters"]["alpha"] = line.split()[-1] + elif line.startswith("_cell_angle_beta"): + extracted_data["lattice_parameters"]["beta"] = line.split()[-1] + elif line.startswith("_cell_angle_gamma"): + extracted_data["lattice_parameters"]["gamma"] = line.split()[-1] + elif "_atom_site_occupancy" in line: + atom_line_idx = line_idx + 1 + break + + for line_idx in range(atom_line_idx, len(lines)): + line = lines[line_idx] + if len(line) < 2: + continue + atom_info = line.split() + atom_type = atom_info[0] + num_atoms = atom_info[2] + x, y, z = atom_info[3], atom_info[4], atom_info[5] + extracted_data["atoms"].append({ + "type": atom_type, + "num": num_atoms, + "coordinates": (x, y, z) + }) + + return extracted_data # Initialize the tokenizer cif_tokenizer = CIFTokenizer() diff --git a/src/open_r1/tasks/crystal_structure/AIRS_preporcess/LICENSE b/src/open_r1/tasks/crystal_structure/AIRS_preporcess/LICENSE new file mode 100644 index 00000000..b552f185 --- /dev/null +++ b/src/open_r1/tasks/crystal_structure/AIRS_preporcess/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Luis M. Antunes + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/src/open_r1/tasks/crystal_structure/AIRS_preporcess/mycif.py b/src/open_r1/tasks/crystal_structure/AIRS_preporcess/mycif.py deleted file mode 100644 index 2f84a691..00000000 --- a/src/open_r1/tasks/crystal_structure/AIRS_preporcess/mycif.py +++ /dev/null @@ -1,1668 +0,0 @@ -"""Wrapper classes for Cif input and output from Structures.""" - -from __future__ import annotations - -import math -import os -import re -import textwrap -import warnings -from collections import defaultdict, deque -from datetime import datetime -from functools import partial -from inspect import getfullargspec as getargspec -from io import StringIO -from itertools import groupby -from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal - -import numpy as np -from monty.io import zopen -from monty.serialization import loadfn - -from pymatgen.core import Composition, DummySpecies, Element, Lattice, PeriodicSite, Species, Structure, get_el_sp -from pymatgen.core.operations import MagSymmOp, SymmOp -from pymatgen.electronic_structure.core import Magmom -from pymatgen.symmetry.analyzer import SpacegroupAnalyzer, SpacegroupOperations -from pymatgen.symmetry.groups import SYMM_DATA, SpaceGroup -from pymatgen.symmetry.maggroups import MagneticSpaceGroup -from pymatgen.symmetry.structure import SymmetrizedStructure -from pymatgen.util.coord import find_in_coord_list_pbc, in_coord_list_pbc - -if TYPE_CHECKING: - from pymatgen.core.trajectory import Vector3D - -__author__ = "Shyue Ping Ong, Will Richards, Matthew Horton" - -sub_spgrp = partial(re.sub, r"[\s_]", "") - -space_groups = {sub_spgrp(key): key for key in SYMM_DATA["space_group_encoding"]} # type: ignore - -space_groups.update({sub_spgrp(key): key for key in SYMM_DATA["space_group_encoding"]}) # type: ignore - - -class CifBlock: - """ - Object for storing cif data. All data is stored in a single dictionary. - Data inside loops are stored in lists in the data dictionary, and - information on which keys are grouped together are stored in the loops - attribute. - """ - - max_len = 70 # not quite 80 so we can deal with semicolons and things - - def __init__(self, data, loops, header): - """ - Args: - data: dict of data to go into the cif. Values should be convertible to string, - or lists of these if the key is in a loop - loops: list of lists of keys, grouped by which loop they should appear in - header: name of the block (appears after the data_ on the first line). - """ - self.loops = loops - self.data = data - # AJ (@computron) says: CIF Block names cannot be more than 75 characters or you get an Exception - self.header = header[:74] - - def __eq__(self, other: object) -> bool: - if not isinstance(other, CifBlock): - return NotImplemented - return self.loops == other.loops and self.data == other.data and self.header == other.header - - def __getitem__(self, key): - return self.data[key] - - def __str__(self) -> str: - """Returns the cif string for the data block.""" - out = [f"data_{self.header}"] - keys = list(self.data) - written = [] - for key in keys: - if key in written: - continue - for loop in self.loops: - # search for a corresponding loop - if key in loop: - out.append(self._loop_to_str(loop)) - written.extend(loop) - break - if key not in written: - # k didn't belong to a loop - v = self._format_field(self.data[key]) - if len(key) + len(v) + 3 < self.max_len: - out.append(f"{key} {v}") - else: - out.extend([key, v]) - return "\n".join(out) - - def _loop_to_str(self, loop): - out = "loop_" - for line in loop: - out += "\n " + line - for fields in zip(*(self.data[k] for k in loop)): - line = "\n" - for val in map(self._format_field, fields): - if val[0] == ";": - out += line + "\n" + val - line = "\n" - elif len(line) + len(val) + 2 < self.max_len: - line += " " + val - else: - out += line - line = "\n " + val - out += line - return out - - def _format_field(self, val) -> str: - val = str(val).strip() - if len(val) > self.max_len: - return f";\n{textwrap.fill(val, self.max_len)}\n;" - # add quotes if necessary - if val == "": - return '""' - if ( - (" " in val or val[0] == "_") - and not (val[0] == "'" and val[-1] == "'") - and not (val[0] == '"' and val[-1] == '"') - ): - quote = '"' if "'" in val else "'" - val = quote + val + quote - return val - - @classmethod - def _process_string(cls, string): - # remove comments - string = re.sub(r"(\s|^)#.*$", "", string, flags=re.MULTILINE) - # remove empty lines - string = re.sub(r"^\s*\n", "", string, flags=re.MULTILINE) - # remove non_ascii - string = string.encode("ascii", "ignore").decode("ascii") - # since line breaks in .cif files are mostly meaningless, - # break up into a stream of tokens to parse, rejoining multiline - # strings (between semicolons) - deq = deque() - multiline = False - ml = [] - # this regex splits on spaces, except when in quotes. starting quotes must not be - # preceded by non-whitespace (these get eaten by the first expression). ending - # quotes must not be followed by non-whitespace - pattern = re.compile(r"""([^'"\s][\S]*)|'(.*?)'(?!\S)|"(.*?)"(?!\S)""") - for line in string.splitlines(): - if multiline: - if line.startswith(";"): - multiline = False - deq.append(("", "", "", " ".join(ml))) - ml = [] - line = line[1:].strip() - else: - ml.append(line) - continue - if line.startswith(";"): - multiline = True - ml.append(line[1:].strip()) - else: - for string in pattern.findall(line): - # location of the data in string depends on whether it was quoted in the input - deq.append(tuple(string)) - return deq - - @classmethod - def from_str(cls, string): - """ - Reads CifBlock from string. - - :param string: String representation. - - Returns: - CifBlock - """ - q = cls._process_string(string) - header = q.popleft()[0][5:] - data = {} - loops = [] - while q: - s = q.popleft() - # cif keys aren't in quotes, so show up in s[0] - if s[0] == "_eof": - break - if s[0].startswith("_"): - try: - data[s[0]] = "".join(q.popleft()) - except IndexError: - data[s[0]] = "" - elif s[0].startswith("loop_"): - columns = [] - items = [] - while q: - s = q[0] - if s[0].startswith("loop_") or not s[0].startswith("_"): - break - columns.append("".join(q.popleft())) - data[columns[-1]] = [] - while q: - s = q[0] - if s[0].startswith(("loop_", "_")): - break - items.append("".join(q.popleft())) - n = len(items) // len(columns) - assert len(items) % n == 0 - loops.append(columns) - for k, v in zip(columns * n, items): - data[k].append(v.strip()) - elif issue := "".join(s).strip(): - warnings.warn(f"Possible issue in CIF file at line: {issue}") - return cls(data, loops, header) - - -class CifFile: - """Reads and parses CifBlocks from a .cif file or string.""" - - def __init__(self, data: dict, orig_string: str | None = None, comment: str | None = None) -> None: - """ - Args: - data (dict): Of CifBlock objects. - orig_string (str): The original cif string. - comment (str): Comment string. - """ - self.data = data - self.orig_string = orig_string - self.comment = comment or "# generated using pymatgen" - - def __str__(self): - out = "\n".join(map(str, self.data.values())) - return f"{self.comment}\n{out}\n" - - @classmethod - def from_str(cls, string) -> CifFile: - """Reads CifFile from a string. - - :param string: String representation. - - Returns: - CifFile - """ - dct = {} - - for block_str in re.split(r"^\s*data_", f"x\n{string}", flags=re.MULTILINE | re.DOTALL)[1:]: - # Skip over Cif block that contains powder diffraction data. - # Some elements in this block were missing from CIF files in - # Springer materials/Pauling file DBs. - # This block does not contain any structure information anyway, and - # CifParser was also not parsing it. - if "powder_pattern" in re.split(r"\n", block_str, maxsplit=1)[0]: - continue - block = CifBlock.from_str(f"data_{block_str}") - # TODO (@janosh, 2023-10-11) multiple CIF blocks with equal header will overwrite each other, - # latest taking precedence. maybe something to fix and test e.g. in test_cif_writer_write_file - dct[block.header] = block - - return cls(dct, string) - - @classmethod - def from_file(cls, filename: str | Path) -> CifFile: - """ - Reads CifFile from a filename. - - :param filename: Filename - - Returns: - CifFile - """ - with zopen(str(filename), mode="rt", errors="replace") as file: - return cls.from_str(file.read()) - - -class CifParser: - """ - Parses a CIF file. Attempts to fix CIFs that are out-of-spec, but will issue warnings - if corrections applied. These are also stored in the CifParser's errors attribute. - """ - - def __init__( - self, - filename: str | StringIO, - occupancy_tolerance: float = 1.0, - site_tolerance: float = 1e-4, - frac_tolerance: float = 1e-4, - check_cif: bool = True, - comp_tol: float = 0.01, - ) -> None: - """ - Args: - filename (str): CIF filename, gzipped or bzipped CIF files are fine too. - occupancy_tolerance (float): If total occupancy of a site is between 1 and occupancy_tolerance, the - occupancies will be scaled down to 1. - site_tolerance (float): This tolerance is used to determine if two sites are sitting in the same position, - in which case they will be combined to a single disordered site. Defaults to 1e-4. - frac_tolerance (float): This tolerance is used to determine is a coordinate should be rounded to an ideal - value. E.g., 0.6667 is rounded to 2/3. This is desired if symmetry operations are going to be applied. - However, for very large CIF files, this may need to be set to 0. - check_cif (bool): Whether to check that stoichiometry reported in CIF matches - that of resulting Structure, and whether elements are missing. Defaults to True. - comp_tol (float): Tolerance for how closely stoichiometries of CIF file and pymatgen should match. - Defaults to 0.01. Context: Experimental CIF files often don't report hydrogens positions due to being - hard-to-locate with X-rays. pymatgen warns if the stoichiometry of the CIF file and the Structure - don't match to within comp_tol. - """ - self._occupancy_tolerance = occupancy_tolerance - self._site_tolerance = site_tolerance - self._frac_tolerance = frac_tolerance - if isinstance(filename, (str, Path)): - self._cif = CifFile.from_file(filename) - else: - self._cif = CifFile.from_str(filename.read()) - - # options related to checking CIFs for missing elements - # or incorrect stoichiometries - self.check_cif = check_cif - self.comp_tol = comp_tol - - # store if CIF contains features from non-core CIF dictionaries - # e.g. magCIF - self.feature_flags = {} - self.warnings: list[str] = [] - - def is_magcif() -> bool: - """Checks to see if file appears to be a magCIF file (heuristic).""" - # Doesn't seem to be a canonical way to test if file is magCIF or - # not, so instead check for magnetic symmetry datanames - prefixes = ["_space_group_magn", "_atom_site_moment", "_space_group_symop_magn"] - for d in self._cif.data.values(): - for k in d.data: - for prefix in prefixes: - if prefix in k: - return True - return False - - self.feature_flags["magcif"] = is_magcif() - - def is_magcif_incommensurate() -> bool: - """ - Checks to see if file contains an incommensurate magnetic - structure (heuristic). - """ - # Doesn't seem to be a canonical way to test if magCIF file - # describes incommensurate structure or not, so instead check - # for common datanames - if not self.feature_flags["magcif"]: - return False - prefixes = ["_cell_modulation_dimension", "_cell_wave_vector"] - for d in self._cif.data.values(): - for k in d.data: - for prefix in prefixes: - if prefix in k: - return True - return False - - self.feature_flags["magcif_incommensurate"] = is_magcif_incommensurate() - - for key in self._cif.data: - # pass individual CifBlocks to _sanitize_data - self._cif.data[key] = self._sanitize_data(self._cif.data[key]) - - @classmethod - def from_str(cls, cif_string: str, **kwargs) -> CifParser: - """ - Creates a CifParser from a string. - - Args: - cif_string (str): String representation of a CIF. - **kwargs: Passthrough of all kwargs supported by CifParser. - - Returns: - CifParser - """ - stream = StringIO(cif_string) - return cls(stream, **kwargs) - - def _sanitize_data(self, data): - """ - Some CIF files do not conform to spec. This function corrects - known issues, particular in regards to Springer materials/ - Pauling files. - - This function is here so that CifParser can assume its - input conforms to spec, simplifying its implementation. - :param data: CifBlock - - Returns: - data CifBlock - """ - """ - This part of the code deals with handling formats of data as found in - CIF files extracted from the Springer Materials/Pauling File - databases, and that are different from standard ICSD formats. - """ - # check for implicit hydrogens, warn if any present - if "_atom_site_attached_hydrogens" in data.data: - attached_hydrogens = [str2float(x) for x in data.data["_atom_site_attached_hydrogens"] if str2float(x) != 0] - if len(attached_hydrogens) > 0: - self.warnings.append( - "Structure has implicit hydrogens defined, parsed structure unlikely to be " - "suitable for use in calculations unless hydrogens added." - ) - - # Check to see if "_atom_site_type_symbol" exists, as some test CIFs do - # not contain this key. - if "_atom_site_type_symbol" in data.data: - # Keep a track of which data row needs to be removed. - # Example of a row: Nb,Zr '0.8Nb + 0.2Zr' .2a .m-3m 0 0 0 1 14 - # 'rhombic dodecahedron, Nb14' - # Without this code, the above row in a structure would be parsed - # as an ordered site with only Nb (since - # CifParser would try to parse the first two characters of the - # label "Nb,Zr") and occupancy=1. - # However, this site is meant to be a disordered site with 0.8 of - # Nb and 0.2 of Zr. - idxs_to_remove = [] - - new_atom_site_label = [] - new_atom_site_type_symbol = [] - new_atom_site_occupancy = [] - new_fract_x = [] - new_fract_y = [] - new_fract_z = [] - - for idx, el_row in enumerate(data["_atom_site_label"]): - # CIF files from the Springer Materials/Pauling File have - # switched the label and symbol. Thus, in the - # above shown example row, '0.8Nb + 0.2Zr' is the symbol. - # Below, we split the strings on ' + ' to - # check if the length (or number of elements) in the label and - # symbol are equal. - if len(data["_atom_site_type_symbol"][idx].split(" + ")) > len(el_row.split(" + ")): - # Dictionary to hold extracted elements and occupancies - els_occu = {} - - # parse symbol to get element names and occupancy and store - # in "els_occu" - symbol_str = data["_atom_site_type_symbol"][idx] - symbol_str_lst = symbol_str.split(" + ") - for elocc_idx, sym in enumerate(symbol_str_lst): - # Remove any bracketed items in the string - symbol_str_lst[elocc_idx] = re.sub(r"\([0-9]*\)", "", sym.strip()) - - # Extract element name and its occupancy from the - # string, and store it as a - # key-value pair in "els_occ". - els_occu[ - str(re.findall(r"\D+", symbol_str_lst[elocc_idx].strip())[1]).replace("", "") - ] = float("0" + re.findall(r"\.?\d+", symbol_str_lst[elocc_idx].strip())[1]) - - x = str2float(data["_atom_site_fract_x"][idx]) - y = str2float(data["_atom_site_fract_y"][idx]) - z = str2float(data["_atom_site_fract_z"][idx]) - - for et, occu in els_occu.items(): - # new atom site labels have 'fix' appended - new_atom_site_label.append(f"{et}_fix{len(new_atom_site_label)}") - new_atom_site_type_symbol.append(et) - new_atom_site_occupancy.append(str(occu)) - new_fract_x.append(str(x)) - new_fract_y.append(str(y)) - new_fract_z.append(str(z)) - - idxs_to_remove.append(idx) - - # Remove the original row by iterating over all keys in the CIF - # data looking for lists, which indicates - # multiple data items, one for each row, and remove items from the - # list that corresponds to the removed row, - # so that it's not processed by the rest of this function (which - # would result in an error). - for original_key in data.data: - if isinstance(data.data[original_key], list): - for idx in sorted(idxs_to_remove, reverse=True): - del data.data[original_key][idx] - - if len(idxs_to_remove) > 0: - self.warnings.append("Pauling file corrections applied.") - - data.data["_atom_site_label"] += new_atom_site_label - data.data["_atom_site_type_symbol"] += new_atom_site_type_symbol - data.data["_atom_site_occupancy"] += new_atom_site_occupancy - data.data["_atom_site_fract_x"] += new_fract_x - data.data["_atom_site_fract_y"] += new_fract_y - data.data["_atom_site_fract_z"] += new_fract_z - # This fixes inconsistencies in naming of several magCIF tags as a result of magCIF - # being in widespread use prior to specification being finalized (on advice of Branton Campbell). - if self.feature_flags["magcif"]: - # CIF-1 style has all underscores, interim standard - # had period before magn instead of before the final - # component (e.g. xyz) - # we want to standardize on a specific key, to simplify - # parsing code - correct_keys = [ - "_space_group_symop_magn_operation.xyz", - "_space_group_symop_magn_centering.xyz", - "_space_group_magn.name_BNS", - "_space_group_magn.number_BNS", - "_atom_site_moment_crystalaxis_x", - "_atom_site_moment_crystalaxis_y", - "_atom_site_moment_crystalaxis_z", - "_atom_site_moment_label", - ] - - # cannot mutate dict during enumeration, so store changes we want to make - changes_to_make = {} - - for original_key in data.data: - for correct_key in correct_keys: - # convert to all underscore - trial_key = "_".join(correct_key.split(".")) - test_key = "_".join(original_key.split(".")) - if trial_key == test_key: - changes_to_make[correct_key] = original_key - - # make changes - for correct_key, original_key in changes_to_make.items(): - data.data[correct_key] = data.data[original_key] - - # renamed_keys maps interim_keys to final_keys - renamed_keys = { - "_magnetic_space_group.transform_to_standard_Pp_abc": "_space_group_magn.transform_BNS_Pp_abc" - } - changes_to_make = {} - - for interim_key, final_key in renamed_keys.items(): - if data.data.get(interim_key): - changes_to_make[final_key] = interim_key - - if len(changes_to_make) > 0: - self.warnings.append("Keys changed to match new magCIF specification.") - - for final_key, interim_key in changes_to_make.items(): - data.data[final_key] = data.data[interim_key] - - # check for finite precision frac coordinates (e.g. 0.6667 instead of 0.6666666...7) - # this can sometimes cause serious issues when applying symmetry operations - important_fracs = (1 / 3, 2 / 3) - fracs_to_change = {} - for label in ("_atom_site_fract_x", "_atom_site_fract_y", "_atom_site_fract_z"): - if label in data.data: - for idx, frac in enumerate(data.data[label]): - try: - frac = str2float(frac) - except Exception: - # coordinate might not be defined e.g. '?' - continue - for comparison_frac in important_fracs: - if abs(1 - frac / comparison_frac) < self._frac_tolerance: - fracs_to_change[(label, idx)] = str(comparison_frac) - if fracs_to_change: - self.warnings.append( - f"{len(fracs_to_change)} fractional coordinates rounded to ideal values to avoid issues with " - "finite precision." - ) - for (label, idx), val in fracs_to_change.items(): - data.data[label][idx] = val - - return data - - def _unique_coords( - self, - coords: list[Vector3D], - magmoms: list[Magmom] | None = None, - lattice: Lattice | None = None, - labels: dict[Vector3D, str] | None = None, - ): - """ - Generate unique coordinates using coord and symmetry positions - and also their corresponding magnetic moments, if supplied. - """ - coords_out: list[np.ndarray] = [] - labels_out = [] - labels = labels or {} - - if magmoms: - magmoms_out = [] - if len(magmoms) != len(coords): - raise ValueError - for tmp_coord, tmp_magmom in zip(coords, magmoms): - for op in self.symmetry_operations: - coord = op.operate(tmp_coord) - coord = np.array([i - math.floor(i) for i in coord]) - if isinstance(op, MagSymmOp): - # Up to this point, magmoms have been defined relative - # to crystal axis. Now convert to Cartesian and into - # a Magmom object. - magmom = Magmom.from_moment_relative_to_crystal_axes( - op.operate_magmom(tmp_magmom), lattice=lattice - ) - else: - magmom = Magmom(tmp_magmom) - if not in_coord_list_pbc(coords_out, coord, atol=self._site_tolerance): - coords_out.append(coord) - magmoms_out.append(magmom) - labels_out.append(labels.get(tmp_coord)) - return coords_out, magmoms_out, labels_out - - for tmp_coord in coords: - for op in self.symmetry_operations: - coord = op.operate(tmp_coord) - coord = np.array([i - math.floor(i) for i in coord]) - if not in_coord_list_pbc(coords_out, coord, atol=self._site_tolerance): - coords_out.append(coord) - labels_out.append(labels.get(tmp_coord)) - - dummy_magmoms = [Magmom(0)] * len(coords_out) - return coords_out, dummy_magmoms, labels_out - - def get_lattice( - self, - data, - length_strings=("a", "b", "c"), - angle_strings=("alpha", "beta", "gamma"), - lattice_type=None, - ): - """ - Generate the lattice from the provided lattice parameters. In - the absence of all six lattice parameters, the crystal system - and necessary parameters are parsed. - """ - try: - return self.get_lattice_no_exception( - data=data, angle_strings=angle_strings, lattice_type=lattice_type, length_strings=length_strings - ) - - except KeyError: - # Missing Key search for cell setting - for lattice_label in ["_symmetry_cell_setting", "_space_group_crystal_system"]: - if data.data.get(lattice_label): - lattice_type = data.data.get(lattice_label).lower() - try: - required_args = getargspec(getattr(Lattice, lattice_type)).args - - lengths = (length for length in length_strings if length in required_args) - angles = (a for a in angle_strings if a in required_args) - return self.get_lattice(data, lengths, angles, lattice_type=lattice_type) - except AttributeError as exc: - self.warnings.append(str(exc)) - warnings.warn(exc) - - else: - return None - return None - - @staticmethod - def get_lattice_no_exception( - data, length_strings=("a", "b", "c"), angle_strings=("alpha", "beta", "gamma"), lattice_type=None - ): - """ - Take a dictionary of CIF data and returns a pymatgen Lattice object. - - Args: - data: a dictionary of the CIF file - length_strings: The strings that are used to identify the length parameters in the CIF file. - angle_strings: The strings that are used to identify the angles in the CIF file. - lattice_type: The type of lattice. This is a string, and can be any of the following: - - Returns: - Lattice object - """ - lengths = [str2float(data["_cell_length_" + i]) for i in length_strings] - angles = [str2float(data["_cell_angle_" + i]) for i in angle_strings] - if not lattice_type: - return Lattice.from_parameters(*lengths, *angles) - return getattr(Lattice, lattice_type)(*(lengths + angles)) - - def get_symops(self, data): - """ - In order to generate symmetry equivalent positions, the symmetry - operations are parsed. If the symops are not present, the space - group symbol is parsed, and symops are generated. - """ - sym_ops = [] - for symmetry_label in [ - "_symmetry_equiv_pos_as_xyz", - "_symmetry_equiv_pos_as_xyz_", - "_space_group_symop_operation_xyz", - "_space_group_symop_operation_xyz_", - ]: - if data.data.get(symmetry_label): - xyz = data.data.get(symmetry_label) - if isinstance(xyz, str): - msg = "A 1-line symmetry op P1 CIF is detected!" - warnings.warn(msg) - self.warnings.append(msg) - xyz = [xyz] - try: - sym_ops = [SymmOp.from_xyz_str(s) for s in xyz] - break - except ValueError: - continue - if not sym_ops: - # Try to parse symbol - for symmetry_label in [ - "_symmetry_space_group_name_H-M", - "_symmetry_space_group_name_H_M", - "_symmetry_space_group_name_H-M_", - "_symmetry_space_group_name_H_M_", - "_space_group_name_Hall", - "_space_group_name_Hall_", - "_space_group_name_H-M_alt", - "_space_group_name_H-M_alt_", - "_symmetry_space_group_name_hall", - "_symmetry_space_group_name_hall_", - "_symmetry_space_group_name_h-m", - "_symmetry_space_group_name_h-m_", - ]: - sg = data.data.get(symmetry_label) - msg_template = "No _symmetry_equiv_pos_as_xyz type key found. Spacegroup from {} used." - - if sg: - sg = sub_spgrp(sg) - try: - spg = space_groups.get(sg) - if spg: - sym_ops = SpaceGroup(spg).symmetry_ops - msg = msg_template.format(symmetry_label) - warnings.warn(msg) - self.warnings.append(msg) - break - except ValueError: - # Ignore any errors - pass - - try: - cod_data = loadfn( - os.path.join(os.path.dirname(os.path.dirname(__file__)), "symmetry", "symm_ops.json") - ) - for d in cod_data: - if sg == re.sub(r"\s+", "", d["hermann_mauguin"]): - xyz = d["symops"] - sym_ops = [SymmOp.from_xyz_str(s) for s in xyz] - msg = msg_template.format(symmetry_label) - warnings.warn(msg) - self.warnings.append(msg) - break - except Exception: - continue - - if sym_ops: - break - if not sym_ops: - # Try to parse International number - for symmetry_label in [ - "_space_group_IT_number", - "_space_group_IT_number_", - "_symmetry_Int_Tables_number", - "_symmetry_Int_Tables_number_", - ]: - if data.data.get(symmetry_label): - try: - i = int(str2float(data.data.get(symmetry_label))) - sym_ops = SpaceGroup.from_int_number(i).symmetry_ops - break - except ValueError: - continue - - if not sym_ops: - msg = "No _symmetry_equiv_pos_as_xyz type key found. Defaulting to P1." - warnings.warn(msg) - self.warnings.append(msg) - sym_ops = [SymmOp.from_xyz_str(s) for s in ["x", "y", "z"]] - - return sym_ops - - def get_magsymops(self, data): - """ - Equivalent to get_symops except for magnetic symmetry groups. - Separate function since additional operation for time reversal symmetry - (which changes magnetic moments on sites) needs to be returned. - """ - mag_symm_ops = [] - bns_name = data.data.get("_space_group_magn.name_BNS") # get BNS label for MagneticSpaceGroup() - bns_num = data.data.get("_space_group_magn.number_BNS") # get BNS number for MagneticSpaceGroup() - - # check to see if magCIF file explicitly contains magnetic symmetry operations - if xyzt := data.data.get("_space_group_symop_magn_operation.xyz"): - if isinstance(xyzt, str): - xyzt = [xyzt] - mag_symm_ops = [MagSymmOp.from_xyzt_str(s) for s in xyzt] - - if data.data.get("_space_group_symop_magn_centering.xyz"): - xyzt = data.data.get("_space_group_symop_magn_centering.xyz") - if isinstance(xyzt, str): - xyzt = [xyzt] - centering_symops = [MagSymmOp.from_xyzt_str(s) for s in xyzt] - - all_ops = [] - for op in mag_symm_ops: - for centering_op in centering_symops: - new_translation = [ - i - np.floor(i) for i in op.translation_vector + centering_op.translation_vector - ] - new_time_reversal = op.time_reversal * centering_op.time_reversal - all_ops.append( - MagSymmOp.from_rotation_and_translation_and_time_reversal( - rotation_matrix=op.rotation_matrix, - translation_vec=new_translation, - time_reversal=new_time_reversal, - ) - ) - mag_symm_ops = all_ops - - # else check to see if it specifies a magnetic space group - elif bns_name or bns_num: - label = bns_name if bns_name else list(map(int, (bns_num.split(".")))) - - if data.data.get("_space_group_magn.transform_BNS_Pp_abc") != "a,b,c;0,0,0": - jonas_faithful = data.data.get("_space_group_magn.transform_BNS_Pp_abc") - msg = MagneticSpaceGroup(label, jonas_faithful) - - elif data.data.get("_space_group_magn.transform_BNS_Pp"): - return NotImplementedError("Incomplete specification to implement.") - else: - msg = MagneticSpaceGroup(label) - - mag_symm_ops = msg.symmetry_ops - - if not mag_symm_ops: - msg = "No magnetic symmetry detected, using primitive symmetry." - warnings.warn(msg) - self.warnings.append(msg) - mag_symm_ops = [MagSymmOp.from_xyzt_str("x, y, z, 1")] - - return mag_symm_ops - - @staticmethod - def parse_oxi_states(data): - """Parse oxidation states from data dictionary.""" - try: - oxi_states = { - data["_atom_type_symbol"][i]: str2float(data["_atom_type_oxidation_number"][i]) - for i in range(len(data["_atom_type_symbol"])) - } - # attempt to strip oxidation state from _atom_type_symbol - # in case the label does not contain an oxidation state - for i, symbol in enumerate(data["_atom_type_symbol"]): - oxi_states[re.sub(r"\d?[\+,\-]?$", "", symbol)] = str2float(data["_atom_type_oxidation_number"][i]) - - except (ValueError, KeyError): - oxi_states = None - return oxi_states - - @staticmethod - def parse_magmoms(data, lattice=None): - """Parse atomic magnetic moments from data dictionary.""" - if lattice is None: - raise Exception("Magmoms given in terms of crystal axes in magCIF spec.") - try: - magmoms = { - data["_atom_site_moment_label"][i]: np.array( - [ - str2float(data["_atom_site_moment_crystalaxis_x"][i]), - str2float(data["_atom_site_moment_crystalaxis_y"][i]), - str2float(data["_atom_site_moment_crystalaxis_z"][i]), - ] - ) - for i in range(len(data["_atom_site_moment_label"])) - } - except (ValueError, KeyError): - return None - return magmoms - - def _parse_symbol(self, sym): - """ - Parse a string with a symbol to extract a string representing an element. - - Args: - sym (str): A symbol to be parsed. - - Returns: - A string with the parsed symbol. None if no parsing was possible. - """ - # Common representations for elements/water in cif files - # TODO: fix inconsistent handling of water - special = { - "Hw": "H", - "Ow": "O", - "Wat": "O", - "wat": "O", - "OH": "", - "OH2": "", - "NO3": "N", - } - - parsed_sym = None - # try with special symbols, otherwise check the first two letters, - # then the first letter alone. If everything fails try extracting the - # first letters. - m_sp = re.match("|".join(special), sym) - if m_sp: - parsed_sym = special[m_sp.group()] - elif Element.is_valid_symbol(sym[:2].title()): - parsed_sym = sym[:2].title() - elif Element.is_valid_symbol(sym[0].upper()): - parsed_sym = sym[0].upper() - else: - m = re.match(r"w?[A-Z][a-z]*", sym) - if m: - parsed_sym = m.group() - - if parsed_sym is not None and (m_sp or not re.match(rf"{parsed_sym}\d*", sym)): - msg = f"{sym} parsed as {parsed_sym}" - warnings.warn(msg) - self.warnings.append(msg) - - return parsed_sym - - def _get_structure( - self, data: dict[str, Any], primitive: bool, symmetrized: bool, check_occu: bool = False - ) -> Structure | None: - """Generate structure from part of the cif.""" - - def get_num_implicit_hydrogens(sym): - num_h = {"Wat": 2, "wat": 2, "O-H": 1} - return num_h.get(sym[:3], 0) - - lattice = self.get_lattice(data) - - # if magCIF, get magnetic symmetry moments and magmoms - # else standard CIF, and use empty magmom dict - if self.feature_flags["magcif_incommensurate"]: - raise NotImplementedError("Incommensurate structures not currently supported.") - if self.feature_flags["magcif"]: - self.symmetry_operations = self.get_magsymops(data) - magmoms = self.parse_magmoms(data, lattice=lattice) - else: - self.symmetry_operations = self.get_symops(data) - magmoms = {} - - oxi_states = self.parse_oxi_states(data) - - coord_to_species = {} # type: ignore - coord_to_magmoms = {} - labels = {} - - def get_matching_coord(coord): - keys = list(coord_to_species) - coords = np.array(keys) - for op in self.symmetry_operations: - frac_coord = op.operate(coord) - indices = find_in_coord_list_pbc(coords, frac_coord, atol=self._site_tolerance) - if len(indices) > 0: - return keys[indices[0]] - return False - - for idx, label in enumerate(data["_atom_site_label"]): - try: - # If site type symbol exists, use it. Otherwise, we use the label. - symbol = self._parse_symbol(data["_atom_site_type_symbol"][idx]) - num_h = get_num_implicit_hydrogens(data["_atom_site_type_symbol"][idx]) - except KeyError: - symbol = self._parse_symbol(label) - num_h = get_num_implicit_hydrogens(label) - if not symbol: - continue - - if oxi_states is not None: - o_s = oxi_states.get(symbol, 0) - # use _atom_site_type_symbol if possible for oxidation state - if "_atom_site_type_symbol" in data.data: # type: ignore[attr-defined] - oxi_symbol = data["_atom_site_type_symbol"][idx] - o_s = oxi_states.get(oxi_symbol, o_s) - try: - el = Species(symbol, o_s) - except Exception: - el = DummySpecies(symbol, o_s) - else: - el = get_el_sp(symbol) # type: ignore - - x = str2float(data["_atom_site_fract_x"][idx]) - y = str2float(data["_atom_site_fract_y"][idx]) - z = str2float(data["_atom_site_fract_z"][idx]) - magmom = magmoms.get(label, np.array([0, 0, 0])) - - try: - occu = str2float(data["_atom_site_occupancy"][idx]) - except (KeyError, ValueError): - occu = 1 - # If check_occu is True or the occupancy is greater than 0, create comp_d - if not check_occu or occu > 0: - coord = (x, y, z) - match = get_matching_coord(coord) - comp_dict = {el: max(occu, 1e-8)} - - if num_h > 0: - comp_dict["H"] = num_h # type: ignore - self.warnings.append( - "Structure has implicit hydrogens defined, parsed structure unlikely to be " - "suitable for use in calculations unless hydrogens added." - ) - comp = Composition(comp_dict) - - if not match: - coord_to_species[coord] = comp - coord_to_magmoms[coord] = magmom - labels[coord] = label - else: - coord_to_species[match] += comp - # disordered magnetic not currently supported - coord_to_magmoms[match] = None - labels[match] = label - sum_occu = [ - sum(c.values()) for c in coord_to_species.values() if set(c.elements) != {Element("O"), Element("H")} - ] - if any(occu > 1 for occu in sum_occu): - msg = ( - f"Some occupancies ({sum_occu}) sum to > 1! If they are within " - "the occupancy_tolerance, they will be rescaled. " - f"The current occupancy_tolerance is set to: {self._occupancy_tolerance}" - ) - warnings.warn(msg) - self.warnings.append(msg) - - all_species = [] - all_coords = [] - all_magmoms = [] - all_hydrogens = [] - equivalent_indices = [] - all_labels = [] - - # check to see if magCIF file is disordered - if self.feature_flags["magcif"]: - for v in coord_to_magmoms.values(): - if v is None: - # Proposed solution to this is to instead store magnetic - # moments as Species 'spin' property, instead of site - # property, but this introduces ambiguities for end user - # (such as unintended use of `spin` and Species will have - # fictitious oxidation state). - raise NotImplementedError("Disordered magnetic structures not currently supported.") - - if coord_to_species.items(): - for idx, (comp, group) in enumerate( - groupby( - sorted(coord_to_species.items(), key=lambda x: x[1]), - key=lambda x: x[1], - ) - ): - tmp_coords = [site[0] for site in group] - tmp_magmom = [coord_to_magmoms[tmp_coord] for tmp_coord in tmp_coords] - - if self.feature_flags["magcif"]: - coords, magmoms, new_labels = self._unique_coords( - tmp_coords, magmoms=tmp_magmom, labels=labels, lattice=lattice - ) - else: - coords, magmoms, new_labels = self._unique_coords(tmp_coords, labels=labels) - - if set(comp.elements) == {Element("O"), Element("H")}: - # O with implicit hydrogens - im_h = comp["H"] - species = Composition({"O": comp["O"]}) - else: - im_h = 0 - species = comp - - # The following might be a more natural representation of equivalent indices, - # but is not in the format expect by SymmetrizedStructure: - # equivalent_indices.append(list(range(len(all_coords), len(coords)+len(all_coords)))) - # The above gives a list like: - # [[0, 1, 2, 3], [4, 5, 6, 7, 8, 9, 10, 11]] where the - # integers are site indices, whereas the version used below will give a version like: - # [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1] - # which is a list in the same order as the sites, but where if a site has the same integer - # it is equivalent. - equivalent_indices += len(coords) * [idx] - - all_hydrogens.extend(len(coords) * [im_h]) - all_coords.extend(coords) - all_species.extend(len(coords) * [species]) - all_magmoms.extend(magmoms) - all_labels.extend(new_labels) - - # rescale occupancies if necessary - all_species_noedit = all_species.copy() # save copy before scaling in case of check_occu=False, used below - for idx, species in enumerate(all_species): - total_occu = sum(species.values()) - if 1 < total_occu <= self._occupancy_tolerance: - all_species[idx] = species / total_occu - - if all_species and len(all_species) == len(all_coords) and len(all_species) == len(all_magmoms): - site_properties = {} - if any(all_hydrogens): - assert len(all_hydrogens) == len(all_coords) - site_properties["implicit_hydrogens"] = all_hydrogens - - if self.feature_flags["magcif"]: - site_properties["magmom"] = all_magmoms - - if len(site_properties) == 0: - site_properties = None # type: ignore - - if any(all_labels): - assert len(all_labels) == len(all_species) - else: - all_labels = None # type: ignore - - struct = Structure(lattice, all_species, all_coords, site_properties=site_properties, labels=all_labels) - - if symmetrized: - # Wyckoff labels not currently parsed, note that not all CIFs will contain Wyckoff labels - # TODO: extract Wyckoff labels (or other CIF attributes) and include as site_properties - wyckoffs = ["Not Parsed"] * len(struct) - - # space groups names are likewise not parsed (again, not all CIFs will contain this information) - # What is stored are the lists of symmetry operations used to generate the structure - # TODO: ensure space group labels are stored if present - sg = SpacegroupOperations("Not Parsed", -1, self.symmetry_operations) - struct = SymmetrizedStructure(struct, sg, equivalent_indices, wyckoffs) - - if not check_occu: - for idx in range(len(struct)): - struct[idx] = PeriodicSite( - all_species_noedit[idx], all_coords[idx], lattice, properties=site_properties, skip_checks=True - ) - - if symmetrized or not check_occu: - return struct - - struct = struct.get_sorted_structure() - - if primitive and self.feature_flags["magcif"]: - struct = struct.get_primitive_structure(use_site_props=True) - elif primitive: - struct = struct.get_primitive_structure() - struct = struct.get_reduced_structure() - - if self.check_cif: - cif_failure_reason = self.check(struct) - if cif_failure_reason is not None: - warnings.warn(cif_failure_reason) - - return struct - return None - - @np.deprecate( - message="get_structures is deprecated and will be removed in 2024. Use parse_structures instead." - "The only difference is that primitive defaults to False in the new parse_structures method." - "So parse_structures(primitive=True) is equivalent to the old behavior of get_structures().", - ) - def get_structures(self, *args, **kwargs) -> list[Structure]: - """ - Deprecated. Use parse_structures instead. Only difference between the two methods is the - default primitive=False in parse_structures. - So parse_structures(primitive=True) is equivalent to the old behavior of get_structures(). - """ - if len(args) > 0: # extract primitive if passed as arg - kwargs["primitive"] = args[0] - args = args[1:] - kwargs.setdefault("primitive", True) - return self.parse_structures(*args, **kwargs) - - def parse_structures( - self, - primitive: bool | None = None, - symmetrized: bool = False, - check_occu: bool = True, - on_error: Literal["ignore", "warn", "raise"] = "warn", - ) -> list[Structure]: - """Return list of structures in CIF file. - - Args: - primitive (bool): Set to True to return primitive unit cells. - Defaults to False. With magnetic CIF files, True will return primitive - magnetic cell which may be larger than nuclear primitive cell. - symmetrized (bool): If True, return a SymmetrizedStructure which will - include the equivalent indices and symmetry operations used to - create the Structure as provided by the CIF (if explicit symmetry - operations are included in the CIF) or generated from information - in the CIF (if only space group labels are provided). Note that - currently Wyckoff labels and space group labels or numbers are - not included in the generated SymmetrizedStructure, these will be - notated as "Not Parsed" or -1 respectively. - check_occu (bool): If False, site occupancy will not be checked, allowing unphysical - occupancy != 1. Useful for experimental results in which occupancy was allowed - to refine to unphysical values. Warning: unphysical site occupancies are incompatible - with many pymatgen features. Defaults to True. - on_error ('ignore' | 'warn' | 'raise'): What to do in case of KeyError or ValueError - while parsing CIF file. Defaults to 'warn'. - - Returns: - list[Structure]: All structures in CIF file. - """ - if os.getenv("CI") and datetime.now() > datetime(2024, 3, 1): # March 2024 seems long enough # pragma: no cover - raise RuntimeError("remove the change of default primitive=True to False made on 2023-10-24") - if primitive is None: - primitive = False - warnings.warn( - "The default value of primitive was changed from True to False in " - "https://github.com/materialsproject/pymatgen/pull/3419. CifParser now returns the cell " - "in the CIF file as is. If you want the primitive cell, please set primitive=True explicitly.", - UserWarning, - ) - if not check_occu: # added in https://github.com/materialsproject/pymatgen/pull/2836 - warnings.warn("Structures with unphysical site occupancies are not compatible with many pymatgen features.") - if primitive and symmetrized: - raise ValueError( - "Using both 'primitive' and 'symmetrized' arguments is not currently supported " - "since unexpected behavior might result." - ) - - structures = [] - for idx, dct in enumerate(self._cif.data.values()): - try: - struct = self._get_structure(dct, primitive, symmetrized, check_occu=check_occu) - if struct: - structures.append(struct) - except (KeyError, ValueError) as exc: - # A user reported a problem with cif files produced by Avogadro - # in which the atomic coordinates are in Cartesian coords. - msg = f"No structure parsed for section {idx + 1} in CIF.\n{exc}" - if on_error == "raise": - raise ValueError(msg) from exc - if on_error == "warn": - warnings.warn(msg) - self.warnings.append(msg) - # continue silently if on_error == "ignore" - - # if on_error == "raise" we don't get to here so no need to check - if self.warnings and on_error == "warn": - warnings.warn("Issues encountered while parsing CIF: " + "\n".join(self.warnings)) - - if len(structures) == 0: - raise ValueError("Invalid CIF file with no structures!") - return structures - - def get_bibtex_string(self): - """ - Get BibTeX reference from CIF file. - :param data: - - Returns: - BibTeX string. - """ - try: - from pybtex.database import BibliographyData, Entry - except ImportError: - raise RuntimeError("Bibliographic data extraction requires pybtex.") - - bibtex_keys = { - "author": ("_publ_author_name", "_citation_author_name"), - "title": ("_publ_section_title", "_citation_title"), - "journal": ( - "_journal_name_full", - "_journal_name_abbrev", - "_citation_journal_full", - "_citation_journal_abbrev", - ), - "volume": ("_journal_volume", "_citation_journal_volume"), - "year": ("_journal_year", "_citation_year"), - "number": ("_journal_number", "_citation_number"), - "page_first": ("_journal_page_first", "_citation_page_first"), - "page_last": ("_journal_page_last", "_citation_page_last"), - "doi": ("_journal_DOI", "_citation_DOI"), - } - - entries = {} - - # TODO: parse '_publ_section_references' when it exists? - # TODO: CIF specification supports multiple citations. - - for idx, data in enumerate(self._cif.data.values()): - # convert to lower-case keys, some cif files inconsistent - data = {k.lower(): v for k, v in data.data.items()} - - bibtex_entry = {} - - for field, tags in bibtex_keys.items(): - for tag in tags: - if tag in data: - if isinstance(data[tag], list): - bibtex_entry[field] = data[tag][0] - else: - bibtex_entry[field] = data[tag] - - # convert to bibtex author format ('and' delimited) - if "author" in bibtex_entry: - # separate out semicolon authors - if isinstance(bibtex_entry["author"], str) and ";" in bibtex_entry["author"]: - bibtex_entry["author"] = bibtex_entry["author"].split(";") - - if isinstance(bibtex_entry["author"], list): - bibtex_entry["author"] = " and ".join(bibtex_entry["author"]) - - # convert to bibtex page range format, use empty string if not specified - if ("page_first" in bibtex_entry) or ("page_last" in bibtex_entry): - bibtex_entry["pages"] = bibtex_entry.get("page_first", "") + "--" + bibtex_entry.get("page_last", "") - bibtex_entry.pop("page_first", None) # and remove page_first, page_list if present - bibtex_entry.pop("page_last", None) - - # cite keys are given as cif-reference-idx in order they are found - entries[f"cifref{idx}"] = Entry("article", list(bibtex_entry.items())) - - return BibliographyData(entries).to_string(bib_format="bibtex") - - def as_dict(self): - """MSONable dict""" - dct = {} - for k, v in self._cif.data.items(): - dct[k] = {} - for k2, v2 in v.data.items(): - dct[k][k2] = v2 - return dct - - @property - def has_errors(self): - """Whether there are errors/warnings detected in CIF parsing.""" - return len(self.warnings) > 0 - - def check(self, structure: Structure) -> str | None: - """Check whether a structure constructed from CIF passes sanity checks. - - Args: - structure (Structure) : structure created from CIF - - Returns: - str | None: If any check fails, on output, returns a human-readable str for the - reason why (e.g., which elements are missing). Returns None if all checks pass. - - Checks: - - Composition from CIF is valid - - CIF composition contains only valid elements - - CIF and structure contain the same elements (often hydrogens - are omitted from CIFs, as their positions cannot be determined from - X-ray diffraction, needs more difficult neutron diffraction) - - CIF and structure have same relative stoichiometry. Thus - if CIF reports stoichiometry LiFeO, and the structure has - composition (LiFeO)4, this check passes. - """ - failure_reason = None - - cif_as_dict = self.as_dict() - head_key = next(iter(cif_as_dict)) - - cif_formula = None - for key in ("_chemical_formula_sum", "_chemical_formula_structural"): - if cif_as_dict[head_key].get(key): - cif_formula = cif_as_dict[head_key][key] - break - - if cif_formula is None and cif_as_dict[head_key].get("_atom_site_type_symbol"): - cif_formula = " ".join(cif_as_dict[head_key]["_atom_site_type_symbol"]) - - try: - cif_composition = Composition(cif_formula) - except Exception as exc: - return f"Cannot determine chemical composition from CIF! {exc}" - - try: - orig_comp = cif_composition.remove_charges().as_dict() - struct_comp = structure.composition.remove_charges().as_dict() - except Exception as exc: - return str(exc) - - orig_comp_elts = {str(elt) for elt in orig_comp} - struct_comp_elts = {str(elt) for elt in struct_comp} - - if orig_comp_elts != struct_comp_elts: - # hard failure - missing elements - - missing = set(orig_comp_elts).difference(set(struct_comp_elts)) - addendum = "from PMG structure composition" - if len(missing) == 0: - addendum = "from CIF-reported composition" - missing = set(struct_comp_elts).difference(set(orig_comp_elts)) - missing_str = ", ".join([str(x) for x in missing]) - failure_reason = f"Missing elements {missing_str} {addendum}" - - elif not all(struct_comp[elt] - orig_comp[elt] == 0 for elt in orig_comp): - # Check that stoichiometry is same, i.e., same relative ratios of elements - ratios = {elt: struct_comp[elt] / orig_comp[elt] for elt in orig_comp_elts} - - same_stoich = all( - abs(ratios[elt_a] - ratios[elt_b]) < self.comp_tol - for elt_a in orig_comp_elts - for elt_b in orig_comp_elts - ) - - if not same_stoich: - failure_reason = f"Incorrect stoichiometry:\n CIF={orig_comp}\n PMG={struct_comp}\n {ratios=}" - - return failure_reason - - -class CifWriter: - """A wrapper around CifFile to write CIF files from pymatgen structures.""" - - def __init__( - self, - struct: Structure, - symprec: float | None = None, - write_magmoms: bool = False, - significant_figures: int = 8, - angle_tolerance: float = 5, - refine_struct: bool = True, - write_site_properties: bool = False, - printout: bool = False, - pos_order: bool = False, - full_order: bool = False, - ) -> None: - """ - Args: - struct (Structure): structure to write - symprec (float): If not none, finds the symmetry of the structure - and writes the cif with symmetry information. Passes symprec - to the SpacegroupAnalyzer. See also refine_struct. - write_magmoms (bool): If True, will write magCIF file. Incompatible - with symprec - significant_figures (int): Specifies precision for formatting of floats. - Defaults to 8. - angle_tolerance (float): Angle tolerance for symmetry finding. Passes - angle_tolerance to the SpacegroupAnalyzer. Used only if symprec - is not None. - refine_struct: Used only if symprec is not None. If True, get_refined_structure - is invoked to convert input structure from primitive to conventional. - write_site_properties (bool): Whether to write the Structure.site_properties - to the CIF as _atom_site_{property name}. Defaults to False. - """ - if write_magmoms and symprec: - warnings.warn("Magnetic symmetry cannot currently be detected by pymatgen,disabling symmetry detection.") - symprec = None - - format_str = f"{{:.{significant_figures}f}}" - - block: dict[str, Any] = {} - loops = [] - spacegroup = ("P 1", 1) - if symprec is not None: - spg_analyzer = SpacegroupAnalyzer(struct, symprec, angle_tolerance=angle_tolerance) - spacegroup = (spg_analyzer.get_space_group_symbol(), spg_analyzer.get_space_group_number()) - - if refine_struct: - # Needs the refined structure when using symprec. This converts - # primitive to conventional structures, the standard for CIF. - struct = spg_analyzer.get_refined_structure() - - lattice = struct.lattice - comp = struct.composition - no_oxi_comp = comp.element_composition - block["_symmetry_space_group_name_H-M"] = spacegroup[0] - for cell_attr in ["a", "b", "c"]: - block["_cell_length_" + cell_attr] = format_str.format(getattr(lattice, cell_attr)) - for cell_attr in ["alpha", "beta", "gamma"]: - block["_cell_angle_" + cell_attr] = format_str.format(getattr(lattice, cell_attr)) - block["_symmetry_Int_Tables_number"] = spacegroup[1] - block["_chemical_formula_structural"] = no_oxi_comp.reduced_formula - block["_chemical_formula_sum"] = no_oxi_comp.formula - block["_cell_volume"] = format_str.format(lattice.volume) - - _, fu = no_oxi_comp.get_reduced_composition_and_factor() - block["_cell_formula_units_Z"] = str(int(fu)) - - if symprec is None: - block["_symmetry_equiv_pos_site_id"] = ["1"] - block["_symmetry_equiv_pos_as_xyz"] = ["x, y, z"] - else: - spg_analyzer = SpacegroupAnalyzer(struct, symprec) - - symm_ops: list[SymmOp] = [] - for op in spg_analyzer.get_symmetry_operations(): - v = op.translation_vector - symm_ops.append(SymmOp.from_rotation_and_translation(op.rotation_matrix, v)) - - ops = [op.as_xyz_str() for op in symm_ops] - block["_symmetry_equiv_pos_site_id"] = [f"{i}" for i in range(1, len(ops) + 1)] - block["_symmetry_equiv_pos_as_xyz"] = ops - - loops.append(["_symmetry_equiv_pos_site_id", "_symmetry_equiv_pos_as_xyz"]) - - try: - symbol_to_oxi_num = {str(el): float(el.oxi_state or 0) for el in sorted(comp.elements)} - block["_atom_type_symbol"] = list(symbol_to_oxi_num) - block["_atom_type_oxidation_number"] = symbol_to_oxi_num.values() - loops.append(["_atom_type_symbol", "_atom_type_oxidation_number"]) - except (TypeError, AttributeError): - symbol_to_oxi_num = {el.symbol: 0 for el in sorted(comp.elements)} - - atom_site_type_symbol = [] - atom_site_symmetry_multiplicity = [] - atom_site_fract_x = [] - atom_site_fract_y = [] - atom_site_fract_z = [] - atom_site_label = [] - atom_site_occupancy = [] - atom_site_moment_label = [] - atom_site_moment_crystalaxis_x = [] - atom_site_moment_crystalaxis_y = [] - atom_site_moment_crystalaxis_z = [] - atom_site_properties: dict[str, list] = defaultdict(list) - count = 0 - if symprec is None: - for site in struct: - for sp, occu in sorted(site.species.items()): - atom_site_type_symbol.append(str(sp)) - atom_site_symmetry_multiplicity.append("1") - atom_site_fract_x.append(format_str.format(site.a)) - atom_site_fract_y.append(format_str.format(site.b)) - atom_site_fract_z.append(format_str.format(site.c)) - atom_site_occupancy.append(str(occu)) - site_label = f"{sp.symbol}{count}" - - if "magmom" in site.properties: - mag = site.properties["magmom"] - elif getattr(sp, "spin", None) is not None: - mag = sp.spin - else: - # Use site label if available for regular sites - site_label = site.label if site.label != site.species_string else site_label - mag = 0 - - atom_site_label.append(site_label) - - magmom = Magmom(mag) - if write_magmoms and abs(magmom) > 0: - moment = Magmom.get_moment_relative_to_crystal_axes(magmom, lattice) - atom_site_moment_label.append(f"{sp.symbol}{count}") - atom_site_moment_crystalaxis_x.append(format_str.format(moment[0])) - atom_site_moment_crystalaxis_y.append(format_str.format(moment[1])) - atom_site_moment_crystalaxis_z.append(format_str.format(moment[2])) - - if write_site_properties: - for key, val in site.properties.items(): - atom_site_properties[key].append(format_str.format(val)) - - count += 1 - else: - # The following just presents a deterministic ordering. - if full_order: - unique_sites = [site for site in struct] - for site in sorted( - unique_sites, - key=lambda t: ( - t.species.average_electroneg, - t.a, - t.b, - t.c, - ), - ): - for sp, occu in site.species.items(): - atom_site_type_symbol.append(str(sp)) - atom_site_symmetry_multiplicity.append("1") - atom_site_fract_x.append(format_str.format(site.a)) - atom_site_fract_y.append(format_str.format(site.b)) - atom_site_fract_z.append(format_str.format(site.c)) - site_label = site.label if site.label != site.species_string else f"{sp.symbol}{count}" - atom_site_label.append(site_label) - atom_site_occupancy.append(str(occu)) - count += 1 - else: - unique_sites = [ - ( - sorted(sites, key=lambda s: tuple(round(x % 1., 7) % 1. for x in s.frac_coords))[0], - len(sites), - ) - for sites in spg_analyzer.get_symmetrized_structure().equivalent_sites - ] - if not pos_order: - for site, mult in sorted( - unique_sites, - key=lambda t: ( - t[0].species.average_electroneg, - -t[1], - t[0].a, - t[0].b, - t[0].c, - ), - ): - for sp, occu in site.species.items(): - atom_site_type_symbol.append(str(sp)) - atom_site_symmetry_multiplicity.append(f"{mult}") - atom_site_fract_x.append(format_str.format(site.a)) - atom_site_fract_y.append(format_str.format(site.b)) - atom_site_fract_z.append(format_str.format(site.c)) - site_label = site.label if site.label != site.species_string else f"{sp.symbol}{count}" - atom_site_label.append(site_label) - atom_site_occupancy.append(str(occu)) - count += 1 - else: - for site, mult in sorted( - unique_sites, - key=lambda t: ( - t[0].a, - t[0].b, - t[0].c, - ), - ): - for sp, occu in site.species.items(): - atom_site_type_symbol.append(str(sp)) - atom_site_symmetry_multiplicity.append(f"{mult}") - atom_site_fract_x.append(format_str.format(site.a)) - atom_site_fract_y.append(format_str.format(site.b)) - atom_site_fract_z.append(format_str.format(site.c)) - site_label = site.label if site.label != site.species_string else f"{sp.symbol}{count}" - atom_site_label.append(site_label) - atom_site_occupancy.append(str(occu)) - count += 1 - - block["_atom_site_type_symbol"] = atom_site_type_symbol - block["_atom_site_label"] = atom_site_label - block["_atom_site_symmetry_multiplicity"] = atom_site_symmetry_multiplicity - block["_atom_site_fract_x"] = atom_site_fract_x - block["_atom_site_fract_y"] = atom_site_fract_y - block["_atom_site_fract_z"] = atom_site_fract_z - block["_atom_site_occupancy"] = atom_site_occupancy - loop_labels = [ - "_atom_site_type_symbol", - "_atom_site_label", - "_atom_site_symmetry_multiplicity", - "_atom_site_fract_x", - "_atom_site_fract_y", - "_atom_site_fract_z", - "_atom_site_occupancy", - ] - if write_site_properties: - for key, vals in atom_site_properties.items(): - block[f"_atom_site_{key}"] = vals - loop_labels += [f"_atom_site_{key}"] - loops.append(loop_labels) - - if write_magmoms: - block["_atom_site_moment_label"] = atom_site_moment_label - block["_atom_site_moment_crystalaxis_x"] = atom_site_moment_crystalaxis_x - block["_atom_site_moment_crystalaxis_y"] = atom_site_moment_crystalaxis_y - block["_atom_site_moment_crystalaxis_z"] = atom_site_moment_crystalaxis_z - loops.append( - [ - "_atom_site_moment_label", - "_atom_site_moment_crystalaxis_x", - "_atom_site_moment_crystalaxis_y", - "_atom_site_moment_crystalaxis_z", - ] - ) - dct = {} - dct[comp.reduced_formula] = CifBlock(block, loops, comp.reduced_formula) - self._cf = CifFile(dct) - - @property - def cif_file(self): - """Returns: CifFile associated with the CifWriter.""" - return self._cf - - def __str__(self): - """Returns the CIF as a string.""" - return str(self._cf) - - def write_file(self, filename: str | Path, mode: Literal["w", "a", "wt", "at"] = "w") -> None: - """Write the CIF file.""" - with zopen(filename, mode=mode) as file: - file.write(str(self)) - - -def str2float(text): - """Remove uncertainty brackets from strings and return the float.""" - try: - # Note that the ending ) is sometimes missing. That is why the code has - # been modified to treat it as optional. Same logic applies to lists. - return float(re.sub(r"\(.+\)*", "", text)) - except TypeError: - if isinstance(text, list) and len(text) == 1: - return float(re.sub(r"\(.+\)*", "", text[0])) - except ValueError as exc: - if text.strip() == ".": - return 0 - raise exc - raise ValueError(f"{text} cannot be converted to float") From 2f40b121cdbb49dfdca2ace9973170aaaa4828e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=90=E7=9D=BF=E6=99=BA?= Date: Mon, 14 Apr 2025 04:33:20 +1000 Subject: [PATCH 03/37] init crystalrelax recipes --- recipes/crystalrelax.yaml | 49 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 recipes/crystalrelax.yaml diff --git a/recipes/crystalrelax.yaml b/recipes/crystalrelax.yaml new file mode 100644 index 00000000..46eac663 --- /dev/null +++ b/recipes/crystalrelax.yaml @@ -0,0 +1,49 @@ +# Model arguments +model_revision: main +torch_dtype: bfloat16 +attn_implementation: flash_attention_2 +bf16: true +tf32: true + +# Chemical Task arguments +chem_task: crystalrelax +dataset_id_or_path: /iopsstor/store/cscs/swissai/a05/chem/binary_compound_relaxing +rewards: +- accuracy + +# Lora Arguments +# No LoRA is used here + +# Training arguments +max_steps: 1450 +per_device_train_batch_size: 2 +gradient_accumulation_steps: 8 +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +learning_rate: 5.0e-7 # 1.0e-6 as in the deepseek math paper 5-e7 from https://hijkzzz.notion.site/unraveling-rlhf-and-its-variants-engineering-insights#147d9a33ecc9806090f3d5c749d31f05 +lr_scheduler_type: cosine +warmup_ratio: 0.03 +# GRPO specific parameters +beta: 0.001 # 0.04 as in the deepseek math paper 0.001 from https://hijkzzz.notion.site/unraveling-rlhf-and-its-variants-engineering-insights#147d9a33ecc9806090f3d5c749d31f05 +max_prompt_length: 600 +max_completion_length: 2200 +num_generations: 16 +use_vllm: true +vllm_device: "cuda:3" +vllm_gpu_memory_utilization: 0.7 +vllm_max_model_len: 3000 + +# Logging arguments +logging_strategy: steps +logging_steps: 2 +report_to: +- wandb + +save_strategy: "steps" +save_steps: 25 +seed: 42 + +# Hugging Face Hub +push_to_hub: false + # hub_model_id: llama-3-1-8b-math-orca-qlora-10k-ep1 # if not defined same as output_dir From 8cfc7756cef85e75cb0771c937cdf226b6afa022 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=90=E7=9D=BF=E6=99=BA?= Date: Mon, 14 Apr 2025 04:33:40 +1000 Subject: [PATCH 04/37] update: accuracy metrics --- .../tasks/crystal_structure/relaxing.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/open_r1/tasks/crystal_structure/relaxing.py b/src/open_r1/tasks/crystal_structure/relaxing.py index 19db9ff7..a288f2a7 100644 --- a/src/open_r1/tasks/crystal_structure/relaxing.py +++ b/src/open_r1/tasks/crystal_structure/relaxing.py @@ -49,6 +49,10 @@ def __init__(self, **kwargs): "As 8_int 0.06170692 0.25000000 0.62500000\n" "<|im_end|>\n" ) + self.log_custom_metrics = True + self.custom_metrics = { + 'val/rewards': [], + } # Dataset here: /iopsstor/store/cscs/swissai/a05/chem/CRLLM-PubChem-compounds1M.csv @@ -122,6 +126,8 @@ def accuracy_reward(self, completions, solution, **kwargs): rewards.append(reward) except Exception as e: rewards.append(-10) + if self.log_custom_metrics: + self.custom_metrics['val/rewards'].extend(rewards) return rewards def preprocess_response(self, response): @@ -132,3 +138,19 @@ def preprocess_response(self, response): return m[-1].strip() else: return "NONE" + + def get_metrics(self) -> Dict: + """ + Get task metrics to log in WANDB. + This function takes no arguments and returns a dictionary of metrics {key[str]: value[float]}. + """ + metrics = dict() + if self.log_custom_metrics: + rewards = self.custom_metrics['val/rewards'] + if rewards: + correct_count = sum(1 for r in rewards if r == 1) + total_count = len(rewards) + accuracy = correct_count / total_count if total_count > 0 else 0.0 + metrics['val/accuracy'] = accuracy + self.custom_metrics['val/rewards'] = [] + return metrics \ No newline at end of file From 472b3310484a7baf9fd8344f53b19309f60beafa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=90=E7=9D=BF=E6=99=BA?= Date: Mon, 14 Apr 2025 05:30:41 +1000 Subject: [PATCH 05/37] fixed: init src_train_file, tgt_train_file, src_test_file, tgt_test_file --- src/open_r1/tasks/crystal_structure/relaxing.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/open_r1/tasks/crystal_structure/relaxing.py b/src/open_r1/tasks/crystal_structure/relaxing.py index a288f2a7..0343cdc4 100644 --- a/src/open_r1/tasks/crystal_structure/relaxing.py +++ b/src/open_r1/tasks/crystal_structure/relaxing.py @@ -11,6 +11,10 @@ class BinaryCompoundRelaxing(RLTask): + src_train_file: str = "" + tgt_train_file: str = "" + src_test_file: str = "" + tgt_test_file: str = "" question_template: str = "" def __init__(self, **kwargs): From db78cfa4e22c63d8a8f731a4464a5648c59d907e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=90=E7=9D=BF=E6=99=BA?= Date: Mon, 14 Apr 2025 05:35:55 +1000 Subject: [PATCH 06/37] fixed: init log_custom_metrics, custom_metrics --- src/open_r1/tasks/crystal_structure/relaxing.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/open_r1/tasks/crystal_structure/relaxing.py b/src/open_r1/tasks/crystal_structure/relaxing.py index 0343cdc4..1378e2b2 100644 --- a/src/open_r1/tasks/crystal_structure/relaxing.py +++ b/src/open_r1/tasks/crystal_structure/relaxing.py @@ -8,6 +8,7 @@ from rdkit import Chem from ..base import RLTask import requests +from dataclasses import field class BinaryCompoundRelaxing(RLTask): @@ -16,6 +17,8 @@ class BinaryCompoundRelaxing(RLTask): src_test_file: str = "" tgt_test_file: str = "" question_template: str = "" + log_custom_metrics: bool = True + custom_metrics: dict = field(default_factory=dict) def __init__(self, **kwargs): super().__init__(**kwargs) From 04c386948ed20fe88b1a5c1d471012d041fa1e51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=90=E7=9D=BF=E6=99=BA?= Date: Mon, 14 Apr 2025 05:44:23 +1000 Subject: [PATCH 07/37] fixed: read_files format issue fixed --- .../tasks/crystal_structure/relaxing.py | 32 ++++++++++++++----- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/src/open_r1/tasks/crystal_structure/relaxing.py b/src/open_r1/tasks/crystal_structure/relaxing.py index 1378e2b2..77f28335 100644 --- a/src/open_r1/tasks/crystal_structure/relaxing.py +++ b/src/open_r1/tasks/crystal_structure/relaxing.py @@ -65,14 +65,30 @@ def __init__(self, **kwargs): def read_files(self, src_file: str, tgt_file: str) -> Dict: """Read source and target files and create dataset dictionary.""" - with open(src_file, "r", encoding="utf-8") as f: - problems = [ - self.question_template.format(self.process_line(line)) - for line in f.readlines() - ] - - with open(tgt_file, "r", encoding="utf-8") as f: - solutions = [self.process_line(line) for line in f.readlines()] + def read_records(file_path: str) -> list: + """Helper function to read multi-line records separated by blank lines.""" + with open(file_path, "r", encoding="utf-8") as f: + lines = f.readlines() + records = [] + current_record = [] + for line in lines: + if line.strip() == "": # Blank line indicates end of a record + if current_record: + records.append("\n".join(current_record)) + current_record = [] + else: + current_record.append(line.strip()) + if current_record: # Append the last record if file doesn't end with blank line + records.append("\n".join(current_record)) + return records + # Read records from source and target files + src_records = read_records(src_file) + tgt_records = read_records(tgt_file) + + # Generate problems using the question template + problems = [self.question_template.format(record) for record in src_records] + # Solutions are the raw target records (assuming no further processing needed) + solutions = tgt_records return { "problem": problems, From 779500940fb856d47ba3cea44e79f2073cc585c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=90=E7=9D=BF=E6=99=BA?= Date: Thu, 24 Apr 2025 14:45:26 +1000 Subject: [PATCH 08/37] add: document --- docs/source/tasks/crystalrelax.rst | 70 ++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 docs/source/tasks/crystalrelax.rst diff --git a/docs/source/tasks/crystalrelax.rst b/docs/source/tasks/crystalrelax.rst new file mode 100644 index 00000000..6015afb8 --- /dev/null +++ b/docs/source/tasks/crystalrelax.rst @@ -0,0 +1,70 @@ +Crystal Relaxing +=================== + +.. currentmodule:: open_r1.tasks.relaxing + +BinaryCompoundRelaxing +------------------ + +.. autoclass:: BinaryCompoundRelaxing + :members: + :show-inheritance: + +Task Description +---------------- + +The `BinaryCompoundRelaxing` task guides a language model through multiple steps of structural relaxation on perturbed binary compounds. Given a serialized CIF description of a compound, the model must iteratively propose adjustments to reduce the internal energy, documenting its reasoning within tags and outputting a final relaxed structure within tags. + +Features +-------- + +- Reads and processes variations of SMILES notations from a dataset +- Converts varying SMILES strings into a canonical form +- Uses a template to guide the model in understanding how to format the response +- Features reward functions based on exact match and validity of SMILES + +Usage Example +------------- + +.. code-block:: python + + from open_r1.tasks.crystal_structure.relaxing import BinaryCompoundRelaxing + + # Initialize the task, pointing to a local dataset directory + task = BinaryCompoundRelaxing(dataset_id_or_path="/path/to/cif_data") + + # Load datasets + dataset = task.load() + train_ds = dataset["train"] + test_ds = dataset["test"] + + # Compute accuracy rewards for an example prediction + completions = ["M2S serialized_cif …"] + solutions = ["M2S serialized_cif …"] + rewards = task.accuracy_reward(completions, solutions) + + +Data Format +----------- + +The task reads paired text files with multi-line CIF records separated by blank lines: + +- `src-train.txt / src-test.txt`: Each record is a serialized CIF string of a perturbed binary structure. +- `tgt-train.txt / tgt-test.txt`: Each record is the ground‑truth CIF string after DFT relaxation. + +Reward Functions +---------------- + +1. **Accuracy Reward (accuracy_reward)** + - Sends each predicted structure (extracted via tags) together with the ground truth to a scoring server at /compute_score. + - Receives an energy‑based reward (e.g., +1 for lower energy, –4 for higher energy, –10 for invalid). + +Task Example +------------ + +This example illustrates how the given non-canonical SMILES is converted to its canonical form: + +.. code-block:: text + + Input: unstable Crystal structure [M2S format] + Output: relaxed Crystal structure [M2S format] \ No newline at end of file From ff78f8fcff38f9e1a9f62cab03d98a9c1486db3e Mon Sep 17 00:00:00 2001 From: Ruizhi Xu Date: Thu, 24 Apr 2025 06:29:48 +0200 Subject: [PATCH 09/37] update: create new reward werver instead of using container --- .../AIRS_preporcess/LICENSE | 0 .../AIRS_preporcess/_tokenizer.py | 0 .../AIRS_preporcess/spacegroups.txt | 0 .../app.py} | 0 .../reward_server/launch.slurm | 20 +++ .../reward_server/requirements.txt | 6 + .../reward_logs/grpo-chem-369894.err | 10 ++ .../reward_logs/grpo-chem-369894.out | 0 .../reward_server/reward_server.py | 139 ++++++++++++++++++ 9 files changed, 175 insertions(+) rename src/open_r1/tasks/crystal_structure/{ => reward_server}/AIRS_preporcess/LICENSE (100%) rename src/open_r1/tasks/crystal_structure/{ => reward_server}/AIRS_preporcess/_tokenizer.py (100%) rename src/open_r1/tasks/crystal_structure/{ => reward_server}/AIRS_preporcess/spacegroups.txt (100%) rename src/open_r1/tasks/crystal_structure/{reward_server.py => reward_server/app.py} (100%) create mode 100644 src/open_r1/tasks/crystal_structure/reward_server/launch.slurm create mode 100644 src/open_r1/tasks/crystal_structure/reward_server/requirements.txt create mode 100644 src/open_r1/tasks/crystal_structure/reward_server/reward_logs/grpo-chem-369894.err create mode 100644 src/open_r1/tasks/crystal_structure/reward_server/reward_logs/grpo-chem-369894.out create mode 100644 src/open_r1/tasks/crystal_structure/reward_server/reward_server.py diff --git a/src/open_r1/tasks/crystal_structure/AIRS_preporcess/LICENSE b/src/open_r1/tasks/crystal_structure/reward_server/AIRS_preporcess/LICENSE similarity index 100% rename from src/open_r1/tasks/crystal_structure/AIRS_preporcess/LICENSE rename to src/open_r1/tasks/crystal_structure/reward_server/AIRS_preporcess/LICENSE diff --git a/src/open_r1/tasks/crystal_structure/AIRS_preporcess/_tokenizer.py b/src/open_r1/tasks/crystal_structure/reward_server/AIRS_preporcess/_tokenizer.py similarity index 100% rename from src/open_r1/tasks/crystal_structure/AIRS_preporcess/_tokenizer.py rename to src/open_r1/tasks/crystal_structure/reward_server/AIRS_preporcess/_tokenizer.py diff --git a/src/open_r1/tasks/crystal_structure/AIRS_preporcess/spacegroups.txt b/src/open_r1/tasks/crystal_structure/reward_server/AIRS_preporcess/spacegroups.txt similarity index 100% rename from src/open_r1/tasks/crystal_structure/AIRS_preporcess/spacegroups.txt rename to src/open_r1/tasks/crystal_structure/reward_server/AIRS_preporcess/spacegroups.txt diff --git a/src/open_r1/tasks/crystal_structure/reward_server.py b/src/open_r1/tasks/crystal_structure/reward_server/app.py similarity index 100% rename from src/open_r1/tasks/crystal_structure/reward_server.py rename to src/open_r1/tasks/crystal_structure/reward_server/app.py diff --git a/src/open_r1/tasks/crystal_structure/reward_server/launch.slurm b/src/open_r1/tasks/crystal_structure/reward_server/launch.slurm new file mode 100644 index 00000000..7506de21 --- /dev/null +++ b/src/open_r1/tasks/crystal_structure/reward_server/launch.slurm @@ -0,0 +1,20 @@ +#!/bin/bash +#SBATCH --job-name=grpo-chem +#SBATCH --ntasks-per-node=1 +#SBATCH --time=00:30:00 +#SBATCH --nodes=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=./reward_logs/%x-%j.out +#SBATCH --err=./reward_logs/%x-%j.err +#SBATCH --environment=vllm071 +#SBATCH -A a-a05 + +set -xe +source ~/.bashrc +cd /Documents/sink/src/open_r1/tasks/crystal_structure/reward_server + +FLASK_IP=$(srun --ntasks=1 hostname -I | awk '{print \$1}') +echo "Flask is up at http://$FLASK_IP:9001" + +pip install -e . --no-deps; +python app.py diff --git a/src/open_r1/tasks/crystal_structure/reward_server/requirements.txt b/src/open_r1/tasks/crystal_structure/reward_server/requirements.txt new file mode 100644 index 00000000..d3028214 --- /dev/null +++ b/src/open_r1/tasks/crystal_structure/reward_server/requirements.txt @@ -0,0 +1,6 @@ +Flask +pymatgen +gemmi +mace +ase + diff --git a/src/open_r1/tasks/crystal_structure/reward_server/reward_logs/grpo-chem-369894.err b/src/open_r1/tasks/crystal_structure/reward_server/reward_logs/grpo-chem-369894.err new file mode 100644 index 00000000..280cbb09 --- /dev/null +++ b/src/open_r1/tasks/crystal_structure/reward_server/reward_logs/grpo-chem-369894.err @@ -0,0 +1,10 @@ ++ source /users/ruizhi_xu/.bashrc +++ test -s /users/ruizhi_xu/.alias +++ true ++ cd /Documents/sink/src/open_r1/tasks/crystal_structure/reward_server +++ srun --ntasks=1 hostname -I +++ awk '{print \$1}' +awk: 1: unexpected character '\' +slurmstepd: error: couldn't chdir to `/Documents/sink/src/open_r1/tasks/crystal_structure/reward_server': No such file or directory: going to /tmp instead +slurmstepd: error: couldn't chdir to `/Documents/sink/src/open_r1/tasks/crystal_structure/reward_server': No such file or directory: going to /tmp instead ++ FLASK_IP= diff --git a/src/open_r1/tasks/crystal_structure/reward_server/reward_logs/grpo-chem-369894.out b/src/open_r1/tasks/crystal_structure/reward_server/reward_logs/grpo-chem-369894.out new file mode 100644 index 00000000..e69de29b diff --git a/src/open_r1/tasks/crystal_structure/reward_server/reward_server.py b/src/open_r1/tasks/crystal_structure/reward_server/reward_server.py new file mode 100644 index 00000000..2ac1e969 --- /dev/null +++ b/src/open_r1/tasks/crystal_structure/reward_server/reward_server.py @@ -0,0 +1,139 @@ +from flask import Flask, request, jsonify +import gc +import random +from io import StringIO +from pymatgen.core import Structure +from pymatgen.analysis.structure_matcher import StructureMatcher +import gemmi +from pymatgen.io.cif import CifWriter +from AIRS_preporcess._tokenizer import CIFTokenizer +from mace.calculators import mace_mp +from ase.io import read + +def compare_internal_energy(cif1, cif2): + atoms1 = read(StringIO(cif1),format='cif') + atoms2 = read(StringIO(cif2),format='cif') + calc = mace_mp(model="large", device='cuda') + atoms1.calc = calc + atoms2.calc = calc + energy1_total = atoms1.get_potential_energy() + energy2_total = atoms2.get_potential_energy() + + energy1_per_atom = energy1_total / len(atoms1) + energy2_per_atom = energy2_total / len(atoms2) + print(";;;;;;;;;;;;;;;;;;;;;") + print("Orginal Internal Energy:", energy1_per_atom) + print("LLM Energy:", energy2_per_atom) + print(";;;;;;;;;;;;;;;;;;;;;") + if energy1_per_atom < energy2_per_atom: + return -4 + elif energy1_per_atom > energy2_per_atom: + return 1 + else: + return -10 + +app = Flask(__name__) +cif_tokenizer = CIFTokenizer() + +def parse_llm_structure(answer_text): + """ + """ + try: + return Structure.from_str(answer_text, fmt="cif") + except Exception as e: + print("Error in parse_llm_structure:", e) + return None + +def compute_score(answer_text, ground_truth, alpha=5.0): + """ + Calculate the score based on the structure match: + Logic description: + + """ + try: + answer_text = cif_tokenizer.deserialize(answer_text, ground_truth.get("ground_truth", "")) + except Exception as e: + print("format error 1", e) + return -10 + do_print = random.randint(1, 1) == 1 + if do_print: + print("-------------- START ------------------") + print("answer_text:", answer_text) + print("ground_cif:", ground_truth.get("ground_truth", "")) + + dft_cif = ground_truth.get("ground_truth", "") + if not dft_cif: + print("No ground truth CIF content provided.") + return -10 + + try: + doc = gemmi.cif.read_string(dft_cif) + doc.check_for_missing_values() + doc.check_for_duplicates() + + doc = gemmi.cif.read_string(answer_text) + doc.check_for_missing_values() + doc.check_for_duplicates() + except Exception as e: + print("CIF error:", e) + return -10 + + try: + dft_structure = Structure.from_str(dft_cif, fmt="cif") + if do_print: + print("dft_structure OK") + except Exception as e: + if do_print: + print("Error parsing DFT structure:", e) + return -10 + + try: + llm_structure = parse_llm_structure(answer_text) + if llm_structure is None: + return -10 + if do_print: + print("llm_structure OK") + except Exception as e: + print("Error parsing LLM-generated structure:", e) + return -10 + reward = -5 + try: + reward = compare_internal_energy(dft_cif, answer_text) + except Exception as e: + print("**************************") + print('CALC ERROR:', e) + print("**************************") + print("-------------- END ------------------") + return reward + + if do_print: + print(f"Reward: {reward}") + print("-------------- END ------------------") + return reward + +@app.route('/compute_score', methods=['POST']) +def compute_score_endpoint(): + """ + The interface /compute_score receives POST requests, and the JSON format content needs to include: + - answer_text: CIF content string generated by LLM + - ground_truth: dictionary containing the key "ground_truth", and the value is the CIF content after DFT optimization + + Return the calculation result in JSON format, for example: + { "reward": -0.123 } + """ + data = request.get_json() + if not data: + return jsonify({"error": "No JSON data provided"}), 400 + + answer_text = data.get("answer_text", "") + ground_truth = data.get("ground_truth", {}) + + if not answer_text or not ground_truth: + return jsonify({"error": "Missing required fields: answer_text and ground_truth"}), 400 + + reward = compute_score(answer_text, ground_truth) + gc.collect() + return jsonify({"reward": reward}) + +if __name__ == '__main__': + app.run(host='0.0.0.0', port=9001, debug=True) \ No newline at end of file From abb3a1cc66cc05ba3e44504852b9bbb7c4d99f59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=90=E7=9D=BF=E6=99=BA?= Date: Fri, 25 Apr 2025 01:08:11 +1000 Subject: [PATCH 10/37] add: result and validation --- .../_static/structure_relaxing_result.png | Bin 0 -> 131412 bytes .../structure_relaxing_success_rate.png | Bin 0 -> 182831 bytes docs/source/tasks/crystalrelax.rst | 30 +++++++++++++++--- 3 files changed, 25 insertions(+), 5 deletions(-) create mode 100644 docs/source/tasks/_static/structure_relaxing_result.png create mode 100644 docs/source/tasks/_static/structure_relaxing_success_rate.png diff --git a/docs/source/tasks/_static/structure_relaxing_result.png b/docs/source/tasks/_static/structure_relaxing_result.png new file mode 100644 index 0000000000000000000000000000000000000000..b6b9459116356e18b04a8dda643e0ca57a466d92 GIT binary patch literal 131412 zcmeEuWmH_t(kK?(AuzbRyEDiD0|^=uAh-lVa7oaC;O_1oAb5}f!3l#q3Bg?kg1f)T zIrpA>a=y2|x7Pc8hn}@-H@&;6rK)R}Oqhn65)LLgCISKijxrdeg@AwxMnFK!M@NO% zlpm^S!!L*~T1s*Vl_Qie_)Xv|U1dvERRnf8jE?XKkq`m-R}=Uj0wOuWqd#B-1QkSz zzhEsy)_=7@LO=+yMnL&jo0stGuTL!e3xE1wSL9s8|MZxP^zYWF;9TT?!-)C69%E7^ zWrSZa9Km`n2nZxhzrKjdT8xJX2r>xDAX#m9#GkDf2^KRRAKaRsCRgq7@3m3^1NZQI4SIS9Y5gNZHX~A7O`kuu9eCSMe*OB@z*zv1i1&XlemJ09 zd*tD7&xHQ>cJTlD(V8HK(V+aldsON`jmL2tlrsH)HASQV-(sh}$N1littb)>$FTe4 z2UVxO)r+K&6QAGv^_ykk6%ip#SeNo`o32$P?XU7$_24h1OzUI1#BwzGVD!@^UGjs;Vkna}X0nGP&(E)u%GwS*-Q__@QgitH~I_cJ<}A2W$5+gMS1?MFF02 z!m;J=-VKt2HQe2u_*yz=XJ_dxUe#yai2eiXEc5%y_s<^h87aD3AY)+S;e>FV5;S~C z`^P3i7*pI6=={9FHhG@@eDh<^j^wts zI_t>~k9kCSMN<`(6t&k!|puFPIJv~pBPvPp%RX4CC%NrE4nNVw)TpO zn)q4V_(Z?KR=HJ5Y9+IN0>JZxM$C6bPTcv@@lAsL?w=aQ;UR{$k*Vk9+A;bs zJhbTeQjRCzhh2B3N~e6+6ld92Tn-kc2oojD<1+jFDMai)V3BfvXuejhkr_5SgKM(x zo7J%+tD=i{Pq;o-Jr;^MG;ncsWehxPXgHsmoP>-<@!C-7R8ZkrVKcB&i=1NT%1rf( zPH0-{>Pl|a)HqT3a%Ei}b_;%FjoX`RD59GC9CGyY=Zp(%@AHx}DUYGbTmE#XrJXZx z6L&ouc|LLg0Bkjqsj5?H5LP23$lk0x&eig1RlC%f-hjDKBkxwVPrkA+l(nUIc5aTQ zvR`pXXDfBMTf2OmF-z~0@3}mB;FP7@9(Pvz?Mx$ zK{mko?tJ+#$tlPu8|?I-DOoA8FZeBI-)yB{k#9`%JV;82WCXlm zr+#7#YIvQR^J7$c)@-xsp(Sf=0NAT{GI#urUc%YfLw~>cw$hkz5V$@Jl5t|V} zXFuOUa6#M?s{LYPREhe@qA-ghfEGLm<;d5th|JBwI-ZhiT1o5dj<7;57<@x7J`de!Df(Cpj9Y1<@`K)YI>R?m=N2eiR-P; zV$RZF%-ZbLwFmIPaoNZWFmPi92#I6>MA{ErhIW>y1NDov%&XaeOF4gk4P+-m*LxcO zV&OvHfwk74W6J@h3#PF`SQg&&>uJrzx?^It} zaecdIZKQY8%|b7_-@JeB#4A#DO@BYt+qUQ;E^g&?-|L{bb3vYfOuB44>KK|@YgFf6 z{bY)MG$9X zpm$vc;C|~n+XsCR63J(cpt#@hWN{KfeM%br*g2@ueU-41gu2tEhg|n;?b~ z(WAabFLpnw4OO$ih}h7>k3`;#F@0(f(&Isod@MAY1Ek3t&-|=B&HS9`L@#PN{boyw zXI;KnGY)0YLn(?w++UoNm)%-T^<(#^F%(xVSc!?B#;7#3dn+x?y46x10oP|XI2S8G z(g;9H>*COe5R87V9*)=RXf$}cDmA@)i+UC#`=%7CTY4s2@X#21oi^EIz+s$l2NLKI zHjbz{fjEN2Jd^k%)?7U6y@yNLveUv+Y35D*JCR4Z6G@><`F)Z7rTwmVvuA?&KJLNq z4rbHN50+BS2Xn_6x@TUZ_+%SH!HdgmXA$%Y3mL&4j~IZXLt#fJw_~eU6Z6)Q8RD#< zIs;ADr!kR2r(vXeO@54*yWWw~!yi-Pa~8G>f70ZayGNU7J6<-1P}zo+m`RFt1zBa? zJ1mfcw??vbc50qm~!6yP; zWu?TRN)n9kg-rAl0ptwmHt0r_OB^wtEw|_&jJ=V$=cLm5$WtTx zp7S*W@u%1dfIdKBUFK&A(Sz40#791ESSMTWECk|ALV9m~F z;cs+VpQk{VxS8_R-}33!Z3$4DvLXb&U9%3nAY4l8{$8QK79@u-iWRQ7ts~eSI^mNi zO7Kmt5a-5$UB&ciIJqMIrqTll=xCQnf4$bsswl!(S1^{_@#;~*b0KlvWrtWg6S)@F z^gijtknH{JlGe^u={pU9;FGG?<&jnnZZADMqD;IB zxQ_#^nM3_FFrd?B%*xKYIsYWwj67LqE_9!aKo#`4y6VaF2KBM6;YK!Wu`A4qs0ZT_ zp+f3Wo*;Zza6{aWwY$c?&nQ3=>5i;1)djPuqD_b9(e%nmvVruHjijua-OV9)&A1oX z$mpX~Yd=1E#qfneTYJ&F`EFZsOPn7(gqNMwrdc3VFvLqes)^RY0mBq*Cdq)`w8SSw zGPu46cX)D8Ic6ya7C@K_tjP-a*qOVTfZz+I2p=CcKi9}}3C1t^2e}mZN#4&t89zkb zgZ&1&&e~eNCWVW_6?9-J23Jshhu$vW2xnhW-w9=ik>mPdUaSBalQ`hhlG4QUF`C%X zIGmYkjr(%C!-Y4Ok+J3fOiezkPSG zNM-#tE-0?XGz?de=?<@zZkO0b$6 z6J4x)4Pi0Q+#Xcm>hT*h=Y7+OyTzMLW#Oy)bKb9Z`W`^z%}eZbH(6}zyJr6L;?0O; z!ZwGa>^TjfvEC6|9Pxr&Dka$d?3%(MaOBBHv;2jooNdT;0#TUI#(=(YcY1UWXk%WT2Py3zIeG zEM4f&svh}HFbS!I=8Wgxu5^d2r4M@iSk^pW2}^)gd|p!Pji$m{L(A4j|6s6|W|gfH z!xm$gfQ=Mr*M7ZIhPmJ)5@>|YI3?k69L6$X9wSA9P4Le5vVaE7g_!=0z2w!`H8Xcc z%vtlH6+rkLjifs-Hfu_yff`Ty?IdrbNs0R=tXRI&^Crt}5{-P4j7{79W2YsLN)!}I zFBzPEvMOBObEa5iCdcNv6UHKsI8UbtQuKEG>3)4Q$WxrakKpJlcbL5DzZyh<)vb^c zX%bS@m*<{WNI8AC`pB4=6dKqOLl^87LI{m;NgLgDXh~c2J{jiyI%Q4pP`y7_ob4f= zh|^LLUCL8A>=K;vxJ&AS;t`5Gh6>VvWo`E7tZ>$qCjF0{ltt@Yk0ZP{l5&a$#c4rb|LzSedNM2V0O5RaZB25r!cORabW;fw-;F%p51%8uv(^XuB; zrB2SHM?bns0?k|wJN!;tvNc5~MIVMmmcqfz`R0e1^Bg1U+cUW@@Dvj-fS0tohRyCx zjnh&J`|AvzYAaqx`93d%NSx`+mshNkLjjU2mosHNGcup1Tvqn}^WKQy=ays!^L;4r zO=z%NAUvM!e^zYHN2W(g_Mim3Qub8E+U2Y>2b-8uE+EaErLt0DcQ+haGQqGqDzC>_`Z>W|fefr^Hr~NkI^z@l}km-uQ zg=Xbd0L?usHVFVrJZS>%9bkdLp0ua>?CJP2VWW^HZo;bWuN~wzzi#`FBW#iHOzg$s zUg)=9QNW5ueVFa}afN(uS$xTzPks99`no{Pw~Gh2Y5=g~_H~+T^1~^qFImJXXC+Zj zRQat}`;7z)3wYt;c3Yi`Yyy_)b-zl|nV^_JbWFx%5$XTIIqQ4IVLX1*85`wGPo-pL zw@pyp$}C1fwVKj?ym89We)}Be6(Yy&qu>0)W&&f$;{Za25^n($?uVwV>7L5DvSV}7 zPy#E_FPnJrAva%^Fcj657Z=+g`UI_9MtvI((=-ptefI_leaS4;zMPPg3twRVT$&^# zFNW4diSf%RN2DO|as_V^jfN77jIxzbi9<`V0J-I|tC*qr#t`;;t0;yyCWmxz>AqqK zyY+D!V9*G%cXqn9bp67S&FW#;;nwH+_6bi=0L@$V`%Q_ZaUrzHcVypqgAcuZ;K+C& z8!0|)`*@zmG!HjGoUGtxb`UMWnLl@s=2rWi!~O2*&tmDvz1K9qJQXEONW!sG6APyX zPjfNR@nL>w;z;iTUvs>9ALP+usLN?d5iu!9PxX>G>o7Euu@EZv+Cnz9rIvHb=mlfn z_ma}1qua8~!BuWdRK>?6)q2a^`6A@zd^A-#pj~YNWxpyqH23+D9?x!USL$6 z2CxvXOhy9fA8{K-_q>5NIvhxQFz^-oqEG zjfThg0f^>#o2Uc0E)*$l#F(hC#&9o9_xGLnh-te^v0bvhkpw}=Rh%uoqQOD;70sxc z7?ByVuGnUU_4IN<`5Trh(qXJud5RZ2DA|pew^V2HN;!dQ}v~1N$*=&+85>1!8%{c4D1u zC4dukoVbBQQ*eAh@4Sc8CRs3_V z6B#Q&DSEZTG4LqBD;$kF91jRAM3NzsRo=8hUX*MHEb(F<*{&b(J-}rI*~zMrlkp{4 zK9#gLZ?-BiU7~gVvl6ihx>sat(=Y6LO~N{{3M_AjDm$3O&w>Au3)P(%w8acD1;(_K zalsf1h*JKc!8G^VN_h{GH>FneW+~gQltr<3#-6a+_y~u>9QJDjv9uE8b$ybhFn3za z3}Hb@Z4mVnpGEE&&$;h*@w+VRiKj$ja)x+oSjQB{jJW*3L?;v$O3aW8l@C^(SL!mzhgyJ5@&;SY6y1~367iN8XAiL_u7)noVTwq~#SDbYUWmK%8 z@zqS(N@BqL@cf0%Pf<>~&qRDtwIsM2gJN-zRlfA;Qt+q*r*KvKktaaro$XvUbxHvs z?(sG?&%q)t`h+Y?#EQM6CUlMOjfqA_7z{3()r_x>Q%JEHBcpa* zV$Szh7T-Z@L{sBb5%nkymrI~St=Hh8wI;2S0r<(6-u#w^Z<80;XrDlUDTUI@m-$d) z)A*$H3vU}G@F*`@jEQQ~K)ie_EOBf;Wz_pR$bE!2-Ly`Esko7TXVY^VUz9eV!^I)( z8KAq6^u8N-Zw|^?7K93nK#Y__uD{-$qW;0wW^lIFJ{mq6duk#@*dgki13$%A-b>95suI7q( za5^Z*ldUOI&nqpZdMclSa<>P&+RDa3bOph})xf%-v4wgYH}tZb=Mnx27Wv4B%d&^P z-82VYBx8D<_f4a|+}OAvB$HU#*Ob~jCBqqTI~m0RdCM1FN@G<9gX<2zyOkjMM8!Z* zjyRP+H-rf1eVqTP3h!6P>bHa650}NA*B#MEBI9c4ZwcQ-`>kqv#I&;JrSy0Ox<^uS z<}x6ptj`3B0Injf0XhJs_0LI;##MWq$KCOV_Sg*SQdc1N%5j=55y&S;yb}Q}W_&20 zuHymzj$WOTuGLYNS1{pa(`YMkFa^ERasFK$iap7Va`$)F3cf^kVh(Ifz%vbO`~f~0 zCn;zEezcQ+lTo5(XidxE{E3zf4F(fVH~Jg^4T=qoI4V{3m%50PZY;LtnZT!1n-^)> zNVb`pSn&{JQJ2)pBWOLcUf3yE?9?~*)*L* z$wz^lF;qbAfRU>{)nRw(lQhzR-BBZm>Mb{Dog|!VKnSQ`I4x&AuqH$ITa<-v8UoBP z4Df?piIQaB`1+?+_mB53qw4gY>rrEm?;{+11h6qVU$&p@OTfDEl3ps0nD@DZSl5OE zZf|dEM}tnDUnOl>Zpq!tpWmHTJ*4NWXJfP$HWjhabA>B8HTNUjRj!k$WXz#*2uZ*0Q8Q;9@iwKZMhN<75J}5ujtKU2H*NCid z3lJWQL@QB-*q@%{Ja}*OamJpQ-2!6?(53dL+wb2hNVVVxXCFK)EGdC5xX2KAt8=|* zIjS0l=al4_ndO4SYEmD((uvGv8XoX`&#gY<9j}EhLrt`c%aQwPL7Q>M?~=p4qQ)34HLjF^zWTPqL_%D zgr7v?&e{)ya$>f&93I1trXQ{xG``VG&4>_h_;r8ty{vuU$($cUCSy=cO7Xp(bVqgP zmM$n^h_v;+xOkv7AD020wUN|*c@b8RS$egsJy>dk!fpb7SzgMK;o_k@AY+{_y~C4U1|*}SW96Rvnf-t> z0>-&R?_K+2Xq78lL!a1AHk?_~qhwbGgn%{XF0*oF1VW13xy>IlB8k--V&Wo=*x;0@ zWsTBAEh=QpqkyZ#Dgh`yB;5pGVG$?Hu9DP7)0KvJ1t>zI`0@Q)19HBcBZDRZA|s90 zEPG>mYy`d>egHh~mRI@Q0oN?#{-o2Eap@P|grSjLeuofcp3O)z*%4imHLq8Ww3tYA zexgeFzuAtP%fWwpj}H)xNxyu&$%e@+GJ_(F223H zucPHqW_Ndv5aKLJZC7aSPbMcAjJZ#>wl1yZ7C7|+Bx!;o)2F1&_%-83CV)uj2`@e` zPhwcvbX8C>DH5=DOa!)Pi zs=pDX=V7R#YhU^4tQU8{5HB{)RzfOUhjGm8%zw-u)#tq48~eNnEAIloVyG|fkwj5@ z9>=mtu#tiVikiOY&*hRJbAo0p*vETgGNyAjB!4jBgr^f1wiMhbsD3bb30L>6s9o7A zm5lpx`YOxboP-dl#(2lzzLzIM#Xm)+j443_NA<7us)N?iG`Fikn+hZL4mLXJ{Q2J~ z%+shn!I4NcabLB z^4oPDZK7zXUJAITx9_FcTv>$F_i=ub4L%Y}Yd!j&kvDxrX}WFlbDK!yRbQt`S4%1p zlsiQm7|`XY#u+%d3utMv zCqju$YLe+^rV`w@vfR2SgSA!AE#2#EzdND3pa*T>z5c5Hp<8Pw{^+JZ|ii@s~v?==pNcPGh=*e(8>phfWJ% z@J_x^Un(=y{cs>)HCm$9a}#His;$HMlwo~s7-Za%ii{-G#cOVXJ5<5QWcf54NsccZ z8A%jLBFfGqPm1&dk1Z|XJAw`-#z`I?=@hKb^VF4op8`#WBNTi$l^F;14_b(>-s*M4 z*bIWZ?}agLS%IK(UTG+(4riR*prFOHTF&D=*C_OCH((q%@Wj~I7@#yeCqB{B8Nm)P z0>$#uBoc21`1c5jQ5T>nB5xuwkSlDFifn=*>0~ugUQm}jQUX;DA-QXhN2G=UyUvix z{!EX@E}Yc%^bxEY5BImNJy;QUk`h$u$O9PW%eUTDuz;LL4F!(Nkls-gAcbJ0DKE=v zz#1>u)O}U#Yr$SIKR?{N=6VgLyD60S+UaK@t@Jxq1>ucK82Ls9B_aSyFVj}6*nkdv zxNK539?3eZTX+9anKgD%j=O-z?8ltW44=IVsoB{^WY(VDv~gRpy>+WMx4E%%1!K*br2rAOua4jS8C(?4nZD_Ke+vMLqrF8J zL#6U>>f##3xbr{N+)kt!NPQnbtypN8rG&}9zCPcG8fDX+scmWyL0Ya{@dCSkOn(L! zQaO=aV$IIpye`oLSok+Bsg-~P-FF+o>qXy1_}vLx0v!MqLj%OT_iRKv900wN6- zOBrZ<;6Y1%3+P2~4J4goRoB{3@_dWQEXm-H^jnY!*zr7;Cv zsJT_f&f8_7(>jzlB$Z(*bgMxMR@a?GktNvJ^Z;e>MG$SY6~PJpw-=@@aT=2Cj@eaT z<|S16pL6A|bq1hK`Q6=^8-SF@fC^Na)Tv$oX+{)=P@{9{&f+$H_YR??Z!4HEGl~R1l`SzSu7(A0yiUyRNG@>kXR%kDl(Jt5*eUEdCft z#X?rLVtoX_(7-P#@H54TDF*mbH(hT@JZqj)1YR|*%=nR-uQp5iI)XGYf)p;8TB?W* z1N@K`)|NX1JU?ul6?$UT zZ7OEvk=$(6^33+RhV_|y@861f{A{H$@1T_%talFW^HA)ilJXRc8evADQl)79W zoEO-jYaJ1F-}|?CLK6ifJuD0~7Ou@Tza>s&q+?(`FWfrRPST7&j>|cF5`Bbtg(a4% ztzM{leH3FH?OIw~M7rFmJ^;5Ye@7qF9E*9~_p#8HUV(?7LCr|C%LIEl9q=kPjSM+B zK|rM@$RMhg(7MDdk0Q^x{{{M#s3P5{?{yGqqUFx(>sg)l^!v@ZrMSHHS5Z1=PZ`^5 z9o4@MjCU*LUwUg6qZD-1X)QJc-S^h(OxusfHH|%3k}-v5Ta6S#=qq^n`Pu4Y&6U{( zpN3AcPNQ)csY5#O6{j@3*W(ISSkY4UhZp#r&lK5ff?VE8o-B(SVQAwu zZY13><_S&}F^1L57xd1go>wU5EiZIAdx