Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 77 additions & 68 deletions R/g_km.R
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ g_km <- function(df,
)
) +
theme_bw(base_size = font_size) +
scale_y_continuous(limits = ylim, expand = c(0.025, 0)) +
scale_y_continuous(limits = ylim) +
labs(title = title, x = xlab, y = ylab, caption = footnotes) +
theme(
axis.text = element_text(size = font_size),
Expand All @@ -352,18 +352,18 @@ g_km <- function(df,
# derive x-axis limits
if (!is.null(max_time) && !is.null(xticks)) {
gg_plt <- gg_plt + scale_x_continuous(
breaks = xticks, limits = c(min(0, xticks), max(c(xticks, max_time))), expand = c(0.025, 0)
breaks = xticks, limits = c(min(0, xticks), max(c(xticks, max_time)))
)
} else if (!is.null(xticks)) {
if (max(data$time) <= max(xticks)) {
gg_plt <- gg_plt + scale_x_continuous(
breaks = xticks, limits = c(min(0, min(xticks)), max(xticks)), expand = c(0.025, 0)
breaks = xticks, limits = c(min(0, min(xticks)), max(xticks))
)
} else {
gg_plt <- gg_plt + scale_x_continuous(breaks = xticks, expand = c(0.025, 0))
gg_plt <- gg_plt + scale_x_continuous(breaks = xticks)
}
} else if (!is.null(max_time)) {
gg_plt <- gg_plt + scale_x_continuous(limits = c(0, max_time), expand = c(0.025, 0))
gg_plt <- gg_plt + scale_x_continuous(limits = c(0, max_time))
}

# set legend position
Expand Down Expand Up @@ -427,7 +427,75 @@ g_km <- function(df,
}
if (!is.null(ggtheme)) gg_plt <- gg_plt + ggtheme

# annotate with stats (text/vlines)
# add at risk annotation table -----------------------------------------------
if (annot_at_risk) {
annot_tbl <- summary(fit_km, times = xticks, extend = TRUE)
annot_tbl <- if (is.null(fit_km$strata)) {
data.frame(
n.risk = annot_tbl$n.risk,
time = annot_tbl$time,
strata = armval
)
} else {
strata_lst <- strsplit(sub("=", "equals", levels(annot_tbl$strata)), "equals")
levels(annot_tbl$strata) <- matrix(unlist(strata_lst), ncol = 2, byrow = TRUE)[, 2]
data.frame(
n.risk = annot_tbl$n.risk,
time = annot_tbl$time,
strata = annot_tbl$strata
)
}

at_risk_tbl <- as.data.frame(tidyr::pivot_wider(annot_tbl, names_from = "time", values_from = "n.risk")[, -1])
at_risk_tbl[is.na(at_risk_tbl)] <- 0
rownames(at_risk_tbl) <- levels(annot_tbl$strata)

gg_at_risk <- df2gg(
at_risk_tbl,
font_size = font_size, col_labels = FALSE, hline = FALSE,
colwidths = rep(1, ncol(at_risk_tbl)),
add_proper_xaxis = TRUE
) +
ggplot2::labs(title = if (!is.null(title)) title else NULL, x = xlab) +
ggplot2::theme_bw(base_size = font_size) +
ggplot2::theme(
plot.title = ggplot2::element_text(size = font_size, vjust = 3, face = "bold"),
panel.border = ggplot2::element_blank(),
panel.grid = ggplot2::element_blank(),
axis.title.y = ggplot2::element_blank(),
axis.ticks.y = ggplot2::element_blank(),
axis.text.y = ggplot2::element_text(size = font_size, face = "italic", hjust = 1),
axis.text.x = ggplot2::element_text(size = font_size),
axis.line.x = ggplot2::element_line()
) +
ggplot2::coord_cartesian(clip = "off", ylim = c(0.5, nrow(at_risk_tbl)))

# 1. Get the exact x-range from the top plot (the 0-1200 range)
top_range <- layer_scales(gg_plt)$x$range$range
top_breaks <- layer_scales(gg_plt)$x$break_positions()

# 2. Force the bottom plot (table) to use the SAME range and breaks
# This ensures 0 on the top is exactly above 0 on the bottom
gg_at_risk <- gg_at_risk +
scale_x_continuous(
limits = top_range,
breaks = top_breaks
)

# 3. Force the top plot to also have no expansion so they match perfectly
gg_plt <- gg_plt + scale_x_continuous(limits = top_range)

if (!as_list) {
# Apply this to both plots
gg_plt <- cowplot::plot_grid(
gg_plt, gg_at_risk,
align = "v", axis = "lr", ncol = 1,
rel_heights = c(rel_height_plot, 1 - rel_height_plot)
)
}
}

# annotate with stats (text/vlines) -----------------------------------------
if (!is.null(annot_stats)) {
if ("median" %in% annot_stats) {
fit_km_all <- survival::survfit(
Expand Down Expand Up @@ -475,66 +543,7 @@ g_km <- function(df,
gg_plt <- gg_plt + guides(fill = guide_legend(override.aes = list(shape = NA, label = "")))
}

# add at risk annotation table
if (annot_at_risk) {
annot_tbl <- summary(fit_km, times = xticks, extend = TRUE)
annot_tbl <- if (is.null(fit_km$strata)) {
data.frame(
n.risk = annot_tbl$n.risk,
time = annot_tbl$time,
strata = armval
)
} else {
strata_lst <- strsplit(sub("=", "equals", levels(annot_tbl$strata)), "equals")
levels(annot_tbl$strata) <- matrix(unlist(strata_lst), ncol = 2, byrow = TRUE)[, 2]
data.frame(
n.risk = annot_tbl$n.risk,
time = annot_tbl$time,
strata = annot_tbl$strata
)
}

at_risk_tbl <- as.data.frame(tidyr::pivot_wider(annot_tbl, names_from = "time", values_from = "n.risk")[, -1])
at_risk_tbl[is.na(at_risk_tbl)] <- 0
rownames(at_risk_tbl) <- levels(annot_tbl$strata)

gg_at_risk <- df2gg(
at_risk_tbl,
font_size = font_size, col_labels = FALSE, hline = FALSE,
colwidths = rep(1, ncol(at_risk_tbl))
) +
labs(title = if (annot_at_risk_title) "Patients at Risk:" else NULL, x = xlab) +
theme_bw(base_size = font_size) +
theme(
plot.title = element_text(size = font_size, vjust = 3, face = "bold"),
panel.border = element_blank(),
panel.grid = element_blank(),
axis.title.y = element_blank(),
axis.ticks.y = element_blank(),
axis.text.y = element_text(size = font_size, face = "italic", hjust = 1),
axis.text.x = element_text(size = font_size),
axis.line.x = element_line()
) +
coord_cartesian(clip = "off", ylim = c(0.5, nrow(at_risk_tbl)))
gg_at_risk <- suppressMessages(
gg_at_risk +
scale_x_continuous(expand = c(0.025, 0), breaks = seq_along(at_risk_tbl) - 0.5, labels = xticks) +
scale_y_continuous(labels = rev(levels(annot_tbl$strata)), breaks = seq_len(nrow(at_risk_tbl)))
)

if (!as_list) {
gg_plt <- cowplot::plot_grid(
gg_plt,
gg_at_risk,
align = "v",
axis = "tblr",
ncol = 1,
rel_heights = c(rel_height_plot, 1 - rel_height_plot)
)
}
}

# add median survival time annotation table
# add median survival time annotation table ----------------------------------
if (annot_surv_med) {
surv_med_tbl <- h_tbl_median_surv(fit_km = fit_km, armval = armval)
bg_fill <- if (isTRUE(control_annot_surv_med[["fill"]])) "#00000020" else control_annot_surv_med[["fill"]]
Expand All @@ -547,7 +556,7 @@ g_km <- function(df,
coord_cartesian(clip = "off", ylim = c(0.5, nrow(surv_med_tbl) + 1.5))
gg_surv_med <- suppressMessages(
gg_surv_med +
scale_x_continuous(expand = c(0.025, 0)) +
scale_x_continuous() +
scale_y_continuous(labels = rev(rownames(surv_med_tbl)), breaks = seq_len(nrow(surv_med_tbl)))
)

Expand Down Expand Up @@ -582,7 +591,7 @@ g_km <- function(df,
coord_cartesian(clip = "off", ylim = c(0.5, nrow(coxph_tbl) + 1.5))
gg_coxph <- suppressMessages(
gg_coxph +
scale_x_continuous(expand = c(0.025, 0)) +
scale_x_continuous() +
scale_y_continuous(labels = rev(rownames(coxph_tbl)), breaks = seq_len(nrow(coxph_tbl)))
)

Expand Down
121 changes: 81 additions & 40 deletions R/utils_ggplot.R
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ rtable2gg <- function(tbl, fontsize = 12, colwidths = NULL, lbl_col_padding = 0)
#' if `col_labels = TRUE`). Defaults to `"bold"`.
#' @param hline (`flag`)\cr whether a horizontal line should be printed below the first row of the table.
#' @param bg_fill (`string`)\cr table background fill color.
#' @param add_proper_xaxis (`flag`)\cr whether to add a proper x-axis with column values.
#'
#' @return A `ggplot` object.
#'
Expand All @@ -157,61 +158,101 @@ rtable2gg <- function(tbl, fontsize = 12, colwidths = NULL, lbl_col_padding = 0)
#' df2gg(head(iris, 5), font_size = 15, colwidths = c(1, 1, 1, 1, 1))
#' }
#' @keywords internal
df2gg <- function(df,
colwidths = NULL,
font_size = 10,
col_labels = TRUE,
col_lab_fontface = "bold",
hline = TRUE,
bg_fill = NULL) {
# convert to text
df <- as.data.frame(apply(df, 1:2, function(x) if (is.na(x)) "NA" else as.character(x)))
df2gg <- function(df, colwidths = NULL, font_size = 10, col_labels = TRUE,
col_lab_fontface = "bold", hline = TRUE, bg_fill = NULL, add_proper_xaxis = FALSE) {
# Convert all values to character, replacing NAs with "NA"
df <- as.data.frame(apply(df, 1:2, function(x) {
if (is.na(x)) {
"NA"
} else {
as.character(x)
}
}))

# Add column labels as first row if specified
if (col_labels) {
df <- as.matrix(df)
df <- rbind(colnames(df), df)
}

# Get column widths
if (is.null(colwidths)) {
colwidths <- apply(df, 2, function(x) max(nchar(x), na.rm = TRUE))
}
tot_width <- sum(colwidths)
# Create ggplot2 object with x-axis specified in df
if (add_proper_xaxis) {
# Determine column widths if not provided
if (is.null(colwidths)) {
tot_width <- max(colnames(df) |> as.numeric(), na.rm = TRUE)
colwidths <- rep(floor(tot_width / ncol(df)), ncol(df))
} else {
tot_width <- sum(colwidths)
}

res <- ggplot(data = df) +
theme_void() +
scale_x_continuous(limits = c(0, tot_width)) +
scale_y_continuous(limits = c(1, nrow(df)))
df_long <- df |>
as.data.frame() |>
# 1. Ensure the row names ('A', 'B', 'C') are a column named 'row_name'
dplyr::mutate(row_name = row.names(df)) |>
# 2. Pivot the remaining columns (starting from '0' to the end) longer
tidyr::pivot_longer(
cols = -.data$row_name, # Select all columns EXCEPT 'row_name'
names_to = "col_name", # Name the new column containing the old column headers
values_to = "value" # Name the new column containing the data values
) |>
dplyr::arrange(.data$row_name, .data$col_name) |>
dplyr::mutate(
col_name = as.numeric(.data$col_name),
row_name = factor(.data$row_name, levels = row.names(df))
)
res <- ggplot2::ggplot(data = df_long) +
ggplot2::theme_void() +
ggplot2::annotate("text",
x = df_long$col_name, y = rev(df_long$row_name), # why rev?
label = df_long$value, size = font_size / .pt
)

if (!is.null(bg_fill)) res <- res + theme(plot.background = element_rect(fill = bg_fill))
# Create ggplot2 object with a specific x-axis based on column widths
} else {
# Determine column widths if not provided
if (is.null(colwidths)) {
colwidths <- apply(df, 2, function(x) max(nchar(x), na.rm = TRUE))
}
tot_width <- sum(colwidths)

if (hline) {
res <- res +
annotate(
"segment",
x = 0 + 0.2 * colwidths[2], xend = tot_width - 0.1 * tail(colwidths, 1),
y = nrow(df) - 0.5, yend = nrow(df) - 0.5
)
}
res <- ggplot2::ggplot(data = df) +
ggplot2::theme_void() +
ggplot2::scale_x_continuous(limits = c(0, tot_width)) +
ggplot2::scale_y_continuous(limits = c(1, nrow(df)))

for (i in seq_len(ncol(df))) {
line_pos <- c(
if (i == 1) 0 else sum(colwidths[1:(i - 1)]),
sum(colwidths[1:i])
)
res <- res +
annotate(
"text",
x = mean(line_pos),
y = rev(seq_len(nrow(df))),
label = df[, i],
size = font_size / .pt,
fontface = if (col_labels) {

for (i in seq_len(ncol(df))) {
line_pos <- c(
if (i == 1) {
0
} else {
sum(colwidths[1:(i - 1)])
},
sum(colwidths[1:i])
)
res <- res + ggplot2::annotate("text",
x = mean(line_pos), y = rev(seq_len(nrow(df))),
label = df[, i], size = font_size / .pt, fontface = if (col_labels) {
c(col_lab_fontface, rep("plain", nrow(df) - 1))
} else {
rep("plain", nrow(df))
}
)
}
}

# Add horizontal line if specified
if (hline) {
res <- res + ggplot2::annotate(
"segment",
x = 0 + 0.2 * colwidths[2], xend = tot_width - 0.1 * tail(colwidths, 1),
y = nrow(df) - 0.5, yend = nrow(df) - 0.5
)
}

# Set background fill if specified
if (!is.null(bg_fill)) {
res <- res + ggplot2::theme(plot.background = ggplot2::element_rect(fill = bg_fill))
}

res
Expand Down
5 changes: 4 additions & 1 deletion man/df2gg.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 24 additions & 0 deletions tests/testthat/test-utils_ggplot.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,27 @@ testthat::test_that("df2gg works as expected", {
testthat::expect_silent(df2gg_cw <- head(iris, 5) %>% df2gg(colwidths = c(1, 1, 1, 1, 1)))
expect_snapshot_ggplot("df2gg_cw", df2gg_cw, width = 5)
})

test_that("df2gg() works with proper x-axis and without", {
# Example using proper x-axis
df <- as.data.frame(matrix(c(
# 0, 250, 500, 750, 1000 <-- (Reference)
54, 28, 10, 3, 0,
59, 35, 16, 5, 1,
54, 25, 4, 0, 0
), nrow = 3, byrow = TRUE))

# Set names manually
colnames(df) <- c("0", "250", "500", "750", "1000")
rownames(df) <- c("A", "B", "C")

# Example with proper x-axis
expect_no_error(
null <- df2gg(df, font_size = 8, add_proper_xaxis = TRUE)
)

# Example without proper x-axis
expect_no_error(
null <- df2gg(df, font_size = 8, add_proper_xaxis = FALSE, hline = FALSE)
)
})
Loading