diff --git a/backend/protzilla/constants/cif_columns.py b/backend/protzilla/constants/cif_columns.py new file mode 100644 index 000000000..5715f2f32 --- /dev/null +++ b/backend/protzilla/constants/cif_columns.py @@ -0,0 +1,63 @@ +from enum import StrEnum + + +ATOM_SITE_PREFIX = "_atom_site." + + +class ATOM_SITE_COLUMNS(StrEnum): + """ + Enum containing all column names that should be present in + the _atom_site. table for mmCIF files from PDB or AFDB + """ + + ID = f"{ATOM_SITE_PREFIX}id" + TYPE_SYMBOL = f"{ATOM_SITE_PREFIX}type_symbol" + LABEL_ATOM_ID = f"{ATOM_SITE_PREFIX}label_atom_id" + LABEL_ALT_ID = f"{ATOM_SITE_PREFIX}label_alt_id" + LABEL_COMP_ID = f"{ATOM_SITE_PREFIX}label_comp_id" + LABEL_ASYM_ID = f"{ATOM_SITE_PREFIX}label_asym_id" + LABEL_ENTITY_ID = f"{ATOM_SITE_PREFIX}label_entity_id" + LABEL_SEQ_ID = f"{ATOM_SITE_PREFIX}label_seq_id" + PDBX_PDB_INS_CODE = f"{ATOM_SITE_PREFIX}pdbx_PDB_ins_code" + CARTN_X = f"{ATOM_SITE_PREFIX}Cartn_x" + CARTN_Y = f"{ATOM_SITE_PREFIX}Cartn_y" + CARTN_Z = f"{ATOM_SITE_PREFIX}Cartn_z" + OCCUPANCY = f"{ATOM_SITE_PREFIX}occupancy" + B_ISO_OR_EQUIV = f"{ATOM_SITE_PREFIX}B_iso_or_equiv" + PDBX_FORMAL_CHARGE = f"{ATOM_SITE_PREFIX}pdbx_formal_charge" + AUTH_SEQ_ID = f"{ATOM_SITE_PREFIX}auth_seq_id" + AUTH_COMP_ID = f"{ATOM_SITE_PREFIX}auth_comp_id" + AUTH_ASYM_ID = f"{ATOM_SITE_PREFIX}auth_asym_id" + AUTH_ATOM_ID = f"{ATOM_SITE_PREFIX}auth_atom_id" + PDBX_PDB_MODEL_NUM = f"{ATOM_SITE_PREFIX}pdbx_PDB_model_num" + + +ATOM_SITE_LABEL_COMP_ID = ATOM_SITE_COLUMNS.LABEL_COMP_ID + +ATOM_SITE_COLUMNS_NUMERIC = [ + ATOM_SITE_COLUMNS.ID, + ATOM_SITE_COLUMNS.LABEL_SEQ_ID, + ATOM_SITE_COLUMNS.CARTN_X, + ATOM_SITE_COLUMNS.CARTN_Y, + ATOM_SITE_COLUMNS.CARTN_Z, + ATOM_SITE_COLUMNS.OCCUPANCY, + ATOM_SITE_COLUMNS.B_ISO_OR_EQUIV, + ATOM_SITE_COLUMNS.AUTH_SEQ_ID, +] + +CHEM_COMP_PREFIX = "_chem_comp." + + +class CHEM_COMP_COLUMNS(StrEnum): + """ + Enum containing all column names that should be present in + the _chem_comp. table for mmCIF files from PDB or AFDB + """ + + ID = f"{CHEM_COMP_PREFIX}id" + TYPE = f"{CHEM_COMP_PREFIX}type" + MON_NSTD_FLAG = f"{CHEM_COMP_PREFIX}mon_nstd_flag" + NAME = f"{CHEM_COMP_PREFIX}name" + PDBX_SYNONYMS = f"{CHEM_COMP_PREFIX}pdbx_synonyms" + FORMULA = f"{CHEM_COMP_PREFIX}formula" + FORMULA_WEIGHT = f"{CHEM_COMP_PREFIX}formula_weight" diff --git a/backend/protzilla/constants/data_types.py b/backend/protzilla/constants/data_types.py index 2665cfe3e..94f4db695 100644 --- a/backend/protzilla/constants/data_types.py +++ b/backend/protzilla/constants/data_types.py @@ -23,7 +23,7 @@ class DataKey(StrEnum): GENE_MAPPING_DF = "gene_mapping_df" CIF_DF = "cif_df" AMINO_ACID_SEQUENCES_DF = "amino_acid_sequences_df" - PAE_DF = "pae_df" # pae = predicted aligned error + PAE_MATRIX = "pae_matrix" # pae = predicted aligned error PLDDT_DF = "plddt_df" # plddt = predicted local distance difference test CROSSLINKING_DF = "crosslinking_df" CONFIDENCE_DF = "confidence_df" diff --git a/backend/protzilla/constants/option_types.py b/backend/protzilla/constants/option_types.py index 678efe9cb..30b2a61bc 100644 --- a/backend/protzilla/constants/option_types.py +++ b/backend/protzilla/constants/option_types.py @@ -60,6 +60,13 @@ class PValueColumnName(StrEnum): ptm = "PTM" +class CrosslinkingValidationCriterion(Enum): + manual_bounds = "Manual Bounds (set below)" + max_pae = "CL length +/- maximum PAE between sites" + min_pae = "CL length +/- minimum PAE between sites" + plddt_adjusted = "plDDT adjusted" + + FC_SIGNIFICANCE_COLUMNS = ["Protein ID", "fc_z_score", "fc_significance"] CORRECTED_P_VALUES_COLUMNS = [ "Protein ID", diff --git a/backend/protzilla/data_analysis/crosslinking_validation.py b/backend/protzilla/data_analysis/crosslinking_validation.py index 4f02ee3bb..8d8b765df 100644 --- a/backend/protzilla/data_analysis/crosslinking_validation.py +++ b/backend/protzilla/data_analysis/crosslinking_validation.py @@ -1,14 +1,15 @@ import itertools import ast import math -from pipes import stepkinds +from typing import Callable + +from backend.protzilla.constants.option_types import CrosslinkingValidationCriterion import pandas as pd import numpy as np import re import logging -from pandas.io.stata import stata_epoch import plotly.graph_objects as go from plotly.graph_objects import Figure @@ -371,6 +372,9 @@ def monomer_validation( crosslinker_information: dict[str, list[float]], cif_df: pd.DataFrame, amino_acid_sequences_df: pd.DataFrame, + pae_matrix: np.ndarray[tuple[int, int]], + plddt_df: pd.DataFrame, + validation_criterion: CrosslinkingValidationCriterion, ) -> dict: """ Validates crosslinking data for a monomeric protein structure by checking @@ -382,6 +386,8 @@ def monomer_validation( allowed distance boundaries (e.g., [min_dist, max_dist]). :param cif_df: DataFrame containing mmCIF information. :param amino_acid_sequences_df: DataFrame containing known amino acid sequences. + :param pae_matrix: NumPy 2D array containing AlphaFold PAE data. + :param plddt_df: DataFrame containing AlphaFold pLDDT data. :return: A dictionary containing the validation results and distance metrics. """ protein_id = structure_metadata_df["uniprot_accession"].iloc[0] @@ -392,9 +398,12 @@ def monomer_validation( structure_metadata_df=structure_metadata_df, cif_df=cif_df, amino_acid_sequences_df=amino_acid_sequences_df, + pae_matrix=pae_matrix, + plddt_df=plddt_df, valid_ids=valid_ids, id_column_name="_atom_site.pdbx_sifts_xref_db_acc", structures_to_validate=[protein_id], + validation_criterion=validation_criterion, ) @@ -453,6 +462,47 @@ def get_valid_ids_per_protein_id_from_job_request( return valid_ids +def get_global_residue_index( + position_within_protein: int, # 1-based index + chain_id: str, + cif_df: pd.DataFrame, +): + """ + For multimer PAE lookup: For a position within a given protein in a chain, + get the global 0-based residue index used to find that position in the PAE matrix. + + Note: This assumes that the order of AAs in the _atom_site table corresponds + to the order of residues in the pae matrix and thus the other residue-based tables in + the cif. + + :param position_within_protein: index of the amino acid within the protein (1-based) + :param chain_id: the chain ID of the protein within the complex + :param cif_df: DataFrame containing the _atom_site table of the complex structure + """ + + # Get table with only unique chain and sequence IDs and infer global index + index_lookup_df = ( + cif_df[["_atom_site.label_asym_id", "_atom_site.label_seq_id"]] + .drop_duplicates() + .reset_index(drop=True) + ) + index_lookup_df.reset_index(inplace=True) + + index_lookup_df = index_lookup_df[ + index_lookup_df["_atom_site.label_asym_id"] == chain_id + ] + index_lookup_df = index_lookup_df[ + index_lookup_df["_atom_site.label_seq_id"] == position_within_protein + ] + + if len(index_lookup_df) != 1: + raise ValueError( + "Invalid input: CIF contains multiple atoms mapped to same chain/sequence ID pair!" + ) + + return index_lookup_df["index"].iloc[0] + + def multimer_validation( crosslinking_df: pd.DataFrame, structure_metadata_df: pd.DataFrame, @@ -460,6 +510,9 @@ def multimer_validation( cif_df: pd.DataFrame, amino_acid_sequences_df: pd.DataFrame, job_request_df: pd.DataFrame, + plddt_df: pd.DataFrame, + pae_matrix: np.ndarray[tuple[int, int]], + validation_criterion: CrosslinkingValidationCriterion, ) -> dict: """ Validates crosslinking data for a multimeric protein complex by checking @@ -477,6 +530,8 @@ def multimer_validation( :param cif_df: DataFrame containing mmCIF information. :param amino_acid_sequences_df: DataFrame containing known amino acid sequences. :param job_request_df: DataFrame containing the loaded AlphaFold job request JSON. + :param plddt_df: DataFrame containing per-residue pLDDT values. + :param pae_matrix: NumPy 2D array containing the PAE values for each residue pair. :return: A dictionary containing the validation results and distance metrics. """ valid_ids = get_valid_ids_per_protein_id_from_job_request( @@ -493,6 +548,9 @@ def multimer_validation( valid_ids=valid_ids, id_column_name="_atom_site.label_entity_id", structures_to_validate=structures_to_validate, + pae_matrix=pae_matrix, + plddt_df=plddt_df, + validation_criterion=validation_criterion, ) @@ -502,9 +560,12 @@ def validate_with_angstrom_deviation( structure_metadata_df: pd.DataFrame, cif_df: pd.DataFrame, amino_acid_sequences_df: pd.DataFrame, - valid_ids: dict, + valid_ids: dict[str, list[int]], id_column_name: str, structures_to_validate: list, + validation_criterion: CrosslinkingValidationCriterion, + plddt_df: pd.DataFrame | None = None, + pae_matrix: np.ndarray[tuple[int, int]] | None = None, ) -> dict: """ Validates crosslinks by comparing the crosslinker lengths with the distances between the linked @@ -516,6 +577,8 @@ def validate_with_angstrom_deviation( :param crosslinker_information: Dictionary mapping crosslinker names to a list of three floats: [crosslinker_length, upper_accepted_deviation, lower_accepted_deviation]. :param cif_df: DataFrame containing CIF information (predicted coordinates of all the protein's atoms). + :param plddt_df: DataFrame containing the local AlphaFold pLDDT values for each residue. + :param pae_matrix: NumPy 2D array containing the PAE values for each residue pair. :param amino_acid_sequences_df: Dataframe that contains all known amino acid sequences. :param valid_ids: Dictionary mapping protein IDs to their valid chain/entity identifiers in the CIF data. :param id_column_name: The column name in the cif_df to use for matching against valid_ids. @@ -575,6 +638,47 @@ def check_crosslink(crosslink: pd.Series) -> pd.Series: amino_acid_sequences_df=amino_acid_sequences_df, protein_id=protein_id2 ) + def get_site_plddts(crosslink: pd.Series): + if plddt_df is None: + return np.nan, np.nan + + plddt_at_position1 = float( + plddt_df.query( + "residueNumber == @crosslink.crosslinker_position1 and " + + "chainID == @crosslink.Chain_id1" + ).iloc[0]["confidenceScore"] + ) + plddt_at_position2 = float( + plddt_df.query( + "residueNumber == @crosslink.crosslinker_position2 and " + + "chainID == @crosslink.Chain_id2" + ).iloc[0]["confidenceScore"] + ) + + return plddt_at_position1, plddt_at_position2 + + def get_paes(): + if pae_matrix is None: + return np.nan, np.nan + + pae_index_pos1 = get_global_residue_index( + crosslink.crosslinker_position1, crosslink.Chain_id1, cif_df + ) + pae_index_pos2 = get_global_residue_index( + crosslink.crosslinker_position2, crosslink.Chain_id2, cif_df + ) + pae_x_position1 = pae_matrix[ + pae_index_pos1, pae_index_pos2 + ] # Using position1 as scored residue + pae_x_position2 = pae_matrix[ + pae_index_pos2, pae_index_pos1 + ] # Using position2 as scored residue + + return pae_x_position1, pae_x_position2 + + plddt_at_position1, plddt_at_position2 = get_site_plddts(crosslink) + pae_x_position1, pae_x_position2 = get_paes() + predicted_distance = get_distance_between_two_amino_acids_in_angstrom( amino_acid_position1=crosslink.crosslinker_position1, amino_acid_position2=crosslink.crosslinker_position2, @@ -595,13 +699,68 @@ def check_crosslink(crosslink: pd.Series) -> pd.Series: f"Missing required information regarding crosslinker length " f"and/or accepted deviation for crosslinker '{crosslink.Crosslinker}'." ) - # Fallback to default deviation bounds when not explicitly provided - accepted_distance_lower_bound = crosslinker_length - ( - accepted_deviation_lower_bound or crosslinker_length - ) - accepted_distance_upper_bound = ( - accepted_deviation_upper_bound or float("inf") - ) + crosslinker_length + + accepted_distance_lower_bound: float = 0.0 + accepted_distance_upper_bound: float = 0.0 + + match validation_criterion: + case CrosslinkingValidationCriterion.manual_bounds.value: + # Fallback to default deviation bounds when not explicitly provided + accepted_distance_lower_bound = crosslinker_length - ( + accepted_deviation_lower_bound or crosslinker_length + ) + accepted_distance_upper_bound = ( + accepted_deviation_upper_bound or float("inf") + ) + crosslinker_length + + case CrosslinkingValidationCriterion.max_pae.value: + if np.isnan(pae_x_position1) or np.isnan(pae_x_position2): + raise ValueError("No PAE data given.") + + pae_tolerance = max(pae_x_position1, pae_x_position2) + accepted_distance_lower_bound = float( + max(crosslinker_length - pae_tolerance, 0.0) + ) + accepted_distance_upper_bound = float( + crosslinker_length + pae_tolerance + ) + + case CrosslinkingValidationCriterion.min_pae.value: + if np.isnan(pae_x_position1) or np.isnan(pae_x_position2): + raise ValueError("No PAE data given.") + pae_x_position1, pae_x_position2 = get_paes() + pae_tolerance = min(pae_x_position1, pae_x_position2) + accepted_distance_lower_bound = float( + max(crosslinker_length - pae_tolerance, 0.0) + ) + accepted_distance_upper_bound = float( + crosslinker_length + pae_tolerance + ) + + case CrosslinkingValidationCriterion.plddt_adjusted.value: + if np.isnan(plddt_at_position1) or np.isnan(plddt_at_position2): + raise ValueError("No pLDDT data given.") + + get_plddt_factor: Callable[[float], float] = lambda plddt: 1 - ( + plddt / 100 + ) + + plddt_factor_pos1 = get_plddt_factor(plddt_at_position1) + plddt_factor_pos2 = get_plddt_factor(plddt_at_position2) + + max_half_tolerance = crosslinker_length # Note: This is quite lenient + tolerance_pos1 = plddt_factor_pos1 * max_half_tolerance + tolerance_pos2 = plddt_factor_pos2 * max_half_tolerance + + accepted_distance_lower_bound = max( + crosslinker_length - tolerance_pos1 - tolerance_pos2, 0 + ) + accepted_distance_upper_bound = ( + crosslinker_length + tolerance_pos1 + tolerance_pos2 + ) + + case _: + raise ValueError("Invalid validation strategy") valid = ( accepted_distance_lower_bound @@ -615,6 +774,10 @@ def check_crosslink(crosslink: pd.Series) -> pd.Series: "valid_crosslink": valid, "crosslinker_position1": crosslink.crosslinker_position1, "crosslinker_position2": crosslink.crosslinker_position2, + "plddt_at_position1": plddt_at_position1, + "plddt_at_position2": plddt_at_position2, + "pae_x_position1": pae_x_position1, + "pae_x_position2": pae_x_position2, } ) @@ -624,6 +787,10 @@ def check_crosslink(crosslink: pd.Series) -> pd.Series: "valid_crosslink", "crosslinker_position1", "crosslinker_position2", + "plddt_at_position1", + "plddt_at_position2", + "pae_x_position1", + "pae_x_position2", ] relevant_crosslinks_df["crosslinker_position1"] = relevant_crosslinks_df[ @@ -907,69 +1074,76 @@ def diagrams_of_crosslinking_validation_data( def monomer_diagrams( - crosslinking_df: pd.DataFrame, + output_crosslinking_result_df: pd.DataFrame, structure_metadata_df: pd.DataFrame, crosslinker_information: dict[str, list[float]], - cif_df: pd.DataFrame, - amino_acid_sequences_df: pd.DataFrame, + validation_criterion: CrosslinkingValidationCriterion, ) -> list[Figure]: """ Generates visual diagrams to evaluate crosslinking validation results for a monomeric protein structure. - This function acts as a wrapper that first runs the crosslink validation - step via `monomer_validation`. It then extracts the resulting dataframe - of validated crosslinks and passes it to the diagram generator to create - the final plots. - - :param crosslinking_df: DataFrame containing the full set of crosslinks. + :param output_crosslinking_result_df: DataFrame containing the CL validation results. :param structure_metadata_df: DataFrame containing structural metadata; the first row's 'uniprot_accession' is used as the target. :param crosslinker_information: Dictionary mapping crosslinker names to a list of three floats: [length, upper_bound, lower_bound]. - :param cif_df: DataFrame containing parsed mmCIF structural coordinate data. - :param amino_acid_sequences_df: DataFrame containing known amino acid sequences. + :param validation_criterion: The validation criterion used for validation. :return: A list of Figure objects visualizing the crosslinking validation data. """ structures_to_validate = [structure_metadata_df["uniprot_accession"].iloc[0]] - validated_df = monomer_validation( - crosslinking_df, - structure_metadata_df, - crosslinker_information, - cif_df, - amino_acid_sequences_df, - )["crosslinking_result_df"] - return diagrams_of_crosslinking_validation_data( - validated_df=validated_df, - structures_to_validate=structures_to_validate, - crosslinker_information=crosslinker_information, - ) + + match validation_criterion: + case CrosslinkingValidationCriterion.manual_bounds.value: + return diagrams_of_crosslinking_validation_data( + validated_df=output_crosslinking_result_df, + structures_to_validate=structures_to_validate, + crosslinker_information=crosslinker_information, + ) + + # TODO: Separate Issue #429 + case ( + CrosslinkingValidationCriterion.max_pae.value + | CrosslinkingValidationCriterion.min_pae.value + ): + return diagrams_of_crosslinking_validation_data( + validated_df=output_crosslinking_result_df, + structures_to_validate=structures_to_validate, + crosslinker_information=crosslinker_information, + ) + + # TODO: Separate Issue #429 + case CrosslinkingValidationCriterion.plddt_adjusted.value: + return diagrams_of_crosslinking_validation_data( + validated_df=output_crosslinking_result_df, + structures_to_validate=structures_to_validate, + crosslinker_information=crosslinker_information, + ) + + case _: + return [] def multimer_diagrams( - crosslinking_df: pd.DataFrame, - structure_metadata_df: pd.DataFrame, + output_crosslinking_result_df: pd.DataFrame, crosslinker_information: dict[str, list[float]], - cif_df: pd.DataFrame, amino_acid_sequences_df: pd.DataFrame, job_request_df: pd.DataFrame, + validation_criterion: CrosslinkingValidationCriterion, ) -> list[Figure]: """ Generates visual diagrams to evaluate crosslinking validation results for a multimeric protein complex. This function parses an AlphaFold job request to determine the valid chain - compositions. It then runs `multimer_validation` to filter and validate - the relevant crosslinks, extracting the result to generate structural - distance and validation plots. + compositions and uses the passed result from the validation. - :param crosslinking_df: DataFrame containing the full set of crosslinks. - :param structure_metadata_df: DataFrame containing structural metadata. + :param output_crosslinking_result_df: DataFrame containing the CL validation results. :param crosslinker_information: Dictionary mapping crosslinker names to a list of three floats: [length, upper_bound, lower_bound]. - :param cif_df: DataFrame containing parsed mmCIF structural coordinate data. :param amino_acid_sequences_df: DataFrame containing known amino acid sequences. :param job_request_df: DataFrame containing the loaded AlphaFold job request JSON. + :param validation_criterion: The validation criterion used for validation. :return: A list of Figure objects visualizing the crosslinking validation data. """ valid_ids = get_valid_ids_per_protein_id_from_job_request( @@ -977,20 +1151,35 @@ def multimer_diagrams( ) structures_to_validate = list(valid_ids.keys()) - validated_df = multimer_validation( - crosslinking_df, - structure_metadata_df, - crosslinker_information, - cif_df, - amino_acid_sequences_df, - job_request_df, - )["crosslinking_result_df"] - - return diagrams_of_crosslinking_validation_data( - validated_df=validated_df, - structures_to_validate=structures_to_validate, - crosslinker_information=crosslinker_information, - ) + match validation_criterion: + case CrosslinkingValidationCriterion.manual_bounds.value: + return diagrams_of_crosslinking_validation_data( + validated_df=output_crosslinking_result_df, + structures_to_validate=structures_to_validate, + crosslinker_information=crosslinker_information, + ) + + # TODO: Separate Issue #429 + case ( + CrosslinkingValidationCriterion.max_pae.value + | CrosslinkingValidationCriterion.min_pae.value + ): + return diagrams_of_crosslinking_validation_data( + validated_df=output_crosslinking_result_df, + structures_to_validate=structures_to_validate, + crosslinker_information=crosslinker_information, + ) + + # TODO: Separate Issue #429 + case CrosslinkingValidationCriterion.plddt_adjusted.value: + return diagrams_of_crosslinking_validation_data( + validated_df=output_crosslinking_result_df, + structures_to_validate=structures_to_validate, + crosslinker_information=crosslinker_information, + ) + + case _: + return [] # Warning: Mostly AI generated diff --git a/backend/protzilla/importing/alphafold_protein_structure_load.py b/backend/protzilla/importing/alphafold_protein_structure_load.py index 4fd343fac..38f4d1307 100644 --- a/backend/protzilla/importing/alphafold_protein_structure_load.py +++ b/backend/protzilla/importing/alphafold_protein_structure_load.py @@ -11,15 +11,24 @@ from datetime import datetime, timezone import gemmi import pandas as pd +import numpy as np +import ast import requests import re from backend.protzilla.constants import paths from backend.protzilla.constants.protzilla_logging import logger +from backend.protzilla.constants.cif_columns import ( + ATOM_SITE_PREFIX, + ATOM_SITE_COLUMNS, + ATOM_SITE_COLUMNS_NUMERIC, + CHEM_COMP_PREFIX, + CHEM_COMP_COLUMNS, +) from backend.protzilla.importing.fasta_import import fasta_import from backend.protzilla.networking import download_file_from_url from backend.protzilla.utilities.utilities import copy_file_to_directory -from backend.protzilla.steps import OutputItem, OutputType +from backend.protzilla.steps import Output, OutputItem, OutputType def get_monomer_metadata_df() -> pd.DataFrame: @@ -108,26 +117,54 @@ def read_alphafold_mmcif(path: Path) -> pd.DataFrame: block = doc.sole_block() - cat_name = "_atom_site." - if cat_name not in block.get_mmcif_category_names(): + if ATOM_SITE_PREFIX not in block.get_mmcif_category_names(): return pd.DataFrame() - table = block.find_mmcif_category(cat_name) - - columns = list(table.tags) - nrows = len(table) - data = {} - for j, col in enumerate(columns): - col_values = [] - for i in range(nrows): - row = table[i] - if j < len(row): - col_values.append(row[j]) - else: - col_values.append(None) - data[col] = col_values + atom_site_table = block.find_mmcif_category(ATOM_SITE_PREFIX) + + atom_site_df = pd.DataFrame( + list(atom_site_table), + columns=list(atom_site_table.tags), + dtype=pd.StringDtype(), + ) + + # convert to numeric dtype for numeric columns present in the dataframe + present_numeric_columns = [ + column for column in ATOM_SITE_COLUMNS_NUMERIC if column in atom_site_table.tags + ] + atom_site_df[present_numeric_columns] = atom_site_df[present_numeric_columns].apply( + pd.to_numeric, errors="coerce" + ) + + atom_site_df = atom_site_df.convert_dtypes() + + if CHEM_COMP_PREFIX not in block.get_mmcif_category_names(): + raise ValueError( + f"Required table with prefix {CHEM_COMP_PREFIX} not found in {path}" + ) + + chem_comp_table = block.find_mmcif_category(CHEM_COMP_PREFIX) + + chem_comp_df = pd.DataFrame( + list(chem_comp_table), + columns=list(chem_comp_table.tags), + dtype=pd.StringDtype(), + )[[CHEM_COMP_COLUMNS.ID, CHEM_COMP_COLUMNS.MON_NSTD_FLAG]] - return pd.DataFrame(data) + # convert flags to native booleans + bool_map = {"y": True, "n": False, ".": pd.NA} + + chem_comp_df[CHEM_COMP_COLUMNS.MON_NSTD_FLAG] = ( + chem_comp_df[CHEM_COMP_COLUMNS.MON_NSTD_FLAG].map(bool_map).astype("boolean") + ) + + # merge on the comp_id and drop the duplicate column + return atom_site_df.merge( + chem_comp_df, + how="left", + left_on=ATOM_SITE_COLUMNS.LABEL_COMP_ID, + right_on=CHEM_COMP_COLUMNS.ID, + ).drop(CHEM_COMP_COLUMNS.ID, axis=1) def get_correct_af_directories( @@ -328,6 +365,9 @@ def handle_alphafold_files( if temp_dir is not None: shutil.rmtree(temp_dir, ignore_errors=True) + # For consistency with multimer pLDDT + plddt_df["chainID"] = "A" + return { "cif_df": cif_df, "pae_df": pae_df, @@ -426,8 +466,13 @@ def fetch_alphafold_protein_structure( messages.append(dict(level=logging.WARNING, msg=message)) data_for_visualization = None + pae_string = str(df_dict["pae_df"]["predicted_aligned_error"].iloc[0]) + pae_matrix = np.array(ast.literal_eval(pae_string)) + del df_dict["pae_df"] + return dict( **df_dict, + pae_matrix=OutputItem(output_type=OutputType.JOBLIB_ARTIFACT, value=pae_matrix), messages=messages, visualization=OutputItem( output_type=OutputType.VISUALIZATION, value=data_for_visualization @@ -435,6 +480,88 @@ def fetch_alphafold_protein_structure( ) +def reduce_pae_to_per_amino_acid( + pae_matrix: np.ndarray, + token_res_ids: list[int], + cif_df: pd.DataFrame, +): + """ + Reduces AlphaFold3 PAE matrices (per-token) to AlphaFold2 PAE matrices (per-amino acid). + If the number of tokens mapping to one AA equals the number of atoms (common for predicted PTMs), + the CA token gets used. Otherwise, the first token gets used. + Required for predictions with PTMs! + + :param pae_matrix: the per-token PAE matrix + :param token_res_ids: the token_res_ids table from the AF3 full_data json + :param cif_df: the atom_site table as a dataframe + + :return: the per-AA/per-residue PAE matrix + """ + + indices_to_delete = [] + + current_idx = 0 + runs = [] + + current_chain_idx = 0 + # Get all runs (start_token_idx, len, chain_idx, res_id) of same res ids into one list + while current_idx < len(token_res_ids): + start_token_idx = current_idx + res_id = token_res_ids[start_token_idx] + length = 1 + + if res_id == 1: + current_chain_idx += 1 + + while True: + current_idx += 1 + if ( + current_idx < len(token_res_ids) + and token_res_ids[current_idx] == res_id + ): + length += 1 + else: + break + + runs.append((start_token_idx, length, current_chain_idx, res_id)) + + for start_token_idx, length, chain_idx, res_id in runs: + if length == 1: + continue + + # Get corresponding entries of _atom_site table for the token + relevant_cif_df = cif_df[cif_df["_atom_site.label_entity_id"] == str(chain_idx)] + relevant_cif_df = relevant_cif_df[ + relevant_cif_df["_atom_site.label_seq_id"] == res_id + ] + + keep_offset = 0 # Relative index to keep within duplicate tokens for one amino acid. Default: first token + + # If we have one token per atom, we try to take the CA atom + if len(relevant_cif_df) == length: + # Reset index twice to get 0..length enumeration for atoms in index + relevant_cif_df.reset_index(drop=True, inplace=True) + relevant_cif_df.reset_index(inplace=True) + + relevant_cif_df = relevant_cif_df[ + relevant_cif_df["_atom_site.label_atom_id"] == "CA" + ] + # 0 or 2+ CA atoms -> default + if len(relevant_cif_df) == 1: + keep_offset = int(relevant_cif_df.iloc[0]["index"]) + + for duplicate_idx in range(0, length): + if duplicate_idx != keep_offset: + indices_to_delete.append(start_token_idx + duplicate_idx) + + # Apply deletion + mask = np.ones(len(pae_matrix), dtype=bool) + mask[indices_to_delete] = False + pae_matrix = pae_matrix[np.ix_(mask, mask)] + + return pae_matrix + + def get_all_available_entry_ids_of_monomer_metadata() -> list[str]: """ " Get the entry ids of all the protein structure predictions that can be found on disk. @@ -694,6 +821,9 @@ def get_monomer_structure_dfs(entry_id: str) -> dict[str, Any]: logger.exception(msg) raise RuntimeError(msg) from e + # For consistency with multimer pLDDT + plddt_df["chainID"] = "A" + df_dict = { "structure_metadata_df": monomer_metadata_df, "cif_df": cif_df, @@ -702,12 +832,19 @@ def get_monomer_structure_dfs(entry_id: str) -> dict[str, Any]: "amino_acid_sequences_df": amino_acid_sequences_df, } check_success_of_get_df(entry_id=entry_id, df_dict=df_dict, messages=messages) + data_for_visualization = { "structure_entry_id": entry_id, "cif_df": cif_df, } + + pae_string = str(df_dict["pae_df"]["predicted_aligned_error"].iloc[0]) + pae_matrix = np.array(ast.literal_eval(pae_string)) + del df_dict["pae_df"] + return dict( **df_dict, + pae_matrix=OutputItem(output_type=OutputType.JOBLIB_ARTIFACT, value=pae_matrix), messages=messages, visualization=OutputItem( output_type=OutputType.VISUALIZATION, value=data_for_visualization @@ -715,6 +852,73 @@ def get_monomer_structure_dfs(entry_id: str) -> dict[str, Any]: ) +def unwrap_full_data_df(full_data_df: pd.DataFrame) -> dict[str, Any]: + """ + Extracts certain data from a full_data_df, deletes the extracted columns + and returns the "remaining" full_data_df as well as the extracted data. + + :param full_data_df: The AlphaFold3 full_data_df + :return dict: + - "full_data_df": The updated reduced full_data_df + - "pae_matrix": Numpy matrix with the PAE values for each residue pair + - "token_res_ids": List with the token -> AA mappings + """ + + try: + pae_matrix = np.array(full_data_df["pae"].iloc[0]) + full_data_df = full_data_df.drop(columns=["pae"]) + except KeyError: + pae_matrix = None + + try: + token_res_ids = np.array(full_data_df["token_res_ids"].iloc[0]) + full_data_df = full_data_df.drop(columns=["token_res_ids"]) + except KeyError as e: + raise KeyError( + "Prediction data does not contain required prediction token to amino acid mapping." + ) from e + + return dict( + full_data_df=full_data_df, + pae_matrix=pae_matrix, + token_res_ids=token_res_ids, + ) + + +def get_plddt_from_cif(cif_df: pd.DataFrame) -> pd.DataFrame | None: + """ + For use with multimers predicted using Alphafold3. + Returns per-residue pLDDT values for the predicted structure. + Note that sine AlphaFold3 uses per-atom pLDDT, we use the pLDDT for the CA atom. + See also https://github.com/google-deepmind/alphafold3/issues/330 + + :param cif_df: the cif_df holding the _atom_site table. + :return: DataFrame containing columns + "chainID", "residueNumber", "confidenceScore", "confidenceCategory" + """ + + try: + filtered_cif_df = cif_df[cif_df["_atom_site.label_atom_id"] == "CA"] + filtered_cif_df = filtered_cif_df[ + [ + "_atom_site.auth_asym_id", + "_atom_site.label_seq_id", + "_atom_site.B_iso_or_equiv", + ] + ] + filtered_cif_df = filtered_cif_df.rename( + columns={ + "_atom_site.auth_asym_id": "chainID", + "_atom_site.label_seq_id": "residueNumber", + "_atom_site.B_iso_or_equiv": "confidenceScore", + } + ) + return filtered_cif_df + + except KeyError: + return None + + def get_multimer_structure_dfs(entry_id: str) -> dict[str, Any]: """ Writes multimer structure data from disk of a specific entry ID into dataframes. @@ -798,9 +1002,31 @@ def get_multimer_structure_dfs(entry_id: str) -> dict[str, Any]: "structure_entry_id": entry_id, "cif_df": cif_df, } + + unwrapped_full_data = unwrap_full_data_df(df_dict["full_data_df"]) + df_dict["full_data_df"] = unwrapped_full_data["full_data_df"] + + pae_matrix = unwrapped_full_data["pae_matrix"] + token_res_ids = unwrapped_full_data["token_res_ids"] + plddt_df = get_plddt_from_cif(df_dict["cif_df"]) + + pae_matrix = reduce_pae_to_per_amino_acid( + pae_matrix, token_res_ids, df_dict["cif_df"] + ) + + if plddt_df is None: + messages.append( + dict( + level=logging.WARNING, + msg=f"Could not parse pLDDT values from CIF file. File is likely malformed!", + ) + ) + return dict( **df_dict, messages=messages, + plddt_df=plddt_df, + pae_matrix=OutputItem(output_type=OutputType.JOBLIB_ARTIFACT, value=pae_matrix), visualization=OutputItem( output_type=OutputType.VISUALIZATION, value=data_for_visualization ), @@ -946,13 +1172,38 @@ def upload_multimer_prediction( } if not any(df.empty for df in df_dict.values()): - success_msg = f"Successfully loaded AlphaFold data for entry '{entry_id}'" - logger.info(success_msg) - messages.append(dict(level=logging.INFO, msg=success_msg)) + + unwrapped_full_data = unwrap_full_data_df(df_dict["full_data_df"]) + df_dict["full_data_df"] = unwrapped_full_data["full_data_df"] + + pae_matrix = reduce_pae_to_per_amino_acid( + unwrapped_full_data["pae_matrix"], + unwrapped_full_data["token_res_ids"], + df_dict["cif_df"], + ) + + pae_matrix = OutputItem( + output_type=OutputType.JOBLIB_ARTIFACT, + value=pae_matrix, + ) + df_dict["pae_matrix"] = pae_matrix + df_dict["plddt_df"] = get_plddt_from_cif(df_dict["cif_df"]) + + if df_dict["plddt_df"] is None: + messages.append( + dict( + level=logging.WARNING, + msg=f"Could not parse pLDDT values from CIF file. File is likely malformed!", + ) + ) data_for_visualization = { "structure_entry_id": entry_id, "cif_df": cif_df, } + + success_msg = f"Successfully loaded AlphaFold data for entry '{entry_id}'" + logger.info(success_msg) + messages.append(dict(level=logging.INFO, msg=success_msg)) else: message = f"Could not load AlphaFold data for entry '{entry_id}'" logger.warning(message) diff --git a/backend/protzilla/methods/data_analysis.py b/backend/protzilla/methods/data_analysis.py index 5a758f5bd..b7f47667c 100644 --- a/backend/protzilla/methods/data_analysis.py +++ b/backend/protzilla/methods/data_analysis.py @@ -3,6 +3,7 @@ import ast from backend.protzilla.constants.option_types import ( + CrosslinkingValidationCriterion, LogBaseWithNoneType, SimpleImputerStrategyType, ) @@ -70,6 +71,7 @@ MultiSelectField, NumberField, TextField, + FormDivider, ) from backend.protzilla.steps import Step, Section from backend.protzilla.step_manager import StepManager @@ -2415,6 +2417,13 @@ def create_crosslink_input_fields(self, form: Form, run: Run): form.add_field(upper_bound_length_deviation_field) form.add_field(lower_bound_length_deviation_field) + bounds_visible = ( + form["validation_criterion"].value + == CrosslinkingValidationCriterion.manual_bounds.value + ) + form[f"{crosslinker}_upper_accepted_deviation"].isVisible = bounds_visible + form[f"{crosslinker}_lower_accepted_deviation"].isVisible = bounds_visible + def collect_crosslinking_information(self, steps: StepManager, inputs) -> dict: # although crosslinker_information is not a dataframe we need to insert the user information regarding the crosslinks as a dictionary into the inputs crosslinker_to_length_and_deviation = {} @@ -2450,9 +2459,18 @@ def create_form(self): return Form( label="Ångström Deviation - Monomer", input_fields=[ + DropdownField( + name="validation_criterion", + label="Validation criterion", + options=CrosslinkingValidationCriterion, + value=CrosslinkingValidationCriterion.manual_bounds, + ), + FormDivider( + label="Crosslinker lengths and bounds", + ), InfoField( label="Set default cross-link lengths and their upper/lower deviations in settings under 'Cross-Links Defaults'.", - ) + ), ], ) @@ -2470,8 +2488,17 @@ def create_form(self): return Form( label="Ångström Deviation - Multimer", input_fields=[ + DropdownField( + name="validation_criterion", + label="Validation criterion", + options=CrosslinkingValidationCriterion, + value=CrosslinkingValidationCriterion.manual_bounds, + ), + FormDivider( + label="Crosslinker lengths and bounds", + ), InfoField( label="Set default cross-link lengths and their upper/lower deviations in settings under 'Cross-Links Defaults'.", - ) + ), ], ) diff --git a/backend/protzilla/methods/importing.py b/backend/protzilla/methods/importing.py index f4fdd50ad..4d07d53f9 100644 --- a/backend/protzilla/methods/importing.py +++ b/backend/protzilla/methods/importing.py @@ -438,9 +438,9 @@ class AlphaFoldPredictionLoad(ImportingStep): output_keys = [ DataKey.STRUCTURE_METADATA_DF, DataKey.CIF_DF, - DataKey.PAE_DF, DataKey.PLDDT_DF, DataKey.AMINO_ACID_SEQUENCES_DF, + DataKey.PAE_MATRIX, ] plot_method = None @@ -503,7 +503,7 @@ class ImportMonomerStructurePredictionFromDisk(ImportingStep): output_keys = [ DataKey.STRUCTURE_METADATA_DF, DataKey.CIF_DF, - DataKey.PAE_DF, + DataKey.PAE_MATRIX, DataKey.PLDDT_DF, DataKey.AMINO_ACID_SEQUENCES_DF, ] @@ -540,6 +540,8 @@ class UploadMultimerPredictions(ImportingStep): DataKey.FULL_DATA_DF, DataKey.JOB_REQUEST_DF, DataKey.AMINO_ACID_SEQUENCES_DF, + DataKey.PAE_MATRIX, + DataKey.PLDDT_DF, ] def create_form(self): @@ -617,6 +619,8 @@ class ImportMultimerStructurePredictionFromDisk(ImportingStep): DataKey.FULL_DATA_DF, DataKey.JOB_REQUEST_DF, DataKey.AMINO_ACID_SEQUENCES_DF, + DataKey.PAE_MATRIX, + DataKey.PLDDT_DF, ] def create_form(self): diff --git a/backend/tests/protzilla/data_analysis/test_crosslinking_validation.py b/backend/tests/protzilla/data_analysis/test_crosslinking_validation.py index 251b0cf6f..0e815a17d 100644 --- a/backend/tests/protzilla/data_analysis/test_crosslinking_validation.py +++ b/backend/tests/protzilla/data_analysis/test_crosslinking_validation.py @@ -1,10 +1,12 @@ import pandas as pd +from backend.protzilla.constants.option_types import CrosslinkingValidationCriterion import pytest import logging from unittest.mock import patch, MagicMock import plotly.graph_objects as go from plotly.graph_objects import Figure import pandas.testing as pdt +import numpy as np from backend.protzilla.data_analysis.crosslinking_validation import ( @@ -35,11 +37,14 @@ (6.01, False), # outside bounds ], ) -def test_validate_with_angstrom_deviation(distance, expected): +def test_monomer_validation_baseline_manual_bounds(distance, expected): + crosslinker_information = {"DSS": [5.0, 1.0, 1.0]} # Length 5 Å ± 1 Å + # Fake AlphaFold Data with chain IDs cif_df = pd.DataFrame( { "_atom_site.label_atom_id": ["CA", "CA"], + "_atom_site.label_asym_id": ["A", "A"], "_atom_site.label_seq_id": [1, 2], "_atom_site.Cartn_x": [0, distance], "_atom_site.Cartn_y": [0, 0], @@ -70,7 +75,6 @@ def test_validate_with_angstrom_deviation(distance, expected): {"entry_id": ["test"], "uniprot_accession": ["P12345"]} ) - crosslinker_information = {"DSS": [5.0, 1.0, 1.0]} # Länge 5 Å ± 1 Å valid_ids = {"P12345": ["P12345"]} structures_to_validate = ["P12345"] @@ -83,9 +87,10 @@ def test_validate_with_angstrom_deviation(distance, expected): valid_ids=valid_ids, id_column_name="_atom_site.pdbx_sifts_xref_db_acc", structures_to_validate=structures_to_validate, + validation_criterion=CrosslinkingValidationCriterion.manual_bounds.value, ) - df = result["crosslinking_result_df"] + df: pd.DataFrame = result["crosslinking_result_df"] assert "alphafold_distance" in df.columns assert "valid_crosslink" in df.columns @@ -97,6 +102,318 @@ def test_validate_with_angstrom_deviation(distance, expected): assert df.loc[0, "link_type"] == "intra" +@pytest.mark.parametrize( + "distance, expected", + [ + (4.99, False), + (5.0, True), + (5.1, False), + ], +) +def test_cl_validation_pae_noerrror(distance, expected): + crosslinker_information = {"DSS": [5.0, 1.0, 1.0]} # Length 5 Å (± 1 Å) + pae_matrix = np.array([[np.nan, 0], [0, np.nan]]) + + # Fake AlphaFold Data with chain IDs + cif_df = pd.DataFrame( + { + "_atom_site.label_atom_id": ["CA", "CA"], + "_atom_site.label_asym_id": ["A", "A"], + "_atom_site.label_seq_id": [1, 2], + "_atom_site.Cartn_x": [0, distance], + "_atom_site.Cartn_y": [0, 0], + "_atom_site.Cartn_z": [0, 0], + "_atom_site.auth_asym_id": ["A", "A"], + "_atom_site.pdbx_sifts_xref_db_acc": ["P12345", "P12345"], + } + ) + + amino_acid_sequences_df = pd.DataFrame( + {"Protein ID": ["P12345-1"], "Protein Sequence": ["AB"]} + ) + + # Fake Crosslink Data + crosslinking_df = pd.DataFrame( + { + "Protein_id1": ["P12345"], + "Protein_id2": ["P12345"], + "Peptide1": ["A"], + "Peptide2": ["B"], + "CL_position_within_peptide1": [0], + "CL_position_within_peptide2": [0], + "Crosslinker": ["DSS"], + } + ) + + structure_metadata_df = pd.DataFrame( + {"entry_id": ["test"], "uniprot_accession": ["P12345"]} + ) + + valid_ids = {"P12345": ["P12345"]} + structures_to_validate = ["P12345"] + + result = validate_with_angstrom_deviation( + crosslinking_df=crosslinking_df, + structure_metadata_df=structure_metadata_df, + crosslinker_information=crosslinker_information, + cif_df=cif_df, + amino_acid_sequences_df=amino_acid_sequences_df, + valid_ids=valid_ids, + id_column_name="_atom_site.pdbx_sifts_xref_db_acc", + structures_to_validate=structures_to_validate, + validation_criterion=CrosslinkingValidationCriterion.min_pae.value, + pae_matrix=pae_matrix, + ) + + df: pd.DataFrame = result["crosslinking_result_df"] + assert df.loc[0, "valid_crosslink"] == expected + + +@pytest.mark.parametrize( + "distance, expected_min, expected_max", + [ + (2.0, False, False), + (3.0, False, True), + (4.0, True, True), + (5.0, True, True), + (6.0, True, True), + (7.0, False, True), + (8.0, False, False), + ], +) +def test_cl_validation_pae_haserror(distance, expected_min, expected_max): + crosslinker_information = {"DSS": [5.0, 1.0, 1.0]} # Length 5 Å (± 1 Å) + pae_matrix = np.array([[np.nan, 1], [2, np.nan]]) + + # Fake AlphaFold Data with chain IDs + cif_df = pd.DataFrame( + { + "_atom_site.label_atom_id": ["CA", "CA"], + "_atom_site.label_asym_id": ["A", "A"], + "_atom_site.label_seq_id": [1, 2], + "_atom_site.Cartn_x": [0, distance], + "_atom_site.Cartn_y": [0, 0], + "_atom_site.Cartn_z": [0, 0], + "_atom_site.auth_asym_id": ["A", "A"], + "_atom_site.pdbx_sifts_xref_db_acc": ["P12345", "P12345"], + } + ) + + amino_acid_sequences_df = pd.DataFrame( + {"Protein ID": ["P12345-1"], "Protein Sequence": ["AB"]} + ) + + # Fake Crosslink Data + crosslinking_df = pd.DataFrame( + { + "Protein_id1": ["P12345"], + "Protein_id2": ["P12345"], + "Peptide1": ["A"], + "Peptide2": ["B"], + "CL_position_within_peptide1": [0], + "CL_position_within_peptide2": [0], + "Crosslinker": ["DSS"], + } + ) + + structure_metadata_df = pd.DataFrame( + {"entry_id": ["test"], "uniprot_accession": ["P12345"]} + ) + + valid_ids = {"P12345": ["P12345"]} + structures_to_validate = ["P12345"] + + result_min = validate_with_angstrom_deviation( + crosslinking_df=crosslinking_df, + structure_metadata_df=structure_metadata_df, + crosslinker_information=crosslinker_information, + cif_df=cif_df, + amino_acid_sequences_df=amino_acid_sequences_df, + valid_ids=valid_ids, + id_column_name="_atom_site.pdbx_sifts_xref_db_acc", + structures_to_validate=structures_to_validate, + validation_criterion=CrosslinkingValidationCriterion.min_pae.value, + pae_matrix=pae_matrix, + ) + + df: pd.DataFrame = result_min["crosslinking_result_df"] + assert df.loc[0, "valid_crosslink"] == expected_min + + result_max = validate_with_angstrom_deviation( + crosslinking_df=crosslinking_df, + structure_metadata_df=structure_metadata_df, + crosslinker_information=crosslinker_information, + cif_df=cif_df, + amino_acid_sequences_df=amino_acid_sequences_df, + valid_ids=valid_ids, + id_column_name="_atom_site.pdbx_sifts_xref_db_acc", + structures_to_validate=structures_to_validate, + validation_criterion=CrosslinkingValidationCriterion.max_pae.value, + pae_matrix=pae_matrix, + ) + + df: pd.DataFrame = result_max["crosslinking_result_df"] + assert df.loc[0, "valid_crosslink"] == expected_max + + +@pytest.mark.parametrize( + "distance, expected", + [ + (4.99, False), + (5.0, True), + (5.1, False), + ], +) +def test_cl_validation_plddt_noerrror(distance, expected): + crosslinker_information = {"DSS": [5.0, 1.0, 1.0]} # Length 5 Å (± 1 Å) + + plddt_df_noerror = pd.DataFrame( + { + "chainID": ["A", "A"], + "residueNumber": [1, 2], + "confidenceScore": [100, 100], + # confidenceCategory is not required + } + ) + + # Fake AlphaFold Data with chain IDs + cif_df = pd.DataFrame( + { + "_atom_site.label_atom_id": ["CA", "CA"], + "_atom_site.label_asym_id": ["A", "A"], + "_atom_site.label_seq_id": [1, 2], + "_atom_site.Cartn_x": [0, distance], + "_atom_site.Cartn_y": [0, 0], + "_atom_site.Cartn_z": [0, 0], + "_atom_site.auth_asym_id": ["A", "A"], + "_atom_site.pdbx_sifts_xref_db_acc": ["P12345", "P12345"], + } + ) + + amino_acid_sequences_df = pd.DataFrame( + {"Protein ID": ["P12345-1"], "Protein Sequence": ["AB"]} + ) + + # Fake Crosslink Data + crosslinking_df = pd.DataFrame( + { + "Protein_id1": ["P12345"], + "Protein_id2": ["P12345"], + "Peptide1": ["A"], + "Peptide2": ["B"], + "CL_position_within_peptide1": [0], + "CL_position_within_peptide2": [0], + "Crosslinker": ["DSS"], + } + ) + + structure_metadata_df = pd.DataFrame( + {"entry_id": ["test"], "uniprot_accession": ["P12345"]} + ) + + valid_ids = {"P12345": ["P12345"]} + structures_to_validate = ["P12345"] + + result = validate_with_angstrom_deviation( + crosslinking_df=crosslinking_df, + structure_metadata_df=structure_metadata_df, + crosslinker_information=crosslinker_information, + cif_df=cif_df, + amino_acid_sequences_df=amino_acid_sequences_df, + valid_ids=valid_ids, + id_column_name="_atom_site.pdbx_sifts_xref_db_acc", + structures_to_validate=structures_to_validate, + validation_criterion=CrosslinkingValidationCriterion.plddt_adjusted.value, + plddt_df=plddt_df_noerror, + ) + + df: pd.DataFrame = result["crosslinking_result_df"] + assert df.loc[0, "valid_crosslink"] == expected + + +# l_cl = 5, t_x = 1.25, t_y = 3.5. +# So range is 0.25 <= d <= 9.75 +@pytest.mark.parametrize( + "distance, expected", + [ + (0.0, False), + (0.24, False), + (0.25, True), + (5.0, True), + (9.0, True), + (9.74, True), + (9.75, True), + (9.76, False), + (10.0, False), + ], +) +def test_cl_validation_plddt_witherror(distance, expected): + crosslinker_information = {"DSS": [5.0, 1.0, 1.0]} # Length 5 Å (± 1 Å) + + plddt_df_noerror = pd.DataFrame( + { + "chainID": ["A", "A"], + "residueNumber": [1, 2], + "confidenceScore": [75, 30], + # confidenceCategory is not required + } + ) + + # Fake AlphaFold Data with chain IDs + cif_df = pd.DataFrame( + { + "_atom_site.label_atom_id": ["CA", "CA"], + "_atom_site.label_asym_id": ["A", "A"], + "_atom_site.label_seq_id": [1, 2], + "_atom_site.Cartn_x": [0, distance], + "_atom_site.Cartn_y": [0, 0], + "_atom_site.Cartn_z": [0, 0], + "_atom_site.auth_asym_id": ["A", "A"], + "_atom_site.pdbx_sifts_xref_db_acc": ["P12345", "P12345"], + } + ) + + amino_acid_sequences_df = pd.DataFrame( + {"Protein ID": ["P12345-1"], "Protein Sequence": ["AB"]} + ) + + # Fake Crosslink Data + crosslinking_df = pd.DataFrame( + { + "Protein_id1": ["P12345"], + "Protein_id2": ["P12345"], + "Peptide1": ["A"], + "Peptide2": ["B"], + "CL_position_within_peptide1": [0], + "CL_position_within_peptide2": [0], + "Crosslinker": ["DSS"], + } + ) + + structure_metadata_df = pd.DataFrame( + {"entry_id": ["test"], "uniprot_accession": ["P12345"]} + ) + + valid_ids = {"P12345": ["P12345"]} + structures_to_validate = ["P12345"] + + result = validate_with_angstrom_deviation( + crosslinking_df=crosslinking_df, + structure_metadata_df=structure_metadata_df, + crosslinker_information=crosslinker_information, + cif_df=cif_df, + amino_acid_sequences_df=amino_acid_sequences_df, + valid_ids=valid_ids, + id_column_name="_atom_site.pdbx_sifts_xref_db_acc", + structures_to_validate=structures_to_validate, + validation_criterion=CrosslinkingValidationCriterion.plddt_adjusted.value, + plddt_df=plddt_df_noerror, + ) + + df: pd.DataFrame = result["crosslinking_result_df"] + assert df.loc[0, "valid_crosslink"] == expected + + def test_modify_form_creates_crosslinker_fields(): crosslinking_df = pd.DataFrame({"Crosslinker": ["DSS", "BS3", "DSS"]}) @@ -353,6 +670,7 @@ def test_validate_multimer_filters_only_pairs_within_structures_to_validate(): valid_ids=valid_ids, id_column_name="_atom_site.label_entity_id", structures_to_validate=structures_to_validate, + validation_criterion=CrosslinkingValidationCriterion.manual_bounds.value, ) result_df = out["crosslinking_result_df"] @@ -428,6 +746,7 @@ def test_validate_multimer_no_links_between_structures_returns_empty_and_warning valid_ids=valid_ids, id_column_name="_atom_site.label_entity_id", structures_to_validate=structures_to_validate, + validation_criterion=CrosslinkingValidationCriterion.manual_bounds.value, ) result_df = out["crosslinking_result_df"] @@ -499,6 +818,7 @@ def test_validate_multimer_duplicates_rows_for_multiple_peptide_matches_and_vali valid_ids=valid_ids, id_column_name="_atom_site.label_entity_id", structures_to_validate=structures_to_validate, + validation_criterion=CrosslinkingValidationCriterion.manual_bounds.value, ) result_df = out["crosslinking_result_df"] @@ -827,6 +1147,7 @@ def test_validate_multimer_with_invalid_crosslinks(): valid_ids=valid_ids, id_column_name="_atom_site.label_entity_id", structures_to_validate=structures_to_validate, + validation_criterion=CrosslinkingValidationCriterion.manual_bounds.value, ) result_df = out["crosslinking_result_df"] @@ -1003,6 +1324,7 @@ def test_validate_multimer_same_protein_different_chains_intra_vs_inter(): valid_ids=valid_ids, id_column_name="_atom_site.label_entity_id", structures_to_validate=structures_to_validate, + validation_criterion=CrosslinkingValidationCriterion.manual_bounds.value, ) result_df = out["crosslinking_result_df"] diff --git a/backend/tests/protzilla/importing/test_alphafold_protein_structure_load.py b/backend/tests/protzilla/importing/test_alphafold_protein_structure_load.py index 940bbcea2..334cdbccf 100644 --- a/backend/tests/protzilla/importing/test_alphafold_protein_structure_load.py +++ b/backend/tests/protzilla/importing/test_alphafold_protein_structure_load.py @@ -1,9 +1,11 @@ +from backend.protzilla.steps import OutputItem import pandas as pd import pytest import json import logging import shutil from pathlib import Path +import numpy as np from backend.protzilla.importing.alphafold_protein_structure_load import ( @@ -29,6 +31,12 @@ check_success_of_get_df, ) from backend.protzilla.constants import paths +from backend.protzilla.constants.cif_columns import ( + ATOM_SITE_PREFIX, + ATOM_SITE_COLUMNS, + CHEM_COMP_COLUMNS, +) +from backend.protzilla.constants.data_types import DataKey def test_to_fasta_default_header_and_newline(): @@ -90,11 +98,18 @@ def test_read_alphafold_mmcif_valid_atom_site(tmp_path): """ data_test loop_ +_chem_comp.id +_chem_comp.mon_nstd_flag +SER y +# +loop_ _atom_site.id _atom_site.type_symbol +_atom_site.label_atom_id +_atom_site.label_comp_id _atom_site.Cartn_x -N N 1.0 -CA C 2.0 +1 N N SER 1.0 +2 C CA SER 2.0 """ ) @@ -102,14 +117,19 @@ def test_read_alphafold_mmcif_valid_atom_site(tmp_path): assert isinstance(df, pd.DataFrame) assert list(df.columns) == [ - "_atom_site.id", - "_atom_site.type_symbol", - "_atom_site.Cartn_x", + ATOM_SITE_COLUMNS.ID, + ATOM_SITE_COLUMNS.TYPE_SYMBOL, + ATOM_SITE_COLUMNS.LABEL_ATOM_ID, + ATOM_SITE_COLUMNS.LABEL_COMP_ID, + ATOM_SITE_COLUMNS.CARTN_X, + CHEM_COMP_COLUMNS.MON_NSTD_FLAG, ] assert len(df) == 2 - assert df["_atom_site.id"].tolist() == ["N", "CA"] - assert df["_atom_site.type_symbol"].tolist() == ["N", "C"] - assert df["_atom_site.Cartn_x"].tolist() == ["1.0", "2.0"] + assert df[ATOM_SITE_COLUMNS.ID].tolist() == [1, 2] + assert df[ATOM_SITE_COLUMNS.TYPE_SYMBOL].tolist() == ["N", "C"] + assert df[ATOM_SITE_COLUMNS.LABEL_ATOM_ID].tolist() == ["N", "CA"] + assert df[ATOM_SITE_COLUMNS.CARTN_X].tolist() == [1.0, 2.0] + assert df[CHEM_COMP_COLUMNS.MON_NSTD_FLAG].tolist() == [True, True] def test_fetch_alphafold_protein_structure_wrong_uniprot_id(): @@ -127,11 +147,11 @@ def test_fetch_alphafold_returned_keys(tmp_path, monkeypatch): out = fetch_alphafold_protein_structure("Q8WP00", persist_upload=True) assert out.keys() == { - "structure_metadata_df", - "cif_df", - "pae_df", - "plddt_df", - "amino_acid_sequences_df", + DataKey.STRUCTURE_METADATA_DF, + DataKey.CIF_DF, + DataKey.PAE_MATRIX, + DataKey.PLDDT_DF, + DataKey.AMINO_ACID_SEQUENCES_DF, "messages", "visualization", } @@ -146,16 +166,16 @@ def test_fetch_alphafold_monomer_metadata(tmp_path, monkeypatch): ) out = fetch_alphafold_protein_structure("Q8WP00", persist_upload=True) - assert isinstance(out["structure_metadata_df"], pd.DataFrame) - assert not out["structure_metadata_df"].empty - assert out["structure_metadata_df"].iloc[0]["uniprot_accession"] == "Q8WP00" + assert isinstance(out[DataKey.STRUCTURE_METADATA_DF], pd.DataFrame) + assert not out[DataKey.STRUCTURE_METADATA_DF].empty + assert out[DataKey.STRUCTURE_METADATA_DF].iloc[0]["uniprot_accession"] == "Q8WP00" assert ( - out["structure_metadata_df"].iloc[0]["model_created_date"] + out[DataKey.STRUCTURE_METADATA_DF].iloc[0]["model_created_date"] == "2025-08-01T00:00:00Z" ) - assert out["structure_metadata_df"].iloc[0]["gene"] == "PRM1" + assert out[DataKey.STRUCTURE_METADATA_DF].iloc[0]["gene"] == "PRM1" assert ( - out["structure_metadata_df"].iloc[0]["model_used"] + out[DataKey.STRUCTURE_METADATA_DF].iloc[0]["model_used"] == "AlphaFold Monomer v2.0 pipeline" ) @@ -197,20 +217,20 @@ def test_fetch_alphafold_dfs_exist(tmp_path, monkeypatch): out = fetch_alphafold_protein_structure("Q8WP00", persist_upload=True) - cif_df = out["cif_df"] + cif_df = out[DataKey.CIF_DF] assert isinstance(cif_df, pd.DataFrame) assert not cif_df.empty - assert any(col.startswith("_atom_site.") for col in cif_df.columns) + assert any(col.startswith(ATOM_SITE_PREFIX) for col in cif_df.columns) - pae_df = out["pae_df"] - assert isinstance(pae_df, pd.DataFrame) - assert not pae_df.empty + pae_matrix = out[DataKey.PAE_MATRIX] + assert isinstance(pae_matrix, OutputItem) + assert len(pae_matrix.value) != 0 - plddt_df = out["plddt_df"] + plddt_df = out[DataKey.PLDDT_DF] assert isinstance(plddt_df, pd.DataFrame) assert not plddt_df.empty - seq_df = out["amino_acid_sequences_df"] + seq_df = out[DataKey.AMINO_ACID_SEQUENCES_DF] assert isinstance(seq_df, pd.DataFrame) assert not seq_df.empty @@ -280,11 +300,18 @@ def test_get_prot_structure_dfs_success(tmp_path, monkeypatch): """ data_test loop_ +_chem_comp.id +_chem_comp.mon_nstd_flag +SER y +# +loop_ _atom_site.id _atom_site.type_symbol +_atom_site.label_atom_id +_atom_site.label_comp_id _atom_site.Cartn_x -N N 1.0 -CA C 2.0 +1 N N SER 1.0 +2 C CA SER 2.0 """ ) @@ -303,34 +330,43 @@ def test_get_prot_structure_dfs_success(tmp_path, monkeypatch): out = get_monomer_structure_dfs("Q8WP00") - assert isinstance(out["structure_metadata_df"], pd.DataFrame) - assert not out["structure_metadata_df"].empty - assert out["structure_metadata_df"].iloc[0]["entry_id"] == "Q8WP00" + assert isinstance(out[DataKey.STRUCTURE_METADATA_DF], pd.DataFrame) + assert not out[DataKey.STRUCTURE_METADATA_DF].empty + assert out[DataKey.STRUCTURE_METADATA_DF].iloc[0]["entry_id"] == "Q8WP00" - assert isinstance(out["cif_df"], pd.DataFrame) - assert not out["cif_df"].empty - assert list(out["cif_df"].columns) == [ - "_atom_site.id", - "_atom_site.type_symbol", - "_atom_site.Cartn_x", + cif_df = out[DataKey.CIF_DF] + assert isinstance(cif_df, pd.DataFrame) + assert not cif_df.empty + assert list(cif_df.columns) == [ + ATOM_SITE_COLUMNS.ID, + ATOM_SITE_COLUMNS.TYPE_SYMBOL, + ATOM_SITE_COLUMNS.LABEL_ATOM_ID, + ATOM_SITE_COLUMNS.LABEL_COMP_ID, + ATOM_SITE_COLUMNS.CARTN_X, + CHEM_COMP_COLUMNS.MON_NSTD_FLAG, ] - assert out["cif_df"]["_atom_site.id"].tolist() == ["N", "CA"] - assert out["cif_df"]["_atom_site.type_symbol"].tolist() == ["N", "C"] - assert out["cif_df"]["_atom_site.Cartn_x"].tolist() == ["1.0", "2.0"] - - assert isinstance(out["pae_df"], pd.DataFrame) - assert not out["pae_df"].empty - assert out["pae_df"]["predicted_aligned_error"].tolist() == [0.1] + assert len(cif_df) == 2 + assert cif_df[ATOM_SITE_COLUMNS.ID].tolist() == [1, 2] + assert cif_df[ATOM_SITE_COLUMNS.TYPE_SYMBOL].tolist() == ["N", "C"] + assert cif_df[ATOM_SITE_COLUMNS.LABEL_ATOM_ID].tolist() == ["N", "CA"] + assert cif_df[ATOM_SITE_COLUMNS.CARTN_X].tolist() == [1.0, 2.0] + assert cif_df[CHEM_COMP_COLUMNS.MON_NSTD_FLAG].tolist() == [True, True] + + assert isinstance(out[DataKey.PAE_MATRIX], OutputItem) + assert isinstance(out[DataKey.PAE_MATRIX].value, np.ndarray) + assert ( + out[DataKey.PAE_MATRIX].value == 0.1 + ) # 0D array (only one value) TODO: Change this to something more reasonable? idk - assert isinstance(out["plddt_df"], pd.DataFrame) - assert not out["plddt_df"].empty - assert out["plddt_df"]["residueNumber"].tolist() == [1] - assert out["plddt_df"]["confidenceScore"].tolist() == [90] + assert isinstance(out[DataKey.PLDDT_DF], pd.DataFrame) + assert not out[DataKey.PLDDT_DF].empty + assert out[DataKey.PLDDT_DF]["residueNumber"].tolist() == [1] + assert out[DataKey.PLDDT_DF]["confidenceScore"].tolist() == [90] - assert isinstance(out["amino_acid_sequences_df"], pd.DataFrame) - assert not out["amino_acid_sequences_df"].empty - assert out["amino_acid_sequences_df"]["Protein ID"].tolist() == ["Q8WP00-1"] - assert out["amino_acid_sequences_df"]["Protein Sequence"].tolist() == ["AAAA"] + assert isinstance(out[DataKey.AMINO_ACID_SEQUENCES_DF], pd.DataFrame) + assert not out[DataKey.AMINO_ACID_SEQUENCES_DF].empty + assert out[DataKey.AMINO_ACID_SEQUENCES_DF]["Protein ID"].tolist() == ["Q8WP00-1"] + assert out[DataKey.AMINO_ACID_SEQUENCES_DF]["Protein Sequence"].tolist() == ["AAAA"] assert any(d.get("level") == logging.INFO for d in out["messages"]) or any( "Successfully loaded" in d.get("msg", "") for d in out["messages"] @@ -421,11 +457,13 @@ def test_get_amino_acid_sequences_df_and_handle_files(tmp_path, monkeypatch): out = handle_alphafold_files( {}, "P", "TESTSEQ", metadata_df, "P", persist_upload=False ) - assert "amino_acid_sequences_df" in out - assert isinstance(out["cif_df"], pd.DataFrame) and out["cif_df"].empty + assert DataKey.AMINO_ACID_SEQUENCES_DF in out + assert isinstance(out[DataKey.CIF_DF], pd.DataFrame) and out[DataKey.CIF_DF].empty assert isinstance(out["pae_df"], pd.DataFrame) and out["pae_df"].empty - assert isinstance(out["plddt_df"], pd.DataFrame) and out["plddt_df"].empty - assert isinstance(out["amino_acid_sequences_df"], pd.DataFrame) + assert ( + isinstance(out[DataKey.PLDDT_DF], pd.DataFrame) and out[DataKey.PLDDT_DF].empty + ) + assert isinstance(out[DataKey.AMINO_ACID_SEQUENCES_DF], pd.DataFrame) def test_upload_multimer_prediction_basic(tmp_path, monkeypatch): @@ -435,19 +473,41 @@ def test_upload_multimer_prediction_basic(tmp_path, monkeypatch): fasta = tmp_path / "seqs.fasta" fasta.write_text(">alpha|X\nAAAA\n") cif = tmp_path / "m.cif" + # Note that we only write the absolutely necessary columns here + # Also this is not biologically plausible cif.write_text( """ -data_test -loop_ -_atom_site.id -_atom_site.type_symbol -N N -""" + data_test + loop_ + _chem_comp.id + _chem_comp.mon_nstd_flag + SER y + GLY y + # + loop_ + _atom_site.id + _atom_site.label_atom_id + _atom_site.label_comp_id + _atom_site.auth_asym_id + _atom_site.label_seq_id + _atom_site.B_iso_or_equiv + 1 N SER A 1 99.99 + 2 CA SER A 1 67.76 + 3 CA SER A 2 33.65 + 4 O SER A 2 5.52 + 5 N GLY B 1 0 + 6 CA GLY B 1 13.37 + # + """ ) conf = tmp_path / "conf.json" - conf.write_text('[{"residueNumber":1, "confidenceScore":99}]') + conf.write_text( + '{"chain_iptm": [0.42, 0.89]}' + ) # Note that we do not use these metrics anywhere full = tmp_path / "full.json" - full.write_text('{"a": [1,2]}') + full.write_text( + '{"random_column": [1,2], "pae": [[1, 2], [3, 4]], "token_res_ids": [1, 2, 1]}' + ) job_request = tmp_path / "job_request.json" job_request.write_text( json.dumps( @@ -458,11 +518,18 @@ def test_upload_multimer_prediction_basic(tmp_path, monkeypatch): "sequences": [ { "proteinChain": { - "sequence": "AAAA", + "sequence": "PE", "count": 1, "useStructureTemplate": True, } - } + }, + { + "proteinChain": { + "sequence": "T", + "count": 1, + "useStructureTemplate": True, + } + }, ], "dialect": "alphafoldserver", "version": 3, @@ -484,7 +551,7 @@ def _copy(src, dest_dir): out = upload_multimer_prediction( entry_id="M1", - uniprot_ids="X", + uniprot_ids="X, Y", model_used="m", amino_acid_sequences=fasta, cif_file=cif, @@ -494,53 +561,135 @@ def _copy(src, dest_dir): persist_upload=True, ) - assert isinstance(out["structure_metadata_df"], pd.DataFrame) + assert isinstance(out[DataKey.STRUCTURE_METADATA_DF], pd.DataFrame) # check metadata contents - mdf = out["structure_metadata_df"] + mdf = out[DataKey.STRUCTURE_METADATA_DF] assert mdf.iloc[0]["entry_id"] == "M1" - assert mdf.iloc[0]["uniprot_ids"] == ["X"] + assert mdf.iloc[0]["uniprot_ids"] == ["X", "Y"] assert mdf.iloc[0]["model_used"] == "m" # cif contents - cif_df = out["cif_df"] + cif_df = out[DataKey.CIF_DF] assert isinstance(cif_df, pd.DataFrame) - assert list(cif_df.columns) == ["_atom_site.id", "_atom_site.type_symbol"] - assert cif_df["_atom_site.id"].tolist() == ["N"] - assert cif_df["_atom_site.type_symbol"].tolist() == ["N"] + assert list(cif_df.columns) == [ + ATOM_SITE_COLUMNS.ID, + ATOM_SITE_COLUMNS.LABEL_ATOM_ID, + ATOM_SITE_COLUMNS.LABEL_COMP_ID, + ATOM_SITE_COLUMNS.AUTH_ASYM_ID, + ATOM_SITE_COLUMNS.LABEL_SEQ_ID, + ATOM_SITE_COLUMNS.B_ISO_OR_EQUIV, + CHEM_COMP_COLUMNS.MON_NSTD_FLAG, + ] + assert cif_df[ATOM_SITE_COLUMNS.ID].tolist() == list(range(1, 7)) + assert cif_df[ATOM_SITE_COLUMNS.LABEL_ATOM_ID].tolist() == [ + "N", + "CA", + "CA", + "O", + "N", + "CA", + ] + assert cif_df[ATOM_SITE_COLUMNS.AUTH_ASYM_ID].tolist() == ["A"] * 4 + ["B"] * 2 + assert cif_df[ATOM_SITE_COLUMNS.B_ISO_OR_EQUIV].tolist() == [ + 99.99, + 67.76, + 33.65, + 5.52, + 0, + 13.37, + ] + assert cif_df[ATOM_SITE_COLUMNS.LABEL_COMP_ID].tolist() == ["SER"] * 4 + ["GLY"] * 2 + assert cif_df[CHEM_COMP_COLUMNS.MON_NSTD_FLAG].tolist() == [True] * 6 # confidence JSON - conf_df = out["confidence_df"] + conf_df = out[DataKey.CONFIDENCE_DF] assert isinstance(conf_df, pd.DataFrame) - assert conf_df["residueNumber"].tolist() == [1] - assert conf_df["confidenceScore"].tolist() == [99] + assert conf_df["chain_iptm"].tolist() == [0.42, 0.89] # full data normalization - full_df = out["full_data_df"] + full_df = out[DataKey.FULL_DATA_DF] assert isinstance(full_df, pd.DataFrame) - assert full_df.iloc[0]["a"] == [1, 2] + assert list(full_df.columns) == ["random_column"] + assert full_df.iloc[0]["random_column"] == [1, 2] # job request JSON - job_df = out["job_request_df"] + job_df = out[DataKey.JOB_REQUEST_DF] assert isinstance(job_df, pd.DataFrame) assert job_df.iloc[0]["name"] == "test_job" assert job_df.iloc[0]["dialect"] == "alphafoldserver" # sequences - seqs = out["amino_acid_sequences_df"] + seqs = out[DataKey.AMINO_ACID_SEQUENCES_DF] assert isinstance(seqs, pd.DataFrame) assert seqs["Protein Sequence"].tolist() == ["AAAA"] assert any(str(v).startswith("X") for v in seqs["Protein ID"].tolist()) + # pLDDT values + plddt_df = out[DataKey.PLDDT_DF] + assert isinstance(plddt_df, pd.DataFrame) + assert list(plddt_df.columns) == ["chainID", "residueNumber", "confidenceScore"] + assert plddt_df["confidenceScore"].tolist() == [ + 67.76, + 33.65, + 13.37, + ] # Keep only CA atoms + + # PAE values + pae_matrix = out[DataKey.PAE_MATRIX].value + assert isinstance(pae_matrix, np.ndarray) + assert pae_matrix[0, 0] == 1 + assert pae_matrix[0, 1] == 2 + assert pae_matrix[1, 0] == 3 + assert pae_matrix[1, 1] == 4 + upload_dir = tmp_path / "M1" assert upload_dir.exists() assert any(upload_dir.glob("*.fasta")) or any(upload_dir.glob("*.fa")) assert any(upload_dir.glob("*.json")) assert any(upload_dir.glob("*.cif")) + # Test no plDDT Data -> plddt_df should be None + cif.write_text( + """ + data_test + loop_ + _chem_comp.id + _chem_comp.mon_nstd_flag + SER y + GLY y + # + loop_ + _atom_site.id + _atom_site.label_atom_id + _atom_site.label_comp_id + _atom_site.auth_asym_id + _atom_site.label_seq_id + 1 N SER A 1 + 2 CA SER A 1 + 3 CA SER A 2 + 4 O SER A 2 + 5 N GLY B 1 + 6 CA GLY B 1 + # + """ + ) -# Additional comprehensive tests for error cases and edge cases + out = upload_multimer_prediction( + entry_id="M1", + uniprot_ids="X, Y", + model_used="m", + amino_acid_sequences=fasta, + cif_file=cif, + confidence_file=conf, + full_data_file=full, + job_request_file=job_request, + persist_upload=True, + ) + + assert out[DataKey.PLDDT_DF] is None +# Additional comprehensive tests for error cases and edge cases def test_get_monomer_metadata_df_existing_csv(tmp_path, monkeypatch): """Test reading existing monomer metadata CSV""" csv_path = tmp_path / "alphafold_monomer_metadata.csv" @@ -605,18 +754,23 @@ def test_to_fasta_lowercase_conversion(): def test_upload_multimer_prediction_no_persist(tmp_path, monkeypatch): - """Test upload_multimer_prediction with persist_upload=False""" + """ + Test upload_multimer_prediction with persist_upload=False. + Also tests full_data without PAE values + """ monkeypatch.setattr(paths, "ALPHAFOLD_MONOMER_PATH", tmp_path) monkeypatch.setattr(paths, "ALPHAFOLD_MULTIMER_PATH", tmp_path) fasta = tmp_path / "seqs.fasta" fasta.write_text(">alpha|X\nAAAA\n") cif = tmp_path / "m.cif" - cif.write_text("data_test\nloop_\n_atom_site.id\nN\n") + cif.write_text( + "data_test\nloop_\n_chem_comp.id\n_chem_comp.mon_nstd_flag\nSER y\nloop_\n#\n_atom_site.id\n_atom_site.label_comp_id\nN SER\n" + ) conf = tmp_path / "conf.json" conf.write_text('[{"residueNumber":1, "confidenceScore":99}]') full = tmp_path / "full.json" - full.write_text('{"a": [1,2]}') + full.write_text('{"a": [1,2], "token_res_ids": [1], "pae": [[2]]}') job_request = tmp_path / "job_request.json" job_request.write_text( json.dumps( @@ -653,10 +807,11 @@ def test_upload_multimer_prediction_no_persist(tmp_path, monkeypatch): ) # verify dataframes are returned - assert isinstance(out["structure_metadata_df"], pd.DataFrame) - assert isinstance(out["cif_df"], pd.DataFrame) - assert isinstance(out["job_request_df"], pd.DataFrame) - assert out["job_request_df"].iloc[0]["name"] == "test_job_2" + assert isinstance(out[DataKey.STRUCTURE_METADATA_DF], pd.DataFrame) + assert isinstance(out[DataKey.CIF_DF], pd.DataFrame) + assert isinstance(out[DataKey.JOB_REQUEST_DF], pd.DataFrame) + assert out[DataKey.PAE_MATRIX].value is not None + assert out[DataKey.JOB_REQUEST_DF].iloc[0]["name"] == "test_job_2" # directory should still exist (created for the entry) upload_dir = tmp_path / "M2" assert not upload_dir.exists() @@ -692,7 +847,9 @@ def test_get_prot_structure_dfs_missing_fasta(tmp_path, monkeypatch): # create CIF but no FASTA cif = prot_dir / "test.cif" - cif.write_text("data_test\nloop_\n_atom_site.id\nN\n") + cif.write_text( + "data_test\nloop_\n_chem_comp.id\n_chem_comp.mon_nstd_flag\nSER y\nloop_\n#\n_atom_site.id\n_atom_site.label_comp_id\nN SER\n" + ) with pytest.raises(FileNotFoundError, match="No FASTA file found"): get_monomer_structure_dfs("NOFASTA") @@ -712,7 +869,9 @@ def test_get_prot_structure_dfs_missing_json(tmp_path, monkeypatch): # create CIF and FASTA but no JSON cif = prot_dir / "test.cif" - cif.write_text("data_test\nloop_\n_atom_site.id\nN\n") + cif.write_text( + "data_test\nloop_\n_chem_comp.id\n_chem_comp.mon_nstd_flag\nSER y\nloop_\n#\n_atom_site.id\n_atom_site.label_comp_id\nN SER\n" + ) fasta = prot_dir / "test.fasta" # valid header for parse_fasta_id (expects at least one "|" in the id) @@ -820,18 +979,30 @@ def test_get_cif_df_from_disk_multiple_cif_warns(tmp_path): """ data_test loop_ +_chem_comp.id +_chem_comp.mon_nstd_flag +SER y +# +loop_ _atom_site.id _atom_site.type_symbol -N N +_atom_site.label_comp_id +N N SER """ ) cif2.write_text( """ data_test loop_ +_chem_comp.id +_chem_comp.mon_nstd_flag +SER y +# +loop_ _atom_site.id _atom_site.type_symbol -CA C +_atom_site.label_comp_id +CA C SER """ ) @@ -872,9 +1043,15 @@ def test_get_multimer_structure_dfs_success(tmp_path, monkeypatch): """ data_test loop_ +_chem_comp.id +_chem_comp.mon_nstd_flag +SER y +# +loop_ _atom_site.id _atom_site.type_symbol -N N +_atom_site.label_comp_id +N N SER """ ) @@ -885,7 +1062,9 @@ def test_get_multimer_structure_dfs_success(tmp_path, monkeypatch): full_data = prot_dir / "full.json" job_request = prot_dir / "job_request.json" confidence.write_text(json.dumps({"chain_iptm": [0.75]})) - full_data.write_text(json.dumps({"pae": [[0.1, 0.2], [0.3, 0.4]]})) + full_data.write_text( + json.dumps({"pae": [[0.1, 0.2], [0.3, 0.4]], "token_res_ids": [1]}) + ) job_request.write_text( json.dumps( [ @@ -909,17 +1088,17 @@ def test_get_multimer_structure_dfs_success(tmp_path, monkeypatch): ) out = get_multimer_structure_dfs("M1") - assert isinstance(out["structure_metadata_df"], pd.DataFrame) - assert isinstance(out["cif_df"], pd.DataFrame) - assert isinstance(out["amino_acid_sequences_df"], pd.DataFrame) - assert isinstance(out["confidence_df"], pd.DataFrame) - assert isinstance(out["full_data_df"], pd.DataFrame) - assert isinstance(out["job_request_df"], pd.DataFrame) - - assert "chain_iptm" in out["confidence_df"].columns - assert "pae" in out["full_data_df"].columns - assert out["job_request_df"].iloc[0]["name"] == "multimer_job" - assert out["job_request_df"].iloc[0]["version"] == 3 + assert isinstance(out[DataKey.STRUCTURE_METADATA_DF], pd.DataFrame) + assert isinstance(out[DataKey.CIF_DF], pd.DataFrame) + assert isinstance(out[DataKey.AMINO_ACID_SEQUENCES_DF], pd.DataFrame) + assert isinstance(out[DataKey.CONFIDENCE_DF], pd.DataFrame) + assert isinstance(out[DataKey.FULL_DATA_DF], pd.DataFrame) + assert isinstance(out[DataKey.JOB_REQUEST_DF], pd.DataFrame) + + assert "chain_iptm" in out[DataKey.CONFIDENCE_DF].columns + assert "pae" not in out[DataKey.FULL_DATA_DF].columns + assert out[DataKey.JOB_REQUEST_DF].iloc[0]["name"] == "multimer_job" + assert out[DataKey.JOB_REQUEST_DF].iloc[0]["version"] == 3 assert any(m.get("level") == logging.INFO for m in out["messages"]) or any( "Successfully loaded" in str(m.get("msg", "")) for m in out["messages"] @@ -956,9 +1135,15 @@ def test_get_multimer_structure_dfs_json_fallback_warns(tmp_path, monkeypatch): """ data_test loop_ +_chem_comp.id +_chem_comp.mon_nstd_flag +SER y +# +loop_ _atom_site.id _atom_site.type_symbol -N N +_atom_site.label_comp_id +N N SER """ ) diff --git a/backend/tests/protzilla/importing/test_pae_matrix_reduction.py b/backend/tests/protzilla/importing/test_pae_matrix_reduction.py new file mode 100644 index 000000000..4daf5449e --- /dev/null +++ b/backend/tests/protzilla/importing/test_pae_matrix_reduction.py @@ -0,0 +1,162 @@ +import numpy as np +import pandas as pd +from backend.protzilla.importing.alphafold_protein_structure_load import ( + reduce_pae_to_per_amino_acid, +) +import pytest + +# These test cases were generated by AI (Gemini) but have been manually checked + + +@pytest.fixture +def empty_cif_df(): + """Returns an empty atom_site DataFrame with required columns.""" + return pd.DataFrame( + columns=[ + "_atom_site.label_entity_id", + "_atom_site.label_seq_id", + "_atom_site.label_atom_id", + ] + ) + + +def test_no_duplicates(empty_cif_df): + """Edge Case 1: Every residue has exactly one token. + + The PAE matrix should remain entirely untouched. + """ + pae_matrix = np.array([[1.0, 2.0], [3.0, 4.0]]) + token_res_ids = [1, 2] + + result = reduce_pae_to_per_amino_acid(pae_matrix, token_res_ids, empty_cif_df) + + assert np.array_equal(result, pae_matrix) + + +def test_duplicates_fallback_to_first_token(empty_cif_df): + """Edge Case 2: Multi-token residue, but length mismatch with CIF. + + Should fall back to keeping the first token (offset 0) and deleting the rest. + """ + # 3 tokens: Residue 1 has 2 tokens, Residue 2 has 1 token. + pae_matrix = np.array([[10, 11, 12], [20, 21, 22], [30, 31, 32]]) + token_res_ids = [1, 1, 2] + + # Keeping token 0 (first of res 1) and token 2 (res 2). Token 1 should be deleted. + expected_indices = [0, 2] + expected_matrix = pae_matrix[np.ix_(expected_indices, expected_indices)] + + result = reduce_pae_to_per_amino_acid(pae_matrix, token_res_ids, empty_cif_df) + + assert np.array_equal(result, expected_matrix) + + +def test_duplicates_keep_ca_atom(): + """Edge Case 3: Run length matches CIF length, and exactly one CA atom is found. + + Should keep the token corresponding exactly to the 'CA' atom position. + """ + # 4 tokens: Residue 1 has 3 tokens, Residue 2 has 1 token. + pae_matrix = np.diag([1.0, 2.0, 3.0, 4.0]) + token_res_ids = [1, 1, 1, 2] + + # CIF setup: 3 atoms for chain 1, residue 1. 'CA' sits at relative index 1. + cif_df = pd.DataFrame( + { + "_atom_site.label_entity_id": ["1", "1", "1"], + "_atom_site.label_seq_id": [1, 1, 1], + "_atom_site.label_atom_id": ["N", "CA", "C"], + } + ) + + # Expected: Keep global index 1 (the CA atom) and global index 3 (residue 2). + # Global indices 0 and 2 should be wiped out. + expected_indices = [1, 3] + expected_matrix = pae_matrix[np.ix_(expected_indices, expected_indices)] + + result = reduce_pae_to_per_amino_acid(pae_matrix, token_res_ids, cif_df) + + assert np.array_equal(result, expected_matrix) + + +def test_duplicates_cif_match_but_no_ca_fallback(): + """Edge Case 4a: Run length matches CIF length, but no CA atom exists. + + Should fall back to keeping the first token (offset 0). + """ + pae_matrix = np.diag([10, 20, 30]) + token_res_ids = [1, 1, 2] + + # CIF matches length (2 atoms), but neither is 'CA' + cif_df = pd.DataFrame( + { + "_atom_site.label_entity_id": ["1", "1"], + "_atom_site.label_seq_id": [1, 1], + "_atom_site.label_atom_id": ["N", "O"], + } + ) + + # Expected to keep global index 0 (fallback) and global index 2. + expected_indices = [0, 2] + expected_matrix = pae_matrix[np.ix_(expected_indices, expected_indices)] + + result = reduce_pae_to_per_amino_acid(pae_matrix, token_res_ids, cif_df) + + assert np.array_equal(result, expected_matrix) + + +def test_duplicates_cif_match_multiple_ca_fallback(): + """Edge Case 4b: Run length matches CIF length, but multiple CA atoms exist. + + Should fall back to keeping the first token (offset 0). + """ + pae_matrix = np.diag([10, 20, 30]) + token_res_ids = [1, 1, 2] + + # CIF matches length (2 atoms), but both claim to be 'CA' + cif_df = pd.DataFrame( + { + "_atom_site.label_entity_id": ["1", "1"], + "_atom_site.label_seq_id": [1, 1], + "_atom_site.label_atom_id": ["CA", "CA"], + } + ) + + # Expected to keep global index 0 (fallback) and global index 2. + expected_indices = [0, 2] + expected_matrix = pae_matrix[np.ix_(expected_indices, expected_indices)] + + result = reduce_pae_to_per_amino_acid(pae_matrix, token_res_ids, cif_df) + + assert np.array_equal(result, expected_matrix) + + +def test_multiple_chains_tracking(): + """Edge Case 5: The system has multiple chains. + + Verifies that `current_chain_idx` increments whenever `res_id == 1` starts a run, + and queries the correct stringified `_atom_site.label_entity_id`. + """ + # Chain 1: res 1 (len 1), res 2 (len 1) + # Chain 2: res 1 (len 2) -> Triggered by encountering 1 again + token_res_ids = [1, 2, 1, 1] + pae_matrix = np.diag([100, 200, 300, 400]) + + cif_df = pd.DataFrame( + { + "_atom_site.label_entity_id": ["1", "1", "2", "2"], + "_atom_site.label_seq_id": [1, 2, 1, 1], + "_atom_site.label_atom_id": ["CA", "CA", "N", "CA"], + } + ) + + # Chain 1, Res 1 (idx 0): len 1 -> Keep + # Chain 1, Res 2 (idx 1): len 1 -> Keep + # Chain 2, Res 1 (idx 2, 3): len 2 -> Matches CIF length for chain '2'. + # CA is at relative index 1 (global idx 3). Global idx 2 is dropped. + expected_indices = [0, 1, 3] + expected_matrix = pae_matrix[np.ix_(expected_indices, expected_indices)] + + result = reduce_pae_to_per_amino_acid(pae_matrix, token_res_ids, cif_df) + + assert np.array_equal(result, expected_matrix) diff --git a/frontend/src/components/app/run-screen/node-editor/StepNode.tsx b/frontend/src/components/app/run-screen/node-editor/StepNode.tsx index a83c617fd..7adb84d93 100644 --- a/frontend/src/components/app/run-screen/node-editor/StepNode.tsx +++ b/frontend/src/components/app/run-screen/node-editor/StepNode.tsx @@ -53,7 +53,7 @@ const DATA_TYPE_ICON_MAP: Partial> = { full_data_df: handleFullDataIcon, gene_mapping_df: handleDnaIcon, metadata_df: handleMetadataIcon, - pae_df: handlePaeIcon, + pae_matrix: handlePaeIcon, peptide_df: handlePeptidesIcon, plddt_df: handlePlddtIcon, protein_df: handleProteinIcon,