Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions R/SHAPclust.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
#' @param iter.max maximum number of iterations allowed
#'
#' @importFrom magrittr %>%
#' @importFrom dplyr mutate
#' @importFrom dplyr mutate filter
#' @importFrom forcats fct_reorder
#' @importFrom ggplot2 ggtitle ggplot aes geom_violin geom_line coord_flip geom_jitter position_jitter scale_shape_manual labs scale_colour_gradient2 geom_text theme geom_hline element_blank element_text element_line ylim facet_wrap ggsave
#' @importFrom plotly ggplotly
#' @importFrom tibble tibble as_tibble
Expand Down Expand Up @@ -159,7 +160,24 @@ SHAPclust <- function(task,
colnames(kmeans_fvals)[1] <- "cluster"

# save the statistical descriptions of the clusters by feature values
kmeans_fvals_desc <- psych::describeBy(kmeans_fvals, group = kmeans_fvals$cluster)
# Use psych package if available, otherwise create basic summary
kmeans_fvals_desc <- tryCatch({
if (requireNamespace("psych", quietly = TRUE)) {
psych::describeBy(kmeans_fvals, group = kmeans_fvals$cluster)
} else {
# Fallback: create basic summary using base R
warning("Package 'psych' not available. Using basic summary instead of detailed description.")
aggregate(. ~ cluster, data = kmeans_fvals, FUN = function(x) {
c(mean = mean(x, na.rm = TRUE),
sd = sd(x, na.rm = TRUE),
min = min(x, na.rm = TRUE),
max = max(x, na.rm = TRUE))
})
}
}, error = function(e) {
warning("Could not create cluster descriptions: ", conditionMessage(e))
NULL
})
shap_Mean_wide_kmeans$row_ids <- shap_Mean_wide_kmeans$row_ids - shap_Mean_wide_kmeans$row_ids[1] + 1
shap_Mean_wide_kmeans[, prediction_correctness := (truth == response)]
shap_Mean_wide_kmeans_forCM <- shap_Mean_wide_kmeans
Expand Down Expand Up @@ -189,7 +207,10 @@ SHAPclust <- function(task,
print(dt_long)
############## SHAP plots for clusters
shap_plot1 <- dt_long %>%
mutate(feature = forcats::fct_reorder(feature, mean_phi)) %>%
# Clean data to ensure forcats::fct_reorder works properly
filter(!is.na(feature), !is.na(mean_phi), is.finite(mean_phi)) %>%
mutate(feature = as.character(feature)) %>%
mutate(feature = forcats::fct_reorder(feature, mean_phi, .fun = function(x) mean(x, na.rm = TRUE))) %>%
ggplot(aes(x = feature, y = Phi, color = f_val)) +
geom_violin(colour = "grey") +
geom_line(aes(group = sample_num), alpha = 0.1, size = 0.2) +
Expand Down
54 changes: 38 additions & 16 deletions R/eSHAP_plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#' @param splits mlr3 object defining data splits for train and test sets
#' @param subset numeric, what percentage of the instances to use from 0 to 1 where 1 means all
#'
#' @importFrom dplyr filter mutate
#' @importFrom forcats fct_reorder
#' @importFrom magrittr %>%
#' @importFrom ggplot2 ggplot aes geom_violin geom_line coord_flip geom_jitter position_jitter scale_shape_manual labs scale_colour_gradient2 geom_text theme element_blank geom_hline element_text element_line ylim
#' @export
Expand Down Expand Up @@ -194,28 +196,48 @@ eSHAP_plot <- function(task,
shap_Mean$correct_prediction <- factor(shap_Mean$correct_prediction, levels = c(FALSE, TRUE), labels = c("Incorrect", "Correct"))


shap_plot <- shap_Mean %>%
mutate(feature = forcats::fct_reorder(feature, mean_phi)) %>%
# Prepare data for plotting with robust error handling
# Handle NaN values in mean_phi by recalculating from Phi
shap_Mean <- shap_Mean %>%
group_by(feature) %>%
mutate(mean_phi = ifelse(is.nan(mean_phi) | is.na(mean_phi),
mean(Phi, na.rm = TRUE),
mean_phi)) %>%
ungroup()

plot_data <- shap_Mean %>%
filter(!is.na(feature)) %>%
filter(!is.na(Phi), is.finite(Phi)) %>%
filter(!is.na(f_val), is.finite(f_val)) %>%
mutate(
feature = as.character(feature),
feature = factor(feature),
Phi = as.numeric(Phi),
f_val = as.numeric(f_val),
mean_phi = as.numeric(mean_phi),
sample_num = as.integer(sample_num)
)

# Check if we have data to plot
if (nrow(plot_data) == 0) {
stop("No valid data for plotting after filtering")
}

shap_plot <- plot_data %>%
ggplot(aes(x = feature, y = Phi, color = f_val)) +
geom_violin(colour = "grey") +
geom_line(aes(group = sample_num), alpha = 0.1, size = 0.2) +
geom_line(aes(group = sample_num), alpha = 0.1, linewidth = 0.2) +
coord_flip() +
geom_jitter(aes(shape = correct_prediction, text = paste(
"Feature: ", feature,
"<br>Unscaled feature value: ", unscaled_f_val,
"<br>SHAP value: ", Phi,
"<br>Prediction correctness: ", correct_prediction,
"<br>Predicted probability: ", pred_prob,
"<br>Predicted class: ", pred_class
)),
geom_jitter(aes(shape = correct_prediction),
alpha = 0.6, size = 1.5, position = position_jitter(width = 0.2, height = 0)
) +
scale_shape_manual(values = c(4, 19), guide = FALSE) +
scale_shape_manual(values = c(4, 19), guide = "none") +
# scale_color_manual(values=c("black","grey")) +
labs(shape = "model prediction") +
scale_colour_gradient2(low = "blue", mid = "green", high = "red", midpoint = 0.5, breaks = c(0, 1), labels = c("Low", "High")) +
guides(shape = ggplot2::guide_legend(override.aes = list(fill = "black", color = "black"))) +
geom_text(aes(x = feature, y = -Inf, label = sprintf("%.3f", mean_phi)), hjust = -0.2, alpha = 0.7, color = "black") +
# Remove problematic geom_text that might cause coord_flip issues
# geom_text(aes(x = feature, y = -Inf, label = sprintf("%.3f", mean_phi)), hjust = -0.2, alpha = 0.7, color = "black") +
theme(
axis.line.y = element_blank(), axis.ticks.y = element_blank(),
legend.position = "right"
Expand All @@ -235,10 +257,10 @@ eSHAP_plot <- function(task,
axis.line = element_line(colour = "grey"),
legend.key.width = grid::unit(2, "mm")
) +
ylim(min(shap_Mean$Phi) - 0.05, max(shap_Mean$Phi) + 0.05)
ylim(min(plot_data$Phi, na.rm = TRUE) - 0.05, max(plot_data$Phi, na.rm = TRUE) + 0.05)

# Convert ggplot to Plotly
shap_plot <- ggplotly(shap_plot, tooltip = "text")
# Convert ggplot to Plotly (simplified without text tooltips)
shap_plot <- ggplotly(shap_plot)

# Additional plot to show SHAP values vs. predicted probabilities
shap_pred_plot <- shap_Mean %>%
Expand Down