From 9f8b5b23135d51f3326826c259be58c806c73fa9 Mon Sep 17 00:00:00 2001 From: jorisfu Date: Mon, 11 May 2026 11:27:21 +0200 Subject: [PATCH 01/33] feat: add pLDDT to CL results table --- .../data_analysis/crosslinking_validation.py | 12 ++++++++++++ backend/protzilla/methods/data_analysis.py | 3 ++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/backend/protzilla/data_analysis/crosslinking_validation.py b/backend/protzilla/data_analysis/crosslinking_validation.py index 3a410b462..6b18f2eab 100644 --- a/backend/protzilla/data_analysis/crosslinking_validation.py +++ b/backend/protzilla/data_analysis/crosslinking_validation.py @@ -355,6 +355,7 @@ def monomer_validation( crosslinker_information: dict[str, list[float]], cif_df: pd.DataFrame, amino_acid_sequences_df: pd.DataFrame, + plddt_df: pd.DataFrame, ) -> dict: """ Validates crosslinking data for a monomeric protein structure by checking @@ -376,11 +377,14 @@ def monomer_validation( structure_metadata_df=structure_metadata_df, cif_df=cif_df, amino_acid_sequences_df=amino_acid_sequences_df, + plddt_df=plddt_df, valid_ids=valid_ids, id_column_name="_atom_site.pdbx_sifts_xref_db_acc", structures_to_validate=[protein_id], ) +def monomer_validation_with_pae(): + pass def get_protein_id_from_sequence(amino_acid_sequences_df, target_sequence): """ @@ -485,6 +489,7 @@ def validate_with_angstrom_deviation( crosslinker_information: dict[str, list[float]], structure_metadata_df: pd.DataFrame, cif_df: pd.DataFrame, + plddt_df: pd.DataFrame, amino_acid_sequences_df: pd.DataFrame, valid_ids: dict, id_column_name: str, @@ -500,6 +505,7 @@ def validate_with_angstrom_deviation( :param crosslinker_information: Dictionary mapping crosslinker names to a list of three floats: [crosslinker_length, upper_accepted_deviation, lower_accepted_deviation]. :param cif_df: DataFrame containing CIF information (predicted coordinates of all the protein's atoms). + :param plddt_df: DataFrame containing the local AlphaFold pLDDT values for each residue. :param amino_acid_sequences_df: Dataframe that contains all known amino acid sequences. :param valid_ids: Dictionary mapping protein IDs to their valid chain/entity identifiers in the CIF data. :param id_column_name: The column name in the cif_df to use for matching against valid_ids. @@ -558,6 +564,8 @@ def check_crosslink(crosslink: pd.Series) -> pd.Series: protein_sequence2 = get_protein_sequence_from_df( amino_acid_sequences_df=amino_acid_sequences_df, protein_id=protein_id2 ) + plddt_at_position1 = plddt_df.query("residueNumber == @crosslink.crosslinker_position1").iloc[0]["confidenceScore"] + plddt_at_position2 = plddt_df.query("residueNumber == @crosslink.crosslinker_position2").iloc[0]["confidenceScore"] predicted_distance = get_distance_between_two_amino_acids_in_angstrom( amino_acid_position1=crosslink.crosslinker_position1, @@ -599,6 +607,8 @@ def check_crosslink(crosslink: pd.Series) -> pd.Series: "valid_crosslink": valid, "crosslinker_position1": crosslink.crosslinker_position1, "crosslinker_position2": crosslink.crosslinker_position2, + "plddt_at_position1": plddt_at_position1, + "plddt_at_position2": plddt_at_position2, } ) @@ -608,6 +618,8 @@ def check_crosslink(crosslink: pd.Series) -> pd.Series: "valid_crosslink", "crosslinker_position1", "crosslinker_position2", + "plddt_at_position1", + "plddt_at_position2", ] relevant_crosslinks_df["crosslinker_position1"] = relevant_crosslinks_df[ diff --git a/backend/protzilla/methods/data_analysis.py b/backend/protzilla/methods/data_analysis.py index 52410f64c..5d55d5fe8 100644 --- a/backend/protzilla/methods/data_analysis.py +++ b/backend/protzilla/methods/data_analysis.py @@ -93,6 +93,7 @@ monomer_diagrams, multimer_diagrams, monomer_validation, + monomer_validation_with_pae, multimer_validation, ) from backend.protzilla.run import Run @@ -2432,7 +2433,7 @@ class CrosslinkingValidationWithAngstromDeviation( operation = "Crosslinking Validation" method_description = "Validates crosslinks within the one protein structure based on the difference between the length of the crosslinker and the distance between the amino acids which were connected by the crosslinker. (in Ångström)" calc_method = staticmethod(monomer_validation) - plot_method = staticmethod(monomer_diagrams) + # plot_method = staticmethod(monomer_diagrams) def create_form(self): return Form(label="Ångström Deviation - Monomer", input_fields=[]) From 24eb5143e8c82e147b238db7d1ae884af320709b Mon Sep 17 00:00:00 2001 From: jorisfu Date: Mon, 11 May 2026 15:13:08 +0200 Subject: [PATCH 02/33] feat: add trivial PAE based validation --- backend/protzilla/constants/option_types.py | 4 ++ .../data_analysis/crosslinking_validation.py | 67 ++++++++++++++++--- backend/protzilla/methods/data_analysis.py | 17 ++++- 3 files changed, 78 insertions(+), 10 deletions(-) diff --git a/backend/protzilla/constants/option_types.py b/backend/protzilla/constants/option_types.py index 678efe9cb..3d4cffef7 100644 --- a/backend/protzilla/constants/option_types.py +++ b/backend/protzilla/constants/option_types.py @@ -59,6 +59,10 @@ class PValueColumnName(StrEnum): protein_id = "Protein ID" ptm = "PTM" +class CrosslinkingValidationCriterion(Enum): + manual_bounds = "Manual Bounds (set below)" + max_pae = "CL length +/- maximum PAE between sites" + min_pae = "CL length +/- minimum PAE between sites" FC_SIGNIFICANCE_COLUMNS = ["Protein ID", "fc_z_score", "fc_significance"] CORRECTED_P_VALUES_COLUMNS = [ diff --git a/backend/protzilla/data_analysis/crosslinking_validation.py b/backend/protzilla/data_analysis/crosslinking_validation.py index 6b18f2eab..9031e4564 100644 --- a/backend/protzilla/data_analysis/crosslinking_validation.py +++ b/backend/protzilla/data_analysis/crosslinking_validation.py @@ -2,6 +2,9 @@ import ast import math +from typing import TYPE_CHECKING + +from backend.protzilla.constants.option_types import CrosslinkingValidationCriterion import pandas as pd import numpy as np import re @@ -355,7 +358,9 @@ def monomer_validation( crosslinker_information: dict[str, list[float]], cif_df: pd.DataFrame, amino_acid_sequences_df: pd.DataFrame, + pae_df: pd.DataFrame, plddt_df: pd.DataFrame, + validation_criterion: CrosslinkingValidationCriterion ) -> dict: """ Validates crosslinking data for a monomeric protein structure by checking @@ -367,6 +372,8 @@ def monomer_validation( allowed distance boundaries (e.g., [min_dist, max_dist]). :param cif_df: DataFrame containing mmCIF information. :param amino_acid_sequences_df: DataFrame containing known amino acid sequences. + :param pae_df: DataFrame containing AlphaFold PAE data. + :param plddt_df: DataFrame containing AlphaFold pLDDT data. :return: A dictionary containing the validation results and distance metrics. """ protein_id = structure_metadata_df["uniprot_accession"].iloc[0] @@ -377,15 +384,19 @@ def monomer_validation( structure_metadata_df=structure_metadata_df, cif_df=cif_df, amino_acid_sequences_df=amino_acid_sequences_df, + pae_df=pae_df, plddt_df=plddt_df, valid_ids=valid_ids, id_column_name="_atom_site.pdbx_sifts_xref_db_acc", structures_to_validate=[protein_id], + validation_criterion=validation_criterion ) + def monomer_validation_with_pae(): pass + def get_protein_id_from_sequence(amino_acid_sequences_df, target_sequence): """ Finds the Protein ID(s) for a given exact protein sequence. @@ -490,10 +501,12 @@ def validate_with_angstrom_deviation( structure_metadata_df: pd.DataFrame, cif_df: pd.DataFrame, plddt_df: pd.DataFrame, + pae_df: pd.DataFrame, amino_acid_sequences_df: pd.DataFrame, valid_ids: dict, id_column_name: str, structures_to_validate: list, + validation_criterion: CrosslinkingValidationCriterion ) -> dict: """ Validates crosslinks by comparing the crosslinker lengths with the distances between the linked @@ -506,6 +519,7 @@ def validate_with_angstrom_deviation( [crosslinker_length, upper_accepted_deviation, lower_accepted_deviation]. :param cif_df: DataFrame containing CIF information (predicted coordinates of all the protein's atoms). :param plddt_df: DataFrame containing the local AlphaFold pLDDT values for each residue. + :param pae_df: DataFrame containing the PAE values for each residue pair. :param amino_acid_sequences_df: Dataframe that contains all known amino acid sequences. :param valid_ids: Dictionary mapping protein IDs to their valid chain/entity identifiers in the CIF data. :param id_column_name: The column name in the cif_df to use for matching against valid_ids. @@ -555,6 +569,9 @@ def validate_with_angstrom_deviation( relevant_crosslinks_df, amino_acid_sequences_df ) + pae_string = str(pae_df["predicted_aligned_error"].iloc[0]) + pae_matrix = np.array(ast.literal_eval(pae_string)) + def check_crosslink(crosslink: pd.Series) -> pd.Series: protein_id1 = crosslink.Protein_id1 protein_id2 = crosslink.Protein_id2 @@ -564,8 +581,19 @@ def check_crosslink(crosslink: pd.Series) -> pd.Series: protein_sequence2 = get_protein_sequence_from_df( amino_acid_sequences_df=amino_acid_sequences_df, protein_id=protein_id2 ) - plddt_at_position1 = plddt_df.query("residueNumber == @crosslink.crosslinker_position1").iloc[0]["confidenceScore"] - plddt_at_position2 = plddt_df.query("residueNumber == @crosslink.crosslinker_position2").iloc[0]["confidenceScore"] + plddt_at_position1 = plddt_df.query( + "residueNumber == @crosslink.crosslinker_position1" + ).iloc[0]["confidenceScore"] + plddt_at_position2 = plddt_df.query( + "residueNumber == @crosslink.crosslinker_position2" + ).iloc[0]["confidenceScore"] + + pae_x_position1 = pae_matrix[ + crosslink.crosslinker_position1, crosslink.crosslinker_position2 + ] # Using position1 as scored residue + pae_x_position2 = pae_matrix[ + crosslink.crosslinker_position2, crosslink.crosslinker_position1 + ] # Using position2 as scored residue predicted_distance = get_distance_between_two_amino_acids_in_angstrom( amino_acid_position1=crosslink.crosslinker_position1, @@ -587,13 +615,30 @@ def check_crosslink(crosslink: pd.Series) -> pd.Series: f"Missing required information regarding crosslinker length " f"and/or accepted deviation for crosslinker '{crosslink.Crosslinker}'." ) - # Fallback to default deviation bounds when not explicitly provided - accepted_distance_lower_bound = crosslinker_length - ( - accepted_deviation_lower_bound or crosslinker_length - ) - accepted_distance_upper_bound = ( - accepted_deviation_upper_bound or float("inf") - ) + crosslinker_length + + accepted_distance_lower_bound: float = 0.0 + accepted_distance_upper_bound: float = 0.0 + + match validation_criterion: + case CrosslinkingValidationCriterion.manual_bounds.value: + # Fallback to default deviation bounds when not explicitly provided + accepted_distance_lower_bound = crosslinker_length - ( + accepted_deviation_lower_bound or crosslinker_length + ) + accepted_distance_upper_bound = ( + accepted_deviation_upper_bound or float("inf") + ) + crosslinker_length + + + case CrosslinkingValidationCriterion.max_pae.value: + pae_tolerance = max(pae_x_position1, pae_x_position2) + accepted_distance_lower_bound = float(max(crosslinker_length - pae_tolerance, 0.0)) + accepted_distance_upper_bound = float(crosslinker_length + pae_tolerance) + + case CrosslinkingValidationCriterion.min_pae.value: + pae_tolerance = min(pae_x_position1, pae_x_position2) + accepted_distance_lower_bound = float(max(crosslinker_length - pae_tolerance, 0.0)) + accepted_distance_upper_bound = float(crosslinker_length + pae_tolerance) valid = ( accepted_distance_lower_bound @@ -609,6 +654,8 @@ def check_crosslink(crosslink: pd.Series) -> pd.Series: "crosslinker_position2": crosslink.crosslinker_position2, "plddt_at_position1": plddt_at_position1, "plddt_at_position2": plddt_at_position2, + "pae_x_position1": pae_x_position1, + "pae_x_position2": pae_x_position2, } ) @@ -620,6 +667,8 @@ def check_crosslink(crosslink: pd.Series) -> pd.Series: "crosslinker_position2", "plddt_at_position1", "plddt_at_position2", + "pae_x_position1", + "pae_x_position2", ] relevant_crosslinks_df["crosslinker_position1"] = relevant_crosslinks_df[ diff --git a/backend/protzilla/methods/data_analysis.py b/backend/protzilla/methods/data_analysis.py index 5d55d5fe8..2abf128f2 100644 --- a/backend/protzilla/methods/data_analysis.py +++ b/backend/protzilla/methods/data_analysis.py @@ -3,6 +3,7 @@ import ast from backend.protzilla.constants.option_types import ( + CrosslinkingValidationCriterion, LogBaseWithNoneType, SimpleImputerStrategyType, ) @@ -70,6 +71,7 @@ MultiSelectField, NumberField, TextField, + FormDivider, ) from backend.protzilla.steps import Step, Section from backend.protzilla.step_manager import StepManager @@ -2436,7 +2438,20 @@ class CrosslinkingValidationWithAngstromDeviation( # plot_method = staticmethod(monomer_diagrams) def create_form(self): - return Form(label="Ångström Deviation - Monomer", input_fields=[]) + return Form( + label="Ångström Deviation For Monomer Structures", + input_fields=[ + DropdownField( + name="validation_criterion", + label="Validation criterion", + options=CrosslinkingValidationCriterion, + value=CrosslinkingValidationCriterion.manual_bounds, + ), + FormDivider( + label="Crosslinker lengths and bounds", + ), + ], + ) class CrosslinkingValidationWithAngstromDeviationForMultimer( From 1ddc1039acfac0436042b8f77c8a7d92f4f5391d Mon Sep 17 00:00:00 2001 From: jorisfu Date: Tue, 12 May 2026 13:39:49 +0200 Subject: [PATCH 03/33] feat: trivial plDDT based validation --- backend/protzilla/constants/option_types.py | 3 + .../data_analysis/crosslinking_validation.py | 59 ++++++++++++++++--- backend/protzilla/methods/data_analysis.py | 2 +- 3 files changed, 54 insertions(+), 10 deletions(-) diff --git a/backend/protzilla/constants/option_types.py b/backend/protzilla/constants/option_types.py index 3d4cffef7..30b2a61bc 100644 --- a/backend/protzilla/constants/option_types.py +++ b/backend/protzilla/constants/option_types.py @@ -59,10 +59,13 @@ class PValueColumnName(StrEnum): protein_id = "Protein ID" ptm = "PTM" + class CrosslinkingValidationCriterion(Enum): manual_bounds = "Manual Bounds (set below)" max_pae = "CL length +/- maximum PAE between sites" min_pae = "CL length +/- minimum PAE between sites" + plddt_adjusted = "plDDT adjusted" + FC_SIGNIFICANCE_COLUMNS = ["Protein ID", "fc_z_score", "fc_significance"] CORRECTED_P_VALUES_COLUMNS = [ diff --git a/backend/protzilla/data_analysis/crosslinking_validation.py b/backend/protzilla/data_analysis/crosslinking_validation.py index 9031e4564..68df6da46 100644 --- a/backend/protzilla/data_analysis/crosslinking_validation.py +++ b/backend/protzilla/data_analysis/crosslinking_validation.py @@ -2,7 +2,7 @@ import ast import math -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable from backend.protzilla.constants.option_types import CrosslinkingValidationCriterion import pandas as pd @@ -360,7 +360,7 @@ def monomer_validation( amino_acid_sequences_df: pd.DataFrame, pae_df: pd.DataFrame, plddt_df: pd.DataFrame, - validation_criterion: CrosslinkingValidationCriterion + validation_criterion: CrosslinkingValidationCriterion, ) -> dict: """ Validates crosslinking data for a monomeric protein structure by checking @@ -389,7 +389,7 @@ def monomer_validation( valid_ids=valid_ids, id_column_name="_atom_site.pdbx_sifts_xref_db_acc", structures_to_validate=[protein_id], - validation_criterion=validation_criterion + validation_criterion=validation_criterion, ) @@ -506,7 +506,7 @@ def validate_with_angstrom_deviation( valid_ids: dict, id_column_name: str, structures_to_validate: list, - validation_criterion: CrosslinkingValidationCriterion + validation_criterion: CrosslinkingValidationCriterion, ) -> dict: """ Validates crosslinks by comparing the crosslinker lengths with the distances between the linked @@ -629,16 +629,57 @@ def check_crosslink(crosslink: pd.Series) -> pd.Series: accepted_deviation_upper_bound or float("inf") ) + crosslinker_length - case CrosslinkingValidationCriterion.max_pae.value: pae_tolerance = max(pae_x_position1, pae_x_position2) - accepted_distance_lower_bound = float(max(crosslinker_length - pae_tolerance, 0.0)) - accepted_distance_upper_bound = float(crosslinker_length + pae_tolerance) + accepted_distance_lower_bound = float( + max(crosslinker_length - pae_tolerance, 0.0) + ) + accepted_distance_upper_bound = float( + crosslinker_length + pae_tolerance + ) case CrosslinkingValidationCriterion.min_pae.value: pae_tolerance = min(pae_x_position1, pae_x_position2) - accepted_distance_lower_bound = float(max(crosslinker_length - pae_tolerance, 0.0)) - accepted_distance_upper_bound = float(crosslinker_length + pae_tolerance) + accepted_distance_lower_bound = float( + max(crosslinker_length - pae_tolerance, 0.0) + ) + accepted_distance_upper_bound = float( + crosslinker_length + pae_tolerance + ) + + case CrosslinkingValidationCriterion.plddt_adjusted.value: + cl_half = crosslinker_length / 2 + + get_plddt_factor: Callable[[float], float] = lambda plddt: 1 - ( + plddt / 100 + ) + + # Strict mode: plDDT of 0 (factor 1) allows +/- half length for each half + # get_cl_half_tolerated_length_range: Callable[ + # [float, float], tuple[float, float] + # ] = lambda cl_half, plddt_factor: ( + # cl_half * (1 - plddt_factor), + # cl_half * (1 + plddt_factor), + # ) + + # Less strict mode: plDDT of 0 (factor 1) allows +/- total CL length for each half + # TODO: The calculations below get ugly when plDDT < 50 because we'd get negative lengths + get_cl_half_tolerated_length_range: Callable[ + [float, float], tuple[float, float] + ] = lambda cl_half, plddt_factor: ( + cl_half * 2 * (1 - plddt_factor), + cl_half * 2 * (1 + plddt_factor), + ) + + plddt_factor_pos1 = get_plddt_factor(plddt_at_position1) + plddt_factor_pos2 = get_plddt_factor(plddt_at_position2) + + cl_half1_min, cl_half1_max = get_cl_half_tolerated_length_range(cl_half, plddt_factor_pos1) + cl_half2_min, cl_half2_max = get_cl_half_tolerated_length_range(cl_half, plddt_factor_pos2) + + accepted_distance_lower_bound = cl_half1_min + cl_half2_min + accepted_distance_upper_bound = cl_half1_max + cl_half2_max + valid = ( accepted_distance_lower_bound diff --git a/backend/protzilla/methods/data_analysis.py b/backend/protzilla/methods/data_analysis.py index 2abf128f2..ed84865b2 100644 --- a/backend/protzilla/methods/data_analysis.py +++ b/backend/protzilla/methods/data_analysis.py @@ -2449,7 +2449,7 @@ def create_form(self): ), FormDivider( label="Crosslinker lengths and bounds", - ), + ), ], ) From 77498d5f230c1a978e074eefab45a2064ca33511 Mon Sep 17 00:00:00 2001 From: jorisfu Date: Fri, 15 May 2026 11:12:07 +0200 Subject: [PATCH 04/33] fix: broken formula --- .../data_analysis/crosslinking_validation.py | 37 +++++-------------- 1 file changed, 9 insertions(+), 28 deletions(-) diff --git a/backend/protzilla/data_analysis/crosslinking_validation.py b/backend/protzilla/data_analysis/crosslinking_validation.py index 68df6da46..d4715778f 100644 --- a/backend/protzilla/data_analysis/crosslinking_validation.py +++ b/backend/protzilla/data_analysis/crosslinking_validation.py @@ -581,12 +581,12 @@ def check_crosslink(crosslink: pd.Series) -> pd.Series: protein_sequence2 = get_protein_sequence_from_df( amino_acid_sequences_df=amino_acid_sequences_df, protein_id=protein_id2 ) - plddt_at_position1 = plddt_df.query( + plddt_at_position1 = float(plddt_df.query( "residueNumber == @crosslink.crosslinker_position1" - ).iloc[0]["confidenceScore"] - plddt_at_position2 = plddt_df.query( + ).iloc[0]["confidenceScore"]) + plddt_at_position2 = float(plddt_df.query( "residueNumber == @crosslink.crosslinker_position2" - ).iloc[0]["confidenceScore"] + ).iloc[0]["confidenceScore"]) pae_x_position1 = pae_matrix[ crosslink.crosslinker_position1, crosslink.crosslinker_position2 @@ -648,38 +648,19 @@ def check_crosslink(crosslink: pd.Series) -> pd.Series: ) case CrosslinkingValidationCriterion.plddt_adjusted.value: - cl_half = crosslinker_length / 2 - get_plddt_factor: Callable[[float], float] = lambda plddt: 1 - ( plddt / 100 ) - # Strict mode: plDDT of 0 (factor 1) allows +/- half length for each half - # get_cl_half_tolerated_length_range: Callable[ - # [float, float], tuple[float, float] - # ] = lambda cl_half, plddt_factor: ( - # cl_half * (1 - plddt_factor), - # cl_half * (1 + plddt_factor), - # ) - - # Less strict mode: plDDT of 0 (factor 1) allows +/- total CL length for each half - # TODO: The calculations below get ugly when plDDT < 50 because we'd get negative lengths - get_cl_half_tolerated_length_range: Callable[ - [float, float], tuple[float, float] - ] = lambda cl_half, plddt_factor: ( - cl_half * 2 * (1 - plddt_factor), - cl_half * 2 * (1 + plddt_factor), - ) - plddt_factor_pos1 = get_plddt_factor(plddt_at_position1) plddt_factor_pos2 = get_plddt_factor(plddt_at_position2) - cl_half1_min, cl_half1_max = get_cl_half_tolerated_length_range(cl_half, plddt_factor_pos1) - cl_half2_min, cl_half2_max = get_cl_half_tolerated_length_range(cl_half, plddt_factor_pos2) - - accepted_distance_lower_bound = cl_half1_min + cl_half2_min - accepted_distance_upper_bound = cl_half1_max + cl_half2_max + max_half_tolerance = crosslinker_length # Note: This is quite lenient + tolerance_pos1 = plddt_factor_pos1 * max_half_tolerance + tolerance_pos2 = plddt_factor_pos2 * max_half_tolerance + accepted_distance_lower_bound = max(crosslinker_length - tolerance_pos1 - tolerance_pos2, 0) + accepted_distance_upper_bound = crosslinker_length + tolerance_pos1 + tolerance_pos2 valid = ( accepted_distance_lower_bound From 0ab6fa010729e8169d8710c933229643aef90ed8 Mon Sep 17 00:00:00 2001 From: jorisfu Date: Fri, 15 May 2026 14:24:56 +0200 Subject: [PATCH 05/33] fix monomer validation test --- .../data_analysis/crosslinking_validation.py | 30 ++++++++++++------- .../test_crosslinking_validation.py | 28 +++++++++++++++-- 2 files changed, 44 insertions(+), 14 deletions(-) diff --git a/backend/protzilla/data_analysis/crosslinking_validation.py b/backend/protzilla/data_analysis/crosslinking_validation.py index d4715778f..8f4952412 100644 --- a/backend/protzilla/data_analysis/crosslinking_validation.py +++ b/backend/protzilla/data_analysis/crosslinking_validation.py @@ -581,18 +581,22 @@ def check_crosslink(crosslink: pd.Series) -> pd.Series: protein_sequence2 = get_protein_sequence_from_df( amino_acid_sequences_df=amino_acid_sequences_df, protein_id=protein_id2 ) - plddt_at_position1 = float(plddt_df.query( - "residueNumber == @crosslink.crosslinker_position1" - ).iloc[0]["confidenceScore"]) - plddt_at_position2 = float(plddt_df.query( - "residueNumber == @crosslink.crosslinker_position2" - ).iloc[0]["confidenceScore"]) + plddt_at_position1 = float( + plddt_df.query("residueNumber == @crosslink.crosslinker_position1").iloc[0][ + "confidenceScore" + ] + ) + plddt_at_position2 = float( + plddt_df.query("residueNumber == @crosslink.crosslinker_position2").iloc[0][ + "confidenceScore" + ] + ) pae_x_position1 = pae_matrix[ - crosslink.crosslinker_position1, crosslink.crosslinker_position2 + crosslink.crosslinker_position1 - 1, crosslink.crosslinker_position2 - 1 ] # Using position1 as scored residue pae_x_position2 = pae_matrix[ - crosslink.crosslinker_position2, crosslink.crosslinker_position1 + crosslink.crosslinker_position2 - 1, crosslink.crosslinker_position1 - 1 ] # Using position2 as scored residue predicted_distance = get_distance_between_two_amino_acids_in_angstrom( @@ -655,12 +659,16 @@ def check_crosslink(crosslink: pd.Series) -> pd.Series: plddt_factor_pos1 = get_plddt_factor(plddt_at_position1) plddt_factor_pos2 = get_plddt_factor(plddt_at_position2) - max_half_tolerance = crosslinker_length # Note: This is quite lenient + max_half_tolerance = crosslinker_length # Note: This is quite lenient tolerance_pos1 = plddt_factor_pos1 * max_half_tolerance tolerance_pos2 = plddt_factor_pos2 * max_half_tolerance - accepted_distance_lower_bound = max(crosslinker_length - tolerance_pos1 - tolerance_pos2, 0) - accepted_distance_upper_bound = crosslinker_length + tolerance_pos1 + tolerance_pos2 + accepted_distance_lower_bound = max( + crosslinker_length - tolerance_pos1 - tolerance_pos2, 0 + ) + accepted_distance_upper_bound = ( + crosslinker_length + tolerance_pos1 + tolerance_pos2 + ) valid = ( accepted_distance_lower_bound diff --git a/backend/tests/protzilla/data_analysis/test_crosslinking_validation.py b/backend/tests/protzilla/data_analysis/test_crosslinking_validation.py index f62a1f716..f7456aa37 100644 --- a/backend/tests/protzilla/data_analysis/test_crosslinking_validation.py +++ b/backend/tests/protzilla/data_analysis/test_crosslinking_validation.py @@ -1,4 +1,5 @@ import pandas as pd +from backend.protzilla.constants.option_types import CrosslinkingValidationCriterion import pytest import logging from unittest.mock import patch, MagicMock @@ -33,7 +34,9 @@ (6.01, False), # outside bounds ], ) -def test_validate_with_angstrom_deviation(distance, expected): +def test_monomer_validation(distance, expected): + crosslinker_information = {"DSS": [5.0, 1.0, 1.0]} # Length 5 Å ± 1 Å + # Fake AlphaFold Data with chain IDs cif_df = pd.DataFrame( { @@ -68,10 +71,24 @@ def test_validate_with_angstrom_deviation(distance, expected): {"entry_id": ["test"], "uniprot_accession": ["P12345"]} ) - crosslinker_information = {"DSS": [5.0, 1.0, 1.0]} # Länge 5 Å ± 1 Å valid_ids = {"P12345": ["P12345"]} structures_to_validate = ["P12345"] + pae_df_noerror = pd.DataFrame( + { + "predicted_aligned_error": ["[[0, 0], [0, 0]]"], + "max_predicted_aligned_error": [31.75], + } + ) + + plddt_df_noerror = pd.DataFrame( + { + "residueNumber": [1, 2], + "confidenceScore": [100, 100], + # confidenceCategory is not required + } + ) + result = validate_with_angstrom_deviation( crosslinking_df=crosslinking_df, structure_metadata_df=structure_metadata_df, @@ -81,9 +98,12 @@ def test_validate_with_angstrom_deviation(distance, expected): valid_ids=valid_ids, id_column_name="_atom_site.pdbx_sifts_xref_db_acc", structures_to_validate=structures_to_validate, + pae_df=pae_df_noerror, + plddt_df=plddt_df_noerror, + validation_criterion=CrosslinkingValidationCriterion.manual_bounds.value, ) - df = result["crosslinking_result_df"] + df: pd.DataFrame = result["crosslinking_result_df"] assert "alphafold_distance" in df.columns assert "valid_crosslink" in df.columns @@ -94,6 +114,8 @@ def test_validate_with_angstrom_deviation(distance, expected): assert df.loc[0, "valid_crosslink"] == expected assert df.loc[0, "link_type"] == "intra" + # Validation with PAE + def test_modify_form_creates_crosslinker_fields(): crosslinking_df = pd.DataFrame({"Crosslinker": ["DSS", "BS3", "DSS"]}) From 291bd7d314ebe27e0697422e09d174878772d847 Mon Sep 17 00:00:00 2001 From: jorisfu Date: Mon, 18 May 2026 11:58:15 +0200 Subject: [PATCH 06/33] refactor: expose PAE as matrix for monomers --- backend/protzilla/constants/data_types.py | 2 +- .../data_analysis/crosslinking_validation.py | 13 +++++-------- .../alphafold_protein_structure_load.py | 18 ++++++++++++++++++ backend/protzilla/methods/importing.py | 4 ++-- .../app/run-screen/node-editor/StepNode.tsx | 2 +- 5 files changed, 27 insertions(+), 12 deletions(-) diff --git a/backend/protzilla/constants/data_types.py b/backend/protzilla/constants/data_types.py index 2665cfe3e..94f4db695 100644 --- a/backend/protzilla/constants/data_types.py +++ b/backend/protzilla/constants/data_types.py @@ -23,7 +23,7 @@ class DataKey(StrEnum): GENE_MAPPING_DF = "gene_mapping_df" CIF_DF = "cif_df" AMINO_ACID_SEQUENCES_DF = "amino_acid_sequences_df" - PAE_DF = "pae_df" # pae = predicted aligned error + PAE_MATRIX = "pae_matrix" # pae = predicted aligned error PLDDT_DF = "plddt_df" # plddt = predicted local distance difference test CROSSLINKING_DF = "crosslinking_df" CONFIDENCE_DF = "confidence_df" diff --git a/backend/protzilla/data_analysis/crosslinking_validation.py b/backend/protzilla/data_analysis/crosslinking_validation.py index 8f4952412..d4c120a78 100644 --- a/backend/protzilla/data_analysis/crosslinking_validation.py +++ b/backend/protzilla/data_analysis/crosslinking_validation.py @@ -358,7 +358,7 @@ def monomer_validation( crosslinker_information: dict[str, list[float]], cif_df: pd.DataFrame, amino_acid_sequences_df: pd.DataFrame, - pae_df: pd.DataFrame, + pae_matrix: np.ndarray[tuple[int, int]], plddt_df: pd.DataFrame, validation_criterion: CrosslinkingValidationCriterion, ) -> dict: @@ -372,7 +372,7 @@ def monomer_validation( allowed distance boundaries (e.g., [min_dist, max_dist]). :param cif_df: DataFrame containing mmCIF information. :param amino_acid_sequences_df: DataFrame containing known amino acid sequences. - :param pae_df: DataFrame containing AlphaFold PAE data. + :param pae_matrix: NumPy 2D array containing AlphaFold PAE data. :param plddt_df: DataFrame containing AlphaFold pLDDT data. :return: A dictionary containing the validation results and distance metrics. """ @@ -384,7 +384,7 @@ def monomer_validation( structure_metadata_df=structure_metadata_df, cif_df=cif_df, amino_acid_sequences_df=amino_acid_sequences_df, - pae_df=pae_df, + pae_matrix=pae_matrix, plddt_df=plddt_df, valid_ids=valid_ids, id_column_name="_atom_site.pdbx_sifts_xref_db_acc", @@ -501,7 +501,7 @@ def validate_with_angstrom_deviation( structure_metadata_df: pd.DataFrame, cif_df: pd.DataFrame, plddt_df: pd.DataFrame, - pae_df: pd.DataFrame, + pae_matrix: np.ndarray[tuple[int, int]], amino_acid_sequences_df: pd.DataFrame, valid_ids: dict, id_column_name: str, @@ -519,7 +519,7 @@ def validate_with_angstrom_deviation( [crosslinker_length, upper_accepted_deviation, lower_accepted_deviation]. :param cif_df: DataFrame containing CIF information (predicted coordinates of all the protein's atoms). :param plddt_df: DataFrame containing the local AlphaFold pLDDT values for each residue. - :param pae_df: DataFrame containing the PAE values for each residue pair. + :param pae_matrix: NumPy 2D array containing the PAE values for each residue pair. :param amino_acid_sequences_df: Dataframe that contains all known amino acid sequences. :param valid_ids: Dictionary mapping protein IDs to their valid chain/entity identifiers in the CIF data. :param id_column_name: The column name in the cif_df to use for matching against valid_ids. @@ -569,9 +569,6 @@ def validate_with_angstrom_deviation( relevant_crosslinks_df, amino_acid_sequences_df ) - pae_string = str(pae_df["predicted_aligned_error"].iloc[0]) - pae_matrix = np.array(ast.literal_eval(pae_string)) - def check_crosslink(crosslink: pd.Series) -> pd.Series: protein_id1 = crosslink.Protein_id1 protein_id2 = crosslink.Protein_id2 diff --git a/backend/protzilla/importing/alphafold_protein_structure_load.py b/backend/protzilla/importing/alphafold_protein_structure_load.py index 7519c1d9b..5a7caecf8 100644 --- a/backend/protzilla/importing/alphafold_protein_structure_load.py +++ b/backend/protzilla/importing/alphafold_protein_structure_load.py @@ -11,6 +11,8 @@ from datetime import datetime, timezone import gemmi import pandas as pd +import numpy as np +import ast import requests import re @@ -426,8 +428,15 @@ def fetch_alphafold_protein_structure( messages.append(dict(level=logging.WARNING, msg=message)) data_for_visualization = None + pae_string = str(df_dict["pae_df"]["predicted_aligned_error"].iloc[0]) + pae_matrix = np.array(ast.literal_eval(pae_string)) + del df_dict["pae_df"] + return dict( **df_dict, + pae_matrix=OutputItem( + output_type=OutputType.JOBLIB_ARTIFACT, value=pae_matrix + ), messages=messages, visualization=OutputItem( output_type=OutputType.VISUALIZATION, value=data_for_visualization @@ -702,12 +711,21 @@ def get_monomer_structure_dfs(entry_id: str) -> dict[str, Any]: "amino_acid_sequences_df": amino_acid_sequences_df, } check_success_of_get_df(entry_id=entry_id, df_dict=df_dict, messages=messages) + data_for_visualization = { "structure_entry_id": entry_id, "cif_df": cif_df, } + + pae_string = str(df_dict["pae_df"]["predicted_aligned_error"].iloc[0]) + pae_matrix = np.array(ast.literal_eval(pae_string)) + del df_dict["pae_df"] + return dict( **df_dict, + pae_matrix=OutputItem( + output_type=OutputType.JOBLIB_ARTIFACT, value=pae_matrix + ), messages=messages, visualization=OutputItem( output_type=OutputType.VISUALIZATION, value=data_for_visualization diff --git a/backend/protzilla/methods/importing.py b/backend/protzilla/methods/importing.py index c5b887d00..49cc9a6ed 100644 --- a/backend/protzilla/methods/importing.py +++ b/backend/protzilla/methods/importing.py @@ -438,9 +438,9 @@ class AlphaFoldPredictionLoad(ImportingStep): output_keys = [ DataKey.STRUCTURE_METADATA_DF, DataKey.CIF_DF, - DataKey.PAE_DF, DataKey.PLDDT_DF, DataKey.AMINO_ACID_SEQUENCES_DF, + DataKey.PAE_MATRIX, ] plot_method = None @@ -503,7 +503,7 @@ class ImportMonomerStructurePredictionFromDisk(ImportingStep): output_keys = [ DataKey.STRUCTURE_METADATA_DF, DataKey.CIF_DF, - DataKey.PAE_DF, + DataKey.PAE_MATRIX, DataKey.PLDDT_DF, DataKey.AMINO_ACID_SEQUENCES_DF, ] diff --git a/frontend/src/components/app/run-screen/node-editor/StepNode.tsx b/frontend/src/components/app/run-screen/node-editor/StepNode.tsx index a83c617fd..7adb84d93 100644 --- a/frontend/src/components/app/run-screen/node-editor/StepNode.tsx +++ b/frontend/src/components/app/run-screen/node-editor/StepNode.tsx @@ -53,7 +53,7 @@ const DATA_TYPE_ICON_MAP: Partial> = { full_data_df: handleFullDataIcon, gene_mapping_df: handleDnaIcon, metadata_df: handleMetadataIcon, - pae_df: handlePaeIcon, + pae_matrix: handlePaeIcon, peptide_df: handlePeptidesIcon, plddt_df: handlePlddtIcon, protein_df: handleProteinIcon, From 23d5b60c543f16371cb6e6bba98877f0bbb9fe54 Mon Sep 17 00:00:00 2001 From: jorisfu Date: Mon, 18 May 2026 18:31:08 +0200 Subject: [PATCH 07/33] feat: PAE for multimers --- .../data_analysis/crosslinking_validation.py | 71 ++++++++++++++----- .../alphafold_protein_structure_load.py | 36 +++++++++- backend/protzilla/methods/data_analysis.py | 14 +++- backend/protzilla/methods/importing.py | 1 + 4 files changed, 101 insertions(+), 21 deletions(-) diff --git a/backend/protzilla/data_analysis/crosslinking_validation.py b/backend/protzilla/data_analysis/crosslinking_validation.py index d4c120a78..b1e77b4bc 100644 --- a/backend/protzilla/data_analysis/crosslinking_validation.py +++ b/backend/protzilla/data_analysis/crosslinking_validation.py @@ -2,8 +2,11 @@ import ast import math +from multiprocessing.sharedctypes import Value from typing import TYPE_CHECKING, Callable +from numpy.testing import assert_ + from backend.protzilla.constants.option_types import CrosslinkingValidationCriterion import pandas as pd import numpy as np @@ -459,6 +462,8 @@ def multimer_validation( cif_df: pd.DataFrame, amino_acid_sequences_df: pd.DataFrame, job_request_df: pd.DataFrame, + pae_matrix: np.ndarray[tuple[int, int]], + validation_criterion: CrosslinkingValidationCriterion, ) -> dict: """ Validates crosslinking data for a multimeric protein complex by checking @@ -476,6 +481,7 @@ def multimer_validation( :param cif_df: DataFrame containing mmCIF information. :param amino_acid_sequences_df: DataFrame containing known amino acid sequences. :param job_request_df: DataFrame containing the loaded AlphaFold job request JSON. + :param pae_matrix: NumPy 2D array containing the PAE values for each residue pair. :return: A dictionary containing the validation results and distance metrics. """ valid_ids = get_valid_ids_per_protein_id_from_job_request( @@ -492,6 +498,8 @@ def multimer_validation( valid_ids=valid_ids, id_column_name="_atom_site.label_entity_id", structures_to_validate=structures_to_validate, + pae_matrix=pae_matrix, + validation_criterion=validation_criterion, ) @@ -500,13 +508,13 @@ def validate_with_angstrom_deviation( crosslinker_information: dict[str, list[float]], structure_metadata_df: pd.DataFrame, cif_df: pd.DataFrame, - plddt_df: pd.DataFrame, - pae_matrix: np.ndarray[tuple[int, int]], amino_acid_sequences_df: pd.DataFrame, valid_ids: dict, id_column_name: str, structures_to_validate: list, validation_criterion: CrosslinkingValidationCriterion, + plddt_df: pd.DataFrame | None = None, + pae_matrix: np.ndarray[tuple[int, int]] | None = None, ) -> dict: """ Validates crosslinks by comparing the crosslinker lengths with the distances between the linked @@ -578,23 +586,39 @@ def check_crosslink(crosslink: pd.Series) -> pd.Series: protein_sequence2 = get_protein_sequence_from_df( amino_acid_sequences_df=amino_acid_sequences_df, protein_id=protein_id2 ) - plddt_at_position1 = float( - plddt_df.query("residueNumber == @crosslink.crosslinker_position1").iloc[0][ - "confidenceScore" - ] - ) - plddt_at_position2 = float( - plddt_df.query("residueNumber == @crosslink.crosslinker_position2").iloc[0][ - "confidenceScore" - ] - ) - pae_x_position1 = pae_matrix[ - crosslink.crosslinker_position1 - 1, crosslink.crosslinker_position2 - 1 - ] # Using position1 as scored residue - pae_x_position2 = pae_matrix[ - crosslink.crosslinker_position2 - 1, crosslink.crosslinker_position1 - 1 - ] # Using position2 as scored residue + def get_site_plddts(): + if plddt_df is None: + return np.nan, np.nan + + plddt_at_position1 = float( + plddt_df.query("residueNumber == @crosslink.crosslinker_position1").iloc[0][ + "confidenceScore" + ] + ) + plddt_at_position2 = float( + plddt_df.query("residueNumber == @crosslink.crosslinker_position2").iloc[0][ + "confidenceScore" + ] + ) + + return plddt_at_position1, plddt_at_position2 + + def get_paes(): + if pae_matrix is None: + return np.nan, np.nan + + pae_x_position1 = pae_matrix[ + crosslink.crosslinker_position1 - 1, crosslink.crosslinker_position2 - 1 + ] # Using position1 as scored residue + pae_x_position2 = pae_matrix[ + crosslink.crosslinker_position2 - 1, crosslink.crosslinker_position1 - 1 + ] # Using position2 as scored residue + + return pae_x_position1, pae_x_position2 + + plddt_at_position1, plddt_at_position2 = get_site_plddts() + pae_x_position1, pae_x_position2 = get_paes() predicted_distance = get_distance_between_two_amino_acids_in_angstrom( amino_acid_position1=crosslink.crosslinker_position1, @@ -631,6 +655,9 @@ def check_crosslink(crosslink: pd.Series) -> pd.Series: ) + crosslinker_length case CrosslinkingValidationCriterion.max_pae.value: + if np.isnan(pae_x_position1) or np.isnan(pae_x_position2): + raise ValueError("No PAE data given.") + pae_tolerance = max(pae_x_position1, pae_x_position2) accepted_distance_lower_bound = float( max(crosslinker_length - pae_tolerance, 0.0) @@ -640,6 +667,9 @@ def check_crosslink(crosslink: pd.Series) -> pd.Series: ) case CrosslinkingValidationCriterion.min_pae.value: + if np.isnan(pae_x_position1) or np.isnan(pae_x_position2): + raise ValueError("No PAE data given.") + pae_x_position1, pae_x_position2 = get_paes() pae_tolerance = min(pae_x_position1, pae_x_position2) accepted_distance_lower_bound = float( max(crosslinker_length - pae_tolerance, 0.0) @@ -649,10 +679,15 @@ def check_crosslink(crosslink: pd.Series) -> pd.Series: ) case CrosslinkingValidationCriterion.plddt_adjusted.value: + if np.isnan(plddt_at_position1) or np.isnan(plddt_at_position2): + raise ValueError("No pLDDT data given.") + get_plddt_factor: Callable[[float], float] = lambda plddt: 1 - ( plddt / 100 ) + plddt_at_position1, plddt_at_position2 = get_site_plddts() + plddt_factor_pos1 = get_plddt_factor(plddt_at_position1) plddt_factor_pos2 = get_plddt_factor(plddt_at_position2) diff --git a/backend/protzilla/importing/alphafold_protein_structure_load.py b/backend/protzilla/importing/alphafold_protein_structure_load.py index 5a7caecf8..052df61e2 100644 --- a/backend/protzilla/importing/alphafold_protein_structure_load.py +++ b/backend/protzilla/importing/alphafold_protein_structure_load.py @@ -21,7 +21,7 @@ from backend.protzilla.importing.fasta_import import fasta_import from backend.protzilla.networking import download_file_from_url from backend.protzilla.utilities.utilities import copy_file_to_directory -from backend.protzilla.steps import OutputItem, OutputType +from backend.protzilla.steps import Output, OutputItem, OutputType def get_monomer_metadata_df() -> pd.DataFrame: @@ -733,6 +733,31 @@ def get_monomer_structure_dfs(entry_id: str) -> dict[str, Any]: ) +def unwrap_full_data_df(full_data_df: pd.DataFrame) -> dict[str, Any]: + """ + Extracts certain data from a full_data_df, deletes the extracted columns + and returns the "remaining" full_data_df as well as the extracted data. + + :param full_data_df: The AlphaFold3 full_data_df + :return dict: + - "full_data_df": The updated reduced full_data_df + - "pae_matrix": Numpy matrix with the PAE values for each residue pair + """ + + # Construct plDDT dataframe + # TODO: Getting pLDDT from AlphaFold3 is a bit harder as its on a per-atom level + # rather than per-residue, so we'd need to extract it from the cif file + # (column _atom_site.B_iso_or_equiv, see https://github.com/google-deepmind/alphafold3/issues/330). + # Skipping this for now. + + pae_matrix = np.array(full_data_df["pae"].iloc[0]) + full_data_df = full_data_df.drop(columns=["pae"]) + + return dict( + full_data_df=full_data_df, + pae_matrix=pae_matrix, + ) + def get_multimer_structure_dfs(entry_id: str) -> dict[str, Any]: """ Writes multimer structure data from disk of a specific entry ID into dataframes. @@ -816,9 +841,18 @@ def get_multimer_structure_dfs(entry_id: str) -> dict[str, Any]: "structure_entry_id": entry_id, "cif_df": cif_df, } + + unwrapped_full_data = unwrap_full_data_df(df_dict["full_data_df"]) + df_dict["full_data_df"] = unwrapped_full_data["full_data_df"] + + pae_matrix = unwrapped_full_data["pae_matrix"] + return dict( **df_dict, messages=messages, + pae_matrix=OutputItem( + output_type=OutputType.JOBLIB_ARTIFACT, value=pae_matrix + ), visualization=OutputItem( output_type=OutputType.VISUALIZATION, value=data_for_visualization ), diff --git a/backend/protzilla/methods/data_analysis.py b/backend/protzilla/methods/data_analysis.py index ed84865b2..cc6ae8493 100644 --- a/backend/protzilla/methods/data_analysis.py +++ b/backend/protzilla/methods/data_analysis.py @@ -2461,10 +2461,20 @@ class CrosslinkingValidationWithAngstromDeviationForMultimer( operation = "Crosslinking Validation" method_description = "Validates crosslinks between proteins based on the difference between the length of the crosslinker and the distance between the amino acids which were connected by the crosslinker. (in Ångström)" calc_method = staticmethod(multimer_validation) - plot_method = staticmethod(multimer_diagrams) + # plot_method = staticmethod(multimer_diagrams) def create_form(self): return Form( label="Ångström Deviation - Multimer", - input_fields=[], + input_fields=[ + DropdownField( + name="validation_criterion", + label="Validation criterion", + options=CrosslinkingValidationCriterion, + value=CrosslinkingValidationCriterion.manual_bounds, + ), + FormDivider( + label="Crosslinker lengths and bounds", + ), + ], ) diff --git a/backend/protzilla/methods/importing.py b/backend/protzilla/methods/importing.py index 49cc9a6ed..51a4a5716 100644 --- a/backend/protzilla/methods/importing.py +++ b/backend/protzilla/methods/importing.py @@ -617,6 +617,7 @@ class ImportMultimerStructurePredictionFromDisk(ImportingStep): DataKey.FULL_DATA_DF, DataKey.JOB_REQUEST_DF, DataKey.AMINO_ACID_SEQUENCES_DF, + DataKey.PAE_MATRIX, ] def create_form(self): From 1e0e30e53084cb1bd05ceaa478abcc6b1c684e1a Mon Sep 17 00:00:00 2001 From: jorisfu Date: Tue, 19 May 2026 10:35:55 +0200 Subject: [PATCH 08/33] feat: pLDDT for multimers --- .../data_analysis/crosslinking_validation.py | 9 +++---- .../alphafold_protein_structure_load.py | 26 +++++++++++++++++++ backend/protzilla/methods/data_analysis.py | 1 - 3 files changed, 29 insertions(+), 7 deletions(-) diff --git a/backend/protzilla/data_analysis/crosslinking_validation.py b/backend/protzilla/data_analysis/crosslinking_validation.py index b1e77b4bc..73456bdba 100644 --- a/backend/protzilla/data_analysis/crosslinking_validation.py +++ b/backend/protzilla/data_analysis/crosslinking_validation.py @@ -395,11 +395,6 @@ def monomer_validation( validation_criterion=validation_criterion, ) - -def monomer_validation_with_pae(): - pass - - def get_protein_id_from_sequence(amino_acid_sequences_df, target_sequence): """ Finds the Protein ID(s) for a given exact protein sequence. @@ -509,7 +504,7 @@ def validate_with_angstrom_deviation( structure_metadata_df: pd.DataFrame, cif_df: pd.DataFrame, amino_acid_sequences_df: pd.DataFrame, - valid_ids: dict, + valid_ids: dict[str, list[int]], id_column_name: str, structures_to_validate: list, validation_criterion: CrosslinkingValidationCriterion, @@ -588,6 +583,7 @@ def check_crosslink(crosslink: pd.Series) -> pd.Series: ) def get_site_plddts(): + # TODO for multimers: get pLDDT from CIF if plddt_df is None: return np.nan, np.nan @@ -605,6 +601,7 @@ def get_site_plddts(): return plddt_at_position1, plddt_at_position2 def get_paes(): + # TODO for multimers: get correct PAE index (global index, not per-chain) if pae_matrix is None: return np.nan, np.nan diff --git a/backend/protzilla/importing/alphafold_protein_structure_load.py b/backend/protzilla/importing/alphafold_protein_structure_load.py index 052df61e2..1d1391aed 100644 --- a/backend/protzilla/importing/alphafold_protein_structure_load.py +++ b/backend/protzilla/importing/alphafold_protein_structure_load.py @@ -758,6 +758,30 @@ def unwrap_full_data_df(full_data_df: pd.DataFrame) -> dict[str, Any]: pae_matrix=pae_matrix, ) +def get_plddt_from_cif(cif_df: pd.DataFrame): + """ + For use with multimers predicted using Alphafold3. + Returns per-residue pLDDT values for the predicted structure. + Note that sine AlphaFold3 uses per-atom pLDDT, we use the pLDDT for the CA atom. + See also https://github.com/google-deepmind/alphafold3/issues/330 + + :param cif_df: the cif_df holding the _atom_site table. + :return: DataFrame containing columns + "chainID", "residueNumber", "confidenceScore", "confidenceCategory" + """ + + filtered_cif_df = cif_df[cif_df["_atom_site.label_atom_id"] == "CA"] + filtered_cif_df = filtered_cif_df[["_atom_site.auth_asym_id", "_atom_site.label_seq_id", "_atom_site.B_iso_or_equiv"]] + filtered_cif_df = filtered_cif_df.rename(columns={ + "_atom_site.auth_asym_id": "chainID", + "_atom_site.label_seq_id": "residueNumber", + "_atom_site.B_iso_or_equiv": "confidenceScore", + }) + + # TODO: ConfidenceCategory maybe yes? + + return filtered_cif_df + def get_multimer_structure_dfs(entry_id: str) -> dict[str, Any]: """ Writes multimer structure data from disk of a specific entry ID into dataframes. @@ -846,10 +870,12 @@ def get_multimer_structure_dfs(entry_id: str) -> dict[str, Any]: df_dict["full_data_df"] = unwrapped_full_data["full_data_df"] pae_matrix = unwrapped_full_data["pae_matrix"] + plddt_df = get_plddt_from_cif(df_dict["cif_df"]) return dict( **df_dict, messages=messages, + plddt_df=plddt_df, pae_matrix=OutputItem( output_type=OutputType.JOBLIB_ARTIFACT, value=pae_matrix ), diff --git a/backend/protzilla/methods/data_analysis.py b/backend/protzilla/methods/data_analysis.py index cc6ae8493..e90abb8ef 100644 --- a/backend/protzilla/methods/data_analysis.py +++ b/backend/protzilla/methods/data_analysis.py @@ -95,7 +95,6 @@ monomer_diagrams, multimer_diagrams, monomer_validation, - monomer_validation_with_pae, multimer_validation, ) from backend.protzilla.run import Run From eabd547923b8aa870cc54f7e1dc8e56b4df55302 Mon Sep 17 00:00:00 2001 From: jorisfu Date: Tue, 19 May 2026 11:28:33 +0200 Subject: [PATCH 09/33] feat: add PAE/plDDT consistently to multimer imports --- .../data_analysis/crosslinking_validation.py | 3 +++ .../alphafold_protein_structure_load.py | 17 ++++++++++++++--- backend/protzilla/methods/importing.py | 3 +++ 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/backend/protzilla/data_analysis/crosslinking_validation.py b/backend/protzilla/data_analysis/crosslinking_validation.py index 73456bdba..fa56d2f05 100644 --- a/backend/protzilla/data_analysis/crosslinking_validation.py +++ b/backend/protzilla/data_analysis/crosslinking_validation.py @@ -457,6 +457,7 @@ def multimer_validation( cif_df: pd.DataFrame, amino_acid_sequences_df: pd.DataFrame, job_request_df: pd.DataFrame, + plddt_df: pd.DataFrame, pae_matrix: np.ndarray[tuple[int, int]], validation_criterion: CrosslinkingValidationCriterion, ) -> dict: @@ -476,6 +477,7 @@ def multimer_validation( :param cif_df: DataFrame containing mmCIF information. :param amino_acid_sequences_df: DataFrame containing known amino acid sequences. :param job_request_df: DataFrame containing the loaded AlphaFold job request JSON. + :param plddt_df: DataFrame containing per-residue pLDDT values. :param pae_matrix: NumPy 2D array containing the PAE values for each residue pair. :return: A dictionary containing the validation results and distance metrics. """ @@ -494,6 +496,7 @@ def multimer_validation( id_column_name="_atom_site.label_entity_id", structures_to_validate=structures_to_validate, pae_matrix=pae_matrix, + plddt_df=plddt_df, validation_criterion=validation_criterion, ) diff --git a/backend/protzilla/importing/alphafold_protein_structure_load.py b/backend/protzilla/importing/alphafold_protein_structure_load.py index 1d1391aed..52ec1a31d 100644 --- a/backend/protzilla/importing/alphafold_protein_structure_load.py +++ b/backend/protzilla/importing/alphafold_protein_structure_load.py @@ -1014,13 +1014,24 @@ def upload_multimer_prediction( } if not any(df.empty for df in df_dict.values()): - success_msg = f"Successfully loaded AlphaFold data for entry '{entry_id}'" - logger.info(success_msg) - messages.append(dict(level=logging.INFO, msg=success_msg)) + + unwrapped_full_data = unwrap_full_data_df(df_dict["full_data_df"]) + df_dict["full_data_df"] = unwrapped_full_data["full_data_df"] + + pae_matrix=OutputItem( + output_type=OutputType.JOBLIB_ARTIFACT, value=unwrapped_full_data["pae_matrix"] + ) + df_dict["pae_matrix"] = pae_matrix + df_dict["plddt_df"] = get_plddt_from_cif(df_dict["cif_df"]) + data_for_visualization = { "structure_entry_id": entry_id, "cif_df": cif_df, } + + success_msg = f"Successfully loaded AlphaFold data for entry '{entry_id}'" + logger.info(success_msg) + messages.append(dict(level=logging.INFO, msg=success_msg)) else: message = f"Could not load AlphaFold data for entry '{entry_id}'" logger.warning(message) diff --git a/backend/protzilla/methods/importing.py b/backend/protzilla/methods/importing.py index 51a4a5716..321389ebd 100644 --- a/backend/protzilla/methods/importing.py +++ b/backend/protzilla/methods/importing.py @@ -540,6 +540,8 @@ class UploadMultimerPredictions(ImportingStep): DataKey.FULL_DATA_DF, DataKey.JOB_REQUEST_DF, DataKey.AMINO_ACID_SEQUENCES_DF, + DataKey.PAE_MATRIX, + DataKey.PLDDT_DF, ] def create_form(self): @@ -618,6 +620,7 @@ class ImportMultimerStructurePredictionFromDisk(ImportingStep): DataKey.JOB_REQUEST_DF, DataKey.AMINO_ACID_SEQUENCES_DF, DataKey.PAE_MATRIX, + DataKey.PLDDT_DF, ] def create_form(self): From bcb054f7ac57693cd53586863ceaa1a2593cc552 Mon Sep 17 00:00:00 2001 From: jorisfu Date: Tue, 19 May 2026 15:11:55 +0200 Subject: [PATCH 10/33] feat: proper PAE validation for multimers --- .../data_analysis/crosslinking_validation.py | 76 +++++++++++++++---- .../alphafold_protein_structure_load.py | 55 ++++++++------ 2 files changed, 94 insertions(+), 37 deletions(-) diff --git a/backend/protzilla/data_analysis/crosslinking_validation.py b/backend/protzilla/data_analysis/crosslinking_validation.py index fa56d2f05..7658a5aa9 100644 --- a/backend/protzilla/data_analysis/crosslinking_validation.py +++ b/backend/protzilla/data_analysis/crosslinking_validation.py @@ -395,6 +395,7 @@ def monomer_validation( validation_criterion=validation_criterion, ) + def get_protein_id_from_sequence(amino_acid_sequences_df, target_sequence): """ Finds the Protein ID(s) for a given exact protein sequence. @@ -450,6 +451,47 @@ def get_valid_ids_per_protein_id_from_job_request( return valid_ids +def get_global_residue_index( + position_within_protein: int, # 1-based index + chain_id: str, + cif_df: pd.DataFrame, +): + """ + For multimer PAE lookup: For a position within a given protein in a chain, + get the global 0-based residue index used to find that position in the PAE matrix. + + Note: This assumes that the order of AAs in the _atom_site table corresponds + to the order of residues in the pae matrix and thus the other residue-based tables in + the cif. + + :param position_within_protein: index of the amino acid within the protein (1-based) + :param chain_id: the chain ID of the protein within the complex + :param cif_df: DataFrame containing the _atom_site table of the complex structure + """ + + # Get table with only unique chain and sequence IDs and infer global index + index_lookup_df = ( + cif_df[["_atom_site.label_asym_id", "_atom_site.label_seq_id"]] + .drop_duplicates() + .reset_index(drop=True) + ) + index_lookup_df.reset_index(inplace=True) + + index_lookup_df = index_lookup_df[ + index_lookup_df["_atom_site.label_asym_id"] == chain_id + ] + index_lookup_df = index_lookup_df[ + index_lookup_df["_atom_site.label_seq_id"] == position_within_protein + ] + + if len(index_lookup_df) != 1: + raise ValueError( + "Invalid input: CIF contains multiple atoms mapped to same chain/sequence ID pair!" + ) + + return index_lookup_df["index"].iloc[0] + + def multimer_validation( crosslinking_df: pd.DataFrame, structure_metadata_df: pd.DataFrame, @@ -585,39 +627,45 @@ def check_crosslink(crosslink: pd.Series) -> pd.Series: amino_acid_sequences_df=amino_acid_sequences_df, protein_id=protein_id2 ) - def get_site_plddts(): - # TODO for multimers: get pLDDT from CIF + def get_site_plddts(crosslink: pd.Series): if plddt_df is None: return np.nan, np.nan plddt_at_position1 = float( - plddt_df.query("residueNumber == @crosslink.crosslinker_position1").iloc[0][ - "confidenceScore" - ] + plddt_df.query( + "residueNumber == @crosslink.crosslinker_position1 and " + + "chainID == @crosslink.Chain_id1" + ).iloc[0]["confidenceScore"] ) plddt_at_position2 = float( - plddt_df.query("residueNumber == @crosslink.crosslinker_position2").iloc[0][ - "confidenceScore" - ] + plddt_df.query( + "residueNumber == @crosslink.crosslinker_position2 and " + + "chainID == @crosslink.Chain_id2" + ).iloc[0]["confidenceScore"] ) return plddt_at_position1, plddt_at_position2 - + def get_paes(): - # TODO for multimers: get correct PAE index (global index, not per-chain) if pae_matrix is None: return np.nan, np.nan + pae_index_pos1 = get_global_residue_index( + crosslink.crosslinker_position1, crosslink.Chain_id1, cif_df + ) + pae_index_pos2 = get_global_residue_index( + crosslink.crosslinker_position2, crosslink.Chain_id2, cif_df + ) pae_x_position1 = pae_matrix[ - crosslink.crosslinker_position1 - 1, crosslink.crosslinker_position2 - 1 + pae_index_pos1, pae_index_pos2 ] # Using position1 as scored residue pae_x_position2 = pae_matrix[ - crosslink.crosslinker_position2 - 1, crosslink.crosslinker_position1 - 1 + pae_index_pos2, pae_index_pos1 ] # Using position2 as scored residue return pae_x_position1, pae_x_position2 - plddt_at_position1, plddt_at_position2 = get_site_plddts() + plddt_at_position1, plddt_at_position2 = get_site_plddts(crosslink) pae_x_position1, pae_x_position2 = get_paes() predicted_distance = get_distance_between_two_amino_acids_in_angstrom( @@ -686,8 +734,6 @@ def get_paes(): plddt / 100 ) - plddt_at_position1, plddt_at_position2 = get_site_plddts() - plddt_factor_pos1 = get_plddt_factor(plddt_at_position1) plddt_factor_pos2 = get_plddt_factor(plddt_at_position2) diff --git a/backend/protzilla/importing/alphafold_protein_structure_load.py b/backend/protzilla/importing/alphafold_protein_structure_load.py index 52ec1a31d..e2a6f069b 100644 --- a/backend/protzilla/importing/alphafold_protein_structure_load.py +++ b/backend/protzilla/importing/alphafold_protein_structure_load.py @@ -330,6 +330,9 @@ def handle_alphafold_files( if temp_dir is not None: shutil.rmtree(temp_dir, ignore_errors=True) + # For consistency with multimer pLDDT + plddt_df["chainID"] = "A" + return { "cif_df": cif_df, "pae_df": pae_df, @@ -434,9 +437,7 @@ def fetch_alphafold_protein_structure( return dict( **df_dict, - pae_matrix=OutputItem( - output_type=OutputType.JOBLIB_ARTIFACT, value=pae_matrix - ), + pae_matrix=OutputItem(output_type=OutputType.JOBLIB_ARTIFACT, value=pae_matrix), messages=messages, visualization=OutputItem( output_type=OutputType.VISUALIZATION, value=data_for_visualization @@ -703,6 +704,9 @@ def get_monomer_structure_dfs(entry_id: str) -> dict[str, Any]: logger.exception(msg) raise RuntimeError(msg) from e + # For consistency with multimer pLDDT + plddt_df["chainID"] = "A" + df_dict = { "structure_metadata_df": monomer_metadata_df, "cif_df": cif_df, @@ -723,9 +727,7 @@ def get_monomer_structure_dfs(entry_id: str) -> dict[str, Any]: return dict( **df_dict, - pae_matrix=OutputItem( - output_type=OutputType.JOBLIB_ARTIFACT, value=pae_matrix - ), + pae_matrix=OutputItem(output_type=OutputType.JOBLIB_ARTIFACT, value=pae_matrix), messages=messages, visualization=OutputItem( output_type=OutputType.VISUALIZATION, value=data_for_visualization @@ -745,19 +747,20 @@ def unwrap_full_data_df(full_data_df: pd.DataFrame) -> dict[str, Any]: """ # Construct plDDT dataframe - # TODO: Getting pLDDT from AlphaFold3 is a bit harder as its on a per-atom level + # TODO: Getting pLDDT from AlphaFold3 is a bit harder as its on a per-atom level # rather than per-residue, so we'd need to extract it from the cif file # (column _atom_site.B_iso_or_equiv, see https://github.com/google-deepmind/alphafold3/issues/330). # Skipping this for now. pae_matrix = np.array(full_data_df["pae"].iloc[0]) full_data_df = full_data_df.drop(columns=["pae"]) - + return dict( full_data_df=full_data_df, pae_matrix=pae_matrix, ) + def get_plddt_from_cif(cif_df: pd.DataFrame): """ For use with multimers predicted using Alphafold3. @@ -766,22 +769,31 @@ def get_plddt_from_cif(cif_df: pd.DataFrame): See also https://github.com/google-deepmind/alphafold3/issues/330 :param cif_df: the cif_df holding the _atom_site table. - :return: DataFrame containing columns + :return: DataFrame containing columns "chainID", "residueNumber", "confidenceScore", "confidenceCategory" """ - + filtered_cif_df = cif_df[cif_df["_atom_site.label_atom_id"] == "CA"] - filtered_cif_df = filtered_cif_df[["_atom_site.auth_asym_id", "_atom_site.label_seq_id", "_atom_site.B_iso_or_equiv"]] - filtered_cif_df = filtered_cif_df.rename(columns={ - "_atom_site.auth_asym_id": "chainID", - "_atom_site.label_seq_id": "residueNumber", - "_atom_site.B_iso_or_equiv": "confidenceScore", - }) + filtered_cif_df = filtered_cif_df[ + [ + "_atom_site.auth_asym_id", + "_atom_site.label_seq_id", + "_atom_site.B_iso_or_equiv", + ] + ] + filtered_cif_df = filtered_cif_df.rename( + columns={ + "_atom_site.auth_asym_id": "chainID", + "_atom_site.label_seq_id": "residueNumber", + "_atom_site.B_iso_or_equiv": "confidenceScore", + } + ) # TODO: ConfidenceCategory maybe yes? - + return filtered_cif_df + def get_multimer_structure_dfs(entry_id: str) -> dict[str, Any]: """ Writes multimer structure data from disk of a specific entry ID into dataframes. @@ -876,9 +888,7 @@ def get_multimer_structure_dfs(entry_id: str) -> dict[str, Any]: **df_dict, messages=messages, plddt_df=plddt_df, - pae_matrix=OutputItem( - output_type=OutputType.JOBLIB_ARTIFACT, value=pae_matrix - ), + pae_matrix=OutputItem(output_type=OutputType.JOBLIB_ARTIFACT, value=pae_matrix), visualization=OutputItem( output_type=OutputType.VISUALIZATION, value=data_for_visualization ), @@ -1018,8 +1028,9 @@ def upload_multimer_prediction( unwrapped_full_data = unwrap_full_data_df(df_dict["full_data_df"]) df_dict["full_data_df"] = unwrapped_full_data["full_data_df"] - pae_matrix=OutputItem( - output_type=OutputType.JOBLIB_ARTIFACT, value=unwrapped_full_data["pae_matrix"] + pae_matrix = OutputItem( + output_type=OutputType.JOBLIB_ARTIFACT, + value=unwrapped_full_data["pae_matrix"], ) df_dict["pae_matrix"] = pae_matrix df_dict["plddt_df"] = get_plddt_from_cif(df_dict["cif_df"]) From 58ac815ecae0df8cad754136b2b458eda7c9baef Mon Sep 17 00:00:00 2001 From: jorisfu Date: Tue, 19 May 2026 16:06:45 +0200 Subject: [PATCH 11/33] fix: adjust existing cl validation tests --- .../data_analysis/crosslinking_validation.py | 3 +++ .../data_analysis/test_crosslinking_validation.py | 13 ++++++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/backend/protzilla/data_analysis/crosslinking_validation.py b/backend/protzilla/data_analysis/crosslinking_validation.py index 7658a5aa9..a5c82e0aa 100644 --- a/backend/protzilla/data_analysis/crosslinking_validation.py +++ b/backend/protzilla/data_analysis/crosslinking_validation.py @@ -748,6 +748,9 @@ def get_paes(): crosslinker_length + tolerance_pos1 + tolerance_pos2 ) + case _: + raise ValueError("Invalid validation strategy") + valid = ( accepted_distance_lower_bound <= predicted_distance diff --git a/backend/tests/protzilla/data_analysis/test_crosslinking_validation.py b/backend/tests/protzilla/data_analysis/test_crosslinking_validation.py index f7456aa37..24bfc5bb1 100644 --- a/backend/tests/protzilla/data_analysis/test_crosslinking_validation.py +++ b/backend/tests/protzilla/data_analysis/test_crosslinking_validation.py @@ -6,6 +6,7 @@ import plotly.graph_objects as go from plotly.graph_objects import Figure import pandas.testing as pdt +import numpy from backend.protzilla.data_analysis.crosslinking_validation import ( @@ -41,6 +42,7 @@ def test_monomer_validation(distance, expected): cif_df = pd.DataFrame( { "_atom_site.label_atom_id": ["CA", "CA"], + "_atom_site.label_asym_id": ["A", "A"], "_atom_site.label_seq_id": [1, 2], "_atom_site.Cartn_x": [0, distance], "_atom_site.Cartn_y": [0, 0], @@ -81,10 +83,13 @@ def test_monomer_validation(distance, expected): } ) + pae_matrix_noerror = numpy.array([[0, 0], [0, 0]]) + plddt_df_noerror = pd.DataFrame( { "residueNumber": [1, 2], "confidenceScore": [100, 100], + "chainID": ["A", "A"], # confidenceCategory is not required } ) @@ -98,7 +103,7 @@ def test_monomer_validation(distance, expected): valid_ids=valid_ids, id_column_name="_atom_site.pdbx_sifts_xref_db_acc", structures_to_validate=structures_to_validate, - pae_df=pae_df_noerror, + pae_matrix=pae_matrix_noerror, plddt_df=plddt_df_noerror, validation_criterion=CrosslinkingValidationCriterion.manual_bounds.value, ) @@ -115,6 +120,7 @@ def test_monomer_validation(distance, expected): assert df.loc[0, "link_type"] == "intra" # Validation with PAE + # TODO: Proper cases with error def test_modify_form_creates_crosslinker_fields(): @@ -373,6 +379,7 @@ def test_validate_multimer_filters_only_pairs_within_structures_to_validate(): valid_ids=valid_ids, id_column_name="_atom_site.label_entity_id", structures_to_validate=structures_to_validate, + validation_criterion=CrosslinkingValidationCriterion.manual_bounds.value, ) result_df = out["crosslinking_result_df"] @@ -448,6 +455,7 @@ def test_validate_multimer_no_links_between_structures_returns_empty_and_warning valid_ids=valid_ids, id_column_name="_atom_site.label_entity_id", structures_to_validate=structures_to_validate, + validation_criterion=CrosslinkingValidationCriterion.manual_bounds.value, ) result_df = out["crosslinking_result_df"] @@ -519,6 +527,7 @@ def test_validate_multimer_duplicates_rows_for_multiple_peptide_matches_and_vali valid_ids=valid_ids, id_column_name="_atom_site.label_entity_id", structures_to_validate=structures_to_validate, + validation_criterion=CrosslinkingValidationCriterion.manual_bounds.value, ) result_df = out["crosslinking_result_df"] @@ -858,6 +867,7 @@ def test_validate_multimer_with_invalid_crosslinks(): valid_ids=valid_ids, id_column_name="_atom_site.label_entity_id", structures_to_validate=structures_to_validate, + validation_criterion=CrosslinkingValidationCriterion.manual_bounds.value, ) result_df = out["crosslinking_result_df"] @@ -1034,6 +1044,7 @@ def test_validate_multimer_same_protein_different_chains_intra_vs_inter(): valid_ids=valid_ids, id_column_name="_atom_site.label_entity_id", structures_to_validate=structures_to_validate, + validation_criterion=CrosslinkingValidationCriterion.manual_bounds.value, ) result_df = out["crosslinking_result_df"] From 634e7127d6aa1ed421406f10bb566d821d26a679 Mon Sep 17 00:00:00 2001 From: jorisfu Date: Tue, 19 May 2026 16:38:26 +0200 Subject: [PATCH 12/33] fix: some alphafold import tests --- .../test_alphafold_protein_structure_load.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/backend/tests/protzilla/importing/test_alphafold_protein_structure_load.py b/backend/tests/protzilla/importing/test_alphafold_protein_structure_load.py index 940bbcea2..1ddeaf018 100644 --- a/backend/tests/protzilla/importing/test_alphafold_protein_structure_load.py +++ b/backend/tests/protzilla/importing/test_alphafold_protein_structure_load.py @@ -1,9 +1,11 @@ +from backend.protzilla.steps import OutputItem import pandas as pd import pytest import json import logging import shutil from pathlib import Path +import numpy as np from backend.protzilla.importing.alphafold_protein_structure_load import ( @@ -129,7 +131,7 @@ def test_fetch_alphafold_returned_keys(tmp_path, monkeypatch): assert out.keys() == { "structure_metadata_df", "cif_df", - "pae_df", + "pae_matrix", "plddt_df", "amino_acid_sequences_df", "messages", @@ -202,9 +204,9 @@ def test_fetch_alphafold_dfs_exist(tmp_path, monkeypatch): assert not cif_df.empty assert any(col.startswith("_atom_site.") for col in cif_df.columns) - pae_df = out["pae_df"] - assert isinstance(pae_df, pd.DataFrame) - assert not pae_df.empty + pae_matrix = out["pae_matrix"] + assert isinstance(pae_matrix, OutputItem) + assert len(pae_matrix.value) != 0 plddt_df = out["plddt_df"] assert isinstance(plddt_df, pd.DataFrame) @@ -318,9 +320,9 @@ def test_get_prot_structure_dfs_success(tmp_path, monkeypatch): assert out["cif_df"]["_atom_site.type_symbol"].tolist() == ["N", "C"] assert out["cif_df"]["_atom_site.Cartn_x"].tolist() == ["1.0", "2.0"] - assert isinstance(out["pae_df"], pd.DataFrame) - assert not out["pae_df"].empty - assert out["pae_df"]["predicted_aligned_error"].tolist() == [0.1] + assert isinstance(out["pae_matrix"], OutputItem) + assert isinstance(out["pae_matrix"].value, np.ndarray) + assert out["pae_matrix"].value == 0.1 # 0D array (only one value) TODO: Change this to something more reasonable? idk assert isinstance(out["plddt_df"], pd.DataFrame) assert not out["plddt_df"].empty @@ -428,6 +430,7 @@ def test_get_amino_acid_sequences_df_and_handle_files(tmp_path, monkeypatch): assert isinstance(out["amino_acid_sequences_df"], pd.DataFrame) +# TODO: Add PAE handling here def test_upload_multimer_prediction_basic(tmp_path, monkeypatch): monkeypatch.setattr(paths, "ALPHAFOLD_MULTIMER_PATH", tmp_path) @@ -447,7 +450,7 @@ def test_upload_multimer_prediction_basic(tmp_path, monkeypatch): conf = tmp_path / "conf.json" conf.write_text('[{"residueNumber":1, "confidenceScore":99}]') full = tmp_path / "full.json" - full.write_text('{"a": [1,2]}') + full.write_text('{"a": [1,2]}') job_request = tmp_path / "job_request.json" job_request.write_text( json.dumps( From 2c66b35ecffd65d0568e24cd3cd99faad63ecb50 Mon Sep 17 00:00:00 2001 From: jorisfu Date: Thu, 21 May 2026 10:09:50 +0200 Subject: [PATCH 13/33] tempfix: bridge monomer plots so method doesn't fail --- .../data_analysis/crosslinking_validation.py | 53 +++++++++++-------- backend/protzilla/methods/data_analysis.py | 2 +- 2 files changed, 31 insertions(+), 24 deletions(-) diff --git a/backend/protzilla/data_analysis/crosslinking_validation.py b/backend/protzilla/data_analysis/crosslinking_validation.py index a5c82e0aa..3e44c0ffe 100644 --- a/backend/protzilla/data_analysis/crosslinking_validation.py +++ b/backend/protzilla/data_analysis/crosslinking_validation.py @@ -1019,43 +1019,50 @@ def diagrams_of_crosslinking_validation_data( def monomer_diagrams( - crosslinking_df: pd.DataFrame, + output_crosslinking_result_df: pd.DataFrame, structure_metadata_df: pd.DataFrame, crosslinker_information: dict[str, list[float]], - cif_df: pd.DataFrame, - amino_acid_sequences_df: pd.DataFrame, + validation_criterion: CrosslinkingValidationCriterion ) -> list[Figure]: """ Generates visual diagrams to evaluate crosslinking validation results for a monomeric protein structure. - This function acts as a wrapper that first runs the crosslink validation - step via `monomer_validation`. It then extracts the resulting dataframe - of validated crosslinks and passes it to the diagram generator to create - the final plots. - - :param crosslinking_df: DataFrame containing the full set of crosslinks. + :param output_crosslinking_result_df: DataFrame containing the CL validation results. :param structure_metadata_df: DataFrame containing structural metadata; the first row's 'uniprot_accession' is used as the target. :param crosslinker_information: Dictionary mapping crosslinker names to a list of three floats: [length, upper_bound, lower_bound]. - :param cif_df: DataFrame containing parsed mmCIF structural coordinate data. - :param amino_acid_sequences_df: DataFrame containing known amino acid sequences. :return: A list of Figure objects visualizing the crosslinking validation data. """ structures_to_validate = [structure_metadata_df["uniprot_accession"].iloc[0]] - validated_df = monomer_validation( - crosslinking_df, - structure_metadata_df, - crosslinker_information, - cif_df, - amino_acid_sequences_df, - )["crosslinking_result_df"] - return diagrams_of_crosslinking_validation_data( - validated_df=validated_df, - structures_to_validate=structures_to_validate, - crosslinker_information=crosslinker_information, - ) + + match validation_criterion: + case CrosslinkingValidationCriterion.manual_bounds.value: + return diagrams_of_crosslinking_validation_data( + validated_df=output_crosslinking_result_df, + structures_to_validate=structures_to_validate, + crosslinker_information=crosslinker_information, + ) + + # TODO: Separate Issue #429 + case CrosslinkingValidationCriterion.max_pae.value | CrosslinkingValidationCriterion.min_pae.value: + return diagrams_of_crosslinking_validation_data( + validated_df=output_crosslinking_result_df, + structures_to_validate=structures_to_validate, + crosslinker_information=crosslinker_information, + ) + + # TODO: Separate Issue #429 + case CrosslinkingValidationCriterion.plddt_adjusted.value: + return diagrams_of_crosslinking_validation_data( + validated_df=output_crosslinking_result_df, + structures_to_validate=structures_to_validate, + crosslinker_information=crosslinker_information, + ) + + case _: + return [] def multimer_diagrams( diff --git a/backend/protzilla/methods/data_analysis.py b/backend/protzilla/methods/data_analysis.py index e90abb8ef..a1dafd0ab 100644 --- a/backend/protzilla/methods/data_analysis.py +++ b/backend/protzilla/methods/data_analysis.py @@ -2434,7 +2434,7 @@ class CrosslinkingValidationWithAngstromDeviation( operation = "Crosslinking Validation" method_description = "Validates crosslinks within the one protein structure based on the difference between the length of the crosslinker and the distance between the amino acids which were connected by the crosslinker. (in Ångström)" calc_method = staticmethod(monomer_validation) - # plot_method = staticmethod(monomer_diagrams) + plot_method = staticmethod(monomer_diagrams) def create_form(self): return Form( From 7ecac2b9ac0e61a1536e1e6c0c18ce57994885dd Mon Sep 17 00:00:00 2001 From: jorisfu Date: Thu, 21 May 2026 10:20:43 +0200 Subject: [PATCH 14/33] tempfix: bridge multimer plots so method doesn't fail --- .../data_analysis/crosslinking_validation.py | 55 +++++++++++-------- backend/protzilla/methods/data_analysis.py | 2 +- 2 files changed, 33 insertions(+), 24 deletions(-) diff --git a/backend/protzilla/data_analysis/crosslinking_validation.py b/backend/protzilla/data_analysis/crosslinking_validation.py index 3e44c0ffe..dd50a34b5 100644 --- a/backend/protzilla/data_analysis/crosslinking_validation.py +++ b/backend/protzilla/data_analysis/crosslinking_validation.py @@ -1033,6 +1033,7 @@ def monomer_diagrams( first row's 'uniprot_accession' is used as the target. :param crosslinker_information: Dictionary mapping crosslinker names to a list of three floats: [length, upper_bound, lower_bound]. + :param validation_criterion: The validation criterion used for validation. :return: A list of Figure objects visualizing the crosslinking validation data. """ structures_to_validate = [structure_metadata_df["uniprot_accession"].iloc[0]] @@ -1066,29 +1067,25 @@ def monomer_diagrams( def multimer_diagrams( - crosslinking_df: pd.DataFrame, - structure_metadata_df: pd.DataFrame, + output_crosslinking_result_df: pd.DataFrame, crosslinker_information: dict[str, list[float]], - cif_df: pd.DataFrame, amino_acid_sequences_df: pd.DataFrame, job_request_df: pd.DataFrame, + validation_criterion: CrosslinkingValidationCriterion, ) -> list[Figure]: """ Generates visual diagrams to evaluate crosslinking validation results for a multimeric protein complex. This function parses an AlphaFold job request to determine the valid chain - compositions. It then runs `multimer_validation` to filter and validate - the relevant crosslinks, extracting the result to generate structural - distance and validation plots. + compositions and uses the passed result from the validation. - :param crosslinking_df: DataFrame containing the full set of crosslinks. - :param structure_metadata_df: DataFrame containing structural metadata. + :param output_crosslinking_result_df: DataFrame containing the CL validation results. :param crosslinker_information: Dictionary mapping crosslinker names to a list of three floats: [length, upper_bound, lower_bound]. - :param cif_df: DataFrame containing parsed mmCIF structural coordinate data. :param amino_acid_sequences_df: DataFrame containing known amino acid sequences. :param job_request_df: DataFrame containing the loaded AlphaFold job request JSON. + :param validation_criterion: The validation criterion used for validation. :return: A list of Figure objects visualizing the crosslinking validation data. """ valid_ids = get_valid_ids_per_protein_id_from_job_request( @@ -1096,17 +1093,29 @@ def multimer_diagrams( ) structures_to_validate = list(valid_ids.keys()) - validated_df = multimer_validation( - crosslinking_df, - structure_metadata_df, - crosslinker_information, - cif_df, - amino_acid_sequences_df, - job_request_df, - )["crosslinking_result_df"] - - return diagrams_of_crosslinking_validation_data( - validated_df=validated_df, - structures_to_validate=structures_to_validate, - crosslinker_information=crosslinker_information, - ) + match validation_criterion: + case CrosslinkingValidationCriterion.manual_bounds.value: + return diagrams_of_crosslinking_validation_data( + validated_df=output_crosslinking_result_df, + structures_to_validate=structures_to_validate, + crosslinker_information=crosslinker_information, + ) + + # TODO: Separate Issue #429 + case CrosslinkingValidationCriterion.max_pae.value | CrosslinkingValidationCriterion.min_pae.value: + return diagrams_of_crosslinking_validation_data( + validated_df=output_crosslinking_result_df, + structures_to_validate=structures_to_validate, + crosslinker_information=crosslinker_information, + ) + + # TODO: Separate Issue #429 + case CrosslinkingValidationCriterion.plddt_adjusted.value: + return diagrams_of_crosslinking_validation_data( + validated_df=output_crosslinking_result_df, + structures_to_validate=structures_to_validate, + crosslinker_information=crosslinker_information, + ) + + case _: + return [] diff --git a/backend/protzilla/methods/data_analysis.py b/backend/protzilla/methods/data_analysis.py index a1dafd0ab..48693799e 100644 --- a/backend/protzilla/methods/data_analysis.py +++ b/backend/protzilla/methods/data_analysis.py @@ -2460,7 +2460,7 @@ class CrosslinkingValidationWithAngstromDeviationForMultimer( operation = "Crosslinking Validation" method_description = "Validates crosslinks between proteins based on the difference between the length of the crosslinker and the distance between the amino acids which were connected by the crosslinker. (in Ångström)" calc_method = staticmethod(multimer_validation) - # plot_method = staticmethod(multimer_diagrams) + plot_method = staticmethod(multimer_diagrams) def create_form(self): return Form( From f2518c15e54ca931cac1d00f7d5402c57965261a Mon Sep 17 00:00:00 2001 From: jorisfu Date: Thu, 21 May 2026 10:44:14 +0200 Subject: [PATCH 15/33] chore: remove obsolete todos --- .../importing/alphafold_protein_structure_load.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/backend/protzilla/importing/alphafold_protein_structure_load.py b/backend/protzilla/importing/alphafold_protein_structure_load.py index 44868b8bd..90a6089a9 100644 --- a/backend/protzilla/importing/alphafold_protein_structure_load.py +++ b/backend/protzilla/importing/alphafold_protein_structure_load.py @@ -746,12 +746,6 @@ def unwrap_full_data_df(full_data_df: pd.DataFrame) -> dict[str, Any]: - "pae_matrix": Numpy matrix with the PAE values for each residue pair """ - # Construct plDDT dataframe - # TODO: Getting pLDDT from AlphaFold3 is a bit harder as its on a per-atom level - # rather than per-residue, so we'd need to extract it from the cif file - # (column _atom_site.B_iso_or_equiv, see https://github.com/google-deepmind/alphafold3/issues/330). - # Skipping this for now. - pae_matrix = np.array(full_data_df["pae"].iloc[0]) full_data_df = full_data_df.drop(columns=["pae"]) @@ -789,8 +783,6 @@ def get_plddt_from_cif(cif_df: pd.DataFrame): } ) - # TODO: ConfidenceCategory maybe yes? - return filtered_cif_df From e13e2a3fe1a43052e0c377d95e9d0ee1141c1f8e Mon Sep 17 00:00:00 2001 From: jorisfu Date: Thu, 21 May 2026 14:15:56 +0200 Subject: [PATCH 16/33] chore: adjust some tests --- backend/protzilla/constants/cif_columns.py | 63 ++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 backend/protzilla/constants/cif_columns.py diff --git a/backend/protzilla/constants/cif_columns.py b/backend/protzilla/constants/cif_columns.py new file mode 100644 index 000000000..5715f2f32 --- /dev/null +++ b/backend/protzilla/constants/cif_columns.py @@ -0,0 +1,63 @@ +from enum import StrEnum + + +ATOM_SITE_PREFIX = "_atom_site." + + +class ATOM_SITE_COLUMNS(StrEnum): + """ + Enum containing all column names that should be present in + the _atom_site. table for mmCIF files from PDB or AFDB + """ + + ID = f"{ATOM_SITE_PREFIX}id" + TYPE_SYMBOL = f"{ATOM_SITE_PREFIX}type_symbol" + LABEL_ATOM_ID = f"{ATOM_SITE_PREFIX}label_atom_id" + LABEL_ALT_ID = f"{ATOM_SITE_PREFIX}label_alt_id" + LABEL_COMP_ID = f"{ATOM_SITE_PREFIX}label_comp_id" + LABEL_ASYM_ID = f"{ATOM_SITE_PREFIX}label_asym_id" + LABEL_ENTITY_ID = f"{ATOM_SITE_PREFIX}label_entity_id" + LABEL_SEQ_ID = f"{ATOM_SITE_PREFIX}label_seq_id" + PDBX_PDB_INS_CODE = f"{ATOM_SITE_PREFIX}pdbx_PDB_ins_code" + CARTN_X = f"{ATOM_SITE_PREFIX}Cartn_x" + CARTN_Y = f"{ATOM_SITE_PREFIX}Cartn_y" + CARTN_Z = f"{ATOM_SITE_PREFIX}Cartn_z" + OCCUPANCY = f"{ATOM_SITE_PREFIX}occupancy" + B_ISO_OR_EQUIV = f"{ATOM_SITE_PREFIX}B_iso_or_equiv" + PDBX_FORMAL_CHARGE = f"{ATOM_SITE_PREFIX}pdbx_formal_charge" + AUTH_SEQ_ID = f"{ATOM_SITE_PREFIX}auth_seq_id" + AUTH_COMP_ID = f"{ATOM_SITE_PREFIX}auth_comp_id" + AUTH_ASYM_ID = f"{ATOM_SITE_PREFIX}auth_asym_id" + AUTH_ATOM_ID = f"{ATOM_SITE_PREFIX}auth_atom_id" + PDBX_PDB_MODEL_NUM = f"{ATOM_SITE_PREFIX}pdbx_PDB_model_num" + + +ATOM_SITE_LABEL_COMP_ID = ATOM_SITE_COLUMNS.LABEL_COMP_ID + +ATOM_SITE_COLUMNS_NUMERIC = [ + ATOM_SITE_COLUMNS.ID, + ATOM_SITE_COLUMNS.LABEL_SEQ_ID, + ATOM_SITE_COLUMNS.CARTN_X, + ATOM_SITE_COLUMNS.CARTN_Y, + ATOM_SITE_COLUMNS.CARTN_Z, + ATOM_SITE_COLUMNS.OCCUPANCY, + ATOM_SITE_COLUMNS.B_ISO_OR_EQUIV, + ATOM_SITE_COLUMNS.AUTH_SEQ_ID, +] + +CHEM_COMP_PREFIX = "_chem_comp." + + +class CHEM_COMP_COLUMNS(StrEnum): + """ + Enum containing all column names that should be present in + the _chem_comp. table for mmCIF files from PDB or AFDB + """ + + ID = f"{CHEM_COMP_PREFIX}id" + TYPE = f"{CHEM_COMP_PREFIX}type" + MON_NSTD_FLAG = f"{CHEM_COMP_PREFIX}mon_nstd_flag" + NAME = f"{CHEM_COMP_PREFIX}name" + PDBX_SYNONYMS = f"{CHEM_COMP_PREFIX}pdbx_synonyms" + FORMULA = f"{CHEM_COMP_PREFIX}formula" + FORMULA_WEIGHT = f"{CHEM_COMP_PREFIX}formula_weight" From 0c645e5a17972836399d297bf11c7724031783a5 Mon Sep 17 00:00:00 2001 From: jorisfu Date: Thu, 21 May 2026 14:18:16 +0200 Subject: [PATCH 17/33] chore: adjust some tests --- .../alphafold_protein_structure_load.py | 1 - .../test_alphafold_protein_structure_load.py | 65 +++++++++++++------ 2 files changed, 45 insertions(+), 21 deletions(-) diff --git a/backend/protzilla/importing/alphafold_protein_structure_load.py b/backend/protzilla/importing/alphafold_protein_structure_load.py index 90a6089a9..4c4c4a799 100644 --- a/backend/protzilla/importing/alphafold_protein_structure_load.py +++ b/backend/protzilla/importing/alphafold_protein_structure_load.py @@ -128,7 +128,6 @@ def read_alphafold_mmcif(path: Path) -> pd.DataFrame: else: col_values.append(None) data[col] = col_values - return pd.DataFrame(data) diff --git a/backend/tests/protzilla/importing/test_alphafold_protein_structure_load.py b/backend/tests/protzilla/importing/test_alphafold_protein_structure_load.py index 1ddeaf018..ba3aa68c5 100644 --- a/backend/tests/protzilla/importing/test_alphafold_protein_structure_load.py +++ b/backend/tests/protzilla/importing/test_alphafold_protein_structure_load.py @@ -430,7 +430,6 @@ def test_get_amino_acid_sequences_df_and_handle_files(tmp_path, monkeypatch): assert isinstance(out["amino_acid_sequences_df"], pd.DataFrame) -# TODO: Add PAE handling here def test_upload_multimer_prediction_basic(tmp_path, monkeypatch): monkeypatch.setattr(paths, "ALPHAFOLD_MULTIMER_PATH", tmp_path) @@ -438,19 +437,30 @@ def test_upload_multimer_prediction_basic(tmp_path, monkeypatch): fasta = tmp_path / "seqs.fasta" fasta.write_text(">alpha|X\nAAAA\n") cif = tmp_path / "m.cif" + # Note that we only write the absolutely necessary columns here + # Also this is not biologically plausible cif.write_text( """ -data_test -loop_ -_atom_site.id -_atom_site.type_symbol -N N -""" + data_test + loop_ + _atom_site.id + _atom_site.label_atom_id + _atom_site.auth_asym_id + _atom_site.label_seq_id + _atom_site.B_iso_or_equiv + 1 N A 1 99.99 + 2 CA A 1 67.76 + 3 CA A 2 33.65 + 4 O A 2 5.52 + 5 N B 1 0 + 6 CA B 1 13.37 + # + """ ) conf = tmp_path / "conf.json" - conf.write_text('[{"residueNumber":1, "confidenceScore":99}]') + conf.write_text('{"chain_iptm": [0.42, 0.89]}') # Note that we do not use these metrics anywhere full = tmp_path / "full.json" - full.write_text('{"a": [1,2]}') + full.write_text('{"random_column": [1,2], "pae": [[1, 2], [3, 4]]}') job_request = tmp_path / "job_request.json" job_request.write_text( json.dumps( @@ -461,11 +471,18 @@ def test_upload_multimer_prediction_basic(tmp_path, monkeypatch): "sequences": [ { "proteinChain": { - "sequence": "AAAA", + "sequence": "PE", "count": 1, "useStructureTemplate": True, } - } + }, + { + "proteinChain": { + "sequence": "T", + "count": 1, + "useStructureTemplate": True, + } + }, ], "dialect": "alphafoldserver", "version": 3, @@ -487,7 +504,7 @@ def _copy(src, dest_dir): out = upload_multimer_prediction( entry_id="M1", - uniprot_ids="X", + uniprot_ids="X, Y", model_used="m", amino_acid_sequences=fasta, cif_file=cif, @@ -497,30 +514,38 @@ def _copy(src, dest_dir): persist_upload=True, ) - assert isinstance(out["structure_metadata_df"], pd.DataFrame) # check metadata contents mdf = out["structure_metadata_df"] + assert isinstance(mdf, pd.DataFrame) assert mdf.iloc[0]["entry_id"] == "M1" - assert mdf.iloc[0]["uniprot_ids"] == ["X"] + assert mdf.iloc[0]["uniprot_ids"] == ["X", "Y"] assert mdf.iloc[0]["model_used"] == "m" # cif contents cif_df = out["cif_df"] assert isinstance(cif_df, pd.DataFrame) - assert list(cif_df.columns) == ["_atom_site.id", "_atom_site.type_symbol"] - assert cif_df["_atom_site.id"].tolist() == ["N"] - assert cif_df["_atom_site.type_symbol"].tolist() == ["N"] + assert list(cif_df.columns) == [ + "_atom_site.id", + "_atom_site.label_atom_id", + "_atom_site.auth_asym_id", + "_atom_site.label_seq_id", + "_atom_site.B_iso_or_equiv", + ] + # assert cif_df["_atom_site.id"].tolist() == list(range(1, 7)) + assert cif_df["_atom_site.label_atom_id"].tolist() == ["N", "CA", "CA", "O", "N", "CA"] + assert cif_df["_atom_site.auth_asym_id"].tolist() == ["A"] * 4 + ["B"] * 2 + assert cif_df["_atom_site.B_iso_or_equiv"].tolist() == [99.99, 67.76, 33.65, 5.52, 0, 13.37] # confidence JSON conf_df = out["confidence_df"] assert isinstance(conf_df, pd.DataFrame) - assert conf_df["residueNumber"].tolist() == [1] - assert conf_df["confidenceScore"].tolist() == [99] + assert conf_df["chain_iptm"].tolist() == [0.42, 0.89] # full data normalization full_df = out["full_data_df"] assert isinstance(full_df, pd.DataFrame) - assert full_df.iloc[0]["a"] == [1, 2] + assert list(full_df.columns) == ["random_column"] + assert full_df.iloc[0]["random_column"] == [1, 2] # job request JSON job_df = out["job_request_df"] From c1d55a1336142d5d1d2b4c86d670536ba646882d Mon Sep 17 00:00:00 2001 From: Tarek Massini Date: Sun, 10 May 2026 20:55:46 +0200 Subject: [PATCH 18/33] feat: introduce parsing of _chem_comp table in cif-files --- .../alphafold_protein_structure_load.py | 70 +++-- .../test_alphafold_protein_structure_load.py | 261 +++++++++++------- 2 files changed, 219 insertions(+), 112 deletions(-) diff --git a/backend/protzilla/importing/alphafold_protein_structure_load.py b/backend/protzilla/importing/alphafold_protein_structure_load.py index 4c4c4a799..c2724eb0d 100644 --- a/backend/protzilla/importing/alphafold_protein_structure_load.py +++ b/backend/protzilla/importing/alphafold_protein_structure_load.py @@ -18,6 +18,13 @@ from backend.protzilla.constants import paths from backend.protzilla.constants.protzilla_logging import logger +from backend.protzilla.constants.cif_columns import ( + ATOM_SITE_PREFIX, + ATOM_SITE_COLUMNS, + ATOM_SITE_COLUMNS_NUMERIC, + CHEM_COMP_PREFIX, + CHEM_COMP_COLUMNS, +) from backend.protzilla.importing.fasta_import import fasta_import from backend.protzilla.networking import download_file_from_url from backend.protzilla.utilities.utilities import copy_file_to_directory @@ -110,25 +117,54 @@ def read_alphafold_mmcif(path: Path) -> pd.DataFrame: block = doc.sole_block() - cat_name = "_atom_site." - if cat_name not in block.get_mmcif_category_names(): + if ATOM_SITE_PREFIX not in block.get_mmcif_category_names(): return pd.DataFrame() - table = block.find_mmcif_category(cat_name) - - columns = list(table.tags) - nrows = len(table) - data = {} - for j, col in enumerate(columns): - col_values = [] - for i in range(nrows): - row = table[i] - if j < len(row): - col_values.append(row[j]) - else: - col_values.append(None) - data[col] = col_values - return pd.DataFrame(data) + atom_site_table = block.find_mmcif_category(ATOM_SITE_PREFIX) + + atom_site_df = pd.DataFrame( + list(atom_site_table), + columns=list(atom_site_table.tags), + dtype=pd.StringDtype(), + ) + + # convert to numeric dtype for numeric columns present in the dataframe + present_numeric_columns = [ + column for column in ATOM_SITE_COLUMNS_NUMERIC if column in atom_site_table.tags + ] + atom_site_df[present_numeric_columns] = atom_site_df[present_numeric_columns].apply( + pd.to_numeric, errors="coerce" + ) + + atom_site_df = atom_site_df.convert_dtypes() + + if CHEM_COMP_PREFIX not in block.get_mmcif_category_names(): + raise ValueError( + f"Required table with prefix {CHEM_COMP_PREFIX} not found in {path}" + ) + + chem_comp_table = block.find_mmcif_category(CHEM_COMP_PREFIX) + + chem_comp_df = pd.DataFrame( + list(chem_comp_table), + columns=list(chem_comp_table.tags), + dtype=pd.StringDtype(), + )[[CHEM_COMP_COLUMNS.ID, CHEM_COMP_COLUMNS.MON_NSTD_FLAG]] + + # convert flags to native booleans + bool_map = {"y": True, "n": False, ".": pd.NA} + + chem_comp_df[CHEM_COMP_COLUMNS.MON_NSTD_FLAG] = ( + chem_comp_df[CHEM_COMP_COLUMNS.MON_NSTD_FLAG].map(bool_map).astype("boolean") + ) + + # merge on the comp_id and drop the duplicate column + return atom_site_df.merge( + chem_comp_df, + how="left", + left_on=ATOM_SITE_COLUMNS.LABEL_COMP_ID, + right_on=CHEM_COMP_COLUMNS.ID, + ).drop(CHEM_COMP_COLUMNS.ID, axis=1) def get_correct_af_directories( diff --git a/backend/tests/protzilla/importing/test_alphafold_protein_structure_load.py b/backend/tests/protzilla/importing/test_alphafold_protein_structure_load.py index ba3aa68c5..9a23060b6 100644 --- a/backend/tests/protzilla/importing/test_alphafold_protein_structure_load.py +++ b/backend/tests/protzilla/importing/test_alphafold_protein_structure_load.py @@ -31,6 +31,12 @@ check_success_of_get_df, ) from backend.protzilla.constants import paths +from backend.protzilla.constants.cif_columns import ( + ATOM_SITE_PREFIX, + ATOM_SITE_COLUMNS, + CHEM_COMP_COLUMNS, +) +from backend.protzilla.constants.data_types import DataKey def test_to_fasta_default_header_and_newline(): @@ -92,11 +98,18 @@ def test_read_alphafold_mmcif_valid_atom_site(tmp_path): """ data_test loop_ +_chem_comp.id +_chem_comp.mon_nstd_flag +SER y +# +loop_ _atom_site.id _atom_site.type_symbol +_atom_site.label_atom_id +_atom_site.label_comp_id _atom_site.Cartn_x -N N 1.0 -CA C 2.0 +1 N N SER 1.0 +2 C CA SER 2.0 """ ) @@ -104,14 +117,19 @@ def test_read_alphafold_mmcif_valid_atom_site(tmp_path): assert isinstance(df, pd.DataFrame) assert list(df.columns) == [ - "_atom_site.id", - "_atom_site.type_symbol", - "_atom_site.Cartn_x", + ATOM_SITE_COLUMNS.ID, + ATOM_SITE_COLUMNS.TYPE_SYMBOL, + ATOM_SITE_COLUMNS.LABEL_ATOM_ID, + ATOM_SITE_COLUMNS.LABEL_COMP_ID, + ATOM_SITE_COLUMNS.CARTN_X, + CHEM_COMP_COLUMNS.MON_NSTD_FLAG, ] assert len(df) == 2 - assert df["_atom_site.id"].tolist() == ["N", "CA"] - assert df["_atom_site.type_symbol"].tolist() == ["N", "C"] - assert df["_atom_site.Cartn_x"].tolist() == ["1.0", "2.0"] + assert df[ATOM_SITE_COLUMNS.ID].tolist() == [1, 2] + assert df[ATOM_SITE_COLUMNS.TYPE_SYMBOL].tolist() == ["N", "C"] + assert df[ATOM_SITE_COLUMNS.LABEL_ATOM_ID].tolist() == ["N", "CA"] + assert df[ATOM_SITE_COLUMNS.CARTN_X].tolist() == [1.0, 2.0] + assert df[CHEM_COMP_COLUMNS.MON_NSTD_FLAG].tolist() == [True, True] def test_fetch_alphafold_protein_structure_wrong_uniprot_id(): @@ -129,11 +147,11 @@ def test_fetch_alphafold_returned_keys(tmp_path, monkeypatch): out = fetch_alphafold_protein_structure("Q8WP00", persist_upload=True) assert out.keys() == { - "structure_metadata_df", - "cif_df", - "pae_matrix", - "plddt_df", - "amino_acid_sequences_df", + DataKey.STRUCTURE_METADATA_DF, + DataKey.CIF_DF, + DataKey.PAE_MATRIX, + DataKey.PLDDT_DF, + DataKey.AMINO_ACID_SEQUENCES_DF, "messages", "visualization", } @@ -148,16 +166,16 @@ def test_fetch_alphafold_monomer_metadata(tmp_path, monkeypatch): ) out = fetch_alphafold_protein_structure("Q8WP00", persist_upload=True) - assert isinstance(out["structure_metadata_df"], pd.DataFrame) - assert not out["structure_metadata_df"].empty - assert out["structure_metadata_df"].iloc[0]["uniprot_accession"] == "Q8WP00" + assert isinstance(out[DataKey.STRUCTURE_METADATA_DF], pd.DataFrame) + assert not out[DataKey.STRUCTURE_METADATA_DF].empty + assert out[DataKey.STRUCTURE_METADATA_DF].iloc[0]["uniprot_accession"] == "Q8WP00" assert ( - out["structure_metadata_df"].iloc[0]["model_created_date"] + out[DataKey.STRUCTURE_METADATA_DF].iloc[0]["model_created_date"] == "2025-08-01T00:00:00Z" ) - assert out["structure_metadata_df"].iloc[0]["gene"] == "PRM1" + assert out[DataKey.STRUCTURE_METADATA_DF].iloc[0]["gene"] == "PRM1" assert ( - out["structure_metadata_df"].iloc[0]["model_used"] + out[DataKey.STRUCTURE_METADATA_DF].iloc[0]["model_used"] == "AlphaFold Monomer v2.0 pipeline" ) @@ -199,20 +217,20 @@ def test_fetch_alphafold_dfs_exist(tmp_path, monkeypatch): out = fetch_alphafold_protein_structure("Q8WP00", persist_upload=True) - cif_df = out["cif_df"] + cif_df = out[DataKey.CIF_DF] assert isinstance(cif_df, pd.DataFrame) assert not cif_df.empty - assert any(col.startswith("_atom_site.") for col in cif_df.columns) + assert any(col.startswith(ATOM_SITE_PREFIX) for col in cif_df.columns) - pae_matrix = out["pae_matrix"] + pae_matrix = out[DataKey.PAE_MATRIX] assert isinstance(pae_matrix, OutputItem) assert len(pae_matrix.value) != 0 - plddt_df = out["plddt_df"] + plddt_df = out[DataKey.PLDDT_DF] assert isinstance(plddt_df, pd.DataFrame) assert not plddt_df.empty - seq_df = out["amino_acid_sequences_df"] + seq_df = out[DataKey.AMINO_ACID_SEQUENCES_DF] assert isinstance(seq_df, pd.DataFrame) assert not seq_df.empty @@ -282,11 +300,18 @@ def test_get_prot_structure_dfs_success(tmp_path, monkeypatch): """ data_test loop_ +_chem_comp.id +_chem_comp.mon_nstd_flag +SER y +# +loop_ _atom_site.id _atom_site.type_symbol +_atom_site.label_atom_id +_atom_site.label_comp_id _atom_site.Cartn_x -N N 1.0 -CA C 2.0 +1 N N SER 1.0 +2 C CA SER 2.0 """ ) @@ -305,34 +330,41 @@ def test_get_prot_structure_dfs_success(tmp_path, monkeypatch): out = get_monomer_structure_dfs("Q8WP00") - assert isinstance(out["structure_metadata_df"], pd.DataFrame) - assert not out["structure_metadata_df"].empty - assert out["structure_metadata_df"].iloc[0]["entry_id"] == "Q8WP00" + assert isinstance(out[DataKey.STRUCTURE_METADATA_DF], pd.DataFrame) + assert not out[DataKey.STRUCTURE_METADATA_DF].empty + assert out[DataKey.STRUCTURE_METADATA_DF].iloc[0]["entry_id"] == "Q8WP00" - assert isinstance(out["cif_df"], pd.DataFrame) - assert not out["cif_df"].empty - assert list(out["cif_df"].columns) == [ - "_atom_site.id", - "_atom_site.type_symbol", - "_atom_site.Cartn_x", + cif_df = out[DataKey.CIF_DF] + assert isinstance(cif_df, pd.DataFrame) + assert not cif_df.empty + assert list(cif_df.columns) == [ + ATOM_SITE_COLUMNS.ID, + ATOM_SITE_COLUMNS.TYPE_SYMBOL, + ATOM_SITE_COLUMNS.LABEL_ATOM_ID, + ATOM_SITE_COLUMNS.LABEL_COMP_ID, + ATOM_SITE_COLUMNS.CARTN_X, + CHEM_COMP_COLUMNS.MON_NSTD_FLAG, ] - assert out["cif_df"]["_atom_site.id"].tolist() == ["N", "CA"] - assert out["cif_df"]["_atom_site.type_symbol"].tolist() == ["N", "C"] - assert out["cif_df"]["_atom_site.Cartn_x"].tolist() == ["1.0", "2.0"] - - assert isinstance(out["pae_matrix"], OutputItem) - assert isinstance(out["pae_matrix"].value, np.ndarray) - assert out["pae_matrix"].value == 0.1 # 0D array (only one value) TODO: Change this to something more reasonable? idk - - assert isinstance(out["plddt_df"], pd.DataFrame) - assert not out["plddt_df"].empty - assert out["plddt_df"]["residueNumber"].tolist() == [1] - assert out["plddt_df"]["confidenceScore"].tolist() == [90] - - assert isinstance(out["amino_acid_sequences_df"], pd.DataFrame) - assert not out["amino_acid_sequences_df"].empty - assert out["amino_acid_sequences_df"]["Protein ID"].tolist() == ["Q8WP00-1"] - assert out["amino_acid_sequences_df"]["Protein Sequence"].tolist() == ["AAAA"] + assert len(cif_df) == 2 + assert cif_df[ATOM_SITE_COLUMNS.ID].tolist() == [1, 2] + assert cif_df[ATOM_SITE_COLUMNS.TYPE_SYMBOL].tolist() == ["N", "C"] + assert cif_df[ATOM_SITE_COLUMNS.LABEL_ATOM_ID].tolist() == ["N", "CA"] + assert cif_df[ATOM_SITE_COLUMNS.CARTN_X].tolist() == [1.0, 2.0] + assert cif_df[CHEM_COMP_COLUMNS.MON_NSTD_FLAG].tolist() == [True, True] + + assert isinstance(out[DataKey.PAE_MATRIX], OutputItem) + assert isinstance(out[DataKey.PAE_MATRIX].value, np.ndarray) + assert out[DataKey.PAE_MATRIX].value == 0.1 # 0D array (only one value) TODO: Change this to something more reasonable? idk + + assert isinstance(out[DataKey.PLDDT_DF], pd.DataFrame) + assert not out[DataKey.PLDDT_DF].empty + assert out[DataKey.PLDDT_DF]["residueNumber"].tolist() == [1] + assert out[DataKey.PLDDT_DF]["confidenceScore"].tolist() == [90] + + assert isinstance(out[DataKey.AMINO_ACID_SEQUENCES_DF], pd.DataFrame) + assert not out[DataKey.AMINO_ACID_SEQUENCES_DF].empty + assert out[DataKey.AMINO_ACID_SEQUENCES_DF]["Protein ID"].tolist() == ["Q8WP00-1"] + assert out[DataKey.AMINO_ACID_SEQUENCES_DF]["Protein Sequence"].tolist() == ["AAAA"] assert any(d.get("level") == logging.INFO for d in out["messages"]) or any( "Successfully loaded" in d.get("msg", "") for d in out["messages"] @@ -423,11 +455,13 @@ def test_get_amino_acid_sequences_df_and_handle_files(tmp_path, monkeypatch): out = handle_alphafold_files( {}, "P", "TESTSEQ", metadata_df, "P", persist_upload=False ) - assert "amino_acid_sequences_df" in out - assert isinstance(out["cif_df"], pd.DataFrame) and out["cif_df"].empty - assert isinstance(out["pae_df"], pd.DataFrame) and out["pae_df"].empty - assert isinstance(out["plddt_df"], pd.DataFrame) and out["plddt_df"].empty - assert isinstance(out["amino_acid_sequences_df"], pd.DataFrame) + assert DataKey.AMINO_ACID_SEQUENCES_DF in out + assert isinstance(out[DataKey.CIF_DF], pd.DataFrame) and out[DataKey.CIF_DF].empty + assert isinstance(out[DataKey.PAE_DF], pd.DataFrame) and out[DataKey.PAE_DF].empty + assert ( + isinstance(out[DataKey.PLDDT_DF], pd.DataFrame) and out[DataKey.PLDDT_DF].empty + ) + assert isinstance(out[DataKey.AMINO_ACID_SEQUENCES_DF], pd.DataFrame) def test_upload_multimer_prediction_basic(tmp_path, monkeypatch): @@ -443,6 +477,11 @@ def test_upload_multimer_prediction_basic(tmp_path, monkeypatch): """ data_test loop_ + _chem_comp.id + _chem_comp.mon_nstd_flag + SER y + # + loop_ _atom_site.id _atom_site.label_atom_id _atom_site.auth_asym_id @@ -514,47 +553,49 @@ def _copy(src, dest_dir): persist_upload=True, ) + assert isinstance(out[DataKey.STRUCTURE_METADATA_DF], pd.DataFrame) # check metadata contents - mdf = out["structure_metadata_df"] - assert isinstance(mdf, pd.DataFrame) + mdf = out[DataKey.STRUCTURE_METADATA_DF] assert mdf.iloc[0]["entry_id"] == "M1" assert mdf.iloc[0]["uniprot_ids"] == ["X", "Y"] assert mdf.iloc[0]["model_used"] == "m" # cif contents - cif_df = out["cif_df"] + cif_df = out[DataKey.CIF_DF] assert isinstance(cif_df, pd.DataFrame) assert list(cif_df.columns) == [ - "_atom_site.id", - "_atom_site.label_atom_id", - "_atom_site.auth_asym_id", - "_atom_site.label_seq_id", - "_atom_site.B_iso_or_equiv", + ATOM_SITE_COLUMNS.ID, + ATOM_SITE_COLUMNS.TYPE_SYMBOL, + ATOM_SITE_COLUMNS.LABEL_ATOM_ID, + ATOM_SITE_COLUMNS.LABEL_COMP_ID, + CHEM_COMP_COLUMNS.MON_NSTD_FLAG, ] - # assert cif_df["_atom_site.id"].tolist() == list(range(1, 7)) - assert cif_df["_atom_site.label_atom_id"].tolist() == ["N", "CA", "CA", "O", "N", "CA"] - assert cif_df["_atom_site.auth_asym_id"].tolist() == ["A"] * 4 + ["B"] * 2 - assert cif_df["_atom_site.B_iso_or_equiv"].tolist() == [99.99, 67.76, 33.65, 5.52, 0, 13.37] + assert cif_df[ATOM_SITE_COLUMNS.ID].tolist() == list(range(1, 7)) + assert cif_df[ATOM_SITE_COLUMNS.LABEL_ATOM_ID].tolist() == ["N", "CA", "CA", "O", "N", "CA"] + assert cif_df[ATOM_SITE_COLUMNS.AUTH_ASYM_ID].tolist() == ["A"] * 4 + ["B"] * 2 + assert cif_df[ATOM_SITE_COLUMNS.B_ISO_OR_EQUIV].tolist() == [99.99, 67.76, 33.65, 5.52, 0, 13.37] + assert cif_df[ATOM_SITE_COLUMNS.LABEL_COMP_ID].tolist() == ["SER"] + assert cif_df[CHEM_COMP_COLUMNS.MON_NSTD_FLAG].tolist() == [True] # confidence JSON - conf_df = out["confidence_df"] + conf_df = out[DataKey.CONFIDENCE_DF] assert isinstance(conf_df, pd.DataFrame) assert conf_df["chain_iptm"].tolist() == [0.42, 0.89] # full data normalization - full_df = out["full_data_df"] + full_df = out[DataKey.FULL_DATA_DF] assert isinstance(full_df, pd.DataFrame) assert list(full_df.columns) == ["random_column"] assert full_df.iloc[0]["random_column"] == [1, 2] # job request JSON - job_df = out["job_request_df"] + job_df = out[DataKey.JOB_REQUEST_DF] assert isinstance(job_df, pd.DataFrame) assert job_df.iloc[0]["name"] == "test_job" assert job_df.iloc[0]["dialect"] == "alphafoldserver" # sequences - seqs = out["amino_acid_sequences_df"] + seqs = out[DataKey.AMINO_ACID_SEQUENCES_DF] assert isinstance(seqs, pd.DataFrame) assert seqs["Protein Sequence"].tolist() == ["AAAA"] assert any(str(v).startswith("X") for v in seqs["Protein ID"].tolist()) @@ -640,7 +681,9 @@ def test_upload_multimer_prediction_no_persist(tmp_path, monkeypatch): fasta = tmp_path / "seqs.fasta" fasta.write_text(">alpha|X\nAAAA\n") cif = tmp_path / "m.cif" - cif.write_text("data_test\nloop_\n_atom_site.id\nN\n") + cif.write_text( + "data_test\nloop_\n_chem_comp.id\n_chem_comp.mon_nstd_flag\nSER y\nloop_\n#\n_atom_site.id\n_atom_site.label_comp_id\nN SER\n" + ) conf = tmp_path / "conf.json" conf.write_text('[{"residueNumber":1, "confidenceScore":99}]') full = tmp_path / "full.json" @@ -681,10 +724,10 @@ def test_upload_multimer_prediction_no_persist(tmp_path, monkeypatch): ) # verify dataframes are returned - assert isinstance(out["structure_metadata_df"], pd.DataFrame) - assert isinstance(out["cif_df"], pd.DataFrame) - assert isinstance(out["job_request_df"], pd.DataFrame) - assert out["job_request_df"].iloc[0]["name"] == "test_job_2" + assert isinstance(out[DataKey.STRUCTURE_METADATA_DF], pd.DataFrame) + assert isinstance(out[DataKey.CIF_DF], pd.DataFrame) + assert isinstance(out[DataKey.JOB_REQUEST_DF], pd.DataFrame) + assert out[DataKey.JOB_REQUEST_DF].iloc[0]["name"] == "test_job_2" # directory should still exist (created for the entry) upload_dir = tmp_path / "M2" assert not upload_dir.exists() @@ -720,7 +763,9 @@ def test_get_prot_structure_dfs_missing_fasta(tmp_path, monkeypatch): # create CIF but no FASTA cif = prot_dir / "test.cif" - cif.write_text("data_test\nloop_\n_atom_site.id\nN\n") + cif.write_text( + "data_test\nloop_\n_chem_comp.id\n_chem_comp.mon_nstd_flag\nSER y\nloop_\n#\n_atom_site.id\n_atom_site.label_comp_id\nN SER\n" + ) with pytest.raises(FileNotFoundError, match="No FASTA file found"): get_monomer_structure_dfs("NOFASTA") @@ -740,7 +785,9 @@ def test_get_prot_structure_dfs_missing_json(tmp_path, monkeypatch): # create CIF and FASTA but no JSON cif = prot_dir / "test.cif" - cif.write_text("data_test\nloop_\n_atom_site.id\nN\n") + cif.write_text( + "data_test\nloop_\n_chem_comp.id\n_chem_comp.mon_nstd_flag\nSER y\nloop_\n#\n_atom_site.id\n_atom_site.label_comp_id\nN SER\n" + ) fasta = prot_dir / "test.fasta" # valid header for parse_fasta_id (expects at least one "|" in the id) @@ -848,18 +895,30 @@ def test_get_cif_df_from_disk_multiple_cif_warns(tmp_path): """ data_test loop_ +_chem_comp.id +_chem_comp.mon_nstd_flag +SER y +# +loop_ _atom_site.id _atom_site.type_symbol -N N +_atom_site.label_comp_id +N N SER """ ) cif2.write_text( """ data_test loop_ +_chem_comp.id +_chem_comp.mon_nstd_flag +SER y +# +loop_ _atom_site.id _atom_site.type_symbol -CA C +_atom_site.label_comp_id +CA C SER """ ) @@ -900,9 +959,15 @@ def test_get_multimer_structure_dfs_success(tmp_path, monkeypatch): """ data_test loop_ +_chem_comp.id +_chem_comp.mon_nstd_flag +SER y +# +loop_ _atom_site.id _atom_site.type_symbol -N N +_atom_site.label_comp_id +N N SER """ ) @@ -937,17 +1002,17 @@ def test_get_multimer_structure_dfs_success(tmp_path, monkeypatch): ) out = get_multimer_structure_dfs("M1") - assert isinstance(out["structure_metadata_df"], pd.DataFrame) - assert isinstance(out["cif_df"], pd.DataFrame) - assert isinstance(out["amino_acid_sequences_df"], pd.DataFrame) - assert isinstance(out["confidence_df"], pd.DataFrame) - assert isinstance(out["full_data_df"], pd.DataFrame) - assert isinstance(out["job_request_df"], pd.DataFrame) - - assert "chain_iptm" in out["confidence_df"].columns - assert "pae" in out["full_data_df"].columns - assert out["job_request_df"].iloc[0]["name"] == "multimer_job" - assert out["job_request_df"].iloc[0]["version"] == 3 + assert isinstance(out[DataKey.STRUCTURE_METADATA_DF], pd.DataFrame) + assert isinstance(out[DataKey.CIF_DF], pd.DataFrame) + assert isinstance(out[DataKey.AMINO_ACID_SEQUENCES_DF], pd.DataFrame) + assert isinstance(out[DataKey.CONFIDENCE_DF], pd.DataFrame) + assert isinstance(out[DataKey.FULL_DATA_DF], pd.DataFrame) + assert isinstance(out[DataKey.JOB_REQUEST_DF], pd.DataFrame) + + assert "chain_iptm" in out[DataKey.CONFIDENCE_DF].columns + assert "pae" in out[DataKey.FULL_DATA_DF].columns + assert out[DataKey.JOB_REQUEST_DF].iloc[0]["name"] == "multimer_job" + assert out[DataKey.JOB_REQUEST_DF].iloc[0]["version"] == 3 assert any(m.get("level") == logging.INFO for m in out["messages"]) or any( "Successfully loaded" in str(m.get("msg", "")) for m in out["messages"] @@ -984,9 +1049,15 @@ def test_get_multimer_structure_dfs_json_fallback_warns(tmp_path, monkeypatch): """ data_test loop_ +_chem_comp.id +_chem_comp.mon_nstd_flag +SER y +# +loop_ _atom_site.id _atom_site.type_symbol -N N +_atom_site.label_comp_id +N N SER """ ) From a935874771197a2ca873cd21982268931e5cb2a3 Mon Sep 17 00:00:00 2001 From: jorisfu Date: Thu, 21 May 2026 16:05:54 +0200 Subject: [PATCH 19/33] chore: fix existing tests --- .../data_analysis/crosslinking_validation.py | 13 ++- .../alphafold_protein_structure_load.py | 57 ++++++++----- backend/protzilla/methods/data_analysis.py | 4 +- .../test_alphafold_protein_structure_load.py | 80 ++++++++++++++----- 4 files changed, 112 insertions(+), 42 deletions(-) diff --git a/backend/protzilla/data_analysis/crosslinking_validation.py b/backend/protzilla/data_analysis/crosslinking_validation.py index cad5214b7..55df3780b 100644 --- a/backend/protzilla/data_analysis/crosslinking_validation.py +++ b/backend/protzilla/data_analysis/crosslinking_validation.py @@ -1082,7 +1082,7 @@ def monomer_diagrams( output_crosslinking_result_df: pd.DataFrame, structure_metadata_df: pd.DataFrame, crosslinker_information: dict[str, list[float]], - validation_criterion: CrosslinkingValidationCriterion + validation_criterion: CrosslinkingValidationCriterion, ) -> list[Figure]: """ Generates visual diagrams to evaluate crosslinking validation results @@ -1107,7 +1107,10 @@ def monomer_diagrams( ) # TODO: Separate Issue #429 - case CrosslinkingValidationCriterion.max_pae.value | CrosslinkingValidationCriterion.min_pae.value: + case ( + CrosslinkingValidationCriterion.max_pae.value + | CrosslinkingValidationCriterion.min_pae.value + ): return diagrams_of_crosslinking_validation_data( validated_df=output_crosslinking_result_df, structures_to_validate=structures_to_validate, @@ -1162,7 +1165,10 @@ def multimer_diagrams( ) # TODO: Separate Issue #429 - case CrosslinkingValidationCriterion.max_pae.value | CrosslinkingValidationCriterion.min_pae.value: + case ( + CrosslinkingValidationCriterion.max_pae.value + | CrosslinkingValidationCriterion.min_pae.value + ): return diagrams_of_crosslinking_validation_data( validated_df=output_crosslinking_result_df, structures_to_validate=structures_to_validate, @@ -1180,6 +1186,7 @@ def multimer_diagrams( case _: return [] + # Warning: Mostly AI generated def create_cl_validation_histogram( distances_valid: pd.Series, diff --git a/backend/protzilla/importing/alphafold_protein_structure_load.py b/backend/protzilla/importing/alphafold_protein_structure_load.py index c2724eb0d..b78d2f955 100644 --- a/backend/protzilla/importing/alphafold_protein_structure_load.py +++ b/backend/protzilla/importing/alphafold_protein_structure_load.py @@ -781,8 +781,11 @@ def unwrap_full_data_df(full_data_df: pd.DataFrame) -> dict[str, Any]: - "pae_matrix": Numpy matrix with the PAE values for each residue pair """ - pae_matrix = np.array(full_data_df["pae"].iloc[0]) - full_data_df = full_data_df.drop(columns=["pae"]) + try: + pae_matrix = np.array(full_data_df["pae"].iloc[0]) + full_data_df = full_data_df.drop(columns=["pae"]) + except KeyError: + pae_matrix = None return dict( full_data_df=full_data_df, @@ -790,7 +793,7 @@ def unwrap_full_data_df(full_data_df: pd.DataFrame) -> dict[str, Any]: ) -def get_plddt_from_cif(cif_df: pd.DataFrame): +def get_plddt_from_cif(cif_df: pd.DataFrame) -> pd.DataFrame | None: """ For use with multimers predicted using Alphafold3. Returns per-residue pLDDT values for the predicted structure. @@ -802,23 +805,26 @@ def get_plddt_from_cif(cif_df: pd.DataFrame): "chainID", "residueNumber", "confidenceScore", "confidenceCategory" """ - filtered_cif_df = cif_df[cif_df["_atom_site.label_atom_id"] == "CA"] - filtered_cif_df = filtered_cif_df[ - [ - "_atom_site.auth_asym_id", - "_atom_site.label_seq_id", - "_atom_site.B_iso_or_equiv", + try: + filtered_cif_df = cif_df[cif_df["_atom_site.label_atom_id"] == "CA"] + filtered_cif_df = filtered_cif_df[ + [ + "_atom_site.auth_asym_id", + "_atom_site.label_seq_id", + "_atom_site.B_iso_or_equiv", + ] ] - ] - filtered_cif_df = filtered_cif_df.rename( - columns={ - "_atom_site.auth_asym_id": "chainID", - "_atom_site.label_seq_id": "residueNumber", - "_atom_site.B_iso_or_equiv": "confidenceScore", - } - ) + filtered_cif_df = filtered_cif_df.rename( + columns={ + "_atom_site.auth_asym_id": "chainID", + "_atom_site.label_seq_id": "residueNumber", + "_atom_site.B_iso_or_equiv": "confidenceScore", + } + ) + return filtered_cif_df - return filtered_cif_df + except KeyError: + return None def get_multimer_structure_dfs(entry_id: str) -> dict[str, Any]: @@ -911,6 +917,14 @@ def get_multimer_structure_dfs(entry_id: str) -> dict[str, Any]: pae_matrix = unwrapped_full_data["pae_matrix"] plddt_df = get_plddt_from_cif(df_dict["cif_df"]) + if plddt_df is None: + messages.append( + dict( + level=logging.WARNING, + msg=f"Could not parse pLDDT values from CIF file. File is likely malformed!", + ) + ) + return dict( **df_dict, messages=messages, @@ -1072,6 +1086,13 @@ def upload_multimer_prediction( df_dict["pae_matrix"] = pae_matrix df_dict["plddt_df"] = get_plddt_from_cif(df_dict["cif_df"]) + if df_dict["plddt_df"] is None: + messages.append( + dict( + level=logging.WARNING, + msg=f"Could not parse pLDDT values from CIF file. File is likely malformed!", + ) + ) data_for_visualization = { "structure_entry_id": entry_id, "cif_df": cif_df, diff --git a/backend/protzilla/methods/data_analysis.py b/backend/protzilla/methods/data_analysis.py index 4c3196708..5deb6557f 100644 --- a/backend/protzilla/methods/data_analysis.py +++ b/backend/protzilla/methods/data_analysis.py @@ -2463,7 +2463,7 @@ def create_form(self): ), InfoField( label="Set default cross-link lengths and their upper/lower deviations in settings under 'Cross-Links Defaults'.", - ) + ), ], ) @@ -2492,6 +2492,6 @@ def create_form(self): ), InfoField( label="Set default cross-link lengths and their upper/lower deviations in settings under 'Cross-Links Defaults'.", - ) + ), ], ) diff --git a/backend/tests/protzilla/importing/test_alphafold_protein_structure_load.py b/backend/tests/protzilla/importing/test_alphafold_protein_structure_load.py index 9a23060b6..31acf52b0 100644 --- a/backend/tests/protzilla/importing/test_alphafold_protein_structure_load.py +++ b/backend/tests/protzilla/importing/test_alphafold_protein_structure_load.py @@ -354,7 +354,9 @@ def test_get_prot_structure_dfs_success(tmp_path, monkeypatch): assert isinstance(out[DataKey.PAE_MATRIX], OutputItem) assert isinstance(out[DataKey.PAE_MATRIX].value, np.ndarray) - assert out[DataKey.PAE_MATRIX].value == 0.1 # 0D array (only one value) TODO: Change this to something more reasonable? idk + assert ( + out[DataKey.PAE_MATRIX].value == 0.1 + ) # 0D array (only one value) TODO: Change this to something more reasonable? idk assert isinstance(out[DataKey.PLDDT_DF], pd.DataFrame) assert not out[DataKey.PLDDT_DF].empty @@ -457,7 +459,7 @@ def test_get_amino_acid_sequences_df_and_handle_files(tmp_path, monkeypatch): ) assert DataKey.AMINO_ACID_SEQUENCES_DF in out assert isinstance(out[DataKey.CIF_DF], pd.DataFrame) and out[DataKey.CIF_DF].empty - assert isinstance(out[DataKey.PAE_DF], pd.DataFrame) and out[DataKey.PAE_DF].empty + assert isinstance(out["pae_df"], pd.DataFrame) and out["pae_df"].empty assert ( isinstance(out[DataKey.PLDDT_DF], pd.DataFrame) and out[DataKey.PLDDT_DF].empty ) @@ -480,26 +482,30 @@ def test_upload_multimer_prediction_basic(tmp_path, monkeypatch): _chem_comp.id _chem_comp.mon_nstd_flag SER y + GLY y # loop_ _atom_site.id _atom_site.label_atom_id + _atom_site.label_comp_id _atom_site.auth_asym_id _atom_site.label_seq_id _atom_site.B_iso_or_equiv - 1 N A 1 99.99 - 2 CA A 1 67.76 - 3 CA A 2 33.65 - 4 O A 2 5.52 - 5 N B 1 0 - 6 CA B 1 13.37 + 1 N SER A 1 99.99 + 2 CA SER A 1 67.76 + 3 CA SER A 2 33.65 + 4 O SER A 2 5.52 + 5 N GLY B 1 0 + 6 CA GLY B 1 13.37 # """ ) conf = tmp_path / "conf.json" - conf.write_text('{"chain_iptm": [0.42, 0.89]}') # Note that we do not use these metrics anywhere + conf.write_text( + '{"chain_iptm": [0.42, 0.89]}' + ) # Note that we do not use these metrics anywhere full = tmp_path / "full.json" - full.write_text('{"random_column": [1,2], "pae": [[1, 2], [3, 4]]}') + full.write_text('{"random_column": [1,2], "pae": [[1, 2], [3, 4]]}') job_request = tmp_path / "job_request.json" job_request.write_text( json.dumps( @@ -565,17 +571,33 @@ def _copy(src, dest_dir): assert isinstance(cif_df, pd.DataFrame) assert list(cif_df.columns) == [ ATOM_SITE_COLUMNS.ID, - ATOM_SITE_COLUMNS.TYPE_SYMBOL, ATOM_SITE_COLUMNS.LABEL_ATOM_ID, ATOM_SITE_COLUMNS.LABEL_COMP_ID, + ATOM_SITE_COLUMNS.AUTH_ASYM_ID, + ATOM_SITE_COLUMNS.LABEL_SEQ_ID, + ATOM_SITE_COLUMNS.B_ISO_OR_EQUIV, CHEM_COMP_COLUMNS.MON_NSTD_FLAG, ] assert cif_df[ATOM_SITE_COLUMNS.ID].tolist() == list(range(1, 7)) - assert cif_df[ATOM_SITE_COLUMNS.LABEL_ATOM_ID].tolist() == ["N", "CA", "CA", "O", "N", "CA"] + assert cif_df[ATOM_SITE_COLUMNS.LABEL_ATOM_ID].tolist() == [ + "N", + "CA", + "CA", + "O", + "N", + "CA", + ] assert cif_df[ATOM_SITE_COLUMNS.AUTH_ASYM_ID].tolist() == ["A"] * 4 + ["B"] * 2 - assert cif_df[ATOM_SITE_COLUMNS.B_ISO_OR_EQUIV].tolist() == [99.99, 67.76, 33.65, 5.52, 0, 13.37] - assert cif_df[ATOM_SITE_COLUMNS.LABEL_COMP_ID].tolist() == ["SER"] - assert cif_df[CHEM_COMP_COLUMNS.MON_NSTD_FLAG].tolist() == [True] + assert cif_df[ATOM_SITE_COLUMNS.B_ISO_OR_EQUIV].tolist() == [ + 99.99, + 67.76, + 33.65, + 5.52, + 0, + 13.37, + ] + assert cif_df[ATOM_SITE_COLUMNS.LABEL_COMP_ID].tolist() == ["SER"] * 4 + ["GLY"] * 2 + assert cif_df[CHEM_COMP_COLUMNS.MON_NSTD_FLAG].tolist() == [True] * 6 # confidence JSON conf_df = out[DataKey.CONFIDENCE_DF] @@ -600,6 +622,24 @@ def _copy(src, dest_dir): assert seqs["Protein Sequence"].tolist() == ["AAAA"] assert any(str(v).startswith("X") for v in seqs["Protein ID"].tolist()) + # pLDDT values + plddt_df = out[DataKey.PLDDT_DF] + assert isinstance(plddt_df, pd.DataFrame) + assert list(plddt_df.columns) == ["chainID", "residueNumber", "confidenceScore"] + assert plddt_df["confidenceScore"].tolist() == [ + 67.76, + 33.65, + 13.37, + ] # Keep only CA atoms + + # PAE values + pae_matrix = out[DataKey.PAE_MATRIX].value + assert isinstance(pae_matrix, np.ndarray) + assert pae_matrix[0, 0] == 1 + assert pae_matrix[0, 1] == 2 + assert pae_matrix[1, 0] == 3 + assert pae_matrix[1, 1] == 4 + upload_dir = tmp_path / "M1" assert upload_dir.exists() assert any(upload_dir.glob("*.fasta")) or any(upload_dir.glob("*.fa")) @@ -608,8 +648,6 @@ def _copy(src, dest_dir): # Additional comprehensive tests for error cases and edge cases - - def test_get_monomer_metadata_df_existing_csv(tmp_path, monkeypatch): """Test reading existing monomer metadata CSV""" csv_path = tmp_path / "alphafold_monomer_metadata.csv" @@ -674,7 +712,10 @@ def test_to_fasta_lowercase_conversion(): def test_upload_multimer_prediction_no_persist(tmp_path, monkeypatch): - """Test upload_multimer_prediction with persist_upload=False""" + """ + Test upload_multimer_prediction with persist_upload=False. + Also tests full_data without PAE values + """ monkeypatch.setattr(paths, "ALPHAFOLD_MONOMER_PATH", tmp_path) monkeypatch.setattr(paths, "ALPHAFOLD_MULTIMER_PATH", tmp_path) @@ -727,6 +768,7 @@ def test_upload_multimer_prediction_no_persist(tmp_path, monkeypatch): assert isinstance(out[DataKey.STRUCTURE_METADATA_DF], pd.DataFrame) assert isinstance(out[DataKey.CIF_DF], pd.DataFrame) assert isinstance(out[DataKey.JOB_REQUEST_DF], pd.DataFrame) + assert out[DataKey.PAE_MATRIX].value is None assert out[DataKey.JOB_REQUEST_DF].iloc[0]["name"] == "test_job_2" # directory should still exist (created for the entry) upload_dir = tmp_path / "M2" @@ -1010,7 +1052,7 @@ def test_get_multimer_structure_dfs_success(tmp_path, monkeypatch): assert isinstance(out[DataKey.JOB_REQUEST_DF], pd.DataFrame) assert "chain_iptm" in out[DataKey.CONFIDENCE_DF].columns - assert "pae" in out[DataKey.FULL_DATA_DF].columns + assert "pae" not in out[DataKey.FULL_DATA_DF].columns assert out[DataKey.JOB_REQUEST_DF].iloc[0]["name"] == "multimer_job" assert out[DataKey.JOB_REQUEST_DF].iloc[0]["version"] == 3 From d5067870a6a00f46ff4a70ba20abcd13ebab81a8 Mon Sep 17 00:00:00 2001 From: jorisfu Date: Thu, 21 May 2026 16:13:42 +0200 Subject: [PATCH 20/33] chore: test for no pLDDT data within cif --- .../test_alphafold_protein_structure_load.py | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/backend/tests/protzilla/importing/test_alphafold_protein_structure_load.py b/backend/tests/protzilla/importing/test_alphafold_protein_structure_load.py index 31acf52b0..c64b7805e 100644 --- a/backend/tests/protzilla/importing/test_alphafold_protein_structure_load.py +++ b/backend/tests/protzilla/importing/test_alphafold_protein_structure_load.py @@ -646,6 +646,46 @@ def _copy(src, dest_dir): assert any(upload_dir.glob("*.json")) assert any(upload_dir.glob("*.cif")) + # Test no plDDT Data -> plddt_df should be None + cif.write_text( + """ + data_test + loop_ + _chem_comp.id + _chem_comp.mon_nstd_flag + SER y + GLY y + # + loop_ + _atom_site.id + _atom_site.label_atom_id + _atom_site.label_comp_id + _atom_site.auth_asym_id + _atom_site.label_seq_id + 1 N SER A 1 + 2 CA SER A 1 + 3 CA SER A 2 + 4 O SER A 2 + 5 N GLY B 1 + 6 CA GLY B 1 + # + """ + ) + + out = upload_multimer_prediction( + entry_id="M1", + uniprot_ids="X, Y", + model_used="m", + amino_acid_sequences=fasta, + cif_file=cif, + confidence_file=conf, + full_data_file=full, + job_request_file=job_request, + persist_upload=True, + ) + + assert out[DataKey.PLDDT_DF] is None + # Additional comprehensive tests for error cases and edge cases def test_get_monomer_metadata_df_existing_csv(tmp_path, monkeypatch): From a034f655dfc7ef66c66fa8f4a3aa916a4a5d3c19 Mon Sep 17 00:00:00 2001 From: jorisfu Date: Thu, 21 May 2026 17:05:16 +0200 Subject: [PATCH 21/33] chore: tests for PAE based CL validation --- .../test_crosslinking_validation.py | 170 ++++++++++++++++-- .../test_alphafold_protein_structure_load.py | 12 +- 2 files changed, 161 insertions(+), 21 deletions(-) diff --git a/backend/tests/protzilla/data_analysis/test_crosslinking_validation.py b/backend/tests/protzilla/data_analysis/test_crosslinking_validation.py index 263632a4f..426eb8d0c 100644 --- a/backend/tests/protzilla/data_analysis/test_crosslinking_validation.py +++ b/backend/tests/protzilla/data_analysis/test_crosslinking_validation.py @@ -6,7 +6,7 @@ import plotly.graph_objects as go from plotly.graph_objects import Figure import pandas.testing as pdt -import numpy +import numpy as np from backend.protzilla.data_analysis.crosslinking_validation import ( @@ -37,7 +37,7 @@ (6.01, False), # outside bounds ], ) -def test_monomer_validation(distance, expected): +def test_monomer_validation_baseline_manual_bounds(distance, expected): crosslinker_information = {"DSS": [5.0, 1.0, 1.0]} # Length 5 Å ± 1 Å # Fake AlphaFold Data with chain IDs @@ -78,15 +78,6 @@ def test_monomer_validation(distance, expected): valid_ids = {"P12345": ["P12345"]} structures_to_validate = ["P12345"] - pae_df_noerror = pd.DataFrame( - { - "predicted_aligned_error": ["[[0, 0], [0, 0]]"], - "max_predicted_aligned_error": [31.75], - } - ) - - pae_matrix_noerror = numpy.array([[0, 0], [0, 0]]) - plddt_df_noerror = pd.DataFrame( { "residueNumber": [1, 2], @@ -105,8 +96,6 @@ def test_monomer_validation(distance, expected): valid_ids=valid_ids, id_column_name="_atom_site.pdbx_sifts_xref_db_acc", structures_to_validate=structures_to_validate, - pae_matrix=pae_matrix_noerror, - plddt_df=plddt_df_noerror, validation_criterion=CrosslinkingValidationCriterion.manual_bounds.value, ) @@ -121,8 +110,159 @@ def test_monomer_validation(distance, expected): assert df.loc[0, "valid_crosslink"] == expected assert df.loc[0, "link_type"] == "intra" - # Validation with PAE - # TODO: Proper cases with error + +@pytest.mark.parametrize( + "distance, expected", + [ + (4.99, False), + (5.0, True), + (5.1, False), + ], +) +def test_cl_validation_pae_noerrror(distance, expected): + crosslinker_information = {"DSS": [5.0, 1.0, 1.0]} # Length 5 Å (± 1 Å) + pae_matrix = np.array([[np.nan, 0], [0, np.nan]]) + + # Fake AlphaFold Data with chain IDs + cif_df = pd.DataFrame( + { + "_atom_site.label_atom_id": ["CA", "CA"], + "_atom_site.label_asym_id": ["A", "A"], + "_atom_site.label_seq_id": [1, 2], + "_atom_site.Cartn_x": [0, distance], + "_atom_site.Cartn_y": [0, 0], + "_atom_site.Cartn_z": [0, 0], + "_atom_site.auth_asym_id": ["A", "A"], + "_atom_site.pdbx_sifts_xref_db_acc": ["P12345", "P12345"], + } + ) + + amino_acid_sequences_df = pd.DataFrame( + {"Protein ID": ["P12345-1"], "Protein Sequence": ["AB"]} + ) + + # Fake Crosslink Data + crosslinking_df = pd.DataFrame( + { + "Protein_id1": ["P12345"], + "Protein_id2": ["P12345"], + "Peptide1": ["A"], + "Peptide2": ["B"], + "CL_position_within_peptide1": [0], + "CL_position_within_peptide2": [0], + "Crosslinker": ["DSS"], + } + ) + + structure_metadata_df = pd.DataFrame( + {"entry_id": ["test"], "uniprot_accession": ["P12345"]} + ) + + valid_ids = {"P12345": ["P12345"]} + structures_to_validate = ["P12345"] + + result = validate_with_angstrom_deviation( + crosslinking_df=crosslinking_df, + structure_metadata_df=structure_metadata_df, + crosslinker_information=crosslinker_information, + cif_df=cif_df, + amino_acid_sequences_df=amino_acid_sequences_df, + valid_ids=valid_ids, + id_column_name="_atom_site.pdbx_sifts_xref_db_acc", + structures_to_validate=structures_to_validate, + validation_criterion=CrosslinkingValidationCriterion.min_pae.value, + pae_matrix=pae_matrix, + ) + + df: pd.DataFrame = result["crosslinking_result_df"] + assert df.loc[0, "valid_crosslink"] == expected + + +@pytest.mark.parametrize( + "distance, expected_min, expected_max", + [ + (2.0, False, False), + (3.0, False, True), + (4.0, True, True), + (5.0, True, True), + (6.0, True, True), + (7.0, False, True), + (8.0, False, False), + ], +) +def test_cl_validation_pae_haserror(distance, expected_min, expected_max): + crosslinker_information = {"DSS": [5.0, 1.0, 1.0]} # Length 5 Å (± 1 Å) + pae_matrix = np.array([[np.nan, 1], [2, np.nan]]) + + # Fake AlphaFold Data with chain IDs + cif_df = pd.DataFrame( + { + "_atom_site.label_atom_id": ["CA", "CA"], + "_atom_site.label_asym_id": ["A", "A"], + "_atom_site.label_seq_id": [1, 2], + "_atom_site.Cartn_x": [0, distance], + "_atom_site.Cartn_y": [0, 0], + "_atom_site.Cartn_z": [0, 0], + "_atom_site.auth_asym_id": ["A", "A"], + "_atom_site.pdbx_sifts_xref_db_acc": ["P12345", "P12345"], + } + ) + + amino_acid_sequences_df = pd.DataFrame( + {"Protein ID": ["P12345-1"], "Protein Sequence": ["AB"]} + ) + + # Fake Crosslink Data + crosslinking_df = pd.DataFrame( + { + "Protein_id1": ["P12345"], + "Protein_id2": ["P12345"], + "Peptide1": ["A"], + "Peptide2": ["B"], + "CL_position_within_peptide1": [0], + "CL_position_within_peptide2": [0], + "Crosslinker": ["DSS"], + } + ) + + structure_metadata_df = pd.DataFrame( + {"entry_id": ["test"], "uniprot_accession": ["P12345"]} + ) + + valid_ids = {"P12345": ["P12345"]} + structures_to_validate = ["P12345"] + + result_min = validate_with_angstrom_deviation( + crosslinking_df=crosslinking_df, + structure_metadata_df=structure_metadata_df, + crosslinker_information=crosslinker_information, + cif_df=cif_df, + amino_acid_sequences_df=amino_acid_sequences_df, + valid_ids=valid_ids, + id_column_name="_atom_site.pdbx_sifts_xref_db_acc", + structures_to_validate=structures_to_validate, + validation_criterion=CrosslinkingValidationCriterion.min_pae.value, + pae_matrix=pae_matrix, + ) + + df: pd.DataFrame = result_min["crosslinking_result_df"] + assert df.loc[0, "valid_crosslink"] == expected_min + + result_max = validate_with_angstrom_deviation( + crosslinking_df=crosslinking_df, + structure_metadata_df=structure_metadata_df, + crosslinker_information=crosslinker_information, + cif_df=cif_df, + amino_acid_sequences_df=amino_acid_sequences_df, + valid_ids=valid_ids, + id_column_name="_atom_site.pdbx_sifts_xref_db_acc", + structures_to_validate=structures_to_validate, + validation_criterion=CrosslinkingValidationCriterion.max_pae.value, + pae_matrix=pae_matrix, + ) + + df: pd.DataFrame = result_max["crosslinking_result_df"] + assert df.loc[0, "valid_crosslink"] == expected_max def test_modify_form_creates_crosslinker_fields(): diff --git a/backend/tests/protzilla/importing/test_alphafold_protein_structure_load.py b/backend/tests/protzilla/importing/test_alphafold_protein_structure_load.py index c64b7805e..6a51c0289 100644 --- a/backend/tests/protzilla/importing/test_alphafold_protein_structure_load.py +++ b/backend/tests/protzilla/importing/test_alphafold_protein_structure_load.py @@ -662,12 +662,12 @@ def _copy(src, dest_dir): _atom_site.label_comp_id _atom_site.auth_asym_id _atom_site.label_seq_id - 1 N SER A 1 - 2 CA SER A 1 - 3 CA SER A 2 - 4 O SER A 2 - 5 N GLY B 1 - 6 CA GLY B 1 + 1 N SER A 1 + 2 CA SER A 1 + 3 CA SER A 2 + 4 O SER A 2 + 5 N GLY B 1 + 6 CA GLY B 1 # """ ) From 6ec24c0b3793c4a1ccc79de03afa041d765eb2d0 Mon Sep 17 00:00:00 2001 From: jorisfu Date: Thu, 21 May 2026 17:20:04 +0200 Subject: [PATCH 22/33] chore: tests for pLDDT based CL validation --- .../test_crosslinking_validation.py | 167 +++++++++++++++++- 1 file changed, 158 insertions(+), 9 deletions(-) diff --git a/backend/tests/protzilla/data_analysis/test_crosslinking_validation.py b/backend/tests/protzilla/data_analysis/test_crosslinking_validation.py index 426eb8d0c..0e815a17d 100644 --- a/backend/tests/protzilla/data_analysis/test_crosslinking_validation.py +++ b/backend/tests/protzilla/data_analysis/test_crosslinking_validation.py @@ -78,15 +78,6 @@ def test_monomer_validation_baseline_manual_bounds(distance, expected): valid_ids = {"P12345": ["P12345"]} structures_to_validate = ["P12345"] - plddt_df_noerror = pd.DataFrame( - { - "residueNumber": [1, 2], - "confidenceScore": [100, 100], - "chainID": ["A", "A"], - # confidenceCategory is not required - } - ) - result = validate_with_angstrom_deviation( crosslinking_df=crosslinking_df, structure_metadata_df=structure_metadata_df, @@ -265,6 +256,164 @@ def test_cl_validation_pae_haserror(distance, expected_min, expected_max): assert df.loc[0, "valid_crosslink"] == expected_max +@pytest.mark.parametrize( + "distance, expected", + [ + (4.99, False), + (5.0, True), + (5.1, False), + ], +) +def test_cl_validation_plddt_noerrror(distance, expected): + crosslinker_information = {"DSS": [5.0, 1.0, 1.0]} # Length 5 Å (± 1 Å) + + plddt_df_noerror = pd.DataFrame( + { + "chainID": ["A", "A"], + "residueNumber": [1, 2], + "confidenceScore": [100, 100], + # confidenceCategory is not required + } + ) + + # Fake AlphaFold Data with chain IDs + cif_df = pd.DataFrame( + { + "_atom_site.label_atom_id": ["CA", "CA"], + "_atom_site.label_asym_id": ["A", "A"], + "_atom_site.label_seq_id": [1, 2], + "_atom_site.Cartn_x": [0, distance], + "_atom_site.Cartn_y": [0, 0], + "_atom_site.Cartn_z": [0, 0], + "_atom_site.auth_asym_id": ["A", "A"], + "_atom_site.pdbx_sifts_xref_db_acc": ["P12345", "P12345"], + } + ) + + amino_acid_sequences_df = pd.DataFrame( + {"Protein ID": ["P12345-1"], "Protein Sequence": ["AB"]} + ) + + # Fake Crosslink Data + crosslinking_df = pd.DataFrame( + { + "Protein_id1": ["P12345"], + "Protein_id2": ["P12345"], + "Peptide1": ["A"], + "Peptide2": ["B"], + "CL_position_within_peptide1": [0], + "CL_position_within_peptide2": [0], + "Crosslinker": ["DSS"], + } + ) + + structure_metadata_df = pd.DataFrame( + {"entry_id": ["test"], "uniprot_accession": ["P12345"]} + ) + + valid_ids = {"P12345": ["P12345"]} + structures_to_validate = ["P12345"] + + result = validate_with_angstrom_deviation( + crosslinking_df=crosslinking_df, + structure_metadata_df=structure_metadata_df, + crosslinker_information=crosslinker_information, + cif_df=cif_df, + amino_acid_sequences_df=amino_acid_sequences_df, + valid_ids=valid_ids, + id_column_name="_atom_site.pdbx_sifts_xref_db_acc", + structures_to_validate=structures_to_validate, + validation_criterion=CrosslinkingValidationCriterion.plddt_adjusted.value, + plddt_df=plddt_df_noerror, + ) + + df: pd.DataFrame = result["crosslinking_result_df"] + assert df.loc[0, "valid_crosslink"] == expected + + +# l_cl = 5, t_x = 1.25, t_y = 3.5. +# So range is 0.25 <= d <= 9.75 +@pytest.mark.parametrize( + "distance, expected", + [ + (0.0, False), + (0.24, False), + (0.25, True), + (5.0, True), + (9.0, True), + (9.74, True), + (9.75, True), + (9.76, False), + (10.0, False), + ], +) +def test_cl_validation_plddt_witherror(distance, expected): + crosslinker_information = {"DSS": [5.0, 1.0, 1.0]} # Length 5 Å (± 1 Å) + + plddt_df_noerror = pd.DataFrame( + { + "chainID": ["A", "A"], + "residueNumber": [1, 2], + "confidenceScore": [75, 30], + # confidenceCategory is not required + } + ) + + # Fake AlphaFold Data with chain IDs + cif_df = pd.DataFrame( + { + "_atom_site.label_atom_id": ["CA", "CA"], + "_atom_site.label_asym_id": ["A", "A"], + "_atom_site.label_seq_id": [1, 2], + "_atom_site.Cartn_x": [0, distance], + "_atom_site.Cartn_y": [0, 0], + "_atom_site.Cartn_z": [0, 0], + "_atom_site.auth_asym_id": ["A", "A"], + "_atom_site.pdbx_sifts_xref_db_acc": ["P12345", "P12345"], + } + ) + + amino_acid_sequences_df = pd.DataFrame( + {"Protein ID": ["P12345-1"], "Protein Sequence": ["AB"]} + ) + + # Fake Crosslink Data + crosslinking_df = pd.DataFrame( + { + "Protein_id1": ["P12345"], + "Protein_id2": ["P12345"], + "Peptide1": ["A"], + "Peptide2": ["B"], + "CL_position_within_peptide1": [0], + "CL_position_within_peptide2": [0], + "Crosslinker": ["DSS"], + } + ) + + structure_metadata_df = pd.DataFrame( + {"entry_id": ["test"], "uniprot_accession": ["P12345"]} + ) + + valid_ids = {"P12345": ["P12345"]} + structures_to_validate = ["P12345"] + + result = validate_with_angstrom_deviation( + crosslinking_df=crosslinking_df, + structure_metadata_df=structure_metadata_df, + crosslinker_information=crosslinker_information, + cif_df=cif_df, + amino_acid_sequences_df=amino_acid_sequences_df, + valid_ids=valid_ids, + id_column_name="_atom_site.pdbx_sifts_xref_db_acc", + structures_to_validate=structures_to_validate, + validation_criterion=CrosslinkingValidationCriterion.plddt_adjusted.value, + plddt_df=plddt_df_noerror, + ) + + df: pd.DataFrame = result["crosslinking_result_df"] + assert df.loc[0, "valid_crosslink"] == expected + + def test_modify_form_creates_crosslinker_fields(): crosslinking_df = pd.DataFrame({"Crosslinker": ["DSS", "BS3", "DSS"]}) From 2643937c5a3e8d2ddd83fd611404ec157560e56a Mon Sep 17 00:00:00 2001 From: jorisfu Date: Fri, 22 May 2026 15:27:50 +0200 Subject: [PATCH 23/33] feat: simple PAE scatter plot --- .../data_analysis/crosslinking_validation.py | 66 ++++++++++++++++++- 1 file changed, 63 insertions(+), 3 deletions(-) diff --git a/backend/protzilla/data_analysis/crosslinking_validation.py b/backend/protzilla/data_analysis/crosslinking_validation.py index 55df3780b..a65fe5e21 100644 --- a/backend/protzilla/data_analysis/crosslinking_validation.py +++ b/backend/protzilla/data_analysis/crosslinking_validation.py @@ -17,6 +17,7 @@ from pandas.io.stata import stata_epoch import plotly.graph_objects as go from plotly.graph_objects import Figure +import plotly.express as px from backend.protzilla.data_preprocessing.plots import ( create_histograms, @@ -1078,6 +1079,65 @@ def diagrams_of_crosslinking_validation_data( return figures +def cl_scatterplots_pae( + cl_results_df: pd.DataFrame, + structures_to_validate: list[str], + crosslinker_information: dict[str, list[float]], + validation_criterion: CrosslinkingValidationCriterion, +) -> list[Figure]: + + figures: list[Figure] = [] + + def get_relevant_pae_value( + pae_x_1: float, + pae_x_2: float, + validation_criterion: CrosslinkingValidationCriterion, + ): + if validation_criterion == CrosslinkingValidationCriterion.max_pae.value: + return max(pae_x_1, pae_x_2) + elif validation_criterion == CrosslinkingValidationCriterion.min_pae.value: + return min(pae_x_1, pae_x_2) + else: + raise ValueError("Illegal validation criterion for PAE plot") + + cl_results_df["measured_distance"] = cl_results_df.apply( + lambda row: crosslinker_information[row["Crosslinker"]][0], axis=1 + ) + cl_results_df["distance_delta"] = abs( + cl_results_df["measured_distance"] - cl_results_df["alphafold_distance"] + ) + cl_results_df["relevant_pae"] = cl_results_df.apply( + lambda row: get_relevant_pae_value(row["pae_x_position1"], row["pae_x_position2"], validation_criterion), axis=1 + ) + + y_label = "Max. PAE between binding sites" if validation_criterion == CrosslinkingValidationCriterion.max_pae.value else "Min. PAE between binding sites" + + fig = px.scatter( + cl_results_df, + y="relevant_pae", + x="distance_delta", + color="valid_crosslink", + color_discrete_map={ + False: "red", + True: "blue", + }, + labels={ + "relevant_pae": y_label, + "distance_delta": "Deviation from CL length in predicted structure (Å)", + "valid_crosslink": "CL matches structure prediction", + "measured_distance": "CL length", + "Crosslinker": "CL type", + }, + log_x=True, + title="Identified Crosslinks", + hover_data=["relevant_pae", "distance_delta", "Crosslinker", "measured_distance"] + ) + + figures.append(fig) + + return figures + + def monomer_diagrams( output_crosslinking_result_df: pd.DataFrame, structure_metadata_df: pd.DataFrame, @@ -1106,15 +1166,15 @@ def monomer_diagrams( crosslinker_information=crosslinker_information, ) - # TODO: Separate Issue #429 case ( CrosslinkingValidationCriterion.max_pae.value | CrosslinkingValidationCriterion.min_pae.value ): - return diagrams_of_crosslinking_validation_data( - validated_df=output_crosslinking_result_df, + return cl_scatterplots_pae( + cl_results_df=output_crosslinking_result_df, structures_to_validate=structures_to_validate, crosslinker_information=crosslinker_information, + validation_criterion=validation_criterion, ) # TODO: Separate Issue #429 From d5be1bd480a670ede05195fd337df142c1dcd837 Mon Sep 17 00:00:00 2001 From: jorisfu Date: Tue, 26 May 2026 16:15:49 +0200 Subject: [PATCH 24/33] feat: AF3 to AF2 PAE matrix translation --- .../alphafold_protein_structure_load.py | 95 ++++++++++++++++++- 1 file changed, 94 insertions(+), 1 deletion(-) diff --git a/backend/protzilla/importing/alphafold_protein_structure_load.py b/backend/protzilla/importing/alphafold_protein_structure_load.py index b78d2f955..cff7b0d04 100644 --- a/backend/protzilla/importing/alphafold_protein_structure_load.py +++ b/backend/protzilla/importing/alphafold_protein_structure_load.py @@ -479,6 +479,80 @@ def fetch_alphafold_protein_structure( ), ) +def reduce_pae_to_per_amino_acid( + pae_matrix: np.ndarray, + token_res_ids: list[int], + cif_df: pd.DataFrame, +): + """ + Reduces AlphaFold3 PAE matrices (per-token) to AlphaFold2 PAE matrices (per-amino acid). + If the number of tokens mapping to one AA equals the number of atoms (common for predicted PTMs), + the CA token gets used. Otherwise, the first token gets used. + Required for predictions with PTMs! + + :param pae_matrix: the per-token PAE matrix + :param token_res_ids: the token_res_ids table from the AF3 full_data json + :param cif_df: the atom_site table as a dataframe + + :return: the per-AA/per-residue PAE matrix + """ + + indices_to_delete = [] + + current_idx = 0 + runs = [] + + current_chain_idx = 0 + # Get all runs (start_token_idx, len, chain_idx, res_id) of same res ids into one list + while current_idx < len(token_res_ids): + start_token_idx = current_idx + res_id = token_res_ids[start_token_idx] + length = 1 + + if res_id == 1: + current_chain_idx += 1 + + while True: + current_idx += 1 + if current_idx < len(token_res_ids) and token_res_ids[current_idx] == res_id: + length += 1 + else: + break + + runs.append((start_token_idx, length, current_chain_idx, res_id)) + + for start_token_idx, length, chain_idx, res_id in runs: + if length == 1: + continue + + # Get corresponding entries of _atom_site table for the token + relevant_cif_df = cif_df[cif_df["_atom_site.label_entity_id"] == str(chain_idx)] + relevant_cif_df = relevant_cif_df[relevant_cif_df["_atom_site.label_seq_id"] == res_id] + + keep_offset = 0 # Relative index to keep within duplicate tokens for one amino acid. Default: first token + + # If we have one token per atom, we try to take the CA atom + if len(relevant_cif_df) == length: + # Reset index twice to get 0..length enumeration for atoms in index + relevant_cif_df.reset_index(drop=True, inplace=True) + relevant_cif_df.reset_index(inplace=True) + + relevant_cif_df = relevant_cif_df[relevant_cif_df["_atom_site.label_atom_id"] == "CA"] + # 0 or 2+ CA atoms -> default + if len(relevant_cif_df) == 1: + keep_offset = int(relevant_cif_df.iloc[0]["index"]) + + for duplicate_idx in range(0, length): + if duplicate_idx != keep_offset: + indices_to_delete.append(start_token_idx + duplicate_idx) + + # Apply deletion + mask = np.ones(len(pae_matrix), dtype=bool) + mask[indices_to_delete] = False + pae_matrix = pae_matrix[np.ix_(mask, mask)] + + return pae_matrix + def get_all_available_entry_ids_of_monomer_metadata() -> list[str]: """ " @@ -760,6 +834,7 @@ def get_monomer_structure_dfs(entry_id: str) -> dict[str, Any]: pae_matrix = np.array(ast.literal_eval(pae_string)) del df_dict["pae_df"] + return dict( **df_dict, pae_matrix=OutputItem(output_type=OutputType.JOBLIB_ARTIFACT, value=pae_matrix), @@ -779,6 +854,7 @@ def unwrap_full_data_df(full_data_df: pd.DataFrame) -> dict[str, Any]: :return dict: - "full_data_df": The updated reduced full_data_df - "pae_matrix": Numpy matrix with the PAE values for each residue pair + - "token_res_ids": List with the token -> AA mappings """ try: @@ -787,9 +863,17 @@ def unwrap_full_data_df(full_data_df: pd.DataFrame) -> dict[str, Any]: except KeyError: pae_matrix = None + try: + token_res_ids = np.array(full_data_df["token_res_ids"].iloc[0]) + full_data_df = full_data_df.drop(columns=["token_res_ids"]) + except KeyError as e: + raise KeyError("Prediction data does not contain required prediction token to amino acid mapping.") from e + + return dict( full_data_df=full_data_df, pae_matrix=pae_matrix, + token_res_ids=token_res_ids, ) @@ -915,8 +999,11 @@ def get_multimer_structure_dfs(entry_id: str) -> dict[str, Any]: df_dict["full_data_df"] = unwrapped_full_data["full_data_df"] pae_matrix = unwrapped_full_data["pae_matrix"] + token_res_ids = unwrapped_full_data["token_res_ids"] plddt_df = get_plddt_from_cif(df_dict["cif_df"]) + pae_matrix = reduce_pae_to_per_amino_acid(pae_matrix, token_res_ids, df_dict["cif_df"]) + if plddt_df is None: messages.append( dict( @@ -1079,9 +1166,15 @@ def upload_multimer_prediction( unwrapped_full_data = unwrap_full_data_df(df_dict["full_data_df"]) df_dict["full_data_df"] = unwrapped_full_data["full_data_df"] + pae_matrix = reduce_pae_to_per_amino_acid( + unwrapped_full_data["pae_matrix"], + unwrapped_full_data["token_res_ids"], + df_dict["cif_df"], + ) + pae_matrix = OutputItem( output_type=OutputType.JOBLIB_ARTIFACT, - value=unwrapped_full_data["pae_matrix"], + value=pae_matrix, ) df_dict["pae_matrix"] = pae_matrix df_dict["plddt_df"] = get_plddt_from_cif(df_dict["cif_df"]) From 5b067af0095d430bd568e8ca252d3c5d7add1ddc Mon Sep 17 00:00:00 2001 From: jorisfu Date: Tue, 26 May 2026 17:10:27 +0200 Subject: [PATCH 25/33] (AI) tests: PAE matrix reduction --- .../importing/test_pae_matrix_reduction.py | 160 ++++++++++++++++++ 1 file changed, 160 insertions(+) create mode 100644 backend/tests/protzilla/importing/test_pae_matrix_reduction.py diff --git a/backend/tests/protzilla/importing/test_pae_matrix_reduction.py b/backend/tests/protzilla/importing/test_pae_matrix_reduction.py new file mode 100644 index 000000000..327f11337 --- /dev/null +++ b/backend/tests/protzilla/importing/test_pae_matrix_reduction.py @@ -0,0 +1,160 @@ +import numpy as np +import pandas as pd +from backend.protzilla.importing.alphafold_protein_structure_load import reduce_pae_to_per_amino_acid +import pytest + +# These test cases were generated by AI (Gemini) but have been manually checked + + +@pytest.fixture +def empty_cif_df(): + """Returns an empty atom_site DataFrame with required columns.""" + return pd.DataFrame( + columns=[ + "_atom_site.label_entity_id", + "_atom_site.label_seq_id", + "_atom_site.label_atom_id", + ] + ) + + +def test_no_duplicates(empty_cif_df): + """Edge Case 1: Every residue has exactly one token. + + The PAE matrix should remain entirely untouched. + """ + pae_matrix = np.array([[1.0, 2.0], [3.0, 4.0]]) + token_res_ids = [1, 2] + + result = reduce_pae_to_per_amino_acid(pae_matrix, token_res_ids, empty_cif_df) + + assert np.array_equal(result, pae_matrix) + + +def test_duplicates_fallback_to_first_token(empty_cif_df): + """Edge Case 2: Multi-token residue, but length mismatch with CIF. + + Should fall back to keeping the first token (offset 0) and deleting the rest. + """ + # 3 tokens: Residue 1 has 2 tokens, Residue 2 has 1 token. + pae_matrix = np.array([[10, 11, 12], [20, 21, 22], [30, 31, 32]]) + token_res_ids = [1, 1, 2] + + # Keeping token 0 (first of res 1) and token 2 (res 2). Token 1 should be deleted. + expected_indices = [0, 2] + expected_matrix = pae_matrix[np.ix_(expected_indices, expected_indices)] + + result = reduce_pae_to_per_amino_acid(pae_matrix, token_res_ids, empty_cif_df) + + assert np.array_equal(result, expected_matrix) + + +def test_duplicates_keep_ca_atom(): + """Edge Case 3: Run length matches CIF length, and exactly one CA atom is found. + + Should keep the token corresponding exactly to the 'CA' atom position. + """ + # 4 tokens: Residue 1 has 3 tokens, Residue 2 has 1 token. + pae_matrix = np.diag([1.0, 2.0, 3.0, 4.0]) + token_res_ids = [1, 1, 1, 2] + + # CIF setup: 3 atoms for chain 1, residue 1. 'CA' sits at relative index 1. + cif_df = pd.DataFrame( + { + "_atom_site.label_entity_id": ["1", "1", "1"], + "_atom_site.label_seq_id": [1, 1, 1], + "_atom_site.label_atom_id": ["N", "CA", "C"], + } + ) + + # Expected: Keep global index 1 (the CA atom) and global index 3 (residue 2). + # Global indices 0 and 2 should be wiped out. + expected_indices = [1, 3] + expected_matrix = pae_matrix[np.ix_(expected_indices, expected_indices)] + + result = reduce_pae_to_per_amino_acid(pae_matrix, token_res_ids, cif_df) + + assert np.array_equal(result, expected_matrix) + + +def test_duplicates_cif_match_but_no_ca_fallback(): + """Edge Case 4a: Run length matches CIF length, but no CA atom exists. + + Should fall back to keeping the first token (offset 0). + """ + pae_matrix = np.diag([10, 20, 30]) + token_res_ids = [1, 1, 2] + + # CIF matches length (2 atoms), but neither is 'CA' + cif_df = pd.DataFrame( + { + "_atom_site.label_entity_id": ["1", "1"], + "_atom_site.label_seq_id": [1, 1], + "_atom_site.label_atom_id": ["N", "O"], + } + ) + + # Expected to keep global index 0 (fallback) and global index 2. + expected_indices = [0, 2] + expected_matrix = pae_matrix[np.ix_(expected_indices, expected_indices)] + + result = reduce_pae_to_per_amino_acid(pae_matrix, token_res_ids, cif_df) + + assert np.array_equal(result, expected_matrix) + + +def test_duplicates_cif_match_multiple_ca_fallback(): + """Edge Case 4b: Run length matches CIF length, but multiple CA atoms exist. + + Should fall back to keeping the first token (offset 0). + """ + pae_matrix = np.diag([10, 20, 30]) + token_res_ids = [1, 1, 2] + + # CIF matches length (2 atoms), but both claim to be 'CA' + cif_df = pd.DataFrame( + { + "_atom_site.label_entity_id": ["1", "1"], + "_atom_site.label_seq_id": [1, 1], + "_atom_site.label_atom_id": ["CA", "CA"], + } + ) + + # Expected to keep global index 0 (fallback) and global index 2. + expected_indices = [0, 2] + expected_matrix = pae_matrix[np.ix_(expected_indices, expected_indices)] + + result = reduce_pae_to_per_amino_acid(pae_matrix, token_res_ids, cif_df) + + assert np.array_equal(result, expected_matrix) + + +def test_multiple_chains_tracking(): + """Edge Case 5: The system has multiple chains. + + Verifies that `current_chain_idx` increments whenever `res_id == 1` starts a run, + and queries the correct stringified `_atom_site.label_entity_id`. + """ + # Chain 1: res 1 (len 1), res 2 (len 1) + # Chain 2: res 1 (len 2) -> Triggered by encountering 1 again + token_res_ids = [1, 2, 1, 1] + pae_matrix = np.diag([100, 200, 300, 400]) + + cif_df = pd.DataFrame( + { + "_atom_site.label_entity_id": ["1", "1", "2", "2"], + "_atom_site.label_seq_id": [1, 2, 1, 1], + "_atom_site.label_atom_id": ["CA", "CA", "N", "CA"], + } + ) + + # Chain 1, Res 1 (idx 0): len 1 -> Keep + # Chain 1, Res 2 (idx 1): len 1 -> Keep + # Chain 2, Res 1 (idx 2, 3): len 2 -> Matches CIF length for chain '2'. + # CA is at relative index 1 (global idx 3). Global idx 2 is dropped. + expected_indices = [0, 1, 3] + expected_matrix = pae_matrix[np.ix_(expected_indices, expected_indices)] + + result = reduce_pae_to_per_amino_acid(pae_matrix, token_res_ids, cif_df) + + assert np.array_equal(result, expected_matrix) From a1a5a9185f9dde9900af398f130ea1da2f1fd8bc Mon Sep 17 00:00:00 2001 From: jorisfu Date: Tue, 26 May 2026 17:13:56 +0200 Subject: [PATCH 26/33] chore: remove unused imports --- backend/protzilla/data_analysis/crosslinking_validation.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/backend/protzilla/data_analysis/crosslinking_validation.py b/backend/protzilla/data_analysis/crosslinking_validation.py index a65fe5e21..0e8b8aec4 100644 --- a/backend/protzilla/data_analysis/crosslinking_validation.py +++ b/backend/protzilla/data_analysis/crosslinking_validation.py @@ -1,12 +1,8 @@ import itertools import ast import math -from pipes import stepkinds -from multiprocessing.sharedctypes import Value -from typing import TYPE_CHECKING, Callable - -from numpy.testing import assert_ +from typing import Callable from backend.protzilla.constants.option_types import CrosslinkingValidationCriterion import pandas as pd @@ -14,7 +10,6 @@ import re import logging -from pandas.io.stata import stata_epoch import plotly.graph_objects as go from plotly.graph_objects import Figure import plotly.express as px From d1e3cf547efa8c66a1abc912ab6678e4a2ba211f Mon Sep 17 00:00:00 2001 From: jorisfu Date: Tue, 26 May 2026 17:45:42 +0200 Subject: [PATCH 27/33] feat: only make bounds fields visible if manual bounds is selected mode --- backend/protzilla/methods/data_analysis.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/backend/protzilla/methods/data_analysis.py b/backend/protzilla/methods/data_analysis.py index 5deb6557f..75bc84ced 100644 --- a/backend/protzilla/methods/data_analysis.py +++ b/backend/protzilla/methods/data_analysis.py @@ -2417,6 +2417,10 @@ def create_crosslink_input_fields(self, form: Form, run: Run): form.add_field(upper_bound_length_deviation_field) form.add_field(lower_bound_length_deviation_field) + bounds_visible = form["validation_criterion"].value == CrosslinkingValidationCriterion.manual_bounds.value + form[f"{crosslinker}_upper_accepted_deviation"].isVisible = bounds_visible + form[f"{crosslinker}_lower_accepted_deviation"].isVisible = bounds_visible + def collect_crosslinking_information(self, steps: StepManager, inputs) -> dict: # although crosslinker_information is not a dataframe we need to insert the user information regarding the crosslinks as a dictionary into the inputs crosslinker_to_length_and_deviation = {} From 05580e9f35e9f627b1c296fdfabe38da35830528 Mon Sep 17 00:00:00 2001 From: jorisfu Date: Tue, 26 May 2026 17:46:10 +0200 Subject: [PATCH 28/33] chore: black --- .../alphafold_protein_structure_load.py | 28 +++++++++++++------ backend/protzilla/methods/data_analysis.py | 5 +++- .../importing/test_pae_matrix_reduction.py | 16 ++++++----- 3 files changed, 32 insertions(+), 17 deletions(-) diff --git a/backend/protzilla/importing/alphafold_protein_structure_load.py b/backend/protzilla/importing/alphafold_protein_structure_load.py index cff7b0d04..38f4d1307 100644 --- a/backend/protzilla/importing/alphafold_protein_structure_load.py +++ b/backend/protzilla/importing/alphafold_protein_structure_load.py @@ -479,6 +479,7 @@ def fetch_alphafold_protein_structure( ), ) + def reduce_pae_to_per_amino_acid( pae_matrix: np.ndarray, token_res_ids: list[int], @@ -514,7 +515,10 @@ def reduce_pae_to_per_amino_acid( while True: current_idx += 1 - if current_idx < len(token_res_ids) and token_res_ids[current_idx] == res_id: + if ( + current_idx < len(token_res_ids) + and token_res_ids[current_idx] == res_id + ): length += 1 else: break @@ -524,12 +528,14 @@ def reduce_pae_to_per_amino_acid( for start_token_idx, length, chain_idx, res_id in runs: if length == 1: continue - + # Get corresponding entries of _atom_site table for the token relevant_cif_df = cif_df[cif_df["_atom_site.label_entity_id"] == str(chain_idx)] - relevant_cif_df = relevant_cif_df[relevant_cif_df["_atom_site.label_seq_id"] == res_id] + relevant_cif_df = relevant_cif_df[ + relevant_cif_df["_atom_site.label_seq_id"] == res_id + ] - keep_offset = 0 # Relative index to keep within duplicate tokens for one amino acid. Default: first token + keep_offset = 0 # Relative index to keep within duplicate tokens for one amino acid. Default: first token # If we have one token per atom, we try to take the CA atom if len(relevant_cif_df) == length: @@ -537,7 +543,9 @@ def reduce_pae_to_per_amino_acid( relevant_cif_df.reset_index(drop=True, inplace=True) relevant_cif_df.reset_index(inplace=True) - relevant_cif_df = relevant_cif_df[relevant_cif_df["_atom_site.label_atom_id"] == "CA"] + relevant_cif_df = relevant_cif_df[ + relevant_cif_df["_atom_site.label_atom_id"] == "CA" + ] # 0 or 2+ CA atoms -> default if len(relevant_cif_df) == 1: keep_offset = int(relevant_cif_df.iloc[0]["index"]) @@ -834,7 +842,6 @@ def get_monomer_structure_dfs(entry_id: str) -> dict[str, Any]: pae_matrix = np.array(ast.literal_eval(pae_string)) del df_dict["pae_df"] - return dict( **df_dict, pae_matrix=OutputItem(output_type=OutputType.JOBLIB_ARTIFACT, value=pae_matrix), @@ -867,8 +874,9 @@ def unwrap_full_data_df(full_data_df: pd.DataFrame) -> dict[str, Any]: token_res_ids = np.array(full_data_df["token_res_ids"].iloc[0]) full_data_df = full_data_df.drop(columns=["token_res_ids"]) except KeyError as e: - raise KeyError("Prediction data does not contain required prediction token to amino acid mapping.") from e - + raise KeyError( + "Prediction data does not contain required prediction token to amino acid mapping." + ) from e return dict( full_data_df=full_data_df, @@ -1002,7 +1010,9 @@ def get_multimer_structure_dfs(entry_id: str) -> dict[str, Any]: token_res_ids = unwrapped_full_data["token_res_ids"] plddt_df = get_plddt_from_cif(df_dict["cif_df"]) - pae_matrix = reduce_pae_to_per_amino_acid(pae_matrix, token_res_ids, df_dict["cif_df"]) + pae_matrix = reduce_pae_to_per_amino_acid( + pae_matrix, token_res_ids, df_dict["cif_df"] + ) if plddt_df is None: messages.append( diff --git a/backend/protzilla/methods/data_analysis.py b/backend/protzilla/methods/data_analysis.py index 75bc84ced..b7f47667c 100644 --- a/backend/protzilla/methods/data_analysis.py +++ b/backend/protzilla/methods/data_analysis.py @@ -2417,7 +2417,10 @@ def create_crosslink_input_fields(self, form: Form, run: Run): form.add_field(upper_bound_length_deviation_field) form.add_field(lower_bound_length_deviation_field) - bounds_visible = form["validation_criterion"].value == CrosslinkingValidationCriterion.manual_bounds.value + bounds_visible = ( + form["validation_criterion"].value + == CrosslinkingValidationCriterion.manual_bounds.value + ) form[f"{crosslinker}_upper_accepted_deviation"].isVisible = bounds_visible form[f"{crosslinker}_lower_accepted_deviation"].isVisible = bounds_visible diff --git a/backend/tests/protzilla/importing/test_pae_matrix_reduction.py b/backend/tests/protzilla/importing/test_pae_matrix_reduction.py index 327f11337..4daf5449e 100644 --- a/backend/tests/protzilla/importing/test_pae_matrix_reduction.py +++ b/backend/tests/protzilla/importing/test_pae_matrix_reduction.py @@ -1,6 +1,8 @@ import numpy as np import pandas as pd -from backend.protzilla.importing.alphafold_protein_structure_load import reduce_pae_to_per_amino_acid +from backend.protzilla.importing.alphafold_protein_structure_load import ( + reduce_pae_to_per_amino_acid, +) import pytest # These test cases were generated by AI (Gemini) but have been manually checked @@ -20,7 +22,7 @@ def empty_cif_df(): def test_no_duplicates(empty_cif_df): """Edge Case 1: Every residue has exactly one token. - + The PAE matrix should remain entirely untouched. """ pae_matrix = np.array([[1.0, 2.0], [3.0, 4.0]]) @@ -33,7 +35,7 @@ def test_no_duplicates(empty_cif_df): def test_duplicates_fallback_to_first_token(empty_cif_df): """Edge Case 2: Multi-token residue, but length mismatch with CIF. - + Should fall back to keeping the first token (offset 0) and deleting the rest. """ # 3 tokens: Residue 1 has 2 tokens, Residue 2 has 1 token. @@ -51,7 +53,7 @@ def test_duplicates_fallback_to_first_token(empty_cif_df): def test_duplicates_keep_ca_atom(): """Edge Case 3: Run length matches CIF length, and exactly one CA atom is found. - + Should keep the token corresponding exactly to the 'CA' atom position. """ # 4 tokens: Residue 1 has 3 tokens, Residue 2 has 1 token. @@ -79,7 +81,7 @@ def test_duplicates_keep_ca_atom(): def test_duplicates_cif_match_but_no_ca_fallback(): """Edge Case 4a: Run length matches CIF length, but no CA atom exists. - + Should fall back to keeping the first token (offset 0). """ pae_matrix = np.diag([10, 20, 30]) @@ -105,7 +107,7 @@ def test_duplicates_cif_match_but_no_ca_fallback(): def test_duplicates_cif_match_multiple_ca_fallback(): """Edge Case 4b: Run length matches CIF length, but multiple CA atoms exist. - + Should fall back to keeping the first token (offset 0). """ pae_matrix = np.diag([10, 20, 30]) @@ -131,7 +133,7 @@ def test_duplicates_cif_match_multiple_ca_fallback(): def test_multiple_chains_tracking(): """Edge Case 5: The system has multiple chains. - + Verifies that `current_chain_idx` increments whenever `res_id == 1` starts a run, and queries the correct stringified `_atom_site.label_entity_id`. """ From 732d83856253503d753193e2d460c80f23809408 Mon Sep 17 00:00:00 2001 From: jorisfu Date: Wed, 27 May 2026 16:12:41 +0200 Subject: [PATCH 29/33] chore: clean up PAE plots and add to multimer validation --- .../data_analysis/crosslinking_validation.py | 93 ++++++++++++++----- 1 file changed, 71 insertions(+), 22 deletions(-) diff --git a/backend/protzilla/data_analysis/crosslinking_validation.py b/backend/protzilla/data_analysis/crosslinking_validation.py index 0e8b8aec4..5f7e6db47 100644 --- a/backend/protzilla/data_analysis/crosslinking_validation.py +++ b/backend/protzilla/data_analysis/crosslinking_validation.py @@ -4,6 +4,8 @@ from typing import Callable +from pandas.core.generic import validate_inclusive + from backend.protzilla.constants.option_types import CrosslinkingValidationCriterion import pandas as pd import numpy as np @@ -1107,25 +1109,72 @@ def get_relevant_pae_value( y_label = "Max. PAE between binding sites" if validation_criterion == CrosslinkingValidationCriterion.max_pae.value else "Min. PAE between binding sites" - fig = px.scatter( - cl_results_df, - y="relevant_pae", - x="distance_delta", - color="valid_crosslink", - color_discrete_map={ - False: "red", - True: "blue", - }, - labels={ - "relevant_pae": y_label, - "distance_delta": "Deviation from CL length in predicted structure (Å)", - "valid_crosslink": "CL matches structure prediction", - "measured_distance": "CL length", - "Crosslinker": "CL type", - }, - log_x=True, - title="Identified Crosslinks", - hover_data=["relevant_pae", "distance_delta", "Crosslinker", "measured_distance"] + + fig = go.Figure() + + valid_cls = cl_results_df[cl_results_df["valid_crosslink"]] + invalid_cls = cl_results_df[~cl_results_df["valid_crosslink"]] + + fig.add_trace(go.Scatter( + x = valid_cls["distance_delta"], + y = valid_cls["relevant_pae"], + customdata = np.stack(( + valid_cls['Crosslinker'], + valid_cls['measured_distance'], + valid_cls['distance_delta'], + valid_cls['relevant_pae'], + ), axis=-1), + mode = 'markers', + name = "CLs matching prediction", + hovertemplate = "%{customdata[0]} (Length %{customdata[1]}Å)
Prediction off by %{customdata[2]:.2f}Å
PAE %{customdata[3]:.2f}Å", + hoverinfo="none" + )) + + fig.add_trace(go.Scatter( + x = invalid_cls["distance_delta"], + y = invalid_cls["relevant_pae"], + customdata = np.stack(( + invalid_cls['Crosslinker'], + invalid_cls['measured_distance'], + invalid_cls['distance_delta'], + invalid_cls['relevant_pae'], + ), axis=-1), + mode = 'markers', + name = "CLs not matching prediction", + hovertemplate = "%{customdata[0]} (Length %{customdata[1]}Å)
Prediction off by %{customdata[2]:.2f}Å
PAE %{customdata[3]:.2f}Å", + hoverinfo="none" + )) + + # X axis range should start as close to 0 as reasonable and extend to max value + min_dist_delta = min(cl_results_df["distance_delta"]) + max_dist_delta = max(cl_results_df["distance_delta"]) + + xmin = -1 + if np.log10(xmin) > min_dist_delta: + xmin = np.log10(min_dist_delta) + + xmax = np.log10(max_dist_delta) + xmax += np.log10(1.2) # reasonable padding + + + fig.update_xaxes( + type="log", + title_text="Deviation from CL length in predicted structure (Å) (log scaled)", + tickmode="linear", + tick0=0, + dtick=np.log10(2), + range=[xmin, xmax], + ) + + fig.update_yaxes( + title_text=y_label, + ) + + + fig.update_layout( + title=dict( + text=f"Identified Crosslinks vs. Predicted Structure ({', '.join(structures_to_validate)})" + ), ) figures.append(fig) @@ -1219,15 +1268,15 @@ def multimer_diagrams( crosslinker_information=crosslinker_information, ) - # TODO: Separate Issue #429 case ( CrosslinkingValidationCriterion.max_pae.value | CrosslinkingValidationCriterion.min_pae.value ): - return diagrams_of_crosslinking_validation_data( - validated_df=output_crosslinking_result_df, + return cl_scatterplots_pae( + cl_results_df=output_crosslinking_result_df, structures_to_validate=structures_to_validate, crosslinker_information=crosslinker_information, + validation_criterion=validation_criterion, ) # TODO: Separate Issue #429 From 5dea77d6f77b8b87f442daed0cebe5b09fc2c91a Mon Sep 17 00:00:00 2001 From: jorisfu Date: Wed, 27 May 2026 17:05:54 +0200 Subject: [PATCH 30/33] feat: plddt plot and some corrections --- .../data_analysis/crosslinking_validation.py | 137 ++++++++++++++++-- 1 file changed, 127 insertions(+), 10 deletions(-) diff --git a/backend/protzilla/data_analysis/crosslinking_validation.py b/backend/protzilla/data_analysis/crosslinking_validation.py index 5f7e6db47..e5abd0ddc 100644 --- a/backend/protzilla/data_analysis/crosslinking_validation.py +++ b/backend/protzilla/data_analysis/crosslinking_validation.py @@ -781,6 +781,8 @@ def get_paes(): "plddt_at_position2": plddt_at_position2, "pae_x_position1": pae_x_position1, "pae_x_position2": pae_x_position2, + "accepted_distance_lower_bound": accepted_distance_lower_bound, + "accepted_distance_upper_bound": accepted_distance_upper_bound, } ) @@ -794,6 +796,8 @@ def get_paes(): "plddt_at_position2", "pae_x_position1", "pae_x_position2", + "accepted_distance_lower_bound", + "accepted_distance_upper_bound", ] relevant_crosslinks_df["crosslinker_position1"] = relevant_crosslinks_df[ @@ -1075,6 +1079,110 @@ def diagrams_of_crosslinking_validation_data( return figures +def cl_scatterplots_plddt( + cl_results_df: pd.DataFrame, + structures_to_validate: list[str], + crosslinker_information: dict[str, list[float]], +) -> list[Figure]: + + figures: list[Figure] = [] + + cl_results_df["measured_distance"] = cl_results_df.apply( + lambda row: crosslinker_information[row["Crosslinker"]][0], axis=1 + ) + cl_results_df["distance_delta"] = abs( + cl_results_df["measured_distance"] - cl_results_df["alphafold_distance"] + ) + cl_results_df["avg_plddt"] = cl_results_df.apply( + lambda row: np.average([row["plddt_at_position1"], row["plddt_at_position2"]]), axis=1 + ) + + y_label = "Avg. pLDDT at binding sites" + + fig = go.Figure() + + valid_cls = cl_results_df[cl_results_df["valid_crosslink"]] + invalid_cls = cl_results_df[~cl_results_df["valid_crosslink"]] + + hovertemplate = "%{customdata[0]} (Length %{customdata[1]}Å)" + \ + "
Predicted distance: %{customdata[2]:.2f}Å " + \ + "(off by %{customdata[3]:.2f}Å)" + \ + "
Accepted distance range %{customdata[4]:.2f} - %{customdata[5]:.2f} Å" + \ + "" + + fig.add_trace(go.Scatter( + x = valid_cls["distance_delta"], + y = valid_cls["avg_plddt"], + customdata = np.stack(( + valid_cls['Crosslinker'], + valid_cls['measured_distance'], + valid_cls['alphafold_distance'], + valid_cls['distance_delta'], + valid_cls['accepted_distance_lower_bound'], + valid_cls['accepted_distance_upper_bound'], + ), axis=-1), + mode = 'markers', + name = "CLs matching prediction", + hovertemplate = hovertemplate, + )) + + fig.add_trace(go.Scatter( + x = invalid_cls["distance_delta"], + y = invalid_cls["avg_plddt"], + customdata = np.stack(( + invalid_cls['Crosslinker'], + invalid_cls['measured_distance'], + invalid_cls['alphafold_distance'], + invalid_cls['distance_delta'], + invalid_cls['accepted_distance_lower_bound'], + invalid_cls['accepted_distance_upper_bound'], + ), axis=-1), + mode = 'markers', + name = "CLs not matching prediction", + hovertemplate = hovertemplate, + )) + + # X axis range should start as close to 0 as reasonable and extend to max value + min_dist_delta = min(cl_results_df["distance_delta"]) + max_dist_delta = max(cl_results_df["distance_delta"]) + + xmin = -1 + if np.log10(xmin) > min_dist_delta: + xmin = np.log10(min_dist_delta) + + xmax = np.log10(max_dist_delta) + xmax += np.log10(1.2) # reasonable padding + + # Y axis must start at 0 and extend to max avg pLDDT + padding + ymin = 0 + ymax = max(cl_results_df["avg_plddt"]) + 5 + + fig.update_xaxes( + type="log", + title_text="Deviation from CL length in predicted structure (Å) (log scaled)", + tickmode="linear", + tick0=0, + dtick=np.log10(2), + range=[xmin, xmax], + ) + + fig.update_yaxes( + title_text=y_label, + range=[ymin, ymax], + ) + + + fig.update_layout( + title=dict( + text=f"Identified Crosslinks vs. Predicted Structure ({', '.join(structures_to_validate)})" + ), + showlegend=True + ) + + figures.append(fig) + + return figures + def cl_scatterplots_pae( cl_results_df: pd.DataFrame, @@ -1115,19 +1223,25 @@ def get_relevant_pae_value( valid_cls = cl_results_df[cl_results_df["valid_crosslink"]] invalid_cls = cl_results_df[~cl_results_df["valid_crosslink"]] + hovertemplate = "%{customdata[0]} (Length %{customdata[1]}Å)" + \ + "
Predicted distance: %{customdata[2]:.2f}Å " + \ + "(off by %{customdata[3]:.2f}Å)" + \ + "
PAE %{customdata[4]:.2f}Å" + \ + "" + fig.add_trace(go.Scatter( x = valid_cls["distance_delta"], y = valid_cls["relevant_pae"], customdata = np.stack(( valid_cls['Crosslinker'], valid_cls['measured_distance'], + valid_cls['alphafold_distance'], valid_cls['distance_delta'], valid_cls['relevant_pae'], ), axis=-1), mode = 'markers', name = "CLs matching prediction", - hovertemplate = "%{customdata[0]} (Length %{customdata[1]}Å)
Prediction off by %{customdata[2]:.2f}Å
PAE %{customdata[3]:.2f}Å", - hoverinfo="none" + hovertemplate = hovertemplate )) fig.add_trace(go.Scatter( @@ -1136,13 +1250,13 @@ def get_relevant_pae_value( customdata = np.stack(( invalid_cls['Crosslinker'], invalid_cls['measured_distance'], + invalid_cls['alphafold_distance'], invalid_cls['distance_delta'], invalid_cls['relevant_pae'], ), axis=-1), mode = 'markers', name = "CLs not matching prediction", - hovertemplate = "%{customdata[0]} (Length %{customdata[1]}Å)
Prediction off by %{customdata[2]:.2f}Å
PAE %{customdata[3]:.2f}Å", - hoverinfo="none" + hovertemplate = hovertemplate, )) # X axis range should start as close to 0 as reasonable and extend to max value @@ -1156,6 +1270,9 @@ def get_relevant_pae_value( xmax = np.log10(max_dist_delta) xmax += np.log10(1.2) # reasonable padding + # Y axis must start at 0 and extend to max relevant PAE + padding + ymin = 0 + ymax = max(cl_results_df["relevant_pae"]) + 5 fig.update_xaxes( type="log", @@ -1168,6 +1285,7 @@ def get_relevant_pae_value( fig.update_yaxes( title_text=y_label, + range=[ymin, ymax], ) @@ -1175,6 +1293,7 @@ def get_relevant_pae_value( title=dict( text=f"Identified Crosslinks vs. Predicted Structure ({', '.join(structures_to_validate)})" ), + showlegend=True, ) figures.append(fig) @@ -1221,10 +1340,9 @@ def monomer_diagrams( validation_criterion=validation_criterion, ) - # TODO: Separate Issue #429 case CrosslinkingValidationCriterion.plddt_adjusted.value: - return diagrams_of_crosslinking_validation_data( - validated_df=output_crosslinking_result_df, + return cl_scatterplots_plddt( + cl_results_df=output_crosslinking_result_df, structures_to_validate=structures_to_validate, crosslinker_information=crosslinker_information, ) @@ -1279,10 +1397,9 @@ def multimer_diagrams( validation_criterion=validation_criterion, ) - # TODO: Separate Issue #429 case CrosslinkingValidationCriterion.plddt_adjusted.value: - return diagrams_of_crosslinking_validation_data( - validated_df=output_crosslinking_result_df, + return cl_scatterplots_plddt( + cl_results_df=output_crosslinking_result_df, structures_to_validate=structures_to_validate, crosslinker_information=crosslinker_information, ) From 2fba6ad054cc5ee5f9c3810e551b169b38372994 Mon Sep 17 00:00:00 2001 From: jorisfu Date: Wed, 27 May 2026 17:08:03 +0200 Subject: [PATCH 31/33] chore: black --- .../data_analysis/crosslinking_validation.py | 192 ++++++++++-------- 1 file changed, 111 insertions(+), 81 deletions(-) diff --git a/backend/protzilla/data_analysis/crosslinking_validation.py b/backend/protzilla/data_analysis/crosslinking_validation.py index e5abd0ddc..807f50c58 100644 --- a/backend/protzilla/data_analysis/crosslinking_validation.py +++ b/backend/protzilla/data_analysis/crosslinking_validation.py @@ -1079,6 +1079,7 @@ def diagrams_of_crosslinking_validation_data( return figures + def cl_scatterplots_plddt( cl_results_df: pd.DataFrame, structures_to_validate: list[str], @@ -1094,7 +1095,8 @@ def cl_scatterplots_plddt( cl_results_df["measured_distance"] - cl_results_df["alphafold_distance"] ) cl_results_df["avg_plddt"] = cl_results_df.apply( - lambda row: np.average([row["plddt_at_position1"], row["plddt_at_position2"]]), axis=1 + lambda row: np.average([row["plddt_at_position1"], row["plddt_at_position2"]]), + axis=1, ) y_label = "Avg. pLDDT at binding sites" @@ -1104,43 +1106,55 @@ def cl_scatterplots_plddt( valid_cls = cl_results_df[cl_results_df["valid_crosslink"]] invalid_cls = cl_results_df[~cl_results_df["valid_crosslink"]] - hovertemplate = "%{customdata[0]} (Length %{customdata[1]}Å)" + \ - "
Predicted distance: %{customdata[2]:.2f}Å " + \ - "(off by %{customdata[3]:.2f}Å)" + \ - "
Accepted distance range %{customdata[4]:.2f} - %{customdata[5]:.2f} Å" + \ - "" - - fig.add_trace(go.Scatter( - x = valid_cls["distance_delta"], - y = valid_cls["avg_plddt"], - customdata = np.stack(( - valid_cls['Crosslinker'], - valid_cls['measured_distance'], - valid_cls['alphafold_distance'], - valid_cls['distance_delta'], - valid_cls['accepted_distance_lower_bound'], - valid_cls['accepted_distance_upper_bound'], - ), axis=-1), - mode = 'markers', - name = "CLs matching prediction", - hovertemplate = hovertemplate, - )) - - fig.add_trace(go.Scatter( - x = invalid_cls["distance_delta"], - y = invalid_cls["avg_plddt"], - customdata = np.stack(( - invalid_cls['Crosslinker'], - invalid_cls['measured_distance'], - invalid_cls['alphafold_distance'], - invalid_cls['distance_delta'], - invalid_cls['accepted_distance_lower_bound'], - invalid_cls['accepted_distance_upper_bound'], - ), axis=-1), - mode = 'markers', - name = "CLs not matching prediction", - hovertemplate = hovertemplate, - )) + hovertemplate = ( + "%{customdata[0]} (Length %{customdata[1]}Å)" + + "
Predicted distance: %{customdata[2]:.2f}Å " + + "(off by %{customdata[3]:.2f}Å)" + + "
Accepted distance range %{customdata[4]:.2f} - %{customdata[5]:.2f} Å" + + "" + ) + + fig.add_trace( + go.Scatter( + x=valid_cls["distance_delta"], + y=valid_cls["avg_plddt"], + customdata=np.stack( + ( + valid_cls["Crosslinker"], + valid_cls["measured_distance"], + valid_cls["alphafold_distance"], + valid_cls["distance_delta"], + valid_cls["accepted_distance_lower_bound"], + valid_cls["accepted_distance_upper_bound"], + ), + axis=-1, + ), + mode="markers", + name="CLs matching prediction", + hovertemplate=hovertemplate, + ) + ) + + fig.add_trace( + go.Scatter( + x=invalid_cls["distance_delta"], + y=invalid_cls["avg_plddt"], + customdata=np.stack( + ( + invalid_cls["Crosslinker"], + invalid_cls["measured_distance"], + invalid_cls["alphafold_distance"], + invalid_cls["distance_delta"], + invalid_cls["accepted_distance_lower_bound"], + invalid_cls["accepted_distance_upper_bound"], + ), + axis=-1, + ), + mode="markers", + name="CLs not matching prediction", + hovertemplate=hovertemplate, + ) + ) # X axis range should start as close to 0 as reasonable and extend to max value min_dist_delta = min(cl_results_df["distance_delta"]) @@ -1151,7 +1165,7 @@ def cl_scatterplots_plddt( xmin = np.log10(min_dist_delta) xmax = np.log10(max_dist_delta) - xmax += np.log10(1.2) # reasonable padding + xmax += np.log10(1.2) # reasonable padding # Y axis must start at 0 and extend to max avg pLDDT + padding ymin = 0 @@ -1171,12 +1185,11 @@ def cl_scatterplots_plddt( range=[ymin, ymax], ) - fig.update_layout( title=dict( text=f"Identified Crosslinks vs. Predicted Structure ({', '.join(structures_to_validate)})" ), - showlegend=True + showlegend=True, ) figures.append(fig) @@ -1212,52 +1225,70 @@ def get_relevant_pae_value( cl_results_df["measured_distance"] - cl_results_df["alphafold_distance"] ) cl_results_df["relevant_pae"] = cl_results_df.apply( - lambda row: get_relevant_pae_value(row["pae_x_position1"], row["pae_x_position2"], validation_criterion), axis=1 + lambda row: get_relevant_pae_value( + row["pae_x_position1"], row["pae_x_position2"], validation_criterion + ), + axis=1, ) - y_label = "Max. PAE between binding sites" if validation_criterion == CrosslinkingValidationCriterion.max_pae.value else "Min. PAE between binding sites" - + y_label = ( + "Max. PAE between binding sites" + if validation_criterion == CrosslinkingValidationCriterion.max_pae.value + else "Min. PAE between binding sites" + ) fig = go.Figure() valid_cls = cl_results_df[cl_results_df["valid_crosslink"]] invalid_cls = cl_results_df[~cl_results_df["valid_crosslink"]] - hovertemplate = "%{customdata[0]} (Length %{customdata[1]}Å)" + \ - "
Predicted distance: %{customdata[2]:.2f}Å " + \ - "(off by %{customdata[3]:.2f}Å)" + \ - "
PAE %{customdata[4]:.2f}Å" + \ - "" - - fig.add_trace(go.Scatter( - x = valid_cls["distance_delta"], - y = valid_cls["relevant_pae"], - customdata = np.stack(( - valid_cls['Crosslinker'], - valid_cls['measured_distance'], - valid_cls['alphafold_distance'], - valid_cls['distance_delta'], - valid_cls['relevant_pae'], - ), axis=-1), - mode = 'markers', - name = "CLs matching prediction", - hovertemplate = hovertemplate - )) - - fig.add_trace(go.Scatter( - x = invalid_cls["distance_delta"], - y = invalid_cls["relevant_pae"], - customdata = np.stack(( - invalid_cls['Crosslinker'], - invalid_cls['measured_distance'], - invalid_cls['alphafold_distance'], - invalid_cls['distance_delta'], - invalid_cls['relevant_pae'], - ), axis=-1), - mode = 'markers', - name = "CLs not matching prediction", - hovertemplate = hovertemplate, - )) + hovertemplate = ( + "%{customdata[0]} (Length %{customdata[1]}Å)" + + "
Predicted distance: %{customdata[2]:.2f}Å " + + "(off by %{customdata[3]:.2f}Å)" + + "
PAE %{customdata[4]:.2f}Å" + + "" + ) + + fig.add_trace( + go.Scatter( + x=valid_cls["distance_delta"], + y=valid_cls["relevant_pae"], + customdata=np.stack( + ( + valid_cls["Crosslinker"], + valid_cls["measured_distance"], + valid_cls["alphafold_distance"], + valid_cls["distance_delta"], + valid_cls["relevant_pae"], + ), + axis=-1, + ), + mode="markers", + name="CLs matching prediction", + hovertemplate=hovertemplate, + ) + ) + + fig.add_trace( + go.Scatter( + x=invalid_cls["distance_delta"], + y=invalid_cls["relevant_pae"], + customdata=np.stack( + ( + invalid_cls["Crosslinker"], + invalid_cls["measured_distance"], + invalid_cls["alphafold_distance"], + invalid_cls["distance_delta"], + invalid_cls["relevant_pae"], + ), + axis=-1, + ), + mode="markers", + name="CLs not matching prediction", + hovertemplate=hovertemplate, + ) + ) # X axis range should start as close to 0 as reasonable and extend to max value min_dist_delta = min(cl_results_df["distance_delta"]) @@ -1268,7 +1299,7 @@ def get_relevant_pae_value( xmin = np.log10(min_dist_delta) xmax = np.log10(max_dist_delta) - xmax += np.log10(1.2) # reasonable padding + xmax += np.log10(1.2) # reasonable padding # Y axis must start at 0 and extend to max relevant PAE + padding ymin = 0 @@ -1288,7 +1319,6 @@ def get_relevant_pae_value( range=[ymin, ymax], ) - fig.update_layout( title=dict( text=f"Identified Crosslinks vs. Predicted Structure ({', '.join(structures_to_validate)})" From c8c3eb8fb1f5ada5be2bf3380e44636a03fe70a5 Mon Sep 17 00:00:00 2001 From: jorisfu Date: Thu, 28 May 2026 10:10:42 +0200 Subject: [PATCH 32/33] fix tests --- .../importing/test_alphafold_protein_structure_load.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/backend/tests/protzilla/importing/test_alphafold_protein_structure_load.py b/backend/tests/protzilla/importing/test_alphafold_protein_structure_load.py index 6a51c0289..8cff2b7b7 100644 --- a/backend/tests/protzilla/importing/test_alphafold_protein_structure_load.py +++ b/backend/tests/protzilla/importing/test_alphafold_protein_structure_load.py @@ -505,7 +505,7 @@ def test_upload_multimer_prediction_basic(tmp_path, monkeypatch): '{"chain_iptm": [0.42, 0.89]}' ) # Note that we do not use these metrics anywhere full = tmp_path / "full.json" - full.write_text('{"random_column": [1,2], "pae": [[1, 2], [3, 4]]}') + full.write_text('{"random_column": [1,2], "pae": [[1, 2], [3, 4]]}, "token_res_ids": [1, 2, 1]') job_request = tmp_path / "job_request.json" job_request.write_text( json.dumps( @@ -768,7 +768,7 @@ def test_upload_multimer_prediction_no_persist(tmp_path, monkeypatch): conf = tmp_path / "conf.json" conf.write_text('[{"residueNumber":1, "confidenceScore":99}]') full = tmp_path / "full.json" - full.write_text('{"a": [1,2]}') + full.write_text('{"a": [1,2], "token_res_ids": [1]}') job_request = tmp_path / "job_request.json" job_request.write_text( json.dumps( @@ -1060,7 +1060,7 @@ def test_get_multimer_structure_dfs_success(tmp_path, monkeypatch): full_data = prot_dir / "full.json" job_request = prot_dir / "job_request.json" confidence.write_text(json.dumps({"chain_iptm": [0.75]})) - full_data.write_text(json.dumps({"pae": [[0.1, 0.2], [0.3, 0.4]]})) + full_data.write_text(json.dumps({"pae": [[0.1, 0.2], [0.3, 0.4]], "token_res_ids": [1]})) job_request.write_text( json.dumps( [ From 5db98296a6e8010f03a7a4c4d78ea38bbbf1da14 Mon Sep 17 00:00:00 2001 From: jorisfu Date: Thu, 28 May 2026 11:21:48 +0200 Subject: [PATCH 33/33] chore: docstrings --- .../data_analysis/crosslinking_validation.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/backend/protzilla/data_analysis/crosslinking_validation.py b/backend/protzilla/data_analysis/crosslinking_validation.py index 5449c35d0..368f3c4c2 100644 --- a/backend/protzilla/data_analysis/crosslinking_validation.py +++ b/backend/protzilla/data_analysis/crosslinking_validation.py @@ -1083,6 +1083,17 @@ def cl_scatterplots_plddt( structures_to_validate: list[str], crosslinker_information: dict[str, list[float]], ) -> list[Figure]: + """ + Scatter plot for crosslinking validation data. Displays the deviation from the predicted structure + on the x-axis (log scaled) and the average pLDDT at the binding sites on the y-axis. + A dot is placed for each crosslinker identified. + + :param cl_results_df: The results from the crosslinking validation step + :param structures_to_validate: the protein IDs included in the validation + :param crosslinker_information: the name: (length, upper devation, lower deviation) bounds for each CL. + + :return: The Plot + """ figures: list[Figure] = [] @@ -1201,6 +1212,18 @@ def cl_scatterplots_pae( crosslinker_information: dict[str, list[float]], validation_criterion: CrosslinkingValidationCriterion, ) -> list[Figure]: + """ + Scatter plot for crosslinking validation data. Displays the deviation from the predicted structure + on the x-axis (log scaled) and the PAE value used for validation (min/max) between the binding sites on the y-axis. + A dot is placed for each crosslinker identified. + + :param cl_results_df: The results from the crosslinking validation step + :param structures_to_validate: the protein IDs included in the validation + :param crosslinker_information: the name: (length, upper devation, lower deviation) bounds for each CL + :param validation_criterion: the validation criterion used for the validation + + :return: The Plot + """ figures: list[Figure] = []