Skip to content
Open
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
9f8b5b2
feat: add pLDDT to CL results table
jorisfu May 11, 2026
24eb514
feat: add trivial PAE based validation
jorisfu May 11, 2026
1ddc103
feat: trivial plDDT based validation
jorisfu May 12, 2026
77498d5
fix: broken formula
jorisfu May 15, 2026
0ab6fa0
fix monomer validation test
jorisfu May 15, 2026
291bd7d
refactor: expose PAE as matrix for monomers
jorisfu May 18, 2026
23d5b60
feat: PAE for multimers
jorisfu May 18, 2026
1e0e30e
feat: pLDDT for multimers
jorisfu May 19, 2026
eabd547
feat: add PAE/plDDT consistently to multimer imports
jorisfu May 19, 2026
bcb054f
feat: proper PAE validation for multimers
jorisfu May 19, 2026
58ac815
fix: adjust existing cl validation tests
jorisfu May 19, 2026
634e712
fix: some alphafold import tests
jorisfu May 19, 2026
2c66b35
tempfix: bridge monomer plots so method doesn't fail
jorisfu May 21, 2026
7ecac2b
tempfix: bridge multimer plots so method doesn't fail
jorisfu May 21, 2026
213e340
merge crosslinking
jorisfu May 21, 2026
f2518c1
chore: remove obsolete todos
jorisfu May 21, 2026
e13e2a3
chore: adjust some tests
jorisfu May 21, 2026
0c645e5
chore: adjust some tests
jorisfu May 21, 2026
c1d55a1
feat: introduce parsing of _chem_comp table in cif-files
tE3m May 10, 2026
a935874
chore: fix existing tests
jorisfu May 21, 2026
d506787
chore: test for no pLDDT data within cif
jorisfu May 21, 2026
a034f65
chore: tests for PAE based CL validation
jorisfu May 21, 2026
6ec24c0
chore: tests for pLDDT based CL validation
jorisfu May 21, 2026
2643937
feat: simple PAE scatter plot
jorisfu May 22, 2026
d5be1bd
feat: AF3 to AF2 PAE matrix translation
jorisfu May 26, 2026
5b067af
(AI) tests: PAE matrix reduction
jorisfu May 26, 2026
a1a5a91
chore: remove unused imports
jorisfu May 26, 2026
d1e3cf5
feat: only make bounds fields visible if manual bounds is selected mode
jorisfu May 26, 2026
05580e9
chore: black
jorisfu May 26, 2026
732d838
chore: clean up PAE plots and add to multimer validation
jorisfu May 27, 2026
5dea77d
feat: plddt plot and some corrections
jorisfu May 27, 2026
2fba6ad
chore: black
jorisfu May 27, 2026
c8c3eb8
fix tests
jorisfu May 28, 2026
9209eab
merge crosslinking
jorisfu May 28, 2026
5db9829
chore: docstrings
jorisfu May 28, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
301 changes: 289 additions & 12 deletions backend/protzilla/data_analysis/crosslinking_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import plotly.graph_objects as go
from plotly.graph_objects import Figure
import plotly.express as px

from backend.protzilla.data_preprocessing.plots import (
create_histograms,
Expand Down Expand Up @@ -778,6 +779,8 @@ def get_paes():
"plddt_at_position2": plddt_at_position2,
"pae_x_position1": pae_x_position1,
"pae_x_position2": pae_x_position2,
"accepted_distance_lower_bound": accepted_distance_lower_bound,
"accepted_distance_upper_bound": accepted_distance_upper_bound,
}
)

Expand All @@ -791,6 +794,8 @@ def get_paes():
"plddt_at_position2",
"pae_x_position1",
"pae_x_position2",
"accepted_distance_lower_bound",
"accepted_distance_upper_bound",
]

