-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot.py
More file actions
57 lines (50 loc) · 1.91 KB
/
plot.py
File metadata and controls
57 lines (50 loc) · 1.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import pandas as pd
from lifelines import KaplanMeierFitter
# Initialize Kaplan Meier Fitter
kmf = KaplanMeierFitter()
# Predefined measures
measure_parameters = {
"None": {
"column": None,
"group_condition": lambda df: slice(None),
"labels": ["None"],
},
"Gender": {
"column": "sex",
"group_condition": lambda df: df["sex"] == 1,
"labels": ["Male", "Female"],
},
"Blood glucose": {
"column": "glucose",
"group_condition": lambda df: df["glucose"] < 5.5,
"labels": ["Low Glucose", "High Glucose"],
},
"Biological Age": {
"column": "biologically_older",
"group_condition": lambda df: df["biologically_older"] == 0,
"labels": ["Biologically younger", "Biologically older"],
},
}
def prepare_survival_dataframe(df, group_condition, label):
"""Prepares a survival function DataFrame based on a group condition and a label."""
time_to_event = df.months_until_death
event_occurred = df.is_dead
kmf.fit(
time_to_event[group_condition], event_occurred[group_condition], label=label
)
survival_function = kmf.survival_function_
survival_function.columns = ["Survival"]
survival_function["Category"] = label
survival_function["Time (months)"] = survival_function.index
return survival_function
def build_plot(df, measure_params):
"""Updates the figure based on the chosen measure parameters."""
group_condition = measure_params["group_condition"](df)
labels = measure_params["labels"]
if len(labels) == 1:
survival_df = prepare_survival_dataframe(df, group_condition, labels[0])
else:
survival_df_group1 = prepare_survival_dataframe(df, group_condition, labels[0])
survival_df_group2 = prepare_survival_dataframe(df, ~group_condition, labels[1])
survival_df = pd.concat([survival_df_group1, survival_df_group2])
return survival_df