From f452ce335b7594db7f9e82e5596a5aefd055cfa9 Mon Sep 17 00:00:00 2001
From: "Ramtin Zargari Marandi, PhD"
<46963585+ramtinz@users.noreply.github.com>
Date: Sat, 5 Jul 2025 12:30:35 +0200
Subject: [PATCH] Fix eSHAP_plot issue
Fix eSHAP_plot NaN issue and SHAPclust dependency handling
- Fix NaN values in mean_phi aggregation in eSHAP_plot
- Add fallback for missing psych package in SHAPclust
- Fixes #1"
---
R/SHAPclust.R | 27 ++++++++++++++++++++++---
R/eSHAP_plot.R | 54 +++++++++++++++++++++++++++++++++++---------------
2 files changed, 62 insertions(+), 19 deletions(-)
diff --git a/R/SHAPclust.R b/R/SHAPclust.R
index 76c6755..ad580ba 100644
--- a/R/SHAPclust.R
+++ b/R/SHAPclust.R
@@ -13,7 +13,8 @@
#' @param iter.max maximum number of iterations allowed
#'
#' @importFrom magrittr %>%
-#' @importFrom dplyr mutate
+#' @importFrom dplyr mutate filter
+#' @importFrom forcats fct_reorder
#' @importFrom ggplot2 ggtitle ggplot aes geom_violin geom_line coord_flip geom_jitter position_jitter scale_shape_manual labs scale_colour_gradient2 geom_text theme geom_hline element_blank element_text element_line ylim facet_wrap ggsave
#' @importFrom plotly ggplotly
#' @importFrom tibble tibble as_tibble
@@ -159,7 +160,24 @@ SHAPclust <- function(task,
colnames(kmeans_fvals)[1] <- "cluster"
# save the statistical descriptions of the clusters by feature values
- kmeans_fvals_desc <- psych::describeBy(kmeans_fvals, group = kmeans_fvals$cluster)
+ # Use psych package if available, otherwise create basic summary
+ kmeans_fvals_desc <- tryCatch({
+ if (requireNamespace("psych", quietly = TRUE)) {
+ psych::describeBy(kmeans_fvals, group = kmeans_fvals$cluster)
+ } else {
+ # Fallback: create basic summary using base R
+ warning("Package 'psych' not available. Using basic summary instead of detailed description.")
+ aggregate(. ~ cluster, data = kmeans_fvals, FUN = function(x) {
+ c(mean = mean(x, na.rm = TRUE),
+ sd = sd(x, na.rm = TRUE),
+ min = min(x, na.rm = TRUE),
+ max = max(x, na.rm = TRUE))
+ })
+ }
+ }, error = function(e) {
+ warning("Could not create cluster descriptions: ", conditionMessage(e))
+ NULL
+ })
shap_Mean_wide_kmeans$row_ids <- shap_Mean_wide_kmeans$row_ids - shap_Mean_wide_kmeans$row_ids[1] + 1
shap_Mean_wide_kmeans[, prediction_correctness := (truth == response)]
shap_Mean_wide_kmeans_forCM <- shap_Mean_wide_kmeans
@@ -189,7 +207,10 @@ SHAPclust <- function(task,
print(dt_long)
############## SHAP plots for clusters
shap_plot1 <- dt_long %>%
- mutate(feature = forcats::fct_reorder(feature, mean_phi)) %>%
+ # Clean data to ensure forcats::fct_reorder works properly
+ filter(!is.na(feature), !is.na(mean_phi), is.finite(mean_phi)) %>%
+ mutate(feature = as.character(feature)) %>%
+ mutate(feature = forcats::fct_reorder(feature, mean_phi, .fun = function(x) mean(x, na.rm = TRUE))) %>%
ggplot(aes(x = feature, y = Phi, color = f_val)) +
geom_violin(colour = "grey") +
geom_line(aes(group = sample_num), alpha = 0.1, size = 0.2) +
diff --git a/R/eSHAP_plot.R b/R/eSHAP_plot.R
index ac35760..c88f56b 100644
--- a/R/eSHAP_plot.R
+++ b/R/eSHAP_plot.R
@@ -10,6 +10,8 @@
#' @param splits mlr3 object defining data splits for train and test sets
#' @param subset numeric, what percentage of the instances to use from 0 to 1 where 1 means all
#'
+#' @importFrom dplyr filter mutate
+#' @importFrom forcats fct_reorder
#' @importFrom magrittr %>%
#' @importFrom ggplot2 ggplot aes geom_violin geom_line coord_flip geom_jitter position_jitter scale_shape_manual labs scale_colour_gradient2 geom_text theme element_blank geom_hline element_text element_line ylim
#' @export
@@ -194,28 +196,48 @@ eSHAP_plot <- function(task,
shap_Mean$correct_prediction <- factor(shap_Mean$correct_prediction, levels = c(FALSE, TRUE), labels = c("Incorrect", "Correct"))
- shap_plot <- shap_Mean %>%
- mutate(feature = forcats::fct_reorder(feature, mean_phi)) %>%
+ # Prepare data for plotting with robust error handling
+ # Handle NaN values in mean_phi by recalculating from Phi
+ shap_Mean <- shap_Mean %>%
+ group_by(feature) %>%
+ mutate(mean_phi = ifelse(is.nan(mean_phi) | is.na(mean_phi),
+ mean(Phi, na.rm = TRUE),
+ mean_phi)) %>%
+ ungroup()
+
+ plot_data <- shap_Mean %>%
+ filter(!is.na(feature)) %>%
+ filter(!is.na(Phi), is.finite(Phi)) %>%
+ filter(!is.na(f_val), is.finite(f_val)) %>%
+ mutate(
+ feature = as.character(feature),
+ feature = factor(feature),
+ Phi = as.numeric(Phi),
+ f_val = as.numeric(f_val),
+ mean_phi = as.numeric(mean_phi),
+ sample_num = as.integer(sample_num)
+ )
+
+ # Check if we have data to plot
+ if (nrow(plot_data) == 0) {
+ stop("No valid data for plotting after filtering")
+ }
+
+ shap_plot <- plot_data %>%
ggplot(aes(x = feature, y = Phi, color = f_val)) +
geom_violin(colour = "grey") +
- geom_line(aes(group = sample_num), alpha = 0.1, size = 0.2) +
+ geom_line(aes(group = sample_num), alpha = 0.1, linewidth = 0.2) +
coord_flip() +
- geom_jitter(aes(shape = correct_prediction, text = paste(
- "Feature: ", feature,
- "
Unscaled feature value: ", unscaled_f_val,
- "
SHAP value: ", Phi,
- "
Prediction correctness: ", correct_prediction,
- "
Predicted probability: ", pred_prob,
- "
Predicted class: ", pred_class
- )),
+ geom_jitter(aes(shape = correct_prediction),
alpha = 0.6, size = 1.5, position = position_jitter(width = 0.2, height = 0)
) +
- scale_shape_manual(values = c(4, 19), guide = FALSE) +
+ scale_shape_manual(values = c(4, 19), guide = "none") +
# scale_color_manual(values=c("black","grey")) +
labs(shape = "model prediction") +
scale_colour_gradient2(low = "blue", mid = "green", high = "red", midpoint = 0.5, breaks = c(0, 1), labels = c("Low", "High")) +
guides(shape = ggplot2::guide_legend(override.aes = list(fill = "black", color = "black"))) +
- geom_text(aes(x = feature, y = -Inf, label = sprintf("%.3f", mean_phi)), hjust = -0.2, alpha = 0.7, color = "black") +
+ # Remove problematic geom_text that might cause coord_flip issues
+ # geom_text(aes(x = feature, y = -Inf, label = sprintf("%.3f", mean_phi)), hjust = -0.2, alpha = 0.7, color = "black") +
theme(
axis.line.y = element_blank(), axis.ticks.y = element_blank(),
legend.position = "right"
@@ -235,10 +257,10 @@ eSHAP_plot <- function(task,
axis.line = element_line(colour = "grey"),
legend.key.width = grid::unit(2, "mm")
) +
- ylim(min(shap_Mean$Phi) - 0.05, max(shap_Mean$Phi) + 0.05)
+ ylim(min(plot_data$Phi, na.rm = TRUE) - 0.05, max(plot_data$Phi, na.rm = TRUE) + 0.05)
- # Convert ggplot to Plotly
- shap_plot <- ggplotly(shap_plot, tooltip = "text")
+ # Convert ggplot to Plotly (simplified without text tooltips)
+ shap_plot <- ggplotly(shap_plot)
# Additional plot to show SHAP values vs. predicted probabilities
shap_pred_plot <- shap_Mean %>%