From f452ce335b7594db7f9e82e5596a5aefd055cfa9 Mon Sep 17 00:00:00 2001 From: "Ramtin Zargari Marandi, PhD" <46963585+ramtinz@users.noreply.github.com> Date: Sat, 5 Jul 2025 12:30:35 +0200 Subject: [PATCH] Fix eSHAP_plot issue Fix eSHAP_plot NaN issue and SHAPclust dependency handling - Fix NaN values in mean_phi aggregation in eSHAP_plot - Add fallback for missing psych package in SHAPclust - Fixes #1" --- R/SHAPclust.R | 27 ++++++++++++++++++++++--- R/eSHAP_plot.R | 54 +++++++++++++++++++++++++++++++++++--------------- 2 files changed, 62 insertions(+), 19 deletions(-) diff --git a/R/SHAPclust.R b/R/SHAPclust.R index 76c6755..ad580ba 100644 --- a/R/SHAPclust.R +++ b/R/SHAPclust.R @@ -13,7 +13,8 @@ #' @param iter.max maximum number of iterations allowed #' #' @importFrom magrittr %>% -#' @importFrom dplyr mutate +#' @importFrom dplyr mutate filter +#' @importFrom forcats fct_reorder #' @importFrom ggplot2 ggtitle ggplot aes geom_violin geom_line coord_flip geom_jitter position_jitter scale_shape_manual labs scale_colour_gradient2 geom_text theme geom_hline element_blank element_text element_line ylim facet_wrap ggsave #' @importFrom plotly ggplotly #' @importFrom tibble tibble as_tibble @@ -159,7 +160,24 @@ SHAPclust <- function(task, colnames(kmeans_fvals)[1] <- "cluster" # save the statistical descriptions of the clusters by feature values - kmeans_fvals_desc <- psych::describeBy(kmeans_fvals, group = kmeans_fvals$cluster) + # Use psych package if available, otherwise create basic summary + kmeans_fvals_desc <- tryCatch({ + if (requireNamespace("psych", quietly = TRUE)) { + psych::describeBy(kmeans_fvals, group = kmeans_fvals$cluster) + } else { + # Fallback: create basic summary using base R + warning("Package 'psych' not available. Using basic summary instead of detailed description.") + aggregate(. ~ cluster, data = kmeans_fvals, FUN = function(x) { + c(mean = mean(x, na.rm = TRUE), + sd = sd(x, na.rm = TRUE), + min = min(x, na.rm = TRUE), + max = max(x, na.rm = TRUE)) + }) + } + }, error = function(e) { + warning("Could not create cluster descriptions: ", conditionMessage(e)) + NULL + }) shap_Mean_wide_kmeans$row_ids <- shap_Mean_wide_kmeans$row_ids - shap_Mean_wide_kmeans$row_ids[1] + 1 shap_Mean_wide_kmeans[, prediction_correctness := (truth == response)] shap_Mean_wide_kmeans_forCM <- shap_Mean_wide_kmeans @@ -189,7 +207,10 @@ SHAPclust <- function(task, print(dt_long) ############## SHAP plots for clusters shap_plot1 <- dt_long %>% - mutate(feature = forcats::fct_reorder(feature, mean_phi)) %>% + # Clean data to ensure forcats::fct_reorder works properly + filter(!is.na(feature), !is.na(mean_phi), is.finite(mean_phi)) %>% + mutate(feature = as.character(feature)) %>% + mutate(feature = forcats::fct_reorder(feature, mean_phi, .fun = function(x) mean(x, na.rm = TRUE))) %>% ggplot(aes(x = feature, y = Phi, color = f_val)) + geom_violin(colour = "grey") + geom_line(aes(group = sample_num), alpha = 0.1, size = 0.2) + diff --git a/R/eSHAP_plot.R b/R/eSHAP_plot.R index ac35760..c88f56b 100644 --- a/R/eSHAP_plot.R +++ b/R/eSHAP_plot.R @@ -10,6 +10,8 @@ #' @param splits mlr3 object defining data splits for train and test sets #' @param subset numeric, what percentage of the instances to use from 0 to 1 where 1 means all #' +#' @importFrom dplyr filter mutate +#' @importFrom forcats fct_reorder #' @importFrom magrittr %>% #' @importFrom ggplot2 ggplot aes geom_violin geom_line coord_flip geom_jitter position_jitter scale_shape_manual labs scale_colour_gradient2 geom_text theme element_blank geom_hline element_text element_line ylim #' @export @@ -194,28 +196,48 @@ eSHAP_plot <- function(task, shap_Mean$correct_prediction <- factor(shap_Mean$correct_prediction, levels = c(FALSE, TRUE), labels = c("Incorrect", "Correct")) - shap_plot <- shap_Mean %>% - mutate(feature = forcats::fct_reorder(feature, mean_phi)) %>% + # Prepare data for plotting with robust error handling + # Handle NaN values in mean_phi by recalculating from Phi + shap_Mean <- shap_Mean %>% + group_by(feature) %>% + mutate(mean_phi = ifelse(is.nan(mean_phi) | is.na(mean_phi), + mean(Phi, na.rm = TRUE), + mean_phi)) %>% + ungroup() + + plot_data <- shap_Mean %>% + filter(!is.na(feature)) %>% + filter(!is.na(Phi), is.finite(Phi)) %>% + filter(!is.na(f_val), is.finite(f_val)) %>% + mutate( + feature = as.character(feature), + feature = factor(feature), + Phi = as.numeric(Phi), + f_val = as.numeric(f_val), + mean_phi = as.numeric(mean_phi), + sample_num = as.integer(sample_num) + ) + + # Check if we have data to plot + if (nrow(plot_data) == 0) { + stop("No valid data for plotting after filtering") + } + + shap_plot <- plot_data %>% ggplot(aes(x = feature, y = Phi, color = f_val)) + geom_violin(colour = "grey") + - geom_line(aes(group = sample_num), alpha = 0.1, size = 0.2) + + geom_line(aes(group = sample_num), alpha = 0.1, linewidth = 0.2) + coord_flip() + - geom_jitter(aes(shape = correct_prediction, text = paste( - "Feature: ", feature, - "
Unscaled feature value: ", unscaled_f_val, - "
SHAP value: ", Phi, - "
Prediction correctness: ", correct_prediction, - "
Predicted probability: ", pred_prob, - "
Predicted class: ", pred_class - )), + geom_jitter(aes(shape = correct_prediction), alpha = 0.6, size = 1.5, position = position_jitter(width = 0.2, height = 0) ) + - scale_shape_manual(values = c(4, 19), guide = FALSE) + + scale_shape_manual(values = c(4, 19), guide = "none") + # scale_color_manual(values=c("black","grey")) + labs(shape = "model prediction") + scale_colour_gradient2(low = "blue", mid = "green", high = "red", midpoint = 0.5, breaks = c(0, 1), labels = c("Low", "High")) + guides(shape = ggplot2::guide_legend(override.aes = list(fill = "black", color = "black"))) + - geom_text(aes(x = feature, y = -Inf, label = sprintf("%.3f", mean_phi)), hjust = -0.2, alpha = 0.7, color = "black") + + # Remove problematic geom_text that might cause coord_flip issues + # geom_text(aes(x = feature, y = -Inf, label = sprintf("%.3f", mean_phi)), hjust = -0.2, alpha = 0.7, color = "black") + theme( axis.line.y = element_blank(), axis.ticks.y = element_blank(), legend.position = "right" @@ -235,10 +257,10 @@ eSHAP_plot <- function(task, axis.line = element_line(colour = "grey"), legend.key.width = grid::unit(2, "mm") ) + - ylim(min(shap_Mean$Phi) - 0.05, max(shap_Mean$Phi) + 0.05) + ylim(min(plot_data$Phi, na.rm = TRUE) - 0.05, max(plot_data$Phi, na.rm = TRUE) + 0.05) - # Convert ggplot to Plotly - shap_plot <- ggplotly(shap_plot, tooltip = "text") + # Convert ggplot to Plotly (simplified without text tooltips) + shap_plot <- ggplotly(shap_plot) # Additional plot to show SHAP values vs. predicted probabilities shap_pred_plot <- shap_Mean %>%