Skip to content

Commit

Permalink
add non-square sorting
Browse files Browse the repository at this point in the history
  • Loading branch information
DominiqueMakowski committed Dec 1, 2024
1 parent fced785 commit b8767d8
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 5 deletions.
62 changes: 58 additions & 4 deletions R/cor_sort.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ cor_sort.easycorrelation <- function(x, distance = "correlation", hclust_method
)

# Make sure Parameter columns are character
# Was added to fix a test, but makes the function not work
# (See https://github.com/easystats/correlation/issues/259)
# reordered$Parameter1 <- as.character(reordered$Parameter1)
# reordered$Parameter2 <- as.character(reordered$Parameter2)

Expand Down Expand Up @@ -76,8 +78,13 @@ cor_sort.easycormatrix <- function(x, distance = "correlation", hclust_method =

#' @export
cor_sort.matrix <- function(x, distance = "correlation", hclust_method = "complete", ...) {
col_order <- .cor_sort_order(x, distance = distance, hclust_method = hclust_method, ...)
reordered <- x[col_order, col_order]
if(isSquare(x)) {
i <- .cor_sort_order_square(x, distance = distance, hclust_method = hclust_method, ...)
} else {
i <- .cor_sort_order_nonsquare(x, distance = "euclidean", ...)
}

reordered <- x[i$row_order, i$col_order]

# Restore class and attributes
attributes(reordered) <- utils::modifyList(
Expand All @@ -91,7 +98,7 @@ cor_sort.matrix <- function(x, distance = "correlation", hclust_method = "comple
# Utils -------------------------------------------------------------------


.cor_sort_order <- function(m, distance = "correlation", hclust_method = "complete", ...) {
.cor_sort_order_square <- function(m, distance = "correlation", hclust_method = "complete", ...) {
if (distance == "correlation") {
d <- stats::as.dist((1 - m) / 2) # r = -1 -> d = 1; r = 1 -> d = 0
} else if (distance == "raw") {
Expand All @@ -101,5 +108,52 @@ cor_sort.matrix <- function(x, distance = "correlation", hclust_method = "comple
}

hc <- stats::hclust(d, method = hclust_method)
row.names(m)[hc$order]
row_order <- row.names(m)[hc$order]
list(row_order = row_order, col_order = row_order)
}


.cor_sort_order_nonsquare <- function(m, distance = "euclidean", ...) {
# Step 1: Perform clustering on rows and columns independently
row_dist <- dist(m, method = distance) # Distance between rows
col_dist <- dist(t(m), method = distance) # Distance between columns

row_hclust <- stats::hclust(row_dist, method = "average")
col_hclust <- stats::hclust(col_dist, method = "average")

# Obtain clustering orders
row_order <- row_hclust$order
col_order <- col_hclust$order

# Reorder matrix based on clustering
clustered_matrix <- m[row_order, col_order]

# Step 2: Refine alignment to emphasize strong correlations along the diagonal
n_rows <- nrow(clustered_matrix)
n_cols <- ncol(clustered_matrix)

used_rows <- logical(n_rows)
refined_row_order <- integer(0)

for (col in seq_len(n_cols)) {
max_value <- -Inf
best_row <- NA

for (row in seq_len(n_rows)[!used_rows]) {
if (abs(clustered_matrix[row, col]) > max_value) {
max_value <- abs(clustered_matrix[row, col])
best_row <- row
}
}

if (!is.na(best_row)) {
refined_row_order <- c(refined_row_order, best_row)
used_rows[best_row] <- TRUE
}
}

# Append any unused rows at the end
refined_row_order <- c(refined_row_order, which(!used_rows))

list(row_order = refined_row_order, col_order = row_order)
}
3 changes: 2 additions & 1 deletion R/display.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
#' @name display.easycormatrix
#'
#' @description Export tables (i.e. data frame) into different output formats.
#' `print_md()` is a alias for `display(format = "markdown")`.
#' `print_md()` is a alias for `display(format = "markdown")`. Note that
#' you can use `format()` to get the formatted table as a dataframe.
#'
#' @param object,x An object returned by
#' [`correlation()`][correlation] or its summary.
Expand Down
12 changes: 12 additions & 0 deletions tests/testthat/test-cor_sort.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
test_that("cor_sort", {
r <- cor(mtcars)
expect_equal(as.numeric(diag(r)), rep(1, ncol(mtcars)))

r1 <- cor_sort(r)
expect_equal(as.numeric(diag(r)), rep(1, ncol(mtcars)))

r2 <- cor(mtcars[names(mtcars)[1:5]], mtcars[names(mtcars)[6:11]])
expect_equal(rownames(r2), names(mtcars)[1:5])
r3 <- cor_sort(r2)
expect_equal(all(rownames(r3) == names(mtcars)[1:5]), FALSE)
})

0 comments on commit b8767d8

Please sign in to comment.