diff --git a/backend/protzilla/data_analysis/crosslinking_validation.py b/backend/protzilla/data_analysis/crosslinking_validation.py index 8d8b765d..368f3c4c 100644 --- a/backend/protzilla/data_analysis/crosslinking_validation.py +++ b/backend/protzilla/data_analysis/crosslinking_validation.py @@ -12,6 +12,7 @@ import plotly.graph_objects as go from plotly.graph_objects import Figure +import plotly.express as px from backend.protzilla.data_preprocessing.plots import ( create_histograms, @@ -778,6 +779,8 @@ def get_paes(): "plddt_at_position2": plddt_at_position2, "pae_x_position1": pae_x_position1, "pae_x_position2": pae_x_position2, + "accepted_distance_lower_bound": accepted_distance_lower_bound, + "accepted_distance_upper_bound": accepted_distance_upper_bound, } ) @@ -791,6 +794,8 @@ def get_paes(): "plddt_at_position2", "pae_x_position1", "pae_x_position2", + "accepted_distance_lower_bound", + "accepted_distance_upper_bound", ] relevant_crosslinks_df["crosslinker_position1"] = relevant_crosslinks_df[ @@ -1073,6 +1078,280 @@ def diagrams_of_crosslinking_validation_data( return figures +def cl_scatterplots_plddt( + cl_results_df: pd.DataFrame, + structures_to_validate: list[str], + crosslinker_information: dict[str, list[float]], +) -> list[Figure]: + """ + Scatter plot for crosslinking validation data. Displays the deviation from the predicted structure + on the x-axis (log scaled) and the average pLDDT at the binding sites on the y-axis. + A dot is placed for each crosslinker identified. + + :param cl_results_df: The results from the crosslinking validation step + :param structures_to_validate: the protein IDs included in the validation + :param crosslinker_information: the name: (length, upper devation, lower deviation) bounds for each CL. + + :return: The Plot + """ + + figures: list[Figure] = [] + + cl_results_df["measured_distance"] = cl_results_df.apply( + lambda row: crosslinker_information[row["Crosslinker"]][0], axis=1 + ) + cl_results_df["distance_delta"] = abs( + cl_results_df["measured_distance"] - cl_results_df["alphafold_distance"] + ) + cl_results_df["avg_plddt"] = cl_results_df.apply( + lambda row: np.average([row["plddt_at_position1"], row["plddt_at_position2"]]), + axis=1, + ) + + y_label = "Avg. pLDDT at binding sites" + + fig = go.Figure() + + valid_cls = cl_results_df[cl_results_df["valid_crosslink"]] + invalid_cls = cl_results_df[~cl_results_df["valid_crosslink"]] + + hovertemplate = ( + "%{customdata[0]} (Length %{customdata[1]}Å)" + + "
Predicted distance: %{customdata[2]:.2f}Å " + + "(off by %{customdata[3]:.2f}Å)" + + "
Accepted distance range %{customdata[4]:.2f} - %{customdata[5]:.2f} Å" + + "" + ) + + fig.add_trace( + go.Scatter( + x=valid_cls["distance_delta"], + y=valid_cls["avg_plddt"], + customdata=np.stack( + ( + valid_cls["Crosslinker"], + valid_cls["measured_distance"], + valid_cls["alphafold_distance"], + valid_cls["distance_delta"], + valid_cls["accepted_distance_lower_bound"], + valid_cls["accepted_distance_upper_bound"], + ), + axis=-1, + ), + mode="markers", + name="CLs matching prediction", + hovertemplate=hovertemplate, + ) + ) + + fig.add_trace( + go.Scatter( + x=invalid_cls["distance_delta"], + y=invalid_cls["avg_plddt"], + customdata=np.stack( + ( + invalid_cls["Crosslinker"], + invalid_cls["measured_distance"], + invalid_cls["alphafold_distance"], + invalid_cls["distance_delta"], + invalid_cls["accepted_distance_lower_bound"], + invalid_cls["accepted_distance_upper_bound"], + ), + axis=-1, + ), + mode="markers", + name="CLs not matching prediction", + hovertemplate=hovertemplate, + ) + ) + + # X axis range should start as close to 0 as reasonable and extend to max value + min_dist_delta = min(cl_results_df["distance_delta"]) + max_dist_delta = max(cl_results_df["distance_delta"]) + + xmin = -1 + if np.log10(xmin) > min_dist_delta: + xmin = np.log10(min_dist_delta) + + xmax = np.log10(max_dist_delta) + xmax += np.log10(1.2) # reasonable padding + + # Y axis must start at 0 and extend to max avg pLDDT + padding + ymin = 0 + ymax = max(cl_results_df["avg_plddt"]) + 5 + + fig.update_xaxes( + type="log", + title_text="Deviation from CL length in predicted structure (Å) (log scaled)", + tickmode="linear", + tick0=0, + dtick=np.log10(2), + range=[xmin, xmax], + ) + + fig.update_yaxes( + title_text=y_label, + range=[ymin, ymax], + ) + + fig.update_layout( + title=dict( + text=f"Identified Crosslinks vs. Predicted Structure ({', '.join(structures_to_validate)})" + ), + showlegend=True, + ) + + figures.append(fig) + + return figures + + +def cl_scatterplots_pae( + cl_results_df: pd.DataFrame, + structures_to_validate: list[str], + crosslinker_information: dict[str, list[float]], + validation_criterion: CrosslinkingValidationCriterion, +) -> list[Figure]: + """ + Scatter plot for crosslinking validation data. Displays the deviation from the predicted structure + on the x-axis (log scaled) and the PAE value used for validation (min/max) between the binding sites on the y-axis. + A dot is placed for each crosslinker identified. + + :param cl_results_df: The results from the crosslinking validation step + :param structures_to_validate: the protein IDs included in the validation + :param crosslinker_information: the name: (length, upper devation, lower deviation) bounds for each CL + :param validation_criterion: the validation criterion used for the validation + + :return: The Plot + """ + + figures: list[Figure] = [] + + def get_relevant_pae_value( + pae_x_1: float, + pae_x_2: float, + validation_criterion: CrosslinkingValidationCriterion, + ): + if validation_criterion == CrosslinkingValidationCriterion.max_pae.value: + return max(pae_x_1, pae_x_2) + elif validation_criterion == CrosslinkingValidationCriterion.min_pae.value: + return min(pae_x_1, pae_x_2) + else: + raise ValueError("Illegal validation criterion for PAE plot") + + cl_results_df["measured_distance"] = cl_results_df.apply( + lambda row: crosslinker_information[row["Crosslinker"]][0], axis=1 + ) + cl_results_df["distance_delta"] = abs( + cl_results_df["measured_distance"] - cl_results_df["alphafold_distance"] + ) + cl_results_df["relevant_pae"] = cl_results_df.apply( + lambda row: get_relevant_pae_value( + row["pae_x_position1"], row["pae_x_position2"], validation_criterion + ), + axis=1, + ) + + y_label = ( + "Max. PAE between binding sites" + if validation_criterion == CrosslinkingValidationCriterion.max_pae.value + else "Min. PAE between binding sites" + ) + + fig = go.Figure() + + valid_cls = cl_results_df[cl_results_df["valid_crosslink"]] + invalid_cls = cl_results_df[~cl_results_df["valid_crosslink"]] + + hovertemplate = ( + "%{customdata[0]} (Length %{customdata[1]}Å)" + + "
Predicted distance: %{customdata[2]:.2f}Å " + + "(off by %{customdata[3]:.2f}Å)" + + "
PAE %{customdata[4]:.2f}Å" + + "" + ) + + fig.add_trace( + go.Scatter( + x=valid_cls["distance_delta"], + y=valid_cls["relevant_pae"], + customdata=np.stack( + ( + valid_cls["Crosslinker"], + valid_cls["measured_distance"], + valid_cls["alphafold_distance"], + valid_cls["distance_delta"], + valid_cls["relevant_pae"], + ), + axis=-1, + ), + mode="markers", + name="CLs matching prediction", + hovertemplate=hovertemplate, + ) + ) + + fig.add_trace( + go.Scatter( + x=invalid_cls["distance_delta"], + y=invalid_cls["relevant_pae"], + customdata=np.stack( + ( + invalid_cls["Crosslinker"], + invalid_cls["measured_distance"], + invalid_cls["alphafold_distance"], + invalid_cls["distance_delta"], + invalid_cls["relevant_pae"], + ), + axis=-1, + ), + mode="markers", + name="CLs not matching prediction", + hovertemplate=hovertemplate, + ) + ) + + # X axis range should start as close to 0 as reasonable and extend to max value + min_dist_delta = min(cl_results_df["distance_delta"]) + max_dist_delta = max(cl_results_df["distance_delta"]) + + xmin = -1 + if np.log10(xmin) > min_dist_delta: + xmin = np.log10(min_dist_delta) + + xmax = np.log10(max_dist_delta) + xmax += np.log10(1.2) # reasonable padding + + # Y axis must start at 0 and extend to max relevant PAE + padding + ymin = 0 + ymax = max(cl_results_df["relevant_pae"]) + 5 + + fig.update_xaxes( + type="log", + title_text="Deviation from CL length in predicted structure (Å) (log scaled)", + tickmode="linear", + tick0=0, + dtick=np.log10(2), + range=[xmin, xmax], + ) + + fig.update_yaxes( + title_text=y_label, + range=[ymin, ymax], + ) + + fig.update_layout( + title=dict( + text=f"Identified Crosslinks vs. Predicted Structure ({', '.join(structures_to_validate)})" + ), + showlegend=True, + ) + + figures.append(fig) + + return figures + + def monomer_diagrams( output_crosslinking_result_df: pd.DataFrame, structure_metadata_df: pd.DataFrame, @@ -1101,21 +1380,20 @@ def monomer_diagrams( crosslinker_information=crosslinker_information, ) - # TODO: Separate Issue #429 case ( CrosslinkingValidationCriterion.max_pae.value | CrosslinkingValidationCriterion.min_pae.value ): - return diagrams_of_crosslinking_validation_data( - validated_df=output_crosslinking_result_df, + return cl_scatterplots_pae( + cl_results_df=output_crosslinking_result_df, structures_to_validate=structures_to_validate, crosslinker_information=crosslinker_information, + validation_criterion=validation_criterion, ) - # TODO: Separate Issue #429 case CrosslinkingValidationCriterion.plddt_adjusted.value: - return diagrams_of_crosslinking_validation_data( - validated_df=output_crosslinking_result_df, + return cl_scatterplots_plddt( + cl_results_df=output_crosslinking_result_df, structures_to_validate=structures_to_validate, crosslinker_information=crosslinker_information, ) @@ -1159,21 +1437,20 @@ def multimer_diagrams( crosslinker_information=crosslinker_information, ) - # TODO: Separate Issue #429 case ( CrosslinkingValidationCriterion.max_pae.value | CrosslinkingValidationCriterion.min_pae.value ): - return diagrams_of_crosslinking_validation_data( - validated_df=output_crosslinking_result_df, + return cl_scatterplots_pae( + cl_results_df=output_crosslinking_result_df, structures_to_validate=structures_to_validate, crosslinker_information=crosslinker_information, + validation_criterion=validation_criterion, ) - # TODO: Separate Issue #429 case CrosslinkingValidationCriterion.plddt_adjusted.value: - return diagrams_of_crosslinking_validation_data( - validated_df=output_crosslinking_result_df, + return cl_scatterplots_plddt( + cl_results_df=output_crosslinking_result_df, structures_to_validate=structures_to_validate, crosslinker_information=crosslinker_information, )