diff --git a/backend/protzilla/data_analysis/crosslinking_validation.py b/backend/protzilla/data_analysis/crosslinking_validation.py
index 8d8b765d..368f3c4c 100644
--- a/backend/protzilla/data_analysis/crosslinking_validation.py
+++ b/backend/protzilla/data_analysis/crosslinking_validation.py
@@ -12,6 +12,7 @@
import plotly.graph_objects as go
from plotly.graph_objects import Figure
+import plotly.express as px
from backend.protzilla.data_preprocessing.plots import (
create_histograms,
@@ -778,6 +779,8 @@ def get_paes():
"plddt_at_position2": plddt_at_position2,
"pae_x_position1": pae_x_position1,
"pae_x_position2": pae_x_position2,
+ "accepted_distance_lower_bound": accepted_distance_lower_bound,
+ "accepted_distance_upper_bound": accepted_distance_upper_bound,
}
)
@@ -791,6 +794,8 @@ def get_paes():
"plddt_at_position2",
"pae_x_position1",
"pae_x_position2",
+ "accepted_distance_lower_bound",
+ "accepted_distance_upper_bound",
]
relevant_crosslinks_df["crosslinker_position1"] = relevant_crosslinks_df[
@@ -1073,6 +1078,280 @@ def diagrams_of_crosslinking_validation_data(
return figures
+def cl_scatterplots_plddt(
+ cl_results_df: pd.DataFrame,
+ structures_to_validate: list[str],
+ crosslinker_information: dict[str, list[float]],
+) -> list[Figure]:
+ """
+ Scatter plot for crosslinking validation data. Displays the deviation from the predicted structure
+ on the x-axis (log scaled) and the average pLDDT at the binding sites on the y-axis.
+ A dot is placed for each crosslinker identified.
+
+ :param cl_results_df: The results from the crosslinking validation step
+ :param structures_to_validate: the protein IDs included in the validation
+ :param crosslinker_information: the name: (length, upper devation, lower deviation) bounds for each CL.
+
+ :return: The Plot
+ """
+
+ figures: list[Figure] = []
+
+ cl_results_df["measured_distance"] = cl_results_df.apply(
+ lambda row: crosslinker_information[row["Crosslinker"]][0], axis=1
+ )
+ cl_results_df["distance_delta"] = abs(
+ cl_results_df["measured_distance"] - cl_results_df["alphafold_distance"]
+ )
+ cl_results_df["avg_plddt"] = cl_results_df.apply(
+ lambda row: np.average([row["plddt_at_position1"], row["plddt_at_position2"]]),
+ axis=1,
+ )
+
+ y_label = "Avg. pLDDT at binding sites"
+
+ fig = go.Figure()
+
+ valid_cls = cl_results_df[cl_results_df["valid_crosslink"]]
+ invalid_cls = cl_results_df[~cl_results_df["valid_crosslink"]]
+
+ hovertemplate = (
+ "%{customdata[0]} (Length %{customdata[1]}Å)"
+ + "
Predicted distance: %{customdata[2]:.2f}Å "
+ + "(off by %{customdata[3]:.2f}Å)"
+ + "
Accepted distance range %{customdata[4]:.2f} - %{customdata[5]:.2f} Å"
+ + ""
+ )
+
+ fig.add_trace(
+ go.Scatter(
+ x=valid_cls["distance_delta"],
+ y=valid_cls["avg_plddt"],
+ customdata=np.stack(
+ (
+ valid_cls["Crosslinker"],
+ valid_cls["measured_distance"],
+ valid_cls["alphafold_distance"],
+ valid_cls["distance_delta"],
+ valid_cls["accepted_distance_lower_bound"],
+ valid_cls["accepted_distance_upper_bound"],
+ ),
+ axis=-1,
+ ),
+ mode="markers",
+ name="CLs matching prediction",
+ hovertemplate=hovertemplate,
+ )
+ )
+
+ fig.add_trace(
+ go.Scatter(
+ x=invalid_cls["distance_delta"],
+ y=invalid_cls["avg_plddt"],
+ customdata=np.stack(
+ (
+ invalid_cls["Crosslinker"],
+ invalid_cls["measured_distance"],
+ invalid_cls["alphafold_distance"],
+ invalid_cls["distance_delta"],
+ invalid_cls["accepted_distance_lower_bound"],
+ invalid_cls["accepted_distance_upper_bound"],
+ ),
+ axis=-1,
+ ),
+ mode="markers",
+ name="CLs not matching prediction",
+ hovertemplate=hovertemplate,
+ )
+ )
+
+ # X axis range should start as close to 0 as reasonable and extend to max value
+ min_dist_delta = min(cl_results_df["distance_delta"])
+ max_dist_delta = max(cl_results_df["distance_delta"])
+
+ xmin = -1
+ if np.log10(xmin) > min_dist_delta:
+ xmin = np.log10(min_dist_delta)
+
+ xmax = np.log10(max_dist_delta)
+ xmax += np.log10(1.2) # reasonable padding
+
+ # Y axis must start at 0 and extend to max avg pLDDT + padding
+ ymin = 0
+ ymax = max(cl_results_df["avg_plddt"]) + 5
+
+ fig.update_xaxes(
+ type="log",
+ title_text="Deviation from CL length in predicted structure (Å) (log scaled)",
+ tickmode="linear",
+ tick0=0,
+ dtick=np.log10(2),
+ range=[xmin, xmax],
+ )
+
+ fig.update_yaxes(
+ title_text=y_label,
+ range=[ymin, ymax],
+ )
+
+ fig.update_layout(
+ title=dict(
+ text=f"Identified Crosslinks vs. Predicted Structure ({', '.join(structures_to_validate)})"
+ ),
+ showlegend=True,
+ )
+
+ figures.append(fig)
+
+ return figures
+
+
+def cl_scatterplots_pae(
+ cl_results_df: pd.DataFrame,
+ structures_to_validate: list[str],
+ crosslinker_information: dict[str, list[float]],
+ validation_criterion: CrosslinkingValidationCriterion,
+) -> list[Figure]:
+ """
+ Scatter plot for crosslinking validation data. Displays the deviation from the predicted structure
+ on the x-axis (log scaled) and the PAE value used for validation (min/max) between the binding sites on the y-axis.
+ A dot is placed for each crosslinker identified.
+
+ :param cl_results_df: The results from the crosslinking validation step
+ :param structures_to_validate: the protein IDs included in the validation
+ :param crosslinker_information: the name: (length, upper devation, lower deviation) bounds for each CL
+ :param validation_criterion: the validation criterion used for the validation
+
+ :return: The Plot
+ """
+
+ figures: list[Figure] = []
+
+ def get_relevant_pae_value(
+ pae_x_1: float,
+ pae_x_2: float,
+ validation_criterion: CrosslinkingValidationCriterion,
+ ):
+ if validation_criterion == CrosslinkingValidationCriterion.max_pae.value:
+ return max(pae_x_1, pae_x_2)
+ elif validation_criterion == CrosslinkingValidationCriterion.min_pae.value:
+ return min(pae_x_1, pae_x_2)
+ else:
+ raise ValueError("Illegal validation criterion for PAE plot")
+
+ cl_results_df["measured_distance"] = cl_results_df.apply(
+ lambda row: crosslinker_information[row["Crosslinker"]][0], axis=1
+ )
+ cl_results_df["distance_delta"] = abs(
+ cl_results_df["measured_distance"] - cl_results_df["alphafold_distance"]
+ )
+ cl_results_df["relevant_pae"] = cl_results_df.apply(
+ lambda row: get_relevant_pae_value(
+ row["pae_x_position1"], row["pae_x_position2"], validation_criterion
+ ),
+ axis=1,
+ )
+
+ y_label = (
+ "Max. PAE between binding sites"
+ if validation_criterion == CrosslinkingValidationCriterion.max_pae.value
+ else "Min. PAE between binding sites"
+ )
+
+ fig = go.Figure()
+
+ valid_cls = cl_results_df[cl_results_df["valid_crosslink"]]
+ invalid_cls = cl_results_df[~cl_results_df["valid_crosslink"]]
+
+ hovertemplate = (
+ "%{customdata[0]} (Length %{customdata[1]}Å)"
+ + "
Predicted distance: %{customdata[2]:.2f}Å "
+ + "(off by %{customdata[3]:.2f}Å)"
+ + "
PAE %{customdata[4]:.2f}Å"
+ + ""
+ )
+
+ fig.add_trace(
+ go.Scatter(
+ x=valid_cls["distance_delta"],
+ y=valid_cls["relevant_pae"],
+ customdata=np.stack(
+ (
+ valid_cls["Crosslinker"],
+ valid_cls["measured_distance"],
+ valid_cls["alphafold_distance"],
+ valid_cls["distance_delta"],
+ valid_cls["relevant_pae"],
+ ),
+ axis=-1,
+ ),
+ mode="markers",
+ name="CLs matching prediction",
+ hovertemplate=hovertemplate,
+ )
+ )
+
+ fig.add_trace(
+ go.Scatter(
+ x=invalid_cls["distance_delta"],
+ y=invalid_cls["relevant_pae"],
+ customdata=np.stack(
+ (
+ invalid_cls["Crosslinker"],
+ invalid_cls["measured_distance"],
+ invalid_cls["alphafold_distance"],
+ invalid_cls["distance_delta"],
+ invalid_cls["relevant_pae"],
+ ),
+ axis=-1,
+ ),
+ mode="markers",
+ name="CLs not matching prediction",
+ hovertemplate=hovertemplate,
+ )
+ )
+
+ # X axis range should start as close to 0 as reasonable and extend to max value
+ min_dist_delta = min(cl_results_df["distance_delta"])
+ max_dist_delta = max(cl_results_df["distance_delta"])
+
+ xmin = -1
+ if np.log10(xmin) > min_dist_delta:
+ xmin = np.log10(min_dist_delta)
+
+ xmax = np.log10(max_dist_delta)
+ xmax += np.log10(1.2) # reasonable padding
+
+ # Y axis must start at 0 and extend to max relevant PAE + padding
+ ymin = 0
+ ymax = max(cl_results_df["relevant_pae"]) + 5
+
+ fig.update_xaxes(
+ type="log",
+ title_text="Deviation from CL length in predicted structure (Å) (log scaled)",
+ tickmode="linear",
+ tick0=0,
+ dtick=np.log10(2),
+ range=[xmin, xmax],
+ )
+
+ fig.update_yaxes(
+ title_text=y_label,
+ range=[ymin, ymax],
+ )
+
+ fig.update_layout(
+ title=dict(
+ text=f"Identified Crosslinks vs. Predicted Structure ({', '.join(structures_to_validate)})"
+ ),
+ showlegend=True,
+ )
+
+ figures.append(fig)
+
+ return figures
+
+
def monomer_diagrams(
output_crosslinking_result_df: pd.DataFrame,
structure_metadata_df: pd.DataFrame,
@@ -1101,21 +1380,20 @@ def monomer_diagrams(
crosslinker_information=crosslinker_information,
)
- # TODO: Separate Issue #429
case (
CrosslinkingValidationCriterion.max_pae.value
| CrosslinkingValidationCriterion.min_pae.value
):
- return diagrams_of_crosslinking_validation_data(
- validated_df=output_crosslinking_result_df,
+ return cl_scatterplots_pae(
+ cl_results_df=output_crosslinking_result_df,
structures_to_validate=structures_to_validate,
crosslinker_information=crosslinker_information,
+ validation_criterion=validation_criterion,
)
- # TODO: Separate Issue #429
case CrosslinkingValidationCriterion.plddt_adjusted.value:
- return diagrams_of_crosslinking_validation_data(
- validated_df=output_crosslinking_result_df,
+ return cl_scatterplots_plddt(
+ cl_results_df=output_crosslinking_result_df,
structures_to_validate=structures_to_validate,
crosslinker_information=crosslinker_information,
)
@@ -1159,21 +1437,20 @@ def multimer_diagrams(
crosslinker_information=crosslinker_information,
)
- # TODO: Separate Issue #429
case (
CrosslinkingValidationCriterion.max_pae.value
| CrosslinkingValidationCriterion.min_pae.value
):
- return diagrams_of_crosslinking_validation_data(
- validated_df=output_crosslinking_result_df,
+ return cl_scatterplots_pae(
+ cl_results_df=output_crosslinking_result_df,
structures_to_validate=structures_to_validate,
crosslinker_information=crosslinker_information,
+ validation_criterion=validation_criterion,
)
- # TODO: Separate Issue #429
case CrosslinkingValidationCriterion.plddt_adjusted.value:
- return diagrams_of_crosslinking_validation_data(
- validated_df=output_crosslinking_result_df,
+ return cl_scatterplots_plddt(
+ cl_results_df=output_crosslinking_result_df,
structures_to_validate=structures_to_validate,
crosslinker_information=crosslinker_information,
)