relevant_crosslinks_df["crosslinker_position1"] = relevant_crosslinks_df[
Expand Down Expand Up @@ -1073,6 +1078,280 @@ def diagrams_of_crosslinking_validation_data(
return figures


def cl_scatterplots_plddt(
cl_results_df: pd.DataFrame,
structures_to_validate: list[str],
crosslinker_information: dict[str, list[float]],
) -> list[Figure]:
"""
Scatter plot for crosslinking validation data. Displays the deviation from the predicted structure
on the x-axis (log scaled) and the average pLDDT at the binding sites on the y-axis.
A dot is placed for each crosslinker identified.

:param cl_results_df: The results from the crosslinking validation step
:param structures_to_validate: the protein IDs included in the validation
:param crosslinker_information: the name: (length, upper devation, lower deviation) bounds for each CL.

:return: The Plot
"""

figures: list[Figure] = []

cl_results_df["measured_distance"] = cl_results_df.apply(
lambda row: crosslinker_information[row["Crosslinker"]][0], axis=1
)
cl_results_df["distance_delta"] = abs(
cl_results_df["measured_distance"] - cl_results_df["alphafold_distance"]
)
cl_results_df["avg_plddt"] = cl_results_df.apply(
lambda row: np.average([row["plddt_at_position1"], row["plddt_at_position2"]]),
axis=1,
)

y_label = "Avg. pLDDT at binding sites"

fig = go.Figure()

valid_cls = cl_results_df[cl_results_df["valid_crosslink"]]
invalid_cls = cl_results_df[~cl_results_df["valid_crosslink"]]

hovertemplate = (
"%{customdata[0]} (Length %{customdata[1]}Å)"
+ "<br>Predicted distance: %{customdata[2]:.2f}Å "
+ "(off by %{customdata[3]:.2f}Å)"
+ "<br>Accepted distance range %{customdata[4]:.2f} - %{customdata[5]:.2f} Å"
+ "<extra></extra>"
)

fig.add_trace(
go.Scatter(
x=valid_cls["distance_delta"],
y=valid_cls["avg_plddt"],
customdata=np.stack(
(
valid_cls["Crosslinker"],
valid_cls["measured_distance"],
valid_cls["alphafold_distance"],
valid_cls["distance_delta"],
valid_cls["accepted_distance_lower_bound"],
valid_cls["accepted_distance_upper_bound"],
),
axis=-1,
),
mode="markers",
name="CLs matching prediction",
hovertemplate=hovertemplate,
)
)

fig.add_trace(
go.Scatter(
x=invalid_cls["distance_delta"],
y=invalid_cls["avg_plddt"],
customdata=np.stack(
(
invalid_cls["Crosslinker"],
invalid_cls["measured_distance"],
invalid_cls["alphafold_distance"],
invalid_cls["distance_delta"],
invalid_cls["accepted_distance_lower_bound"],
invalid_cls["accepted_distance_upper_bound"],
),
axis=-1,
),
mode="markers",
name="CLs not matching prediction",
hovertemplate=hovertemplate,
)
)

# X axis range should start as close to 0 as reasonable and extend to max value
min_dist_delta = min(cl_results_df["distance_delta"])
max_dist_delta = max(cl_results_df["distance_delta"])

xmin = -1
if np.log10(xmin) > min_dist_delta:
xmin = np.log10(min_dist_delta)

xmax = np.log10(max_dist_delta)
xmax += np.log10(1.2) # reasonable padding

# Y axis must start at 0 and extend to max avg pLDDT + padding
ymin = 0
ymax = max(cl_results_df["avg_plddt"]) + 5

fig.update_xaxes(
type="log",
title_text="Deviation from CL length in predicted structure (Å) (log scaled)",
tickmode="linear",
tick0=0,
dtick=np.log10(2),
range=[xmin, xmax],
)

fig.update_yaxes(
title_text=y_label,
range=[ymin, ymax],
)

fig.update_layout(
title=dict(
text=f"Identified Crosslinks vs. Predicted Structure ({', '.join(structures_to_validate)})"
),
showlegend=True,
)

figures.append(fig)

return figures


def cl_scatterplots_pae(
cl_results_df: pd.DataFrame,
structures_to_validate: list[str],
crosslinker_information: dict[str, list[float]],
validation_criterion: CrosslinkingValidationCriterion,
) -> list[Figure]:
"""
Scatter plot for crosslinking validation data. Displays the deviation from the predicted structure
on the x-axis (log scaled) and the PAE value used for validation (min/max) between the binding sites on the y-axis.
A dot is placed for each crosslinker identified.

:param cl_results_df: The results from the crosslinking validation step
:param structures_to_validate: the protein IDs included in the validation
:param crosslinker_information: the name: (length, upper devation, lower deviation) bounds for each CL
:param validation_criterion: the validation criterion used for the validation

:return: The Plot
"""

figures: list[Figure] = []

def get_relevant_pae_value(
pae_x_1: float,
pae_x_2: float,
validation_criterion: CrosslinkingValidationCriterion,
):
if validation_criterion == CrosslinkingValidationCriterion.max_pae.value:
return max(pae_x_1, pae_x_2)
elif validation_criterion == CrosslinkingValidationCriterion.min_pae.value:
return min(pae_x_1, pae_x_2)
else:
raise ValueError("Illegal validation criterion for PAE plot")

cl_results_df["measured_distance"] = cl_results_df.apply(
lambda row: crosslinker_information[row["Crosslinker"]][0], axis=1
)
cl_results_df["distance_delta"] = abs(
cl_results_df["measured_distance"] - cl_results_df["alphafold_distance"]
)
cl_results_df["relevant_pae"] = cl_results_df.apply(
lambda row: get_relevant_pae_value(
row["pae_x_position1"], row["pae_x_position2"], validation_criterion
),
axis=1,
)

y_label = (
"Max. PAE between binding sites"
if validation_criterion == CrosslinkingValidationCriterion.max_pae.value
else "Min. PAE between binding sites"
)

fig = go.Figure()

valid_cls = cl_results_df[cl_results_df["valid_crosslink"]]
invalid_cls = cl_results_df[~cl_results_df["valid_crosslink"]]

hovertemplate = (
"%{customdata[0]} (Length %{customdata[1]}Å)"
+ "<br>Predicted distance: %{customdata[2]:.2f}Å "
+ "(off by %{customdata[3]:.2f}Å)"
+ "<br>PAE %{customdata[4]:.2f}Å"
+ "<extra></extra>"
)

fig.add_trace(
go.Scatter(
x=valid_cls["distance_delta"],
y=valid_cls["relevant_pae"],
customdata=np.stack(
(
valid_cls["Crosslinker"],
valid_cls["measured_distance"],
valid_cls["alphafold_distance"],
valid_cls["distance_delta"],
valid_cls["relevant_pae"],
),
axis=-1,
),
mode="markers",
name="CLs matching prediction",
hovertemplate=hovertemplate,
)
)

fig.add_trace(
go.Scatter(
x=invalid_cls["distance_delta"],
y=invalid_cls["relevant_pae"],
customdata=np.stack(
(
invalid_cls["Crosslinker"],
invalid_cls["measured_distance"],
invalid_cls["alphafold_distance"],
invalid_cls["distance_delta"],
invalid_cls["relevant_pae"],
),
axis=-1,
),
mode="markers",
name="CLs not matching prediction",
hovertemplate=hovertemplate,
)
)

# X axis range should start as close to 0 as reasonable and extend to max value
min_dist_delta = min(cl_results_df["distance_delta"])
max_dist_delta = max(cl_results_df["distance_delta"])

xmin = -1
if np.log10(xmin) > min_dist_delta:
xmin = np.log10(min_dist_delta)

xmax = np.log10(max_dist_delta)
xmax += np.log10(1.2) # reasonable padding

# Y axis must start at 0 and extend to max relevant PAE + padding
ymin = 0
ymax = max(cl_results_df["relevant_pae"]) + 5

fig.update_xaxes(
type="log",
title_text="Deviation from CL length in predicted structure (Å) (log scaled)",
tickmode="linear",
tick0=0,
dtick=np.log10(2),
range=[xmin, xmax],
)

fig.update_yaxes(
title_text=y_label,
range=[ymin, ymax],
)

fig.update_layout(
title=dict(
text=f"Identified Crosslinks vs. Predicted Structure ({', '.join(structures_to_validate)})"
),
showlegend=True,
)

figures.append(fig)

return figures


def monomer_diagrams(
output_crosslinking_result_df: pd.DataFrame,
structure_metadata_df: pd.DataFrame,
Expand Down Expand Up @@ -1101,21 +1380,20 @@ def monomer_diagrams(
crosslinker_information=crosslinker_information,
)

# TODO: Separate Issue #429
case (
CrosslinkingValidationCriterion.max_pae.value
| CrosslinkingValidationCriterion.min_pae.value
):
return diagrams_of_crosslinking_validation_data(
validated_df=output_crosslinking_result_df,
return cl_scatterplots_pae(
cl_results_df=output_crosslinking_result_df,
structures_to_validate=structures_to_validate,
crosslinker_information=crosslinker_information,
validation_criterion=validation_criterion,
)

# TODO: Separate Issue #429
case CrosslinkingValidationCriterion.plddt_adjusted.value:
return diagrams_of_crosslinking_validation_data(
validated_df=output_crosslinking_result_df,
return cl_scatterplots_plddt(
cl_results_df=output_crosslinking_result_df,
structures_to_validate=structures_to_validate,
crosslinker_information=crosslinker_information,
)
Expand Down Expand Up @@ -1159,21 +1437,20 @@ def multimer_diagrams(
crosslinker_information=crosslinker_information,
)

# TODO: Separate Issue #429
case (
CrosslinkingValidationCriterion.max_pae.value
| CrosslinkingValidationCriterion.min_pae.value
):
return diagrams_of_crosslinking_validation_data(
validated_df=output_crosslinking_result_df,
return cl_scatterplots_pae(
cl_results_df=output_crosslinking_result_df,
structures_to_validate=structures_to_validate,
crosslinker_information=crosslinker_information,
validation_criterion=validation_criterion,
)

# TODO: Separate Issue #429
case CrosslinkingValidationCriterion.plddt_adjusted.value:
return diagrams_of_crosslinking_validation_data(
validated_df=output_crosslinking_result_df,
return cl_scatterplots_plddt(
cl_results_df=output_crosslinking_result_df,
structures_to_validate=structures_to_validate,
crosslinker_information=crosslinker_information,
)
Expand Down
Loading