diff --git a/R/cor_sort.R b/R/cor_sort.R index 7c30d22f..f686ae11 100644 --- a/R/cor_sort.R +++ b/R/cor_sort.R @@ -38,6 +38,8 @@ cor_sort.easycorrelation <- function(x, distance = "correlation", hclust_method ) # Make sure Parameter columns are character + # Was added to fix a test, but makes the function not work + # (See https://github.com/easystats/correlation/issues/259) # reordered$Parameter1 <- as.character(reordered$Parameter1) # reordered$Parameter2 <- as.character(reordered$Parameter2) @@ -76,8 +78,13 @@ cor_sort.easycormatrix <- function(x, distance = "correlation", hclust_method = #' @export cor_sort.matrix <- function(x, distance = "correlation", hclust_method = "complete", ...) { - col_order <- .cor_sort_order(x, distance = distance, hclust_method = hclust_method, ...) - reordered <- x[col_order, col_order] + if(isSquare(x) & all(colnames(x) %in% rownames(x))) { + i <- .cor_sort_square(x, distance = distance, hclust_method = hclust_method, ...) + } else { + i <- .cor_sort_nonsquare(x, distance = "euclidean", ...) + } + + reordered <- x[i$row_order, i$col_order] # Restore class and attributes attributes(reordered) <- utils::modifyList( @@ -91,7 +98,7 @@ cor_sort.matrix <- function(x, distance = "correlation", hclust_method = "comple # Utils ------------------------------------------------------------------- -.cor_sort_order <- function(m, distance = "correlation", hclust_method = "complete", ...) { +.cor_sort_square <- function(m, distance = "correlation", hclust_method = "complete", ...) { if (distance == "correlation") { d <- stats::as.dist((1 - m) / 2) # r = -1 -> d = 1; r = 1 -> d = 0 } else if (distance == "raw") { @@ -101,5 +108,54 @@ cor_sort.matrix <- function(x, distance = "correlation", hclust_method = "comple } hc <- stats::hclust(d, method = hclust_method) - row.names(m)[hc$order] + row_order <- row.names(m)[hc$order] + list(row_order = row_order, col_order = row_order) +} + + +.cor_sort_nonsquare <- function(m, distance = "euclidean", ...) { + # Step 1: Perform clustering on rows and columns independently + row_dist <- dist(m, method = distance) # Distance between rows + col_dist <- dist(t(m), method = distance) # Distance between columns + + row_hclust <- stats::hclust(row_dist, method = "average") + col_hclust <- stats::hclust(col_dist, method = "average") + + # Obtain clustering orders + row_order <- row_hclust$order + col_order <- col_hclust$order + + # Reorder matrix based on clustering + clustered_matrix <- m[row_order, col_order] + + # Step 2: Refine alignment to emphasize strong correlations along the diagonal + n_rows <- nrow(clustered_matrix) + n_cols <- ncol(clustered_matrix) + + used_rows <- logical(n_rows) + refined_row_order <- integer(0) + + for (col in seq_len(n_cols)) { + max_value <- -Inf + best_row <- NA + + for (row in seq_len(n_rows)[!used_rows]) { + if (abs(clustered_matrix[row, col]) > max_value) { + max_value <- abs(clustered_matrix[row, col]) + best_row <- row + } + } + + if (!is.na(best_row)) { + refined_row_order <- c(refined_row_order, best_row) + used_rows[best_row] <- TRUE + } + } + + # Append any unused rows at the end + refined_row_order <- c(refined_row_order, which(!used_rows)) + + # Apply + m <- clustered_matrix[refined_row_order, ] + list(row_order = rownames(m), col_order = colnames(m)) } diff --git a/R/display.R b/R/display.R index f8b93f3c..986fc3ea 100644 --- a/R/display.R +++ b/R/display.R @@ -2,7 +2,8 @@ #' @name display.easycormatrix #' #' @description Export tables (i.e. data frame) into different output formats. -#' `print_md()` is a alias for `display(format = "markdown")`. +#' `print_md()` is a alias for `display(format = "markdown")`. Note that +#' you can use `format()` to get the formatted table as a dataframe. #' #' @param object,x An object returned by #' [`correlation()`][correlation] or its summary. diff --git a/tests/testthat/test-cor_sort.R b/tests/testthat/test-cor_sort.R new file mode 100644 index 00000000..d53beda9 --- /dev/null +++ b/tests/testthat/test-cor_sort.R @@ -0,0 +1,17 @@ +test_that("cor_sort", { + r <- cor(mtcars) + expect_equal(as.numeric(diag(r)), rep(1, ncol(mtcars))) + # heatmap(r, Rowv = NA, Colv = NA) + + r1 <- cor_sort(r) + expect_equal(as.numeric(diag(r)), rep(1, ncol(mtcars))) + # heatmap(r1, Rowv = NA, Colv = NA) + + r2 <- cor(mtcars[names(mtcars)[1:5]], mtcars[names(mtcars)[6:11]]) + expect_equal(rownames(r2), names(mtcars)[1:5]) + # heatmap(r2, Rowv = NA, Colv = NA) + + r3 <- cor_sort(r2) + expect_equal(all(rownames(r3) == names(mtcars)[1:5]), FALSE) + # heatmap(r3, Rowv = NA, Colv = NA) +}) \ No newline at end of file